src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 50291 59fa53ed7507
parent 50290 ce87d6a901eb
child 50292 aee77001243f
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:14 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 13:06:14 2012 +0200
     1.3 @@ -213,8 +213,7 @@
     1.4      val fp_rec_fun_Ts = fst (split_last (binder_types (fastype_of fp_rec1)));
     1.5  
     1.6      val ((iter_only as (gss, _, _), rec_only as (hss, _, _)),
     1.7 -         (zs, cs, cpss, coiter_only as ((pgss, _, cgssss), _),
     1.8 -          corec_only as ((phss, _, chssss), _))) =
     1.9 +         (zs, cs, cpss, coiter_only as ((pgss, crgsss), _), corec_only as ((phss, cshsss), _))) =
    1.10        if lfp then
    1.11          let
    1.12            val y_Tsss =
    1.13 @@ -245,7 +244,7 @@
    1.14            val zssss = map2 (map2 (map2 cons)) zssss_hd zssss_tl;
    1.15          in
    1.16            (((gss, g_Tss, yssss), (hss, h_Tss, zssss)),
    1.17 -           ([], [], [], (([], [], []), ([], [])), (([], [], []), ([], []))))
    1.18 +           ([], [], [], (([], []), ([], [])), (([], []), ([], []))))
    1.19          end
    1.20        else
    1.21          let
    1.22 @@ -254,11 +253,11 @@
    1.23  
    1.24            val p_Tss = map2 (fn n => replicate (Int.max (0, n - 1)) o mk_predT) ns Cs;
    1.25  
    1.26 -          fun zip_getterss qss fss = maps (op @) (qss ~~ fss);
    1.27 +          fun zip_predss_getterss qss fss = maps (op @) (qss ~~ fss);
    1.28  
    1.29 -          fun zip_preds_gettersss [] [qss] [fss] = zip_getterss qss fss
    1.30 -            | zip_preds_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
    1.31 -              p :: zip_getterss qss fss @ zip_preds_gettersss ps qsss fsss;
    1.32 +          fun zip_preds_predsss_gettersss [] [qss] [fss] = zip_predss_getterss qss fss
    1.33 +            | zip_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
    1.34 +              p :: zip_predss_getterss qss fss @ zip_preds_predsss_gettersss ps qsss fsss;
    1.35  
    1.36            fun mk_types maybe_dest_sumT fun_Ts =
    1.37              let
    1.38 @@ -269,7 +268,7 @@
    1.39                    Cs mss' f_prod_Tss;
    1.40                val q_Tssss =
    1.41                  map (map (map (fn [_] => [] | [_, C] => [mk_predT (domain_type C)]))) f_Tssss;
    1.42 -              val pf_Tss = map3 zip_preds_gettersss p_Tss q_Tssss f_Tssss;
    1.43 +              val pf_Tss = map3 zip_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
    1.44              in (q_Tssss, f_sum_prod_Ts, f_Tssss, pf_Tss) end;
    1.45  
    1.46            val (r_Tssss, g_sum_prod_Ts, g_Tssss, pg_Tss) = mk_types single fp_iter_fun_Ts;
    1.47 @@ -297,12 +296,17 @@
    1.48  
    1.49            val cpss = map2 (fn c => map (fn p => p $ c)) cs pss;
    1.50  
    1.51 +          fun mk_preds_getters_join [] [cf] = cf
    1.52 +            | mk_preds_getters_join [cq] [cf, cf'] =
    1.53 +              mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
    1.54 +
    1.55            fun mk_terms qssss fssss =
    1.56              let
    1.57 -              val pfss = map3 zip_preds_gettersss pss qssss fssss;
    1.58 +              val pfss = map3 zip_preds_predsss_gettersss pss qssss fssss;
    1.59                val cqssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs qssss;
    1.60                val cfssss = map2 (fn c => map (map (map (fn f => f $ c)))) cs fssss;
    1.61 -            in (pfss, cqssss, cfssss) end;
    1.62 +              val cqfsss = map2 (map2 (map2 mk_preds_getters_join)) cqssss cfssss;
    1.63 +            in (pfss, cqfsss) end;
    1.64          in
    1.65            ((([], [], []), ([], [], [])),
    1.66             ([z], cs, cpss, (mk_terms rssss gssss, (g_sum_prod_Ts, pg_Tss)),
    1.67 @@ -433,16 +437,11 @@
    1.68            let
    1.69              val B_to_fpT = C --> fpT;
    1.70  
    1.71 -            fun mk_getters_join [] [cf] = cf
    1.72 -              | mk_getters_join [cq] [cf, cf'] =
    1.73 -                mk_If cq (mk_Inl (fastype_of cf') cf) (mk_Inr (fastype_of cf) cf');
    1.74 +            fun mk_preds_getterss_join c n cps sum_prod_T cqfss =
    1.75 +              Term.lambda c (mk_IfN sum_prod_T cps
    1.76 +                (map2 (mk_InN_balanced sum_prod_T n) (map HOLogic.mk_tuple cqfss) (1 upto n)));
    1.77  
    1.78 -            fun mk_preds_gettersss_join c n cps sum_prod_T cqsss cfsss =
    1.79 -              Term.lambda c (mk_IfN sum_prod_T cps
    1.80 -                (map2 (mk_InN_balanced sum_prod_T n)
    1.81 -                   (map2 (HOLogic.mk_tuple oo map2 mk_getters_join) cqsss cfsss) (1 upto n)));
    1.82 -
    1.83 -            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqssss, cfssss), (f_sum_prod_Ts,
    1.84 +            fun generate_coiter_like (suf, fp_iter_like, ((pfss, cqfsss), (f_sum_prod_Ts,
    1.85                  pf_Tss))) =
    1.86                let
    1.87                  val res_T = fold_rev (curry (op --->)) pf_Tss B_to_fpT;
    1.88 @@ -452,7 +451,7 @@
    1.89                  val spec =
    1.90                    mk_Trueprop_eq (lists_bmoc pfss (Free (Binding.name_of binder, res_T)),
    1.91                      Term.list_comb (fp_iter_like,
    1.92 -                      map6 mk_preds_gettersss_join cs ns cpss f_sum_prod_Ts cqssss cfssss));
    1.93 +                      map5 mk_preds_getterss_join cs ns cpss f_sum_prod_Ts cqfsss));
    1.94                in (binder, spec) end;
    1.95  
    1.96              val coiter_like_bundles =
    1.97 @@ -542,11 +541,9 @@
    1.98              val rec_tacss =
    1.99                map2 (map o mk_iter_like_tac pre_map_defs map_ids rec_defs) fp_rec_thms ctr_defss;
   1.100            in
   1.101 -            (map2 (map2 (fn goal => fn tac =>
   1.102 -                 Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
   1.103 +            (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   1.104                 goal_iterss iter_tacss,
   1.105 -             map2 (map2 (fn goal => fn tac =>
   1.106 -                 Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
   1.107 +             map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   1.108                 goal_recss rec_tacss)
   1.109            end;
   1.110  
   1.111 @@ -573,10 +570,10 @@
   1.112            let
   1.113              fun mk_goal_cond pos = HOLogic.mk_Trueprop o (not pos ? HOLogic.mk_not);
   1.114  
   1.115 -            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfss' =
   1.116 +            fun mk_goal_coiter_like pfss c cps fcoiter_like n k ctr m cfs' =
   1.117                fold_rev (fold_rev Logic.all) ([c] :: pfss)
   1.118                  (Logic.list_implies (seq_conds mk_goal_cond n k cps,
   1.119 -                   mk_Trueprop_eq (fcoiter_like $ c, lists_bmoc (take m cfss') ctr)));
   1.120 +                   mk_Trueprop_eq (fcoiter_like $ c, Term.list_comb (ctr, take m cfs'))));
   1.121  
   1.122              fun build_call fiter_likes maybe_tack (T, U) =
   1.123                if T = U then
   1.124 @@ -589,22 +586,21 @@
   1.125              fun mk_U maybe_mk_sumT =
   1.126                typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
   1.127  
   1.128 -            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack
   1.129 -                (cf as Free (_, Type (_, [_, T])) $ _) =
   1.130 -              if exists_subtype (member (op =) Cs) T then
   1.131 -                build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cf
   1.132 -              else
   1.133 -                cf;
   1.134 +            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
   1.135 +              let val T = fastype_of cqf in
   1.136 +                if exists_subtype (member (op =) Cs) T then
   1.137 +                  build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
   1.138 +                else
   1.139 +                  cqf
   1.140 +              end;
   1.141  
   1.142 -            val cgssss' =
   1.143 -              map (map (map (map (repair_calls gcoiters (K I) (K I))))) cgssss;
   1.144 -            val chssss' =
   1.145 -              map (map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z))))) chssss;
   1.146 +            val crgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) crgsss;
   1.147 +            val cshsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
   1.148  
   1.149              val goal_coiterss =
   1.150 -              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss cgssss';
   1.151 +              map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss crgsss';
   1.152              val goal_corecss =
   1.153 -              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss chssss';
   1.154 +              map8 (map4 oooo mk_goal_coiter_like phss) cs cpss hcorecs ns kss ctrss mss cshsss';
   1.155  
   1.156              val coiter_tacss =
   1.157                map3 (map oo mk_coiter_like_tac coiter_defs map_ids) fp_iter_thms pre_map_defs
   1.158 @@ -613,9 +609,12 @@
   1.159                map3 (map oo mk_coiter_like_tac corec_defs map_ids) fp_rec_thms pre_map_defs
   1.160                  ctr_defss;
   1.161            in
   1.162 -            (map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   1.163 +            (map2 (map2 (fn goal => fn tac =>
   1.164 +                 Skip_Proof.prove lthy [] [] goal (tac o #context) |> Thm.close_derivation))
   1.165                 goal_coiterss coiter_tacss,
   1.166 -             map2 (map2 (fn goal => fn tac => Skip_Proof.prove lthy [] [] goal (tac o #context)))
   1.167 +             map2 (map2 (fn goal => fn tac =>
   1.168 +                 Skip_Proof.prove lthy [] [] goal (tac o #context)
   1.169 +                 |> Local_Defs.unfold lthy @{thms sum_case_if} |> Thm.close_derivation))
   1.170                 goal_corecss corec_tacss)
   1.171            end;
   1.172