1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
1.2 +++ b/src/HOL/Imperative_HOL/Heap_Monad.thy Thu Jan 08 17:10:41 2009 +0100
1.3 @@ -0,0 +1,425 @@
1.4 +(* Title: HOL/Library/Heap_Monad.thy
1.5 + ID: $Id$
1.6 + Author: John Matthews, Galois Connections; Alexander Krauss, Lukas Bulwahn & Florian Haftmann, TU Muenchen
1.7 +*)
1.8 +
1.9 +header {* A monad with a polymorphic heap *}
1.10 +
1.11 +theory Heap_Monad
1.12 +imports Heap
1.13 +begin
1.14 +
1.15 +subsection {* The monad *}
1.16 +
1.17 +subsubsection {* Monad combinators *}
1.18 +
1.19 +datatype exception = Exn
1.20 +
1.21 +text {* Monadic heap actions either produce values
1.22 + and transform the heap, or fail *}
1.23 +datatype 'a Heap = Heap "heap \<Rightarrow> ('a + exception) \<times> heap"
1.24 +
1.25 +primrec
1.26 + execute :: "'a Heap \<Rightarrow> heap \<Rightarrow> ('a + exception) \<times> heap" where
1.27 + "execute (Heap f) = f"
1.28 +lemmas [code del] = execute.simps
1.29 +
1.30 +lemma Heap_execute [simp]:
1.31 + "Heap (execute f) = f" by (cases f) simp_all
1.32 +
1.33 +lemma Heap_eqI:
1.34 + "(\<And>h. execute f h = execute g h) \<Longrightarrow> f = g"
1.35 + by (cases f, cases g) (auto simp: expand_fun_eq)
1.36 +
1.37 +lemma Heap_eqI':
1.38 + "(\<And>h. (\<lambda>x. execute (f x) h) = (\<lambda>y. execute (g y) h)) \<Longrightarrow> f = g"
1.39 + by (auto simp: expand_fun_eq intro: Heap_eqI)
1.40 +
1.41 +lemma Heap_strip: "(\<And>f. PROP P f) \<equiv> (\<And>g. PROP P (Heap g))"
1.42 +proof
1.43 + fix g :: "heap \<Rightarrow> ('a + exception) \<times> heap"
1.44 + assume "\<And>f. PROP P f"
1.45 + then show "PROP P (Heap g)" .
1.46 +next
1.47 + fix f :: "'a Heap"
1.48 + assume assm: "\<And>g. PROP P (Heap g)"
1.49 + then have "PROP P (Heap (execute f))" .
1.50 + then show "PROP P f" by simp
1.51 +qed
1.52 +
1.53 +definition
1.54 + heap :: "(heap \<Rightarrow> 'a \<times> heap) \<Rightarrow> 'a Heap" where
1.55 + [code del]: "heap f = Heap (\<lambda>h. apfst Inl (f h))"
1.56 +
1.57 +lemma execute_heap [simp]:
1.58 + "execute (heap f) h = apfst Inl (f h)"
1.59 + by (simp add: heap_def)
1.60 +
1.61 +definition
1.62 + bindM :: "'a Heap \<Rightarrow> ('a \<Rightarrow> 'b Heap) \<Rightarrow> 'b Heap" (infixl ">>=" 54) where
1.63 + [code del]: "f >>= g = Heap (\<lambda>h. case execute f h of
1.64 + (Inl x, h') \<Rightarrow> execute (g x) h'
1.65 + | r \<Rightarrow> r)"
1.66 +
1.67 +notation
1.68 + bindM (infixl "\<guillemotright>=" 54)
1.69 +
1.70 +abbreviation
1.71 + chainM :: "'a Heap \<Rightarrow> 'b Heap \<Rightarrow> 'b Heap" (infixl ">>" 54) where
1.72 + "f >> g \<equiv> f >>= (\<lambda>_. g)"
1.73 +
1.74 +notation
1.75 + chainM (infixl "\<guillemotright>" 54)
1.76 +
1.77 +definition
1.78 + return :: "'a \<Rightarrow> 'a Heap" where
1.79 + [code del]: "return x = heap (Pair x)"
1.80 +
1.81 +lemma execute_return [simp]:
1.82 + "execute (return x) h = apfst Inl (x, h)"
1.83 + by (simp add: return_def)
1.84 +
1.85 +definition
1.86 + raise :: "string \<Rightarrow> 'a Heap" where -- {* the string is just decoration *}
1.87 + [code del]: "raise s = Heap (Pair (Inr Exn))"
1.88 +
1.89 +notation (latex output)
1.90 + "raise" ("\<^raw:{\textsf{raise}}>")
1.91 +
1.92 +lemma execute_raise [simp]:
1.93 + "execute (raise s) h = (Inr Exn, h)"
1.94 + by (simp add: raise_def)
1.95 +
1.96 +
1.97 +subsubsection {* do-syntax *}
1.98 +
1.99 +text {*
1.100 + We provide a convenient do-notation for monadic expressions
1.101 + well-known from Haskell. @{const Let} is printed
1.102 + specially in do-expressions.
1.103 +*}
1.104 +
1.105 +nonterminals do_expr
1.106 +
1.107 +syntax
1.108 + "_do" :: "do_expr \<Rightarrow> 'a"
1.109 + ("(do (_)//done)" [12] 100)
1.110 + "_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
1.111 + ("_ <- _;//_" [1000, 13, 12] 12)
1.112 + "_chainM" :: "'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
1.113 + ("_;//_" [13, 12] 12)
1.114 + "_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
1.115 + ("let _ = _;//_" [1000, 13, 12] 12)
1.116 + "_nil" :: "'a \<Rightarrow> do_expr"
1.117 + ("_" [12] 12)
1.118 +
1.119 +syntax (xsymbols)
1.120 + "_bindM" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
1.121 + ("_ \<leftarrow> _;//_" [1000, 13, 12] 12)
1.122 +syntax (latex output)
1.123 + "_do" :: "do_expr \<Rightarrow> 'a"
1.124 + ("(\<^raw:{\textsf{do}}> (_))" [12] 100)
1.125 + "_let" :: "pttrn \<Rightarrow> 'a \<Rightarrow> do_expr \<Rightarrow> do_expr"
1.126 + ("\<^raw:\textsf{let}> _ = _;//_" [1000, 13, 12] 12)
1.127 +notation (latex output)
1.128 + "return" ("\<^raw:{\textsf{return}}>")
1.129 +
1.130 +translations
1.131 + "_do f" => "f"
1.132 + "_bindM x f g" => "f \<guillemotright>= (\<lambda>x. g)"
1.133 + "_chainM f g" => "f \<guillemotright> g"
1.134 + "_let x t f" => "CONST Let t (\<lambda>x. f)"
1.135 + "_nil f" => "f"
1.136 +
1.137 +print_translation {*
1.138 +let
1.139 + fun dest_abs_eta (Abs (abs as (_, ty, _))) =
1.140 + let
1.141 + val (v, t) = Syntax.variant_abs abs;
1.142 + in (Free (v, ty), t) end
1.143 + | dest_abs_eta t =
1.144 + let
1.145 + val (v, t) = Syntax.variant_abs ("", dummyT, t $ Bound 0);
1.146 + in (Free (v, dummyT), t) end;
1.147 + fun unfold_monad (Const (@{const_syntax bindM}, _) $ f $ g) =
1.148 + let
1.149 + val (v, g') = dest_abs_eta g;
1.150 + val vs = fold_aterms (fn Free (v, _) => insert (op =) v | _ => I) v [];
1.151 + val v_used = fold_aterms
1.152 + (fn Free (w, _) => (fn s => s orelse member (op =) vs w) | _ => I) g' false;
1.153 + in if v_used then
1.154 + Const ("_bindM", dummyT) $ v $ f $ unfold_monad g'
1.155 + else
1.156 + Const ("_chainM", dummyT) $ f $ unfold_monad g'
1.157 + end
1.158 + | unfold_monad (Const (@{const_syntax chainM}, _) $ f $ g) =
1.159 + Const ("_chainM", dummyT) $ f $ unfold_monad g
1.160 + | unfold_monad (Const (@{const_syntax Let}, _) $ f $ g) =
1.161 + let
1.162 + val (v, g') = dest_abs_eta g;
1.163 + in Const ("_let", dummyT) $ v $ f $ unfold_monad g' end
1.164 + | unfold_monad (Const (@{const_syntax Pair}, _) $ f) =
1.165 + Const (@{const_syntax return}, dummyT) $ f
1.166 + | unfold_monad f = f;
1.167 + fun contains_bindM (Const (@{const_syntax bindM}, _) $ _ $ _) = true
1.168 + | contains_bindM (Const (@{const_syntax Let}, _) $ _ $ Abs (_, _, t)) =
1.169 + contains_bindM t;
1.170 + fun bindM_monad_tr' (f::g::ts) = list_comb
1.171 + (Const ("_do", dummyT) $ unfold_monad (Const (@{const_syntax bindM}, dummyT) $ f $ g), ts);
1.172 + fun Let_monad_tr' (f :: (g as Abs (_, _, g')) :: ts) = if contains_bindM g' then list_comb
1.173 + (Const ("_do", dummyT) $ unfold_monad (Const (@{const_syntax Let}, dummyT) $ f $ g), ts)
1.174 + else raise Match;
1.175 +in [
1.176 + (@{const_syntax bindM}, bindM_monad_tr'),
1.177 + (@{const_syntax Let}, Let_monad_tr')
1.178 +] end;
1.179 +*}
1.180 +
1.181 +
1.182 +subsection {* Monad properties *}
1.183 +
1.184 +subsubsection {* Monad laws *}
1.185 +
1.186 +lemma return_bind: "return x \<guillemotright>= f = f x"
1.187 + by (simp add: bindM_def return_def)
1.188 +
1.189 +lemma bind_return: "f \<guillemotright>= return = f"
1.190 +proof (rule Heap_eqI)
1.191 + fix h
1.192 + show "execute (f \<guillemotright>= return) h = execute f h"
1.193 + by (auto simp add: bindM_def return_def split: sum.splits prod.splits)
1.194 +qed
1.195 +
1.196 +lemma bind_bind: "(f \<guillemotright>= g) \<guillemotright>= h = f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= h)"
1.197 + by (rule Heap_eqI) (auto simp add: bindM_def split: split: sum.splits prod.splits)
1.198 +
1.199 +lemma bind_bind': "f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= h x) = f \<guillemotright>= (\<lambda>x. g x \<guillemotright>= (\<lambda>y. return (x, y))) \<guillemotright>= (\<lambda>(x, y). h x y)"
1.200 + by (rule Heap_eqI) (auto simp add: bindM_def split: split: sum.splits prod.splits)
1.201 +
1.202 +lemma raise_bind: "raise e \<guillemotright>= f = raise e"
1.203 + by (simp add: raise_def bindM_def)
1.204 +
1.205 +
1.206 +lemmas monad_simp = return_bind bind_return bind_bind raise_bind
1.207 +
1.208 +
1.209 +subsection {* Generic combinators *}
1.210 +
1.211 +definition
1.212 + liftM :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b Heap"
1.213 +where
1.214 + "liftM f = return o f"
1.215 +
1.216 +definition
1.217 + compM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> ('b \<Rightarrow> 'c Heap) \<Rightarrow> 'a \<Rightarrow> 'c Heap" (infixl ">>==" 54)
1.218 +where
1.219 + "(f >>== g) = (\<lambda>x. f x \<guillemotright>= g)"
1.220 +
1.221 +notation
1.222 + compM (infixl "\<guillemotright>==" 54)
1.223 +
1.224 +lemma liftM_collapse: "liftM f x = return (f x)"
1.225 + by (simp add: liftM_def)
1.226 +
1.227 +lemma liftM_compM: "liftM f \<guillemotright>== g = g o f"
1.228 + by (auto intro: Heap_eqI' simp add: expand_fun_eq liftM_def compM_def bindM_def)
1.229 +
1.230 +lemma compM_return: "f \<guillemotright>== return = f"
1.231 + by (simp add: compM_def monad_simp)
1.232 +
1.233 +lemma compM_compM: "(f \<guillemotright>== g) \<guillemotright>== h = f \<guillemotright>== (g \<guillemotright>== h)"
1.234 + by (simp add: compM_def monad_simp)
1.235 +
1.236 +lemma liftM_bind:
1.237 + "(\<lambda>x. liftM f x \<guillemotright>= liftM g) = liftM (\<lambda>x. g (f x))"
1.238 + by (rule Heap_eqI') (simp add: monad_simp liftM_def bindM_def)
1.239 +
1.240 +lemma liftM_comp:
1.241 + "liftM f o g = liftM (f o g)"
1.242 + by (rule Heap_eqI') (simp add: liftM_def)
1.243 +
1.244 +lemmas monad_simp' = monad_simp liftM_compM compM_return
1.245 + compM_compM liftM_bind liftM_comp
1.246 +
1.247 +primrec
1.248 + mapM :: "('a \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b list Heap"
1.249 +where
1.250 + "mapM f [] = return []"
1.251 + | "mapM f (x#xs) = do y \<leftarrow> f x;
1.252 + ys \<leftarrow> mapM f xs;
1.253 + return (y # ys)
1.254 + done"
1.255 +
1.256 +primrec
1.257 + foldM :: "('a \<Rightarrow> 'b \<Rightarrow> 'b Heap) \<Rightarrow> 'a list \<Rightarrow> 'b \<Rightarrow> 'b Heap"
1.258 +where
1.259 + "foldM f [] s = return s"
1.260 + | "foldM f (x#xs) s = f x s \<guillemotright>= foldM f xs"
1.261 +
1.262 +definition
1.263 + assert :: "('a \<Rightarrow> bool) \<Rightarrow> 'a \<Rightarrow> 'a Heap"
1.264 +where
1.265 + "assert P x = (if P x then return x else raise (''assert''))"
1.266 +
1.267 +lemma assert_cong [fundef_cong]:
1.268 + assumes "P = P'"
1.269 + assumes "\<And>x. P' x \<Longrightarrow> f x = f' x"
1.270 + shows "(assert P x >>= f) = (assert P' x >>= f')"
1.271 + using assms by (auto simp add: assert_def return_bind raise_bind)
1.272 +
1.273 +hide (open) const heap execute
1.274 +
1.275 +
1.276 +subsection {* Code generator setup *}
1.277 +
1.278 +subsubsection {* Logical intermediate layer *}
1.279 +
1.280 +definition
1.281 + Fail :: "message_string \<Rightarrow> exception"
1.282 +where
1.283 + [code del]: "Fail s = Exn"
1.284 +
1.285 +definition
1.286 + raise_exc :: "exception \<Rightarrow> 'a Heap"
1.287 +where
1.288 + [code del]: "raise_exc e = raise []"
1.289 +
1.290 +lemma raise_raise_exc [code, code inline]:
1.291 + "raise s = raise_exc (Fail (STR s))"
1.292 + unfolding Fail_def raise_exc_def raise_def ..
1.293 +
1.294 +hide (open) const Fail raise_exc
1.295 +
1.296 +
1.297 +subsubsection {* SML and OCaml *}
1.298 +
1.299 +code_type Heap (SML "unit/ ->/ _")
1.300 +code_const Heap (SML "raise/ (Fail/ \"bare Heap\")")
1.301 +code_const "op \<guillemotright>=" (SML "!(fn/ f'_/ =>/ fn/ ()/ =>/ f'_/ (_/ ())/ ())")
1.302 +code_const return (SML "!(fn/ ()/ =>/ _)")
1.303 +code_const "Heap_Monad.Fail" (SML "Fail")
1.304 +code_const "Heap_Monad.raise_exc" (SML "!(fn/ ()/ =>/ raise/ _)")
1.305 +
1.306 +code_type Heap (OCaml "_")
1.307 +code_const Heap (OCaml "failwith/ \"bare Heap\"")
1.308 +code_const "op \<guillemotright>=" (OCaml "!(fun/ f'_/ ()/ ->/ f'_/ (_/ ())/ ())")
1.309 +code_const return (OCaml "!(fun/ ()/ ->/ _)")
1.310 +code_const "Heap_Monad.Fail" (OCaml "Failure")
1.311 +code_const "Heap_Monad.raise_exc" (OCaml "!(fun/ ()/ ->/ raise/ _)")
1.312 +
1.313 +setup {* let
1.314 + open Code_Thingol;
1.315 +
1.316 + fun lookup naming = the o Code_Thingol.lookup_const naming;
1.317 +
1.318 + fun imp_monad_bind'' bind' return' unit' ts =
1.319 + let
1.320 + val dummy_name = "";
1.321 + val dummy_type = ITyVar dummy_name;
1.322 + val dummy_case_term = IVar dummy_name;
1.323 + (*assumption: dummy values are not relevant for serialization*)
1.324 + val unitt = IConst (unit', ([], []));
1.325 + fun dest_abs ((v, ty) `|-> t, _) = ((v, ty), t)
1.326 + | dest_abs (t, ty) =
1.327 + let
1.328 + val vs = Code_Thingol.fold_varnames cons t [];
1.329 + val v = Name.variant vs "x";
1.330 + val ty' = (hd o fst o Code_Thingol.unfold_fun) ty;
1.331 + in ((v, ty'), t `$ IVar v) end;
1.332 + fun force (t as IConst (c, _) `$ t') = if c = return'
1.333 + then t' else t `$ unitt
1.334 + | force t = t `$ unitt;
1.335 + fun tr_bind' [(t1, _), (t2, ty2)] =
1.336 + let
1.337 + val ((v, ty), t) = dest_abs (t2, ty2);
1.338 + in ICase (((force t1, ty), [(IVar v, tr_bind'' t)]), dummy_case_term) end
1.339 + and tr_bind'' t = case Code_Thingol.unfold_app t
1.340 + of (IConst (c, (_, ty1 :: ty2 :: _)), [x1, x2]) => if c = bind'
1.341 + then tr_bind' [(x1, ty1), (x2, ty2)]
1.342 + else force t
1.343 + | _ => force t;
1.344 + in (dummy_name, dummy_type) `|-> ICase (((IVar dummy_name, dummy_type),
1.345 + [(unitt, tr_bind' ts)]), dummy_case_term) end
1.346 + and imp_monad_bind' bind' return' unit' (const as (c, (_, tys))) ts = if c = bind' then case (ts, tys)
1.347 + of ([t1, t2], ty1 :: ty2 :: _) => imp_monad_bind'' bind' return' unit' [(t1, ty1), (t2, ty2)]
1.348 + | ([t1, t2, t3], ty1 :: ty2 :: _) => imp_monad_bind'' bind' return' unit' [(t1, ty1), (t2, ty2)] `$ t3
1.349 + | (ts, _) => imp_monad_bind bind' return' unit' (eta_expand 2 (const, ts))
1.350 + else IConst const `$$ map (imp_monad_bind bind' return' unit') ts
1.351 + and imp_monad_bind bind' return' unit' (IConst const) = imp_monad_bind' bind' return' unit' const []
1.352 + | imp_monad_bind bind' return' unit' (t as IVar _) = t
1.353 + | imp_monad_bind bind' return' unit' (t as _ `$ _) = (case unfold_app t
1.354 + of (IConst const, ts) => imp_monad_bind' bind' return' unit' const ts
1.355 + | (t, ts) => imp_monad_bind bind' return' unit' t `$$ map (imp_monad_bind bind' return' unit') ts)
1.356 + | imp_monad_bind bind' return' unit' (v_ty `|-> t) = v_ty `|-> imp_monad_bind bind' return' unit' t
1.357 + | imp_monad_bind bind' return' unit' (ICase (((t, ty), pats), t0)) = ICase
1.358 + (((imp_monad_bind bind' return' unit' t, ty), (map o pairself) (imp_monad_bind bind' return' unit') pats), imp_monad_bind bind' return' unit' t0);
1.359 +
1.360 + fun imp_program naming = (Graph.map_nodes o map_terms_stmt)
1.361 + (imp_monad_bind (lookup naming @{const_name bindM})
1.362 + (lookup naming @{const_name return})
1.363 + (lookup naming @{const_name Unity}));
1.364 +
1.365 +in
1.366 +
1.367 + Code_Target.extend_target ("SML_imp", ("SML", imp_program))
1.368 + #> Code_Target.extend_target ("OCaml_imp", ("OCaml", imp_program))
1.369 +
1.370 +end
1.371 +*}
1.372 +
1.373 +
1.374 +code_reserved OCaml Failure raise
1.375 +
1.376 +
1.377 +subsubsection {* Haskell *}
1.378 +
1.379 +text {* Adaption layer *}
1.380 +
1.381 +code_include Haskell "STMonad"
1.382 +{*import qualified Control.Monad;
1.383 +import qualified Control.Monad.ST;
1.384 +import qualified Data.STRef;
1.385 +import qualified Data.Array.ST;
1.386 +
1.387 +type RealWorld = Control.Monad.ST.RealWorld;
1.388 +type ST s a = Control.Monad.ST.ST s a;
1.389 +type STRef s a = Data.STRef.STRef s a;
1.390 +type STArray s a = Data.Array.ST.STArray s Int a;
1.391 +
1.392 +runST :: (forall s. ST s a) -> a;
1.393 +runST s = Control.Monad.ST.runST s;
1.394 +
1.395 +newSTRef = Data.STRef.newSTRef;
1.396 +readSTRef = Data.STRef.readSTRef;
1.397 +writeSTRef = Data.STRef.writeSTRef;
1.398 +
1.399 +newArray :: (Int, Int) -> a -> ST s (STArray s a);
1.400 +newArray = Data.Array.ST.newArray;
1.401 +
1.402 +newListArray :: (Int, Int) -> [a] -> ST s (STArray s a);
1.403 +newListArray = Data.Array.ST.newListArray;
1.404 +
1.405 +lengthArray :: STArray s a -> ST s Int;
1.406 +lengthArray a = Control.Monad.liftM snd (Data.Array.ST.getBounds a);
1.407 +
1.408 +readArray :: STArray s a -> Int -> ST s a;
1.409 +readArray = Data.Array.ST.readArray;
1.410 +
1.411 +writeArray :: STArray s a -> Int -> a -> ST s ();
1.412 +writeArray = Data.Array.ST.writeArray;*}
1.413 +
1.414 +code_reserved Haskell RealWorld ST STRef Array
1.415 + runST
1.416 + newSTRef reasSTRef writeSTRef
1.417 + newArray newListArray lengthArray readArray writeArray
1.418 +
1.419 +text {* Monad *}
1.420 +
1.421 +code_type Heap (Haskell "ST/ RealWorld/ _")
1.422 +code_const Heap (Haskell "error/ \"bare Heap\"")
1.423 +code_monad "op \<guillemotright>=" Haskell
1.424 +code_const return (Haskell "return")
1.425 +code_const "Heap_Monad.Fail" (Haskell "_")
1.426 +code_const "Heap_Monad.raise_exc" (Haskell "error")
1.427 +
1.428 +end