Browse Source

add a new code-gen based on the STG machine

master
Amélia Liao 3 years ago
parent
commit
5595b4946a
8 changed files with 563 additions and 12 deletions
  1. +13
    -7
      assemble.ml
  2. +2
    -2
      compile.ml
  3. +32
    -0
      driver.ml
  4. +2
    -2
      lib/parsers.ml
  5. +186
    -0
      stg/lower.ml
  6. +270
    -0
      stg/output.ml
  7. +53
    -0
      stg/stg.ml
  8. +5
    -1
      tc.ml

+ 13
- 7
assemble.ml View File

@ -96,6 +96,8 @@ let sc2lua (name, arity, body) =
name ^ "_combinator = { F, " ^ name ^ ", " ^ show arity ^ ", " ^ show name ^ " };"
body ^ "\n" ^ dec
let private pasted_files : ref (S.t string) = ref S.empty
let foreign2lua (Fimport { cc, fent = fspec, var, ftype }) =
let (file, fspec) =
match Strings.split_on " " fspec with
@ -123,12 +125,16 @@ let foreign2lua (Fimport { cc, fent = fspec, var, ftype }) =
let contents =
match file with
| Some path ->
let f = open_for_reading path
let x = read_all f
close_file f
match x with
| Some s -> "--- " ^ path ^ "\n" ^ s ^ "\n"
| None -> ""
if path `S.member` !pasted_files then
""
else
pasted_files := S.insert path !pasted_files
let f = open_for_reading path
let x = read_all f
close_file f
match x with
| Some s -> "--- " ^ path ^ "\n" ^ s ^ "\n"
| None -> ""
| None -> ""
contents ^ wrapper ^ "\n" ^ dec
@ -153,4 +159,4 @@ let assm_program decs =
let local_decs =
foldl (fun x v -> x ^ ", " ^ v) ("local " ^ local1) locals
let body = foldl (fun x s -> x ^ codegen s ^ "\n") "" decs
preamble ^ local_decs ^ "\n" ^ body ^ "unwind({{ A, main_combinator, 123 }}, 1)"
preamble ^ local_decs ^ "\n" ^ body ^ "unwind({main_combinator}, 1)"

+ 2
- 2
compile.ml View File

@ -89,8 +89,8 @@ let rec lambda_lift strict = function
|> free_vars
|> flip S.difference known_sc
|> S.members
let def = ("Lam" ^ show i, vars, case)
let app = foldl (fun f -> app f # Ref) (Ref ("Lam" ^ show i)) vars
let def = ("Case" ^ show i, vars, case)
let app = foldl (fun f -> app f # Ref) (Ref ("Case" ^ show i)) vars
put (i + 1, Decl def :: defs, known_sc)
|> map (const app)
| Let (vs, e) ->


+ 32
- 0
driver.ml View File

@ -5,6 +5,9 @@ open import "./parser.ml"
open import "prelude.ml"
open import "lua/io.ml"
module Stg = import "./stg/lower.ml"
module Out = import "./stg/output.ml"
external val dofile : string -> () = "dofile"
let printerror (e, { line, col }) =
@ -53,6 +56,35 @@ let test_file infile =
| None -> ()
close_file infile
let rec take n xs =
match n, xs with
| _, [] -> []
| 0, _ -> []
| n, Cons (x, xs) -> Cons (x, take (n - 1) xs)
let go_stg infile outfile =
let infile = open_for_reading infile
let outfile = open_file outfile Write_m
match read_all infile with
| Some str ->
match lex prog str with
| Right (ds, _) ->
let decs =
ds |> T.tc_program [] []
|> fun (_, _, z) -> z
|> flip (>>=) Stg.lower_dec
let (_, sts, locals) = foldl Out.stmts_of_def (M.empty, [], []) decs
write_bytes outfile "local Constr_mt = { __call = function(x) return x end }\n"
Out.get_file_contents () |> (^"\n") |> write_bytes outfile
write_bytes outfile (Out.mk_pap_def ^ "\n")
write_bytes outfile (Out.ppr_stmt "" (Out.Local (take 100 locals, [])) ^ "\n")
iter (write_bytes outfile # (^"\n") # Out.ppr_stmt "") sts
write_bytes outfile "main_entry(function() return 'the real world is fake' end)\n"
| Left e ->
printerror e
| None -> ()
close_file infile
close_file outfile
external val args : string * string =
"{ _1 = select(1, ...), _2 = select(2, ...) }"


+ 2
- 2
lib/parsers.ml View File

@ -234,7 +234,7 @@ let chainr1 p op =
pure (f x y)
) <|> pure x
let _ = rest (* shut up, amc *)
scan
force scan
let parse (P x) s =
let! x = x { line = 0, col = 0 } s
@ -254,5 +254,5 @@ let lift m = P @@ fun pos s ->
let! x = m
pure @@ Ok (E, x, s, pos)
let morph (k : forall 'a. 'm 'a -> 'n 'a) (P x) = P @@ fun s -> k (x s)
let morph (k : forall 'a. 'm 'a -> 'n 'a) (P x) = P @@ fun p s -> k (x p s)
let void x = map (const ()) x

+ 186
- 0
stg/lower.ml View File

@ -0,0 +1,186 @@
open import "prelude.ml"
module Src = import "../lang.ml"
open Src
module Map = import "data/map.ml"
module Set = import "data/set.ml"
module Stg = import "./stg.ml"
let spine f =
let rec spine = function
| App (f, x) ->
let (f, args) = spine f
(f, x :: args)
| e -> (e, [])
let (f, args) = spine f
(f, reverse args)
let rec napp f = function
| [] -> f
| Cons (x, xs) -> napp (App (f, x)) xs
let get_con_arities prog =
let go_con m (Constr (name, tys)) = M.insert name (length tys) m
let go m = function
| Data (_, _, cons) -> foldl go_con m cons
| _ -> m
foldl go M.empty prog
let gensym =
let counter = ref 0
fun () ->
counter := !counter + 1
"_s" ^ show !counter
let rec add_n_args da exp =
if da <= 0 then
exp
else
let var = gensym ()
Lam (var, App (add_n_args (da - 1) exp, Ref var))
let rec eta_expand_cons arities =
let rec go = function
| Case (exp, alts) -> Case (go exp, map (second (second go)) alts)
| Lam (var, alts) -> Lam (var, go alts)
| If (a, b, c) -> If (go a, go b, go c)
| Let (decs, body) -> Let (map (second go) decs, go body)
| exp ->
match spine exp with
| Ref func, args ->
let arg_len = length args
match Map.lookup func arities with
| Some arity when arity > arg_len -> add_n_args (arity - arg_len) (napp (Ref func) args)
| _ -> exp
| _, _ -> error @@ "What?"
go
let build_stg_app func = function
| [] -> Stg.Atom func
| args -> Stg.(App (func, args))
let mk_lambda_form name exp =
let free_vars = Stg.stg_fv exp
{ name, free_vars, update = Stg.Updatable, args = [], body = exp }
let mk_function name args exp =
let free_vars = foldl (flip Set.delete) (Stg.stg_fv exp) args
{ name, free_vars, update = Stg.Function, args, body = exp }
let rec unlambda = function
| Lam (var, body) ->
let (args, body) = unlambda body
(var :: args, body)
| e -> ([], e)
let rec lower_spine (func, args) kont =
lower_atom func @@ fun func ->
let rec go kont lowered = function
| [] -> kont (build_stg_app func (reverse lowered))
| Cons (Ref e, args) ->
go kont (Stg.Var e :: lowered) args
| Cons (Lit i, args) ->
go kont (Stg.Int i :: lowered) args
| Cons (arg, args) ->
lower_atom arg @@ fun arg ->
go kont (arg :: lowered) args
go kont [] args
and lower exp kont =
match spine exp with
| exp, [] ->
match exp with
| App _ -> error @@ "Impossible lower App with empty spine"
(* STG atoms *)
| Ref e -> kont Stg.(Atom (Var e))
| Lit e -> kont Stg.(Con (0, 1, [Int e]))
(* Lambdas need to be bound as lambda-forms *)
| Lam _ as lam ->
let name = gensym ()
let (args, body) = unlambda lam
let body = lower_body body
Stg.Let ([mk_function name args body], kont Stg.(Atom (Var name)))
| If (cond, th, el) ->
lower cond @@ fun cond ->
lower th @@ fun th ->
lower el @@ fun el ->
Stg.( Case (cond, "binder" ^ gensym(), [(Con_pat (0, []), th), (Default, el)]) )
|> kont
| Let (bindings, body) ->
lower_binds bindings @@ fun lambda_forms ->
Stg.Let (lambda_forms, lower body kont)
| Case (scrut, arms) ->
lower scrut @@ fun scrut ->
lower_arms arms @@ fun arms ->
Stg.Case (scrut, "binder" ^ gensym(), arms) |> kont
| e -> lower_spine e kont
and lower_atom exp kont =
lower exp @@ function
| Stg.Atom at -> kont at
| e ->
let name = gensym ()
Stg.(Let ([mk_lambda_form name e], kont (Var name)))
and lower_binds bindings kont =
let rec go acc = function
| [] -> kont (reverse acc)
| Cons ((name, bind), rest) ->
go (lower_rhs name bind :: acc) rest
go [] bindings
and lower_arms arms kont =
let rec go i acc = function
| [] -> kont (reverse acc)
| Cons ((_, args, exp), rest) ->
let body = lower_body exp
go (i + 1) ((Stg.(Con_pat (i, args)), body) :: acc) rest
go 0 [] arms
and lower_rhs name exp =
match exp with
| Lam _ as lam ->
let (args, body) = unlambda lam
let body = lower_body body
mk_function name args body
| _ ->
let body = lower_body exp
mk_lambda_form name body
and lower_body exp = lower exp (fun x -> x)
let mk_stg_prim name prim =
let binary_prim x =
let open Stg
let body =
Case (Atom (Var "x"), "x",
[( Default, Case (Atom (Var "y"), "y",
[(Default, Prim (x, [Var "x", Var "y"]))]))])
Fun { name, args = ["x", "y"], body, is_con = None }
match prim with
| "add" -> binary_prim Stg.Add
| "sub" -> binary_prim Stg.Sub
| "mul" -> binary_prim Stg.Mul
| "div" -> binary_prim Stg.Div
| "equ" -> binary_prim Stg.Equ
| e -> error @@ "No such primitive " ^ e
let lower_dec = function
| Decl (name, manifest_args, expr) ->
let (args, body) = unlambda expr
let args = manifest_args ++ args
let body = lower_body body
[ Stg.Fun { name, args, body, is_con = None } ]
| Data (_, _, cons) ->
let mk_stg_con (Constr (name, args), i) =
let args = [ gensym () | with _ <- args ]
Stg.Fun { name, args, body = build_stg_app (Stg.Var name) (Stg.Var <$> args), is_con = Some i }
[ mk_stg_con c | with c <- zip cons [0 .. length cons - 1] ]
| Foreign (Fimport { cc = Prim, fent = prim, var = name }) ->
[ mk_stg_prim name prim ]
| Foreign (Fimport { cc = Lua, fent, var, ftype }) ->
[ Stg.Ffi_def { name = var, fent, arity = arity ftype }]

+ 270
- 0
stg/output.ml View File

@ -0,0 +1,270 @@
module Stg = import "./stg.ml"
module Map = import "data/map.ml"
module Set = import "data/set.ml"
module Strings = import "../lib/strings.ml"
open Stg
open import "lua/io.ml"
open import "prelude.ml"
type lua_ref 'expr =
| Lvar of string
| Lindex of lua_ref 'expr * 'expr
type lua_expr 'stmt =
| Lfunc of list string * list 'stmt
| Lcall_e of lua_expr 'stmt * list (lua_expr 'stmt)
| Lstr of string
| Lint of int
| Lref of lua_ref (lua_expr 'stmt)
| Lbop of lua_expr 'stmt * string * lua_expr 'stmt
| Ltable of list (lua_expr 'stmt * lua_expr 'stmt)
| Ltrue
| Ldots
type lua_stmt =
| Return of lua_expr lua_stmt
| Local of list string * list (lua_expr lua_stmt)
| Func of string * list string * list lua_stmt
| Assign of list (lua_ref (lua_expr lua_stmt)) * list (lua_expr lua_stmt)
| If of lua_expr lua_stmt * list lua_stmt * list lua_stmt
let rec ppr_ref indl = function
| Lvar v -> v
| Lindex (e, Lstr x) -> ppr_ref indl e ^ "." ^ x
| Lindex (e, e') -> ppr_ref indl e ^ "[" ^ ppr_expr indl e' ^ "]"
and ppr_expr indl = function
| Lfunc (args, body) ->
"function(" ^ ppr_args args ^ ")\n" ^ ppr_body (indl ^ " ") body ^ indl ^ "end"
| Lcall_e (Lref _ as func, args) ->
ppr_expr indl func ^ "(" ^ ppr_args (ppr_expr indl <$> args) ^ ")"
| Lcall_e (func, args) ->
"(" ^ ppr_expr indl func ^ ")(" ^ ppr_args (ppr_expr indl <$> args) ^ ")"
| Lstr s -> show s
| Lint i -> show i
| Ldots -> "..."
| Lref r -> ppr_ref indl r
| Ltrue -> "true"
| Lbop (l, o, r) -> ppr_expr indl l ^ " " ^ o ^ " " ^ ppr_expr indl r
| Ltable entries -> "{" ^ ppr_args (ppr_pair indl <$> entries) ^ "}"
and ppr_stmt indl = function
| Return r -> "return " ^ ppr_expr indl r
| If (c, t, []) ->
"if " ^ ppr_expr indl c ^ " then\n"
^ ppr_body (indl ^ " ") t
^ indl ^ "end"
| If (c, [], e) ->
"if not (" ^ ppr_expr indl c ^ ") then\n"
^ ppr_body (indl ^ " ") e
^ indl ^ "end"
| If (c, t, e) ->
"if " ^ ppr_expr indl c ^ " then\n"
^ ppr_body (indl ^ " ") t
^ indl ^ "else\n"
^ ppr_body (indl ^ " ") e
^ indl ^ "end"
| Local ([], []) -> ""
| Local (vs, []) -> "local " ^ ppr_args vs
| Local (vs, es) ->
"local " ^ ppr_args vs ^ " = " ^ ppr_args (ppr_expr indl <$> es)
| Assign (vs, es) ->
ppr_args (ppr_ref indl <$> vs) ^ " = " ^ ppr_args (ppr_expr indl <$> es)
| Func (n, args, body) ->
"function " ^ n ^ "(" ^ ppr_args args ^ ")\n" ^ ppr_body (indl ^ " ") body ^ indl ^ "end"
and ppr_args = function
| [] -> ""
| Cons (a, args) -> foldl (fun a b -> a ^ ", " ^ b) a args
and ppr_body indl = function
| [] -> "\n"
| Cons (a, args) ->
foldl (fun a b -> a ^ "\n" ^ indl ^ b) (indl ^ ppr_stmt indl a) (ppr_stmt indl <$> args) ^ "\n"
and ppr_pair indl (k, v) = "[" ^ ppr_expr indl k ^ "] = " ^ ppr_expr indl v
let gensym =
let counter = ref 0
fun () ->
counter := !counter + 1
"_a" ^ show !counter
let escape = function
| "nil" -> "_Lnil"
| x -> x
let var x = Lref (Lvar (escape x))
let mk_pap_def =
"\
local function mk_pap(fun, ...) \
local pending = { ... }\
return setmetatable({}, { __call = function(...) \
local args = table.pack(...)\
for i = 1, #pending do\
table.insert(args, i, pending[i])\
end\
return fun(unpack(args, 1, args.n + #pending))\
end}) \
end"
let make_lambda name args body =
let name = escape name
let args = map escape args
let arity = length args
[ Local ([name, name ^ "_entry" ], []),
Func (name ^ "_entry", args, body),
Func (name, ["..."], [
If (Lbop (Lcall_e (var "select", [Lstr "#", Ldots]), "==", Lint arity), [
Return (Lcall_e (var (name ^ "_entry"), [Ldots]))
], [
If (Lbop (Lcall_e (var "select", [Lstr "#", Ldots]), ">", Lint arity), [
Local (["_spill"], [Lcall_e (var "table.pack", [Ldots])]),
Return (Lcall_e (Lcall_e (var (name ^ "_entry"), [Ldots]),
[Lcall_e (var "table.unpack", [var "_spill", Lint arity, var "_spill.n"])]))
], [
Return (Lcall_e (var "mk_pap", [var name, Ldots]))
])])])]
let expr_of_atom = function
| Var v -> var v
| Int i -> Lfunc ([], [Return (Lint i)])
let return x = [Return x]
let rec stmts_of_expr arities = function
| Atom _ as a -> expr_of_expr arities a |> return
| App _ as a -> expr_of_expr arities a |> return
| Prim (f, xs) -> stmts_of_prim (f, expr_of_atom <$> xs)
| Con _ as a -> expr_of_expr arities a |> return
| Case (expr, binder, alts) ->
let rec make_cases = function
| [] -> [Return (Lcall_e (var "error", [Lstr "Unmatched case"]))]
| Cons ((Default, tail), _) -> stmts_of_expr arities tail
| Cons ((Con_pat (tag, names), tail), rest) ->
let accesses =
[ Lref (Lindex (Lvar binder, Lint (i + 1)))
| with i <- [1 .. length names]
]
[If (Lbop (Lref (Lindex (Lvar binder, Lint 1)), "==", Lint tag),
Local (names, accesses) :: stmts_of_expr arities tail,
make_cases rest
)]
Local ([binder], [enter arities expr]) :: make_cases alts
| Let (binders, rest) ->
let names = map (.name) binders
Local (names, []) :: gen_lambda_forms arities binders ++ stmts_of_expr arities rest
and expr_of_expr arities = function
| Atom (Var v) ->
match Map.lookup v arities with
| Some (Left (0, tag)) -> Lcall_e (var "setmetatable", [ Ltable [(Lint 1, Lint tag)], var "Constr_mt" ])
| _ -> expr_of_atom (Var v)
| Atom a -> expr_of_atom a
| App (f, xs) ->
match f with
| Int _ -> error "Attempt to call int"
| Var v ->
match Map.lookup v arities with
| Some (Right x) when x == length xs ->
(Lcall_e (var (v ^ "_entry"), expr_of_atom <$> xs))
| Some (Left (x, tag)) when x == length xs ->
let go i a = (Lint (i + 1), expr_of_atom a)
Lcall_e (var "setmetatable", [
Ltable ((Lint 1, Lint tag) :: zip_with go [1..length xs] xs),
var "Constr_mt"
])
| _ -> Lcall_e (var v, expr_of_atom <$> xs)
| Prim (f, xs) -> expr_of_prim (f, expr_of_atom <$> xs)
| Con (tag, _, atoms) ->
let go i a = (Lint (i + 1), expr_of_atom a)
Lcall_e (var "setmetatable", [
Ltable ((Lint 1, Lint tag) :: zip_with go [1..length atoms] atoms),
var "Constr_mt"
])
| e -> Lcall_e (Lfunc ([], stmts_of_expr arities e), [])
and enter arities expr =
Lcall_e (expr_of_expr arities expr, [])
and expr_of_prim = function
| Add, [a, b] -> Lfunc ([], [Return (Lbop (a, "+", b))])
| Sub, [a, b] -> Lfunc ([], [Return (Lbop (a, "-", b))])
| Mul, [a, b] -> Lfunc ([], [Return (Lbop (a, "*", b))])
| Div, [a, b] -> Lfunc ([], [Return (Lbop (a, "/", b))])
| e -> Lcall_e (Lfunc ([], stmts_of_prim e), [])
and stmts_of_prim = function
| Equ, [a, b] -> [
If (Lbop (a, "==", b),
stmts_of_expr Map.empty (Con (0, 0, [])),
stmts_of_expr Map.empty (Con (1, 0, [])))
]
| e -> expr_of_prim e |> return
and gen_lambda_forms arities : list (lambda_form stg_expr) -> list lua_stmt = function
| [] -> []
| Cons ({update = Function, name, args, body}, rest) ->
let arities = Map.insert name (Right (length args)) arities
let bst = stmts_of_expr arities body
tail (make_lambda name args bst) ++ gen_lambda_forms arities rest
| Cons ({update = Updatable, name, args, body}, rest) ->
let body = expr_of_expr arities body
let s = Assign ([Lvar name], [
Lcall_e (var "setmetatable", [
Ltable [],
Ltable [ (Lstr "__call", Lfunc (["_self"], [
If (Lref (Lindex (Lvar "_self", Lint 1)), [
Return (Lref (Lindex (Lvar "_self", Lint 1)))
], [
Local (["val"], [Lcall_e (body, [])]),
Assign ([Lindex (Lvar "_self", Lint 1)], [var "val"]),
Return (var "val")
])
]))
]
])
])
s :: gen_lambda_forms arities rest
let private pasted_files : ref (Set.t string) = ref Set.empty
let stmts_of_def (arities, code, locals) = function
| Fun { name, args, body, is_con } ->
let arities = Map.insert name (match is_con with | Some i -> Left (length args, i) | None -> Right (length args)) arities
let body = stmts_of_expr arities body
let Cons (Local (locals', _), def) = make_lambda name args body
(arities, def ++ code, locals' ++ locals)
| Ffi_def { name, fent, arity } ->
let fspec =
match Strings.split_on " " fent with
| [file, func] ->
pasted_files := Set.insert file !pasted_files
func
| [func] -> func
| _ -> error @@ "Foreign spec too big: " ^ fent
let args = [ gensym () | with _ <- [1 .. arity] ]
let Cons (Local (locals', _), def) = make_lambda name args [Return (Lcall_e (var fspec, var <$> args))]
(arities, def ++ code, locals' ++ locals)
let get_file_contents () =
let files = Set.members !pasted_files
let go contents path =
let f = open_for_reading path
let x = read_all f
close_file f
match x with
| Some s -> "--- foreign file: " ^ path ^ "\n" ^ s ^ "\n" ^ contents
| None -> contents
foldl go "" files

+ 53
- 0
stg/stg.ml View File

@ -0,0 +1,53 @@
module Set = import "data/set.ml"
open import "prelude.ml"
type update_flag = Updatable | Function
type lambda_form 'a <- { name : string, free_vars : Set.t string, args : list string, update : update_flag, body : 'a }
type stg_atom =
| Var of string
| Int of int
type stg_pattern =
| Con_pat of int * list string
| Default
type stg_primitive =
| Add
| Sub
| Mul
| Div
| Equ
type stg_expr =
| Let of list (lambda_form stg_expr) * stg_expr
| Case of stg_expr * string * list (stg_pattern * stg_expr)
| App of stg_atom * list stg_atom
| Con of int * int * list stg_atom
| Prim of stg_primitive * list stg_atom
| Atom of stg_atom
type stg_def =
| Fun of { name : string, args : list string, body : stg_expr, is_con : option int }
| Ffi_def of { name : string, fent : string, arity : int }
let stg_fv =
let rec go = function
| Atom a -> go_atom a
| Let (lfs, e) ->
let fv = go e
fv `Set.difference` Set.from_list (map (.name) lfs)
| App (a, args) -> foldl Set.union Set.empty (map go_atom (a::args))
| Con (_, _, i) -> foldl Set.union Set.empty (map go_atom i)
| Case (ex, binder, pats) ->
foldl go_pat (go ex) pats
|> Set.delete binder
| Prim (_, args) -> foldl Set.union Set.empty (map go_atom args)
and go_atom = function
| Int _ -> Set.empty
| Var e -> Set.singleton e
and go_pat set = function
| Default, e -> Set.union set (go e)
| Con_pat (_, args), e -> Set.union set (go e `Set.difference` Set.from_list args)
go

+ 5
- 1
tc.ml View File

@ -425,7 +425,11 @@ let tc_program value_exports type_exports (prog : list decl) =
| [Foreign (Fimport {var, ftype}) as def] ->
let ty_scope' = add_missing_vars M.empty ftype
let t = check_is_type (M.union ty_scope' ty_scope) ftype
(dt_info, M.insert var (Forall { vars = M.keys ty_scope', body = t } |> Poly) val_scope, ty_scope, def :: out)
(
dt_info,
M.insert var (Forall { vars = M.keys ty_scope', body = t } |> Poly) val_scope,
ty_scope, def :: out
)
| Cons (Foreign (Fimport {var}), _) ->
error @@ "Foreign definition " ^ var ^ " is part of a group. How?"
| Cons (Decl (name, args, body), ds) ->


Loading…
Cancel
Save