From 6d065cdddd1220e602dd8a2d26a9b3b7f2ebcabd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Am=C3=A9lia=20Liao?= Date: Thu, 4 Mar 2021 16:28:17 -0300 Subject: [PATCH] Use glued evaluation to get shorter normal forms --- intro.tt | 14 ++++++------ src/Elab.hs | 8 +++---- src/Elab/Eval.hs | 53 ++++++++++++++++++++++++++++---------------- src/Elab/Monad.hs | 5 ++++- src/Elab/WiredIn.hs | 37 ++++++++++++++++++------------- src/Main.hs | 4 ++-- src/Syntax.hs | 23 ++++++++++++++++++- src/Syntax/Pretty.hs | 3 +++ 8 files changed, 97 insertions(+), 50 deletions(-) diff --git a/intro.tt b/intro.tt index 6f1e316..e181a47 100644 --- a/intro.tt +++ b/intro.tt @@ -534,21 +534,21 @@ IsoToEquiv {A} {B} iso = fill1 : I -> I -> A fill1 i j = comp (\i -> A) (\k [ (i = i0) -> t x1 (iand j k) - , (i = i1) -> g y - , (j = i0) -> g (p1 i) ]) + , (i = i1) -> g y + , (j = i0) -> g (p1 i) ]) (inS (g (p1 i))) fill2 : I -> I -> A fill2 i j = comp (\i -> A) (\k [ (i = i0) -> rem0 (ior j (inot k)) - , (i = i1) -> rem1 (ior j (inot k)) - , (j = i1) -> g y ]) + , (i = i1) -> rem1 (ior j (inot k)) + , (j = i1) -> g y ]) (inS (g y)) sq : I -> I -> A sq i j = comp (\i -> A) (\k [ (i = i0) -> fill0 j (inot k) - , (i = i1) -> fill1 j (inot k) - , (j = i1) -> g y - , (j = i0) -> t (p i) (inot k) ]) + , (i = i1) -> fill1 j (inot k) + , (j = i1) -> g y + , (j = i0) -> t (p i) (inot k) ]) (inS (fill2 i j)) sq1 : I -> I -> B diff --git a/src/Elab.hs b/src/Elab.hs index 5f453a2..2342021 100644 --- a/src/Elab.hs +++ b/src/Elab.hs @@ -208,14 +208,14 @@ checkLetItems map (P.LetBind name rhs:xs) cont = do Nothing -> do (tm, ty) <- infer rhs tm_nf <- eval tm - define (Defined name 0) ty tm_nf \name' -> + makeLetDef (Defined name 0) ty tm_nf \name' -> checkLetItems map xs \xs -> cont ((name', quote ty, tm):xs) Just Nothing -> throwElab $ Redefinition (Defined name 0) Just (Just ty_nf) -> do rhs <- check rhs ty_nf rhs_nf <- eval rhs - define (Defined name 0) ty_nf rhs_nf \name' -> + makeLetDef (Defined name 0) ty_nf rhs_nf \name' -> checkLetItems (Map.insert name Nothing map) xs \xs -> cont ((name', quote ty_nf, rhs):xs) @@ -326,7 +326,7 @@ checkStatement (P.Defn name rhs) k = do Nothing -> do (tm, ty) <- infer rhs tm_nf <- eval tm - define (Defined name 0) ty tm_nf (const k) + makeLetDef (Defined name 0) ty tm_nf (const k) Just nm -> do ty_nf <- getNfType nm t <- asks (Set.member nm . definedNames) @@ -334,7 +334,7 @@ checkStatement (P.Defn name rhs) k = do rhs <- check rhs ty_nf rhs_nf <- eval rhs - define (Defined name 0) ty_nf rhs_nf $ \name -> + makeLetDef (Defined name 0) ty_nf rhs_nf $ \name -> local (\e -> e { definedNames = Set.insert name (definedNames e) }) k checkStatement (P.Builtin winame var) k = do diff --git a/src/Elab/Eval.hs b/src/Elab/Eval.hs index 2f6e4cd..8acf9a1 100644 --- a/src/Elab/Eval.hs +++ b/src/Elab/Eval.hs @@ -41,15 +41,16 @@ eval :: Term -> ElabM Value eval t = asks (flip eval' t) forceIO :: MonadIO m => Value -> m Value -forceIO mv@(VNe (HMeta (mvCell -> cell)) args) = do +forceIO mv@(VNe hd@(HMeta (mvCell -> cell)) args) = do solved <- liftIO $ readIORef cell case solved of - Just vl -> forceIO $ foldl applProj vl args + Just vl -> forceIO (foldl applProj vl args) Nothing -> pure mv forceIO vl@(VSystem fs) = case Map.lookup VI1 fs of Just x -> forceIO x Nothing -> pure vl +forceIO (GluedVl _ _ vl) = forceIO vl forceIO (VComp line phi u a0) = comp line <$> forceIO phi <*> pure u <*> pure a0 forceIO x = pure x @@ -74,12 +75,8 @@ zonkIO (VNe hd sp) = do Just vl -> zonkIO $ foldl applProj vl sp' Nothing -> pure $ VNe hd sp' hd -> pure $ VNe hd sp' - where - zonkSp (PApp p x) = PApp p <$> zonkIO x - zonkSp (PIElim l x y i) = PIElim <$> zonkIO l <*> zonkIO x <*> zonkIO y <*> zonkIO i - zonkSp (POuc a phi u) = POuc <$> zonkIO a <*> zonkIO phi <*> zonkIO u - zonkSp PProj1 = pure PProj1 - zonkSp PProj2 = pure PProj2 + +zonkIO (GluedVl h sp vl) = GluedVl h <$> traverse zonkSp sp <*> zonkIO vl zonkIO (VLam p (Closure s k)) = pure $ VLam p (Closure s (zonk . k)) zonkIO (VPi p d (Closure s k)) = VPi p <$> zonkIO d <*> pure (Closure s (zonk . k)) @@ -124,11 +121,19 @@ zonkIO VTt = pure VTt zonkIO VFf = pure VFf zonkIO (VIf a b c d) = elimBool <$> zonkIO a <*> zonkIO b <*> zonkIO c <*> zonkIO d +zonkSp :: Projection -> IO Projection +zonkSp (PApp p x) = PApp p <$> zonkIO x +zonkSp (PIElim l x y i) = PIElim <$> zonkIO l <*> zonkIO x <*> zonkIO y <*> zonkIO i +zonkSp (POuc a phi u) = POuc <$> zonkIO a <*> zonkIO phi <*> zonkIO u +zonkSp PProj1 = pure PProj1 +zonkSp PProj2 = pure PProj2 + mkVSystem :: Map.Map Value Value -> Value -mkVSystem map = - case Map.lookup VI1 map of +mkVSystem vals = + let map' = Map.fromList (map (\(a, b) -> (force a, b)) (Map.toList vals)) in + case Map.lookup VI1 map' of Just x -> x - Nothing -> VSystem (Map.filterWithKey (\k _ -> k /= VI0) map) + Nothing -> VSystem (Map.filterWithKey (\k _ -> k /= VI0) map') zonk :: Value -> Value zonk = unsafePerformIO . zonkIO @@ -204,7 +209,8 @@ vApp :: HasCallStack => Plicity -> Value -> Value -> Value vApp p (VLam p' k) arg | p == p' = clCont k arg | otherwise = error $ "wrong plicity " ++ show p ++ " vs " ++ show p' ++ " in app " ++ show (App p (quote (VLam p' k)) (quote arg)) -vApp p (VNe h sp) arg = VNe h (sp Seq.:|> PApp p arg) +vApp p (VNe h sp) arg = VNe h (sp Seq.:|> PApp p arg) +vApp p (GluedVl h sp vl) arg = GluedVl h (sp Seq.:|> PApp p arg) (vApp p vl arg) vApp p (VSystem fs) arg = VSystem (fmap (flip (vApp p) arg) fs) vApp p (VInc (VPi _ _ (Closure _ r)) phi f) arg = VInc (r (vApp p f arg)) phi (vApp p f arg) vApp _ x _ = error $ "can't apply " ++ show (prettyTm (quote x)) @@ -216,6 +222,7 @@ infixl 9 @@ vProj1 :: HasCallStack => Value -> Value vProj1 (VPair a _) = a vProj1 (VNe h sp) = VNe h (sp Seq.:|> PProj1) +vProj1 (GluedVl h sp vl) = GluedVl h (sp Seq.:|> PProj1) (vProj1 vl) vProj1 (VSystem fs) = VSystem (fmap vProj1 fs) vProj1 (VInc (VSigma a _) b c) = VInc a b (vProj1 c) vProj1 x = error $ "can't proj1 " ++ show (prettyTm (quote x)) @@ -223,6 +230,7 @@ vProj1 x = error $ "can't proj1 " ++ show (prettyTm (quote x)) vProj2 :: HasCallStack => Value -> Value vProj2 (VPair _ b) = b vProj2 (VNe h sp) = VNe h (sp Seq.:|> PProj2) +vProj2 (GluedVl h sp vl) = GluedVl h (sp Seq.:|> PProj2) (vProj2 vl) vProj2 (VSystem fs) = VSystem (fmap vProj2 fs) vProj2 (VInc (VSigma _ (Closure _ r)) b c) = VInc (r (vProj1 c)) b (vProj2 c) vProj2 x = error $ "can't proj2 " ++ show (prettyTm (quote x)) @@ -239,12 +247,12 @@ unify' topa topb = join $ go <$> forceIO topa <*> forceIO topb where | x == x', length a == length a' = traverse_ (uncurry unify'Spine) (Seq.zip a a') - go (VLam p (Closure _ k)) vl = do - t <- VVar <$> newName + go (VLam p (Closure n k)) vl = do + t <- VVar <$> newName' n unify' (k t) (vApp p vl t) - go vl (VLam p (Closure _ k)) = do - t <- VVar <$> newName + go vl (VLam p (Closure n k)) = do + t <- VVar <$> newName' n unify' (vApp p vl t) (k t) go (VPair a b) vl = unify' a (vProj1 vl) *> unify' b (vProj2 vl) @@ -297,8 +305,8 @@ unify' topa topb = join $ go <$> forceIO topa <*> forceIO topb where go (VComp a phi u a0) (VComp a' phi' u' a0') = traverse_ (uncurry unify') [(a, a'), (phi, phi'), (u, u'), (a0, a0')] - go (VGlueTy _ VI1 u _0) rhs = unify' (u @@ VItIsOne) rhs - go lhs (VGlueTy _ VI1 u _0) = unify' lhs (u @@ VItIsOne) + go (VGlueTy _ (force -> VI1) u _0) rhs = unify' (u @@ VItIsOne) rhs + go lhs (VGlueTy _ (force -> VI1) u _0) = unify' lhs (u @@ VItIsOne) go (VGlueTy a phi u a0) (VGlueTy a' phi' u' a0') = traverse_ (uncurry unify') [(a, a'), (phi, phi'), (u, u'), (a0, a0')] @@ -348,7 +356,7 @@ unify' topa topb = join $ go <$> forceIO topa <*> forceIO topb where | otherwise = fail unify :: HasCallStack => Value -> Value -> ElabM () -unify a b = unify' a b `catchElab` \(_ :: NotEqual) -> liftIO $ throwIO (NotEqual a b) +unify a b = unify' a b `catchElab` \(_ :: SomeException) -> liftIO $ throwIO (NotEqual a b) isConvertibleTo :: Value -> Value -> ElabM (Term -> Term) isConvertibleTo a b = isConvertibleTo (force a) (force b) where @@ -392,6 +400,11 @@ newName = liftIO $ do x <- atomicModifyIORef _nameCounter $ \x -> (x + 1, x + 1) pure (Bound (T.pack (show x)) x) +newName' :: Name -> ElabM Name +newName' n = do + ~(Bound _ x) <- newName + pure (Bound (getNameText n) x) + _nameCounter :: IORef Int _nameCounter = unsafePerformIO $ newIORef 0 {-# NOINLINE _nameCounter #-} @@ -432,6 +445,8 @@ checkScope scope (VNe h sp) = checkProj PProj1 = pure () checkProj PProj2 = pure () +checkScope scope (GluedVl _ _p vl) = checkScope scope vl + checkScope scope (VLam _ (Closure n k)) = checkScope (Set.insert n scope) (k (VVar n)) diff --git a/src/Elab/Monad.hs b/src/Elab/Monad.hs index 323fc7b..217854d 100644 --- a/src/Elab/Monad.hs +++ b/src/Elab/Monad.hs @@ -48,7 +48,10 @@ assume :: Name -> Value -> (Name -> ElabM a) -> ElabM a assume nm ty k = defineInternal nm ty VVar k define :: Name -> Value -> Value -> (Name -> ElabM a) -> ElabM a -define nm vty val = defineInternal nm vty (const val) +define nm vty val = defineInternal nm vty (\nm -> val) + +makeLetDef :: Name -> Value -> Value -> (Name -> ElabM a) -> ElabM a +makeLetDef nm vty val = defineInternal nm vty (\nm -> GluedVl (HVar nm) mempty val) assumes :: [Name] -> Value -> ([Name] -> ElabM a) -> ElabM a assumes nms ty k = do diff --git a/src/Elab/WiredIn.hs b/src/Elab/WiredIn.hs index 6883495..ba3b669 100644 --- a/src/Elab/WiredIn.hs +++ b/src/Elab/WiredIn.hs @@ -24,6 +24,7 @@ import Syntax import System.IO.Unsafe import Syntax.Pretty (prettyTm) import GHC.Stack (HasCallStack) +import Debug.Trace wiType :: WiredIn -> NFType wiType WiType = VType @@ -169,32 +170,32 @@ newtype NoSuchPrimitive = NoSuchPrimitive { getUnknownPrim :: Text } -- Interval operations iand, ior :: Value -> Value -> Value -iand = \case +iand x = case force x of VI1 -> id VI0 -> const VI0 - VIAnd x y -> \case + VIAnd x y -> \z -> case force z of VI0 -> VI0 VI1 -> VI1 z -> iand x (iand y z) - x -> \case + x -> \y -> case force y of VI0 -> VI0 VI1 -> x y -> VIAnd x y -ior = \case +ior x = case force x of VI0 -> id VI1 -> const VI1 - VIOr x y -> \case + VIOr x y -> \z -> case force z of VI1 -> VI1 VI0 -> VIOr x y - z -> ior x (ior y z) - x -> \case + _ -> ior x (ior y z) + x -> \y -> case force y of VI1 -> VI1 VI0 -> x y -> VIOr x y inot :: Value -> Value -inot = \case +inot x = case force x of VI0 -> VI1 VI1 -> VI0 VIOr x y -> VIAnd (inot x) (inot y) @@ -203,10 +204,12 @@ inot = \case x -> VINot x ielim :: Value -> Value -> Value -> Value -> NFEndp -> Value +ielim line left right (GluedVl h sp vl) i = + GluedVl h (sp Seq.:|> PIElim line left right i) (ielim line left right vl i) ielim line left right fn i = case force fn of VLine _ _ _ fun -> fun @@ i - x -> case i of + x -> case force i of VI1 -> right VI0 -> left _ -> case x of @@ -221,6 +224,7 @@ outS _ (force -> VI1) u _ = u @@ VItIsOne outS _ _phi _ (VInc _ _ x) = x outS _ VI0 _ x = x +outS a phi u (GluedVl x sp vl) = GluedVl x (sp Seq.:|> POuc a phi u) (outS a phi u vl) outS a phi u (VNe x sp) = VNe x (sp Seq.:|> POuc a phi u) outS _ _ _ v = error $ "can't outS " ++ show (prettyTm (quote v)) @@ -315,9 +319,10 @@ comp a psi@phi u (compOutS (a @@ VI1) phi (u @@ VI1 @@ VItIsOne) -> a0) = _ -> VComp a phi u (VInc (a @@ VI0) phi a0) compOutS :: NFSort -> NFEndp -> Value -> Value -> Value -compOutS _ _hi _0 vl@VComp{} = vl -compOutS _ _hi _0 (VInc _ _ x) = x -compOutS _ _ _ v = v +compOutS a b c d = compOutS a b c (force d) where + compOutS _ _hi _0 vl@VComp{} = vl + compOutS _ _hi _0 (VInc _ _ x) = x + compOutS _ _ _ v = v system :: (Value -> Value -> Value) -> Value system k = fun \i -> fun \isone -> k i isone @@ -334,13 +339,13 @@ glueType :: NFSort -> NFEndp -> NFPartial -> NFPartial -> Value glueType a phi tys eqvs = VGlueTy a phi tys eqvs glueElem :: NFSort -> NFEndp -> NFPartial -> NFPartial -> NFPartial -> Value -> Value -glueElem _a VI1 _tys _eqvs t _vl = t @@ VItIsOne +glueElem _a (force -> VI1) _tys _eqvs t _vl = t @@ VItIsOne glueElem a phi tys eqvs t vl = VGlue a phi tys eqvs t vl unglue :: NFSort -> NFEndp -> NFPartial -> NFPartial -> Value -> Value -unglue _a VI1 _tys eqvs x = vProj1 (eqvs @@ VItIsOne) @@ x -unglue _a _phi _tys _eqvs (VGlue _ _ _ _ _ vl) = vl -unglue _ _ _ _ (VSystem (Map.toList -> [])) = VSystem (Map.fromList []) +unglue _a (force -> VI1) _tys eqvs x = vProj1 (eqvs @@ VItIsOne) @@ x +unglue _a _phi _tys _eqvs (force -> VGlue _ _ _ _ _ vl) = vl +unglue a phi tys eqvs (force -> VSystem fs) = VSystem (fmap (unglue a phi tys eqvs) fs) unglue a phi tys eqvs vl = VUnglue a phi tys eqvs vl -- Definition of equivalences diff --git a/src/Main.hs b/src/Main.hs index fc10548..9eec467 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -157,8 +157,8 @@ displayExceptions lines = , Handler \(NotEqual ta tb) -> do putStrLn . unlines $ [ "\x1b[1;31merror\x1b[0m: Mismatch between actual and expected types:" - , " * \x1b[1mActual\x1b[0m: " ++ show (zonk ta) - , " * \x1b[1mExpected\x1b[0m: " ++ show (zonk tb) + , " * \x1b[1mActual\x1b[0m: " ++ showValue (zonk ta) + , " * \x1b[1mExpected\x1b[0m: " ++ showValue (zonk tb) ] , Handler \(NoSuchPrimitive x) -> do putStrLn $ "Unknown primitive: " ++ T.unpack x diff --git a/src/Syntax.hs b/src/Syntax.hs index a76e710..e029f27 100644 --- a/src/Syntax.hs +++ b/src/Syntax.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE DeriveDataTypeable #-} module Syntax where @@ -15,6 +16,8 @@ import Data.Set (Set) import Data.Data import Presyntax.Presyntax (Plicity(..), Posn) +import Data.Monoid +import Debug.Trace (traceShow) data WiredIn = WiType @@ -144,6 +147,8 @@ data Value | VSigma Value Closure | VPair Value Value + | GluedVl Head (Seq Projection) Value + | VType | VTypeω | VI @@ -193,6 +198,15 @@ quoteWith names (VNe h sp) = foldl goSpine (goHead h) sp where goSpine t PProj2 = Proj2 t goSpine t (POuc a phi u) = Ouc (quote a) (quote phi) (quote u) t +quoteWith names (GluedVl h sp (VLam p (Closure n k))) = + quoteWith names (VLam p (Closure n (\a -> GluedVl h (sp Seq.:|> PApp p a) (k a)))) + +quoteWith names (GluedVl h sp vl) + | GluedVl _ _ inner <- vl = quoteWith names (GluedVl h sp inner) + | Seq.Empty <- sp = quoteWith names vl + | alwaysShort vl = quoteWith names vl + | otherwise = quoteWith names (VNe h sp) + quoteWith names (VLam p (Closure n k)) = let n' = refresh Nothing names n in Lam p n' (quoteWith (Set.insert n' names) (k (VVar n'))) @@ -240,6 +254,13 @@ quoteWith _ames VTt = Tt quoteWith _ames VFf = Ff quoteWith names (VIf a b c d) = If (quoteWith names a) (quoteWith names b) (quoteWith names c) (quoteWith names d) +alwaysShort :: Value -> Bool +alwaysShort VBool{} = True +alwaysShort VTt{} = True +alwaysShort VFf{} = True +alwaysShort VVar{} = True +alwaysShort _ = False + refresh :: Maybe Value -> Set Name -> Name -> Name refresh (Just VI) n _ = refresh Nothing n (Bound (T.pack "phi") 0) refresh x s n @@ -278,4 +299,4 @@ data Projection | PProj1 | PProj2 | POuc NFSort NFEndp Value - deriving (Eq, Show, Ord) + deriving (Eq, Show, Ord) \ No newline at end of file diff --git a/src/Syntax/Pretty.hs b/src/Syntax/Pretty.hs index 9d44e9c..b727c66 100644 --- a/src/Syntax/Pretty.hs +++ b/src/Syntax/Pretty.hs @@ -91,6 +91,9 @@ prettyTm = prettyTm . everywhere (mkT beautify) where beautify Ff = Ref (Bound (T.pack ".false") 0) beautify (If a b c d) = toFun "if" [a, b, c, d] + beautify (Lam Ex v (App Ex f (Ref v'))) + | v == v', v `Set.notMember` freeVars f = f + beautify (Partial phi a) = toFun "Partial" [phi, a] beautify (PartialP phi a) = toFun "PartialP" [phi, a] beautify (Comp a phi u a0) = toFun "comp" [a, phi, u, a0]