renamed "primcorec" to "primcorecursive", to open the door to a 'theory -> theory' command called "primcorec" (cf. "fun" vs. "function")
1 (* Title: HOL/BNF/Tools/bnf_fp_n2m.ML
2 Author: Dmitriy Traytel, TU Muenchen
5 Flattening of nested to mutual (co)recursion.
10 val construct_mutualized_fp: BNF_FP_Util.fp_kind -> typ list -> BNF_FP_Def_Sugar.fp_sugar list ->
11 binding list -> (string * sort) list -> typ list * typ list list -> BNF_Def.bnf list ->
12 local_theory -> BNF_FP_Util.fp_result * local_theory
15 structure BNF_FP_N2M : BNF_FP_N2M =
23 open BNF_FP_N2M_Tactics
25 fun force_typ ctxt T =
26 map_types Type_Infer.paramify_vars
28 #> Syntax.check_term ctxt
29 #> singleton (Variable.polymorphic ctxt);
33 val ((fAT, fBT), fT) = `dest_funT (fastype_of f);
34 val ((gAT, gBT), gT) = `dest_funT (fastype_of g);
36 Const (@{const_name map_pair},
37 fT --> gT --> HOLogic.mk_prodT (fAT, gAT) --> HOLogic.mk_prodT (fBT, gBT)) $ f $ g
42 val ((fAT, fBT), fT) = `dest_funT (fastype_of f);
43 val ((gAT, gBT), gT) = `dest_funT (fastype_of g);
45 Const (@{const_name sum_map}, fT --> gT --> mk_sumT (fAT, gAT) --> mk_sumT (fBT, gBT)) $ f $ g
48 fun construct_mutualized_fp fp fpTs fp_sugars bs resBs (resDs, Dss) bnfs lthy =
50 fun steal get = map (of_fp_sugar (get o #fp_res)) fp_sugars;
53 val deads = fold (union (op =)) Dss resDs;
54 val As = subtract (op =) deads (map TFree resBs);
55 val names_lthy = fold Variable.declare_typ (As @ deads) lthy;
58 val ((Xs, Bs), names_lthy) = names_lthy
62 val phiTs = map2 mk_pred2T As Bs;
64 val fpTs' = map (Term.typ_subst_atomic theta) fpTs;
65 val pre_phiTs = map2 mk_pred2T fpTs fpTs';
67 fun mk_co_algT T U = fp_case fp (T --> U) (U --> T);
68 fun co_swap pair = fp_case fp I swap pair;
69 val dest_co_algT = co_swap o dest_funT;
70 val co_alg_argT = fp_case fp range_type domain_type;
71 val co_alg_funT = fp_case fp domain_type range_type;
72 val mk_co_product = curry (fp_case fp mk_convol mk_sum_case);
73 val mk_map_co_product = fp_case fp mk_prod_map mk_sum_map;
74 val co_proj1_const = fp_case fp (fst_const o fst) (uncurry Inl_const o dest_sumT o snd);
75 val mk_co_productT = curry (fp_case fp HOLogic.mk_prodT mk_sumT);
76 val dest_co_productT = fp_case fp HOLogic.dest_prodT dest_sumT;
78 val ((ctors, dtors), (xtor's, xtors)) =
80 val ctors = map2 (force_typ names_lthy o (fn T => dummyT --> T)) fpTs (steal #ctors);
81 val dtors = map2 (force_typ names_lthy o (fn T => T --> dummyT)) fpTs (steal #dtors);
83 ((ctors, dtors), `(map (Term.subst_atomic_types theta)) (fp_case fp ctors dtors))
86 val xTs = map (domain_type o fastype_of) xtors;
87 val yTs = map (domain_type o fastype_of) xtor's;
89 val (((((phis, phis'), pre_phis), xs), ys), names_lthy) = names_lthy
90 |> mk_Frees' "R" phiTs
91 ||>> mk_Frees "S" pre_phiTs
93 ||>> mk_Frees "y" yTs;
95 val fp_bnfs = steal #bnfs;
96 val pre_bnfs = map (of_fp_sugar #pre_bnfs) fp_sugars;
97 val pre_bnfss = map #pre_bnfs fp_sugars;
98 val nesty_bnfss = map (fn sugar => #nested_bnfs sugar @ #nesting_bnfs sugar) fp_sugars;
99 val fp_nesty_bnfss = fp_bnfs :: nesty_bnfss;
100 val fp_nesty_bnfs = distinct eq_bnf (flat fp_nesty_bnfss);
103 let val Ts = Term.add_frees t [];
104 in fold_rev Term.absfree (filter (member op = Ts) phis') t end;
108 fun find_rel T As Bs = fp_nesty_bnfss
109 |> map (filter_out (curry eq_bnf BNF_Comp.DEADID_bnf))
110 |> get_first (find_first (fn bnf => Type.could_unify (T_of_bnf bnf, T)))
111 |> Option.map (fn bnf =>
112 let val live = live_of_bnf bnf;
113 in (mk_rel live As Bs (rel_of_bnf bnf), live) end)
114 |> the_default (HOLogic.eq_const T, 0);
116 fun mk_rel (T as Type (_, Ts)) (Type (_, Us)) =
118 val (rel, live) = find_rel T Ts Us;
119 val (Ts', Us') = fastype_of rel |> strip_typeN live |> fst |> map_split dest_pred2T;
120 val rels = map2 mk_rel Ts' Us';
122 Term.list_comb (rel, rels)
124 | mk_rel (T as TFree _) _ = nth phis (find_index (curry op = T) As)
125 | mk_rel _ _ = raise Fail "fpTs contains schematic type variables";
127 map2 (abstract oo mk_rel) fpTs fpTs'
130 val pre_rels = map2 (fn Ds => mk_rel_of_bnf Ds (As @ fpTs) (Bs @ fpTs')) Dss bnfs;
132 val rel_unfoldss = map (maps (fn bnf => no_refl [rel_def_of_bnf bnf])) pre_bnfss;
133 val rel_xtor_co_inducts = steal (split_conj_thm o #rel_xtor_co_induct_thm)
134 |> map2 (fn unfs => unfold_thms lthy (id_apply :: unfs)) rel_unfoldss;
136 val rel_defs = map rel_def_of_bnf bnfs;
137 val rel_monos = map rel_mono_of_bnf bnfs;
139 val rel_xtor_co_induct_thm =
140 mk_rel_xtor_co_induct_thm fp pre_rels pre_phis rels phis xs ys xtors xtor's
141 (mk_rel_xtor_co_induct_tactic fp rel_xtor_co_inducts rel_defs rel_monos) lthy;
143 val rel_eqs = no_refl (map rel_eq_of_bnf fp_nesty_bnfs);
144 val map_id0s = no_refl (map map_id0_of_bnf bnfs);
146 val xtor_co_induct_thm =
150 val (Ps, names_lthy) = names_lthy
151 |> mk_Frees "P" (map (fn T => T --> HOLogic.boolT) fpTs);
153 let val T = domain_type (fastype_of P);
154 in mk_Grp (HOLogic.Collect_const T $ P) (HOLogic.id_const T) end;
155 val cts = map (SOME o certify lthy) (map HOLogic.eq_const As @ map mk_Grp_id Ps);
157 cterm_instantiate_pos cts rel_xtor_co_induct_thm
158 |> singleton (Proof_Context.export names_lthy lthy)
159 |> unfold_thms lthy (@{thms eq_le_Grp_id_iff all_simps(1,2)[symmetric]} @ rel_eqs)
160 |> funpow n (fn thm => thm RS spec)
161 |> unfold_thms lthy (@{thm eq_alt} :: map rel_Grp_of_bnf bnfs @ map_id0s)
162 |> unfold_thms lthy @{thms Grp_id_mono_subst eqTrueI[OF subset_UNIV] simp_thms(22)}
163 |> unfold_thms lthy @{thms subset_iff mem_Collect_eq
164 atomize_conjL[symmetric] atomize_all[symmetric] atomize_imp[symmetric]}
165 |> unfold_thms lthy (maps set_defs_of_bnf bnfs)
169 val cts = NONE :: map (SOME o certify lthy) (map HOLogic.eq_const As);
171 cterm_instantiate_pos cts rel_xtor_co_induct_thm
172 |> unfold_thms lthy (@{thms le_fun_def le_bool_def all_simps(1,2)[symmetric]} @ rel_eqs)
173 |> funpow (2 * n) (fn thm => thm RS spec)
174 |> Conv.fconv_rule Object_Logic.atomize
175 |> funpow n (fn thm => thm RS mp)
178 val fold_preTs = map2 (fn Ds => mk_T_of_bnf Ds allAs) Dss bnfs;
179 val fold_pre_deads_only_Ts = map2 (fn Ds => mk_T_of_bnf Ds (replicate live dummyT)) Dss bnfs;
180 val rec_theta = Xs ~~ map2 mk_co_productT fpTs Xs;
181 val rec_preTs = map (Term.typ_subst_atomic rec_theta) fold_preTs;
183 val fold_strTs = map2 mk_co_algT fold_preTs Xs;
184 val rec_strTs = map2 mk_co_algT rec_preTs Xs;
185 val resTs = map2 mk_co_algT fpTs Xs;
187 val (((fold_strs, fold_strs'), (rec_strs, rec_strs')), names_lthy) = names_lthy
188 |> mk_Frees' "s" fold_strTs
189 ||>> mk_Frees' "s" rec_strTs;
191 val co_iters = steal #xtor_co_iterss;
192 val ns = map (length o #pre_bnfs) fp_sugars;
193 fun substT rho (Type (@{type_name "fun"}, [T, U])) = substT rho T --> substT rho U
194 | substT rho (Type (s, Ts)) = Type (s, map (typ_subst_nonatomic rho) Ts)
196 fun force_iter is_rec i TU TU_rec raw_iters =
198 val approx_fold = un_fold_of raw_iters
199 |> force_typ names_lthy
200 (replicate (nth ns i) dummyT ---> (if is_rec then TU_rec else TU));
201 val TUs = binder_fun_types (Term.typ_subst_atomic (Xs ~~ fpTs) (fastype_of approx_fold));
202 val js = find_indices Type.could_unify
203 TUs (map (Term.typ_subst_atomic (Xs ~~ fpTs)) fold_strTs);
204 val Tpats = map (fn j => mk_co_algT (nth fold_pre_deads_only_Ts j) (nth Xs j)) js;
205 val iter = raw_iters |> (if is_rec then co_rec_of else un_fold_of);
207 force_typ names_lthy (Tpats ---> TU) iter
210 fun mk_iter b_opt is_rec iters lthy TU =
212 val x = co_alg_argT TU;
213 val i = find_index (fn T => x = T) Xs;
215 (case find_first (fn f => body_fun_type (fastype_of f) = TU) iters of
216 NONE => nth co_iters i
217 |> force_iter is_rec i
218 (TU |> (is_none b_opt andalso not is_rec) ? substT (fpTs ~~ Xs))
219 (TU |> (is_none b_opt) ? substT (map2 mk_co_productT fpTs Xs ~~ Xs))
221 val TUs = binder_fun_types (fastype_of TUiter);
222 val iter_preTs = if is_rec then rec_preTs else fold_preTs;
223 val iter_strs = if is_rec then rec_strs else fold_strs;
226 val i = find_index (fn T => co_alg_argT TU' = T) Xs;
227 val sF = co_alg_funT TU';
228 val F = nth iter_preTs i;
229 val s = nth iter_strs i;
234 val smapT = replicate live dummyT ---> mk_co_algT sF F;
235 fun hidden_to_unit t =
236 Term.subst_TVars (map (rpair HOLogic.unitT) (Term.add_tvar_names t [])) t;
237 val smap = map_of_bnf (nth bnfs i)
238 |> force_typ names_lthy smapT
240 val smap_argTs = strip_typeN live (fastype_of smap) |> fst;
242 (if domain_type TU = range_type TU then
243 HOLogic.id_const (domain_type TU)
246 val (TY, (U, X)) = TU |> dest_co_algT ||> dest_co_productT;
247 val T = mk_co_algT TY U;
249 (case try (force_typ lthy T o build_map lthy co_proj1_const o dest_funT) T of
250 SOME f => mk_co_product f
251 (fst (fst (mk_iter NONE is_rec iters lthy (mk_co_algT TY X))))
252 | NONE => mk_map_co_product
253 (build_map lthy co_proj1_const
254 (dest_funT (mk_co_algT (dest_co_productT TY |> fst) U)))
255 (HOLogic.id_const X))
258 fst (fst (mk_iter NONE is_rec iters lthy TU)))
259 val smap_args = map mk_smap_arg smap_argTs;
261 HOLogic.mk_comp (co_swap (s, Term.list_comb (smap, smap_args)))
264 val t = Term.list_comb (TUiter, map mk_s TUs);
267 NONE => ((t, Drule.dummy_thm), lthy)
268 | SOME b => Local_Theory.define ((b, NoSyn), ((Thm.def_binding b, []),
269 fold_rev Term.absfree (if is_rec then rec_strs' else fold_strs') t)) lthy |>> apsnd snd)
272 fun mk_iters is_rec name lthy =
273 fold2 (fn TU => fn b => fn ((iters, defs), lthy) =>
274 mk_iter (SOME b) is_rec iters lthy TU |>> (fn (f, d) => (f :: iters, d :: defs)))
275 resTs (map (Binding.suffix_name ("_" ^ name)) bs) (([], []), lthy)
276 |>> apfst rev o apsnd rev;
277 val foldN = fp_case fp ctor_foldN dtor_unfoldN;
278 val recN = fp_case fp ctor_recN dtor_corecN;
279 val (((raw_un_folds, raw_un_fold_defs), (raw_co_recs, raw_co_rec_defs)), (lthy, raw_lthy)) =
281 |> mk_iters false foldN
282 ||>> mk_iters true recN
283 ||> `Local_Theory.restore;
285 val phi = Proof_Context.export_morphism raw_lthy lthy;
287 val un_folds = map (Morphism.term phi) raw_un_folds;
288 val co_recs = map (Morphism.term phi) raw_co_recs;
290 val (xtor_un_fold_thms, xtor_co_rec_thms) =
292 val folds = map (fn f => Term.list_comb (f, fold_strs)) raw_un_folds;
293 val recs = map (fn r => Term.list_comb (r, rec_strs)) raw_co_recs;
294 val fold_mapTs = co_swap (As @ fpTs, As @ Xs);
295 val rec_mapTs = co_swap (As @ fpTs, As @ map2 mk_co_productT fpTs Xs);
297 map2 (fn Ds => fn bnf =>
298 Term.list_comb (uncurry (mk_map_of_bnf Ds) fold_mapTs bnf,
299 map HOLogic.id_const As @ folds))
302 map2 (fn Ds => fn bnf =>
303 Term.list_comb (uncurry (mk_map_of_bnf Ds) rec_mapTs bnf,
304 map HOLogic.id_const As @ map2 (mk_co_product o HOLogic.id_const) fpTs recs))
307 fun mk_goals f xtor s smap =
308 ((f, xtor), (s, smap))
309 |> pairself (HOLogic.mk_comp o co_swap)
312 val fold_goals = map4 mk_goals folds xtors fold_strs pre_fold_maps
313 val rec_goals = map4 mk_goals recs xtors rec_strs pre_rec_maps;
315 fun mk_thms ss goals tac =
316 Library.foldr1 HOLogic.mk_conj goals
317 |> HOLogic.mk_Trueprop
318 |> fold_rev Logic.all ss
319 |> (fn goal => Goal.prove_sorry raw_lthy [] [] goal tac)
320 |> Thm.close_derivation
323 |> map (fn thm => thm RS @{thm comp_eq_dest});
325 val pre_map_defs = no_refl (map map_def_of_bnf bnfs);
326 val fp_pre_map_defs = no_refl (map map_def_of_bnf pre_bnfs);
328 val map_unfoldss = map (maps (fn bnf => no_refl [map_def_of_bnf bnf])) pre_bnfss;
329 val unfold_map = map2 (fn unfs => unfold_thms lthy (id_apply :: unfs)) map_unfoldss;
331 val fp_xtor_co_iterss = steal #xtor_co_iter_thmss;
332 val fp_xtor_un_folds = map (mk_pointfree lthy o un_fold_of) fp_xtor_co_iterss |> unfold_map;
333 val fp_xtor_co_recs = map (mk_pointfree lthy o co_rec_of) fp_xtor_co_iterss |> unfold_map;
335 val fp_co_iter_o_mapss = steal #xtor_co_iter_o_map_thmss;
336 val fp_fold_o_maps = map un_fold_of fp_co_iter_o_mapss |> unfold_map;
337 val fp_rec_o_maps = map co_rec_of fp_co_iter_o_mapss |> unfold_map;
338 val fold_thms = fp_case fp @{thm o_assoc[symmetric]} @{thm o_assoc} ::
339 @{thms id_apply o_apply o_id id_o map_pair.comp map_pair.id sum_map.comp sum_map.id};
340 val rec_thms = fold_thms @ fp_case fp
341 @{thms fst_convol map_pair_o_convol convol_o}
342 @{thms sum_case_o_inj(1) sum_case_o_sum_map o_sum_case};
343 val map_thms = no_refl (maps (fn bnf =>
344 [map_comp0_of_bnf bnf RS sym, map_id0_of_bnf bnf]) fp_nesty_bnfs);
346 fun mk_tac defs o_map_thms xtor_thms thms {context = ctxt, prems = _} =
348 (flat [thms, defs, pre_map_defs, fp_pre_map_defs, xtor_thms, o_map_thms, map_thms]) THEN
349 CONJ_WRAP (K (HEADGOAL (rtac refl))) bnfs;
351 val fold_tac = mk_tac raw_un_fold_defs fp_fold_o_maps fp_xtor_un_folds fold_thms;
352 val rec_tac = mk_tac raw_co_rec_defs fp_rec_o_maps fp_xtor_co_recs rec_thms;
354 (mk_thms fold_strs fold_goals fold_tac, mk_thms rec_strs rec_goals rec_tac)
357 (* These results are half broken. This is deliberate. We care only about those fields that are
358 used by "primrec_new", "primcorecursive", and "datatype_new_compat". *)
364 xtor_co_iterss = transpose [un_folds, co_recs],
365 xtor_co_induct = xtor_co_induct_thm,
366 dtor_ctors = steal #dtor_ctors (*too general types*),
367 ctor_dtors = steal #ctor_dtors (*too general types*),
368 ctor_injects = steal #ctor_injects (*too general types*),
369 dtor_injects = steal #dtor_injects (*too general types*),
370 xtor_map_thms = steal #xtor_map_thms (*too general types and terms*),
371 xtor_set_thmss = steal #xtor_set_thmss (*too general types and terms*),
372 xtor_rel_thms = steal #xtor_rel_thms (*too general types and terms*),
373 xtor_co_iter_thmss = transpose [xtor_un_fold_thms, xtor_co_rec_thms],
374 xtor_co_iter_o_map_thmss = steal #xtor_co_iter_o_map_thmss (*theorem about old constant*),
375 rel_xtor_co_induct_thm = rel_xtor_co_induct_thm}
376 |> morph_fp_result (Morphism.term_morphism (singleton (Variable.polymorphic lthy))));