use balanced sums for constructors (to gracefully handle 100 constructors or more)
authorblanchet
Mon, 10 Sep 2012 17:35:53 +0200
changeset 502702ecc533d6697
parent 50269 edc322ac5279
child 50271 df98aeb80a19
use balanced sums for constructors (to gracefully handle 100 constructors or more)
src/HOL/Codatatype/BNF_Library.thy
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_util.ML
     1.1 --- a/src/HOL/Codatatype/BNF_Library.thy	Mon Sep 10 17:32:39 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/BNF_Library.thy	Mon Sep 10 17:35:53 2012 +0200
     1.3 @@ -19,7 +19,7 @@
     1.4  
     1.5  lemma all_unit_eq: "(\<And>x. PROP P x) \<equiv> PROP P ()" by simp
     1.6  
     1.7 -lemma all_prod_eq: "(\<And>x. PROP P x) \<equiv> (\<And>a b. PROP P (a, b))" by auto
     1.8 +lemma all_prod_eq: "(\<And>x. PROP P x) \<equiv> (\<And>a b. PROP P (a, b))" by clarsimp
     1.9  
    1.10  lemma False_imp_eq: "(False \<Longrightarrow> P) \<equiv> Trueprop True"
    1.11  by presburger
    1.12 @@ -89,23 +89,23 @@
    1.13  by (rule ext) (auto simp add: collect_def)
    1.14  
    1.15  lemma conj_subset_def: "A \<subseteq> {x. P x \<and> Q x} = (A \<subseteq> {x. P x} \<and> A \<subseteq> {x. Q x})"
    1.16 -by auto
    1.17 +by blast
    1.18  
    1.19  lemma subset_emptyI: "(\<And>x. x \<in> A \<Longrightarrow> False) \<Longrightarrow> A \<subseteq> {}"
    1.20 -by auto
    1.21 +by blast
    1.22  
    1.23  lemma rev_bspec: "a \<in> A \<Longrightarrow> \<forall>z \<in> A. P z \<Longrightarrow> P a"
    1.24  by simp
    1.25  
    1.26  lemma Un_cong: "\<lbrakk>A = B; C = D\<rbrakk> \<Longrightarrow> A \<union> C = B \<union> D"
    1.27 -by auto
    1.28 +by simp
    1.29  
    1.30  lemma UN_image_subset: "\<Union>f ` g x \<subseteq> X = (g x \<subseteq> {x. f x \<subseteq> X})"
    1.31 -by auto
    1.32 +by blast
    1.33  
    1.34  lemma image_Collect_subsetI:
    1.35    "(\<And>x. P x \<Longrightarrow> f x \<in> B) \<Longrightarrow> f ` {x. P x} \<subseteq> B"
    1.36 -by auto
    1.37 +by blast
    1.38  
    1.39  lemma comp_set_bd_Union_o_collect: "|\<Union>\<Union>(\<lambda>f. f x) ` X| \<le>o hbd \<Longrightarrow> |(Union \<circ> collect X) x| \<le>o hbd"
    1.40  by (unfold o_apply collect_def SUP_def)
    1.41 @@ -826,17 +826,21 @@
    1.42    "sum_case f (sum_case f' g') (Inr p) = sum_case f' g' p"
    1.43  by auto
    1.44  
    1.45 +lemma one_pointE: "\<lbrakk>\<And>x. s = x \<Longrightarrow> P\<rbrakk> \<Longrightarrow> P"
    1.46 +by simp
    1.47 +
    1.48 +lemma obj_sumE_f:
    1.49 +"\<lbrakk>\<forall>x. s = f (Inl x) \<longrightarrow> P; \<forall>x. s = f (Inr x) \<longrightarrow> P\<rbrakk> \<Longrightarrow> \<forall>x. s = f x \<longrightarrow> P"
    1.50 +by (metis sum.exhaust)
    1.51 +
    1.52  lemma obj_sumE: "\<lbrakk>\<forall>x. s = Inl x \<longrightarrow> P; \<forall>x. s = Inr x \<longrightarrow> P\<rbrakk> \<Longrightarrow> P"
    1.53  by (cases s) auto
    1.54  
    1.55 -lemma obj_sum_base: "\<lbrakk>\<And>x. s = x \<Longrightarrow> P\<rbrakk> \<Longrightarrow> P"
    1.56 -by auto
    1.57 -
    1.58  lemma obj_sum_step:
    1.59    "\<lbrakk>\<forall>x. s = f (Inr (Inl x)) \<longrightarrow> P; \<forall>x. s = f (Inr (Inr x)) \<longrightarrow> P\<rbrakk> \<Longrightarrow> \<forall>x. s = f (Inr x) \<longrightarrow> P"
    1.60  by (metis obj_sumE)
    1.61  
    1.62  lemma not_arg_cong_Inr: "x \<noteq> y \<Longrightarrow> Inr x \<noteq> Inr y"
    1.63 -by auto
    1.64 +by simp
    1.65  
    1.66  end
     2.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:32:39 2012 +0200
     2.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Mon Sep 10 17:35:53 2012 +0200
     2.3 @@ -50,11 +50,33 @@
     2.4  fun mk_uncurried2_fun f xss =
     2.5    mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
     2.6  
     2.7 +val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
     2.8 +val dest_sumTN_balanced = Balanced_Tree.dest dest_sumT;
     2.9 +
    2.10 +fun mk_InN_balanced ctxt sum_T Ts t k =
    2.11 +  let
    2.12 +    val u =
    2.13 +      Balanced_Tree.access {left = mk_Inl dummyT, right = mk_Inr dummyT, init = t} (length Ts) k;
    2.14 +  in singleton (Type_Infer_Context.infer_types ctxt) (Type.constraint sum_T u) end;
    2.15 +
    2.16 +val mk_sum_caseN_balanced = Balanced_Tree.make mk_sum_case;
    2.17 +
    2.18 +fun mk_sumEN_balanced n =
    2.19 +  let
    2.20 +    val thm =
    2.21 +      Balanced_Tree.make (fn (thm1, thm2) => thm1 RSN (1, thm2 RSN (2, @{thm obj_sumE_f})))
    2.22 +        (replicate n asm_rl) OF (replicate n (impI RS allI));
    2.23 +    val f as (_, f_T) =
    2.24 +      Term.add_vars (prop_of thm) []
    2.25 +      |> filter (fn ((s, _), _) => s = "f") |> the_single;
    2.26 +    val inst = [pairself (cterm_of @{theory}) (Var f, Abs (Name.uu, domain_type f_T, Bound 0))];
    2.27 +  in cterm_instantiate inst thm end;
    2.28 +
    2.29  fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
    2.30  
    2.31  fun tack z_name (c, v) f =
    2.32    let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
    2.33 -    Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
    2.34 +    Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
    2.35    end;
    2.36  
    2.37  fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    2.38 @@ -148,9 +170,9 @@
    2.39        | freeze_fp T = T;
    2.40  
    2.41      val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
    2.42 -    val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
    2.43 +    val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
    2.44  
    2.45 -    val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
    2.46 +    val eqs = map dest_TFree Xs ~~ ctr_sum_prod_TsXs;
    2.47  
    2.48      val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
    2.49          fp_iter_thms, fp_rec_thms), lthy)) =
    2.50 @@ -215,7 +237,7 @@
    2.51        if lfp then
    2.52          let
    2.53            val y_Tsss =
    2.54 -            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN n o domain_type)
    2.55 +            map3 (fn n => fn ms => map2 dest_tupleT ms o dest_sumTN_balanced n o domain_type)
    2.56                ns mss fp_iter_fun_Ts;
    2.57            val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
    2.58  
    2.59 @@ -225,8 +247,8 @@
    2.60              ||>> mk_Freesss "x" y_Tsss;
    2.61  
    2.62            val z_Tssss =
    2.63 -            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o dest_sumTN n
    2.64 -              o domain_type) ns mss fp_rec_fun_Ts;
    2.65 +            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
    2.66 +              dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
    2.67            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
    2.68  
    2.69            val hss = map2 (map2 retype_free) gss h_Tss;
    2.70 @@ -251,7 +273,7 @@
    2.71            fun mk_types fun_Ts =
    2.72              let
    2.73                val f_sum_prod_Ts = map range_type fun_Ts;
    2.74 -              val f_prod_Tss = map2 dest_sumTN ns f_sum_prod_Ts;
    2.75 +              val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
    2.76                val f_Tsss =
    2.77                  map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
    2.78                val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
    2.79 @@ -288,6 +310,7 @@
    2.80        let
    2.81          val unfT = domain_type (fastype_of fld);
    2.82          val ctr_prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
    2.83 +        val ctr_sum_prod_T = mk_sumTN_balanced ctr_prod_Ts;
    2.84          val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
    2.85  
    2.86          val ((((u, v), fs), xss), _) =
    2.87 @@ -299,12 +322,15 @@
    2.88  
    2.89          val ctr_rhss =
    2.90            map2 (fn k => fn xs =>
    2.91 -            fold_rev Term.lambda xs (fld $ mk_InN ctr_prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
    2.92 +              fold_rev Term.lambda xs (fld $ mk_InN_balanced no_defs_lthy ctr_sum_prod_T ctr_prod_Ts
    2.93 +                (HOLogic.mk_tuple xs) k))
    2.94 +            ks xss;
    2.95  
    2.96          val case_binder = Binding.suffix_name ("_" ^ caseN) b;
    2.97  
    2.98          val case_rhs =
    2.99 -          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
   2.100 +          fold_rev Term.lambda (fs @ [v])
   2.101 +            (mk_sum_caseN_balanced (map2 mk_uncurried_fun fs xss) $ (unf $ v));
   2.102  
   2.103          val ((raw_case :: raw_ctrs, raw_case_def :: raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   2.104            |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
   2.105 @@ -340,7 +366,8 @@
   2.106  
   2.107              val sumEN_thm' =
   2.108                Local_Defs.unfold lthy @{thms all_unit_eq}
   2.109 -                (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) [] (mk_sumEN n))
   2.110 +                (Drule.instantiate' (map (SOME o certifyT lthy) ctr_prod_Ts) []
   2.111 +                   (mk_sumEN_balanced n))
   2.112                |> Morphism.thm phi;
   2.113            in
   2.114              mk_exhaust_tac ctxt n ctr_defs fld_iff_unf_thm sumEN_thm'
   2.115 @@ -373,7 +400,7 @@
   2.116                  val spec =
   2.117                    mk_Trueprop_eq (lists_bmoc fss (Free (Binding.name_of binder, res_T)),
   2.118                      Term.list_comb (fp_iter_like,
   2.119 -                      map2 (mk_sum_caseN oo map2 mk_uncurried2_fun) fss xssss));
   2.120 +                      map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
   2.121                in (binder, spec) end;
   2.122  
   2.123              val iter_likes =
   2.124 @@ -411,7 +438,8 @@
   2.125  
   2.126                  fun mk_preds_getters_join c n cps sum_prod_T prod_Ts cfss =
   2.127                    Term.lambda c (mk_IfN sum_prod_T cps
   2.128 -                    (map2 (mk_InN prod_Ts) (map HOLogic.mk_tuple cfss) (1 upto n)));
   2.129 +                    (map2 (mk_InN_balanced no_defs_lthy sum_prod_T prod_Ts)
   2.130 +                      (map HOLogic.mk_tuple cfss) (1 upto n)));
   2.131  
   2.132                  val spec =
   2.133                    mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
     3.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Mon Sep 10 17:32:39 2012 +0200
     3.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Mon Sep 10 17:35:53 2012 +0200
     3.3 @@ -81,15 +81,18 @@
     3.4    val split_conj_thm: thm -> thm list
     3.5    val split_conj_prems: int -> thm -> thm
     3.6  
     3.7 +  val mk_sumTN: typ list -> typ
     3.8 +
     3.9    val Inl_const: typ -> typ -> term
    3.10    val Inr_const: typ -> typ -> term
    3.11  
    3.12 -  val mk_Inl: term -> typ -> term
    3.13 -  val mk_Inr: term -> typ -> term
    3.14 +  val mk_Inl: typ -> term -> term
    3.15 +  val mk_Inr: typ -> term -> term
    3.16    val mk_InN: typ list -> term -> int -> term
    3.17 -  val mk_sum_case: term -> term -> term
    3.18 +  val mk_sum_case: term * term -> term
    3.19    val mk_sum_caseN: term list -> term
    3.20  
    3.21 +  val dest_sumT: typ -> typ * typ
    3.22    val dest_sumTN: int -> typ -> typ list
    3.23    val dest_tupleT: int -> typ -> typ list
    3.24  
    3.25 @@ -197,18 +200,20 @@
    3.26  val set_inclN = "set_incl"
    3.27  val set_set_inclN = "set_set_incl"
    3.28  
    3.29 +fun mk_sumTN Ts = Library.foldr1 mk_sumT Ts;
    3.30 +
    3.31  fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
    3.32 -fun mk_Inl t RT = Inl_const (fastype_of t) RT $ t;
    3.33 +fun mk_Inl RT t = Inl_const (fastype_of t) RT $ t;
    3.34  
    3.35  fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
    3.36 -fun mk_Inr t LT = Inr_const LT (fastype_of t) $ t;
    3.37 +fun mk_Inr LT t = Inr_const LT (fastype_of t) $ t;
    3.38  
    3.39  fun mk_InN [_] t 1 = t
    3.40 -  | mk_InN (_ :: Ts) t 1 = mk_Inl t (mk_sumTN Ts)
    3.41 -  | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
    3.42 +  | mk_InN (_ :: Ts) t 1 = mk_Inl (mk_sumTN Ts) t
    3.43 +  | mk_InN (LT :: Ts) t m = mk_Inr LT (mk_InN Ts t (m - 1))
    3.44    | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
    3.45  
    3.46 -fun mk_sum_case f g =
    3.47 +fun mk_sum_case (f, g) =
    3.48    let
    3.49      val fT = fastype_of f;
    3.50      val gT = fastype_of g;
    3.51 @@ -218,7 +223,9 @@
    3.52    end;
    3.53  
    3.54  fun mk_sum_caseN [f] = f
    3.55 -  | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
    3.56 +  | mk_sum_caseN (f :: fs) = mk_sum_case (f, mk_sum_caseN fs);
    3.57 +
    3.58 +fun dest_sumT (Type (@{type_name sum}, [T, T'])) = (T, T');
    3.59  
    3.60  fun dest_sumTN 1 T = [T]
    3.61    | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
    3.62 @@ -247,7 +254,7 @@
    3.63        if i = n then th else split n (i + 1) (conjI RSN (i, th)) handle THM _ => th;
    3.64    in split limit 1 th end;
    3.65  
    3.66 -fun mk_sumEN 1 = @{thm obj_sum_base}
    3.67 +fun mk_sumEN 1 = @{thm one_pointE}
    3.68    | mk_sumEN 2 = @{thm sumE}
    3.69    | mk_sumEN n =
    3.70      (fold (fn i => fn thm => @{thm obj_sum_step} RSN (i, thm)) (2 upto n - 1) @{thm obj_sumE}) OF
     4.1 --- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Mon Sep 10 17:32:39 2012 +0200
     4.2 +++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Mon Sep 10 17:35:53 2012 +0200
     4.3 @@ -506,7 +506,7 @@
     4.4      val mor_sum_case_thm =
     4.5        let
     4.6          val maps = map3 (fn s => fn sum_s => fn map =>
     4.7 -          mk_sum_case (HOLogic.mk_comp (Term.list_comb (map, passive_ids @ Inls), s)) sum_s)
     4.8 +          mk_sum_case (HOLogic.mk_comp (Term.list_comb (map, passive_ids @ Inls), s), sum_s))
     4.9            s's sum_ss map_Inls;
    4.10        in
    4.11          Skip_Proof.prove lthy [] []
    4.12 @@ -2144,7 +2144,7 @@
    4.13        let
    4.14          val corecT = Library.foldr (op -->) (corec_sTs, AT --> T);
    4.15          val maps = map3 (fn unf => fn sum_s => fn map => mk_sum_case
    4.16 -            (HOLogic.mk_comp (Term.list_comb (map, passive_ids @ corec_Inls), unf)) sum_s)
    4.17 +            (HOLogic.mk_comp (Term.list_comb (map, passive_ids @ corec_Inls), unf), sum_s))
    4.18            unfs corec_ss corec_maps;
    4.19  
    4.20          val lhs = Term.list_comb (Free (corec_name i, corecT), corec_ss);
    4.21 @@ -2171,7 +2171,7 @@
    4.22      val corec_defs = map (Morphism.thm phi) corec_def_frees;
    4.23  
    4.24      val sum_cases =
    4.25 -      map2 (fn T => fn i => mk_sum_case (HOLogic.id_const T) (mk_corec corec_ss i)) Ts ks;
    4.26 +      map2 (fn T => fn i => mk_sum_case (HOLogic.id_const T, mk_corec corec_ss i)) Ts ks;
    4.27      val corec_thms =
    4.28        let
    4.29          fun mk_goal i corec_s corec_map unf z =
     5.1 --- a/src/HOL/Codatatype/Tools/bnf_util.ML	Mon Sep 10 17:32:39 2012 +0200
     5.2 +++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Mon Sep 10 17:35:53 2012 +0200
     5.3 @@ -62,7 +62,6 @@
     5.4    val mk_relT: typ * typ -> typ
     5.5    val dest_relT: typ -> typ * typ
     5.6    val mk_sumT: typ * typ -> typ
     5.7 -  val mk_sumTN: typ list -> typ
     5.8  
     5.9    val ctwo: term
    5.10    val fst_const: typ -> term
    5.11 @@ -304,7 +303,6 @@
    5.12  val mk_relT = HOLogic.mk_setT o HOLogic.mk_prodT;
    5.13  val dest_relT = HOLogic.dest_prodT o HOLogic.dest_setT;
    5.14  fun mk_sumT (LT, RT) = Type (@{type_name Sum_Type.sum}, [LT, RT]);
    5.15 -fun mk_sumTN Ts = Library.foldr1 mk_sumT Ts;
    5.16  fun mk_partial_funT (ranT, domT) = domT --> mk_optionT ranT;
    5.17  
    5.18