1 (* Title: HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
2 Author: Jasmin Blanchette, TU Muenchen
5 Suggared flattening of nested to mutual (co)recursion.
8 signature BNF_FP_N2M_SUGAR =
10 val unfold_let: term -> term
11 val dest_map: Proof.context -> string -> term -> term * term list
13 val mutualize_fp_sugars: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
14 (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
16 (BNF_FP_Def_Sugar.fp_sugar list
17 * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
19 val indexify_callsss: BNF_FP_Def_Sugar.fp_sugar -> (term * term list list) list ->
21 val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
22 (term * term list list) list list -> local_theory ->
23 (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
24 * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
28 structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
40 type n2m_sugar = fp_sugar list * (lfp_sugar_thms option * gfp_sugar_thms option);
42 structure Data = Generic_Data
44 type T = n2m_sugar Typtab.table;
45 val empty = Typtab.empty;
47 val merge = Typtab.merge (eq_fst (eq_list eq_fp_sugar));
50 fun morph_n2m_sugar phi (fp_sugars, (lfp_sugar_thms_opt, gfp_sugar_thms_opt)) =
51 (map (morph_fp_sugar phi) fp_sugars,
52 (Option.map (morph_lfp_sugar_thms phi) lfp_sugar_thms_opt,
53 Option.map (morph_gfp_sugar_thms phi) gfp_sugar_thms_opt));
55 val transfer_n2m_sugar =
56 morph_n2m_sugar o Morphism.thm_morphism o Thm.transfer o Proof_Context.theory_of;
58 fun n2m_sugar_of ctxt =
59 Typtab.lookup (Data.get (Context.Proof ctxt))
60 #> Option.map (transfer_n2m_sugar ctxt);
62 fun register_n2m_sugar key n2m_sugar =
63 Local_Theory.declaration {syntax = false, pervasive = false}
64 (fn phi => Data.map (Typtab.default (key, morph_n2m_sugar phi n2m_sugar)));
66 fun unfold_let (Const (@{const_name Let}, _) $ arg1 $ arg2) = unfold_let (betapply (arg2, arg1))
67 | unfold_let (Const (@{const_name prod_case}, _) $ t) =
69 t' as Abs (s1, T1, Abs (s2, T2, _)) =>
71 val x = (s1 ^ s2, Term.maxidx_of_term t + 1);
72 val v = Var (x, HOLogic.mk_prodT (T1, T2));
74 lambda v (unfold_let (betapplys (t', [HOLogic.mk_fst v, HOLogic.mk_snd v])))
77 | unfold_let (t $ u) = betapply (unfold_let t, unfold_let u)
78 | unfold_let (Abs (s, T, t)) = Abs (s, T, unfold_let t)
81 fun mk_map_pattern ctxt s =
83 val bnf = the (bnf_of ctxt s);
84 val mapx = map_of_bnf bnf;
85 val live = live_of_bnf bnf;
86 val (f_Ts, _) = strip_typeN live (fastype_of mapx);
87 val fs = map_index (fn (i, T) => Var (("?f", i), T)) f_Ts;
89 (mapx, betapplys (mapx, fs))
92 fun dest_map ctxt s call =
94 val (map0, pat) = mk_map_pattern ctxt s;
95 val (_, tenv) = fo_match ctxt call pat;
97 (map0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
100 fun dest_abs_or_applied_map _ _ (Abs (_, _, t)) = (Term.dummy, [t])
101 | dest_abs_or_applied_map ctxt s (t1 $ _) = dest_map ctxt s t1;
103 fun map_partition f xs =
104 fold_rev (fn x => fn (ys, (good, bad)) =>
105 case f x of SOME y => (y :: ys, (x :: good, bad)) | NONE => (ys, (good, x :: bad)))
108 fun key_of_fp_eqs fp fpTs fp_eqs =
109 Type (fp_case fp "l" "g", fpTs @ maps (fn (x, T) => [TFree x, T]) fp_eqs);
111 (* TODO: test with sort constraints on As *)
112 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
114 fun mutualize_fp_sugars has_nested fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
117 val thy = Proof_Context.theory_of no_defs_lthy0;
119 val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
121 fun incompatible_calls t1 t2 =
122 error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^
125 val b_names = map Binding.name_of bs;
126 val fp_b_names = map base_name_of_typ fpTs;
128 val nn = length fpTs;
130 fun target_ctr_sugar_of_fp_sugar fpT ({T, index, ctr_sugars, ...} : fp_sugar) =
132 val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
133 val phi = Morphism.term_morphism (Term.subst_TVars rho);
135 morph_ctr_sugar phi (nth ctr_sugars index)
138 val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
139 val mapss = map (of_fp_sugar #mapss) fp_sugars0;
140 val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
142 val ctrss = map #ctrs ctr_sugars0;
143 val ctr_Tss = map (map fastype_of) ctrss;
145 val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
146 val As = map TFree As';
148 val ((Cs, Xs), no_defs_lthy) =
150 |> fold Variable.declare_typ As
152 ||>> variant_tfrees fp_b_names;
154 fun check_call_dead live_call call =
155 if null (get_indices call) then () else incompatible_calls live_call call;
157 fun freeze_fpTs_simple (T as Type (s, Ts)) =
158 (case find_index (curry (op =) T) fpTs of
159 ~1 => Type (s, map freeze_fpTs_simple Ts)
161 | freeze_fpTs_simple T = T;
163 fun freeze_fpTs_map (callss, (live_call :: _, dead_calls)) s Ts =
164 (List.app (check_call_dead live_call) dead_calls;
165 Type (s, map2 freeze_fpTs (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
166 (transpose callss)) Ts))
167 and freeze_fpTs calls (T as Type (s, Ts)) =
168 (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
170 (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
171 ([], _) => freeze_fpTs_simple T
172 | callsp => freeze_fpTs_map callsp s Ts)
173 | callsp => freeze_fpTs_map callsp s Ts)
174 | freeze_fpTs _ T = T;
176 val ctr_Tsss = map (map binder_types) ctr_Tss;
177 val ctrXs_Tsss = map2 (map2 (map2 freeze_fpTs)) callssss ctr_Tsss;
178 val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
179 val Ts = map (body_type o hd) ctr_Tss;
181 val ns = map length ctr_Tsss;
182 val kss = map (fn n => 1 upto n) ns;
183 val mss = map (map length) ctr_Tsss;
185 val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
186 val key = key_of_fp_eqs fp fpTs fp_eqs;
188 (case n2m_sugar_of no_defs_lthy key of
189 SOME n2m_sugar => (n2m_sugar, no_defs_lthy)
192 val base_fp_names = Name.variant_list [] fp_b_names;
193 val fp_bs = map2 (fn b_name => fn base_fp_name =>
194 Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
195 b_names base_fp_names;
197 val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
198 dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
199 fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
201 val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
202 val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
204 val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
205 mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
207 fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
209 val ((co_iterss, co_iter_defss), lthy) =
211 (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
212 else define_coiters [unfoldN, corecN] (the coiters_args_types))
213 (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
216 val rho = tvar_subst thy Ts fpTs;
218 Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
219 (Morphism.term_morphism (Term.subst_TVars rho));
220 val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
222 val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
224 val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
225 sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
226 if fp = Least_FP then
227 derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
228 xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
229 co_iterss co_iter_defss lthy
230 |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
231 ([induct], fold_thmss, rec_thmss, [], [], [], []))
232 ||> (fn info => (SOME info, NONE))
234 derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types)
235 xtor_co_induct dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs
236 ctrXs_Tsss kss mss ns ctr_defss ctr_sugars co_iterss co_iter_defss
237 (Proof_Context.export lthy no_defs_lthy) lthy
238 |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
239 (disc_unfold_thmss, disc_corec_thmss, _), _,
240 (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
241 (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
242 disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
243 ||> (fn info => (NONE, SOME info));
245 val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
247 fun mk_target_fp_sugar (kk, T) =
248 {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
249 nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
250 ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
251 co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
252 disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
253 sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
254 |> morph_fp_sugar phi;
256 val n2m_sugar = (map_index mk_target_fp_sugar fpTs, fp_sugar_thms);
258 (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
262 ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
264 fun indexify_callsss fp_sugar callsss =
266 val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
268 (case AList.lookup Term.aconv_untyped callsss ctr of
269 NONE => replicate (num_binder_types (fastype_of ctr)) []
270 | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
275 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
277 val qsoty = quote o Syntax.string_of_typ lthy;
278 val qsotys = space_implode " or " o map qsoty;
280 fun duplicate_datatype T = error (qsoty T ^ " is not mutually recursive with itself");
281 fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
282 fun not_co_datatype (T as Type (s, _)) =
283 if fp = Least_FP andalso
284 is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
285 error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
288 | not_co_datatype T = not_co_datatype0 T;
289 fun not_mutually_nested_rec Ts1 Ts2 =
290 error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^
293 val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
295 val perm_actual_Ts as Type (_, tyargs0) :: _ =
296 sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
298 fun check_enrich_with_mutuals _ [] = []
299 | check_enrich_with_mutuals seen ((T as Type (T_name, tyargs)) :: Ts) =
300 (case fp_sugar_of lthy T_name of
301 SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) =>
304 val mutual_Ts = map (fn Type (s, _) => Type (s, tyargs)) Ts';
306 seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
307 not_mutually_nested_rec mutual_Ts seen;
308 val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
310 mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
314 | NONE => not_co_datatype T)
315 | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
317 val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
318 val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
319 val Ts = actual_Ts @ missing_Ts;
322 val kks = 0 upto nn - 1;
324 val callssss0 = pad_list [] nn actual_callssss0;
326 val common_name = mk_common_name (map Binding.name_of actual_bs);
327 val bs = pad_list (Binding.name common_name) nn actual_bs;
329 fun permute xs = permute_like (op =) Ts perm_Ts xs;
330 fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
332 val perm_bs = permute bs;
333 val perm_kks = permute kks;
334 val perm_callssss0 = permute callssss0;
335 val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
337 val has_nested = exists (fn Type (_, tyargs) => tyargs <> tyargs0) Ts;
338 val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0;
340 val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
342 val ((perm_fp_sugars, fp_sugar_thms), lthy) =
343 mutualize_fp_sugars has_nested fp perm_bs perm_Ts get_perm_indices perm_callssss
344 perm_fp_sugars0 lthy;
346 val fp_sugars = unpermute perm_fp_sugars;
348 ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)