diff --git a/compile.ml b/compile.ml index 259a00f..c059777 100644 --- a/compile.ml +++ b/compile.ml @@ -213,24 +213,28 @@ let program decs = let (decs, (_, lams, _)) = run_state (traverse (lambda_lift_sc # eta_contract) decs) (0, [], S.empty) - let define nm = + + let define nm k = let! x = get if nm `S.member` x then - error @@ "Redefinition of value " ^ nm + pure [] else - modify (S.insert nm) + let! _ = modify (S.insert nm) + k let go = flip traverse (lams ++ decs) @@ function | Decl ((nm, args, _) as sc) -> - let! _ = define nm - let code = supercomb sc - [Sc (nm, length args, code)] |> pure + define nm ( + let code = supercomb sc + [Sc (nm, length args, code)] |> pure + ) | Data (_, _, cs) -> pure (compile_cons cs) | Foreign (Fimport { cc = Prim, var } as fi) -> - let! _ = define var - let (code, _) = cg_prim fi - pure [code] + define var ( + let (code, _) = cg_prim fi + pure [code] + ) | Foreign f -> pure [Fd f] let (out, _) = run_state go S.empty join out diff --git a/driver.ml b/driver.ml index dee5b8e..17cc99f 100644 --- a/driver.ml +++ b/driver.ml @@ -1,5 +1,6 @@ module C = import "./compile.ml" module A = import "./assemble.ml" +module T = import "./tc.ml" open import "./parser.ml" open import "prelude.ml" open import "lua/io.ml" @@ -33,7 +34,7 @@ let go' infile outfile = let test str = match lex prog str with | Right (ds, _) -> - let code = ds |> C.program + let code = ds |> T.tc_program |> C.program let lua = code |> A.assm_program print code put_line lua diff --git a/lang.ml b/lang.ml index 0d7ad15..aeabba8 100644 --- a/lang.ml +++ b/lang.ml @@ -38,6 +38,13 @@ type hstype = | Tyarr of hstype * hstype | Tytup of list hstype +let rec free_cons = function + | Tycon v -> S.singleton v + | Tyvar _ -> S.empty + | Tyapp (f, x) -> S.union (free_cons f) (free_cons x) + | Tyarr (f, x) -> S.union (free_cons f) (free_cons x) + | Tytup xs -> foldl (fun s x -> S.union s (free_cons x)) S.empty xs + let rec arity = function | Tyarr (_, i) -> 1 + arity i | _ -> 0 diff --git a/lib/graph.ml b/lib/graph.ml new file mode 100644 index 0000000..c14fd53 --- /dev/null +++ b/lib/graph.ml @@ -0,0 +1,81 @@ +module M = import "data/map.ml" +module S = import "data/set.ml" + +open import "prelude.ml" + +type t 'a <- M.t 'a (S.t 'a) + +let sccs (graph : t 'a) = + let rec dfs (node : 'a) (path : M.t 'a int) (sccs : M.t 'a 'a) = + let shallower old candidate = + match M.lookup old path, M.lookup candidate path with + | _, None -> old + | None, _ -> candidate + | Some a, Some b -> + if b < a then candidate else old + let children = + match M.lookup node graph with + | Some t -> t + | None -> error "Node not in graph?" + let go (folded, shallowest) child = + match M.lookup child path with + | Some _ -> + (folded, shallower shallowest child) + | _ -> + let scc = dfs child (M.insert node (length path) path) folded + let sfc = + match M.lookup child scc with + | Some x -> x + | None -> error "no child in scc?" + (scc, shallower shallowest sfc) + let (new, shallowest) = + S.members children |> foldl go (sccs, node) + M.insert node shallowest new + let go sccs next = + match M.lookup next sccs with + | Some _ -> sccs + | _ -> dfs next M.empty sccs + graph + |> M.keys + |> foldl go M.empty + +let toposort (graph : t 'a) : list 'a = + let nodes = M.keys graph + let l = ref [] + let temp = ref S.empty + let perm = ref S.empty + let rec visit n = + if n `S.member` !perm then + () + else if n `S.member` !temp then + error "not a dag" + else + let o_temp = !temp + temp := S.insert n o_temp + match M.lookup n graph with + | None -> () + | Some xs -> iter visit (S.members xs) + temp := o_temp + perm := S.insert n !perm + l := n :: !l + iter visit nodes + reverse !l + +let groups_of_sccs (graph : t 'a) = + let sccs = sccs graph + let edges_of n = + match M.lookup n graph with + | None -> error "not in graph" + | Some v -> v + let components = + sccs + |> M.assocs + |> map (fun (k, s) -> M.singleton s (S.singleton k)) + |> foldl (M.union_by (fun _ -> S.union)) M.empty + let comp_deps = + components + |> M.assocs + |> map (fun (node, edges) -> (node, edges_of node `S.difference` edges)) + |> M.from_list + let ordering = toposort comp_deps + [ x | with k <- ordering, with Some x <- [M.lookup k components] ] diff --git a/parser.ml b/parser.ml index bd1e3fa..df0e0df 100644 --- a/parser.ml +++ b/parser.ml @@ -60,7 +60,10 @@ and ty_fexp : forall 'm. monad 'm => parser_t 'm hstype = and ty_exp : forall 'm. monad 'm => parser_t 'm hstype = chainr1 ty_fexp (map (const (curry Tyarr)) arrow) and ty_tup : forall 'm. monad 'm => parser_t 'm hstype = - Tytup <$> sep_by comma ty_exp + let tytup = function + | [x] -> x + | x -> Tytup x + tytup <$> sep_by comma ty_exp let datadec : forall 'm. monad 'm => parser_t 'm decl = let! _ = try (keyword "data") @@ -70,9 +73,11 @@ let datadec : forall 'm. monad 'm => parser_t 'm decl = pure (Constr (nm, args)) let! nm = conid let! args = many (try varid) - let! _ = equals - let! c = sep_by_1 pipe (try datacon) - pure (Data (nm, args, c)) + let! cs = optionally ( + let! _ = equals + sep_by_1 pipe (try datacon) + ) + pure (Data (nm, args, match cs with | Some cs -> cs | None -> [])) let fdecl : forall 'm. monad 'm => parser_t 'm fdecl = let! _ = try (keyword "import") diff --git a/tc.ml b/tc.ml new file mode 100644 index 0000000..153888e --- /dev/null +++ b/tc.ml @@ -0,0 +1,452 @@ +module M = import "data/map.ml" +module G = import "./lib/graph.ml" +open import "./lang.ml" +open import "amulet/exception.ml" +open import "prelude.ml" + +type tc_tyvar 'a = Tv of { + name : string, level : int, var : ref (option 'a) +} + +instance eq (tc_tyvar 'a) begin + let Tv x == Tv y = x.name == y.name +end + +instance ord (tc_tyvar 'a) begin + let Tv x `compare` Tv y = x.name `compare` y.name +end + +type tc_kappa = + | K_arr of tc_kappa * tc_kappa + | K_star + | K_var of tc_tyvar tc_kappa + +type tc_rho = + | T_uvar of tc_tyvar tc_rho + | T_var of string + | T_con of string + | T_app of tc_rho * tc_rho + | T_arr of tc_rho * tc_rho + +instance show tc_rho begin + let show = + let rec show_arg = function + | T_app _ as x -> "(" ^ go x ^ ")" + | x -> show_domain x + and show_domain = function + | T_arr _ as x -> "(" ^ go x ^ ")" + | x -> go x + and go = function + | T_uvar (Tv n) -> + match !n.var with + | Some t -> go t + | None -> n.name + | T_var v -> v + | T_con v -> v + | T_app (f, x) -> go f ^ " " ^ show_arg x + | T_arr (a, b) -> show_domain a ^ " -> " ^ go b + go +end + +instance show tc_kappa begin + let show x = + let rec go = function + | K_star -> "*" + | K_var (Tv v) -> "?" ^ v.name + | K_arr (a, b) -> show_domain a ^ " -> " ^ go b + and show_domain = function + | K_arr _ as x -> "(" ^ show x ^ ")" + | x -> go x + go x +end + +type tc_sigma = + Forall of { + vars : list string, + body : tc_rho + } + +let rec free_unif_vars = function + | T_uvar v -> S.singleton v + | T_var _ -> S.empty + | T_con _ -> S.empty + | T_app (f, x) -> S.union (free_unif_vars f) (free_unif_vars x) + | T_arr (a, b) -> S.union (free_unif_vars a) (free_unif_vars b) + +let new_name = + let c = ref 0 + fun () -> + c := !c + 1 + "alpha" ^ show !c + +let new_tcvar level = + let name = new_name () + Tv { name, level, var = ref None } + +let rec zonk = function + | T_uvar (Tv r) as rho -> + match !r.var with + | Some rho -> zonk rho + | None -> rho + | T_var v -> T_var v + | T_con v -> T_con v + | T_app (f, x) -> T_app (zonk f, zonk x) + | T_arr (f, x) -> T_arr (zonk f, zonk x) + +let empty (Tv r) = + match !r.var with + | None -> true + | Some (T_uvar (Tv r')) -> r.name == r'.name + | _ -> false + +let generalise level rho = + let rho = zonk rho + let vars = + free_unif_vars rho + |> S.filter (fun (Tv r) -> r.level > level && empty (Tv r)) + |> S.members + flip iter vars @@ fun (Tv r) -> + r.var := Some (T_var r.name) + Forall { vars = map (fun (Tv r) -> r.name) vars, body = zonk rho } + +let rec unify a b = + let solve r s = + match !r.var with + | Some t -> unify t s + | None -> r.var := Some s + match a, b with + | T_uvar (Tv r), b -> solve r b + | a, T_uvar (Tv r) -> solve r a + | T_var a, T_var b when a == b -> () + | T_con a, T_con b when a == b -> () + | T_app (f, x), T_app (f', x') -> + unify f f' + unify x x' + | T_arr (a, b), T_arr (a', b') -> + unify a a' + unify b b' + | a, b -> error @@ "Types " ^ show a ^ " and " ^ show b ^ " are not equal" + +let rec unify_kappa a b = + let solve r s = + match !r.var with + | Some t -> unify_kappa t s + | None -> r.var := Some s + match a, b with + | K_var (Tv r), b -> solve r b + | a, K_var (Tv r) -> solve r a + | K_star, K_star -> () + | K_arr (a, b), K_arr (a', b') -> + unify_kappa a a' + unify_kappa b b' + | a, b -> error @@ "Kinds " ^ show a ^ " and " ^ show b ^ " are not equal" + +type scheme 'a = Poly of tc_sigma | Mono of 'a + +instance show 'a => show (scheme 'a) begin + let show = function + | Poly (Forall {vars,body}) -> + foldl (fun s i -> s ^ " " ^ i) "forall" vars ^ ". " ^ show body + | Mono x -> show x +end + +let mono m = function + | Mono x -> x + | Poly _ -> error @@ "Unexpected polytype " ^ m + +let get_scope map var = + match M.lookup var map with + | Some v -> v + | None -> error @@ "Name not in scope: " ^ var + +let is_function_kind level tau = + match tau with + | K_arr (a, b) -> (a, b) + | _ -> + let a = new_tcvar level |> K_var + let b = new_tcvar level |> K_var + unify_kappa tau (K_arr (a, b)) + (a, b) + +let rec infer_kind scope = function + | Tyvar v -> + let kappa = get_scope scope v |> mono "(kinds aren't ever polymorphic)" + (T_var v, kappa) + | Tycon v -> + let kappa = get_scope scope v |> mono "(kinds aren't ever polymorphic)" + (T_con v, kappa) + | Tyapp (f, x) -> + let (f, k_f) = infer_kind scope f + let (x, k_x) = infer_kind scope x + let (domain, result) = is_function_kind 0 k_f + unify_kappa domain k_x + (T_app (f, x), result) + | Tyarr (a, b) -> + let a = check_is_type scope a + let b = check_is_type scope b + (T_arr (a, b), K_star) + | Tytup [] -> (T_con "Unit#", K_star) + | _ -> error "Tuple types not supported" +and check_is_type scope t = + let (t, k) = infer_kind scope t + unify_kappa k K_star + t + +let rec default_to_star = function + | K_var (Tv r) -> + match !r.var with + | Some k -> default_to_star k + | None -> K_star + | K_star -> K_star + | K_arr (a, b) -> K_arr (default_to_star a, default_to_star b) + + +type dt_info <- + { name : string, d_args : list string, c_args : list tc_rho, c_ret : tc_rho } + +let mk_con_info (d_name : string) (d_args : list string) : list (string * list tc_rho) -> list dt_info = + let go (name, args) = + { name, c_args = args, d_args, c_ret = foldl (fun f x -> T_app (f, T_var x)) (T_con d_name) d_args } + map go + +let infer_data_group_kind scope (group : list _) = + let init_kind (group, names) (name, args, constr) = + let args = + args |> map (fun v -> (v, new_tcvar 0 |> K_var |> Mono)) + let kind = foldl (fun t (_, r) -> K_arr (t, mono "" r)) K_star args + let scope = M.from_list args + ((name, kind, constr, scope, args) :: group, M.insert name (Mono kind) names) + + let (group, scope') = foldl init_kind ([], M.empty) group + + let scope = M.union scope scope' + + let group : list (string * tc_kappa * list string * list (string * list tc_rho)) = + flip map group @@ fun (name, kind, constrs, args, args') -> + let scope = M.union scope args + constrs + |> map (fun (Constr (name, args)) -> (name, map (check_is_type scope) args)) + |> (name,kind,[x|with (x,_)<-args'],) + + flip map group @@ fun (name, kind, args, constrs) -> + (name, default_to_star kind, constrs, mk_con_info name args constrs) + +let rec subst_vars f = function + | T_var v as t -> + match f v with + | None -> t + | Some t -> t + | T_uvar (Tv v) as t -> + match !v.var with + | Some t -> subst_vars f t + | None -> t + | T_con v -> T_con v + | T_app (a, b) -> T_app (subst_vars f a, subst_vars f b) + | T_arr (a, b) -> T_arr (subst_vars f a, subst_vars f b) + +let instantiate level (Forall { vars, body }) = + let vars = + vars + |> map (fun v -> (v, new_tcvar level |> T_uvar)) + |> M.from_list + subst_vars (flip M.lookup vars) body + +let lookup_ty level scope v = + get_scope scope v |> function + | Mono t -> t + | Poly s -> instantiate level s + +let is_function_type level tau = + match tau with + | T_arr (a, b) -> (a, b) + | _ -> + let a = new_tcvar level |> T_uvar + let b = new_tcvar level |> T_uvar + unify tau (T_arr (a, b)) + (a, b) + +(* TODO: Rank-N types *) +let is_subtype = unify + +let rec infer dt_info level scope = function + | Ref v -> lookup_ty level scope v |> (Ref v,) + | App (f, x) -> + let (f, arg, res) = + infer dt_info level scope f + |> second (is_function_type level) + let x = check dt_info level scope arg x + (App (f, x), res) + | Lit i -> (Lit i, T_con "Int") + | Let (bindings, body) -> + let (bindings, scope') = + infer_binding_group dt_info level scope bindings + let (body, body_t) = infer dt_info level (scope `M.union` map force scope') body + (Let (bindings, body), body_t) + | x -> + let t = new_tcvar level |> T_uvar + let x = check dt_info level scope t x + (x, t) + +and check dt_info level scope wanted = function + | Lam (arg, body) -> + let (arg_t, body_t) = is_function_type level wanted + let body = + (* TODO: Rank-N types *) + check dt_info level (M.insert arg (Mono arg_t) scope) body_t body + Lam (arg, body) + | Case (_, []) -> error "Empty case" + | Case (scrutinee, Cons ((con, _, _), _) as patterns) -> + let data = + match M.lookup con dt_info with + | Some data -> data + | None -> error @@ "Constructor " ^ con ^ " doesn't belong to a type" + + let (scrutinee, scrut_t) = infer dt_info level scope scrutinee + + let go_arm {name, d_args, c_args, c_ret} (con, args, expr) = + if name <> con then + error @@ "Constructors out of order: expected this pattern to match " ^ name + else () + + if length c_args <> length args then + error @@ "Constructor " + ^ con ^ " has " + ^ show (length c_args) + ^ " but is being matched against with " ^ show (length args) + ^ " variables" + else () + + let d_args = + d_args + |> map (fun v -> (v, new_tcvar level |> T_uvar)) + |> M.from_list + let c_args = map (Mono # subst_vars (flip M.lookup d_args)) c_args + let c_ret = subst_vars (flip M.lookup d_args) c_ret + + unify c_ret scrut_t + + let scope' = M.from_list (zip args c_args) `M.union` scope + (con, args, check dt_info level scope' wanted expr) + + Case (scrutinee, zip_with go_arm data patterns) + | x -> + let (x, t) = infer dt_info level scope x + is_subtype t wanted + x + +and infer_binding_group dt_info level (scope : M.t string _) bindings : _ * M.t string _ = + let inner = level + 1 + let initial_types = + bindings + |> map (fun (name, _) -> (name, new_tcvar inner |> T_uvar |> Mono)) + |> M.from_list + + let initial_types = initial_types |> M.union scope + + let go_binding (bindings : list _, scope' : M.t _ _) (name : string, body : expr) = + let (body, body_ty) = + (fun () -> infer dt_info inner initial_types body) + `catch` fun (e : some exception) -> + error (describe_exception e ^ "\nwhen type checking " ^ name) + M.lookup name scope + |> function + | Some (Mono t) -> unify t body_ty + | _ -> () + ( + (name, body) :: bindings, + M.insert name (lazy (generalise level body_ty |> Poly)) scope' + ) + foldl go_binding ([], M.empty) bindings + +let dependency_graph defs = + let rec free_vars_of_cons t m (Constr (name, args)) = + let cons = + foldl (fun s t -> S.union s (free_cons t)) (S.singleton t) + args + M.insert name cons m + let define n x m = + M.alter (function + | Some _ -> error @@ "Redefinition of value " ^ n + | None -> Some x) + n m + let go (graph, defs) = function + | Foreign (Fimport { var }) as x -> + (M.insert var S.empty graph, define var x defs) + | Decl (name, args, expr) as x -> + let fvs = + free_vars expr + |> flip S.difference (S.from_list args) + |> S.delete name + (M.insert name fvs graph, define name x defs) + | Data (name, _, cons) as x -> + M.union graph (foldl (free_vars_of_cons name) M.empty cons) + |> M.insert name S.empty + |> (, define name x defs) + let (graph, defs) = foldl go (M.empty, M.empty) defs + (G.groups_of_sccs graph, defs) + +let mk_lam args body = foldr (curry Lam) body args +let rec unlambda = function + | Lam (v, x) -> + let (args, x) = unlambda x + (v :: args, x) + | e -> ([], e) + +let rec replicate n x = + if n <= 0 then + [] + else + x :: replicate (n - 1) x + +let rec add_missing_vars scope = function + | Tyvar v -> + match M.lookup v scope with + | Some _ -> scope + | None -> + let k = new_tcvar 0 |> K_var + M.insert v (Mono k) scope + | Tycon _ -> scope + | Tyapp (a, b) -> add_missing_vars (add_missing_vars scope b) a + | Tyarr (a, b) -> add_missing_vars (add_missing_vars scope b) a + | Tytup xs -> foldl add_missing_vars scope xs + +let tc_program (prog : list decl) = + let (plan, defs) = dependency_graph prog + let tc_one (dt_info, val_scope, ty_scope, out) group = + let defs = [ x | with name <- S.members group, with Some x <- [M.lookup name defs] ] + match defs with + | [] -> (dt_info, val_scope, ty_scope, defs) + | [Foreign (Fimport {var, ftype}) as def] -> + let ty_scope' = add_missing_vars M.empty ftype + let t = check_is_type (M.union ty_scope' ty_scope) ftype + (dt_info, M.insert var (Forall { vars = M.keys ty_scope', body = t } |> Poly) val_scope, ty_scope, def :: out) + | Cons (Foreign (Fimport {var}), _) -> + error @@ "Foreign definition " ^ var ^ " is part of a group. How?" + | Cons (Decl (name, args, body), ds) -> + let bindings = + (name, mk_lam args body) + :: [ (name, mk_lam args body) | with Decl (name, args, body) <- ds ] + let (bindings, scope') = infer_binding_group dt_info -1 val_scope bindings + let decs = + [ Decl (name, unlambda expr) | with (name, expr) <- bindings ] + (dt_info, M.union (map force scope') val_scope, ty_scope, decs ++ defs) + | Cons (Data d, ds) -> + let datas = d :: [ d | with Data d <- ds ] + let r = infer_data_group_kind ty_scope datas + let fix_constr (name, rhos : list tc_rho) = + Constr (name, replicate (length rhos) (Tycon "#")) + let rec go dt ty (vl : M.t string (scheme tc_rho)) ds = function + | [] -> (dt, vl, ty, reverse ds ++ out) + | Cons ((name, kind, constrs, info : list dt_info), rest) -> + go + (foldl (fun i {name} -> M.insert name info i) dt info) + (M.insert name (Mono kind) ty) + (foldl + (fun s {name,d_args,c_args,c_ret} -> + M.insert name (Forall { vars = d_args, body = foldr (curry T_arr) c_ret c_args} |> Poly) s) + vl info) + (Data (name, [], fix_constr <$> constrs) :: ds) + rest + go dt_info ty_scope val_scope [] r + let (_, _, _, p) = foldl tc_one (M.empty, M.empty, M.empty, []) plan + p