src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 50270 2ecc533d6697
parent 50269 edc322ac5279
child 50271 df98aeb80a19
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:32:39 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:35:53 2012 +0200
     1.3 @@ -50,11 +50,33 @@
     1.4  fun mk_uncurried2_fun f xss =
     1.5    mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
     1.6  
     1.7 +val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
     1.8 +val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
     1.9 +
    1.10 +fun mk_InN_balanced ctxt sum_T Ts t k =
    1.11 +  let
    1.12 +    val u =
    1.13 +      Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} (length Ts) k;
    1.14 +  in singleton (Type_Infer_Context.infer_types ctxt) (Type.constraint sum_T u) end;
    1.15 +
    1.16 +val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
    1.17 +
    1.18 +fun mk_sumEN_balanced n =
    1.19 +  let
    1.20 +    val thm =
    1.21 +      Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
    1.22 +        (replicate n asm_rl) OF (replicate n (impI RS allI));
    1.23 +    val f as (_, f_T) =
    1.24 +      Term.add_vars (prop_of thm) []
    1.25 +      |> filter (fn ((s, _), _) => s = "f") |> the_single;
    1.26 +    val inst = [pairself (cterm_of @{theory}) (Var f, Abs (Name.uu, domain_type f_T, Bound 0))];
    1.27 +  in cterm_instantiate inst thm end;
    1.28 +
    1.29  fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
    1.30  
    1.31  fun tack z_name (c, v) f =
    1.32    let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
    1.33 -    Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
    1.34 +    Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
    1.35    end;
    1.36  
    1.37  fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    1.38 @@ -148,9 +170,9 @@
    1.39        | freeze_fp T = T;
    1.40  
    1.41      val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
    1.42 -    val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
    1.43 +    val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
    1.44  
    1.45 -    val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
    1.46 +    val eqs = map dest_TFree Xs ~~ ctr_sum_prod_TsXs;
    1.47  
    1.48      val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
    1.49          fp_iter_thms, fp_rec_thms), lthy)) =
    1.50 @@ -215,7 +237,7 @@
    1.51        if lfp then
    1.52          let
    1.53            val y_Tsss =
    1.54 -            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
    1.55 +            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
    1.56                ns mss fp_iter_fun_Ts;
    1.57            val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
    1.58  
    1.59 @@ -225,8 +247,8 @@
    1.60              ||>> mk_Freesss "x" y_Tsss;
    1.61  
    1.62            val z_Tssss =
    1.63 -            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
    1.64 -              o domain_type) ns mss fp_rec_fun_Ts;
    1.65 +            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
    1.66 +              dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
    1.67            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
    1.68  
    1.69            val hss = map2 (map2 retype_free) gss h_Tss;
    1.70 @@ -251,7 +273,7 @@
    1.71            fun mk_types fun_Ts =
    1.72              let
    1.73                val f_sum_prod_Ts = map range_type fun_Ts;
    1.74 -              val f_prod_Tss = map2 dest_sumTN ns f_sum_prod_Ts;
    1.75 +              val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
    1.76                val f_Tsss =
    1.77                  map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
    1.78                val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
    1.79 @@ -288,6 +310,7 @@
    1.80        let
    1.81          val unfT = domain_type (fastype_of fld);
    1.82          val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
    1.83 +        val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
    1.84          val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
    1.85  
    1.86          val ((((u, v), fs), xss), _) =
    1.87 @@ -299,12 +322,15 @@
    1.88  
    1.89          val ctr_rhss =
    1.90            map2 (fn k => fn xs =>
    1.91 -            fold_rev Term.lambda xs (fld $ mk_InN ctr_prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
    1.92 +              fold_rev Term.lambda xs (fld $ mk_InN_balanced no_defs_lthy ctr_sum_prod_T ctr_prod_Ts
    1.93 +                (HOLogic.mk_tuple xs) k))
    1.94 +            ks xss;
    1.95  
    1.96          val case_binder = Binding.suffix_name ("_" ^ caseN) b;
    1.97  
    1.98          val case_rhs =
    1.99 -          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
   1.100 +          fold_rev Term.lambda (fs @ [v])
   1.101 +            (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ v));
   1.102  
   1.103          val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   1.104            |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   1.105 @@ -340,7 +366,8 @@
   1.106  
   1.107              val sumEN_thm' =
   1.108                Local_Defs.unfold lthy @{thms all_unit_eq}
   1.109 -                (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) [] (mk_sumEN n))
   1.110 +                (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) []
   1.111 +                   (mk_sumEN_balanced n))
   1.112                |> Morphism.thm phi;
   1.113            in
   1.114              mk_exhaust_tac ctxt n ctr_defs fld_iff_unf_thm sumEN_thm'
   1.115 @@ -373,7 +400,7 @@
   1.116                  val spec =
   1.117                    mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binder, res_T)),
   1.118                      Term.list_comb (fp_iter_like,
   1.119 -                      map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) fss xssss));
   1.120 +                      map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
   1.121                in (binder, spec) end;
   1.122  
   1.123              val iter_likes =
   1.124 @@ -411,7 +438,8 @@
   1.125  
   1.126                  fun mk_preds_getters_join c n cps sum_prod_T prod_Ts cfss =
   1.127                    Term.lambda c (mk_IfN sum_prod_T cps
   1.128 -                    (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
   1.129 +                    (map2 (mk_InN_balanced no_defs_lthy sum_prod_T prod_Ts)
   1.130 +                      (map HOLogic.mk_tuple cfss) (1 upto n)));
   1.131  
   1.132                  val spec =
   1.133                    mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),