give lambda abstractions a chance, as an alternative to function composition, for corecursion via "fun"
1 (* Title: HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
2 Author: Lorenz Panny, TU Muenchen
3 Author: Jasmin Blanchette, TU Muenchen
6 Library for recursor and corecursor sugar.
9 signature BNF_FP_REC_SUGAR_UTIL =
13 Direct_Rec of int (*before*) * int (*after*) |
17 Dummy_No_Corec of int |
19 Direct_Corec of int (*stop?*) * int (*end*) * int (*continue*) |
33 calls: corec_call list,
43 nested_map_idents: thm list,
44 nested_map_comps: thm list,
45 ctr_specs: rec_ctr_spec list}
49 nested_maps: thm list,
50 nested_map_idents: thm list,
51 nested_map_comps: thm list,
52 ctr_specs: corec_ctr_spec list}
54 val massage_indirect_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
55 typ list -> term -> term -> term -> term
56 val massage_direct_corec_call: Proof.context -> (term -> bool) -> (term -> term) -> typ -> term ->
58 val massage_indirect_corec_call: Proof.context -> (term -> bool) ->
59 (typ -> typ -> term -> term) -> typ list -> typ -> term -> term
60 val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
61 val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ -> term -> term
62 val fold_rev_corec_code_rhs: (term -> term list -> 'a -> 'a) -> term -> 'a -> 'a
63 val simplify_bool_ifs: theory -> term -> term list
64 val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
65 ((term * term list list) list) list -> local_theory ->
66 (bool * rec_spec list * typ list * thm * thm list) * local_theory
67 val corec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
68 ((term * term list list) list) list -> local_theory ->
69 (bool * corec_spec list * typ list * thm * thm * thm list * thm list) * local_theory
72 structure BNF_FP_Rec_Sugar_Util : BNF_FP_REC_SUGAR_UTIL =
84 Direct_Rec of int * int |
88 Dummy_No_Corec of int |
90 Direct_Corec of int * int * int |
91 Indirect_Corec of int;
104 calls: corec_call list,
110 sel_corecs: thm list};
114 nested_map_idents: thm list,
115 nested_map_comps: thm list,
116 ctr_specs: rec_ctr_spec list};
120 nested_maps: thm list,
121 nested_map_idents: thm list,
122 nested_map_comps: thm list,
123 ctr_specs: corec_ctr_spec list};
125 val id_def = @{thm id_def};
127 exception AINT_NO_MAP of term;
129 fun ill_formed_rec_call ctxt t =
130 error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
131 fun ill_formed_corec_call ctxt t =
132 error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
133 fun invalid_map ctxt t =
134 error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
135 fun unexpected_rec_call ctxt t =
136 error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
137 fun unexpected_corec_call ctxt t =
138 error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
140 fun factor_out_types ctxt massage destU U T =
142 SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt
143 | NONE => invalid_map ctxt);
145 fun map_flattened_map_args ctxt s map_args fs =
147 val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
148 val flat_fs' = map_args flat_fs;
150 permute_like (op aconv) flat_fs fs flat_fs'
153 fun massage_indirect_rec_call ctxt has_call massage_unapplied_direct_call bound_Ts y y' =
155 val typof = curry fastype_of1 bound_Ts;
156 val build_map_fst = build_map ctxt (fst_const o fst);
161 fun y_of_y' () = build_map_fst (yU, yT) $ y';
162 val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
164 fun check_and_massage_unapplied_direct_call U T t =
166 factor_out_types ctxt massage_unapplied_direct_call HOLogic.dest_prodT U T t
168 HOLogic.mk_comp (t, build_map_fst (U, T));
170 fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
171 (case try (dest_map ctxt s) t of
174 val Type (_, ran_Ts) = range_type (typof t);
175 val map' = mk_map (length fs) Us ran_Ts map0;
176 val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
178 list_comb (map', fs')
180 | NONE => raise AINT_NO_MAP t)
181 | massage_map _ _ t = raise AINT_NO_MAP t
182 and massage_map_or_map_arg U T t =
184 if has_call t then unexpected_rec_call ctxt t else t
187 handle AINT_NO_MAP _ => check_and_massage_unapplied_direct_call U T t;
189 fun massage_call (t as t1 $ t2) =
191 massage_map yU yT (elim_y t1) $ y'
192 handle AINT_NO_MAP t' => invalid_map ctxt t'
194 ill_formed_rec_call ctxt t
195 | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
200 fun massage_let_and_if ctxt has_call massage_leaf U =
202 val check_cond = ((not o has_call) orf unexpected_corec_call ctxt);
204 (case Term.strip_comb t of
205 (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
206 | (Const (@{const_name If}, _), arg :: args) =>
207 list_comb (If_const U $ tap check_cond arg, map massage_rec args)
208 | _ => massage_leaf t)
213 fun fold_rev_let_and_if f =
216 (case Term.strip_comb t of
217 (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
218 | (Const (@{const_name If}, _), _ :: args) => fold_rev fld args
224 fun massage_direct_corec_call ctxt has_call massage_direct_call U t =
225 massage_let_and_if ctxt has_call massage_direct_call U t;
227 fun massage_indirect_corec_call ctxt has_call massage_direct_call bound_Ts U t =
229 val typof = curry fastype_of1 bound_Ts;
230 val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd)
232 fun check_and_massage_direct_call U T t =
233 if has_call t then factor_out_types ctxt massage_direct_call dest_sumT U T t
234 else build_map_Inl (T, U) $ t;
236 fun check_and_massage_unapplied_direct_call U T t =
237 let val var = Var ((Name.uu, Term.maxidx_of_term t + 1), domain_type (typof t)) in
238 Term.lambda var (check_and_massage_direct_call U T (t $ var))
241 fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
242 (case try (dest_map ctxt s) t of
245 val Type (_, dom_Ts) = domain_type (typof t);
246 val map' = mk_map (length fs) dom_Ts Us map0;
247 val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
249 list_comb (map', fs')
251 | NONE => raise AINT_NO_MAP t)
252 | massage_map _ _ t = raise AINT_NO_MAP t
253 and massage_map_or_map_arg U T t =
255 if has_call t then unexpected_corec_call ctxt t else t
258 handle AINT_NO_MAP _ => check_and_massage_unapplied_direct_call U T t;
260 fun massage_call U T =
261 massage_let_and_if ctxt has_call (fn t =>
265 (case try (dest_ctr ctxt s) t of
267 let val f' = mk_ctr Us f in
268 list_comb (f', map3 massage_call (binder_types (typof f')) (map typof args) args)
274 check_and_massage_direct_call U T t
276 massage_map U T t1 $ t2
277 handle AINT_NO_MAP _ => check_and_massage_direct_call U T t)
278 | Abs (s, T', t') => Abs (s, T', massage_call (range_type U) (range_type T) t')
279 | _ => check_and_massage_direct_call U T t))
280 | _ => ill_formed_corec_call ctxt t)
282 build_map_Inl (T, U) $ t) U
284 massage_call U (typof t) t
287 fun expand_ctr_term ctxt s Ts t =
288 (case fp_sugar_of ctxt s of
291 val T = Type (s, Ts);
293 val {ctrs = ctrs0, discs = discs0, selss = selss0, ...} = of_fp_sugar #ctr_sugars fp_sugar;
294 val ctrs = map (mk_ctr Ts) ctrs0;
295 val discs = map (mk_disc_or_sel Ts) discs0;
296 val selss = map (map (mk_disc_or_sel Ts)) selss0;
297 val xdiscs = map (rapp x) discs;
298 val xselss = map (map (rapp x)) selss;
299 val xsel_ctrs = map2 (curry Term.list_comb) ctrs xselss;
300 val xif = mk_IfN T xdiscs xsel_ctrs;
302 Const (@{const_name Let}, T --> (T --> T) --> T) $ t $ Abs (Name.uu, T, xif)
304 | NONE => raise Fail "expand_ctr_term");
306 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
307 (case fastype_of1 (bound_Ts, t) of
309 massage_let_and_if ctxt has_call (fn t =>
310 if can (dest_ctr ctxt s) t then t
311 else massage_let_and_if ctxt has_call I T (expand_ctr_term ctxt s Ts t)) T t
312 | _ => raise Fail "expand_corec_code_rhs");
314 fun massage_corec_code_rhs ctxt massage_ctr =
315 massage_let_and_if ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
317 fun fold_rev_corec_code_rhs f = fold_rev_let_and_if (uncurry f o Term.strip_comb);
319 fun add_conjuncts (Const (@{const_name conj}, _) $ t $ t') = add_conjuncts t o add_conjuncts t'
320 | add_conjuncts t = cons t;
322 fun conjuncts t = add_conjuncts t [];
324 fun simplify_bool_ifs thy =
325 Raw_Simplifier.rewrite_term thy @{thms bool_if_simps[THEN eq_reflection]} []
326 #> conjuncts #> (fn [@{term True}] => [] | ts => ts);
328 fun indexed xs h = let val h' = h + length xs in (h upto h' - 1, h') end;
329 fun indexedd xss = fold_map indexed xss;
330 fun indexeddd xsss = fold_map indexedd xsss;
331 fun indexedddd xssss = fold_map indexeddd xssss;
333 fun find_index_eq hs h = find_index (curry (op =) h) hs;
335 (*FIXME: remove special cases for products and sum once they are registered as datatypes*)
336 fun map_thms_of_typ ctxt (Type (s, _)) =
337 if s = @{type_name prod} then
338 @{thms map_pair_simp}
339 else if s = @{type_name sum} then
340 @{thms sum_map.simps}
342 (case fp_sugar_of ctxt s of
343 SOME {index, mapss, ...} => nth mapss index
345 | map_thms_of_typ _ _ = [];
347 val lose_co_rec = false (*FIXME: try true?*);
349 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
351 val thy = Proof_Context.theory_of lthy;
353 val ((nontriv, missing_arg_Ts, perm0_kks,
354 fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
355 co_inducts = [induct_thm], ...} :: _), lthy') =
356 nested_to_mutual_fps lose_co_rec Least_FP bs arg_Ts get_indices callssss0 lthy;
358 val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
360 val indices = map #index fp_sugars;
361 val perm_indices = map #index perm_fp_sugars;
363 val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
364 val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
365 val perm_fpTs = map (body_type o fastype_of o hd) perm_ctrss;
367 val nn0 = length arg_Ts;
368 val nn = length perm_fpTs;
369 val kks = 0 upto nn - 1;
370 val perm_ns = map length perm_ctr_Tsss;
371 val perm_mss = map (map length) perm_ctr_Tsss;
373 val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
375 val perm_fun_arg_Tssss =
376 mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
378 fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
379 fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
381 val induct_thms = unpermute0 (conj_dests nn induct_thm);
383 val fpTs = unpermute perm_fpTs;
384 val Cs = unpermute perm_Cs;
386 val As_rho = tvar_subst thy (take nn0 fpTs) arg_Ts;
387 val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
389 val substA = Term.subst_TVars As_rho;
390 val substAT = Term.typ_subst_TVars As_rho;
391 val substCT = Term.typ_subst_TVars Cs_rho;
393 val perm_Cs' = map substCT perm_Cs;
395 fun offset_of_ctr 0 _ = 0
396 | offset_of_ctr n ({ctrs, ...} :: ctr_sugars) =
397 length ctrs + offset_of_ctr (n - 1) ctr_sugars;
399 fun call_of [i] [T] = (if exists_subtype_in Cs T then Indirect_Rec else No_Rec) i
400 | call_of [i, i'] _ = Direct_Rec (i, i');
402 fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
404 val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
405 val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
406 val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
408 {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
412 fun mk_ctr_specs index ctr_sugars iter_thmsss =
414 val ctrs = #ctrs (nth ctr_sugars index);
415 val rec_thmss = co_rec_of (nth iter_thmsss index);
416 val k = offset_of_ctr index ctr_sugars;
419 map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss
422 fun mk_spec {T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...} =
423 {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
424 nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
425 nested_map_comps = map map_comp_of_bnf nested_bnfs,
426 ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
428 ((nontriv, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms), lthy')
431 fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
433 val thy = Proof_Context.theory_of lthy;
435 val ((nontriv, missing_res_Ts, perm0_kks,
436 fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...},
437 co_inducts = coinduct_thms, ...} :: _), lthy') =
438 nested_to_mutual_fps lose_co_rec Greatest_FP bs res_Ts get_indices callssss0 lthy;
440 val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
442 val indices = map #index fp_sugars;
443 val perm_indices = map #index perm_fp_sugars;
445 val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
446 val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
447 val perm_fpTs = map (body_type o fastype_of o hd) perm_ctrss;
449 val nn0 = length res_Ts;
450 val nn = length perm_fpTs;
451 val kks = 0 upto nn - 1;
452 val perm_ns = map length perm_ctr_Tsss;
454 val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o
455 of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars;
456 val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) =
457 mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1);
459 val (perm_p_hss, h) = indexedd perm_p_Tss 0;
460 val (perm_q_hssss, h') = indexedddd perm_q_Tssss h;
461 val (perm_f_hssss, _) = indexedddd perm_f_Tssss h';
464 flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
466 fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
467 fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
469 val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
471 val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss);
472 val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss);
473 val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss);
475 val f_Tssss = unpermute perm_f_Tssss;
476 val fpTs = unpermute perm_fpTs;
477 val Cs = unpermute perm_Cs;
479 val As_rho = tvar_subst thy (take nn0 fpTs) res_Ts;
480 val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts;
482 val substA = Term.subst_TVars As_rho;
483 val substAT = Term.typ_subst_TVars As_rho;
484 val substCT = Term.typ_subst_TVars Cs_rho;
486 val perm_Cs' = map substCT perm_Cs;
488 fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] =
489 (if exists_subtype_in Cs T then Indirect_Corec
490 else if nullary then Dummy_No_Corec
492 | call_of _ [q_i] [g_i, g_i'] _ = Direct_Corec (q_i, g_i, g_i');
494 fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss discI sel_thms collapse corec_thm
495 disc_corec sel_corecs =
496 let val nullary = not (can dest_funT (fastype_of ctr)) in
497 {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho,
498 calls = map3 (call_of nullary) q_iss f_iss f_Tss, discI = discI, sel_thms = sel_thms,
499 collapse = collapse, corec_thm = corec_thm, disc_corec = disc_corec,
500 sel_corecs = sel_corecs}
503 fun mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss disc_coitersss
506 val ctrs = #ctrs (nth ctr_sugars index);
507 val discs = #discs (nth ctr_sugars index);
508 val selss = #selss (nth ctr_sugars index);
509 val p_ios = map SOME p_is @ [NONE];
510 val discIs = #discIs (nth ctr_sugars index);
511 val sel_thmss = #sel_thmss (nth ctr_sugars index);
512 val collapses = #collapses (nth ctr_sugars index);
513 val corec_thms = co_rec_of (nth coiter_thmsss index);
514 val disc_corecs = (case co_rec_of (nth disc_coitersss index) of [] => [TrueI]
516 val sel_corecss = co_rec_of (nth sel_coiterssss index);
518 map13 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss discIs sel_thmss collapses
519 corec_thms disc_corecs sel_corecss
522 fun mk_spec {T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss,
523 disc_co_itersss = disc_coitersss, sel_co_iterssss = sel_coiterssss, ...}
524 p_is q_isss f_isss f_Tsss =
525 {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)),
526 nested_maps = maps (map_thms_of_typ lthy o T_of_bnf) nested_bnfs,
527 nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
528 nested_map_comps = map map_comp_of_bnf nested_bnfs,
529 ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss
530 disc_coitersss sel_coiterssss};
532 ((nontriv, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
533 co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss,
534 strong_co_induct_of coinduct_thmss), lthy')