less prototype, less bad code implementation of CCHM type theory
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
5.4 KiB

  1. {-# LANGUAGE DeriveAnyClass #-}
  2. {-# LANGUAGE GeneralizedNewtypeDeriving #-}
  3. {-# LANGUAGE DerivingVia #-}
  4. {-# LANGUAGE DeriveAnyClass #-}
  5. module Elab.Monad where
  6. import Control.Monad.Reader
  7. import Control.Exception
  8. import Data.Text.Prettyprint.Doc.Render.Terminal (AnsiStyle)
  9. import qualified Data.Map.Strict as Map
  10. import Data.Text.Prettyprint.Doc
  11. import Data.Map.Strict (Map)
  12. import Data.Sequence (Seq)
  13. import Data.Text (Text)
  14. import Data.Set (Set)
  15. import Data.Typeable
  16. import Data.IORef
  17. import qualified Presyntax.Presyntax as P
  18. import Syntax
  19. data ElabEnv =
  20. ElabEnv { getEnv :: Map Name (NFType, Value)
  21. , nameMap :: Map Text Name
  22. , pingPong :: {-# UNPACK #-} !Int
  23. , commHook :: Value -> IO ()
  24. , currentSpan :: Maybe (P.Posn, P.Posn)
  25. , currentFile :: Maybe Text
  26. , whereBound :: Map Name (P.Posn, P.Posn)
  27. , definedNames :: Set Name
  28. , boundaries :: Map Name Boundary
  29. , unsolvedMetas :: {-# UNPACK #-} !(IORef (Map MV [(Seq Projection, Value)]))
  30. }
  31. newtype ElabM a = ElabM { runElab :: ElabEnv -> IO a }
  32. deriving (Functor, Applicative, Monad, MonadReader ElabEnv, MonadIO)
  33. via ReaderT ElabEnv IO
  34. emptyEnv :: IO ElabEnv
  35. emptyEnv = ElabEnv mempty mempty 0 (const (pure ())) Nothing Nothing mempty mempty mempty <$> newIORef mempty
  36. addBoundary :: Name -> Boundary -> ElabM a -> ElabM a
  37. addBoundary nm boundary = local (\e -> e { boundaries = Map.insert nm boundary (boundaries e)} )
  38. assume :: Name -> Value -> (Name -> ElabM a) -> ElabM a
  39. assume nm ty k = defineInternal nm ty VVar k
  40. define :: Name -> Value -> Value -> (Name -> ElabM a) -> ElabM a
  41. define nm vty val = defineInternal nm vty (const val)
  42. makeLetDef :: Name -> Value -> Value -> (Name -> ElabM a) -> ElabM a
  43. makeLetDef nm vty val = defineInternal nm vty (\nm -> GluedVl (HVar nm) mempty val)
  44. assumes :: [Name] -> Value -> ([Name] -> ElabM a) -> ElabM a
  45. assumes nms ty k = do
  46. let
  47. go acc [] k = k acc
  48. go acc (x:xs) k = assume x ty $ \n -> go (n:acc) xs k
  49. in go [] nms k
  50. defineInternal :: Name -> Value -> (Name -> Value) -> (Name -> ElabM a) -> ElabM a
  51. defineInternal nm vty val k =
  52. do
  53. env <- ask
  54. let (env', name') = go env
  55. local (const env') (k name')
  56. where
  57. go x =
  58. let
  59. nm' = case Map.lookup (getNameText nm) (nameMap x) of
  60. Just name -> incName nm name
  61. Nothing -> nm
  62. in ( x { getEnv = Map.insert nm' (vty, val nm') (getEnv x)
  63. , nameMap = Map.insert (getNameText nm) nm' (nameMap x)
  64. , whereBound = maybe (whereBound x) (flip (Map.insert nm') (whereBound x)) (currentSpan x)
  65. }
  66. , nm')
  67. redefine :: Name -> Value -> Value -> ElabM a -> ElabM a
  68. redefine nm vty val = local go where
  69. go x = x { getEnv = Map.insert nm (vty, val) (getEnv x)
  70. , nameMap = Map.insert (getNameText nm) nm (nameMap x)
  71. , whereBound = maybe (whereBound x) (flip (Map.insert nm) (whereBound x)) (currentSpan x)
  72. }
  73. getValue :: Name -> ElabM Value
  74. getValue nm = do
  75. vl <- asks (Map.lookup nm . getEnv)
  76. case vl of
  77. Just v -> pure (snd v)
  78. Nothing -> throwElab $ NotInScope nm
  79. getNfType :: Name -> ElabM NFType
  80. getNfType nm = do
  81. vl <- asks (Map.lookup nm . getEnv)
  82. case vl of
  83. Just v -> pure (fst v)
  84. Nothing -> throwElab $ NotInScope nm
  85. getNameFor :: Text -> ElabM Name
  86. getNameFor x = do
  87. vl <- asks (Map.lookup x . nameMap)
  88. case vl of
  89. Just v -> pure v
  90. Nothing -> liftIO . throwIO $ NotInScope (Bound x 0)
  91. switch :: ElabM a -> ElabM a
  92. switch k =
  93. do
  94. depth <- asks pingPong
  95. when (depth >= 128) $ throwElab StackOverflow
  96. local go k
  97. where go e = e { pingPong = pingPong e + 1 }
  98. newtype NotInScope = NotInScope { nameNotInScope :: Name }
  99. deriving (Show, Typeable)
  100. deriving anyclass (Exception)
  101. data AttachedNote = AttachedNote { getNote :: Doc AnsiStyle, getExc :: SomeException }
  102. deriving (Show, Typeable)
  103. deriving anyclass (Exception)
  104. withNote :: ElabM a -> Doc AnsiStyle -> ElabM a
  105. withNote k note = do
  106. env <- ask
  107. liftIO $
  108. runElab k env
  109. `catch` \e -> throwIO (AttachedNote note e)
  110. data WhileChecking = WhileChecking { startPos :: P.Posn, endPos :: P.Posn, exc :: SomeException }
  111. deriving (Show, Typeable, Exception)
  112. withSpan :: P.Posn -> P.Posn -> ElabM a -> ElabM a
  113. withSpan a b k = do
  114. env <- ask
  115. liftIO $
  116. runElab k env{ currentSpan = Just (a, b) }
  117. `catches` [ Handler $ \e@WhileChecking{} -> throwIO e
  118. , Handler $ \e -> throwIO (WhileChecking a b e)
  119. ]
  120. data SeeAlso = SeeAlso { saStartPos :: P.Posn, saEndPos :: P.Posn, saExc :: SomeException }
  121. deriving (Show, Typeable, Exception)
  122. seeAlso :: ElabM a -> Name -> ElabM a
  123. seeAlso k nm = do
  124. env <- ask
  125. case Map.lookup nm (whereBound env) of
  126. Just l ->
  127. liftIO $ runElab k env
  128. `catch` \e -> throwIO (SeeAlso (fst l) (snd l) e)
  129. Nothing -> k
  130. catchElab :: Exception e => ElabM a -> (e -> ElabM a) -> ElabM a
  131. catchElab k h = do
  132. env <- ask
  133. liftIO $ runElab k env `catch` \e -> runElab (h e) env
  134. tryElab :: Exception e => ElabM a -> ElabM (Either e a)
  135. tryElab k = do
  136. env <- ask
  137. liftIO $ (Right <$> runElab k env) `catch` \e -> pure (Left e)
  138. throwElab :: Exception e => e -> ElabM a
  139. throwElab = liftIO . throwIO
  140. incName :: Name -> Name -> Name
  141. incName (Bound x _) n = Bound x (getNameNum n + 1)
  142. incName (Defined x _) n = Defined x (getNameNum n + 1)
  143. incName (ConName x _ s a) n = ConName x (getNameNum n + 1) s a