generate 'rec_o_map' and 'size_o_map' in size extension
authorblanchet
Wed, 23 Apr 2014 10:23:26 +0200
changeset 579820a35354137a5
parent 57981 c9d6b581bd3b
child 57983 029997d3b5d8
generate 'rec_o_map' and 'size_o_map' in size extension
src/HOL/BNF_FP_Base.thy
src/HOL/BNF_LFP.thy
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_fp_n2m.ML
src/HOL/Tools/BNF/bnf_lfp_size.ML
src/HOL/Tools/BNF/bnf_util.ML
     1.1 --- a/src/HOL/BNF_FP_Base.thy	Wed Apr 23 10:23:26 2014 +0200
     1.2 +++ b/src/HOL/BNF_FP_Base.thy	Wed Apr 23 10:23:26 2014 +0200
     1.3 @@ -89,34 +89,34 @@
     1.4  lemma spec2: "\<forall>x y. P x y \<Longrightarrow> P x y"
     1.5  by blast
     1.6  
     1.7 -lemma rewriteR_comp_comp: "\<lbrakk>g o h = r\<rbrakk> \<Longrightarrow> f o g o h = f o r"
     1.8 +lemma rewriteR_comp_comp: "\<lbrakk>g \<circ> h = r\<rbrakk> \<Longrightarrow> f \<circ> g \<circ> h = f \<circ> r"
     1.9    unfolding comp_def fun_eq_iff by auto
    1.10  
    1.11 -lemma rewriteR_comp_comp2: "\<lbrakk>g o h = r1 o r2; f o r1 = l\<rbrakk> \<Longrightarrow> f o g o h = l o r2"
    1.12 +lemma rewriteR_comp_comp2: "\<lbrakk>g \<circ> h = r1 \<circ> r2; f \<circ> r1 = l\<rbrakk> \<Longrightarrow> f \<circ> g \<circ> h = l \<circ> r2"
    1.13    unfolding comp_def fun_eq_iff by auto
    1.14  
    1.15 -lemma rewriteL_comp_comp: "\<lbrakk>f o g = l\<rbrakk> \<Longrightarrow> f o (g o h) = l o h"
    1.16 +lemma rewriteL_comp_comp: "\<lbrakk>f \<circ> g = l\<rbrakk> \<Longrightarrow> f \<circ> (g \<circ> h) = l \<circ> h"
    1.17    unfolding comp_def fun_eq_iff by auto
    1.18  
    1.19 -lemma rewriteL_comp_comp2: "\<lbrakk>f o g = l1 o l2; l2 o h = r\<rbrakk> \<Longrightarrow> f o (g o h) = l1 o r"
    1.20 +lemma rewriteL_comp_comp2: "\<lbrakk>f \<circ> g = l1 \<circ> l2; l2 \<circ> h = r\<rbrakk> \<Longrightarrow> f \<circ> (g \<circ> h) = l1 \<circ> r"
    1.21    unfolding comp_def fun_eq_iff by auto
    1.22  
    1.23 -lemma convol_o: "<f, g> o h = <f o h, g o h>"
    1.24 +lemma convol_o: "<f, g> \<circ> h = <f \<circ> h, g \<circ> h>"
    1.25    unfolding convol_def by auto
    1.26  
    1.27 -lemma map_prod_o_convol: "map_prod h1 h2 o <f, g> = <h1 o f, h2 o g>"
    1.28 +lemma map_prod_o_convol: "map_prod h1 h2 \<circ> <f, g> = <h1 \<circ> f, h2 \<circ> g>"
    1.29    unfolding convol_def by auto
    1.30  
    1.31  lemma map_prod_o_convol_id: "(map_prod f id \<circ> <id , g>) x = <id \<circ> f , g> x"
    1.32    unfolding map_prod_o_convol id_comp comp_id ..
    1.33  
    1.34 -lemma o_case_sum: "h o case_sum f g = case_sum (h o f) (h o g)"
    1.35 +lemma o_case_sum: "h \<circ> case_sum f g = case_sum (h \<circ> f) (h \<circ> g)"
    1.36    unfolding comp_def by (auto split: sum.splits)
    1.37  
    1.38 -lemma case_sum_o_map_sum: "case_sum f g o map_sum h1 h2 = case_sum (f o h1) (g o h2)"
    1.39 +lemma case_sum_o_map_sum: "case_sum f g \<circ> map_sum h1 h2 = case_sum (f \<circ> h1) (g \<circ> h2)"
    1.40    unfolding comp_def by (auto split: sum.splits)
    1.41  
    1.42 -lemma case_sum_o_map_sum_id: "(case_sum id g o map_sum f id) x = case_sum (f o id) g x"
    1.43 +lemma case_sum_o_map_sum_id: "(case_sum id g \<circ> map_sum f id) x = case_sum (f \<circ> id) g x"
    1.44    unfolding case_sum_o_map_sum id_comp comp_id ..
    1.45  
    1.46  lemma rel_fun_def_butlast:
    1.47 @@ -144,7 +144,7 @@
    1.48  
    1.49  lemma
    1.50    assumes "type_definition Rep Abs UNIV"
    1.51 -  shows type_copy_Rep_o_Abs: "Rep \<circ> Abs = id" and type_copy_Abs_o_Rep: "Abs o Rep = id"
    1.52 +  shows type_copy_Rep_o_Abs: "Rep \<circ> Abs = id" and type_copy_Abs_o_Rep: "Abs \<circ> Rep = id"
    1.53    unfolding fun_eq_iff comp_apply id_apply
    1.54      type_definition.Abs_inverse[OF assms UNIV_I] type_definition.Rep_inverse[OF assms] by simp_all
    1.55  
    1.56 @@ -152,7 +152,7 @@
    1.57    assumes "type_definition Rep Abs UNIV"
    1.58            "type_definition Rep' Abs' UNIV"
    1.59            "type_definition Rep'' Abs'' UNIV"
    1.60 -  shows "Abs' o M o Rep'' = (Abs' o M1 o Rep) o (Abs o M2 o Rep'') \<Longrightarrow> M1 o M2 = M"
    1.61 +  shows "Abs' \<circ> M \<circ> Rep'' = (Abs' \<circ> M1 \<circ> Rep) \<circ> (Abs \<circ> M2 \<circ> Rep'') \<Longrightarrow> M1 \<circ> M2 = M"
    1.62    by (rule sym) (auto simp: fun_eq_iff type_definition.Abs_inject[OF assms(2) UNIV_I UNIV_I]
    1.63      type_definition.Abs_inverse[OF assms(1) UNIV_I]
    1.64      type_definition.Abs_inverse[OF assms(3) UNIV_I] dest: spec[of _ "Abs'' x" for x])
    1.65 @@ -160,7 +160,7 @@
    1.66  lemma vimage2p_id: "vimage2p id id R = R"
    1.67    unfolding vimage2p_def by auto
    1.68  
    1.69 -lemma vimage2p_comp: "vimage2p (f1 o f2) (g1 o g2) = vimage2p f2 g2 o vimage2p f1 g1"
    1.70 +lemma vimage2p_comp: "vimage2p (f1 \<circ> f2) (g1 \<circ> g2) = vimage2p f2 g2 \<circ> vimage2p f1 g1"
    1.71    unfolding fun_eq_iff vimage2p_def o_apply by simp
    1.72  
    1.73  ML_file "Tools/BNF/bnf_fp_util.ML"
     2.1 --- a/src/HOL/BNF_LFP.thy	Wed Apr 23 10:23:26 2014 +0200
     2.2 +++ b/src/HOL/BNF_LFP.thy	Wed Apr 23 10:23:26 2014 +0200
     2.3 @@ -186,12 +186,29 @@
     2.4  lemma ssubst_Pair_rhs: "\<lbrakk>(r, s) \<in> R; s' = s\<rbrakk> \<Longrightarrow> (r, s') \<in> R"
     2.5    by (rule ssubst)
     2.6  
     2.7 +lemma fun_cong_unused_0: "f = (\<lambda>x. g) \<Longrightarrow> f (\<lambda>x. 0) = g"
     2.8 +  by (erule arg_cong)
     2.9 +
    2.10  lemma snd_o_convol: "(snd \<circ> (\<lambda>x. (f x, g x))) = g"
    2.11    by (rule ext) simp
    2.12  
    2.13  lemma inj_on_convol_id: "inj_on (\<lambda>x. (x, f x)) X"
    2.14    unfolding inj_on_def by simp
    2.15  
    2.16 +lemma case_prod_app: "case_prod f x y = case_prod (\<lambda>l r. f l r y) x"
    2.17 +  by (case_tac x) simp
    2.18 +
    2.19 +lemma case_sum_o_map_sum: "case_sum l r (map_sum f g x) = case_sum (l \<circ> f) (r \<circ> g) x"
    2.20 +  by (case_tac x) simp+
    2.21 +
    2.22 +lemma case_prod_o_map_prod: "case_prod h (map_prod f g x) = case_prod (\<lambda>l r. h (f l) (g r)) x"
    2.23 +  by (case_tac x) simp+
    2.24 +
    2.25 +lemma prod_inj_map: "inj f \<Longrightarrow> inj g \<Longrightarrow> inj (map_prod f g)"
    2.26 +  by (simp add: inj_on_def)
    2.27 +
    2.28 +declare [[ML_print_depth = 10000]] (*###*)
    2.29 +
    2.30  ML_file "Tools/BNF/bnf_lfp_util.ML"
    2.31  ML_file "Tools/BNF/bnf_lfp_tactics.ML"
    2.32  ML_file "Tools/BNF/bnf_lfp.ML"
    2.33 @@ -201,4 +218,58 @@
    2.34  
    2.35  hide_fact (open) id_transfer
    2.36  
    2.37 +datatype_new ('a, 'b) j = J0 | J 'a "('a, 'b) j"
    2.38 +thm j.size j.rec_o_map j.size_o_map
    2.39 +
    2.40 +datatype_new 'a l = N nat nat | C 'a "'a l"
    2.41 +thm l.size l.rec_o_map l.size_o_map
    2.42 +
    2.43 +datatype_new ('a, 'b) x = XN 'b | XC 'a "('a, 'b) x"
    2.44 +thm x.size x.rec_o_map x.size_o_map
    2.45 +
    2.46 +datatype_new
    2.47 +  'a tl = TN | TC "'a mt" "'a tl" and
    2.48 +  'a mt = MT 'a "'a tl"
    2.49 +thm tl.size tl.rec_o_map tl.size_o_map
    2.50 +thm mt.size mt.rec_o_map mt.size_o_map
    2.51 +
    2.52 +datatype_new 'a t = T nat 'a "'a t l"
    2.53 +thm t.size t.rec_o_map t.size_o_map
    2.54 +
    2.55 +datatype_new 'a fset = FSet0 | FSet 'a "'a fset"
    2.56 +thm fset.size fset.rec_o_map fset.size_o_map
    2.57 +
    2.58 +datatype_new 'a u = U 'a "'a u fset"
    2.59 +thm u.size u.rec_o_map u.size_o_map
    2.60 +
    2.61 +datatype_new
    2.62 +  ('a, 'b) v = V "nat l" | V' 'a "('a, 'b) w" and
    2.63 +  ('a, 'b) w = W 'b "('a, 'b) v fset l"
    2.64 +thm v.size v.rec_o_map v.size_o_map
    2.65 +thm w.size w.rec_o_map w.size_o_map
    2.66 +
    2.67 +(*TODO:
    2.68 +* deal with *unused* dead variables and other odd cases (e.g. recursion through fun)
    2.69 +* what happens if recursion through arbitrary bnf, like 'fsize'?
    2.70 +  * by default
    2.71 +  * offer possibility to register size function and theorems
    2.72 +* non-recursive types use 'case' instead of 'rec', causes trouble (revert?)
    2.73 +* compat with old size?
    2.74 +  * recursion of old through new (e.g. through list)?
    2.75 +  * recursion of new through old?
    2.76 +  * should they share theory data?
    2.77 +* code generator setup?
    2.78 +*)
    2.79 +
    2.80 +
    2.81  end
    2.82 +datatype_new 'a x = X0 | X 'a (*###*)
    2.83 +thm x.size
    2.84 +thm x.size_o_map
    2.85 +datatype_new 'a x = X0 | X 'a "'a x" (*###*)
    2.86 +thm x.size
    2.87 +thm x.size_o_map
    2.88 +datatype_new 'a l = N | C 'a "'a l"
    2.89 +datatype_new ('a, 'b) tl = TN 'b | TC 'a "'a l"
    2.90 +
    2.91 +end
     3.1 --- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Wed Apr 23 10:23:26 2014 +0200
     3.2 +++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Wed Apr 23 10:23:26 2014 +0200
     3.3 @@ -22,6 +22,7 @@
     3.4       ctr_defs: thm list,
     3.5       ctr_sugar: Ctr_Sugar.ctr_sugar,
     3.6       co_rec: term,
     3.7 +     co_rec_def: thm,
     3.8       maps: thm list,
     3.9       common_co_inducts: thm list,
    3.10       co_inducts: thm list,
    3.11 @@ -137,6 +138,7 @@
    3.12     ctr_defs: thm list,
    3.13     ctr_sugar: Ctr_Sugar.ctr_sugar,
    3.14     co_rec: term,
    3.15 +   co_rec_def: thm,
    3.16     maps: thm list,
    3.17     common_co_inducts: thm list,
    3.18     co_inducts: thm list,
    3.19 @@ -161,6 +163,7 @@
    3.20     ctr_defs = map (Morphism.thm phi) ctr_defs,
    3.21     ctr_sugar = morph_ctr_sugar phi ctr_sugar,
    3.22     co_rec = Morphism.term phi co_rec,
    3.23 +   co_rec_def = Morphism.thm phi co_rec_def,
    3.24     maps = map (Morphism.thm phi) maps,
    3.25     common_co_inducts = map (Morphism.thm phi) common_co_inducts,
    3.26     co_inducts = map (Morphism.thm phi) co_inducts,
     4.1 --- a/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Wed Apr 23 10:23:26 2014 +0200
     4.2 +++ b/src/HOL/Tools/BNF/bnf_fp_n2m.ML	Wed Apr 23 10:23:26 2014 +0200
     4.3 @@ -239,7 +239,7 @@
     4.4      val co_recs = of_fp_res #xtor_co_recs;
     4.5      val ns = map (length o #Ts o #fp_res) fp_sugars;
     4.6  
     4.7 -    fun substT rho (Type (@{type_name "fun"}, [T, U])) = substT rho T --> substT rho U
     4.8 +    fun substT rho (Type (@{type_name fun}, [T, U])) = substT rho T --> substT rho U
     4.9        | substT rho (Type (s, Ts)) = Type (s, map (typ_subst_nonatomic rho) Ts)
    4.10        | substT _ T = T;
    4.11  
     5.1 --- a/src/HOL/Tools/BNF/bnf_lfp_size.ML	Wed Apr 23 10:23:26 2014 +0200
     5.2 +++ b/src/HOL/Tools/BNF/bnf_lfp_size.ML	Wed Apr 23 10:23:26 2014 +0200
     5.3 @@ -9,13 +9,15 @@
     5.4  struct
     5.5  
     5.6  open BNF_Util
     5.7 +open BNF_Tactics
     5.8  open BNF_Def
     5.9  open BNF_FP_Def_Sugar
    5.10  
    5.11  val size_N = "size_"
    5.12  
    5.13 +val rec_o_mapN = "rec_o_map"
    5.14  val sizeN = "size"
    5.15 -val size_mapN = "size_map"
    5.16 +val size_o_mapN = "size_o_map"
    5.17  
    5.18  structure Data = Theory_Data
    5.19  (
    5.20 @@ -34,209 +36,287 @@
    5.21  
    5.22  fun mk_abs_zero_nat T = Term.absdummy T zero_nat;
    5.23  
    5.24 -fun generate_size (fp_sugars as ({T = Type (_, As), BT = Type (_, Bs),
    5.25 -    fp_res = {bnfs = fp_bnfs, ...}, common_co_inducts = common_inducts, ...} : fp_sugar) :: _) thy =
    5.26 -  let
    5.27 -    val data = Data.get thy;
    5.28 +fun pointfill ctxt th = unfold_thms ctxt [o_apply] (th RS fun_cong);
    5.29  
    5.30 -    val Ts = map #T fp_sugars
    5.31 -    val T_names = map (fst o dest_Type) Ts;
    5.32 -    val nn = length Ts;
    5.33 +fun mk_unabs_def_unused_0 n =
    5.34 +  funpow n (fn thm => thm RS @{thm fun_cong_unused_0} handle THM _ => thm RS fun_cong);
    5.35  
    5.36 -    val B_ify = Term.typ_subst_atomic (As ~~ Bs);
    5.37 +val rec_o_map_simp_thms =
    5.38 +  @{thms o_def id_apply case_prod_app case_sum_o_map_sum case_prod_o_map_prod
    5.39 +      BNF_Comp.id_bnf_comp_def};
    5.40  
    5.41 -    val recs = map #co_rec fp_sugars;
    5.42 -    val rec_thmss = map #co_rec_thms fp_sugars;
    5.43 -    val rec_Ts = map fastype_of recs;
    5.44 -    val Cs = map body_type rec_Ts;
    5.45 -    val Cs_rho = map (rpair HOLogic.natT) Cs;
    5.46 -    val substCT = Term.subst_atomic_types Cs_rho;
    5.47 +fun mk_rec_o_map_tac ctxt rec_def pre_map_defs abs_inverses ctor_rec_o_map =
    5.48 +  unfold_thms_tac ctxt [rec_def] THEN
    5.49 +  HEADGOAL (rtac (ctor_rec_o_map RS trans) THEN'
    5.50 +    K (PRIMITIVE (Conv.fconv_rule Thm.eta_long_conversion)) THEN' asm_simp_tac
    5.51 +      (ss_only (pre_map_defs @ distinct Thm.eq_thm_prop abs_inverses @ rec_o_map_simp_thms) ctxt));
    5.52  
    5.53 -    val f_Ts = map mk_to_natT As;
    5.54 -    val f_TsB = map mk_to_natT Bs;
    5.55 -    val num_As = length As;
    5.56 +val size_o_map_simp_thms =
    5.57 +  @{thms o_apply prod_inj_map inj_on_id snd_comp_apfst[unfolded apfst_def]};
    5.58  
    5.59 -    val f_names = map (prefix "f" o string_of_int) (1 upto num_As);
    5.60 -    val fs = map2 (curry Free) f_names f_Ts;
    5.61 -    val fsB = map2 (curry Free) f_names f_TsB;
    5.62 -    val As_fs = As ~~ fs;
    5.63 +fun mk_size_o_map_tac ctxt size_def rec_o_map inj_maps size_maps =
    5.64 +  unfold_thms_tac ctxt [size_def] THEN
    5.65 +  HEADGOAL (rtac (rec_o_map RS trans) THEN'
    5.66 +    asm_simp_tac (ss_only (inj_maps @ size_maps @ size_o_map_simp_thms) ctxt));
    5.67  
    5.68 -    val gen_size_names = map (Long_Name.map_base_name (prefix size_N)) T_names;
    5.69 +fun generate_size (fp_sugars as ({T = Type (_, As), BT = Type (_, Bs), fp = Least_FP,
    5.70 +        fp_res = {bnfs = fp_bnfs, xtor_co_rec_o_map_thms = ctor_rec_o_maps, ...}, nested_bnfs, ...}
    5.71 +      : fp_sugar) :: _) thy =
    5.72 +    let
    5.73 +      val data = Data.get thy;
    5.74  
    5.75 -    fun is_pair_C @{type_name prod} [_, T'] = member (op =) Cs T'
    5.76 -      | is_pair_C _ _ = false;
    5.77 +      val Ts = map #T fp_sugars
    5.78 +      val T_names = map (fst o dest_Type) Ts;
    5.79 +      val nn = length Ts;
    5.80  
    5.81 -    fun mk_size_of_typ (T as TFree _) =
    5.82 -        pair (case AList.lookup (op =) As_fs T of
    5.83 -            SOME f => f
    5.84 -          | NONE => if member (op =) Cs T then Term.absdummy T (Bound 0) else mk_abs_zero_nat T)
    5.85 -      | mk_size_of_typ (T as Type (s, Ts)) =
    5.86 -        if is_pair_C s Ts then
    5.87 -          pair (snd_const T)
    5.88 -        else if exists (exists_subtype_in As) Ts then
    5.89 -          (case Symtab.lookup data s of
    5.90 -            SOME (gen_size_name, (_, gen_size_maps)) =>
    5.91 -            let
    5.92 -              val (args, gen_size_mapss') = split_list (map (fn T => mk_size_of_typ T []) Ts);
    5.93 -              val gen_size_const = Const (gen_size_name, map fastype_of args ---> mk_to_natT T);
    5.94 -            in
    5.95 -              fold (union Thm.eq_thm) (gen_size_maps :: gen_size_mapss')
    5.96 -              #> pair (Term.list_comb (gen_size_const, args))
    5.97 -            end
    5.98 -          | NONE => pair (mk_abs_zero_nat T))
    5.99 +      val B_ify = Term.typ_subst_atomic (As ~~ Bs);
   5.100 +
   5.101 +      val recs = map #co_rec fp_sugars;
   5.102 +      val rec_thmss = map #co_rec_thms fp_sugars;
   5.103 +      val rec_Ts as rec_T1 :: _ = map fastype_of recs;
   5.104 +      val rec_arg_Ts = binder_fun_types rec_T1;
   5.105 +      val Cs = map body_type rec_Ts;
   5.106 +      val Cs_rho = map (rpair HOLogic.natT) Cs;
   5.107 +      val substCnatT = Term.subst_atomic_types Cs_rho;
   5.108 +
   5.109 +      val f_Ts = map mk_to_natT As;
   5.110 +      val f_TsB = map mk_to_natT Bs;
   5.111 +      val num_As = length As;
   5.112 +
   5.113 +      val f_names = map (prefix "f" o string_of_int) (1 upto num_As);
   5.114 +      val fs = map2 (curry Free) f_names f_Ts;
   5.115 +      val fsB = map2 (curry Free) f_names f_TsB;
   5.116 +      val As_fs = As ~~ fs;
   5.117 +
   5.118 +      val size_names = map (Long_Name.map_base_name (prefix size_N)) T_names;
   5.119 +
   5.120 +      fun is_pair_C @{type_name prod} [_, T'] = member (op =) Cs T'
   5.121 +        | is_pair_C _ _ = false;
   5.122 +
   5.123 +      fun mk_size_of_typ (T as TFree _) =
   5.124 +          pair (case AList.lookup (op =) As_fs T of
   5.125 +              SOME f => f
   5.126 +            | NONE => if member (op =) Cs T then Term.absdummy T (Bound 0) else mk_abs_zero_nat T)
   5.127 +        | mk_size_of_typ (T as Type (s, Ts)) =
   5.128 +          if is_pair_C s Ts then
   5.129 +            pair (snd_const T)
   5.130 +          else if exists (exists_subtype_in As) Ts then
   5.131 +            (case Symtab.lookup data s of
   5.132 +              SOME (size_name, (_, size_o_maps)) =>
   5.133 +              let
   5.134 +                val (args, size_o_mapss') = split_list (map (fn T => mk_size_of_typ T []) Ts);
   5.135 +                val size_const = Const (size_name, map fastype_of args ---> mk_to_natT T);
   5.136 +              in
   5.137 +                fold (union Thm.eq_thm) (size_o_maps :: size_o_mapss')
   5.138 +                #> pair (Term.list_comb (size_const, args))
   5.139 +              end
   5.140 +            | NONE => pair (mk_abs_zero_nat T))
   5.141 +          else
   5.142 +            pair (mk_abs_zero_nat T);
   5.143 +
   5.144 +      fun mk_size_of_arg t =
   5.145 +        mk_size_of_typ (fastype_of t) #>> (fn s => substCnatT (betapply (s, t)));
   5.146 +
   5.147 +      fun mk_size_arg rec_arg_T size_o_maps =
   5.148 +        let
   5.149 +          val x_Ts = binder_types rec_arg_T;
   5.150 +          val m = length x_Ts;
   5.151 +          val x_names = map (prefix "x" o string_of_int) (1 upto m);
   5.152 +          val xs = map2 (curry Free) x_names x_Ts;
   5.153 +          val (summands, size_o_maps') =
   5.154 +            fold_map mk_size_of_arg xs size_o_maps
   5.155 +            |>> remove (op =) zero_nat;
   5.156 +          val sum =
   5.157 +            if null summands then HOLogic.zero
   5.158 +            else foldl1 mk_plus_nat (summands @ [HOLogic.Suc_zero]);
   5.159 +        in
   5.160 +          (fold_rev Term.lambda (map substCnatT xs) sum, size_o_maps')
   5.161 +        end;
   5.162 +
   5.163 +      fun mk_size_rhs recx size_o_maps =
   5.164 +        let val (args, size_o_maps') = fold_map mk_size_arg rec_arg_Ts size_o_maps in
   5.165 +          (fold_rev Term.lambda fs (Term.list_comb (substCnatT recx, args)), size_o_maps')
   5.166 +        end;
   5.167 +
   5.168 +      fun mk_def_binding f =
   5.169 +        Binding.conceal o Binding.name o Thm.def_name o f o Long_Name.base_name;
   5.170 +
   5.171 +      val (size_rhss, nested_size_o_maps) = fold_map mk_size_rhs recs [];
   5.172 +      val size_Ts = map fastype_of size_rhss;
   5.173 +      val size_consts = map2 (curry Const) size_names size_Ts;
   5.174 +      val size_constsB = map (Term.map_types B_ify) size_consts;
   5.175 +      val size_def_bs = map (mk_def_binding I) size_names;
   5.176 +
   5.177 +      val (size_defs, thy2) =
   5.178 +        thy
   5.179 +        |> Sign.add_consts (map (fn (s, T) => (Binding.name (Long_Name.base_name s), T, NoSyn))
   5.180 +          (size_names ~~ size_Ts))
   5.181 +        |> Global_Theory.add_defs false (map Thm.no_attributes (size_def_bs ~~
   5.182 +          map Logic.mk_equals (size_consts ~~ size_rhss)));
   5.183 +
   5.184 +      val zeros = map mk_abs_zero_nat As;
   5.185 +
   5.186 +      val overloaded_size_rhss = map (fn c => Term.list_comb (c, zeros)) size_consts;
   5.187 +      val overloaded_size_Ts = map fastype_of overloaded_size_rhss;
   5.188 +      val overloaded_size_consts = map (curry Const @{const_name size}) overloaded_size_Ts;
   5.189 +      val overloaded_size_def_bs = map (mk_def_binding (suffix "_overloaded")) size_names;
   5.190 +
   5.191 +      fun define_overloaded_size def_b lhs0 rhs lthy =
   5.192 +        let
   5.193 +          val Free (c, _) = Syntax.check_term lthy lhs0;
   5.194 +          val (thm, lthy') = lthy
   5.195 +            |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs))
   5.196 +            |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
   5.197 +          val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
   5.198 +          val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
   5.199 +        in (thm', lthy') end;
   5.200 +
   5.201 +      val (overloaded_size_defs, thy3) = thy2
   5.202 +        |> Class.instantiation (T_names, map dest_TFree As, [HOLogic.class_size])
   5.203 +        |> fold_map3 define_overloaded_size overloaded_size_def_bs overloaded_size_consts
   5.204 +          overloaded_size_rhss
   5.205 +        ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
   5.206 +        ||> Local_Theory.exit_global;
   5.207 +
   5.208 +      val thy3_ctxt = Proof_Context.init_global thy3;
   5.209 +
   5.210 +      val size_defs' =
   5.211 +        map (mk_unabs_def (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) size_defs;
   5.212 +      val size_defs_unused_0 =
   5.213 +        map (mk_unabs_def_unused_0 (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) size_defs;
   5.214 +      val overloaded_size_defs' =
   5.215 +        map (mk_unabs_def 1 o (fn thm => thm RS meta_eq_to_obj_eq)) overloaded_size_defs;
   5.216 +
   5.217 +      val nested_size_maps = map (pointfill thy3_ctxt) nested_size_o_maps @ nested_size_o_maps;
   5.218 +      val all_inj_maps = map inj_map_of_bnf (fp_bnfs @ nested_bnfs);
   5.219 +
   5.220 +      fun derive_size_simp size_def' simp0 =
   5.221 +        (trans OF [size_def', simp0])
   5.222 +        |> Simplifier.asm_full_simplify (ss_only (@{thms inj_on_convol_id snd_o_convol} @
   5.223 +          all_inj_maps @ nested_size_maps) thy3_ctxt)
   5.224 +        |> fold_thms thy3_ctxt size_defs_unused_0;
   5.225 +      fun derive_overloaded_size_simp size_def' simp0 =
   5.226 +        (trans OF [size_def', simp0])
   5.227 +        |> unfold_thms thy3_ctxt @{thms add_0_left add_0_right}
   5.228 +        |> fold_thms thy3_ctxt overloaded_size_defs';
   5.229 +
   5.230 +      val size_simpss = map2 (map o derive_size_simp) size_defs' rec_thmss;
   5.231 +      val size_simps = flat size_simpss;
   5.232 +      val overloaded_size_simpss =
   5.233 +        map2 (map o derive_overloaded_size_simp) overloaded_size_defs' size_simpss;
   5.234 +
   5.235 +      val ABs = As ~~ Bs;
   5.236 +      val g_names = map (prefix "g" o string_of_int) (1 upto num_As);
   5.237 +      val gs = map2 (curry Free) g_names (map (op -->) ABs);
   5.238 +
   5.239 +      val liveness = map (op <>) ABs;
   5.240 +      val live_gs = AList.find (op =) (gs ~~ liveness) true;
   5.241 +      val live = length live_gs;
   5.242 +
   5.243 +      val maps0 = map map_of_bnf fp_bnfs;
   5.244 +
   5.245 +      val (rec_o_map_thmss, size_o_map_thmss) =
   5.246 +        if live = 0 then
   5.247 +          `I (replicate nn [])
   5.248          else
   5.249 -          pair (mk_abs_zero_nat T);
   5.250 +          let
   5.251 +            val pre_bnfs = map #pre_bnf fp_sugars;
   5.252 +            val pre_map_defs = map map_def_of_bnf pre_bnfs;
   5.253 +            val abs_inverses = map (#abs_inverse o #absT_info) fp_sugars;
   5.254 +            val rec_defs = map #co_rec_def fp_sugars;
   5.255  
   5.256 -    fun mk_size_of_arg t =
   5.257 -      mk_size_of_typ (fastype_of t) #>> (fn s => substCT (betapply (s, t)));
   5.258 +            val gmaps = map (fn map0 => Term.list_comb (mk_map live As Bs map0, live_gs)) maps0;
   5.259  
   5.260 -    fun mk_gen_size_arg arg_T gen_size_maps =
   5.261 -      let
   5.262 -        val x_Ts = binder_types arg_T;
   5.263 -        val m = length x_Ts;
   5.264 -        val x_names = map (prefix "x" o string_of_int) (1 upto m);
   5.265 -        val xs = map2 (curry Free) x_names x_Ts;
   5.266 -        val (summands, gen_size_maps') =
   5.267 -          fold_map mk_size_of_arg xs gen_size_maps
   5.268 -          |>> remove (op =) zero_nat;
   5.269 -        val sum =
   5.270 -          if null summands then HOLogic.zero
   5.271 -          else foldl1 mk_plus_nat (summands @ [HOLogic.Suc_zero]);
   5.272 -      in
   5.273 -        (fold_rev Term.lambda (map substCT xs) sum, gen_size_maps')
   5.274 -      end;
   5.275 +            val num_rec_args = length rec_arg_Ts;
   5.276 +            val h_Ts = map B_ify rec_arg_Ts;
   5.277 +            val h_names = map (prefix "h" o string_of_int) (1 upto num_rec_args);
   5.278 +            val hs = map2 (curry Free) h_names h_Ts;
   5.279 +            val hrecs = map (fn recx => Term.list_comb (Term.map_types B_ify recx, hs)) recs;
   5.280  
   5.281 -    fun mk_gen_size_rhs rec_T recx gen_size_maps =
   5.282 -      let
   5.283 -        val arg_Ts = binder_fun_types rec_T;
   5.284 -        val (args, gen_size_maps') = fold_map mk_gen_size_arg arg_Ts gen_size_maps;
   5.285 -      in
   5.286 -        (fold_rev Term.lambda fs (Term.list_comb (substCT recx, args)), gen_size_maps')
   5.287 -      end;
   5.288 +            val rec_o_map_lhss = map2 (curry HOLogic.mk_comp) hrecs gmaps;
   5.289  
   5.290 -    fun mk_def_binding f = Binding.conceal o Binding.name o Thm.def_name o f o Long_Name.base_name;
   5.291 +            val ABgs = ABs ~~ gs;
   5.292  
   5.293 -    val (gen_size_rhss, nested_gen_size_maps) = fold_map2 mk_gen_size_rhs rec_Ts recs [];
   5.294 -    val gen_size_Ts = map fastype_of gen_size_rhss;
   5.295 -    val gen_size_consts = map2 (curry Const) gen_size_names gen_size_Ts;
   5.296 -    val gen_size_constsB = map (Term.map_types B_ify) gen_size_consts;
   5.297 -    val gen_size_def_bs = map (mk_def_binding I) gen_size_names;
   5.298 +            fun mk_rec_arg_arg (x as Free (_, T)) =
   5.299 +              let val U = B_ify T in
   5.300 +                if T = U then x else build_map thy3_ctxt (the o AList.lookup (op =) ABgs) (T, U) $ x
   5.301 +              end;
   5.302  
   5.303 -    val (gen_size_defs, thy2) =
   5.304 -      thy
   5.305 -      |> Sign.add_consts (map (fn (s, T) => (Binding.name (Long_Name.base_name s), T, NoSyn))
   5.306 -        (gen_size_names ~~ gen_size_Ts))
   5.307 -      |> Global_Theory.add_defs false (map Thm.no_attributes (gen_size_def_bs ~~
   5.308 -        map Logic.mk_equals (gen_size_consts ~~ gen_size_rhss)));
   5.309 +            fun mk_rec_o_map_arg rec_arg_T h =
   5.310 +              let
   5.311 +                val x_Ts = binder_types rec_arg_T;
   5.312 +                val m = length x_Ts;
   5.313 +                val x_names = map (prefix "x" o string_of_int) (1 upto m);
   5.314 +                val xs = map2 (curry Free) x_names x_Ts;
   5.315 +                val xs' = map mk_rec_arg_arg xs;
   5.316 +              in
   5.317 +                fold_rev Term.lambda xs (Term.list_comb (h, xs'))
   5.318 +              end;
   5.319  
   5.320 -    val zeros = map mk_abs_zero_nat As;
   5.321 +            fun mk_rec_o_map_rhs recx =
   5.322 +              let val args = map2 mk_rec_o_map_arg rec_arg_Ts hs in
   5.323 +                Term.list_comb (recx, args)
   5.324 +              end;
   5.325  
   5.326 -    val spec_size_rhss = map (fn c => Term.list_comb (c, zeros)) gen_size_consts;
   5.327 -    val spec_size_Ts = map fastype_of spec_size_rhss;
   5.328 -    val spec_size_consts = map (curry Const @{const_name size}) spec_size_Ts;
   5.329 -    val spec_size_def_bs = map (mk_def_binding (suffix "_overloaded")) gen_size_names;
   5.330 +            val rec_o_map_rhss = map mk_rec_o_map_rhs recs;
   5.331  
   5.332 -    fun define_spec_size def_b lhs0 rhs lthy =
   5.333 -      let
   5.334 -        val Free (c, _) = Syntax.check_term lthy lhs0;
   5.335 -        val (thm, lthy') = lthy
   5.336 -          |> Local_Theory.define ((Binding.name c, NoSyn), ((def_b, []), rhs))
   5.337 -          |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
   5.338 -        val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
   5.339 -        val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
   5.340 -      in (thm', lthy') end;
   5.341 +            val rec_o_map_goals =
   5.342 +              map2 (HOLogic.mk_Trueprop oo curry HOLogic.mk_eq) rec_o_map_lhss rec_o_map_rhss;
   5.343 +            val rec_o_map_thms =
   5.344 +              map3 (fn goal => fn rec_def => fn ctor_rec_o_map =>
   5.345 +                  Goal.prove_global thy3 [] [] goal (fn {context = ctxt, ...} =>
   5.346 +                    mk_rec_o_map_tac ctxt rec_def pre_map_defs abs_inverses ctor_rec_o_map)
   5.347 +                  |> Thm.close_derivation)
   5.348 +                rec_o_map_goals rec_defs ctor_rec_o_maps;
   5.349  
   5.350 -    val (spec_size_defs, thy3) = thy2
   5.351 -      |> Class.instantiation (T_names, map dest_TFree As, [HOLogic.class_size])
   5.352 -      |> fold_map3 define_spec_size spec_size_def_bs spec_size_consts spec_size_rhss
   5.353 -      ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
   5.354 -      ||> Local_Theory.exit_global;
   5.355 +            val size_o_map_conds =
   5.356 +              if exists (can Logic.dest_implies o Thm.prop_of) nested_size_o_maps then
   5.357 +                map (HOLogic.mk_Trueprop o mk_inj) live_gs
   5.358 +              else
   5.359 +                [];
   5.360  
   5.361 -    val thy3_ctxt = Proof_Context.init_global thy3;
   5.362 +            val fsizes = map (fn size_constB => Term.list_comb (size_constB, fsB)) size_constsB;
   5.363 +            val size_o_map_lhss = map2 (curry HOLogic.mk_comp) fsizes gmaps;
   5.364  
   5.365 -    val gen_size_defs' =
   5.366 -      map (mk_unabs_def (num_As + 1) o (fn thm => thm RS meta_eq_to_obj_eq)) gen_size_defs;
   5.367 -    val spec_size_defs' =
   5.368 -      map (mk_unabs_def 1 o (fn thm => thm RS meta_eq_to_obj_eq)) spec_size_defs;
   5.369 +            val fgs = map2 (fn fB => fn g as Free (_, Type (_, [A, B])) =>
   5.370 +              if A = B then fB else HOLogic.mk_comp (fB, g)) fsB gs;
   5.371 +            val size_o_map_rhss = map (fn c => Term.list_comb (c, fgs)) size_consts;
   5.372  
   5.373 -    fun derive_size_simp unfolds folds size_def' simp0 =
   5.374 -      fold_thms thy3_ctxt folds (unfold_thms thy3_ctxt unfolds (trans OF [size_def', simp0]));
   5.375 -    val derive_gen_size_simp =
   5.376 -      derive_size_simp (@{thm snd_o_convol} :: nested_gen_size_maps) gen_size_defs';
   5.377 -    val derive_spec_size_simp = derive_size_simp @{thms add_0_left add_0_right} spec_size_defs';
   5.378 +            val size_o_map_goals =
   5.379 +              map2 (curry Logic.list_implies size_o_map_conds o HOLogic.mk_Trueprop oo
   5.380 +                curry HOLogic.mk_eq) size_o_map_lhss size_o_map_rhss;
   5.381 +            val size_o_map_thms =
   5.382 +              map3 (fn goal => fn size_def => fn rec_o_map =>
   5.383 +                  Goal.prove_global thy3 [] [] goal (fn {context = ctxt, ...} =>
   5.384 +                    mk_size_o_map_tac ctxt size_def rec_o_map all_inj_maps nested_size_maps)
   5.385 +                  |> Thm.close_derivation)
   5.386 +                size_o_map_goals size_defs rec_o_map_thms;
   5.387 +          in
   5.388 +            pairself (map single) (rec_o_map_thms, size_o_map_thms)
   5.389 +          end;
   5.390  
   5.391 -    val gen_size_simpss = map2 (map o derive_gen_size_simp) gen_size_defs' rec_thmss;
   5.392 -    val gen_size_simps = flat gen_size_simpss;
   5.393 -    val spec_size_simpss = map2 (map o derive_spec_size_simp) spec_size_defs' gen_size_simpss;
   5.394 +      val (_, thy4) = thy3
   5.395 +        |> fold_map4 (fn T_name => fn size_simps => fn rec_o_map_thms => fn size_o_map_thms =>
   5.396 +            let val qualify = Binding.qualify true (Long_Name.base_name T_name) in
   5.397 +              Global_Theory.note_thmss ""
   5.398 +                ([((qualify (Binding.name rec_o_mapN), []), [(rec_o_map_thms, [])]),
   5.399 +                  ((qualify (Binding.name sizeN),
   5.400 +                     [Simplifier.simp_add, Nitpick_Simps.add, Thm.declaration_attribute
   5.401 +                        (fn thm => Context.mapping (Code.add_default_eqn thm) I)]),
   5.402 +                   [(size_simps, [])]),
   5.403 +                  ((qualify (Binding.name size_o_mapN), []), [(size_o_map_thms, [])])]
   5.404 +                 |> filter_out (forall (null o fst) o snd))
   5.405 +            end)
   5.406 +          T_names (map2 append size_simpss overloaded_size_simpss) rec_o_map_thmss size_o_map_thmss
   5.407 +        ||> Spec_Rules.add_global Spec_Rules.Equational (size_consts, size_simps);
   5.408 +    in
   5.409 +      thy4
   5.410 +      |> Data.map (fold2 (fn T_name => fn size_name =>
   5.411 +          Symtab.update_new (T_name, (size_name, (size_simps, flat size_o_map_thmss))))
   5.412 +        T_names size_names)
   5.413 +    end
   5.414 +  | generate_size _ thy = thy;
   5.415  
   5.416 -    val ABs = As ~~ Bs;
   5.417 -    val g_names = map (prefix "g" o string_of_int) (1 upto num_As);
   5.418 -    val gs = map2 (curry Free) g_names (map (op -->) ABs);
   5.419 -
   5.420 -    val liveness = map (op <>) ABs;
   5.421 -    val live_gs = AList.find (op =) (gs ~~ liveness) true;
   5.422 -    val live = length live_gs;
   5.423 -
   5.424 -    val u_names = map (prefix "u" o string_of_int) (1 upto nn);
   5.425 -    val us = map2 (curry Free) u_names Ts;
   5.426 -
   5.427 -    val maps0 = map map_of_bnf fp_bnfs;
   5.428 -    val map_thms = maps #maps fp_sugars;
   5.429 -
   5.430 -    fun mk_gen_size_map_tac ctxt =
   5.431 -      HEADGOAL (rtac (co_induct_of common_inducts)) THEN
   5.432 -      ALLGOALS (asm_simp_tac (ss_only (o_apply :: map_thms @ gen_size_simps) ctxt));
   5.433 -
   5.434 -    val gen_size_map_thmss =
   5.435 -      if live = 0 then
   5.436 -        replicate nn []
   5.437 -      else if null nested_gen_size_maps then
   5.438 -        let
   5.439 -          val xgmaps =
   5.440 -            map2 (fn map0 => fn u => Term.list_comb (mk_map live As Bs map0, live_gs) $ u) maps0 us;
   5.441 -          val fsizes =
   5.442 -            map (fn gen_size_constB => Term.list_comb (gen_size_constB, fsB)) gen_size_constsB;
   5.443 -          val lhss = map2 (curry (op $)) fsizes xgmaps;
   5.444 -
   5.445 -          val fgs = map2 (fn fB => fn g as Free (_, Type (_, [A, B])) =>
   5.446 -            if A = B then fB else HOLogic.mk_comp (fB, g)) fsB gs;
   5.447 -          val rhss = map2 (fn gen_size_const => fn u => Term.list_comb (gen_size_const, fgs) $ u)
   5.448 -            gen_size_consts us;
   5.449 -
   5.450 -          val goal = Library.foldr1 HOLogic.mk_conj (map2 (curry HOLogic.mk_eq) lhss rhss)
   5.451 -            |> HOLogic.mk_Trueprop;
   5.452 -        in
   5.453 -          Goal.prove_global thy3 [] [] goal (mk_gen_size_map_tac o #context)
   5.454 -          |> Thm.close_derivation
   5.455 -          |> conj_dests nn
   5.456 -          |> map single
   5.457 -        end
   5.458 -      else
   5.459 -        (* TODO: implement general case, with nesting of datatypes that themselves nest other
   5.460 -           types *)
   5.461 -        replicate nn [];
   5.462 -
   5.463 -    val (_, thy4) = thy3
   5.464 -      |> fold_map3 (fn T_name => fn size_simps => fn gen_size_map_thms =>
   5.465 -          let val qualify = Binding.qualify true (Long_Name.base_name T_name) in
   5.466 -            Global_Theory.note_thmss ""
   5.467 -              ([((qualify (Binding.name sizeN),
   5.468 -                   [Simplifier.simp_add, Nitpick_Simps.add, Thm.declaration_attribute
   5.469 -                      (fn thm => Context.mapping (Code.add_default_eqn thm) I)]),
   5.470 -                 [(size_simps, [])]),
   5.471 -                ((qualify (Binding.name size_mapN), []), [(gen_size_map_thms, [])])]
   5.472 -               |> filter_out (forall (null o fst) o snd))
   5.473 -          end)
   5.474 -        T_names (map2 append gen_size_simpss spec_size_simpss) gen_size_map_thmss
   5.475 -      ||> Spec_Rules.add_global Spec_Rules.Equational (gen_size_consts, gen_size_simps);
   5.476 -  in
   5.477 -    thy4
   5.478 -    |> Data.map (fold2 (fn T_name => fn gen_size_name =>
   5.479 -        Symtab.update_new (T_name, (gen_size_name, (gen_size_simps, flat gen_size_map_thmss))))
   5.480 -      T_names gen_size_names)
   5.481 -  end;
   5.482 -
   5.483 -(* FIXME: get rid of "perhaps o try" once the code is stable *)
   5.484 -val _ = Theory.setup (fp_sugar_interpretation (perhaps o try o generate_size));
   5.485 +val _ = Theory.setup (fp_sugar_interpretation generate_size);
   5.486  
   5.487  end;
     6.1 --- a/src/HOL/Tools/BNF/bnf_util.ML	Wed Apr 23 10:23:26 2014 +0200
     6.2 +++ b/src/HOL/Tools/BNF/bnf_util.ML	Wed Apr 23 10:23:26 2014 +0200
     6.3 @@ -215,7 +215,6 @@
     6.4        map15 f x1s x2s x3s x4s x5s x6s x7s x8s x9s x10s x11s x12s x13s x14s x15s
     6.5    | map15 _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ = raise ListPair.UnequalLengths;
     6.6  
     6.7 -
     6.8  fun fold_map4 _ [] [] [] [] acc = ([], acc)
     6.9    | fold_map4 f (x1::x1s) (x2::x2s) (x3::x3s) (x4::x4s) acc =
    6.10      let