{-
    Kaya - My favourite toy language.
    Copyright (C) 2004, 2005 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

module CodegenCPP where

import TAC
import Language
import IO
import Debug.Trace
import Lib

data Output = RawOutput String
	    | FNOutput (Name,String)
	    | ExternOutput Name
  deriving Show

{-
writeCf :: InputType -> [FilePath] -> [CompileResult] -> FilePath -> IO ()
writeCf domain lds xs out = do mprog <- getmain lds domain
			       header <- getheader lds domain
			       let str = ((writeout header mprog).writecpp) xs
			       writeFile out str
-}

writeC :: Name -> -- Module name
	  InputType -> [FilePath] -> Context -> 
	  [CompileResult] -> Handle -> IO ()
writeC mod domain lds ctxt xs out 
    = do mprog <- getmain mod lds ctxt domain
	 header <- getheader mod lds ctxt domain
	 let str = ((writeout header mprog).writecpp) xs
	 hPutStr out str
	 hClose out

writeout :: String -> String -> [Output] -> String
writeout header mprog xs = header ++
		writedecls xs ++
		writeout' xs ++ mprog

writeout' [] = ""
writeout' ((RawOutput str):xs) = str ++ "\n" ++ writeout' xs
writeout' ((ExternOutput f):xs) = "void " ++ (show f) ++ "(VMState* vm);\n" 
				  ++ writeout' xs
writeout' ((FNOutput (f,def)):rest) 
    = "void " ++ (show f) ++ "(VMState* vm){\n" ++ def ++ "}\n"
      ++ writeout' rest 

writedecls [] = ""
writedecls ((FNOutput (f,def)):rest) 
    = "void " ++ (show f) ++ "(VMState* vm);\n"
      ++ writedecls rest 
writedecls (_:xs) = writedecls xs

getmain,getheader :: Name -> [FilePath] -> Context -> InputType -> IO String
getmain mod lds ctxt Program = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str 
      return str'
getmain mod lds ctxt Shebang = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str 
      return str'
getmain mod lds ctxt Module = return ""
getmain mod lds ctxt Webapp = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str 
      return str'

-- getmain mod lds ctxt Program = 
--    do str <- findFile lds "program.vcc"
--       str' <- replaceDefs mod ctxt ["main"] str 
--       return str'
-- getmain mod lds ctxt Shebang = 
--    do str <- findFile lds "program.vcc"
--       str' <- replaceDefs mod ctxt ["main"] str 
--       return str'
-- getmain mod lds ctxt Module = return ""
-- getmain mod lds ctxt Webapp = 
--     do str <- findFile lds "webapp.vcc"
--        str' <- replaceDefs mod ctxt ["PreContent","Default","PostContent",
-- 				   "initWebApp","flush","IllegalHandler"] str 
--        return str'

getheader mod lds ctxt _ = 
   do hf <- findFile lds "header.vcc"
      return $ hf ++ "\nValue** globaltable"++show mod++"=NULL;\n\n"

getStartup :: InputType -> [FilePath] -> IO String
getStartup Program lds = findFile lds "program.ks"
getStartup Webapp lds = findFile lds "webapp.ks"
getStartup _ _ = return ""

replaceDefs :: Name -> Context -> [String] -> String -> IO String
replaceDefs mod ctxt [] x = return x
replaceDefs mod ctxt (x:xs) str = 
    do newstr <- replaceDefs mod ctxt xs str
       (fname,_) <- ctxtlookup mod (UN x) ctxt
       let str' = replace ("%"++x) (show fname) newstr
       return str'

replace :: String -> String -> String -> String
replace _ _ "" = ""
replace old new xs 
   | take (length old) xs == old 
       = new ++ replace old new (drop (length old) xs)
replace old new (x:xs) = x:(replace old new xs)

findFile :: [FilePath] -> FilePath -> IO String
findFile [] path
  = fail $ "Can't find " ++ path
findFile (x:xs) path 
  = catch
         (do --putStrLn $ "Trying " ++ x ++ path
	     f <- readFile (x++path)
	     return f)
         (\e -> findFile xs path)

writecpp :: [CompileResult] -> [Output]
writecpp [] = []
writecpp ((RawCode str):xs) = (RawOutput str):(writecpp xs)
writecpp ((ByteCode (n,def)):xs) = {- trace (show n ++ show def) $-} ((FNOutput (n,cpp def)):writecpp xs)
writecpp ((ExternDef n):xs) = (ExternOutput n):writecpp xs

cpp :: [TAC] -> String
cpp [] = ""
cpp (x:xs) = "\t" ++ instr x ++ ";\n" ++ cpp xs

