1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Fri Oct 04 17:00:35 2013 +0200
1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Fri Oct 04 18:27:07 2013 +0200
1.3 @@ -34,10 +34,10 @@
1.4 open BNF_FP_Rec_Sugar_Util
1.5 open BNF_FP_Rec_Sugar_Tactics
1.6
1.7 -val codeN = "code"
1.8 -val ctrN = "ctr"
1.9 -val discN = "disc"
1.10 -val selN = "sel"
1.11 +val codeN = "code";
1.12 +val ctrN = "ctr";
1.13 +val discN = "disc";
1.14 +val selN = "sel";
1.15
1.16 val nitpick_attrs = @{attributes [nitpick_simp]};
1.17 val simp_attrs = @{attributes [simp]};
1.18 @@ -472,7 +472,7 @@
1.19 Disc of co_eqn_data_disc |
1.20 Sel of co_eqn_data_sel;
1.21
1.22 -fun co_dissect_eqn_disc sequential fun_names (corec_specs : corec_spec list) prems' concl
1.23 +fun co_dissect_eqn_disc seq fun_names (corec_specs : corec_spec list) prems' concl
1.24 matchedsss =
1.25 let
1.26 fun find_subterm p = let (* FIXME \<exists>? *)
1.27 @@ -507,11 +507,11 @@
1.28 val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
1.29 val prems = map (abstract (List.rev fun_args)) prems';
1.30 val real_prems =
1.31 - (if catch_all orelse sequential then maps negate_disj matchedss else []) @
1.32 + (if catch_all orelse seq then maps negate_disj matchedss else []) @
1.33 (if catch_all then [] else prems);
1.34
1.35 val matchedsss' = AList.delete (op =) fun_name matchedsss
1.36 - |> cons (fun_name, if sequential then matchedss @ [prems] else matchedss @ [real_prems]);
1.37 + |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]);
1.38
1.39 val user_eqn =
1.40 (real_prems, betapply (#disc (nth ctr_specs ctr_no), applied_fun))
1.41 @@ -560,49 +560,72 @@
1.42 }
1.43 end;
1.44
1.45 -fun co_dissect_eqn_ctr sequential fun_names (corec_specs : corec_spec list) eqn' imp_prems imp_rhs
1.46 +fun co_dissect_eqn_ctr seq fun_names (corec_specs : corec_spec list) eqn' prems concl
1.47 matchedsss =
1.48 let
1.49 - val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
1.50 + val (lhs, rhs) = HOLogic.dest_eq concl;
1.51 val fun_name = head_of lhs |> fst o dest_Free;
1.52 val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
1.53 val (ctr, ctr_args) = strip_comb rhs;
1.54 val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
1.55 handle Option.Option => primrec_error_eqn "not a constructor" ctr;
1.56
1.57 - val disc_imp_rhs = betapply (disc, lhs);
1.58 + val disc_concl = betapply (disc, lhs);
1.59 val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
1.60 then (NONE, matchedsss)
1.61 else apfst SOME (co_dissect_eqn_disc
1.62 - sequential fun_names corec_specs imp_prems disc_imp_rhs matchedsss);
1.63 + seq fun_names corec_specs prems disc_concl matchedsss);
1.64
1.65 - val sel_imp_rhss = (sels ~~ ctr_args)
1.66 + val sel_concls = (sels ~~ ctr_args)
1.67 |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
1.68
1.69 (*
1.70 -val _ = tracing ("reduced\n " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n \<cdot> " ^
1.71 - (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n \<cdot> ")) "" ^
1.72 - space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
1.73 +val _ = tracing ("reduced\n " ^ Syntax.string_of_term @{context} concl ^ "\nto\n \<cdot> " ^
1.74 + (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n \<cdot> ")) "" ^
1.75 + space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls));
1.76 *)
1.77
1.78 - val eqns_data_sel =
1.79 - map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_imp_rhss;
1.80 + val eqns_data_sel = map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_concls;
1.81 in
1.82 (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
1.83 end;
1.84
1.85 -fun co_dissect_eqn sequential fun_names (corec_specs : corec_spec list) eqn' of_spec matchedsss =
1.86 +fun co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss =
1.87 + let
1.88 + val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
1.89 + val fun_name = head_of lhs |> fst o dest_Free;
1.90 + val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
1.91 +
1.92 + val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
1.93 + if member ((op =) o apsnd #ctr) ctr_specs ctr
1.94 + then cons (ctr, cs)
1.95 + else primrec_error_eqn "not a constructor" ctr) [] rhs' []
1.96 + |> AList.group (op =);
1.97 +
1.98 + val ctr_premss = map (single o mk_disjs o map mk_conjs o snd) cond_ctrs;
1.99 + val ctr_concls = cond_ctrs |> map (fn (ctr, _) =>
1.100 + binder_types (fastype_of ctr)
1.101 + |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
1.102 + if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
1.103 + |> curry list_comb ctr
1.104 + |> curry HOLogic.mk_eq lhs);
1.105 + in
1.106 + fold_map2 (co_dissect_eqn_ctr false fun_names corec_specs eqn') ctr_premss ctr_concls matchedsss
1.107 + end;
1.108 +
1.109 +fun co_dissect_eqn lthy has_call seq fun_names (corec_specs : corec_spec list) eqn' of_spec
1.110 + matchedsss =
1.111 let
1.112 val eqn = drop_All eqn'
1.113 handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
1.114 - val (imp_prems, imp_rhs) = Logic.strip_horn eqn
1.115 + val (prems, concl) = Logic.strip_horn eqn
1.116 |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
1.117
1.118 - val head = imp_rhs
1.119 + val head = concl
1.120 |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
1.121 |> head_of;
1.122
1.123 - val maybe_rhs = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
1.124 + val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
1.125
1.126 val discs = maps #ctr_specs corec_specs |> map #disc;
1.127 val sels = maps #ctr_specs corec_specs |> maps #sels;
1.128 @@ -611,12 +634,17 @@
1.129 if member (op =) discs head orelse
1.130 is_some maybe_rhs andalso
1.131 member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
1.132 - co_dissect_eqn_disc sequential fun_names corec_specs imp_prems imp_rhs matchedsss
1.133 + co_dissect_eqn_disc seq fun_names corec_specs prems concl matchedsss
1.134 |>> single
1.135 else if member (op =) sels head then
1.136 - ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec imp_rhs], matchedsss)
1.137 - else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
1.138 - co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss
1.139 + ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec concl], matchedsss)
1.140 + else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
1.141 + member (op =) ctrs (head_of (the maybe_rhs)) then
1.142 + co_dissect_eqn_ctr seq fun_names corec_specs eqn' prems concl matchedsss
1.143 + else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
1.144 + null prems then
1.145 + co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss
1.146 + |>> flat
1.147 else
1.148 primrec_error_eqn "malformed function equation" eqn
1.149 end;
1.150 @@ -646,7 +674,7 @@
1.151 fun rewrite_g _ t = if has_call t then undef_const else t;
1.152 fun rewrite_h bound_Ts t =
1.153 if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
1.154 - fun massage f t = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
1.155 + fun massage f _ = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
1.156 in
1.157 (massage rewrite_q,
1.158 massage rewrite_g,
1.159 @@ -763,7 +791,7 @@
1.160 chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
1.161 end;
1.162
1.163 -fun add_primcorec simple sequential fixes specs of_specs lthy =
1.164 +fun add_primcorec simple seq fixes specs of_specs lthy =
1.165 let
1.166 val (bs, mxs) = map_split (apfst fst) fixes;
1.167 val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
1.168 @@ -778,8 +806,10 @@
1.169 val fun_names = map Binding.name_of bs;
1.170 val corec_specs = take actual_nn corec_specs'; (*###*)
1.171
1.172 + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
1.173 val eqns_data =
1.174 - fold_map2 (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) of_specs []
1.175 + fold_map2 (co_dissect_eqn lthy has_call seq fun_names corec_specs)
1.176 + (map snd specs) of_specs []
1.177 |> flat o fst;
1.178
1.179 val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
1.180 @@ -796,14 +826,13 @@
1.181 |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
1.182 |> map (flat o snd);
1.183
1.184 - val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
1.185 val arg_Tss = map (binder_types o snd o fst) fixes;
1.186 val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
1.187 val (defs, exclss') =
1.188 co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
1.189
1.190 fun excl_tac (c, c', a) =
1.191 - if a orelse c = c' orelse sequential then
1.192 + if a orelse c = c' orelse seq then
1.193 SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy [])))
1.194 else if simple then
1.195 SOME (K (auto_tac lthy))