src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 50289 ddd606ec45b9
parent 50288 f839ce127a2e
child 50290 ce87d6a901eb
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:13 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:13 2012 +0200
     1.3 @@ -210,12 +210,8 @@
     1.4      val fp_iter_fun_Ts = fst (split_last (binder_types (fastype_of fp_iter1)));
     1.5      val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
     1.6  
     1.7 -    fun dest_rec_pair (T as Type (@{type_name prod}, Us as [_, U])) =
     1.8 -        if member (op =) Cs U then Us else [T]
     1.9 -      | dest_rec_pair T = [T];
    1.10 -
    1.11      val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
    1.12 -         (zs, cs, cpss, coiter_only as ((pgss, cgsss), _), corec_only as ((phss, chsss), _))) =
    1.13 +         (zs, cs, cpss, coiter_only as ((pgss, cgssss), _), corec_only as ((phss, chssss), _))) =
    1.14        if lfp then
    1.15          let
    1.16            val y_Tsss =
    1.17 @@ -227,18 +223,25 @@
    1.18              lthy
    1.19              |> mk_Freess "f" g_Tss
    1.20              ||>> mk_Freesss "x" y_Tsss;
    1.21 +          val yssss = map (map (map single)) ysss;
    1.22 +
    1.23 +          fun dest_rec_prodT (T as Type (@{type_name prod}, Us as [_, U])) =
    1.24 +              if member (op =) Cs U then Us else [T]
    1.25 +            | dest_rec_prodT T = [T];
    1.26  
    1.27            val z_Tssss =
    1.28 -            map3 (fn n => fn ms => map2 (map dest_rec_pair oo dest_tupleT) ms o
    1.29 +            map3 (fn n => fn ms => map2 (map dest_rec_prodT oo dest_tupleT) ms o
    1.30                dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
    1.31            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
    1.32  
    1.33            val hss = map2 (map2 retype_free) gss h_Tss;
    1.34 -          val (zssss, _) =
    1.35 +          val zssss_hd = map2 (map2 (map2 (fn y => fn T :: _ => retype_free y T))) ysss z_Tssss;
    1.36 +          val (zssss_tl, _) =
    1.37              lthy
    1.38 -            |> mk_Freessss "x" z_Tssss;
    1.39 +            |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
    1.40 +          val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
    1.41          in
    1.42 -          (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
    1.43 +          (((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
    1.44             ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
    1.45          end
    1.46        else
    1.47 @@ -249,20 +252,23 @@
    1.48            val p_Tss =
    1.49              map2 (fn C => fn n => replicate (Int.max (0, n - 1)) (C --> HOLogic.boolT)) Cs ns;
    1.50  
    1.51 -          fun zip_preds_getters [] [fs] = fs
    1.52 -            | zip_preds_getters (p :: ps) (fs :: fss) = p :: fs @ zip_preds_getters ps fss;
    1.53 +          fun zip_getters fss = flat fss;
    1.54  
    1.55 -          fun mk_types fun_Ts =
    1.56 +          fun zip_preds_getters [] [fss] = zip_getters fss
    1.57 +            | zip_preds_getters (p :: ps) (fss :: fsss) =
    1.58 +              p :: zip_getters fss @ zip_preds_getters ps fsss;
    1.59 +
    1.60 +          fun mk_types maybe_dest_sumT fun_Ts =
    1.61              let
    1.62                val f_sum_prod_Ts = map range_type fun_Ts;
    1.63                val f_prod_Tss = map2 dest_sumTN_balanced ns f_sum_prod_Ts;
    1.64                val f_Tsss =
    1.65                  map3 (fn C => map2 (map (curry (op -->) C) oo dest_tupleT)) Cs mss' f_prod_Tss;
    1.66 -              val pf_Tss = map2 zip_preds_getters p_Tss f_Tsss
    1.67 +              val f_Tssss = map (map (map maybe_dest_sumT)) f_Tsss;
    1.68 +              val pf_Tss = map2 zip_preds_getters p_Tss f_Tssss;
    1.69              in (f_sum_prod_Ts, f_Tsss, pf_Tss) end;
    1.70  
    1.71 -          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
    1.72 -          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
    1.73 +          val (g_sum_prod_Ts, g_Tsss, pg_Tss) = mk_types single fp_iter_fun_Ts;
    1.74  
    1.75            val ((((Free (z, _), cs), pss), gsss), _) =
    1.76              lthy
    1.77 @@ -270,20 +276,28 @@
    1.78              ||>> mk_Frees "a" Cs
    1.79              ||>> mk_Freess "p" p_Tss
    1.80              ||>> mk_Freesss "g" g_Tsss;
    1.81 +          val gssss = map (map (map single)) gsss;
    1.82 +
    1.83 +          fun dest_corec_sumT (T as Type (@{type_name sum}, Us as [_, U])) =
    1.84 +              if member (op =) Cs U then Us else [T]
    1.85 +            | dest_corec_sumT T = [T];
    1.86 +
    1.87 +          val (h_sum_prod_Ts, h_Tsss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
    1.88  
    1.89            val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
    1.90 +          val hssss = map (map (map single)) hsss; (*###*)
    1.91  
    1.92            val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
    1.93  
    1.94 -          fun mk_terms fsss =
    1.95 +          fun mk_terms fssss =
    1.96              let
    1.97 -              val pfss = map2 zip_preds_getters pss fsss;
    1.98 -              val cfsss = map2 (fn c => map (map (fn f => f $ c))) cs fsss
    1.99 -            in (pfss, cfsss) end;
   1.100 +              val pfss = map2 zip_preds_getters pss fssss;
   1.101 +              val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
   1.102 +            in (pfss, cfssss) end;
   1.103          in
   1.104            ((([], [], []), ([], [], [])),
   1.105 -           ([z], cs, cpss, (mk_terms gsss, (g_sum_prod_Ts, pg_Tss)),
   1.106 -            (mk_terms hsss, (h_sum_prod_Ts, ph_Tss))))
   1.107 +           ([z], cs, cpss, (mk_terms gssss, (g_sum_prod_Ts, pg_Tss)),
   1.108 +            (mk_terms hssss, (h_sum_prod_Ts, ph_Tss))))
   1.109          end;
   1.110  
   1.111      fun pour_some_sugar_on_type (((((((((((((((((b, fpT), C), fld), unf), fp_iter), fp_rec),
   1.112 @@ -383,11 +397,11 @@
   1.113                        map2 (mk_sum_caseN_balanced oo map2 mk_uncurried2_fun) fss xssss));
   1.114                in (binder, spec) end;
   1.115  
   1.116 -            val iter_likes =
   1.117 +            val iter_like_bundles =
   1.118                [(iterN, fp_iter, iter_only),
   1.119                 (recN, fp_rec, rec_only)];
   1.120  
   1.121 -            val (binders, specs) = map generate_iter_like iter_likes |> split_list;
   1.122 +            val (binders, specs) = map generate_iter_like iter_like_bundles |> split_list;
   1.123  
   1.124              val ((csts, defs), (lthy', lthy)) = no_defs_lthy
   1.125                |> apfst split_list o fold_map2 (fn b => fn spec =>
   1.126 @@ -410,27 +424,29 @@
   1.127            let
   1.128              val B_to_fpT = C --> fpT;
   1.129  
   1.130 -            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfsss), (f_sum_prod_Ts, pf_Tss))) =
   1.131 +            fun mk_preds_getters_join c n cps sum_prod_T cfsss =
   1.132 +              Term.lambda c (mk_IfN sum_prod_T cps
   1.133 +                (map2 (mk_InN_balanced sum_prod_T n) (map (HOLogic.mk_tuple o flat) cfsss) (*###*)
   1.134 +                   (1 upto n)));
   1.135 +
   1.136 +            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cfssss), (f_sum_prod_Ts,
   1.137 +                pf_Tss))) =
   1.138                let
   1.139                  val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
   1.140  
   1.141                  val binder = Binding.suffix_name ("_" ^ suf) b;
   1.142  
   1.143 -                fun mk_preds_getters_join c n cps sum_prod_T cfss =
   1.144 -                  Term.lambda c (mk_IfN sum_prod_T cps
   1.145 -                    (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cfss) (1 upto n)));
   1.146 -
   1.147                  val spec =
   1.148                    mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
   1.149                      Term.list_comb (fp_iter_like,
   1.150 -                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfsss));
   1.151 +                      map5 mk_preds_getters_join cs ns cpss f_sum_prod_Ts cfssss));
   1.152                in (binder, spec) end;
   1.153  
   1.154 -            val coiter_likes =
   1.155 +            val coiter_like_bundles =
   1.156                [(coiterN, fp_iter, coiter_only),
   1.157                 (corecN, fp_rec, corec_only)];
   1.158  
   1.159 -            val (binders, specs) = map generate_coiter_like coiter_likes |> split_list;
   1.160 +            val (binders, specs) = map generate_coiter_like coiter_like_bundles |> split_list;
   1.161  
   1.162              val ((csts, defs), (lthy', lthy)) = no_defs_lthy
   1.163                |> apfst split_list o fold_map2 (fn b => fn spec =>
   1.164 @@ -490,14 +506,14 @@
   1.165                    ~1 => build_map (build_call fiter_likes maybe_tick) T U
   1.166                  | j => maybe_tick (nth vs j) (nth fiter_likes j));
   1.167  
   1.168 -            fun mk_U maybe_prodT =
   1.169 -              typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
   1.170 +            fun mk_U maybe_mk_prodT =
   1.171 +              typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
   1.172  
   1.173 -            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
   1.174 +            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
   1.175                if member (op =) fpTs T then
   1.176                  maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
   1.177                else if exists_subtype (member (op =) fpTs) T then
   1.178 -                [build_call fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
   1.179 +                [build_call fiter_likes maybe_tick (T, mk_U maybe_mk_prodT T) $ x]
   1.180                else
   1.181                  [x];
   1.182  
   1.183 @@ -544,10 +560,10 @@
   1.184            let
   1.185              fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
   1.186  
   1.187 -            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
   1.188 +            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfss' =
   1.189                fold_rev (fold_rev Logic.all) ([c] :: pfss)
   1.190                  (Logic.list_implies (seq_conds mk_goal_cond n k cps,
   1.191 -                   mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, take m cfs'))));
   1.192 +                   mk_Trueprop_eq (fcoiter_like $ c, lists_bmoc (take m cfss') ctr)));
   1.193  
   1.194              fun build_call fiter_likes maybe_tack (T, U) =
   1.195                if T = U then
   1.196 @@ -557,23 +573,25 @@
   1.197                    ~1 => build_map (build_call fiter_likes maybe_tack) T U
   1.198                  | j => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j));
   1.199  
   1.200 -            fun mk_U maybe_sumT =
   1.201 -              typ_subst (map2 (fn C => fn fpT => (maybe_sumT fpT C, fpT)) Cs fpTs);
   1.202 +            fun mk_U maybe_mk_sumT =
   1.203 +              typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
   1.204  
   1.205 -            fun repair_calls fiter_likes maybe_sumT maybe_tack
   1.206 +            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack
   1.207                  (cf as Free (_, Type (_, [_, T])) $ _) =
   1.208                if exists_subtype (member (op =) Cs) T then
   1.209 -                build_call fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
   1.210 +                build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cf
   1.211                else
   1.212                  cf;
   1.213  
   1.214 -            val cgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) cgsss;
   1.215 -            val chsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) chsss;
   1.216 +            val cgssss' =
   1.217 +              map (map (map (map (repair_calls gcoiters (K I) (K I))))) cgssss;
   1.218 +            val chssss' =
   1.219 +              map (map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z))))) chssss;
   1.220  
   1.221              val goal_coiterss =
   1.222 -              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgsss';
   1.223 +              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgssss';
   1.224              val goal_corecss =
   1.225 -              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chsss';
   1.226 +              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chssss';
   1.227  
   1.228              val coiter_tacss =
   1.229                map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs