Browse Source

Implement MLTT elaborator w/ type inference

feature/hits
Amélia Liao 3 years ago
parent
commit
fcd5428ee0
9 changed files with 489 additions and 12 deletions
  1. +11
    -1
      cubical.cabal
  2. +87
    -0
      src/Elab.hs
  3. +201
    -0
      src/Elab/Eval.hs
  4. +48
    -0
      src/Elab/Monad.hs
  5. +6
    -2
      src/Presyntax/Lexer.x
  6. +28
    -8
      src/Presyntax/Parser.y
  7. +7
    -1
      src/Presyntax/Presyntax.hs
  8. +5
    -0
      src/Presyntax/Tokens.hs
  9. +96
    -0
      src/Syntax.hs

+ 11
- 1
cubical.cabal View File

@ -19,8 +19,10 @@ executable cubical
default-language: Haskell2010
build-depends: base ^>= 4.14
, mtl ^>= 2.2
, text ^>= 1.2
, array ^>= 0.5.4
, array ^>= 0.5
, containers ^>= 0.6
, bytestring ^>= 0.10
other-modules: Presyntax.Lexer
@ -28,6 +30,14 @@ executable cubical
, Presyntax.Tokens
, Presyntax.Presyntax
, Syntax
, Elab
, Elab.Eval
, Elab.Monad
build-tool-depends: alex:alex >= 3.2.4 && < 4.0
, happy:happy >= 1.19.12 && < 2.0
ghc-options: -Wall -Wextra -Wno-name-shadowing

+ 87
- 0
src/Elab.hs View File

