src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 50248 7f412734fbb3
parent 50247 9ea11f0c53e4
child 50249 4626ff7cbd2c
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sun Sep 09 17:14:39 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sun Sep 09 18:55:10 2012 +0200
     1.3 @@ -48,7 +48,12 @@
     1.4  fun mk_uncurried2_fun f xss =
     1.5    mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat xss);
     1.6  
     1.7 -fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v))
     1.8 +fun tick v f = Term.lambda v (HOLogic.mk_prod (v, f $ v));
     1.9 +
    1.10 +fun tack z_name (c, v) f =
    1.11 +  let val z = Free (z_name, mk_sumT (fastype_of v, fastype_of c)) in
    1.12 +    Term.lambda z (mk_sum_case (Term.lambda v v) (Term.lambda c (f $ c)) $ z)
    1.13 +  end;
    1.14  
    1.15  fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    1.16  
    1.17 @@ -204,7 +209,7 @@
    1.18        | dest_rec_pair T = [T];
    1.19  
    1.20      val ((iter_only as (gss, g_Tss, yssss), rec_only as (hss, h_Tss, zssss)),
    1.21 -         (cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
    1.22 +         (zs, cs, cpss, p_Tss, coiter_only as ((pgss, cgsss), g_sum_prod_Ts, g_prod_Tss, g_Tsss),
    1.23            corec_only as ((phss, chsss), h_sum_prod_Ts, h_prod_Tss, h_Tsss))) =
    1.24        if lfp then
    1.25          let
    1.26 @@ -229,7 +234,7 @@
    1.27              |> mk_Freessss "x" z_Tssss;
    1.28          in
    1.29            (((gss, g_Tss, map (map (map single)) ysss), (hss, h_Tss, zssss)),
    1.30 -           ([], [], [], (([], []), [], [], []), (([], []), [], [], [])))
    1.31 +           ([], [], [], [], (([], []), [], [], []), (([], []), [], [], [])))
    1.32          end
    1.33        else
    1.34          let
    1.35 @@ -254,15 +259,15 @@
    1.36            val (g_sum_prod_Ts, g_prod_Tss, g_Tsss, pg_Tss) = mk_types fp_iter_fun_Ts;
    1.37            val (h_sum_prod_Ts, h_prod_Tss, h_Tsss, ph_Tss) = mk_types fp_rec_fun_Ts;
    1.38  
    1.39 -          val (((c, pss), gsss), _) =
    1.40 +          val ((((Free (z, _), cs), pss), gsss), _) =
    1.41              lthy
    1.42 -            |> yield_singleton (mk_Frees "c") dummyT
    1.43 +            |> yield_singleton (mk_Frees "z") dummyT
    1.44 +            ||>> mk_Frees "a" Cs
    1.45              ||>> mk_Freess "p" p_Tss
    1.46              ||>> mk_Freesss "g" g_Tsss;
    1.47  
    1.48            val hsss = map2 (map2 (map2 retype_free)) gsss h_Tsss;
    1.49  
    1.50 -          val cs = map (retype_free c) Cs;
    1.51            val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
    1.52  
    1.53            fun mk_terms fsss =
    1.54 @@ -272,7 +277,7 @@
    1.55              in (pfss, cfsss) end;
    1.56          in
    1.57            ((([], [], []), ([], [], [])),
    1.58 -           (cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
    1.59 +           ([z], cs, cpss, p_Tss, (mk_terms gsss, g_sum_prod_Ts, g_prod_Tss, pg_Tss),
    1.60              (mk_terms hsss, h_sum_prod_Ts, h_prod_Tss, ph_Tss)))
    1.61          end;
    1.62  
    1.63 @@ -447,24 +452,6 @@
    1.64          Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
    1.65        end;
    1.66  
    1.67 -    fun build_iter_like_call vs basic_Ts fiter_likes maybe_tick =
    1.68 -      let
    1.69 -        fun build (T, U) =
    1.70 -          if T = U then
    1.71 -            Const (@{const_name id}, T --> T)
    1.72 -          else
    1.73 -            (case (find_index (curry (op =) T) basic_Ts, (T, U)) of
    1.74 -              (~1, (Type (s, Ts), Type (_, Us))) =>
    1.75 -              let
    1.76 -                val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
    1.77 -                val mapx = mk_map Ts Us map0;
    1.78 -                val TUs =
    1.79 -                  map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
    1.80 -                val args = map build TUs;
    1.81 -              in Term.list_comb (mapx, args) end
    1.82 -            | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
    1.83 -      in build end;
    1.84 -
    1.85      fun pour_more_sugar_on_lfps ((ctrss, iters, recs, vs, xsss, ctr_defss, iter_defs, rec_defs),
    1.86          lthy) =
    1.87        let
    1.88 @@ -478,14 +465,32 @@
    1.89                fold_rev (fold_rev Logic.all) (xs :: fss)
    1.90                  (mk_Trueprop_eq (fiter_like $ xctr, Term.list_comb (f, fxs)));
    1.91  
    1.92 +            fun build_call fiter_likes maybe_tick =
    1.93 +              let
    1.94 +                fun build (T, U) =
    1.95 +                  if T = U then
    1.96 +                    Const (@{const_name id}, T --> T)
    1.97 +                  else
    1.98 +                    (case (find_index (curry (op =) T) fpTs, (T, U)) of
    1.99 +                      (~1, (Type (s, Ts), Type (_, Us))) =>
   1.100 +                      let
   1.101 +                        val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
   1.102 +                        val mapx = mk_map Ts Us map0;
   1.103 +                        val TUs =
   1.104 +                          map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
   1.105 +                        val args = map build TUs;
   1.106 +                      in Term.list_comb (mapx, args) end
   1.107 +                    | (j, _) => maybe_tick (nth vs j) (nth fiter_likes j))
   1.108 +              in build end;
   1.109 +
   1.110              fun mk_U maybe_prodT =
   1.111                typ_subst (map2 (fn fpT => fn C => (fpT, maybe_prodT fpT C)) fpTs Cs);
   1.112  
   1.113              fun repair_calls fiter_likes maybe_cons maybe_tick maybe_prodT (x as Free (_, T)) =
   1.114                if member (op =) fpTs T then
   1.115 -                maybe_cons x [build_iter_like_call vs fpTs fiter_likes (K I) (T, mk_U (K I) T) $ x]
   1.116 +                maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
   1.117                else if exists_subtype (member (op =) fpTs) T then
   1.118 -                [build_iter_like_call vs fpTs fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
   1.119 +                [build_call fiter_likes maybe_tick (T, mk_U maybe_prodT T) $ x]
   1.120                else
   1.121                  [x];
   1.122  
   1.123 @@ -521,6 +526,8 @@
   1.124      fun pour_more_sugar_on_gfps ((ctrss, coiters, corecs, vs, xsss, ctr_defss, coiter_defs,
   1.125          corec_defs), lthy) =
   1.126        let
   1.127 +        val z = the_single zs;
   1.128 +
   1.129          val gcoiters = map (lists_bmoc pgss) coiters;
   1.130          val hcorecs = map (lists_bmoc phss) corecs;
   1.131  
   1.132 @@ -533,32 +540,58 @@
   1.133                  (Logic.list_implies (seq_conds mk_goal_cond n k cps,
   1.134                     mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, cfs'))));
   1.135  
   1.136 +            fun build_call fiter_likes maybe_tack =
   1.137 +              let
   1.138 +                fun build (T, U) =
   1.139 +                  if T = U then
   1.140 +                    Const (@{const_name id}, T --> T)
   1.141 +                  else
   1.142 +                    (case (find_index (curry (op =) U) fpTs, (T, U)) of
   1.143 +                      (~1, (Type (s, Ts), Type (_, Us))) =>
   1.144 +                      let
   1.145 +                        val map0 = map_of_bnf (the (bnf_of lthy (Long_Name.base_name s)));
   1.146 +                        val mapx = mk_map Ts Us map0;
   1.147 +                        val TUs =
   1.148 +                          map dest_funT (fst (split_last (fst (strip_map_type (fastype_of mapx)))));
   1.149 +                        val args = map build TUs;
   1.150 +                      in Term.list_comb (mapx, args) end
   1.151 +                    | (j, _) => maybe_tack (nth cs j, nth vs j) (nth fiter_likes j))
   1.152 +              in build end;
   1.153 +
   1.154              fun mk_U maybe_sumT =
   1.155 -              typ_subst (map2 (fn C => fn fpT => (C, maybe_sumT C fpT)) Cs fpTs);
   1.156 +              typ_subst (map2 (fn C => fn fpT => (maybe_sumT fpT C, fpT)) Cs fpTs);
   1.157  
   1.158              fun repair_calls fiter_likes maybe_sumT maybe_tack
   1.159                  (cf as Free (_, Type (_, [_, T])) $ _) =
   1.160                if exists_subtype (member (op =) Cs) T then
   1.161 -                build_iter_like_call vs Cs fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
   1.162 +                build_call fiter_likes maybe_tack (T, mk_U maybe_sumT T) $ cf
   1.163                else
   1.164                  cf;
   1.165  
   1.166              val cgsss = map (map (map (repair_calls gcoiters (K I) (K I)))) cgsss;
   1.167 +            val chsss = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) chsss;
   1.168  
   1.169              val goal_coiterss =
   1.170                map7 (map3 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss cgsss;
   1.171 +            val goal_corecss =
   1.172 +              map7 (map3 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss chsss;
   1.173  
   1.174              val coiter_tacss =
   1.175                map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
   1.176                  ctr_defss;
   1.177 +            val corec_tacss =
   1.178 +              map3 (map oo mk_coiter_like_tac corec_defs map_ids) fp_rec_thms pre_map_defs
   1.179 +                ctr_defss;
   1.180            in
   1.181              (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   1.182                 goal_coiterss coiter_tacss,
   1.183 -             [])
   1.184 +             map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   1.185 +               goal_corecss corec_tacss)
   1.186            end;
   1.187  
   1.188          val notes =
   1.189 -          [(coitersN, coiter_thmss)]
   1.190 +          [(coitersN, coiter_thmss),
   1.191 +           (corecsN, corec_thmss)]
   1.192            |> maps (fn (thmN, thmss) =>
   1.193              map2 (fn b => fn thms =>
   1.194                  ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))