Browse Source

support proper case expressions and constructors

master
Amélia Liao 4 years ago
parent
commit
e08eae14db
7 changed files with 197 additions and 90 deletions
  1. +45
    -5
      assemble.ml
  2. +107
    -28
      compile.ml
  3. +10
    -6
      driver.ml
  4. +10
    -38
      lang.ml
  5. +14
    -2
      parser.ml
  6. +11
    -5
      preamble.lua
  7. +0
    -6
      test.hs

+ 45
- 5
assemble.ml View File

@ -7,6 +7,7 @@ open import "./lang.ml"
let resolve_addr = function let resolve_addr = function
| Combinator n -> n ^ "_combinator" | Combinator n -> n ^ "_combinator"
| Arg i -> "stack[sp - " ^ show (i + 1) ^ "][3]" | Arg i -> "stack[sp - " ^ show (i + 1) ^ "][3]"
| Local i -> "stack[sp - " ^ show i ^ "]"
| Int i -> show i | Int i -> show i
let rec gm2lua = function let rec gm2lua = function
@ -15,22 +16,61 @@ let rec gm2lua = function
| Pop n -> | Pop n ->
" sp = sp - " ^ show n " sp = sp - " ^ show n
| Update n -> | Update n ->
" stack[sp - " ^ show (n + 1) ^ "] = { I, stack[sp] }; sp = sp - 1"
let it = "stack[sp - " ^ show (n + 1) ^ "]"
" if type(" ^ it ^ ") == 'table' then\n"
^ " " ^ it ^ "[1] = I; " ^ it ^ "[2] = stack[sp]\n"
^ " else " ^ it ^ " = stack[sp] end\n"
^ " sp = sp - 1"
| Mkap -> | Mkap ->
" stack[sp - 1] = { A, stack[sp - 1], stack[sp] }; sp = sp - 1" " stack[sp - 1] = { A, stack[sp - 1], stack[sp] }; sp = sp - 1"
| Unwind -> | Unwind ->
" return unwind(stack, sp)" " return unwind(stack, sp)"
| Eval -> " stack[sp] = eval(stack[sp])"
| Eval -> " stack[sp] = eval(stack, sp)"
| Add -> " stack[sp - 1] = stack[sp - 1] + stack[sp]; sp = sp - 1" | Add -> " stack[sp - 1] = stack[sp - 1] + stack[sp]; sp = sp - 1"
| Sub -> " stack[sp - 1] = stack[sp - 1] - stack[sp]; sp = sp - 1" | Sub -> " stack[sp - 1] = stack[sp - 1] - stack[sp]; sp = sp - 1"
| Div -> " stack[sp - 1] = stack[sp - 1] / stack[sp]; sp = sp - 1" | Div -> " stack[sp - 1] = stack[sp - 1] / stack[sp]; sp = sp - 1"
| Mul -> " stack[sp - 1] = stack[sp - 1] * stack[sp]; sp = sp - 1" | Mul -> " stack[sp - 1] = stack[sp - 1] * stack[sp]; sp = sp - 1"
| Alloc lim ->
let rec go acc n =
if n > 0 then
go (acc ^ ";\n stack[sp + " ^ show n ^ "] = {}") (n - 1)
else
acc ^ "; sp = sp + " ^ show lim
go "--" lim
| Slide n ->
" stack[sp - " ^ show n ^ "] = stack[sp]; sp = sp - " ^ show n
| Iszero (yes, no) -> | Iszero (yes, no) ->
" if stack[sp] == 0 then\n" " if stack[sp] == 0 then\n"
^ foldl (fun x i -> x ^ " " ^ gm2lua i) "" yes ^ "\n"
^ foldl (fun x i -> x ^ " " ^ gm2lua i ^ ";\n") "" yes
^ " else\n" ^ " else\n"
^ foldl (fun x i -> x ^ " " ^ gm2lua i) "" no ^ "\n"
^ foldl (fun x i -> x ^ " " ^ gm2lua i ^ ";\n") "" no
^ " end" ^ " end"
| Pack (arity, tag) ->
let rec go acc i =
if i > 0 then
go (acc ^ ", stack[sp - " ^ show (i - 1) ^ "]") (i - 1)
else
acc
let values = go "" arity
" stack[sp + 1] = {" ^ show tag ^ values ^ "}; sp = sp + 1"
| Casejump alts ->
let rec go con = function
| [] -> " error('unmatched case')"
| Cons ((arity, code : list _), alts) ->
(* Where is the constructor? stack[sp], then it moves to
* stack[sp - 1]. Generally: stack[sp - i], 0 <= i < arity *)
let rec go_arg i =
if i < arity then
" stack[sp + 1] = stack[sp - " ^ show i ^ "][" ^ show (i + 2) ^ "]; sp = sp + 1;\n"
^ go_arg (i + 1)
else
foldl (fun x i -> x ^ " " ^ gm2lua i ^ ";\n") "" code
" if stack[sp][1] == " ^ show con ^ " then\n"
^ go_arg 0
^ " else"
^ go (con + 1) alts
^ " end"
go 0 alts
let compute_local_set xs = let compute_local_set xs =
let rec go i (s : S.t string) = function let rec go i (s : S.t string) = function
@ -50,7 +90,7 @@ let compute_local_set xs =
let sc2lua (name, arity, body) = let sc2lua (name, arity, body) =
let body = let body =
body body
|> foldl (fun x s -> x ^ gm2lua s ^ ";\n") (name ^ " = function(stack, sp)\n")
|> foldl (fun x s -> x ^ "-- " ^ show s ^ "\n" ^ gm2lua s ^ ";\n") (name ^ " = function(stack, sp)\n")
|> (^ "end") |> (^ "end")
let dec = let dec =
name ^ "_combinator = { F, " ^ name ^ ", " ^ show arity ^ ", " ^ show name ^ " };" name ^ "_combinator = { F, " ^ name ^ ", " ^ show arity ^ ", " ^ show name ^ " };"


+ 107
- 28
compile.ml View File

@ -7,6 +7,7 @@ open import "./lib/monads.ml"
type addr = type addr =
| Combinator of string | Combinator of string
| Local of int
| Arg of int | Arg of int
| Int of int | Int of int
@ -14,32 +15,47 @@ type gm_code =
| Push of addr | Push of addr
| Update of int | Update of int
| Pop of int | Pop of int
| Slide of int
| Alloc of int
| Unwind | Unwind
| Mkap | Mkap
| Add | Sub | Mul | Div | Eval | Add | Sub | Mul | Div | Eval
| Iszero of list gm_code * list gm_code | Iszero of list gm_code * list gm_code
| Pack of int * int
| Casejump of list (int * list gm_code)
instance show gm_code begin instance show gm_code begin
let show = function let show = function
| Mkap -> "Mkap" | Mkap -> "Mkap"
| Unwind -> "Unwind" | Unwind -> "Unwind"
| Push (Combinator k) -> "Push " ^ k | Push (Combinator k) -> "Push " ^ k
| Push (Arg i) -> "Arg " ^ show i
| Push (Int i) -> "Int " ^ show i
| Push (Arg i) -> "Pusharg " ^ show i
| Push (Local i) -> "Pushlocal " ^ show i
| Push (Int i) -> "Pushint " ^ show i
| Update n -> "Update " ^ show n | Update n -> "Update " ^ show n
| Pop n -> "Pop " ^ show n
| Pop n -> "Pop " ^ show n
| Slide n -> "Slide " ^ show n
| Alloc n -> "Alloc " ^ show n
| Add -> "Add" | Add -> "Add"
| Mul -> "Mul" | Mul -> "Mul"
| Sub -> "Sub" | Sub -> "Sub"
| Div -> "Div" | Div -> "Div"
| Eval -> "Eval" | Eval -> "Eval"
| Iszero p -> "Iszero " ^ show p
| Pack (arity, tag) -> "Pack{" ^ show arity ^ "," ^ show tag ^ "}"
| Casejump xs -> "Casejump " ^ show xs
| Iszero xs -> "Iszero " ^ show xs
end end
type program_item = type program_item =
| Sc of string * int * list gm_code | Sc of string * int * list gm_code
| Fd of fdecl | Fd of fdecl
instance show program_item begin
let show = function
| Sc p -> show p
| Fd _ -> "<foreign item>"
end
let rec lambda_lift = function let rec lambda_lift = function
| Ref v -> pure (Ref v) | Ref v -> pure (Ref v)
| Lit v -> pure (Lit v) | Lit v -> pure (Lit v)
@ -60,10 +76,24 @@ let rec lambda_lift = function
put (i + 1, Decl def :: defs, known_sc) put (i + 1, Decl def :: defs, known_sc)
|> map (const app) |> map (const app)
| Case (sc, alts) -> | Case (sc, alts) ->
alts
|> map (fun (_, x) -> x)
|> foldl app sc
|> lambda_lift
let! sc = lambda_lift sc
let! alts = traverse (fun (c, args, e) -> (c,args,) <$> lambda_lift e) alts
let case = Case (sc, alts)
let! (i, defs, known_sc) = get
let vars =
case
|> free_vars
|> flip S.difference known_sc
|> S.members
let def = ("Lam" ^ show i, vars, case)
let app = foldl (fun f -> app f # Ref) (Ref ("Lam" ^ show i)) vars
put (i + 1, Decl def :: defs, known_sc)
|> map (const app)
| Let (vs, e) ->
let! vs = flip traverse vs @@ fun (v, e) ->
(v,) <$> lambda_lift e
let! e = lambda_lift e
pure (Let (vs, e))
let rec eta_contract = function let rec eta_contract = function
| Decl (n, a, e) as dec -> | Decl (n, a, e) as dec ->
@ -87,7 +117,9 @@ let rec lambda_lift_sc = function
let! _ = modify (fun (a, b, s) -> (a, b, S.insert n s)) let! _ = modify (fun (a, b, s) -> (a, b, S.insert n s))
pure (Decl (n, a, e)) pure (Decl (n, a, e))
| Data c -> Data c |> pure | Data c -> Data c |> pure
| Foreign i -> Foreign i |> pure
| Foreign (Fimport { var } as i) ->
let! _ = modify (second (second (S.insert var)))
Foreign i |> pure
type dlist 'a <- list 'a -> list 'a type dlist 'a <- list 'a -> list 'a
@ -99,42 +131,88 @@ let cg_prim (Fimport { var, fent }) =
, Push (Arg 2), Eval (* y, x, arg0, arg1, arg2, arg3 *) , Push (Arg 2), Eval (* y, x, arg0, arg1, arg2, arg3 *)
, Sub (* y - x, arg0, arg1, arg2, arg3 *) , Sub (* y - x, arg0, arg1, arg2, arg3 *)
, Iszero ([ Push (Arg 3) ], [ Push (Arg 4) ]) , Iszero ([ Push (Arg 3) ], [ Push (Arg 4) ])
, Push (Int 0), Mkap, Update 4, Pop 4, Unwind ]
, Update 4, Pop 4, Unwind ]
match fent with match fent with
| "add" -> Sc (var, 2, prim_math_op Add)
| "sub" -> Sc (var, 2, prim_math_op Sub)
| "mul" -> Sc (var, 2, prim_math_op Mul)
| "div" -> Sc (var, 2, prim_math_op Div)
| "equ" -> Sc (var, 4, prim_equality)
| "add" -> (Sc (var, 2, prim_math_op Add), Add)
| "sub" -> (Sc (var, 2, prim_math_op Sub), Sub)
| "mul" -> (Sc (var, 2, prim_math_op Mul), Mul)
| "div" -> (Sc (var, 2, prim_math_op Div), Div)
| "equ" -> (Sc (var, 4, prim_equality), Unwind)
| "seq" -> (Sc (var, 2, [ Push (Arg 0), Eval, Update 0, Push (Arg 2), Update 2, Pop 2, Unwind]), Eval)
| e -> error @@ "No such primitive " ^ e | e -> error @@ "No such primitive " ^ e
let rec compile (env : M.t string int) = function
type slot = As of int | Ls of int
let offs n = function
| As x -> As (x + n)
| Ls x -> Ls (x + n)
let incr = offs 1
let rec compile (env : M.t string slot) = function
| Ref v -> | Ref v ->
match M.lookup v env with match M.lookup v env with
| Some i -> (Push (Arg i) ::)
| Some (As i) -> (Push (Arg i) ::)
| Some (Ls i) -> (Push (Local i) ::)
| None -> (Push (Combinator v) ::) | None -> (Push (Combinator v) ::)
| App (f, x) -> | App (f, x) ->
let f = compile env f let f = compile env f
let x = compile (map (1 +) env) x
let x = compile (map incr env) x
f # x # (Mkap ::) f # x # (Mkap ::)
| Lam _ -> | Lam _ ->
error "Can not compile lambda expression, did you forget to lift?" error "Can not compile lambda expression, did you forget to lift?"
| Case _ ->
error "Can not compile case expression, did you forget to lift?"
| Case (sc, alts) ->
let rec go_alts = function
| [] -> []
| Cons ((_, args, exp), rest) ->
let c_arity = length args
let env =
args
|> flip zip [Ls k | with k <- [c_arity - 1, c_arity - 2 .. 0]]
|> M.from_list
|> M.union (offs (c_arity + 1) <$> env)
(c_arity, compile env exp [Slide c_arity]) :: go_alts rest
compile env sc # (Eval ::) # (Casejump (go_alts alts) ::)
| Lit i -> (Push (Int i) ::) | Lit i -> (Push (Int i) ::)
| Let (vs, e) ->
let n = length vs
let env =
vs
|> map (fun (x, _) -> x)
|> flip zip [Ls x | with x <- [n - 1, n - 2 .. 0]]
|> M.from_list
|> M.union (offs n <$> env)
let defs = zip [1..n] vs
let rec go : list (int * string * expr) -> dlist gm_code = function
| [] -> id
| Cons ((k, (_, exp)), rest) ->
compile env exp # (Update (n - k) ::) # go rest
(Alloc n ::) # go defs # compile env e # (Slide n ::)
let supercomb (_, args, exp) = let supercomb (_, args, exp) =
let env = M.from_list (zip args [0..length args]) let env = M.from_list (zip args [0..length args])
let k = compile (M.from_list (zip args [0..length args])) exp
let k = compile (M.from_list (zip args (As <$> [0..length args]))) exp
k [Update (length env), Pop (length env), Unwind] k [Update (length env), Pop (length env), Unwind]
let known_scs = S.from_list [ "getchar", "putchar" ]
let compile_cons =
let rec go i = function
| [] -> []
| Cons (Constr (n, args), rest) ->
let arity = length args
let rec pushargs i =
if i < arity then
Push (Arg (2 * i)) :: pushargs (i + 1)
else
[]
Sc (n, arity, pushargs 0 ++ [ Pack (arity, i), Update (2 * arity), Pop (2 * arity), Unwind ])
:: go (i + 1) rest
go 0
let program decs = let program decs =
let (decs, (_, lams, _)) = let (decs, (_, lams, _)) =
run_state (traverse (lambda_lift_sc # eta_contract) decs) (0, [], known_scs)
run_state (traverse (lambda_lift_sc # eta_contract) decs)
(0, [], S.empty)
let define nm = let define nm =
let! x = get let! x = get
if nm `S.member` x then if nm `S.member` x then
@ -147,11 +225,12 @@ let program decs =
| Decl ((nm, args, _) as sc) -> | Decl ((nm, args, _) as sc) ->
let! _ = define nm let! _ = define nm
let code = supercomb sc let code = supercomb sc
Sc (nm, length args, code) |> pure
| Data _ -> error "data declaration in compiler"
[Sc (nm, length args, code)] |> pure
| Data (_, _, cs) -> pure (compile_cons cs)
| Foreign (Fimport { cc = Prim, var } as fi) -> | Foreign (Fimport { cc = Prim, var } as fi) ->
let! _ = define var let! _ = define var
pure (cg_prim fi)
| Foreign f -> pure (Fd f)
let (code, _) = cg_prim fi
pure [code]
| Foreign f -> pure [Fd f]
let (out, _) = run_state go S.empty let (out, _) = run_state go S.empty
out
join out

+ 10
- 6
driver.ml View File

@ -4,6 +4,8 @@ open import "./parser.ml"
open import "prelude.ml" open import "prelude.ml"
open import "lua/io.ml" open import "lua/io.ml"
external val dofile : string -> () = "dofile"
let printerror (e, { line, col }) = let printerror (e, { line, col }) =
put_line @@ "line " ^ show line ^ ", col " ^ show col ^ ":" put_line @@ "line " ^ show line ^ ", col " ^ show col ^ ":"
print e print e
@ -16,7 +18,6 @@ let go infile outfile =
match lex prog str with match lex prog str with
| Right (ds, _) -> | Right (ds, _) ->
ds ds
|> ds_prog
|> C.program |> C.program
|> A.assm_program |> A.assm_program
|> write_bytes outfile |> write_bytes outfile
@ -25,14 +26,17 @@ let go infile outfile =
close_file infile close_file infile
close_file outfile close_file outfile
let go' infile outfile =
go infile outfile
dofile outfile
let test str = let test str =
match lex prog str with match lex prog str with
| Right (ds, _) -> | Right (ds, _) ->
ds
|> ds_prog
|> C.program
|> A.assm_program
|> put_line
let code = ds |> C.program
let lua = code |> A.assm_program
print code
put_line lua
| Left e -> printerror e | Left e -> printerror e
let test_file infile = let test_file infile =


+ 10
- 38
lang.ml View File

@ -6,8 +6,9 @@ type expr =
| Ref of string | Ref of string
| App of expr * expr | App of expr * expr
| Lam of string * expr | Lam of string * expr
| Case of expr * list (string * expr)
| Case of expr * list (string * list string * expr)
| Lit of int | Lit of int
| Let of list (string * expr) * expr
let app = curry App let app = curry App
let lam = curry Lam let lam = curry Lam
@ -18,21 +19,18 @@ let rec free_vars = function
| Lam (v, x) -> v `S.delete` free_vars x | Lam (v, x) -> v `S.delete` free_vars x
| Case (e, bs) -> | Case (e, bs) ->
bs bs
|> map (fun (_, e) -> free_vars e)
|> map (fun (_, a, e) -> free_vars e `S.difference` S.from_list a)
|> foldl S.union S.empty |> foldl S.union S.empty
|> S.union (free_vars e) |> S.union (free_vars e)
| Let (vs, b) ->
let bound = S.from_list (map (fun (x, _) -> x) vs)
vs
|> map (fun (_, e) -> free_vars e)
|> foldl S.union S.empty
|> S.union (free_vars b)
|> flip S.difference bound
| Lit _ -> S.empty | Lit _ -> S.empty
let rec subst m = function
| Ref v ->
match M.lookup v m with
| Some s -> s
| None -> Ref v
| App (f, x) -> App (subst m f, subst m x)
| Lam (v, x) -> Lam (v, subst (M.delete v m) x)
| Case (e, xs) -> Case (subst m e, map (second (subst m)) xs)
| Lit x -> Lit x
type hstype = type hstype =
| Tycon of string | Tycon of string
| Tyvar of string | Tyvar of string
@ -62,29 +60,3 @@ type decl =
| Foreign of fdecl | Foreign of fdecl
type prog <- list decl type prog <- list decl
let rec xs !! i =
match xs, i with
| Cons (x, _), 0 -> x
| Cons (_, xs), i when i > 0 -> xs !! (i - 1)
| _, _ -> error "empty list and/or negative index"
let ds_data (_, _, cs) =
let ncons = length cs
let alts = map (("c" ^) # show) [1..ncons]
let ds_con i (Constr (n, args)) =
let arity = length args
let alt = alts !! i
let args = map (("x" ^) # show) [1..arity]
Decl (n, args, foldr lam (foldl app (Ref alt) (map Ref args)) alts)
let rec go i = function
| [] -> []
| Cons (x, xs) -> ds_con i x :: go (i + 1) xs
go 0 cs
let ds_prog prog =
let! c = prog
match c with
| Data c -> ds_data c
| Decl (n, args, e) -> [Decl (n, args, e)]
| Foreign d -> [Foreign d]

+ 14
- 2
parser.ml View File

@ -32,10 +32,22 @@ and expr : forall 'm. monad 'm => parser_t 'm expr =
let! vs = many (try varid) let! vs = many (try varid)
let! _ = arrow let! _ = arrow
let! e = expr let! e = expr
pure (c, foldr ((Lam #) # curry id) e vs)
pure (c, vs, e)
) )
pure (Case (e, arms)) pure (Case (e, arms))
try lam <|> try case <+> fexp
let hslet =
let binding =
let! c = varid
let! vs = many (try varid)
let! _ = equals
let! e = expr
pure (c, foldr ((Lam #) # curry id) e vs)
let! _ = keyword "let"
let! bs = laid_out_block binding
let! _ = keyword "in"
let! b = expr
pure (Let (bs, b))
try lam <|> try case <|> try hslet <+> fexp
let rec ty_atom : forall 'm. monad 'm => parser_t 'm hstype = let rec ty_atom : forall 'm. monad 'm => parser_t 'm hstype =
map Tyvar (try varid) map Tyvar (try varid)


+ 11
- 5
preamble.lua View File

@ -16,6 +16,7 @@ local function unwind(stack, sp)
error("insufficient arguments for supercombinator " .. x[4]) error("insufficient arguments for supercombinator " .. x[4])
end end
end end
return x
else else
return x, stack, sp return x, stack, sp
end end
@ -24,21 +25,26 @@ end
local function repr(x) local function repr(x)
if type(x) == 'table' then if type(x) == 'table' then
if x[1] == A then if x[1] == A then
return repr(x[2]) .. '(' .. repr(x[3])
return repr(x[2]) .. '(' .. repr(x[3]) .. ')'
elseif x[1] == F then elseif x[1] == F then
return x[4] return x[4]
elseif x[1] == I then elseif x[1] == I then
return '&' .. repr(x[2]) return '&' .. repr(x[2])
end end
return '<bad node>'
local r = {}
for k, v in pairs(x) do
r[k] = repr(v)
end
return '{' .. table.concat(r, ', ') .. '}'
else else
return tostring(x) return tostring(x)
end end
end end
local function eval(node)
local stack, sp = { node }, 1
return (unwind(stack, sp))
local function eval(stack, sp)
local nf = (unwind({ stack[sp] }, 1))
stack[sp] = { I, nf }
return nf
end end
local function getchar(stack, sp) local function getchar(stack, sp)


+ 0
- 6
test.hs View File

@ -1,6 +0,0 @@
data List a = Nil | Cons a (List a);
map f xs = case xs of { Nil -> Nil; Cons x xs -> Cons (f x) (map f xs) };
readall k = getchar (\ch -> readall (\xs -> k (Cons ch xs))) (\ch -> k Nil);
putall x xs = case xs of { Nil -> x; Cons x xs -> putchar x (\ch -> putall x xs) };
id x = x;
main x = readall (putall x);

Loading…
Cancel
Save