1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Mon Sep 10 17:32:39 2012 +0200
1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Mon Sep 10 17:35:53 2012 +0200
1.3 @@ -50,11 +50,33 @@
1.4 fun mk_uncurried2_fun f xss =
1.5 mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
1.6
1.7 +val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
1.8 +val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
1.9 +
1.10 +fun mk_InN_balanced ctxt sum_T Ts t k =
1.11 + let
1.12 + val u =
1.13 + Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} (length Ts) k;
1.14 + in singleton (Type_Infer_Context.infer_types ctxt) (Type.constraint sum_T u) end;
1.15 +
1.16 +val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
1.17 +
1.18 +fun mk_sumEN_balanced n =
1.19 + let
1.20 + val thm =
1.21 + Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
1.22 + (replicate n asm_rl) OF (replicate n (impI RS allI));
1.23 + val f as (_, f_T) =
1.24 + Term.add_vars (prop_of thm) []
1.25 + |> filter (fn ((s, _), _) => s = "f") |> the_single;
1.26 + val inst = [pairself (cterm_of @{theory}) (Var f, Abs (Name.uu, domain_type f_T, Bound 0))];
1.27 + in cterm_instantiate inst thm end;
1.28 +
1.29 fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
1.30
1.31 fun tack z_name (c, v) f =
1.32 let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
1.33 - Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
1.34 + Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
1.35 end;
1.36
1.37 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
1.38 @@ -148,9 +170,9 @@
1.39 | freeze_fp T = T;
1.40
1.41 val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
1.42 - val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
1.43 + val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
1.44
1.45 - val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
1.46 + val eqs = map dest_TFree Xs ~~ ctr_sum_prod_TsXs;
1.47
1.48 val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
1.49 fp_iter_thms, fp_rec_thms), lthy)) =
1.50 @@ -215,7 +237,7 @@
1.51 if lfp then
1.52 let
1.53 val y_Tsss =
1.54 - map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
1.55 + map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
1.56 ns mss fp_iter_fun_Ts;
1.57 val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
1.58
1.59 @@ -225,8 +247,8 @@
1.60 ||>> mk_Freesss "x" y_Tsss;
1.61
1.62 val z_Tssss =
1.63 - map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
1.64 - o domain_type) ns mss fp_rec_fun_Ts;
1.65 + map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
1.66 + dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
1.67 val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
1.68
1.69 val hss = map2 (map2 retype_free) gss h_Tss;
1.70 @@ -251,7 +273,7 @@
1.71 fun mk_types fun_Ts =
1.72 let
1.73 val f_sum_prod_Ts = map range_type fun_Ts;
1.74 - val f_prod_Tss = map2 dest_sumTN ns f_sum_prod_Ts;
1.75 + val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
1.76 val f_Tsss =
1.77 map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
1.78 val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
1.79 @@ -288,6 +310,7 @@
1.80 let
1.81 val unfT = domain_type (fastype_of fld);
1.82 val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
1.83 + val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
1.84 val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
1.85
1.86 val ((((u, v), fs), xss), _) =
1.87 @@ -299,12 +322,15 @@
1.88
1.89 val ctr_rhss =
1.90 map2 (fn k => fn xs =>
1.91 - fold_rev Term.lambda xs (fld $ mk_InN ctr_prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
1.92 + fold_rev Term.lambda xs (fld $ mk_InN_balanced no_defs_lthy ctr_sum_prod_T ctr_prod_Ts
1.93 + (HOLogic.mk_tuple xs) k))
1.94 + ks xss;
1.95
1.96 val case_binder = Binding.suffix_name ("_" ^ caseN) b;
1.97
1.98 val case_rhs =
1.99 - fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
1.100 + fold_rev Term.lambda (fs @ [v])
1.101 + (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ v));
1.102
1.103 val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
1.104 |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
1.105 @@ -340,7 +366,8 @@
1.106
1.107 val sumEN_thm' =
1.108 Local_Defs.unfold lthy @{thms all_unit_eq}
1.109 - (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) [] (mk_sumEN n))
1.110 + (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) []
1.111 + (mk_sumEN_balanced n))
1.112 |> Morphism.thm phi;
1.113 in
1.114 mk_exhaust_tac ctxt n ctr_defs fld_iff_unf_thm sumEN_thm'
1.115 @@ -373,7 +400,7 @@
1.116 val spec =
1.117 mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binder, res_T)),
1.118 Term.list_comb (fp_iter_like,
1.119 - map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) fss xssss));
1.120 + map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
1.121 in (binder, spec) end;
1.122
1.123 val iter_likes =
1.124 @@ -411,7 +438,8 @@
1.125
1.126 fun mk_preds_getters_join c n cps sum_prod_T prod_Ts cfss =
1.127 Term.lambda c (mk_IfN sum_prod_T cps
1.128 - (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
1.129 + (map2 (mk_InN_balanced no_defs_lthy sum_prod_T prod_Ts)
1.130 + (map HOLogic.mk_tuple cfss) (1 upto n)));
1.131
1.132 val spec =
1.133 mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),