1 (* Title: HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
2 Author: Lorenz Panny, TU Muenchen
3 Author: Jasmin Blanchette, TU Muenchen
6 Library for recursor and corecursor sugar.
9 signature BNF_FP_REC_SUGAR_UTIL =
13 Direct_Rec of int (*before*) * int (*after*) |
17 Dummy_No_Corec of int |
19 Direct_Corec of int (*stop?*) * int (*end*) * int (*continue*) |
33 calls: corec_call list,
43 nested_map_idents: thm list,
44 nested_map_comps: thm list,
45 ctr_specs: rec_ctr_spec list}
49 nested_maps: thm list,
50 nested_map_idents: thm list,
51 nested_map_comps: thm list,
52 ctr_specs: corec_ctr_spec list}
54 val massage_indirect_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
55 typ list -> term -> term -> term -> term
56 val massage_direct_corec_call: Proof.context -> (term -> bool) -> (term -> term) -> typ -> term ->
58 val massage_indirect_corec_call: Proof.context -> (term -> bool) ->
59 (typ -> typ -> term -> term) -> typ list -> typ -> term -> term
60 val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
61 val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ -> term -> term
62 val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
63 ((term * term list list) list) list -> local_theory ->
64 (bool * rec_spec list * typ list * thm * thm list) * local_theory
65 val corec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
66 ((term * term list list) list) list -> local_theory ->
67 (bool * corec_spec list * typ list * thm * thm * thm list * thm list) * local_theory
70 structure BNF_FP_Rec_Sugar_Util : BNF_FP_REC_SUGAR_UTIL =
82 Direct_Rec of int * int |
86 Dummy_No_Corec of int |
88 Direct_Corec of int * int * int |
89 Indirect_Corec of int;
102 calls: corec_call list,
108 sel_corecs: thm list};
112 nested_map_idents: thm list,
113 nested_map_comps: thm list,
114 ctr_specs: rec_ctr_spec list};
118 nested_maps: thm list,
119 nested_map_idents: thm list,
120 nested_map_comps: thm list,
121 ctr_specs: corec_ctr_spec list};
123 val id_def = @{thm id_def};
125 exception AINT_NO_MAP of term;
127 fun ill_formed_rec_call ctxt t =
128 error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
129 fun ill_formed_corec_call ctxt t =
130 error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
131 fun invalid_map ctxt t =
132 error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
133 fun unexpected_rec_call ctxt t =
134 error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
135 fun unexpected_corec_call ctxt t =
136 error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
138 fun factor_out_types ctxt massage destU U T =
140 SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt
141 | NONE => invalid_map ctxt);
143 fun map_flattened_map_args ctxt s map_args fs =
145 val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
146 val flat_fs' = map_args flat_fs;
148 permute_like (op aconv) flat_fs fs flat_fs'
151 fun massage_indirect_rec_call ctxt has_call massage_unapplied_direct_call bound_Ts y y' =
153 val typof = curry fastype_of1 bound_Ts;
154 val build_map_fst = build_map ctxt (fst_const o fst);
159 fun y_of_y' () = build_map_fst (yU, yT) $ y';
160 val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
162 fun check_and_massage_unapplied_direct_call U T t =
164 factor_out_types ctxt massage_unapplied_direct_call HOLogic.dest_prodT U T t
166 HOLogic.mk_comp (t, build_map_fst (U, T));
168 fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
169 (case try (dest_map ctxt s) t of
172 val Type (_, ran_Ts) = range_type (typof t);
173 val map' = mk_map (length fs) Us ran_Ts map0;
174 val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
176 list_comb (map', fs')
178 | NONE => raise AINT_NO_MAP t)
179 | massage_map _ _ t = raise AINT_NO_MAP t
180 and massage_map_or_map_arg U T t =
182 if has_call t then unexpected_rec_call ctxt t else t
185 handle AINT_NO_MAP _ => check_and_massage_unapplied_direct_call U T t;
187 fun massage_call (t as t1 $ t2) =
189 massage_map yU yT (elim_y t1) $ y'
190 handle AINT_NO_MAP t' => invalid_map ctxt t'
192 ill_formed_rec_call ctxt t
193 | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
198 fun massage_let_and_if ctxt has_call massage_leaf U =
200 val check_cond = ((not o has_call) orf unexpected_corec_call ctxt);
202 (case Term.strip_comb t of
203 (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
204 | (Const (@{const_name If}, _), arg :: args) =>
205 list_comb (If_const U $ tap check_cond arg, map massage_rec args)
206 | _ => massage_leaf t)
211 fun massage_direct_corec_call ctxt has_call massage_direct_call U t =
212 massage_let_and_if ctxt has_call massage_direct_call U t;
214 fun massage_indirect_corec_call ctxt has_call massage_direct_call bound_Ts U t =
216 val typof = curry fastype_of1 bound_Ts;
217 val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd)
219 fun check_and_massage_direct_call U T t =
220 if has_call t then factor_out_types ctxt massage_direct_call dest_sumT U T t
221 else build_map_Inl (T, U) $ t;
223 fun check_and_massage_unapplied_direct_call U T t =
224 let val var = Var ((Name.uu, Term.maxidx_of_term t + 1), domain_type (typof t)) in
225 Term.lambda var (check_and_massage_direct_call U T (t $ var))
228 fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
229 (case try (dest_map ctxt s) t of
232 val Type (_, dom_Ts) = domain_type (typof t);
233 val map' = mk_map (length fs) dom_Ts Us map0;
234 val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
236 list_comb (map', fs')
238 | NONE => raise AINT_NO_MAP t)
239 | massage_map _ _ t = raise AINT_NO_MAP t
240 and massage_map_or_map_arg U T t =
242 if has_call t then unexpected_corec_call ctxt t else t
245 handle AINT_NO_MAP _ => check_and_massage_unapplied_direct_call U T t;
247 fun massage_call U T =
248 massage_let_and_if ctxt has_call (fn t =>
251 (case try (dest_ctr ctxt s) t of
253 let val f' = mk_ctr Us f in
254 list_comb (f', map3 massage_call (binder_types (typof f')) (map typof args) args)
260 check_and_massage_direct_call U T t
262 massage_map U T t1 $ t2
263 handle AINT_NO_MAP _ => check_and_massage_direct_call U T t)
264 | _ => check_and_massage_direct_call U T t))
265 | _ => ill_formed_corec_call ctxt t)) U
267 massage_call U (typof t) t
270 fun expand_ctr_term ctxt s Ts t =
271 (case fp_sugar_of ctxt s of
274 val T = Type (s, Ts);
276 val {ctrs = ctrs0, discs = discs0, selss = selss0, ...} = of_fp_sugar #ctr_sugars fp_sugar;
277 val ctrs = map (mk_ctr Ts) ctrs0;
278 val discs = map (mk_disc_or_sel Ts) discs0;
279 val selss = map (map (mk_disc_or_sel Ts)) selss0;
280 val xdiscs = map (rapp x) discs;
281 val xselss = map (map (rapp x)) selss;
282 val xsel_ctrs = map2 (curry Term.list_comb) ctrs xselss;
283 val xif = mk_IfN T xdiscs xsel_ctrs;
285 Const (@{const_name Let}, T --> (T --> T) --> T) $ t $ Abs (Name.uu, T, xif)
287 | NONE => raise Fail "expand_ctr_term");
289 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
290 (case fastype_of1 (bound_Ts, t) of
292 massage_let_and_if ctxt has_call (fn t =>
293 if can (dest_ctr ctxt s) t then t
294 else massage_let_and_if ctxt has_call I T (expand_ctr_term ctxt s Ts t)) T t
295 | _ => raise Fail "expand_corec_code_rhs");
297 fun massage_corec_code_rhs ctxt massage_ctr =
298 massage_let_and_if ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
300 fun indexed xs h = let val h' = h + length xs in (h upto h' - 1, h') end;
301 fun indexedd xss = fold_map indexed xss;
302 fun indexeddd xsss = fold_map indexedd xsss;
303 fun indexedddd xssss = fold_map indexeddd xssss;
305 fun find_index_eq hs h = find_index (curry (op =) h) hs;
307 (*FIXME: remove special cases for products and sum once they are registered as datatypes*)
308 fun map_thms_of_typ ctxt (Type (s, _)) =
309 if s = @{type_name prod} then
310 @{thms map_pair_simp}
311 else if s = @{type_name sum} then
312 @{thms sum_map.simps}
314 (case fp_sugar_of ctxt s of
315 SOME {index, mapss, ...} => nth mapss index
317 | map_thms_of_typ _ _ = [];
319 val lose_co_rec = false (*FIXME: try true?*);
321 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
323 val thy = Proof_Context.theory_of lthy;
325 val ((nontriv, missing_arg_Ts, perm0_kks,
326 fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
327 co_inducts = [induct_thm], ...} :: _), lthy') =
328 nested_to_mutual_fps lose_co_rec Least_FP bs arg_Ts get_indices callssss0 lthy;
330 val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
332 val indices = map #index fp_sugars;
333 val perm_indices = map #index perm_fp_sugars;
335 val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
336 val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
337 val perm_fpTs = map (body_type o fastype_of o hd) perm_ctrss;
339 val nn0 = length arg_Ts;
340 val nn = length perm_fpTs;
341 val kks = 0 upto nn - 1;
342 val perm_ns = map length perm_ctr_Tsss;
343 val perm_mss = map (map length) perm_ctr_Tsss;
345 val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
347 val perm_fun_arg_Tssss =
348 mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
350 fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
351 fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
353 val induct_thms = unpermute0 (conj_dests nn induct_thm);
355 val fpTs = unpermute perm_fpTs;
356 val Cs = unpermute perm_Cs;
358 val As_rho = tvar_subst thy (take nn0 fpTs) arg_Ts;
359 val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
361 val substA = Term.subst_TVars As_rho;
362 val substAT = Term.typ_subst_TVars As_rho;
363 val substCT = Term.typ_subst_TVars Cs_rho;
365 val perm_Cs' = map substCT perm_Cs;
367 fun offset_of_ctr 0 _ = 0
368 | offset_of_ctr n ({ctrs, ...} :: ctr_sugars) =
369 length ctrs + offset_of_ctr (n - 1) ctr_sugars;
371 fun call_of [i] [T] = (if exists_subtype_in Cs T then Indirect_Rec else No_Rec) i
372 | call_of [i, i'] _ = Direct_Rec (i, i');
374 fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
376 val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
377 val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
378 val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
380 {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
384 fun mk_ctr_specs index ctr_sugars iter_thmsss =
386 val ctrs = #ctrs (nth ctr_sugars index);
387 val rec_thmss = co_rec_of (nth iter_thmsss index);
388 val k = offset_of_ctr index ctr_sugars;
391 map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss
394 fun mk_spec {T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...} =
395 {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
396 nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
397 nested_map_comps = map map_comp_of_bnf nested_bnfs,
398 ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
400 ((nontriv, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms), lthy')
403 fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
405 val thy = Proof_Context.theory_of lthy;
407 val ((nontriv, missing_res_Ts, perm0_kks,
408 fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...},
409 co_inducts = coinduct_thms, ...} :: _), lthy') =
410 nested_to_mutual_fps lose_co_rec Greatest_FP bs res_Ts get_indices callssss0 lthy;
412 val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
414 val indices = map #index fp_sugars;
415 val perm_indices = map #index perm_fp_sugars;
417 val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
418 val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
419 val perm_fpTs = map (body_type o fastype_of o hd) perm_ctrss;
421 val nn0 = length res_Ts;
422 val nn = length perm_fpTs;
423 val kks = 0 upto nn - 1;
424 val perm_ns = map length perm_ctr_Tsss;
426 val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o
427 of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars;
428 val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) =
429 mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1);
431 val (perm_p_hss, h) = indexedd perm_p_Tss 0;
432 val (perm_q_hssss, h') = indexedddd perm_q_Tssss h;
433 val (perm_f_hssss, _) = indexedddd perm_f_Tssss h';
436 flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
438 fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
439 fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
441 val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
443 val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss);
444 val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss);
445 val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss);
447 val f_Tssss = unpermute perm_f_Tssss;
448 val fpTs = unpermute perm_fpTs;
449 val Cs = unpermute perm_Cs;
451 val As_rho = tvar_subst thy (take nn0 fpTs) res_Ts;
452 val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts;
454 val substA = Term.subst_TVars As_rho;
455 val substAT = Term.typ_subst_TVars As_rho;
456 val substCT = Term.typ_subst_TVars Cs_rho;
458 val perm_Cs' = map substCT perm_Cs;
460 fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] =
461 (if exists_subtype_in Cs T then Indirect_Corec
462 else if nullary then Dummy_No_Corec
464 | call_of _ [q_i] [g_i, g_i'] _ = Direct_Corec (q_i, g_i, g_i');
466 fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss discI sel_thms collapse corec_thm
467 disc_corec sel_corecs =
468 let val nullary = not (can dest_funT (fastype_of ctr)) in
469 {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho,
470 calls = map3 (call_of nullary) q_iss f_iss f_Tss, discI = discI, sel_thms = sel_thms,
471 collapse = collapse, corec_thm = corec_thm, disc_corec = disc_corec,
472 sel_corecs = sel_corecs}
475 fun mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss disc_coitersss
478 val ctrs = #ctrs (nth ctr_sugars index);
479 val discs = #discs (nth ctr_sugars index);
480 val selss = #selss (nth ctr_sugars index);
481 val p_ios = map SOME p_is @ [NONE];
482 val discIs = #discIs (nth ctr_sugars index);
483 val sel_thmss = #sel_thmss (nth ctr_sugars index);
484 val collapses = #collapses (nth ctr_sugars index);
485 val corec_thms = co_rec_of (nth coiter_thmsss index);
486 val disc_corecs = (case co_rec_of (nth disc_coitersss index) of [] => [TrueI]
488 val sel_corecss = co_rec_of (nth sel_coiterssss index);
490 map13 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss discIs sel_thmss collapses
491 corec_thms disc_corecs sel_corecss
494 fun mk_spec {T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss,
495 disc_co_itersss = disc_coitersss, sel_co_iterssss = sel_coiterssss, ...}
496 p_is q_isss f_isss f_Tsss =
497 {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)),
498 nested_maps = maps (map_thms_of_typ lthy o T_of_bnf) nested_bnfs,
499 nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
500 nested_map_comps = map map_comp_of_bnf nested_bnfs,
501 ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss
502 disc_coitersss sel_coiterssss};
504 ((nontriv, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
505 co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss,
506 strong_co_induct_of coinduct_thmss), lthy')