{-# LANGUAGE LambdaCase #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ViewPatterns #-} module Elab.Eval where import Control.Monad.Reader import Control.Exception import qualified Data.Map.Strict as Map import qualified Data.Sequence as Seq import qualified Data.Set as Set import qualified Data.Text as T import Data.Sequence (Seq) import Data.Traversable import Data.Set (Set) import Data.Typeable import Data.Foldable import Data.IORef import Data.Maybe import Elab.Eval.Formula import Elab.Monad import Presyntax.Presyntax (Plicity(..)) import Prettyprinter import Syntax.Pretty import Syntax import System.IO.Unsafe import {-# SOURCE #-} Elab.WiredIn import GHC.Stack eval :: Term -> ElabM Value eval t = asks (flip eval' t) forceIO :: MonadIO m => Value -> m Value forceIO mv@(VNe (HMeta (MV id cell)) args) = do solved <- liftIO $ readIORef cell case solved of Just vl -> forceIO $ foldl applProj vl args Nothing -> pure mv forceIO (VComp line phi u a0) = comp line <$> forceIO phi <*> pure u <*> pure a0 forceIO x = pure x applProj :: Value -> Projection -> Value applProj fun (PApp p arg) = vApp p fun arg applProj fun (PIElim l x y i) = ielim l x y fun i applProj fun (POuc a phi u) = outS a phi u fun applProj fun PProj1 = vProj1 fun applProj fun PProj2 = vProj2 fun force :: Value -> Value force = unsafePerformIO . forceIO -- everywhere force zonkIO :: Value -> IO Value zonkIO (VNe hd sp) = do sp' <- traverse zonkSp sp case hd of HMeta (MV _ cell) -> do solved <- liftIO $ readIORef cell case solved of Just vl -> zonkIO $ foldl applProj vl sp' Nothing -> pure $ VNe hd sp' hd -> pure $ VNe hd sp' where zonkSp (PApp p x) = PApp p <$> zonkIO x zonkSp (PIElim l x y i) = PIElim <$> zonkIO l <*> zonkIO x <*> zonkIO y <*> zonkIO i zonkSp (POuc a phi u) = POuc <$> zonkIO a <*> zonkIO phi <*> zonkIO u zonkSp PProj1 = pure PProj1 zonkSp PProj2 = pure PProj2 zonkIO (VLam p (Closure s k)) = pure $ VLam p (Closure s (zonk . k)) zonkIO (VPi p d (Closure s k)) = VPi p <$> zonkIO d <*> pure (Closure s (zonk . k)) zonkIO (VSigma d (Closure s k)) = VSigma <$> zonkIO d <*> pure (Closure s (zonk . k)) zonkIO (VPair a b) = VPair <$> zonkIO a <*> zonkIO b zonkIO (VPath line x y) = VPath <$> zonkIO line <*> zonkIO x <*> zonkIO y zonkIO (VLine line x y f) = VLine <$> zonkIO line <*> zonkIO x <*> zonkIO y <*> zonkIO f -- Sorts zonkIO VType = pure VType zonkIO VTypeω = pure VTypeω zonkIO VI = pure VI zonkIO VI0 = pure VI0 zonkIO VI1 = pure VI1 zonkIO (VIAnd x y) = iand <$> zonkIO x <*> zonkIO y zonkIO (VIOr x y) = ior <$> zonkIO x <*> zonkIO y zonkIO (VINot x) = inot <$> zonkIO x zonkIO (VIsOne x) = VIsOne <$> zonkIO x zonkIO (VIsOne1 x) = VIsOne1 <$> zonkIO x zonkIO (VIsOne2 x) = VIsOne2 <$> zonkIO x zonkIO VItIsOne = pure VItIsOne zonkIO (VPartial x y) = VPartial <$> zonkIO x <*> zonkIO y zonkIO (VPartialP x y) = VPartialP <$> zonkIO x <*> zonkIO y zonkIO (VSystem fs) = do t <- for (Map.toList fs) $ \(a, b) -> (,) <$> zonkIO a <*> zonkIO b pure (mkVSystem (Map.fromList t)) zonkIO (VSub a b c) = VSub <$> zonkIO a <*> zonkIO b <*> zonkIO c zonkIO (VInc a b c) = VInc <$> zonkIO a <*> zonkIO b <*> zonkIO c zonkIO (VComp a b c d) = comp <$> zonkIO a <*> zonkIO b <*> zonkIO c <*> zonkIO d mkVSystem :: Map.Map Value Value -> Value mkVSystem map = case Map.lookup VI1 map of Just x -> x Nothing -> VSystem (Map.filterWithKey (\k _ -> k /= VI0) map) zonk :: Value -> Value zonk = unsafePerformIO . zonkIO eval' :: ElabEnv -> Term -> Value eval' env (Ref x) = case Map.lookup x (getEnv env) of Just (_, vl) -> vl _ -> VVar x eval' env (App p f x) = vApp p (eval' env f) (eval' env x) eval' env (Lam p s t) = VLam p $ Closure s $ \a -> eval' env { getEnv = Map.insert (Bound s) (error "type of abs", a) (getEnv env) } t eval' env (Pi p s d t) = VPi p (eval' env d) $ Closure s $ \a -> eval' env { getEnv = (Map.insert (Bound s) (error "type of abs", a) (getEnv env))} t eval' _ (Meta m) = VNe (HMeta m) mempty eval' env (Sigma s d t) = VSigma (eval' env d) $ Closure s $ \a -> eval' env { getEnv = Map.insert (Bound s) (error "type of abs", a) (getEnv env) } t eval' e (Pair a b) = VPair (eval' e a) (eval' e b) eval' e (Proj1 a) = vProj1 (eval' e a) eval' e (Proj2 a) = vProj2 (eval' e a) eval' _ Type = VType eval' _ Typeω = VTypeω eval' _ I = VI eval' _ I0 = VI0 eval' _ I1 = VI1 eval' e (IAnd x y) = iand (eval' e x) (eval' e y) eval' e (IOr x y) = ior (eval' e x) (eval' e y) eval' e (INot x) = inot (eval' e x) eval' e (PathP l a b) = VPath (eval' e l) (eval' e a) (eval' e b) eval' e (IElim l x y f i) = ielim (eval' e l) (eval' e x) (eval' e y) (eval' e f) (eval' e i) eval' e (PathIntro p x y f) = VLine (eval' e p) (eval' e x) (eval' e y) (eval' e f) eval' e (IsOne i) = VIsOne (eval' e i) eval' e (IsOne1 i) = VIsOne1 (eval' e i) eval' e (IsOne2 i) = VIsOne2 (eval' e i) eval' _ ItIsOne = VItIsOne eval' e (Partial x y) = VPartial (eval' e x) (eval' e y) eval' e (PartialP x y) = VPartialP (eval' e x) (eval' e y) eval' e (System fs) = VSystem (Map.fromList $ map (\(x, y) -> (eval' e x, eval' e y)) $ Map.toList $ fs) eval' e (Sub a phi u) = VSub (eval' e a) (eval' e phi) (eval' e u) eval' e (Inc a phi u) = VInc (eval' e a) (eval' e phi) (eval' e u) eval' e (Ouc a phi u x) = outS (eval' e a) (eval' e phi) (eval' e u) (eval' e x) eval' e (Comp a phi u a0) = comp (eval' e a) (eval' e phi) (eval' e u) (eval' e a0) vApp :: HasCallStack => Plicity -> Value -> Value -> Value vApp p (VLam p' k) arg | p == p' = clCont k arg | otherwise = error $ "wrong plicity " ++ show p ++ " vs " ++ show p' ++ " in app " ++ show (App p (quote (VLam p' k)) (quote arg)) vApp p (VNe h sp) arg = VNe h (sp Seq.:|> PApp p arg) vApp p (VSystem fs) arg = VSystem (fmap (flip (vApp p) arg) fs) vApp _ x _ = error $ "can't apply " ++ show x (@@) :: HasCallStack => Value -> Value -> Value (@@) = vApp Ex infixl 9 @@ vProj1 :: Value -> Value vProj1 (VPair a _) = a vProj1 (VNe h sp) = VNe h (sp Seq.:|> PProj1) vProj1 (VSystem fs) = VSystem (fmap vProj1 fs) vProj1 x = error $ "can't proj1 " ++ show x vProj2 :: Value -> Value vProj2 (VPair _ b) = b vProj2 (VNe h sp) = VNe h (sp Seq.:|> PProj2) vProj2 (VSystem fs) = VSystem (fmap vProj2 fs) vProj2 x = error $ "can't proj2 " ++ show x data NotEqual = NotEqual Value Value deriving (Show, Typeable, Exception) unify' :: HasCallStack => Value -> Value -> ElabM () unify' topa topb = join $ go <$> forceIO topa <*> forceIO topb where go (VNe (HMeta mv) sp) rhs = solveMeta mv sp rhs go rhs (VNe (HMeta mv) sp) = solveMeta mv sp rhs go (VNe x a) (VNe x' a') | x == x', length a == length a' = traverse_ (uncurry unify'Spine) (Seq.zip a a') | x == HVar (Bound (T.pack "y")), x' == HVar (Bound (T.pack "A")) = error "what" go lhs@(VNe _hd (_ Seq.:|> PIElim _l x y i)) rhs = case force i of VI0 -> unify' x rhs VI1 -> unify' y rhs _ -> case rhs of VSystem sys -> goSystem (flip unify') sys lhs _ -> fail go lhs rhs@(VNe _hd (_ Seq.:|> PIElim _l x y i)) = case force i of VI0 -> unify' lhs x VI1 -> unify' lhs y _ -> case lhs of VSystem sys -> goSystem unify' sys rhs _ -> fail go (VLam p (Closure _ k)) vl = do t <- VVar . Bound <$> newName unify' (k t) (vApp p vl t) go vl (VLam p (Closure _ k)) = do t <- VVar . Bound <$> newName unify' (vApp p vl t) (k t) go (VPair a b) vl = unify' a (vProj1 vl) *> unify' b (vProj2 vl) go vl (VPair a b) = unify' (vProj1 vl) a *> unify' (vProj2 vl) b go (VPi p d (Closure _ k)) (VPi p' d' (Closure _ k')) | p == p' = do t <- VVar . Bound <$> newName unify' d d' unify' (k t) (k' t) go (VSigma d (Closure _ k)) (VSigma d' (Closure _ k')) = do t <- VVar . Bound <$> newName unify' d d' unify' (k t) (k' t) go VType VType = pure () go VTypeω VTypeω = pure () go VI VI = pure () go (VPath l x y) (VPath l' x' y') = do unify' l l' unify' x x' unify' y y' go (VLine l x y p) p' = do n <- VVar . Bound <$> newName unify (p @@ n) (ielim l x y p' n) go p' (VLine l x y p) = do n <- VVar . Bound <$> newName unify (ielim l x y p' n) (p @@ n) go (VIsOne x) (VIsOne y) = unify' x y -- IsOne is proof-irrelevant: go VItIsOne _ = pure () go _ VItIsOne = pure () go VIsOne1{} _ = pure () go _ VIsOne1{} = pure () go VIsOne2{} _ = pure () go _ VIsOne2{} = pure () go (VPartial phi r) (VPartial phi' r') = unify' phi phi' *> unify r r' go (VPartialP phi r) (VPartialP phi' r') = unify' phi phi' *> unify r r' go (VSub a phi u) (VSub a' phi' u') = traverse_ (uncurry unify') [(a, a'), (phi, phi'), (u, u')] go (VInc a phi u) (VInc a' phi' u') = traverse_ (uncurry unify') [(a, a'), (phi, phi'), (u, u')] go (VComp a phi u a0) (VComp a' phi' u' a0') = traverse_ (uncurry unify') [(a, a'), (phi, phi'), (u, u'), (a0, a0')] go (VSystem sys) rhs = goSystem unify' sys rhs go rhs (VSystem sys) = goSystem (flip unify') sys rhs go x y = case (toDnf x, toDnf y) of (Just xs, Just ys) -> unify'Formula xs ys _ -> fail goSystem :: (Value -> Value -> ElabM ()) -> Map.Map Value Value -> Value -> ElabM () goSystem k sys rhs = do let rhs_q = quote rhs env <- ask for_ (Map.toList sys) $ \(f, i) -> do let i_q = quote i for (truthAssignments f (getEnv env)) $ \e -> k (eval' env{getEnv = e} i_q) (eval' env{getEnv = e} rhs_q) fail = throwElab $ NotEqual topa topb unify'Spine (PApp a v) (PApp a' v') | a == a' = unify' v v' unify'Spine PProj1 PProj1 = pure () unify'Spine PProj2 PProj2 = pure () unify'Spine (PIElim _ _ _ i) (PIElim _ _ _ j) = unify' i j unify'Spine (POuc a phi u) (POuc a' phi' u') = traverse_ (uncurry unify') [(a, a'), (phi, phi'), (u, u')] unify'Spine _ _ = fail unify'Formula x y | compareDNFs x y = pure () | otherwise = fail unify :: HasCallStack => Value -> Value -> ElabM () unify a b = unify' a b `catchElab` \(_ :: NotEqual) -> liftIO $ throwIO (NotEqual a b) isConvertibleTo :: Value -> Value -> ElabM (Term -> Term) VPi Im d (Closure _v k) `isConvertibleTo` ty = do meta <- newMeta d cont <- k meta `isConvertibleTo` ty pure (\f -> cont (App Im f (quote meta))) VType `isConvertibleTo` VTypeω = pure id VPi p d (Closure _ k) `isConvertibleTo` VPi p' d' (Closure _ k') | p == p' = do wp <- d' `isConvertibleTo` d n <- newName wp_n <- eval (Lam Ex n (wp (Ref (Bound n)))) wp' <- k (VVar (Bound n)) `isConvertibleTo` k' (wp_n @@ VVar (Bound n)) pure (\f -> Lam p n (wp' (App p f (wp (Ref (Bound n)))))) isConvertibleTo a b = do unify' a b pure id newMeta :: Value -> ElabM Value newMeta _dom = do n <- newName c <- liftIO $ newIORef Nothing let m = MV n c env <- asks getEnv t <- for (Map.toList env) $ \(n, _) -> pure $ case n of Bound{} -> Just (PApp Ex (VVar n)) _ -> Nothing pure (VNe (HMeta m) (Seq.fromList (catMaybes t))) newName :: MonadIO m => m T.Text newName = liftIO $ do x <- atomicModifyIORef _nameCounter $ \x -> (x + 1, x + 1) pure (T.pack (show x)) _nameCounter :: IORef Int _nameCounter = unsafePerformIO $ newIORef 0 {-# NOINLINE _nameCounter #-} solveMeta :: MV -> Seq Projection -> Value -> ElabM () solveMeta m@(MV _ cell) sp rhs = do env <- ask names <- checkSpine Set.empty sp checkScope (Set.fromList (Bound <$> names)) rhs `withNote` hsep [prettyTm (quote (VNe (HMeta m) sp)), pretty '≡', prettyTm (quote rhs)] let tm = quote rhs lam = eval' env $ foldr (Lam Ex) tm names liftIO . atomicModifyIORef' cell $ \case Just _ -> error "filled cell in solvedMeta" Nothing -> (Just lam, ()) checkScope :: Set Name -> Value -> ElabM () checkScope scope (VNe h sp) = do case h of HVar v@Bound{} -> unless (v `Set.member` scope) . throwElab $ NotInScope v HVar{} -> pure () HMeta{} -> pure () traverse_ checkProj sp where checkProj (PApp _ t) = checkScope scope t checkProj (PIElim l x y i) = traverse_ (checkScope scope) [l, x, y, i] checkProj (POuc a phi u) = traverse_ (checkScope scope) [a, phi, u] checkProj PProj1 = pure () checkProj PProj2 = pure () checkScope scope (VLam _ (Closure n k)) = checkScope (Set.insert (Bound n) scope) (k (VVar (Bound n))) checkScope scope (VPi _ d (Closure n k)) = do checkScope scope d checkScope (Set.insert (Bound n) scope) (k (VVar (Bound n))) checkScope scope (VSigma d (Closure n k)) = do checkScope scope d checkScope (Set.insert (Bound n) scope) (k (VVar (Bound n))) checkScope s (VPair a b) = traverse_ (checkScope s) [a, b] checkScope _ VType = pure () checkScope _ VTypeω = pure () checkScope _ VI = pure () checkScope _ VI0 = pure () checkScope _ VI1 = pure () checkScope s (VIAnd x y) = traverse_ (checkScope s) [x, y] checkScope s (VIOr x y) = traverse_ (checkScope s) [x, y] checkScope s (VINot x) = checkScope s x checkScope s (VPath line a b) = traverse_ (checkScope s) [line, a, b] checkScope s (VLine _ _ _ line) = checkScope s line checkScope s (VIsOne x) = checkScope s x checkScope s (VIsOne1 x) = checkScope s x checkScope s (VIsOne2 x) = checkScope s x checkScope _ VItIsOne = pure () checkScope s (VPartial x y) = traverse_ (checkScope s) [x, y] checkScope s (VPartialP x y) = traverse_ (checkScope s) [x, y] checkScope s (VSystem fs) = for_ (Map.toList fs) $ \(x, y) -> traverse_ (checkScope s) [x, y] checkScope s (VSub a b c) = traverse_ (checkScope s) [a, b, c] checkScope s (VInc a b c) = traverse_ (checkScope s) [a, b, c] checkScope s (VComp a phi u a0) = traverse_ (checkScope s) [a, phi, u, a0] checkSpine :: Set Name -> Seq Projection -> ElabM [T.Text] checkSpine scope (PApp Ex (VVar n@(Bound t)) Seq.:<| xs) | n `Set.member` scope = throwElab $ NonLinearSpine n | otherwise = (t:) <$> checkSpine scope xs checkSpine _ (p Seq.:<| _) = throwElab $ SpineProj p checkSpine _ Seq.Empty = pure [] newtype NonLinearSpine = NonLinearSpine { getDupeName :: Name } deriving (Show, Typeable, Exception) newtype SpineProjection = SpineProj { getSpineProjection :: Projection } deriving (Show, Typeable, Exception)