1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Tue Sep 03 21:46:42 2013 +0100
1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Wed Sep 04 02:11:50 2013 +0200
1.3 @@ -36,8 +36,7 @@
1.4
1.5 fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
1.6 |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
1.7 -fun abs_tuple t = if t = undef_const then t else
1.8 - strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
1.9 +val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
1.10
1.11 val simp_attrs = @{attributes [simp]};
1.12
1.13 @@ -107,7 +106,7 @@
1.14 user_eqn = eqn'}
1.15 end;
1.16
1.17 -fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type =
1.18 +fun rewrite_map_arg get_ctr_pos rec_type res_type =
1.19 let
1.20 val pT = HOLogic.mk_prodT (rec_type, res_type);
1.21
1.22 @@ -117,11 +116,9 @@
1.23 | subst d t =
1.24 let
1.25 val (u, vs) = strip_comb t;
1.26 - val maybe_fun_name_ctr_pos =
1.27 - find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list;
1.28 - val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos;
1.29 + val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
1.30 in
1.31 - if is_some maybe_fun_name_ctr_pos then
1.32 + if ctr_pos >= 0 then
1.33 if d = SOME ~1 andalso length vs = ctr_pos then
1.34 list_comb (permute_args ctr_pos (snd_const pT), vs)
1.35 else if length vs > ctr_pos andalso is_some d
1.36 @@ -138,7 +135,7 @@
1.37 subst (SOME ~1)
1.38 end;
1.39
1.40 -fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t =
1.41 +fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
1.42 let
1.43 fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
1.44 | subst bound_Ts (t as g' $ y) =
1.45 @@ -146,19 +143,18 @@
1.46 val maybe_direct_y' = AList.lookup (op =) direct_calls y;
1.47 val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
1.48 val (g, g_args) = strip_comb g';
1.49 - val maybe_ctr_pos =
1.50 - try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list;
1.51 - val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse
1.52 + val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
1.53 + val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
1.54 primrec_error_eqn "too few arguments in recursive call" t;
1.55 in
1.56 if not (member (op =) ctr_args y) then
1.57 pairself (subst bound_Ts) (g', y) |> (op $)
1.58 - else if is_some maybe_ctr_pos then
1.59 + else if ctr_pos >= 0 then
1.60 list_comb (the maybe_direct_y', g_args)
1.61 else if is_some maybe_indirect_y' then
1.62 (if has_call g' then t else y)
1.63 |> massage_indirect_rec_call lthy has_call
1.64 - (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y')
1.65 + (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
1.66 |> (if has_call g' then I else curry (op $) g')
1.67 else
1.68 t
1.69 @@ -211,16 +207,17 @@
1.70 nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
1.71 indirect_calls';
1.72
1.73 + val fun_name_ctr_pos_list =
1.74 + map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
1.75 + val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
1.76 val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
1.77 val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
1.78
1.79 - val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
1.80 - val fun_name_ctr_pos_list =
1.81 - map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
1.82 + val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
1.83 in
1.84 t
1.85 - |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls
1.86 - |> fold_rev absfree abstractions
1.87 + |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
1.88 + |> fold_rev lambda abstractions
1.89 end;
1.90
1.91 fun build_defs lthy bs mxs funs_data rec_specs has_call =
1.92 @@ -372,15 +369,16 @@
1.93
1.94 type co_eqn_data_disc = {
1.95 fun_name: string,
1.96 + fun_args: term list,
1.97 ctr_no: int, (*###*)
1.98 cond: term,
1.99 user_eqn: term
1.100 };
1.101 type co_eqn_data_sel = {
1.102 fun_name: string,
1.103 + fun_args: term list,
1.104 ctr: term,
1.105 sel: term,
1.106 - fun_args: term list,
1.107 rhs_term: term,
1.108 user_eqn: term
1.109 };
1.110 @@ -388,11 +386,10 @@
1.111 Disc of co_eqn_data_disc |
1.112 Sel of co_eqn_data_sel;
1.113
1.114 -fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
1.115 +fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
1.116 let
1.117 fun find_subterm p = let (* FIXME \<exists>? *)
1.118 - fun f (t as u $ v) =
1.119 - fold_rev (curry merge_options) [if p t then SOME t else NONE, f u, f v] NONE
1.120 + fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
1.121 | f t = if p t then SOME t else NONE
1.122 in f end;
1.123
1.124 @@ -406,9 +403,8 @@
1.125
1.126 val discs = #ctr_specs corec_spec |> map #disc;
1.127 val ctrs = #ctr_specs corec_spec |> map #ctr;
1.128 - val n_ctrs = length ctrs;
1.129 val not_disc = head_of imp_rhs = @{term Not};
1.130 - val _ = not_disc andalso n_ctrs <> 2 andalso
1.131 + val _ = not_disc andalso length ctrs <> 2 andalso
1.132 primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" imp_rhs;
1.133 val disc = find_subterm (member (op =) discs o head_of) imp_rhs;
1.134 val eq_ctr0 = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
1.135 @@ -428,32 +424,28 @@
1.136 val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
1.137 val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
1.138 val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;
1.139 - val matched_conds = filter (equal fun_name o fst) matched_conds_ps |> map snd;
1.140 - val imp_lhs = mk_conjs imp_lhs';
1.141 + val matched_cond = filter (equal fun_name o fst) matched_conds |> map snd |> mk_disjs;
1.142 + val imp_lhs = mk_conjs imp_lhs'
1.143 + |> incr_boundvars (length fun_args)
1.144 + |> subst_atomic (fun_args ~~ map Bound (length fun_args - 1 downto 0))
1.145 val cond =
1.146 if catch_all then
1.147 - if null matched_conds then fold_rev absfree (map dest_Free fun_args) @{const True} else
1.148 - (strip_abs_vars (hd matched_conds),
1.149 - mk_disjs (map strip_abs_body matched_conds) |> HOLogic.mk_not)
1.150 - |-> fold_rev (fn (v, T) => fn u => Abs (v, T, u))
1.151 + matched_cond |> HOLogic.mk_not
1.152 else if sequential then
1.153 - HOLogic.mk_conj (HOLogic.mk_not (mk_disjs (map strip_abs_body matched_conds)), imp_lhs)
1.154 - |> fold_rev absfree (map dest_Free fun_args)
1.155 + HOLogic.mk_conj (HOLogic.mk_not matched_cond, imp_lhs)
1.156 else
1.157 - imp_lhs |> fold_rev absfree (map dest_Free fun_args);
1.158 - val matched_cond =
1.159 - if sequential then fold_rev absfree (map dest_Free fun_args) imp_lhs else cond;
1.160 + imp_lhs;
1.161
1.162 - val matched_conds_ps' = if catch_all
1.163 - then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps
1.164 - else (fun_name, matched_cond) :: matched_conds_ps;
1.165 + val matched_conds' =
1.166 + (fun_name, if catch_all orelse not sequential then cond else imp_lhs) :: matched_conds;
1.167 in
1.168 (Disc {
1.169 fun_name = fun_name,
1.170 + fun_args = fun_args,
1.171 ctr_no = ctr_no,
1.172 cond = cond,
1.173 user_eqn = eqn'
1.174 - }, matched_conds_ps')
1.175 + }, matched_conds')
1.176 end;
1.177
1.178 fun co_dissect_eqn_sel fun_name_corec_spec_list eqn' eqn =
1.179 @@ -473,15 +465,15 @@
1.180 in
1.181 Sel {
1.182 fun_name = fun_name,
1.183 + fun_args = fun_args,
1.184 ctr = #ctr ctr_spec,
1.185 sel = sel,
1.186 - fun_args = fun_args,
1.187 rhs_term = rhs,
1.188 user_eqn = eqn'
1.189 }
1.190 end;
1.191
1.192 -fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
1.193 +fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
1.194 let
1.195 val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
1.196 val fun_name = head_of lhs |> fst o dest_Free;
1.197 @@ -491,10 +483,10 @@
1.198 handle Option.Option => primrec_error_eqn "not a constructor" ctr;
1.199
1.200 val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
1.201 - val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1
1.202 - then (NONE, matched_conds_ps)
1.203 + val (maybe_eqn_data_disc, matched_conds') = if length (#ctr_specs corec_spec) = 1
1.204 + then (NONE, matched_conds)
1.205 else apfst SOME (co_dissect_eqn_disc
1.206 - sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps);
1.207 + sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds);
1.208
1.209 val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
1.210 |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
1.211 @@ -506,10 +498,10 @@
1.212 val eqns_data_sel =
1.213 map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
1.214 in
1.215 - (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps')
1.216 + (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds')
1.217 end;
1.218
1.219 -fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps =
1.220 +fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds =
1.221 let
1.222 val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
1.223 strip_qnt_body @{const_name all} eqn')
1.224 @@ -531,65 +523,68 @@
1.225 if member (op =) discs head orelse
1.226 is_some maybe_rhs andalso
1.227 member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
1.228 - co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
1.229 + co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
1.230 |>> single
1.231 else if member (op =) sels head then
1.232 - ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds_ps)
1.233 + ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds)
1.234 else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
1.235 - co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
1.236 + co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
1.237 else
1.238 primrec_error_eqn "malformed function equation" eqn
1.239 end;
1.240
1.241 fun build_corec_args_discs disc_eqns ctr_specs =
1.242 - let
1.243 - val conds = map #cond disc_eqns;
1.244 - val args' =
1.245 - if length ctr_specs = 1 then []
1.246 - else if length disc_eqns = length ctr_specs then
1.247 - fst (split_last conds)
1.248 - else if length disc_eqns = length ctr_specs - 1 then
1.249 - let val n = 0 upto length ctr_specs - 1
1.250 - |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
1.251 - if n = length ctr_specs - 1 then
1.252 - conds
1.253 - else
1.254 - split_last conds
1.255 - ||> (fn t => fold_rev absfree (strip_abs_vars t) (strip_abs_body t |> HOLogic.mk_not))
1.256 - |>> chop n
1.257 - |> (fn ((l, r), x) => l @ (x :: r))
1.258 - end
1.259 - else
1.260 - 0 upto length ctr_specs - 1
1.261 - |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
1.262 - |> Option.map #cond
1.263 - |> the_default undef_const)
1.264 - |> fst o split_last;
1.265 - in
1.266 - (* FIXME: deal with #preds above *)
1.267 - fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
1.268 - end;
1.269 + if null disc_eqns then I else
1.270 + let
1.271 + val conds = map #cond disc_eqns;
1.272 + val fun_args = #fun_args (hd disc_eqns);
1.273 + val args =
1.274 + if length ctr_specs = 1 then []
1.275 + else if length disc_eqns = length ctr_specs then
1.276 + fst (split_last conds)
1.277 + else if length disc_eqns = length ctr_specs - 1 then
1.278 + let val n = 0 upto length ctr_specs - 1
1.279 + |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
1.280 + if n = length ctr_specs - 1 then
1.281 + conds
1.282 + else
1.283 + split_last conds
1.284 + ||> HOLogic.mk_not
1.285 + |>> chop n
1.286 + |> (fn ((l, r), x) => l @ (x :: r))
1.287 + end
1.288 + else
1.289 + 0 upto length ctr_specs - 1
1.290 + |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
1.291 + |> Option.map #cond
1.292 + |> the_default undef_const)
1.293 + |> fst o split_last;
1.294 + in
1.295 + (* FIXME deal with #preds above *)
1.296 + (map_filter #pred ctr_specs, args)
1.297 + |-> fold2 (fn idx => fn t => nth_map idx
1.298 + (K (subst_bounds (List.rev fun_args, t)
1.299 + |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args))))
1.300 + end;
1.301
1.302 fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns
1.303 - |> try (fn SOME sel_eqn => (#fun_args sel_eqn |> map dest_Free, #rhs_term sel_eqn))
1.304 + |> try (fn SOME sel_eqn => (#fun_args sel_eqn, #rhs_term sel_eqn))
1.305 |> the_default ([], undef_const)
1.306 - |-> abs_tuple oo fold_rev absfree;
1.307 + |-> abs_tuple;
1.308
1.309 fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
1.310 let
1.311 val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns
1.312 -
1.313 - fun rewrite U T t =
1.314 + fun rewrite is_end U T t =
1.315 if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *)
1.316 - else if T = U = has_call t then undef_const
1.317 - else if T = U then t (* end *)
1.318 + else if is_end = has_call t then undef_const
1.319 + else if is_end then t (* end *)
1.320 else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
1.321 - fun massage rhs_term t =
1.322 - massage_direct_corec_call lthy has_call rewrite [] (body_type (fastype_of t)) rhs_term;
1.323 - val abstract = abs_tuple oo fold_rev absfree o map dest_Free;
1.324 + fun massage rhs_term is_end t = massage_direct_corec_call
1.325 + lthy has_call (rewrite is_end) [] (range_type (fastype_of t)) rhs_term;
1.326 in
1.327 - if is_none maybe_sel_eqn then I else
1.328 - massage (#rhs_term (the maybe_sel_eqn)) #> abstract (#fun_args (the maybe_sel_eqn))
1.329 + if is_none maybe_sel_eqn then K I else
1.330 + abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
1.331 end;
1.332
1.333 fun build_corec_arg_indirect_call sel_eqns sel =
1.334 @@ -614,7 +609,7 @@
1.335 (build_corec_arg_no_call sel_eqns sel |> K)) no_calls'
1.336 #> fold (fn (sel, (q, g, h)) =>
1.337 let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
1.338 - nth_map h f o nth_map g f o nth_map q f end) direct_calls'
1.339 + nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
1.340 #> fold (fn (sel, n) => nth_map n
1.341 (build_corec_arg_indirect_call sel_eqns sel |> K)) indirect_calls'
1.342 end
1.343 @@ -651,24 +646,25 @@
1.344 |> fold2 build_corec_args_discs disc_eqnss ctr_specss
1.345 |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
1.346
1.347 + fun currys Ts t = if length Ts <= 1 then t else
1.348 + t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
1.349 + (length Ts - 1 downto 0 |> map Bound)
1.350 + |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
1.351 +
1.352 val _ = tracing ("corecursor arguments:\n \<cdot> " ^
1.353 space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));
1.354
1.355 fun uneq_pairs_rev xs = xs (* FIXME \<exists>? *)
1.356 |> these o try (split_last #> (fn (ys, y) => uneq_pairs_rev ys @ map (pair y) ys));
1.357 val proof_obligations = if sequential then [] else
1.358 - maps (uneq_pairs_rev o map #cond) disc_eqnss
1.359 - |> map (fn (x, y) => ((strip_abs_body x, strip_abs_body y), strip_abs_vars x))
1.360 - |> map (apfst (apsnd HOLogic.mk_not #> pairself HOLogic.mk_Trueprop
1.361 - #> apfst (curry (op $) @{const ==>}) #> (op $)))
1.362 - |> map (fn (t, abs_vars) => fold_rev (fn (v, T) => fn u =>
1.363 - Const (@{const_name all}, (T --> @{typ prop}) --> @{typ prop}) $
1.364 - Abs (v, T, u)) abs_vars t);
1.365 + maps (uneq_pairs_rev o map (fn {fun_args, cond, ...} => (fun_args, cond))) disc_eqnss
1.366 + |> map (fn ((fun_args, x), (_, y)) => [x, HOLogic.mk_not y]
1.367 + |> map (HOLogic.mk_Trueprop o curry subst_bounds (List.rev fun_args))
1.368 + |> curry list_comb @{const ==>});
1.369
1.370 - fun currys Ts t = if length Ts <= 1 then t else
1.371 - t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
1.372 - (length Ts - 1 downto 0 |> map Bound)
1.373 - |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
1.374 +val _ = tracing ("proof obligations:\n \<cdot> " ^
1.375 + space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) proof_obligations));
1.376 +
1.377 in
1.378 map (list_comb o rpair corec_args) corecs
1.379 |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss