1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Wed Sep 12 23:06:39 2012 +0200
1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML Wed Sep 12 23:06:39 2012 +0200
1.3 @@ -41,17 +41,18 @@
1.4 fun strip_map_type (Type (@{type_name fun}, [T as Type _, T'])) = strip_map_type T' |>> cons T
1.5 | strip_map_type T = ([], T);
1.6
1.7 +fun resort_tfree S (TFree (s, _)) = TFree (s, S);
1.8 +
1.9 fun typ_subst inst (T as Type (s, Ts)) =
1.10 (case AList.lookup (op =) inst T of
1.11 NONE => Type (s, map (typ_subst inst) Ts)
1.12 | SOME T' => T')
1.13 | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
1.14
1.15 -fun resort_tfree S (TFree (s, _)) = TFree (s, S);
1.16 -
1.17 val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
1.18
1.19 fun mk_id T = Const (@{const_name id}, T --> T);
1.20 +fun mk_id_fun T = Abs (Name.uu, T, Bound 0);
1.21
1.22 fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
1.23 fun mk_uncurried_fun f xs = mk_tupled_fun (HOLogic.mk_tuple xs) f xs;
1.24 @@ -65,6 +66,8 @@
1.25 Term.lambda z (mk_sum_case (Term.lambda v v, Term.lambda c (f $ c)) $ z)
1.26 end;
1.27
1.28 +fun fold_def_rule n thm = funpow n (fn thm => thm RS fun_cong) (thm RS meta_eq_to_obj_eq) RS sym;
1.29 +
1.30 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
1.31
1.32 fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
1.33 @@ -153,7 +156,7 @@
1.34 val rhs_As' = fold (fold (fold Term.add_tfreesT)) fake_ctr_Tsss [];
1.35 val _ = (case subtract (op =) (map dest_TFree As) rhs_As' of
1.36 [] => ()
1.37 - | A' :: _ => error ("Extra type variables on rhs: " ^
1.38 + | A' :: _ => error ("Extra type variable on right-hand side: " ^
1.39 quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
1.40
1.41 fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
1.42 @@ -494,6 +497,7 @@
1.43 end;
1.44
1.45 val pre_map_defs = map map_def_of_bnf pre_bnfs;
1.46 + val pre_set_defss = map set_defs_of_bnf pre_bnfs;
1.47 val map_ids = map map_id_of_bnf nested_bnfs;
1.48
1.49 fun mk_map Ts Us t =
1.50 @@ -514,7 +518,29 @@
1.51 let
1.52 val (induct_thms, induct_thm) =
1.53 let
1.54 - val induct_thm = fp_induct;
1.55 + val sym_ctr_defss = map2 (map2 fold_def_rule) mss ctr_defss;
1.56 +
1.57 + val ss = @{simpset} |> fold Simplifier.add_simp
1.58 + @{thms collect_def[abs_def] sum_setl_def[abs_def] sum_setr_def[abs_def]
1.59 + fsts_def[abs_def] snds_def[abs_def] False_imp_eq all_point_1};
1.60 +
1.61 + val induct_thm0 = fp_induct OF (map mk_sumEN_tupled_balanced mss);
1.62 +
1.63 + val spurious_fs =
1.64 + Term.add_vars (prop_of induct_thm0) []
1.65 + |> filter (fn (_, Type (@{type_name fun}, [_, T'])) => T' <> HOLogic.boolT
1.66 + | _ => false);
1.67 +
1.68 + val cxs =
1.69 + map (fn s as (_, T) =>
1.70 + (certify lthy (Var s), certify lthy (mk_id_fun (domain_type T)))) spurious_fs;
1.71 +
1.72 + val induct_thm =
1.73 + Drule.cterm_instantiate cxs induct_thm0
1.74 + |> Tactic.rule_by_tactic lthy (ALLGOALS (REPEAT_DETERM o bound_hyp_subst_tac))
1.75 + |> Local_Defs.unfold lthy
1.76 + (@{thm triv_forall_equality} :: flat sym_ctr_defss @ flat pre_set_defss)
1.77 + |> Simplifier.full_simplify ss;
1.78 in
1.79 `(conj_dests N) induct_thm
1.80 end;
1.81 @@ -540,7 +566,7 @@
1.82 fun mk_U maybe_mk_prodT =
1.83 typ_subst (map2 (fn fpT => fn C => (fpT, maybe_mk_prodT fpT C)) fpTs Cs);
1.84
1.85 - fun repair_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
1.86 + fun intr_calls fiter_likes maybe_cons maybe_tick maybe_mk_prodT (x as Free (_, T)) =
1.87 if member (op =) fpTs T then
1.88 maybe_cons x [build_call fiter_likes (K I) (T, mk_U (K I) T) $ x]
1.89 else if exists_subtype (member (op =) fpTs) T then
1.90 @@ -548,9 +574,8 @@
1.91 else
1.92 [x];
1.93
1.94 - val gxsss = map (map (maps (repair_calls giters (K I) (K I) (K I)))) xsss;
1.95 - val hxsss =
1.96 - map (map (maps (repair_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
1.97 + val gxsss = map (map (maps (intr_calls giters (K I) (K I) (K I)))) xsss;
1.98 + val hxsss = map (map (maps (intr_calls hrecs cons tick (curry HOLogic.mk_prodT)))) xsss;
1.99
1.100 val goal_iterss = map5 (map4 o mk_goal_iter_like gss) giters xctrss gss xsss gxsss;
1.101 val goal_recss = map5 (map4 o mk_goal_iter_like hss) hrecs xctrss hss xsss hxsss;
1.102 @@ -567,13 +592,13 @@
1.103 end;
1.104
1.105 val common_notes =
1.106 - [(inductN, [induct_thm], []), (*### attribs *)
1.107 - (inductsN, induct_thms, [])] (*### attribs *)
1.108 + (if N > 1 then [(inductN, [induct_thm], [])] (* FIXME: attribs *) else [])
1.109 |> map (fn (thmN, thms, attrs) =>
1.110 ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
1.111
1.112 val notes =
1.113 - [(itersN, iter_thmss, simp_attrs),
1.114 + [(inductN, map single induct_thms, []), (* FIXME: attribs *)
1.115 + (itersN, iter_thmss, simp_attrs),
1.116 (recsN, rec_thmss, Code.add_default_eqn_attrib :: simp_attrs)]
1.117 |> maps (fn (thmN, thmss, attrs) =>
1.118 map2 (fn b => fn thms =>
1.119 @@ -617,7 +642,7 @@
1.120 fun mk_U maybe_mk_sumT =
1.121 typ_subst (map2 (fn C => fn fpT => (maybe_mk_sumT fpT C, fpT)) Cs fpTs);
1.122
1.123 - fun repair_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
1.124 + fun intr_calls fiter_likes maybe_mk_sumT maybe_tack cqf =
1.125 let val T = fastype_of cqf in
1.126 if exists_subtype (member (op =) Cs) T then
1.127 build_call fiter_likes maybe_tack (T, mk_U maybe_mk_sumT T) $ cqf
1.128 @@ -625,8 +650,8 @@
1.129 cqf
1.130 end;
1.131
1.132 - val crgsss' = map (map (map (repair_calls gcoiters (K I) (K I)))) crgsss;
1.133 - val cshsss' = map (map (map (repair_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
1.134 + val crgsss' = map (map (map (intr_calls gcoiters (K I) (K I)))) crgsss;
1.135 + val cshsss' = map (map (map (intr_calls hcorecs (curry mk_sumT) (tack z)))) cshsss;
1.136
1.137 val goal_coiterss =
1.138 map8 (map4 oooo mk_goal_coiter_like pgss) cs cpss gcoiters ns kss ctrss mss crgsss';
1.139 @@ -672,8 +697,14 @@
1.140 val sel_corec_thmsss =
1.141 map3 (map3 (map2 o mk_sel_coiter_like_thm)) corec_thmss selsss sel_thmsss;
1.142
1.143 + val common_notes =
1.144 + (if N > 1 then [(coinductN, [coinduct_thm], [])] (* FIXME: attribs *) else [])
1.145 + |> map (fn (thmN, thms, attrs) =>
1.146 + ((Binding.qualify true fp_common_name (Binding.name thmN), attrs), [(thms, [])]));
1.147 +
1.148 val notes =
1.149 - [(coitersN, coiter_thmss, []),
1.150 + [(coinductN, map single coinduct_thms, []), (* FIXME: attribs *)
1.151 + (coitersN, coiter_thmss, []),
1.152 (disc_coitersN, disc_coiter_thmss, []),
1.153 (sel_coitersN, map flat sel_coiter_thmsss, []),
1.154 (corecsN, corec_thmss, []),