1 (* Title: HOL/Tools/datatype_codegen.ML
2 Author: Stefan Berghofer and Florian Haftmann, TU Muenchen
4 Code generator facilities for inductive datatypes.
7 signature DATATYPE_CODEGEN =
9 val mk_eq_eqns: theory -> string -> (thm * bool) list
10 val mk_case_cert: theory -> string -> thm
11 val setup: theory -> theory
14 structure DatatypeCodegen : DATATYPE_CODEGEN =
17 (** SML code generator **)
21 (**** datatype definition ****)
23 (* find shortest path to constructor with no recursive arguments *)
25 fun find_nonempty (descr: DatatypeAux.descr) is i =
27 val (_, _, constrs) = valOf (AList.lookup (op =) descr i);
28 fun arg_nonempty (_, DatatypeAux.DtRec i) = if i mem is then NONE
29 else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i)
30 | arg_nonempty _ = SOME 0;
31 fun max xs = Library.foldl
33 | (SOME i, SOME j) => SOME (Int.max (i, j))
34 | (_, NONE) => NONE) (SOME 0, xs);
35 val xs = sort (int_ord o pairself snd)
36 (List.mapPartial (fn (s, dts) => Option.map (pair s)
37 (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
38 in case xs of [] => NONE | x :: _ => SOME x end;
40 fun add_dt_defs thy defs dep module (descr: DatatypeAux.descr) sorts gr =
42 val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
43 val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
44 exists (exists DatatypeAux.is_rec_type o snd) cs) descr');
46 val (_, (tname, _, _)) :: _ = descr';
47 val node_id = tname ^ " (type)";
48 val module' = if_library (thyname_of_type thy tname) module;
50 fun mk_dtdef prfx [] gr = ([], gr)
51 | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr =
53 val tvs = map DatatypeAux.dest_DtTFree dts;
54 val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
55 val ((_, type_id), gr') = mk_type_id module' tname gr;
56 val (ps, gr'') = gr' |>
57 fold_map (fn (cname, cargs) =>
58 fold_map (invoke_tycodegen thy defs node_id module' false)
60 mk_const_id module' cname) cs';
61 val (rest, gr''') = mk_dtdef "and " xs gr''
63 (Pretty.block (str prfx ::
64 (if null tvs then [] else
65 [mk_tuple (map str tvs), str " "]) @
66 [str (type_id ^ " ="), Pretty.brk 1] @
67 List.concat (separate [Pretty.brk 1, str "| "]
68 (map (fn (ps', (_, cname)) => [Pretty.block
70 (if null ps' then [] else
71 List.concat ([str " of", Pretty.brk 1] ::
72 separate [str " *", Pretty.brk 1]
73 (map single ps'))))]) ps))) :: rest, gr''')
76 fun mk_constr_term cname Ts T ps =
77 List.concat (separate [str " $", Pretty.brk 1]
78 ([str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
79 mk_type false (Ts ---> T), str ")"] :: ps));
81 fun mk_term_of_def gr prfx [] = []
82 | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
84 val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
85 val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
86 val T = Type (tname, dts');
87 val rest = mk_term_of_def gr "and " xs;
88 val (_, eqs) = Library.foldl_map (fn (prfx, (cname, Ts)) =>
89 let val args = map (fn i =>
90 str ("x" ^ string_of_int i)) (1 upto length Ts)
91 in (" | ", Pretty.blk (4,
92 [str prfx, mk_term_of gr module' false T, Pretty.brk 1,
93 if null Ts then str (snd (get_const_id gr cname))
94 else parens (Pretty.block
95 [str (snd (get_const_id gr cname)),
96 Pretty.brk 1, mk_tuple args]),
97 str " =", Pretty.brk 1] @
98 mk_constr_term cname Ts T
99 (map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
100 Pretty.brk 1, x]]) (args ~~ Ts))))
104 fun mk_gen_of_def gr prfx [] = []
105 | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
107 val tvs = map DatatypeAux.dest_DtTFree dts;
108 val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
109 val T = Type (tname, Us);
111 List.partition (exists DatatypeAux.is_rec_type o snd) cs;
112 val SOME (cname, _) = find_nonempty descr [i] i;
114 fun mk_delay p = Pretty.block
115 [str "fn () =>", Pretty.brk 1, p];
117 fun mk_force p = Pretty.block [p, Pretty.brk 1, str "()"];
119 fun mk_constr s b (cname, dts) =
121 val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
122 (DatatypeAux.typ_of_dtyp descr sorts dt))
123 [str (if b andalso DatatypeAux.is_rec_type dt then "0"
125 val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
127 (DatatypeProp.indexify_names (replicate (length dts) "x"));
129 (DatatypeProp.indexify_names (replicate (length dts) "t"));
130 val (_, id) = get_const_id gr cname
133 (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
136 _ :: _ :: _ => Pretty.block
137 [str id, Pretty.brk 1, mk_tuple xs]
138 | _ => mk_app false (str id) xs,
139 mk_delay (Pretty.block (mk_constr_term cname Ts T
140 (map (single o mk_force) ts)))])
143 fun mk_choice [c] = mk_constr "(i-1)" false c
144 | mk_choice cs = Pretty.block [str "one_of",
145 Pretty.brk 1, Pretty.blk (1, str "[" ::
146 List.concat (separate [str ",", Pretty.fbrk]
147 (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
148 [str "]"]), Pretty.brk 1, str "()"];
150 val gs = maps (fn s =>
151 let val s' = strip_tname s
152 in [str (s' ^ "G"), str (s' ^ "T")] end) tvs;
153 val gen_name = "gen_" ^ snd (get_type_id gr tname)
156 Pretty.blk (4, separate (Pretty.brk 1)
157 (str (prfx ^ gen_name ^
158 (if null cs1 then "" else "'")) :: gs @
159 (if null cs1 then [] else [str "i"]) @
161 [str " =", Pretty.brk 1] @
162 (if not (null cs1) andalso not (null cs2)
163 then [str "frequency", Pretty.brk 1,
164 Pretty.blk (1, [str "[",
165 mk_tuple [str "i", mk_delay (mk_choice cs1)],
166 str ",", Pretty.fbrk,
167 mk_tuple [str "1", mk_delay (mk_choice cs2)],
168 str "]"]), Pretty.brk 1, str "()"]
169 else if null cs2 then
170 [Pretty.block [str "(case", Pretty.brk 1,
171 str "i", Pretty.brk 1, str "of",
172 Pretty.brk 1, str "0 =>", Pretty.brk 1,
173 mk_constr "0" true (cname, valOf (AList.lookup (op =) cs cname)),
174 Pretty.brk 1, str "| _ =>", Pretty.brk 1,
175 mk_choice cs1, str ")"]]
176 else [mk_choice cs2])) ::
178 else [Pretty.blk (4, separate (Pretty.brk 1)
179 (str ("and " ^ gen_name) :: gs @ [str "i"]) @
180 [str " =", Pretty.brk 1] @
181 separate (Pretty.brk 1) (str (gen_name ^ "'") :: gs @
182 [str "i", str "i"]))]) @
183 mk_gen_of_def gr "and " xs
187 (module', (add_edge_acyclic (node_id, dep) gr
188 handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
190 val gr1 = add_edge (node_id, dep)
191 (new_node (node_id, (NONE, "", "")) gr);
192 val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ;
194 map_node node_id (K (NONE, module',
195 string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
196 [str ";"])) ^ "\n\n" ^
197 (if "term_of" mem !mode then
198 string_of (Pretty.blk (0, separate Pretty.fbrk
199 (mk_term_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
201 (if "test" mem !mode then
202 string_of (Pretty.blk (0, separate Pretty.fbrk
203 (mk_gen_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
209 (**** case expressions ****)
211 fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr =
212 let val i = length constrs
213 in if length ts <= i then
214 invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr
217 val ts1 = Library.take (i, ts);
218 val t :: ts2 = Library.drop (i, ts);
219 val names = List.foldr OldTerm.add_term_names
220 (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1;
221 val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
223 fun pcase [] [] [] gr = ([], gr)
224 | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
226 val j = length cargs;
227 val xs = Name.variant_list names (replicate j "x");
228 val Us' = Library.take (j, fst (strip_type U));
229 val frees = map Free (xs ~~ Us');
230 val (cp, gr0) = invoke_codegen thy defs dep module false
231 (list_comb (Const (cname, Us' ---> dT), frees)) gr;
232 val t' = Envir.beta_norm (list_comb (t, frees));
233 val (p, gr1) = invoke_codegen thy defs dep module false t' gr0;
234 val (ps, gr2) = pcase cs ts Us gr1;
236 ([Pretty.block [cp, str " =>", Pretty.brk 1, p]] :: ps, gr2)
239 val (ps1, gr1) = pcase constrs ts1 Ts gr ;
240 val ps = List.concat (separate [Pretty.brk 1, str "| "] ps1);
241 val (p, gr2) = invoke_codegen thy defs dep module false t gr1;
242 val (ps2, gr3) = fold_map (invoke_codegen thy defs dep module true) ts2 gr2;
243 in ((if not (null ts2) andalso brack then parens else I)
244 (Pretty.block (separate (Pretty.brk 1)
245 (Pretty.block ([str "(case ", p, str " of",
246 Pretty.brk 1] @ ps @ [str ")"]) :: ps2))), gr3)
251 (**** constructors ****)
253 fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr =
254 let val i = length args
255 in if i > 1 andalso length ts < i then
256 invoke_codegen thy defs dep module brack (eta_expand c ts i) gr
259 val id = mk_qual_id module (get_const_id gr s);
260 val (ps, gr') = fold_map
261 (invoke_codegen thy defs dep module (i = 1)) ts gr;
263 _ :: _ :: _ => (if brack then parens else I)
264 (Pretty.block [str id, Pretty.brk 1, mk_tuple ps])
265 | _ => (mk_app brack (str id) ps), gr')
270 (**** code generators for terms and types ****)
272 fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of
273 (c as Const (s, T), ts) =>
274 (case DatatypePackage.datatype_of_case thy s of
275 SOME {index, descr, ...} =>
276 if is_some (get_assoc_code thy (s, T)) then NONE else
277 SOME (pretty_case thy defs dep module brack
278 (#3 (the (AList.lookup op = descr index))) c ts gr )
279 | NONE => case (DatatypePackage.datatype_of_constr thy s, strip_type T) of
280 (SOME {index, descr, ...}, (_, U as Type (tyname, _))) =>
281 if is_some (get_assoc_code thy (s, T)) then NONE else
283 val SOME (tyname', _, constrs) = AList.lookup op = descr index;
284 val SOME args = AList.lookup op = constrs s
286 if tyname <> tyname' then NONE
287 else SOME (pretty_constr thy defs
288 dep module brack args c ts (snd (invoke_tycodegen thy defs dep module false U gr)))
293 fun datatype_tycodegen thy defs dep module brack (Type (s, Ts)) gr =
294 (case DatatypePackage.get_datatype thy s of
296 | SOME {descr, sorts, ...} =>
297 if is_some (get_assoc_type thy s) then NONE else
299 val (ps, gr') = fold_map
300 (invoke_tycodegen thy defs dep module false) Ts gr;
301 val (module', gr'') = add_dt_defs thy defs dep module descr sorts gr' ;
302 val (tyid, gr''') = mk_type_id module' s gr''
303 in SOME (Pretty.block ((if null Ts then [] else
304 [mk_tuple ps, str " "]) @
305 [str (mk_qual_id module tyid)]), gr''')
307 | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
310 (** generic code generator **)
312 (* case certificates *)
314 fun mk_case_cert thy tyco =
317 (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
318 val thms as hd_thm :: _ = raw_thms
319 |> Conjunction.intr_balanced
321 |> Conjunction.elim_balanced (length raw_thms)
322 |> map Simpdata.mk_meta_eq
323 |> map Drule.zero_var_indexes
324 val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
325 | _ => I) (Thm.prop_of hd_thm) [];
331 |> apsnd (fst o split_last)
333 val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
334 val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
337 |> Conjunction.intr_balanced
338 |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
339 |> Thm.implies_intr asm
340 |> Thm.generalize ([], params) 0
341 |> AxClass.unoverload thy
348 fun mk_eq_eqns thy dtco =
350 val (vs, cos) = DatatypePackage.the_datatype_spec thy dtco;
351 val { descr, index, inject = inject_thms, ... } = DatatypePackage.the_datatype thy dtco;
352 val ty = Type (dtco, map TFree vs);
353 fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
355 fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
356 fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
357 val triv_injects = map_filter
358 (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
360 fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
361 trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
362 val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index);
363 fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
364 [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
365 val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index));
366 val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
367 val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss
368 addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms))
369 addsimprocs [DatatypePackage.distinct_simproc]);
370 fun prove prop = Goal.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
372 in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end;
374 fun add_equality vs dtcos thy =
376 fun add_def dtco lthy =
378 val ty = Type (dtco, map TFree vs);
379 fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
380 $ Free ("x", ty) $ Free ("y", ty);
381 val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
382 (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
383 val def' = Syntax.check_term lthy def;
384 val ((_, (_, thm)), lthy') = Specification.definition
385 (NONE, (Attrib.empty_binding, def')) lthy;
386 val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
387 val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
388 in (thm', lthy') end;
389 fun tac thms = Class.intro_classes_tac []
390 THEN ALLGOALS (ProofContext.fact_tac thms);
391 fun add_eq_thms dtco thy =
393 val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
394 val thy_ref = Theory.check_thy thy;
395 fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco));
397 Code.add_eqnl (const, Lazy.lazy mk_thms) thy
401 |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq])
402 |> fold_map add_def dtcos
403 |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
404 (fn _ => fn def_thms => tac def_thms) def_thms)
405 |-> (fn def_thms => fold Code.del_eqn def_thms)
406 |> fold add_eq_thms dtcos
410 (* liberal addition of code data for datatypes *)
412 fun mk_constr_consts thy vs dtco cos =
414 val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
415 val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
416 in if is_some (try (Code_Unit.constrset_of_consts thy) cs')
421 fun add_all_code dtcos thy =
423 val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
424 val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
425 val css = if exists is_none any_css then []
426 else map_filter I any_css;
427 val case_rewrites = maps (#case_rewrites o DatatypePackage.the_datatype thy) dtcos;
428 val certs = map (mk_case_cert thy) dtcos;
432 |> fold Code.add_datatype css
433 |> fold_rev Code.add_default_eqn case_rewrites
434 |> fold Code.add_case certs
435 |> add_equality vs dtcos
443 add_codegen "datatype" datatype_codegen
444 #> add_tycodegen "datatype" datatype_tycodegen
445 #> DatatypePackage.interpretation add_all_code