src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 50357 8ea4bad49ed5
parent 50353 4a922800531d
child 50376 cc1d39529dd1
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Wed Sep 12 23:06:39 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Wed Sep 12 23:06:39 2012 +0200
     1.3 @@ -41,17 +41,18 @@
     1.4  fun strip_map_type (Type (@{type_name fun}, [T as Type _, T'])) = strip_map_type T' |>> cons T
     1.5    | strip_map_type T = ([], T);
     1.6  
     1.7 +fun resort_tfree S (TFree (s, _)) = TFree (s, S);
     1.8 +
     1.9  fun typ_subst inst (T as Type (s, Ts)) =
    1.10      (case AList.lookup (op =) inst T of
    1.11        NONE => Type (s, map (typ_subst inst) Ts)
    1.12      | SOME T' => T')
    1.13    | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
    1.14  
    1.15 -fun resort_tfree S (TFree (s, _)) = TFree (s, S);
    1.16 -
    1.17  val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
    1.18  
    1.19  fun mk_id T = Const (@{const_name id}, T --> T);
    1.20 +fun mk_id_fun T = Abs (Name.uu, T, Bound 0);
    1.21  
    1.22  fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
    1.23  fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
    1.24 @@ -65,6 +66,8 @@
    1.25      Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
    1.26    end;
    1.27  
    1.28 +fun fold_def_rule n thm = funpow n (fn thm => thm RS fun_cong) (thm RS meta_eq_to_obj_eq) RS sym;
    1.29 +
    1.30  fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    1.31  
    1.32  fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
    1.33 @@ -153,7 +156,7 @@
    1.34      val rhs_As' = fold (fold (fold Term.add_tfreesT)) fake_ctr_Tsss [];
    1.35      val _ = (case subtract (op =) (map dest_TFree As) rhs_As' of
    1.36          [] => ()
    1.37 -      | A' :: _ => error ("Extra type variables on rhs: " ^
    1.38 +      | A' :: _ => error ("Extra type variable on right-hand side: " ^
    1.39            quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
    1.40  
    1.41      fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
    1.42 @@ -494,6 +497,7 @@
    1.43        end;
    1.44  
    1.45      val pre_map_defs = map map_def_of_bnf pre_bnfs;
    1.46 +    val pre_set_defss = map set_defs_of_bnf pre_bnfs;
    1.47      val map_ids = map map_id_of_bnf nested_bnfs;
    1.48  
    1.49      fun mk_map Ts Us t =
    1.50 @@ -514,7 +518,29 @@
    1.51        let
    1.52          val (induct_thms, induct_thm) =
    1.53            let
    1.54 -            val induct_thm = fp_induct;
    1.55 +            val sym_ctr_defss = map2 (map2 fold_def_rule) mss ctr_defss;
    1.56 +
    1.57 +            val ss = @{simpset} |> fold Simplifier.add_simp
    1.58 +              @{thms collect_def[abs_def] sum_setl_def[abs_def] sum_setr_def[abs_def]
    1.59 +                 fsts_def[abs_def] snds_def[abs_def] False_imp_eq all_point_1};
    1.60 +
    1.61 +            val induct_thm0 = fp_induct OF (map mk_sumEN_tupled_balanced mss);
    1.62 +
    1.63 +            val spurious_fs =
    1.64 +              Term.add_vars (prop_of induct_thm0) []
    1.65 +              |> filter (fn (_, Type (@{type_name fun}, [_, T'])) => T' <> HOLogic.boolT
    1.66 +                | _ => false);
    1.67 +
    1.68 +            val cxs =
    1.69 +              map (fn s as (_, T) =>
    1.70 +                (certify lthy (Var s), certify lthy (mk_id_fun (domain_type T)))) spurious_fs;
    1.71 +
    1.72 +            val induct_thm =
    1.73 +              Drule.cterm_instantiate cxs induct_thm0
    1.74 +              |> Tactic.rule_by_tactic lthy (ALLGOALS (REPEAT_DETERM o bound_hyp_subst_tac))
    1.75 +              |> Local_Defs.unfold lthy
    1.76 +                (@{thm triv_forall_equality} :: flat sym_ctr_defss @ flat pre_set_defss)
    1.77 +              |> Simplifier.full_simplify ss;
    1.78            in
    1.79              `(conj_dests N) induct_thm
    1.80            end;
    1.81 @@ -540,7 +566,7 @@
    1.82              fun mk_U maybe_mk_prodT =
    1.83                typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
    1.84  
    1.85 -            fun repair_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
    1.86 +            fun intr_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
    1.87                if member (op =) fpTs T then
    1.88                  maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
    1.89                else if exists_subtype (member (op =) fpTs) T then
    1.90 @@ -548,9 +574,8 @@
    1.91                else
    1.92                  [x];
    1.93  
    1.94 -            val gxsss = map (map (maps (repair_calls giters (K I) (K I) (K I)))) xsss;
    1.95 -            val hxsss =
    1.96 -              map (map (maps (repair_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
    1.97 +            val gxsss = map (map (maps (intr_calls giters (K I) (K I) (K I)))) xsss;
    1.98 +            val hxsss = map (map (maps (intr_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
    1.99  
   1.100              val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
   1.101              val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
   1.102 @@ -567,13 +592,13 @@
   1.103            end;
   1.104  
   1.105          val common_notes =
   1.106 -          [(inductN, [induct_thm], []), (*### attribs *)
   1.107 -           (inductsN, induct_thms, [])] (*### attribs *)
   1.108 +          (if N > 1 then [(inductN, [induct_thm], [])] (* FIXME: attribs *) else [])
   1.109            |> map (fn (thmN, thms, attrs) =>
   1.110                ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
   1.111  
   1.112          val notes =
   1.113 -          [(itersN, iter_thmss, simp_attrs),
   1.114 +          [(inductN, map single induct_thms, []), (* FIXME: attribs *)
   1.115 +           (itersN, iter_thmss, simp_attrs),
   1.116             (recsN, rec_thmss, Code.add_default_eqn_attrib :: simp_attrs)]
   1.117            |> maps (fn (thmN, thmss, attrs) =>
   1.118              map2 (fn b => fn thms =>
   1.119 @@ -617,7 +642,7 @@
   1.120              fun mk_U maybe_mk_sumT =
   1.121                typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
   1.122  
   1.123 -            fun repair_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
   1.124 +            fun intr_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
   1.125                let val T = fastype_of cqf in
   1.126                  if exists_subtype (member (op =) Cs) T then
   1.127                    build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
   1.128 @@ -625,8 +650,8 @@
   1.129                    cqf
   1.130                end;
   1.131  
   1.132 -            val crgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) crgsss;
   1.133 -            val cshsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
   1.134 +            val crgsss' = map (map (map (intr_calls gcoiters (K I) (K I)))) crgsss;
   1.135 +            val cshsss' = map (map (map (intr_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
   1.136  
   1.137              val goal_coiterss =
   1.138                map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss crgsss';
   1.139 @@ -672,8 +697,14 @@
   1.140          val sel_corec_thmsss =
   1.141            map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss sel_thmsss;
   1.142  
   1.143 +        val common_notes =
   1.144 +          (if N > 1 then [(coinductN, [coinduct_thm], [])] (* FIXME: attribs *) else [])
   1.145 +          |> map (fn (thmN, thms, attrs) =>
   1.146 +              ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
   1.147 +
   1.148          val notes =
   1.149 -          [(coitersN, coiter_thmss, []),
   1.150 +          [(coinductN, map single coinduct_thms, []), (* FIXME: attribs *)
   1.151 +           (coitersN, coiter_thmss, []),
   1.152             (disc_coitersN, disc_coiter_thmss, []),
   1.153             (sel_coitersN, map flat sel_coiter_thmsss, []),
   1.154             (corecsN, corec_thmss, []),