printOp Plus = "+"
printOp Minus = "-"
printOp Times = "*"
printOp Divide = "/"
printOp Modulo = "%"
printOp OpLT = "<"
printOp OpGT = ">"
printOp OpLE = "<="
printOp OpGE = ">="
printOp Equal = "=="
printOp NEqual = "!="
printOp OpAnd = "&"
printOp OpOr = "|"
printOp OpAndBool = "&&"
printOp OpOrBool = "||"
printOp OpXOR = "^"
printOp OpShLeft = "<<"
printOp OpShRight = ">>"
printOp BAnd = "&&"
printOp BOr = "||"

printUnOp Not = "!"
printUnOp Neg = "-"

instr :: TAC -> String
instr (DECLARE v) = "DECLARE("++show v++")"
instr (DECLAREARG v) = "DECLAREARG("++show v++")"
instr (DECLAREQUICK v) = "DECLAREARG("++show v++")"
instr (USETMP v) = "DECLAREARG("++show v++")"
instr (TMPINT i) = "TMPINT(t"++show i ++")"
instr (TMPREAL i) = "TMPREAL(t"++show i ++")"
instr (SET var idx val) = "SET("++show var++","++show idx ++ ","++show val++")"
instr TOINDEX = "TOINDEX"
instr SETTOP = "SETTOP"
instr ADDTOP = "ADDTOP"
instr SUBTOP = "SUBTOP"
instr MULTOP = "MULTOP"
instr DIVTOP = "DIVTOP"
instr (MKARRAY i) = "MKARRAY("++show i++")"
instr (TMPSET tmp val) = "t"++show tmp++"="++show val
instr (RTMPSET tmp val) = "t"++show tmp++"="++show val
instr (CALL v) = "CALLFUN("++show v++")"
instr (CALLNAME f) = "CALL("++show f++")"
instr CALLTOP = "CALLTOP"
instr (TAILCALL v) = "TAILCALLFUN("++show v++")"
instr (TAILCALLNAME f) = "TAILCALL("++show f++")"
instr TAILCALLTOP = "TAILCALLTOP"
instr (CLOSURE n i) = "CLOSURE("++show n ++","++show i++")"
instr (FOREIGNCALL n lib ty args) = mkfcall n lib ty args 
instr (MKCON t i) = "MKCON("++show t ++ "," ++ show i ++ ")"
instr MKEXCEPT = "MKEXCEPT"
instr EQEXCEPT = "EQEXCEPT"
instr NEEXCEPT = "NEEXCEPT"
instr EQSTRING = "EQSTRING"
instr NESTRING = "NESTRING"
instr (GETVAL v) = "GETVAL(t"++show v++")"
instr (GETRVAL v) = "GETRVAL(t"++show v++")"
instr GETINDEX = "GETINDEX"
instr (PROJARG a t) = "PROJARG("++ show a ++ "," ++ show t ++ ")"
instr (INFIX t op x y) = "INTINFIX(" ++ tmp t ++ "," ++ printOp op ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (INTPOWER t x y) = "INTPOWER(" ++ tmp t ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (REALINFIX t op x y) = "REALINFIX(" ++ tmp t ++ "," ++
			     printOp op ++ "," ++ 
			     tmp x ++ "," ++ tmp y ++ ")"
instr (REALINFIXBOOL op x y) = "REALINFIXBOOL(" ++ printOp op ++ "," ++ 
			       tmp x ++ "," ++ tmp y ++ ")"
instr (REALPOWER t x y) = "REALPOWER(" ++ tmp t ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (UNARY t op x) = "INTUNARY("++tmp t++","++
		       printUnOp op ++ "," ++ tmp x ++ ")"
instr (REALUNARY t op x) = "REALUNARY("++tmp t++","++
			   printUnOp op ++ "," ++ tmp x ++ ")"
instr APPEND = "APPEND"
instr PRINTINT = "PRINTINT"
instr PRINTSTR = "PRINTSTR"
instr PRINTEXC = "PRINTEXC"
instr NEWLINE = "NEWLINE"
instr (LABEL l) = "LABEL("++show l++")"
instr (JUMP l) = "JUMP("++show l++")"
instr (JFALSE l) = "JFALSE("++show l++")"
instr (JTRUE l) = "JTRUE("++show l++")"
instr (JTFALSE t l) = "JTFALSE("++tmp t++","++show l++")"
instr (JTTRUE t l) = "JTTRUE("++tmp t++","++show l++")"
instr (TRY n) = "TRY("++ show n ++ ")"
instr TRIED = "TRIED"
instr THROW = "THROW"
instr RESTORE = "RESTORE"
instr (PUSH i) = "PUSH("++pushitem i++")"
instr (PUSH2 x y) = "PUSH2("++pushitem x++","++pushitem y++")"
instr (PUSH3 x y z) = "PUSH3("++pushitem x++","++pushitem y++","++
		                pushitem z++")"
instr (PUSH4 x y z w) = "PUSH4("++pushitem x ++ "," ++ pushitem y ++ "," ++
		                pushitem z ++ "," ++ pushitem w ++ ")"
instr (PUSHSETTOP x) = "PUSHSETTOP("++pushitem x++")"
instr (PUSHGETVAL x t) = "PUSHGETVAL("++pushitem x++",t"++show t++")"
instr (PUSHGETRVAL x t) = "PUSHGETRVAL("++pushitem x++",t"++show t++")"

instr (PUSHGLOBAL x i) = "PUSHGLOBAL(globaltable"++x++"," ++show i ++ ")"
instr (CREATEGLOBAL x i) = "CREATEGLOBAL("++show x ++ "," ++ show i++")"
instr RETURN = "return"
instr (SETVAL v x) = "SETVAL("++show v++","++show x++")"
instr (SETINT v x) = "SETINT("++show v++","++tmp x++")"
instr (SETVAR v x) = "SETVAR("++show v++","++show x++")"
instr (GETLENGTH) = "GETLENGTH"
instr (POP v) = "POP("++show v++")"
instr (POPARG v) = "POPARG("++show v++")"
instr (POPINDEX v) = "POPINDEX("++show v++")"
instr (ARRAY v) = "ARRAY("++show v++")"
instr (PROJ v i) = "PROJ("++show v ++ ","++show i++")"
instr DISCARD = "DISCARD"
instr (CASE as) = "switch(TAG) {\n" ++ (instrCases 0 as) ++ "\t}"
instr STR2INT = "STR2INT"
instr INT2STR = "INT2STR"
instr REAL2STR = "REAL2STR"
instr BOOL2STR = "BOOL2STR"
instr STR2REAL = "STR2REAL"
instr CHR2STR = "CHR2STR"
instr INT2REAL = "INT2REAL"
instr REAL2INT = "REAL2INT"
instr VMPTR = "VMPTR"
{- CIM 12/7/05 changed to KERROR to avoid clashes with MinGW -}
instr ERROR = "KERROR"
instr _ = "NOP"

pushitem (NAME n) = "MKFUN("++show n++")"
pushitem (VAL x) = "MKINT("++show x++")"
pushitem (RVAL x) = "MKREAL("++show x++")"
pushitem (STR x) = "MKSTR("++show x++")"
pushitem (INT t) = "MKINT("++tmp t++")"
pushitem (REAL t) = "MKREAL("++tmp t++")"
pushitem (VAR v) = show v

mkfcall n lib ty args = popvals 0 (length args) ++
			(conv ty) ++ "("++n++"("++stackconv 0 args++")))"
 where
    popvals n 0 = ""
    popvals n (a+1) = show (tmpval n) ++ 
		    " = vm->doPop(); " ++ popvals (n+1) a

{- CIM 12/7/05 changed to KVOID to avoid clashes with MinGW -}
    conv (Prim Void) = "KVOID("
    conv (Prim Number) = "PUSH(MKINT"
    conv (Prim RealNum) = "PUSH(MKREAL"
    conv (Prim Boolean) = "PUSH(MKINT"
    conv (Prim Character) = "PUSH(MKCHAR"
    conv (Prim StringType) = "PUSH(MKSTR"
    conv (Prim File) = "PUSH(MKINT"
    conv (Prim Pointer) = "PUSH(MKINT"
    conv (TyVar _) = "PUSH(" -- Enough rope to hang yourself with!
    conv (Array _) = "PUSH(MKARRAYVAL"
    conv t = "PUSH(" -- error $ "Can't deal with that type in foreign calls" ++ show t

    stackconv n [] = ""
    stackconv n [x] = stackconv' n x
    stackconv n (x:xs) = stackconv' n x ++ "," ++ stackconv (n+1) xs
    stackconv' n (Prim Number) = show (tmpval n) ++ "->getInt()"
    stackconv' n (Prim RealNum) = show (tmpval n) ++ "->getReal()"
    stackconv' n (Prim Boolean) = show (tmpval n) ++ "->getInt()"
    stackconv' n (Prim Character) = show (tmpval n) ++ "->getInt()"
    stackconv' n (Prim StringType) 
	= show (tmpval n) ++ "->getString()->getVal()"
    stackconv' n (Prim File) 
	= "(FILE*)("++show (tmpval n) ++ "->getRaw())"
    stackconv' n (Prim Pointer) = show (tmpval n) ++ "->getRaw()"
    stackconv' n (Array _) = show (tmpval n) ++ "->getArray()"
    stackconv' n (TyVar _) = show (tmpval n)
    stackconv' n t = show (tmpval n) -- error $ "Can't deal with that type (" ++ show t ++ ") in foreign calls"


instrCases :: Int -> [[TAC]] -> String
instrCases v [] = ""
instrCases v (x:xs) = "\tcase " ++ show v ++ ":\n" ++ cpp x ++ "\tbreak;\n"
		      ++ instrCases (v+1) xs