@ -0,0 +1,87 @@
{-# LANGUAGE TupleSections, OverloadedStrings #-}
module Elab where
import Elab.Monad
import qualified Presyntax.Presyntax as P
import Syntax
import Elab.Eval
infer :: P.Expr -> ElabM (Term, NFType)
infer (P.Var t) = (Ref (Bound t),) <$> getNfType (Bound t)
infer (P.App p f x) = do
(f, f_ty) <- infer f
(d, r, w) <- isPiType p f_ty
x <- check x d
x_nf <- eval x
pure (App p (w f) x, r x_nf)
infer (P.Pi p s d r) = do
d <- check d VType
d_nf <- eval d
assume (Bound s) d_nf $ do
r <- check r VType
pure (Pi p s d r, VType)
infer (P.Sigma s d r) = do
d <- check d VType
d_nf <- eval d
assume (Bound s) d_nf $ do
r <- check r VType
pure (Sigma s d r, VType)
infer exp = do
t <- newMeta VType
tm <- check exp t
pure (tm, t)
check :: P.Expr -> NFType -> ElabM Term
check (P.Lam p var body) (VPi p' dom (Closure _ rng)) | p == p' =
assume (Bound var) dom $
Lam p var <$> check body (rng (VVar (Bound var)))
check tm (VPi P.Im dom (Closure var rng)) =
assume (Bound var) dom $
Lam P.Im var <$> check tm (rng (VVar (Bound var)))
check (P.Lam p v b) ty = do
(d, r, wp) <- isPiType p ty
assume (Bound v) d $
wp . Lam P.Im v <$> check b (r (VVar (Bound v)))
check (P.Pair a b) ty = do
(d, r, wp) <- isSigmaType ty
a <- check a d
a_nf <- eval a
b <- check b (r a_nf)
pure (wp (Pair a b))
check exp ty = do
(tm, has) <- infer exp
unify has ty
pure tm
isPiType :: P.Plicity -> NFType -> ElabM (Value, NFType -> NFType, Term -> Term)
isPiType p (VPi p' d (Closure _ k)) | p == p' = pure (d, k, id)
isPiType p t = do
dom <- newMeta VType
name <- newName
assume (Bound name) dom $ do
rng <- newMeta VType
wp <- isConvertibleTo t (VPi p dom (Closure name (const rng)))
pure (dom, const rng, wp)
isSigmaType :: NFType -> ElabM (Value, NFType -> NFType, Term -> Term)
isSigmaType (VSigma d (Closure _ k)) = pure (d, k, id)
isSigmaType t = do
dom <- newMeta VType
name <- newName
assume (Bound name) dom $ do
rng <- newMeta VType
wp <- isConvertibleTo t (VSigma dom (Closure name (const rng)))
pure (dom, const rng, wp)
identityTy :: NFType
identityTy = VPi P.Im VType (Closure "A" $ \t -> VPi P.Ex t (Closure "_" (const t)))

+ 201
- 0
src/Elab/Eval.hs View File

@ -0,0 +1,201 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveAnyClass #-}
module Elab.Eval where
import Control.Monad.Reader
import Control.Exception
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Text as T
import Data.Traversable
import Data.Set (Set)
import Data.Typeable
import Data.Foldable
import Data.IORef
import Data.Maybe
import Elab.Monad
import Presyntax.Presyntax (Plicity(..))
import Syntax
import System.IO.Unsafe
eval :: Term -> ElabM Value
eval t = asks (flip evalWithEnv t)
forceIO :: MonadIO m => Value -> m Value
forceIO vl@(VNe (HMeta (MV _ cell)) args) = do
solved <- liftIO $ readIORef cell
case solved of
Just vl -> forceIO $ foldl applProj vl (reverse args)
Nothing -> pure vl
forceIO x = pure x
applProj :: Value -> Projection -> Value
applProj fun (PApp p arg) = vApp p fun arg
applProj fun PProj1 = vProj1 fun
applProj fun PProj2 = vProj2 fun
force :: Value -> Value
force = unsafePerformIO . forceIO
evalWithEnv :: ElabEnv -> Term -> Value
evalWithEnv env (Ref x) =
case Map.lookup x (getEnv env) of
Just (_, vl) -> vl
_ -> error "variable not in scope when evaluating"
evalWithEnv env (App p f x) = vApp p (evalWithEnv env f) (evalWithEnv env x)
evalWithEnv env (Lam p s t) =
VLam p $ Closure s $ \a ->
evalWithEnv (ElabEnv (Map.insert (Bound s) (error "type of abs", a) (getEnv env))) t
evalWithEnv env (Pi p s d t) =
VPi p (evalWithEnv env d) $ Closure s $ \a ->
evalWithEnv (ElabEnv (Map.insert (Bound s) (error "type of abs", a) (getEnv env))) t
evalWithEnv _ (Meta m) = VNe (HMeta m) []
evalWithEnv env (Sigma s d t) =
VSigma (evalWithEnv env d) $ Closure s $ \a ->
evalWithEnv (ElabEnv (Map.insert (Bound s) (error "type of abs", a) (getEnv env))) t
evalWithEnv e (Pair a b) = VPair (evalWithEnv e a) (evalWithEnv e b)
evalWithEnv e (Proj1 a) = vProj1 (evalWithEnv e a)
evalWithEnv e (Proj2 a) = vProj2 (evalWithEnv e a)
evalWithEnv _ Type = VType
vApp :: Plicity -> Value -> Value -> Value
vApp p (VLam p' k) arg = assert (p == p') $ clCont k arg
vApp p (VNe h sp) arg = VNe h (PApp p arg:sp)
vApp _ x _ = error $ "can't apply " ++ show x
vProj1 :: Value -> Value
vProj1 (VPair a _) = a
vProj1 (VNe h sp) = VNe h (PProj1:sp)
vProj1 x = error $ "can't proj1 " ++ show x
vProj2 :: Value -> Value
vProj2 (VPair _ b) = b
vProj2 (VNe h sp) = VNe h (PProj2:sp)
vProj2 x = error $ "can't proj2 " ++ show x
data NotEqual = NotEqual Value Value
deriving (Show, Typeable, Exception)
unify :: Value -> Value -> ElabM ()
unify topa topb = join $ go <$> forceIO topa <*> forceIO topb where
go (VNe (HMeta mv) sp) rhs = solveMeta mv sp rhs
go (VNe x a) (VNe x' a')
| x == x', length a == length a' =
traverse_ (uncurry unifySpine) (zip a a')
| otherwise = fail
go (VLam p (Closure _ k)) (VLam p' (Closure _ k')) | p == p' = do
t <- VVar . Bound <$> newName
unify (k t) (k' t)
go (VPi p d (Closure _ k)) (VPi p' d' (Closure _ k')) | p == p' = do
t <- VVar . Bound <$> newName
unify d d'
unify (k t) (k' t)
go _ _ = fail
fail = liftIO . throwIO $ NotEqual topa topb
unifySpine (PApp a v) (PApp a' v')
| a == a' = unify v v'
unifySpine _ _ = fail
isConvertibleTo :: Value -> Value -> ElabM (Term -> Term)
VPi Im d (Closure _v k) `isConvertibleTo` ty = do
meta <- newMeta d
cont <- k meta `isConvertibleTo` ty
pure (\f -> cont (App Ex f (quote meta)))
isConvertibleTo a b = do
unify a b
pure id
newMeta :: Value -> ElabM Value
newMeta _dom = do
n <- newName
c <- liftIO $ newIORef Nothing
let m = MV n c
env <- asks getEnv
t <- for (Map.toList env) $ \(n, (_, c)) -> pure $
case c of
VVar n' | n == n' -> Just (PApp Ex (VVar n'))
_ -> Nothing
pure (VNe (HMeta m) (catMaybes t))
newName :: MonadIO m => m T.Text
newName = liftIO $ do
x <- atomicModifyIORef _nameCounter $ \x -> (x + 1, x + 1)
pure (T.pack (show x))
_nameCounter :: IORef Int
_nameCounter = unsafePerformIO $ newIORef 0
{-# NOINLINE _nameCounter #-}
solveMeta :: MV -> [Projection] -> Value -> ElabM ()
solveMeta m@(MV _ cell) sp rhs = do
liftIO $ print (m, sp, rhs)
names <- checkSpine Set.empty sp
checkScope (Set.fromList (Bound <$> names)) rhs
let tm = quote rhs
lam = evalWithEnv emptyEnv $ foldr (Lam Ex) tm names
liftIO . atomicModifyIORef' cell $ \case
Just _ -> error "filled cell in solvedMeta"
Nothing -> (Just lam, ())
checkScope :: Set Name -> Value -> ElabM ()
checkScope scope (VNe h sp) =
do
case h of
HVar v ->
unless (v `Set.member` scope) . liftIO . throwIO $
NotInScope v
HMeta{} -> pure ()
traverse_ checkProj sp
where
checkProj (PApp _ t) = checkScope scope t
checkProj PProj1 = pure ()
checkProj PProj2 = pure ()
checkScope scope (VLam _ (Closure n k)) =
checkScope (Set.insert (Bound n) scope) (k (VVar (Bound n)))
checkScope scope (VPi _ d (Closure n k)) = do
checkScope scope d
checkScope (Set.insert (Bound n) scope) (k (VVar (Bound n)))
checkScope scope (VSigma d (Closure n k)) = do
checkScope scope d
checkScope (Set.insert (Bound n) scope) (k (VVar (Bound n)))
checkScope s (VPair a b) = traverse_ (checkScope s) [a, b]
checkScope _ VType = pure ()
checkSpine :: Set Name -> [Projection] -> ElabM [T.Text]
checkSpine scope (PApp Ex (VVar n@(Bound t)):xs)
| n `Set.member` scope = liftIO . throwIO $ NonLinearSpine n
| otherwise = (t:) <$> checkSpine scope xs
checkSpine _ (p:_) = liftIO . throwIO $ SpineProj p
checkSpine _ [] = pure []
newtype NonLinearSpine = NonLinearSpine { getDupeName :: Name }
deriving (Show, Typeable, Exception)
newtype SpineProjection = SpineProj { getSpineProjection :: Projection }
deriving (Show, Typeable, Exception)

+ 48
- 0
src/Elab/Monad.hs View File

@ -0,0 +1,48 @@
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DerivingVia #-}
module Elab.Monad where
import Control.Monad.Reader
import Control.Exception
import qualified Data.Map.Strict as Map
import Data.Map.Strict (Map)
import Data.Typeable
import Syntax
newtype ElabEnv = ElabEnv { getEnv :: Map Name (NFType, Value) }
newtype ElabM a = ElabM { runElab :: ElabEnv -> IO a }
deriving (Functor, Applicative, Monad, MonadReader ElabEnv, MonadIO)
via ReaderT ElabEnv IO
newtype NotInScope = NotInScope { getName :: Name }
deriving (Show, Typeable)
deriving anyclass (Exception)
emptyEnv :: ElabEnv
emptyEnv = ElabEnv mempty
assume :: Name -> Value -> ElabM a -> ElabM a
assume nm ty = local go where
go = ElabEnv . Map.insert nm (ty, VVar nm) . getEnv
define :: Name -> Value -> Value -> ElabM a -> ElabM a
define nm ty vl = local go where
go = ElabEnv . Map.insert nm (ty, vl) . getEnv
getValue :: Name -> ElabM Value
getValue nm = do
vl <- asks (Map.lookup nm . getEnv)
case vl of
Just v -> pure (snd v)
Nothing -> liftIO . throwIO $ NotInScope nm
getNfType :: Name -> ElabM NFType
getNfType nm = do
vl <- asks (Map.lookup nm . getEnv)
case vl of
Just v -> pure (fst v)
Nothing -> liftIO . throwIO $ NotInScope nm

+ 6
- 2
src/Presyntax/Lexer.x View File

@ -3,7 +3,6 @@ module Presyntax.Lexer where
import qualified Data.ByteString.Lazy as Lbs
import qualified Data.Text.Encoding as T
import qualified Data.ByteString as Sbs
import Presyntax.Tokens
}
@ -19,6 +18,11 @@ tokens :-
\= { always TokEqual }
\: { always TokColon }
\, { always TokComma }
\* { always TokStar }
".1" { always TokPi1 }
".2" { always TokPi2 }
\\ { always TokLambda }
"->" { always TokArrow }
@ -35,6 +39,6 @@ alexEOF = do
(AlexPn _ l c, _, _, _) <- alexGetInput
pure $ Token TokEof l c
yield k t@(AlexPn _ l c, _, s, _) i = pure (Token (k $! (T.decodeUtf8 (Lbs.toStrict (Lbs.take i s)))) l c)
yield k (AlexPn _ l c, _, s, _) i = pure (Token (k $! (T.decodeUtf8 (Lbs.toStrict (Lbs.take i s)))) l c)
always k x i = yield (const k) x i
}

+ 28
- 8
src/Presyntax/Parser.y View File

@ -33,18 +33,28 @@ import Presyntax.Lexer
'->' { Token TokArrow _ _ }
':' { Token TokColon _ _ }
'=' { Token TokEqual _ _ }
',' { Token TokComma _ _ }
'*' { Token TokStar _ _ }
'.1' { Token TokPi1 _ _ }
'.2' { Token TokPi2 _ _ }
%%
Exp :: { Expr }
Exp
: ExpFun Exp { App Ex $1 $2 }
| ExpFun '{' Exp '}' { App Im $1 $3 }
| '\\' LambdaList '->' Exp { makeLams $2 $4 }
| '(' VarList ':' Exp ')' '->' Exp { makePis Ex $2 $4 $7 }
| '{' VarList ':' Exp '}' '->' Exp { makePis Im $2 $4 $7 }
| ExpFun '->' Exp { Pi Ex (T.singleton '_') $1 $3 }
| ExpFun { $1 }
: ExpProj Exp { App Ex $1 $2 }
| ExpProj '{' Exp '}' { App Im $1 $3 }
| '\\' LambdaList '->' Exp { makeLams $2 $4 }
| '(' VarList ':' Exp ')' '->' Exp { makePis Ex $2 $4 $7 }
| '{' VarList ':' Exp '}' '->' Exp { makePis Im $2 $4 $7 }
| ExpProj '->' Exp { Pi Ex (T.singleton '_') $1 $3 }
| '(' VarList ':' Exp ')' '*' Exp { makeSigmas $2 $4 $7 }
| ExpProj '*' Exp { Sigma (T.singleton '_') $1 $3 }
| ExpProj { $1 }
LambdaList :: { [(Plicity, Text)] }
: var { [(Ex, $1)] }
@ -57,9 +67,18 @@ VarList :: { [Text] }
: var { [$1] }
| var VarList { $1:$2 }
ExpProj :: { Expr }
: ExpFun '.1' { Proj1 $1 }
| ExpFun '.2' { Proj2 $1 }
| ExpFun { $1 }
ExpFun :: { Expr }
: Atom { $1 }
| '(' Exp ')' { $2 }
| '(' Tuple ')' { $2 }
Tuple :: { Expr }
: Exp { $1 }
| Exp ',' Tuple { Pair $1 $3 }
Atom :: { Expr }
: var { Var $1 }
@ -71,4 +90,5 @@ parseError x = alexError (show x)
makeLams xs b = foldr (uncurry Lam) b xs
makePis p xs t b = foldr (flip (Pi p) t) b xs
makeSigmas xs t b = foldr (flip Sigma t) b xs
}

+ 7
- 1
src/Presyntax/Presyntax.hs View File

@ -8,9 +8,15 @@ data Plicity
data Expr
= Var Text
| App Plicity Expr Expr
| Lam Plicity Text Expr
| Pi Plicity Text Expr Expr
| Lam Plicity Text Expr
| Sigma Text Expr Expr
| Pair Expr Expr
| Proj1 Expr
| Proj2 Expr
deriving (Eq, Show, Ord)
data Statement


+ 5
- 0
src/Presyntax/Tokens.hs View File

@ -14,8 +14,13 @@ data TokenClass
| TokCParen
| TokCBrace
| TokStar
| TokColon
| TokEqual
| TokComma
| TokPi1
| TokPi2
deriving (Eq, Show, Ord)
data Token


+ 96
- 0
src/Syntax.hs View File

@ -0,0 +1,96 @@
{-# LANGUAGE PatternSynonyms #-}
module Syntax where
import Data.Function (on)
import Data.Text (Text)
import Presyntax.Presyntax (Plicity(..))
import qualified Data.Text as T
import Data.IORef (IORef)
data Term
= Ref Name
| App Plicity Term Term
| Lam Plicity Text Term
| Pi Plicity Text Term Term
| Meta MV
| Type
| Sigma Text Term Term
| Pair Term Term
| Proj1 Term
| Proj2 Term
deriving (Eq, Show, Ord)
data MV =
MV { mvName :: Text
, mvCell :: {-# UNPACK #-} !(IORef (Maybe Value))
}
instance Eq MV where
(==) = (==) `on` mvName
instance Ord MV where
(<=) = (<=) `on` mvName
instance Show MV where
show = ('?':) . T.unpack . mvName
data Name
= Bound Text
deriving (Eq, Show, Ord)
type NFType = Value
data Value
= VNe Head [Projection]
| VLam Plicity Closure
| VPi Plicity Value Closure
| VSigma Value Closure
| VPair Value Value
| VType
deriving (Eq, Show, Ord)
pattern VVar :: Name -> Value
pattern VVar x = VNe (HVar x) []
quote :: Value -> Term
quote (VNe h sp) = foldl goSpine (goHead h) (reverse sp) where
goHead (HVar v) = Ref v
goHead (HMeta m) = Meta m
goSpine t (PApp p v) = App p t (quote v)
goSpine t PProj1 = Proj1 t
goSpine t PProj2 = Proj2 t
quote (VLam p (Closure n k)) = Lam p n (quote (k (VVar (Bound n))))
quote (VPi p d (Closure n k)) = Pi p n (quote d) (quote (k (VVar (Bound n))))
quote (VSigma d (Closure n k)) = Sigma n (quote d) (quote (k (VVar (Bound n))))
quote (VPair a b) = Pair (quote a) (quote b)
quote VType = Type
data Closure
= Closure
{ clArgName :: Text
, clCont :: Value -> Value
}
instance Show Closure where
show (Closure n c) = "Closure \\" ++ show n ++ " -> " ++ show (c (VVar (Bound n)))
instance Eq Closure where
Closure _ k == Closure _ k' =
k (VVar (Bound (T.pack "_"))) == k' (VVar (Bound (T.pack "_")))
instance Ord Closure where
Closure _ k <= Closure _ k' =
k (VVar (Bound (T.pack "_"))) <= k' (VVar (Bound (T.pack "_")))
data Head
= HVar Name
| HMeta MV
deriving (Eq, Show, Ord)
data Projection
= PApp Plicity Value
| PProj1
| PProj2
deriving (Eq, Show, Ord)

Loading…
Cancel
Save