You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

452 lines
14 KiB

  1. module M = import "data/map.ml"
  2. module G = import "./lib/graph.ml"
  3. open import "./lang.ml"
  4. open import "amulet/exception.ml"
  5. open import "prelude.ml"
  6. type tc_tyvar 'a = Tv of {
  7. name : string, level : int, var : ref (option 'a)
  8. }
  9. instance eq (tc_tyvar 'a) begin
  10. let Tv x == Tv y = x.name == y.name
  11. end
  12. instance ord (tc_tyvar 'a) begin
  13. let Tv x `compare` Tv y = x.name `compare` y.name
  14. end
  15. type tc_kappa =
  16. | K_arr of tc_kappa * tc_kappa
  17. | K_star
  18. | K_var of tc_tyvar tc_kappa
  19. type tc_rho =
  20. | T_uvar of tc_tyvar tc_rho
  21. | T_var of string
  22. | T_con of string
  23. | T_app of tc_rho * tc_rho
  24. | T_arr of tc_rho * tc_rho
  25. instance show tc_rho begin
  26. let show =
  27. let rec show_arg = function
  28. | T_app _ as x -> "(" ^ go x ^ ")"
  29. | x -> show_domain x
  30. and show_domain = function
  31. | T_arr _ as x -> "(" ^ go x ^ ")"
  32. | x -> go x
  33. and go = function
  34. | T_uvar (Tv n) ->
  35. match !n.var with
  36. | Some t -> go t
  37. | None -> n.name
  38. | T_var v -> v
  39. | T_con v -> v
  40. | T_app (f, x) -> go f ^ " " ^ show_arg x
  41. | T_arr (a, b) -> show_domain a ^ " -> " ^ go b
  42. go
  43. end
  44. instance show tc_kappa begin
  45. let show x =
  46. let rec go = function
  47. | K_star -> "*"
  48. | K_var (Tv v) -> "?" ^ v.name
  49. | K_arr (a, b) -> show_domain a ^ " -> " ^ go b
  50. and show_domain = function
  51. | K_arr _ as x -> "(" ^ show x ^ ")"
  52. | x -> go x
  53. go x
  54. end
  55. type tc_sigma =
  56. Forall of {
  57. vars : list string,
  58. body : tc_rho
  59. }
  60. let rec free_unif_vars = function
  61. | T_uvar v -> S.singleton v
  62. | T_var _ -> S.empty
  63. | T_con _ -> S.empty
  64. | T_app (f, x) -> S.union (free_unif_vars f) (free_unif_vars x)
  65. | T_arr (a, b) -> S.union (free_unif_vars a) (free_unif_vars b)
  66. let new_name =
  67. let c = ref 0
  68. fun () ->
  69. c := !c + 1
  70. "alpha" ^ show !c
  71. let new_tcvar level =
  72. let name = new_name ()
  73. Tv { name, level, var = ref None }
  74. let rec zonk = function
  75. | T_uvar (Tv r) as rho ->
  76. match !r.var with
  77. | Some rho -> zonk rho
  78. | None -> rho
  79. | T_var v -> T_var v
  80. | T_con v -> T_con v
  81. | T_app (f, x) -> T_app (zonk f, zonk x)
  82. | T_arr (f, x) -> T_arr (zonk f, zonk x)
  83. let empty (Tv r) =
  84. match !r.var with
  85. | None -> true
  86. | Some (T_uvar (Tv r')) -> r.name == r'.name
  87. | _ -> false
  88. let generalise level rho =
  89. let rho = zonk rho
  90. let vars =
  91. free_unif_vars rho
  92. |> S.filter (fun (Tv r) -> r.level > level && empty (Tv r))
  93. |> S.members
  94. flip iter vars @@ fun (Tv r) ->
  95. r.var := Some (T_var r.name)
  96. Forall { vars = map (fun (Tv r) -> r.name) vars, body = zonk rho }
  97. let rec unify a b =
  98. let solve r s =
  99. match !r.var with
  100. | Some t -> unify t s
  101. | None -> r.var := Some s
  102. match a, b with
  103. | T_uvar (Tv r), b -> solve r b
  104. | a, T_uvar (Tv r) -> solve r a
  105. | T_var a, T_var b when a == b -> ()
  106. | T_con a, T_con b when a == b -> ()
  107. | T_app (f, x), T_app (f', x') ->
  108. unify f f'
  109. unify x x'
  110. | T_arr (a, b), T_arr (a', b') ->
  111. unify a a'
  112. unify b b'
  113. | a, b -> error @@ "Types " ^ show a ^ " and " ^ show b ^ " are not equal"
  114. let rec unify_kappa a b =
  115. let solve r s =
  116. match !r.var with
  117. | Some t -> unify_kappa t s
  118. | None -> r.var := Some s
  119. match a, b with
  120. | K_var (Tv r), b -> solve r b
  121. | a, K_var (Tv r) -> solve r a
  122. | K_star, K_star -> ()
  123. | K_arr (a, b), K_arr (a', b') ->
  124. unify_kappa a a'
  125. unify_kappa b b'
  126. | a, b -> error @@ "Kinds " ^ show a ^ " and " ^ show b ^ " are not equal"
  127. type scheme 'a = Poly of tc_sigma | Mono of 'a
  128. instance show 'a => show (scheme 'a) begin
  129. let show = function
  130. | Poly (Forall {vars,body}) ->
  131. foldl (fun s i -> s ^ " " ^ i) "forall" vars ^ ". " ^ show body
  132. | Mono x -> show x
  133. end
  134. let mono m = function
  135. | Mono x -> x
  136. | Poly _ -> error @@ "Unexpected polytype " ^ m
  137. let get_scope map var =
  138. match M.lookup var map with
  139. | Some v -> v
  140. | None -> error @@ "Name not in scope: " ^ var
  141. let is_function_kind level tau =
  142. match tau with
  143. | K_arr (a, b) -> (a, b)
  144. | _ ->
  145. let a = new_tcvar level |> K_var
  146. let b = new_tcvar level |> K_var
  147. unify_kappa tau (K_arr (a, b))
  148. (a, b)
  149. let rec infer_kind scope = function
  150. | Tyvar v ->
  151. let kappa = get_scope scope v |> mono "(kinds aren't ever polymorphic)"
  152. (T_var v, kappa)
  153. | Tycon v ->
  154. let kappa = get_scope scope v |> mono "(kinds aren't ever polymorphic)"
  155. (T_con v, kappa)
  156. | Tyapp (f, x) ->
  157. let (f, k_f) = infer_kind scope f
  158. let (x, k_x) = infer_kind scope x
  159. let (domain, result) = is_function_kind 0 k_f
  160. unify_kappa domain k_x
  161. (T_app (f, x), result)
  162. | Tyarr (a, b) ->
  163. let a = check_is_type scope a
  164. let b = check_is_type scope b
  165. (T_arr (a, b), K_star)
  166. | Tytup [] -> (T_con "Unit#", K_star)
  167. | _ -> error "Tuple types not supported"
  168. and check_is_type scope t =
  169. let (t, k) = infer_kind scope t
  170. unify_kappa k K_star
  171. t
  172. let rec default_to_star = function
  173. | K_var (Tv r) ->
  174. match !r.var with
  175. | Some k -> default_to_star k
  176. | None -> K_star
  177. | K_star -> K_star
  178. | K_arr (a, b) -> K_arr (default_to_star a, default_to_star b)
  179. type dt_info <-
  180. { name : string, d_args : list string, c_args : list tc_rho, c_ret : tc_rho }
  181. let mk_con_info (d_name : string) (d_args : list string) : list (string * list tc_rho) -> list dt_info =
  182. let go (name, args) =
  183. { name, c_args = args, d_args, c_ret = foldl (fun f x -> T_app (f, T_var x)) (T_con d_name) d_args }
  184. map go
  185. let infer_data_group_kind scope (group : list _) =
  186. let init_kind (group, names) (name, args, constr) =
  187. let args =
  188. args |> map (fun v -> (v, new_tcvar 0 |> K_var |> Mono))
  189. let kind = foldl (fun t (_, r) -> K_arr (t, mono "" r)) K_star args
  190. let scope = M.from_list args
  191. ((name, kind, constr, scope, args) :: group, M.insert name (Mono kind) names)
  192. let (group, scope') = foldl init_kind ([], M.empty) group
  193. let scope = M.union scope scope'
  194. let group : list (string * tc_kappa * list string * list (string * list tc_rho)) =
  195. flip map group @@ fun (name, kind, constrs, args, args') ->
  196. let scope = M.union scope args
  197. constrs
  198. |> map (fun (Constr (name, args)) -> (name, map (check_is_type scope) args))
  199. |> (name,kind,[x|with (x,_)<-args'],)
  200. flip map group @@ fun (name, kind, args, constrs) ->
  201. (name, default_to_star kind, constrs, mk_con_info name args constrs)
  202. let rec subst_vars f = function
  203. | T_var v as t ->
  204. match f v with
  205. | None -> t
  206. | Some t -> t
  207. | T_uvar (Tv v) as t ->
  208. match !v.var with
  209. | Some t -> subst_vars f t
  210. | None -> t
  211. | T_con v -> T_con v
  212. | T_app (a, b) -> T_app (subst_vars f a, subst_vars f b)
  213. | T_arr (a, b) -> T_arr (subst_vars f a, subst_vars f b)
  214. let instantiate level (Forall { vars, body }) =
  215. let vars =
  216. vars
  217. |> map (fun v -> (v, new_tcvar level |> T_uvar))
  218. |> M.from_list
  219. subst_vars (flip M.lookup vars) body
  220. let lookup_ty level scope v =
  221. get_scope scope v |> function
  222. | Mono t -> t
  223. | Poly s -> instantiate level s
  224. let is_function_type level tau =
  225. match tau with
  226. | T_arr (a, b) -> (a, b)
  227. | _ ->
  228. let a = new_tcvar level |> T_uvar
  229. let b = new_tcvar level |> T_uvar
  230. unify tau (T_arr (a, b))
  231. (a, b)
  232. (* TODO: Rank-N types *)
  233. let is_subtype = unify
  234. let rec infer dt_info level scope = function
  235. | Ref v -> lookup_ty level scope v |> (Ref v,)
  236. | App (f, x) ->
  237. let (f, arg, res) =
  238. infer dt_info level scope f
  239. |> second (is_function_type level)
  240. let x = check dt_info level scope arg x
  241. (App (f, x), res)
  242. | Lit i -> (Lit i, T_con "Int")
  243. | Let (bindings, body) ->
  244. let (bindings, scope') =
  245. infer_binding_group dt_info level scope bindings
  246. let (body, body_t) = infer dt_info level (scope `M.union` map force scope') body
  247. (Let (bindings, body), body_t)
  248. | x ->
  249. let t = new_tcvar level |> T_uvar
  250. let x = check dt_info level scope t x
  251. (x, t)
  252. and check dt_info level scope wanted = function
  253. | Lam (arg, body) ->
  254. let (arg_t, body_t) = is_function_type level wanted
  255. let body =
  256. (* TODO: Rank-N types *)
  257. check dt_info level (M.insert arg (Mono arg_t) scope) body_t body
  258. Lam (arg, body)
  259. | Case (_, []) -> error "Empty case"
  260. | Case (scrutinee, Cons ((con, _, _), _) as patterns) ->
  261. let data =
  262. match M.lookup con dt_info with
  263. | Some data -> data
  264. | None -> error @@ "Constructor " ^ con ^ " doesn't belong to a type"
  265. let (scrutinee, scrut_t) = infer dt_info level scope scrutinee
  266. let go_arm {name, d_args, c_args, c_ret} (con, args, expr) =
  267. if name <> con then
  268. error @@ "Constructors out of order: expected this pattern to match " ^ name
  269. else ()
  270. if length c_args <> length args then
  271. error @@ "Constructor "
  272. ^ con ^ " has "
  273. ^ show (length c_args)
  274. ^ " but is being matched against with " ^ show (length args)
  275. ^ " variables"
  276. else ()
  277. let d_args =
  278. d_args
  279. |> map (fun v -> (v, new_tcvar level |> T_uvar))
  280. |> M.from_list
  281. let c_args = map (Mono # subst_vars (flip M.lookup d_args)) c_args
  282. let c_ret = subst_vars (flip M.lookup d_args) c_ret
  283. unify c_ret scrut_t
  284. let scope' = M.from_list (zip args c_args) `M.union` scope
  285. (con, args, check dt_info level scope' wanted expr)
  286. Case (scrutinee, zip_with go_arm data patterns)
  287. | x ->
  288. let (x, t) = infer dt_info level scope x
  289. is_subtype t wanted
  290. x
  291. and infer_binding_group dt_info level (scope : M.t string _) bindings : _ * M.t string _ =
  292. let inner = level + 1
  293. let initial_types =
  294. bindings
  295. |> map (fun (name, _) -> (name, new_tcvar inner |> T_uvar |> Mono))
  296. |> M.from_list
  297. let initial_types = initial_types |> M.union scope
  298. let go_binding (bindings : list _, scope' : M.t _ _) (name : string, body : expr) =
  299. let (body, body_ty) =
  300. (fun () -> infer dt_info inner initial_types body)
  301. `catch` fun (e : some exception) ->
  302. error (describe_exception e ^ "\nwhen type checking " ^ name)
  303. M.lookup name scope
  304. |> function
  305. | Some (Mono t) -> unify t body_ty
  306. | _ -> ()
  307. (
  308. (name, body) :: bindings,
  309. M.insert name (lazy (generalise level body_ty |> Poly)) scope'
  310. )
  311. foldl go_binding ([], M.empty) bindings
  312. let dependency_graph defs =
  313. let rec free_vars_of_cons t m (Constr (name, args)) =
  314. let cons =
  315. foldl (fun s t -> S.union s (free_cons t)) (S.singleton t)
  316. args
  317. M.insert name cons m
  318. let define n x m =
  319. M.alter (function
  320. | Some _ -> error @@ "Redefinition of value " ^ n
  321. | None -> Some x)
  322. n m
  323. let go (graph, defs) = function
  324. | Foreign (Fimport { var }) as x ->
  325. (M.insert var S.empty graph, define var x defs)
  326. | Decl (name, args, expr) as x ->
  327. let fvs =
  328. free_vars expr
  329. |> flip S.difference (S.from_list args)
  330. |> S.delete name
  331. (M.insert name fvs graph, define name x defs)
  332. | Data (name, _, cons) as x ->
  333. M.union graph (foldl (free_vars_of_cons name) M.empty cons)
  334. |> M.insert name S.empty
  335. |> (, define name x defs)
  336. let (graph, defs) = foldl go (M.empty, M.empty) defs
  337. (G.groups_of_sccs graph, defs)
  338. let mk_lam args body = foldr (curry Lam) body args
  339. let rec unlambda = function
  340. | Lam (v, x) ->
  341. let (args, x) = unlambda x
  342. (v :: args, x)
  343. | e -> ([], e)
  344. let rec replicate n x =
  345. if n <= 0 then
  346. []
  347. else
  348. x :: replicate (n - 1) x
  349. let rec add_missing_vars scope = function
  350. | Tyvar v ->
  351. match M.lookup v scope with
  352. | Some _ -> scope
  353. | None ->
  354. let k = new_tcvar 0 |> K_var
  355. M.insert v (Mono k) scope
  356. | Tycon _ -> scope
  357. | Tyapp (a, b) -> add_missing_vars (add_missing_vars scope b) a
  358. | Tyarr (a, b) -> add_missing_vars (add_missing_vars scope b) a
  359. | Tytup xs -> foldl add_missing_vars scope xs
  360. let tc_program (prog : list decl) =
  361. let (plan, defs) = dependency_graph prog
  362. let tc_one (dt_info, val_scope, ty_scope, out) group =
  363. let defs = [ x | with name <- S.members group, with Some x <- [M.lookup name defs] ]
  364. match defs with
  365. | [] -> (dt_info, val_scope, ty_scope, defs)
  366. | [Foreign (Fimport {var, ftype}) as def] ->
  367. let ty_scope' = add_missing_vars M.empty ftype
  368. let t = check_is_type (M.union ty_scope' ty_scope) ftype
  369. (dt_info, M.insert var (Forall { vars = M.keys ty_scope', body = t } |> Poly) val_scope, ty_scope, def :: out)
  370. | Cons (Foreign (Fimport {var}), _) ->
  371. error @@ "Foreign definition " ^ var ^ " is part of a group. How?"
  372. | Cons (Decl (name, args, body), ds) ->
  373. let bindings =
  374. (name, mk_lam args body)
  375. :: [ (name, mk_lam args body) | with Decl (name, args, body) <- ds ]
  376. let (bindings, scope') = infer_binding_group dt_info -1 val_scope bindings
  377. let decs =
  378. [ Decl (name, unlambda expr) | with (name, expr) <- bindings ]
  379. (dt_info, M.union (map force scope') val_scope, ty_scope, decs ++ defs)
  380. | Cons (Data d, ds) ->
  381. let datas = d :: [ d | with Data d <- ds ]
  382. let r = infer_data_group_kind ty_scope datas
  383. let fix_constr (name, rhos : list tc_rho) =
  384. Constr (name, replicate (length rhos) (Tycon "#"))
  385. let rec go dt ty (vl : M.t string (scheme tc_rho)) ds = function
  386. | [] -> (dt, vl, ty, reverse ds ++ out)
  387. | Cons ((name, kind, constrs, info : list dt_info), rest) ->
  388. go
  389. (foldl (fun i {name} -> M.insert name info i) dt info)
  390. (M.insert name (Mono kind) ty)
  391. (foldl
  392. (fun s {name,d_args,c_args,c_ret} ->
  393. M.insert name (Forall { vars = d_args, body = foldr (curry T_arr) c_ret c_args} |> Poly) s)
  394. vl info)
  395. (Data (name, [], fix_constr <$> constrs) :: ds)
  396. rest
  397. go dt_info ty_scope val_scope [] r
  398. let (_, _, _, p) = foldl tc_one (M.empty, M.empty, M.empty, []) plan
  399. p