split 'primrec_new' and 'primcorec' code (to ease bootstrapping, e.g. dependency on datatype 'String' in 'primcorec')
1.1 --- a/src/HOL/BNF/BNF_FP_Base.thy Mon Nov 04 15:44:43 2013 +0100
1.2 +++ b/src/HOL/BNF/BNF_FP_Base.thy Mon Nov 04 16:53:43 2013 +0100
1.3 @@ -172,7 +172,5 @@
1.4 ML_file "Tools/bnf_fp_n2m.ML"
1.5 ML_file "Tools/bnf_fp_n2m_sugar.ML"
1.6 ML_file "Tools/bnf_fp_rec_sugar_util.ML"
1.7 -ML_file "Tools/bnf_fp_rec_sugar_tactics.ML"
1.8 -ML_file "Tools/bnf_fp_rec_sugar.ML"
1.9
1.10 end
2.1 --- a/src/HOL/BNF/BNF_GFP.thy Mon Nov 04 15:44:43 2013 +0100
2.2 +++ b/src/HOL/BNF/BNF_GFP.thy Mon Nov 04 16:53:43 2013 +0100
2.3 @@ -308,6 +308,8 @@
2.4 lemma fun_rel_image2p: "(fun_rel R (image2p f g R)) f g"
2.5 unfolding fun_rel_def image2p_def by auto
2.6
2.7 +ML_file "Tools/bnf_gfp_rec_sugar_tactics.ML"
2.8 +ML_file "Tools/bnf_gfp_rec_sugar.ML"
2.9 ML_file "Tools/bnf_gfp_util.ML"
2.10 ML_file "Tools/bnf_gfp_tactics.ML"
2.11 ML_file "Tools/bnf_gfp.ML"
3.1 --- a/src/HOL/BNF/BNF_LFP.thy Mon Nov 04 15:44:43 2013 +0100
3.2 +++ b/src/HOL/BNF/BNF_LFP.thy Mon Nov 04 16:53:43 2013 +0100
3.3 @@ -230,6 +230,7 @@
3.4 lemma predicate2D_vimage2p: "\<lbrakk>R \<le> vimage2p f g S; R x y\<rbrakk> \<Longrightarrow> S (f x) (g y)"
3.5 unfolding vimage2p_def by auto
3.6
3.7 +ML_file "Tools/bnf_lfp_rec_sugar.ML"
3.8 ML_file "Tools/bnf_lfp_util.ML"
3.9 ML_file "Tools/bnf_lfp_tactics.ML"
3.10 ML_file "Tools/bnf_lfp.ML"
4.1 --- a/src/HOL/BNF/Tools/bnf_def.ML Mon Nov 04 15:44:43 2013 +0100
4.2 +++ b/src/HOL/BNF/Tools/bnf_def.ML Mon Nov 04 16:53:43 2013 +0100
4.3 @@ -81,6 +81,9 @@
4.4 val mk_rel: int -> typ list -> typ list -> term -> term
4.5 val build_map: Proof.context -> (typ * typ -> term) -> typ * typ -> term
4.6 val build_rel: Proof.context -> (typ * typ -> term) -> typ * typ -> term
4.7 + val flatten_type_args_of_bnf: bnf -> 'a -> 'a list -> 'a list
4.8 + val map_flattened_map_args: Proof.context -> string -> (term list -> 'a list) -> term list ->
4.9 + 'a list
4.10
4.11 val mk_witness: int list * term -> thm list -> nonemptiness_witness
4.12 val minimize_wits: (''a list * 'b) list -> (''a list * 'b) list
4.13 @@ -88,8 +91,6 @@
4.14
4.15 val zip_axioms: 'a -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list
4.16
4.17 - val flatten_type_args_of_bnf: bnf -> 'a -> 'a list -> 'a list
4.18 -
4.19 datatype const_policy = Dont_Inline | Hardly_Inline | Smart_Inline | Do_Inline
4.20 datatype fact_policy = Dont_Note | Note_Some | Note_All
4.21
4.22 @@ -524,6 +525,14 @@
4.23 val build_map = build_map_or_rel mk_map HOLogic.id_const map_of_bnf dest_funT;
4.24 val build_rel = build_map_or_rel mk_rel HOLogic.eq_const rel_of_bnf dest_pred2T;
4.25
4.26 +fun map_flattened_map_args ctxt s map_args fs =
4.27 + let
4.28 + val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
4.29 + val flat_fs' = map_args flat_fs;
4.30 + in
4.31 + permute_like (op aconv) flat_fs fs flat_fs'
4.32 + end;
4.33 +
4.34
4.35 (* Names *)
4.36
5.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Mon Nov 04 15:44:43 2013 +0100
5.2 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000
5.3 @@ -1,1128 +0,0 @@
5.4 -(* Title: HOL/BNF/Tools/bnf_fp_rec_sugar.ML
5.5 - Author: Lorenz Panny, TU Muenchen
5.6 - Copyright 2013
5.7 -
5.8 -Recursor and corecursor sugar.
5.9 -*)
5.10 -
5.11 -signature BNF_FP_REC_SUGAR =
5.12 -sig
5.13 - val add_primrec: (binding * typ option * mixfix) list ->
5.14 - (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
5.15 - val add_primrec_cmd: (binding * string option * mixfix) list ->
5.16 - (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
5.17 - val add_primrec_global: (binding * typ option * mixfix) list ->
5.18 - (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
5.19 - val add_primrec_overloaded: (string * (string * typ) * bool) list ->
5.20 - (binding * typ option * mixfix) list ->
5.21 - (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
5.22 - val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
5.23 - local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
5.24 - val add_primcorecursive_cmd: bool ->
5.25 - (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
5.26 - Proof.context -> Proof.state
5.27 - val add_primcorec_cmd: bool ->
5.28 - (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
5.29 - local_theory -> local_theory
5.30 -end;
5.31 -
5.32 -structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR =
5.33 -struct
5.34 -
5.35 -open BNF_Util
5.36 -open BNF_FP_Util
5.37 -open BNF_FP_N2M_Sugar
5.38 -open BNF_FP_Rec_Sugar_Util
5.39 -open BNF_FP_Rec_Sugar_Tactics
5.40 -
5.41 -val codeN = "code";
5.42 -val ctrN = "ctr";
5.43 -val discN = "disc";
5.44 -val selN = "sel";
5.45 -
5.46 -val nitpicksimp_attrs = @{attributes [nitpick_simp]};
5.47 -val simp_attrs = @{attributes [simp]};
5.48 -val code_nitpicksimp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs;
5.49 -val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
5.50 -
5.51 -exception Primrec_Error of string * term list;
5.52 -
5.53 -fun primrec_error str = raise Primrec_Error (str, []);
5.54 -fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
5.55 -fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
5.56 -
5.57 -fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
5.58 -
5.59 -val free_name = try (fn Free (v, _) => v);
5.60 -val const_name = try (fn Const (v, _) => v);
5.61 -val undef_const = Const (@{const_name undefined}, dummyT);
5.62 -
5.63 -fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
5.64 - |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
5.65 -val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
5.66 -fun drop_All t = subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev,
5.67 - strip_qnt_body @{const_name all} t)
5.68 -fun abstract vs =
5.69 - let fun a n (t $ u) = a n t $ a n u
5.70 - | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b)
5.71 - | a n t = let val idx = find_index (equal t) vs in
5.72 - if idx < 0 then t else Bound (n + idx) end
5.73 - in a 0 end;
5.74 -fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u;
5.75 -fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts));
5.76 -
5.77 -fun get_indices fixes t = map (fst #>> Binding.name_of #> Free) fixes
5.78 - |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
5.79 - |> map_filter I;
5.80 -
5.81 -
5.82 -(* Primrec *)
5.83 -
5.84 -type eqn_data = {
5.85 - fun_name: string,
5.86 - rec_type: typ,
5.87 - ctr: term,
5.88 - ctr_args: term list,
5.89 - left_args: term list,
5.90 - right_args: term list,
5.91 - res_type: typ,
5.92 - rhs_term: term,
5.93 - user_eqn: term
5.94 -};
5.95 -
5.96 -fun dissect_eqn lthy fun_names eqn' =
5.97 - let
5.98 - val eqn = drop_All eqn' |> HOLogic.dest_Trueprop
5.99 - handle TERM _ =>
5.100 - primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
5.101 - val (lhs, rhs) = HOLogic.dest_eq eqn
5.102 - handle TERM _ =>
5.103 - primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
5.104 - val (fun_name, args) = strip_comb lhs
5.105 - |>> (fn x => if is_Free x then fst (dest_Free x)
5.106 - else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
5.107 - val (left_args, rest) = take_prefix is_Free args;
5.108 - val (nonfrees, right_args) = take_suffix is_Free rest;
5.109 - val num_nonfrees = length nonfrees;
5.110 - val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
5.111 - primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
5.112 - primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
5.113 - val _ = member (op =) fun_names fun_name orelse
5.114 - primrec_error_eqn "malformed function equation (does not start with function name)" eqn
5.115 -
5.116 - val (ctr, ctr_args) = strip_comb (the_single nonfrees);
5.117 - val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
5.118 - primrec_error_eqn "partially applied constructor in pattern" eqn;
5.119 - val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
5.120 - primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
5.121 - "\" in left-hand side") eqn end;
5.122 - val _ = forall is_Free ctr_args orelse
5.123 - primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
5.124 - val _ =
5.125 - let val b = fold_aterms (fn x as Free (v, _) =>
5.126 - if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
5.127 - not (member (op =) fun_names v) andalso
5.128 - not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
5.129 - in
5.130 - null b orelse
5.131 - primrec_error_eqn ("extra variable(s) in right-hand side: " ^
5.132 - commas (map (Syntax.string_of_term lthy) b)) eqn
5.133 - end;
5.134 - in
5.135 - {fun_name = fun_name,
5.136 - rec_type = body_type (type_of ctr),
5.137 - ctr = ctr,
5.138 - ctr_args = ctr_args,
5.139 - left_args = left_args,
5.140 - right_args = right_args,
5.141 - res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
5.142 - rhs_term = rhs,
5.143 - user_eqn = eqn'}
5.144 - end;
5.145 -
5.146 -fun rewrite_map_arg get_ctr_pos rec_type res_type =
5.147 - let
5.148 - val pT = HOLogic.mk_prodT (rec_type, res_type);
5.149 -
5.150 - val maybe_suc = Option.map (fn x => x + 1);
5.151 - fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
5.152 - | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
5.153 - | subst d t =
5.154 - let
5.155 - val (u, vs) = strip_comb t;
5.156 - val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
5.157 - in
5.158 - if ctr_pos >= 0 then
5.159 - if d = SOME ~1 andalso length vs = ctr_pos then
5.160 - list_comb (permute_args ctr_pos (snd_const pT), vs)
5.161 - else if length vs > ctr_pos andalso is_some d
5.162 - andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
5.163 - list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
5.164 - else
5.165 - primrec_error_eqn ("recursive call not directly applied to constructor argument") t
5.166 - else
5.167 - list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
5.168 - end
5.169 - in
5.170 - subst (SOME ~1)
5.171 - end;
5.172 -
5.173 -fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
5.174 - let
5.175 - fun try_nested_rec bound_Ts y t =
5.176 - AList.lookup (op =) nested_calls y
5.177 - |> Option.map (fn y' =>
5.178 - massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
5.179 -
5.180 - fun subst bound_Ts (t as g' $ y) =
5.181 - let
5.182 - fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
5.183 - val y_head = head_of y;
5.184 - in
5.185 - if not (member (op =) ctr_args y_head) then
5.186 - subst_rec ()
5.187 - else
5.188 - (case try_nested_rec bound_Ts y_head t of
5.189 - SOME t' => t'
5.190 - | NONE =>
5.191 - let val (g, g_args) = strip_comb g' in
5.192 - (case try (get_ctr_pos o the) (free_name g) of
5.193 - SOME ctr_pos =>
5.194 - (length g_args >= ctr_pos orelse
5.195 - primrec_error_eqn "too few arguments in recursive call" t;
5.196 - (case AList.lookup (op =) mutual_calls y of
5.197 - SOME y' => list_comb (y', g_args)
5.198 - | NONE => subst_rec ()))
5.199 - | NONE => subst_rec ())
5.200 - end)
5.201 - end
5.202 - | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
5.203 - | subst _ t = t
5.204 -
5.205 - fun subst' t =
5.206 - if has_call t then
5.207 - (* FIXME detect this case earlier? *)
5.208 - primrec_error_eqn "recursive call not directly applied to constructor argument" t
5.209 - else
5.210 - try_nested_rec [] (head_of t) t |> the_default t
5.211 - in
5.212 - subst' o subst []
5.213 - end;
5.214 -
5.215 -fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
5.216 - (maybe_eqn_data : eqn_data option) =
5.217 - (case maybe_eqn_data of
5.218 - NONE => undef_const
5.219 - | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
5.220 - let
5.221 - val calls = #calls ctr_spec;
5.222 - val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
5.223 -
5.224 - val no_calls' = tag_list 0 calls
5.225 - |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
5.226 - val mutual_calls' = tag_list 0 calls
5.227 - |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
5.228 - val nested_calls' = tag_list 0 calls
5.229 - |> map_filter (try (apsnd (fn Nested_Rec p => p)));
5.230 -
5.231 - val args = replicate n_args ("", dummyT)
5.232 - |> Term.rename_wrt_term t
5.233 - |> map Free
5.234 - |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
5.235 - nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
5.236 - no_calls'
5.237 - |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
5.238 - nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
5.239 - mutual_calls'
5.240 - |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
5.241 - nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
5.242 - nested_calls';
5.243 -
5.244 - val fun_name_ctr_pos_list =
5.245 - map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
5.246 - val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
5.247 - val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
5.248 - val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
5.249 - in
5.250 - t
5.251 - |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
5.252 - |> fold_rev lambda (args @ left_args @ right_args)
5.253 - end);
5.254 -
5.255 -fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
5.256 - let
5.257 - val n_funs = length funs_data;
5.258 -
5.259 - val ctr_spec_eqn_data_list' =
5.260 - (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
5.261 - |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
5.262 - ##> (fn x => null x orelse
5.263 - primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
5.264 - val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
5.265 - primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
5.266 -
5.267 - val ctr_spec_eqn_data_list =
5.268 - ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
5.269 -
5.270 - val recs = take n_funs rec_specs |> map #recx;
5.271 - val rec_args = ctr_spec_eqn_data_list
5.272 - |> sort ((op <) o pairself (#offset o fst) |> make_ord)
5.273 - |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
5.274 - val ctr_poss = map (fn x =>
5.275 - if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
5.276 - primrec_error ("inconstant constructor pattern position for function " ^
5.277 - quote (#fun_name (hd x)))
5.278 - else
5.279 - hd x |> #left_args |> length) funs_data;
5.280 - in
5.281 - (recs, ctr_poss)
5.282 - |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
5.283 - |> Syntax.check_terms lthy
5.284 - |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
5.285 - bs mxs
5.286 - end;
5.287 -
5.288 -fun find_rec_calls ctxt has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
5.289 - let
5.290 - fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
5.291 - | find bound_Ts (t as _ $ _) ctr_arg =
5.292 - let
5.293 - val typof = curry fastype_of1 bound_Ts;
5.294 - val (f', args') = strip_comb t;
5.295 - val n = find_index (equal ctr_arg o head_of) args';
5.296 - in
5.297 - if n < 0 then
5.298 - find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
5.299 - else
5.300 - let
5.301 - val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
5.302 - val (arg_head, arg_args) = Term.strip_comb arg;
5.303 - in
5.304 - if has_call f then
5.305 - mk_partial_compN (length arg_args) (typof f) (typof arg_head) f ::
5.306 - maps (fn x => find bound_Ts x ctr_arg) args
5.307 - else
5.308 - find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
5.309 - end
5.310 - end
5.311 - | find _ _ _ = [];
5.312 - in
5.313 - map (find [] rhs_term) ctr_args
5.314 - |> (fn [] => NONE | callss => SOME (ctr, callss))
5.315 - end;
5.316 -
5.317 -fun prepare_primrec fixes specs lthy =
5.318 - let
5.319 - val (bs, mxs) = map_split (apfst fst) fixes;
5.320 - val fun_names = map Binding.name_of bs;
5.321 - val eqns_data = map (dissect_eqn lthy fun_names) specs;
5.322 - val funs_data = eqns_data
5.323 - |> partition_eq ((op =) o pairself #fun_name)
5.324 - |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
5.325 - |> map (fn (x, y) => the_single y handle List.Empty =>
5.326 - primrec_error ("missing equations for function " ^ quote x));
5.327 -
5.328 - val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
5.329 - val arg_Ts = map (#rec_type o hd) funs_data;
5.330 - val res_Ts = map (#res_type o hd) funs_data;
5.331 - val callssss = funs_data
5.332 - |> map (partition_eq ((op =) o pairself #ctr))
5.333 - |> map (maps (map_filter (find_rec_calls lthy has_call)));
5.334 -
5.335 - val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
5.336 - rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
5.337 -
5.338 - val actual_nn = length funs_data;
5.339 -
5.340 - val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
5.341 - map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
5.342 - primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
5.343 - " is not a constructor in left-hand side") user_eqn) eqns_data end;
5.344 -
5.345 - val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
5.346 -
5.347 - fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
5.348 - (fun_data : eqn_data list) =
5.349 - let
5.350 - val def_thms = map (snd o snd) def_thms';
5.351 - val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
5.352 - |> fst
5.353 - |> map_filter (try (fn (x, [y]) =>
5.354 - (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
5.355 - |> map (fn (user_eqn, num_extra_args, rec_thm) =>
5.356 - mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
5.357 - |> K |> Goal.prove lthy [] [] user_eqn
5.358 - |> Thm.close_derivation);
5.359 - val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
5.360 - in
5.361 - (poss, simp_thmss)
5.362 - end;
5.363 -
5.364 - val notes =
5.365 - (if n2m then map2 (fn name => fn thm =>
5.366 - (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
5.367 - |> map (fn (prefix, thmN, thms, attrs) =>
5.368 - ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
5.369 -
5.370 - val common_name = mk_common_name fun_names;
5.371 -
5.372 - val common_notes =
5.373 - (if n2m then [(inductN, [induct_thm], [])] else [])
5.374 - |> map (fn (thmN, thms, attrs) =>
5.375 - ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
5.376 - in
5.377 - (((fun_names, defs),
5.378 - fn lthy => fn defs =>
5.379 - split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
5.380 - lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
5.381 - end;
5.382 -
5.383 -(* primrec definition *)
5.384 -
5.385 -fun add_primrec_simple fixes ts lthy =
5.386 - let
5.387 - val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
5.388 - handle ERROR str => primrec_error str;
5.389 - in
5.390 - lthy
5.391 - |> fold_map Local_Theory.define defs
5.392 - |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
5.393 - end
5.394 - handle Primrec_Error (str, eqns) =>
5.395 - if null eqns
5.396 - then error ("primrec_new error:\n " ^ str)
5.397 - else error ("primrec_new error:\n " ^ str ^ "\nin\n " ^
5.398 - space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns));
5.399 -
5.400 -local
5.401 -
5.402 -fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
5.403 - let
5.404 - val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
5.405 - val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
5.406 -
5.407 - val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
5.408 -
5.409 - val mk_notes =
5.410 - flat ooo map3 (fn poss => fn prefix => fn thms =>
5.411 - let
5.412 - val (bs, attrss) = map_split (fst o nth specs) poss;
5.413 - val notes =
5.414 - map3 (fn b => fn attrs => fn thm =>
5.415 - ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])]))
5.416 - bs attrss thms;
5.417 - in
5.418 - ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
5.419 - end);
5.420 - in
5.421 - lthy
5.422 - |> add_primrec_simple fixes (map snd specs)
5.423 - |-> (fn (names, (ts, (posss, simpss))) =>
5.424 - Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
5.425 - #> Local_Theory.notes (mk_notes posss names simpss)
5.426 - #>> pair ts o map snd)
5.427 - end;
5.428 -
5.429 -in
5.430 -
5.431 -val add_primrec = gen_primrec Specification.check_spec;
5.432 -val add_primrec_cmd = gen_primrec Specification.read_spec;
5.433 -
5.434 -end;
5.435 -
5.436 -fun add_primrec_global fixes specs thy =
5.437 - let
5.438 - val lthy = Named_Target.theory_init thy;
5.439 - val ((ts, simps), lthy') = add_primrec fixes specs lthy;
5.440 - val simps' = burrow (Proof_Context.export lthy' lthy) simps;
5.441 - in ((ts, simps'), Local_Theory.exit_global lthy') end;
5.442 -
5.443 -fun add_primrec_overloaded ops fixes specs thy =
5.444 - let
5.445 - val lthy = Overloading.overloading ops thy;
5.446 - val ((ts, simps), lthy') = add_primrec fixes specs lthy;
5.447 - val simps' = burrow (Proof_Context.export lthy' lthy) simps;
5.448 - in ((ts, simps'), Local_Theory.exit_global lthy') end;
5.449 -
5.450 -
5.451 -
5.452 -(* Primcorec *)
5.453 -
5.454 -type coeqn_data_disc = {
5.455 - fun_name: string,
5.456 - fun_T: typ,
5.457 - fun_args: term list,
5.458 - ctr: term,
5.459 - ctr_no: int, (*###*)
5.460 - disc: term,
5.461 - prems: term list,
5.462 - auto_gen: bool,
5.463 - maybe_ctr_rhs: term option,
5.464 - maybe_code_rhs: term option,
5.465 - user_eqn: term
5.466 -};
5.467 -
5.468 -type coeqn_data_sel = {
5.469 - fun_name: string,
5.470 - fun_T: typ,
5.471 - fun_args: term list,
5.472 - ctr: term,
5.473 - sel: term,
5.474 - rhs_term: term,
5.475 - user_eqn: term
5.476 -};
5.477 -
5.478 -datatype coeqn_data =
5.479 - Disc of coeqn_data_disc |
5.480 - Sel of coeqn_data_sel;
5.481 -
5.482 -fun dissect_coeqn_disc seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list)
5.483 - maybe_ctr_rhs maybe_code_rhs prems' concl matchedsss =
5.484 - let
5.485 - fun find_subterm p = let (* FIXME \<exists>? *)
5.486 - fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
5.487 - | f t = if p t then SOME t else NONE
5.488 - in f end;
5.489 -
5.490 - val applied_fun = concl
5.491 - |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
5.492 - |> the
5.493 - handle Option.Option => primrec_error_eqn "malformed discriminator formula" concl;
5.494 - val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
5.495 - val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
5.496 -
5.497 - val discs = map #disc basic_ctr_specs;
5.498 - val ctrs = map #ctr basic_ctr_specs;
5.499 - val not_disc = head_of concl = @{term Not};
5.500 - val _ = not_disc andalso length ctrs <> 2 andalso
5.501 - primrec_error_eqn "negated discriminator for a type with \<noteq> 2 constructors" concl;
5.502 - val disc' = find_subterm (member (op =) discs o head_of) concl;
5.503 - val eq_ctr0 = concl |> perhaps (try HOLogic.dest_not) |> try (HOLogic.dest_eq #> snd)
5.504 - |> (fn SOME t => let val n = find_index (equal t) ctrs in
5.505 - if n >= 0 then SOME n else NONE end | _ => NONE);
5.506 - val _ = is_some disc' orelse is_some eq_ctr0 orelse
5.507 - primrec_error_eqn "no discriminator in equation" concl;
5.508 - val ctr_no' =
5.509 - if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs;
5.510 - val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
5.511 - val {ctr, disc, ...} = nth basic_ctr_specs ctr_no;
5.512 -
5.513 - val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_;
5.514 - val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
5.515 - val prems = map (abstract (List.rev fun_args)) prems';
5.516 - val real_prems =
5.517 - (if catch_all orelse seq then maps s_not_conj matchedss else []) @
5.518 - (if catch_all then [] else prems);
5.519 -
5.520 - val matchedsss' = AList.delete (op =) fun_name matchedsss
5.521 - |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]);
5.522 -
5.523 - val user_eqn =
5.524 - (real_prems, concl)
5.525 - |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop o abstract (List.rev fun_args)
5.526 - |> curry Logic.list_all (map dest_Free fun_args) o Logic.list_implies;
5.527 - in
5.528 - (Disc {
5.529 - fun_name = fun_name,
5.530 - fun_T = fun_T,
5.531 - fun_args = fun_args,
5.532 - ctr = ctr,
5.533 - ctr_no = ctr_no,
5.534 - disc = disc,
5.535 - prems = real_prems,
5.536 - auto_gen = catch_all,
5.537 - maybe_ctr_rhs = maybe_ctr_rhs,
5.538 - maybe_code_rhs = maybe_code_rhs,
5.539 - user_eqn = user_eqn
5.540 - }, matchedsss')
5.541 - end;
5.542 -
5.543 -fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
5.544 - maybe_of_spec eqn =
5.545 - let
5.546 - val (lhs, rhs) = HOLogic.dest_eq eqn
5.547 - handle TERM _ =>
5.548 - primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
5.549 - val sel = head_of lhs;
5.550 - val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
5.551 - handle TERM _ =>
5.552 - primrec_error_eqn "malformed selector argument in left-hand side" eqn;
5.553 - val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name)
5.554 - handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
5.555 - val {ctr, ...} =
5.556 - (case maybe_of_spec of
5.557 - SOME of_spec => the (find_first (equal of_spec o #ctr) basic_ctr_specs)
5.558 - | NONE => filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single
5.559 - handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn);
5.560 - val user_eqn = drop_All eqn';
5.561 - in
5.562 - Sel {
5.563 - fun_name = fun_name,
5.564 - fun_T = fun_T,
5.565 - fun_args = fun_args,
5.566 - ctr = ctr,
5.567 - sel = sel,
5.568 - rhs_term = rhs,
5.569 - user_eqn = user_eqn
5.570 - }
5.571 - end;
5.572 -
5.573 -fun dissect_coeqn_ctr seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
5.574 - maybe_code_rhs prems concl matchedsss =
5.575 - let
5.576 - val (lhs, rhs) = HOLogic.dest_eq concl;
5.577 - val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
5.578 - val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
5.579 - val (ctr, ctr_args) = strip_comb (unfold_let rhs);
5.580 - val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs)
5.581 - handle Option.Option => primrec_error_eqn "not a constructor" ctr;
5.582 -
5.583 - val disc_concl = betapply (disc, lhs);
5.584 - val (maybe_eqn_data_disc, matchedsss') = if length basic_ctr_specs = 1
5.585 - then (NONE, matchedsss)
5.586 - else apfst SOME (dissect_coeqn_disc seq fun_names basic_ctr_specss
5.587 - (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
5.588 -
5.589 - val sel_concls = sels ~~ ctr_args
5.590 - |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
5.591 -
5.592 -(*
5.593 -val _ = tracing ("reduced\n " ^ Syntax.string_of_term @{context} concl ^ "\nto\n \<cdot> " ^
5.594 - (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n \<cdot> ")) "" ^
5.595 - space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls) ^
5.596 - "\nfor premise(s)\n \<cdot> " ^
5.597 - space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) prems));
5.598 -*)
5.599 -
5.600 - val eqns_data_sel =
5.601 - map (dissect_coeqn_sel fun_names basic_ctr_specss eqn' (SOME ctr)) sel_concls;
5.602 - in
5.603 - (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
5.604 - end;
5.605 -
5.606 -fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss =
5.607 - let
5.608 - val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
5.609 - val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
5.610 - val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
5.611 -
5.612 - val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
5.613 - if member ((op =) o apsnd #ctr) basic_ctr_specs ctr
5.614 - then cons (ctr, cs)
5.615 - else primrec_error_eqn "not a constructor" ctr) [] rhs' []
5.616 - |> AList.group (op =);
5.617 -
5.618 - val ctr_premss = (case cond_ctrs of [_] => [[]] | _ => map (s_dnf o snd) cond_ctrs);
5.619 - val ctr_concls = cond_ctrs |> map (fn (ctr, _) =>
5.620 - binder_types (fastype_of ctr)
5.621 - |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
5.622 - if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
5.623 - |> curry list_comb ctr
5.624 - |> curry HOLogic.mk_eq lhs);
5.625 - in
5.626 - fold_map2 (dissect_coeqn_ctr false fun_names basic_ctr_specss eqn'
5.627 - (SOME (abstract (List.rev fun_args) rhs)))
5.628 - ctr_premss ctr_concls matchedsss
5.629 - end;
5.630 -
5.631 -fun dissect_coeqn lthy seq has_call fun_names (basic_ctr_specss : basic_corec_ctr_spec list list)
5.632 - eqn' maybe_of_spec matchedsss =
5.633 - let
5.634 - val eqn = drop_All eqn'
5.635 - handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
5.636 - val (prems, concl) = Logic.strip_horn eqn
5.637 - |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
5.638 -
5.639 - val head = concl
5.640 - |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
5.641 - |> head_of;
5.642 -
5.643 - val maybe_rhs = concl |> perhaps (try HOLogic.dest_not) |> try (snd o HOLogic.dest_eq);
5.644 -
5.645 - val discs = maps (map #disc) basic_ctr_specss;
5.646 - val sels = maps (maps #sels) basic_ctr_specss;
5.647 - val ctrs = maps (map #ctr) basic_ctr_specss;
5.648 - in
5.649 - if member (op =) discs head orelse
5.650 - is_some maybe_rhs andalso
5.651 - member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
5.652 - dissect_coeqn_disc seq fun_names basic_ctr_specss NONE NONE prems concl matchedsss
5.653 - |>> single
5.654 - else if member (op =) sels head then
5.655 - ([dissect_coeqn_sel fun_names basic_ctr_specss eqn' maybe_of_spec concl], matchedsss)
5.656 - else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
5.657 - member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then
5.658 - dissect_coeqn_ctr seq fun_names basic_ctr_specss eqn' NONE prems concl matchedsss
5.659 - else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
5.660 - null prems then
5.661 - dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss
5.662 - |>> flat
5.663 - else
5.664 - primrec_error_eqn "malformed function equation" eqn
5.665 - end;
5.666 -
5.667 -fun build_corec_arg_disc (ctr_specs : corec_ctr_spec list)
5.668 - ({fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
5.669 - if is_none (#pred (nth ctr_specs ctr_no)) then I else
5.670 - s_conjs prems
5.671 - |> curry subst_bounds (List.rev fun_args)
5.672 - |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args)
5.673 - |> K |> nth_map (the (#pred (nth ctr_specs ctr_no)));
5.674 -
5.675 -fun build_corec_arg_no_call (sel_eqns : coeqn_data_sel list) sel =
5.676 - find_first (equal sel o #sel) sel_eqns
5.677 - |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term)
5.678 - |> the_default undef_const
5.679 - |> K;
5.680 -
5.681 -fun build_corec_args_mutual_call lthy has_call (sel_eqns : coeqn_data_sel list) sel =
5.682 - (case find_first (equal sel o #sel) sel_eqns of
5.683 - NONE => (I, I, I)
5.684 - | SOME {fun_args, rhs_term, ... } =>
5.685 - let
5.686 - val bound_Ts = List.rev (map fastype_of fun_args);
5.687 - fun rewrite_stop _ t = if has_call t then @{term False} else @{term True};
5.688 - fun rewrite_end _ t = if has_call t then undef_const else t;
5.689 - fun rewrite_cont bound_Ts t =
5.690 - if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
5.691 - fun massage f _ = massage_mutual_corec_call lthy has_call f bound_Ts rhs_term
5.692 - |> abs_tuple fun_args;
5.693 - in
5.694 - (massage rewrite_stop, massage rewrite_end, massage rewrite_cont)
5.695 - end);
5.696 -
5.697 -fun build_corec_arg_nested_call lthy has_call (sel_eqns : coeqn_data_sel list) sel =
5.698 - (case find_first (equal sel o #sel) sel_eqns of
5.699 - NONE => I
5.700 - | SOME {fun_args, rhs_term, ...} =>
5.701 - let
5.702 - val bound_Ts = List.rev (map fastype_of fun_args);
5.703 - fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
5.704 - | rewrite bound_Ts U T (t as _ $ _) =
5.705 - let val (u, vs) = strip_comb t in
5.706 - if is_Free u andalso has_call u then
5.707 - Inr_const U T $ mk_tuple1 bound_Ts vs
5.708 - else if const_name u = SOME @{const_name prod_case} then
5.709 - map (rewrite bound_Ts U T) vs |> chop 1 |>> HOLogic.mk_split o the_single |> list_comb
5.710 - else
5.711 - list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs)
5.712 - end
5.713 - | rewrite _ U T t =
5.714 - if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
5.715 - fun massage t =
5.716 - rhs_term
5.717 - |> massage_nested_corec_call lthy has_call rewrite bound_Ts (range_type (fastype_of t))
5.718 - |> abs_tuple fun_args;
5.719 - in
5.720 - massage
5.721 - end);
5.722 -
5.723 -fun build_corec_args_sel lthy has_call (all_sel_eqns : coeqn_data_sel list)
5.724 - (ctr_spec : corec_ctr_spec) =
5.725 - (case filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns of
5.726 - [] => I
5.727 - | sel_eqns =>
5.728 - let
5.729 - val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
5.730 - val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
5.731 - val mutual_calls' = map_filter (try (apsnd (fn Mutual_Corec n => n))) sel_call_list;
5.732 - val nested_calls' = map_filter (try (apsnd (fn Nested_Corec n => n))) sel_call_list;
5.733 - in
5.734 - I
5.735 - #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls'
5.736 - #> fold (fn (sel, (q, g, h)) =>
5.737 - let val (fq, fg, fh) = build_corec_args_mutual_call lthy has_call sel_eqns sel in
5.738 - nth_map q fq o nth_map g fg o nth_map h fh end) mutual_calls'
5.739 - #> fold (fn (sel, n) => nth_map n
5.740 - (build_corec_arg_nested_call lthy has_call sel_eqns sel)) nested_calls'
5.741 - end);
5.742 -
5.743 -fun build_codefs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list)
5.744 - (disc_eqnss : coeqn_data_disc list list) (sel_eqnss : coeqn_data_sel list list) =
5.745 - let
5.746 - val corecs = map #corec corec_specs;
5.747 - val ctr_specss = map #ctr_specs corec_specs;
5.748 - val corec_args = hd corecs
5.749 - |> fst o split_last o binder_types o fastype_of
5.750 - |> map (Const o pair @{const_name undefined})
5.751 - |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss
5.752 - |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
5.753 - fun currys [] t = t
5.754 - | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0))
5.755 - |> fold_rev (Term.abs o pair Name.uu) Ts;
5.756 -
5.757 -(*
5.758 -val _ = tracing ("corecursor arguments:\n \<cdot> " ^
5.759 - space_implode "\n \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
5.760 -*)
5.761 -
5.762 - val exclss' =
5.763 - disc_eqnss
5.764 - |> map (map (fn x => (#fun_args x, #ctr_no x, #prems x, #auto_gen x))
5.765 - #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs [])
5.766 - #> maps (uncurry (map o pair)
5.767 - #> map (fn ((fun_args, c, x, a), (_, c', y, a')) =>
5.768 - ((c, c', a orelse a'), (x, s_not (s_conjs y)))
5.769 - ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop
5.770 - ||> Logic.list_implies
5.771 - ||> curry Logic.list_all (map dest_Free fun_args))))
5.772 - in
5.773 - map (list_comb o rpair corec_args) corecs
5.774 - |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
5.775 - |> map2 currys arg_Tss
5.776 - |> Syntax.check_terms lthy
5.777 - |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
5.778 - bs mxs
5.779 - |> rpair exclss'
5.780 - end;
5.781 -
5.782 -fun mk_real_disc_eqns fun_binding arg_Ts ({ctr_specs, ...} : corec_spec)
5.783 - (sel_eqns : coeqn_data_sel list) (disc_eqns : coeqn_data_disc list) =
5.784 - if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
5.785 - let
5.786 - val n = 0 upto length ctr_specs
5.787 - |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
5.788 - val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
5.789 - |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
5.790 - val extra_disc_eqn = {
5.791 - fun_name = Binding.name_of fun_binding,
5.792 - fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
5.793 - fun_args = fun_args,
5.794 - ctr = #ctr (nth ctr_specs n),
5.795 - ctr_no = n,
5.796 - disc = #disc (nth ctr_specs n),
5.797 - prems = maps (s_not_conj o #prems) disc_eqns,
5.798 - auto_gen = true,
5.799 - maybe_ctr_rhs = NONE,
5.800 - maybe_code_rhs = NONE,
5.801 - user_eqn = undef_const};
5.802 - in
5.803 - chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
5.804 - end;
5.805 -
5.806 -fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
5.807 - let
5.808 - val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs
5.809 - |> find_index (equal sel) o #sels o the;
5.810 - fun find t = if has_call t then snd (fold_rev_corec_call ctxt (K cons) [] t []) else [];
5.811 - in
5.812 - find rhs_term
5.813 - |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
5.814 - end;
5.815 -
5.816 -fun add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy =
5.817 - let
5.818 - val (bs, mxs) = map_split (apfst fst) fixes;
5.819 - val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
5.820 -
5.821 - val fun_names = map Binding.name_of bs;
5.822 - val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
5.823 - val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
5.824 - val eqns_data =
5.825 - fold_map2 (dissect_coeqn lthy seq has_call fun_names basic_ctr_specss) (map snd specs)
5.826 - maybe_of_specs []
5.827 - |> flat o fst;
5.828 -
5.829 - val callssss =
5.830 - map_filter (try (fn Sel x => x)) eqns_data
5.831 - |> partition_eq ((op =) o pairself #fun_name)
5.832 - |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
5.833 - |> map (flat o snd)
5.834 - |> map2 (fold o find_corec_calls lthy has_call) basic_ctr_specss
5.835 - |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} =>
5.836 - (ctr, map (K []) sels))) basic_ctr_specss);
5.837 -
5.838 -(*
5.839 -val _ = tracing ("callssss = " ^ @{make_string} callssss);
5.840 -*)
5.841 -
5.842 - val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms,
5.843 - strong_coinduct_thms), lthy') =
5.844 - corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
5.845 - val actual_nn = length bs;
5.846 - val corec_specs = take actual_nn corec_specs'; (*###*)
5.847 - val ctr_specss = map #ctr_specs corec_specs;
5.848 -
5.849 - val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
5.850 - |> partition_eq ((op =) o pairself #fun_name)
5.851 - |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
5.852 - |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
5.853 - val _ = disc_eqnss' |> map (fn x =>
5.854 - let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
5.855 - primrec_error_eqns "excess discriminator formula in definition"
5.856 - (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
5.857 -
5.858 - val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
5.859 - |> partition_eq ((op =) o pairself #fun_name)
5.860 - |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
5.861 - |> map (flat o snd);
5.862 -
5.863 - val arg_Tss = map (binder_types o snd o fst) fixes;
5.864 - val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
5.865 - val (defs, exclss') =
5.866 - build_codefs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
5.867 -
5.868 - fun excl_tac (c, c', a) =
5.869 - if a orelse c = c' orelse seq then SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy [])))
5.870 - else maybe_tac;
5.871 -
5.872 -(*
5.873 -val _ = tracing ("exclusiveness properties:\n \<cdot> " ^
5.874 - space_implode "\n \<cdot> " (maps (map (Syntax.string_of_term lthy o snd)) exclss'));
5.875 -*)
5.876 -
5.877 - val exclss'' = exclss' |> map (map (fn (idx, t) =>
5.878 - (idx, (Option.map (Goal.prove lthy [] [] t #> Thm.close_derivation) (excl_tac idx), t))));
5.879 - val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss'';
5.880 - val (goal_idxss, goalss) = exclss''
5.881 - |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
5.882 - |> split_list o map split_list;
5.883 -
5.884 - fun prove thmss' def_thms' lthy =
5.885 - let
5.886 - val def_thms = map (snd o snd) def_thms';
5.887 -
5.888 - val exclss' = map (op ~~) (goal_idxss ~~ thmss');
5.889 - fun mk_exclsss excls n =
5.890 - (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1))
5.891 - |-> fold (fn ((c, c', _), thm) => nth_map c (nth_map c' (K [thm])));
5.892 - val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
5.893 - |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
5.894 -
5.895 - fun prove_disc ({ctr_specs, ...} : corec_spec) exclsss
5.896 - ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
5.897 - if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else
5.898 - let
5.899 - val {disc_corec, ...} = nth ctr_specs ctr_no;
5.900 - val k = 1 + ctr_no;
5.901 - val m = length prems;
5.902 - val t =
5.903 - list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
5.904 - |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
5.905 - |> HOLogic.mk_Trueprop
5.906 - |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
5.907 - |> curry Logic.list_all (map dest_Free fun_args);
5.908 - in
5.909 - if prems = [@{term False}] then [] else
5.910 - mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss
5.911 - |> K |> Goal.prove lthy [] [] t
5.912 - |> Thm.close_derivation
5.913 - |> pair (#disc (nth ctr_specs ctr_no))
5.914 - |> single
5.915 - end;
5.916 -
5.917 - fun prove_sel ({nested_maps, nested_map_idents, nested_map_comps, ctr_specs, ...}
5.918 - : corec_spec) (disc_eqns : coeqn_data_disc list) exclsss
5.919 - ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) =
5.920 - let
5.921 - val SOME ctr_spec = find_first (equal ctr o #ctr) ctr_specs;
5.922 - val ctr_no = find_index (equal ctr o #ctr) ctr_specs;
5.923 - val prems = the_default (maps (s_not_conj o #prems) disc_eqns)
5.924 - (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems);
5.925 - val sel_corec = find_index (equal sel) (#sels ctr_spec)
5.926 - |> nth (#sel_corecs ctr_spec);
5.927 - val k = 1 + ctr_no;
5.928 - val m = length prems;
5.929 - val t =
5.930 - list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
5.931 - |> curry betapply sel
5.932 - |> rpair (abstract (List.rev fun_args) rhs_term)
5.933 - |> HOLogic.mk_Trueprop o HOLogic.mk_eq
5.934 - |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
5.935 - |> curry Logic.list_all (map dest_Free fun_args);
5.936 - val (distincts, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term;
5.937 - in
5.938 - mk_primcorec_sel_tac lthy def_thms distincts sel_splits sel_split_asms nested_maps
5.939 - nested_map_idents nested_map_comps sel_corec k m exclsss
5.940 - |> K |> Goal.prove lthy [] [] t
5.941 - |> Thm.close_derivation
5.942 - |> pair sel
5.943 - end;
5.944 -
5.945 - fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list)
5.946 - (sel_eqns : coeqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) =
5.947 - (* don't try to prove theorems when some sel_eqns are missing *)
5.948 - if not (exists (equal ctr o #ctr) disc_eqns)
5.949 - andalso not (exists (equal ctr o #ctr) sel_eqns)
5.950 - orelse
5.951 - filter (equal ctr o #ctr) sel_eqns
5.952 - |> fst o finds ((op =) o apsnd #sel) sels
5.953 - |> exists (null o snd)
5.954 - then [] else
5.955 - let
5.956 - val (fun_name, fun_T, fun_args, prems, maybe_rhs) =
5.957 - (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
5.958 - |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x,
5.959 - #maybe_ctr_rhs x))
5.960 - ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], NONE))
5.961 - |> the o merge_options;
5.962 - val m = length prems;
5.963 - val t = (if is_some maybe_rhs then the maybe_rhs else
5.964 - filter (equal ctr o #ctr) sel_eqns
5.965 - |> fst o finds ((op =) o apsnd #sel) sels
5.966 - |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
5.967 - |> curry list_comb ctr)
5.968 - |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
5.969 - map Bound (length fun_args - 1 downto 0)))
5.970 - |> HOLogic.mk_Trueprop
5.971 - |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
5.972 - |> curry Logic.list_all (map dest_Free fun_args);
5.973 - val maybe_disc_thm = AList.lookup (op =) disc_alist disc;
5.974 - val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist);
5.975 - in
5.976 - if prems = [@{term False}] then [] else
5.977 - mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
5.978 - |> K |> Goal.prove lthy [] [] t
5.979 - |> Thm.close_derivation
5.980 - |> pair ctr
5.981 - |> single
5.982 - end;
5.983 -
5.984 - fun prove_code disc_eqns sel_eqns ctr_alist ctr_specs =
5.985 - let
5.986 - val (fun_name, fun_T, fun_args, maybe_rhs) =
5.987 - (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns,
5.988 - find_first (member (op =) (map #ctr ctr_specs) o #ctr) sel_eqns)
5.989 - |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #maybe_code_rhs x))
5.990 - ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, NONE))
5.991 - |> the o merge_options;
5.992 -
5.993 - val bound_Ts = List.rev (map fastype_of fun_args);
5.994 -
5.995 - val lhs = list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0));
5.996 - val maybe_rhs_info =
5.997 - (case maybe_rhs of
5.998 - SOME rhs =>
5.999 - let
5.1000 - val raw_rhs = expand_corec_code_rhs lthy has_call bound_Ts rhs;
5.1001 - val cond_ctrs =
5.1002 - fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts raw_rhs [];
5.1003 - val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs;
5.1004 - in SOME (rhs, raw_rhs, ctr_thms) end
5.1005 - | NONE =>
5.1006 - let
5.1007 - fun prove_code_ctr {ctr, sels, ...} =
5.1008 - if not (exists (equal ctr o fst) ctr_alist) then NONE else
5.1009 - let
5.1010 - val prems = find_first (equal ctr o #ctr) disc_eqns
5.1011 - |> Option.map #prems |> the_default [];
5.1012 - val t =
5.1013 - filter (equal ctr o #ctr) sel_eqns
5.1014 - |> fst o finds ((op =) o apsnd #sel) sels
5.1015 - |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x))
5.1016 - #-> abstract)
5.1017 - |> curry list_comb ctr;
5.1018 - in
5.1019 - SOME (prems, t)
5.1020 - end;
5.1021 - val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs;
5.1022 - in
5.1023 - if exists is_none maybe_ctr_conds_argss then NONE else
5.1024 - let
5.1025 - val rhs = fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t)
5.1026 - maybe_ctr_conds_argss
5.1027 - (Const (@{const_name Code.abort}, @{typ String.literal} -->
5.1028 - (@{typ unit} --> body_type fun_T) --> body_type fun_T) $
5.1029 - HOLogic.mk_literal fun_name $
5.1030 - absdummy @{typ unit} (incr_boundvars 1 lhs));
5.1031 - in SOME (rhs, rhs, map snd ctr_alist) end
5.1032 - end);
5.1033 - in
5.1034 - (case maybe_rhs_info of
5.1035 - NONE => []
5.1036 - | SOME (rhs, raw_rhs, ctr_thms) =>
5.1037 - let
5.1038 - val ms = map (Logic.count_prems o prop_of) ctr_thms;
5.1039 - val (raw_t, t) = (raw_rhs, rhs)
5.1040 - |> pairself
5.1041 - (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
5.1042 - map Bound (length fun_args - 1 downto 0)))
5.1043 - #> HOLogic.mk_Trueprop
5.1044 - #> curry Logic.list_all (map dest_Free fun_args));
5.1045 - val (distincts, discIs, sel_splits, sel_split_asms) =
5.1046 - case_thms_of_term lthy bound_Ts raw_rhs;
5.1047 -
5.1048 - val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits
5.1049 - sel_split_asms ms ctr_thms
5.1050 - |> K |> Goal.prove lthy [] [] raw_t
5.1051 - |> Thm.close_derivation;
5.1052 - in
5.1053 - mk_primcorec_code_of_raw_code_tac lthy distincts sel_splits raw_code_thm
5.1054 - |> K |> Goal.prove lthy [] [] t
5.1055 - |> Thm.close_derivation
5.1056 - |> single
5.1057 - end)
5.1058 - end;
5.1059 -
5.1060 - val disc_alists = map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss;
5.1061 - val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss;
5.1062 - val disc_thmss = map (map snd) disc_alists;
5.1063 - val sel_thmss = map (map snd) sel_alists;
5.1064 -
5.1065 - val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
5.1066 - ctr_specss;
5.1067 - val ctr_thmss = map (map snd) ctr_alists;
5.1068 -
5.1069 - val code_thmss = map4 prove_code disc_eqnss sel_eqnss ctr_alists ctr_specss;
5.1070 -
5.1071 - val simp_thmss = map2 append disc_thmss sel_thmss
5.1072 -
5.1073 - val common_name = mk_common_name fun_names;
5.1074 -
5.1075 - val notes =
5.1076 - [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
5.1077 - (codeN, code_thmss, code_nitpicksimp_attrs),
5.1078 - (ctrN, ctr_thmss, []),
5.1079 - (discN, disc_thmss, simp_attrs),
5.1080 - (selN, sel_thmss, simp_attrs),
5.1081 - (simpsN, simp_thmss, []),
5.1082 - (strong_coinductN, map (if n2m then single else K []) strong_coinduct_thms, [])]
5.1083 - |> maps (fn (thmN, thmss, attrs) =>
5.1084 - map2 (fn fun_name => fn thms =>
5.1085 - ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]))
5.1086 - fun_names (take actual_nn thmss))
5.1087 - |> filter_out (null o fst o hd o snd);
5.1088 -
5.1089 - val common_notes =
5.1090 - [(coinductN, if n2m then [coinduct_thm] else [], []),
5.1091 - (strong_coinductN, if n2m then [strong_coinduct_thm] else [], [])]
5.1092 - |> filter_out (null o #2)
5.1093 - |> map (fn (thmN, thms, attrs) =>
5.1094 - ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
5.1095 - in
5.1096 - lthy |> Local_Theory.notes (notes @ common_notes) |> snd
5.1097 - end;
5.1098 -
5.1099 - fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss';
5.1100 - in
5.1101 - (goalss, after_qed, lthy')
5.1102 - end;
5.1103 -
5.1104 -fun add_primcorec_ursive_cmd maybe_tac seq (raw_fixes, raw_specs') lthy =
5.1105 - let
5.1106 - val (raw_specs, maybe_of_specs) =
5.1107 - split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy));
5.1108 - val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy;
5.1109 - in
5.1110 - add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy
5.1111 - handle ERROR str => primrec_error str
5.1112 - end
5.1113 - handle Primrec_Error (str, eqns) =>
5.1114 - if null eqns
5.1115 - then error ("primcorec error:\n " ^ str)
5.1116 - else error ("primcorec error:\n " ^ str ^ "\nin\n " ^
5.1117 - space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns));
5.1118 -
5.1119 -val add_primcorecursive_cmd = (fn (goalss, after_qed, lthy) =>
5.1120 - lthy
5.1121 - |> Proof.theorem NONE after_qed goalss
5.1122 - |> Proof.refine (Method.primitive_text I)
5.1123 - |> Seq.hd) ooo add_primcorec_ursive_cmd NONE;
5.1124 -
5.1125 -val add_primcorec_cmd = (fn (goalss, after_qed, lthy) =>
5.1126 - lthy
5.1127 - |> after_qed (map (fn [] => []
5.1128 - | _ => primrec_error "need exclusiveness proofs - use primcorecursive instead of primcorec")
5.1129 - goalss)) ooo add_primcorec_ursive_cmd (SOME (fn {context = ctxt, ...} => auto_tac ctxt));
5.1130 -
5.1131 -end;
6.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_tactics.ML Mon Nov 04 15:44:43 2013 +0100
6.2 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000
6.3 @@ -1,142 +0,0 @@
6.4 -(* Title: HOL/BNF/Tools/bnf_fp_rec_sugar_tactics.ML
6.5 - Author: Jasmin Blanchette, TU Muenchen
6.6 - Copyright 2013
6.7 -
6.8 -Tactics for recursor and corecursor sugar.
6.9 -*)
6.10 -
6.11 -signature BNF_FP_REC_SUGAR_TACTICS =
6.12 -sig
6.13 - val mk_primcorec_assumption_tac: Proof.context -> thm list -> int -> tactic
6.14 - val mk_primcorec_code_of_raw_code_tac: Proof.context -> thm list -> thm list -> thm -> tactic
6.15 - val mk_primcorec_ctr_of_dtr_tac: Proof.context -> int -> thm -> thm option -> thm list -> tactic
6.16 - val mk_primcorec_disc_tac: Proof.context -> thm list -> thm -> int -> int -> thm list list list ->
6.17 - tactic
6.18 - val mk_primcorec_raw_code_of_ctr_tac: Proof.context -> thm list -> thm list -> thm list ->
6.19 - thm list -> int list -> thm list -> tactic
6.20 - val mk_primcorec_sel_tac: Proof.context -> thm list -> thm list -> thm list -> thm list ->
6.21 - thm list -> thm list -> thm list -> thm -> int -> int -> thm list list list -> tactic
6.22 - val mk_primrec_tac: Proof.context -> int -> thm list -> thm list -> thm list -> thm -> tactic
6.23 -end;
6.24 -
6.25 -structure BNF_FP_Rec_Sugar_Tactics : BNF_FP_REC_SUGAR_TACTICS =
6.26 -struct
6.27 -
6.28 -open BNF_Util
6.29 -open BNF_Tactics
6.30 -
6.31 -val falseEs = @{thms not_TrueE FalseE};
6.32 -val Let_def = @{thm Let_def};
6.33 -val neq_eq_eq_contradict = @{thm neq_eq_eq_contradict};
6.34 -val split_if = @{thm split_if};
6.35 -val split_if_asm = @{thm split_if_asm};
6.36 -val split_connectI = @{thms allI impI conjI};
6.37 -
6.38 -fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
6.39 - unfold_thms_tac ctxt fun_defs THEN
6.40 - HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
6.41 - unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN
6.42 - HEADGOAL (rtac refl);
6.43 -
6.44 -fun mk_primcorec_assumption_tac ctxt discIs =
6.45 - SELECT_GOAL (unfold_thms_tac ctxt
6.46 - @{thms not_not not_False_eq_True not_True_eq_False de_Morgan_conj de_Morgan_disj} THEN
6.47 - SOLVE (HEADGOAL (REPEAT o (rtac refl ORELSE' atac ORELSE' etac conjE ORELSE'
6.48 - eresolve_tac falseEs ORELSE'
6.49 - resolve_tac @{thms TrueI conjI disjI1 disjI2} ORELSE'
6.50 - dresolve_tac discIs THEN' atac ORELSE'
6.51 - etac notE THEN' atac ORELSE'
6.52 - etac disjE))));
6.53 -
6.54 -fun mk_primcorec_same_case_tac m =
6.55 - HEADGOAL (if m = 0 then rtac TrueI
6.56 - else REPEAT_DETERM_N (m - 1) o (rtac conjI THEN' atac) THEN' atac);
6.57 -
6.58 -fun mk_primcorec_different_case_tac ctxt m excl =
6.59 - HEADGOAL (if m = 0 then mk_primcorec_assumption_tac ctxt []
6.60 - else dtac excl THEN' (REPEAT_DETERM_N (m - 1) o atac) THEN' mk_primcorec_assumption_tac ctxt []);
6.61 -
6.62 -fun mk_primcorec_cases_tac ctxt k m exclsss =
6.63 - let val n = length exclsss in
6.64 - EVERY (map (fn [] => if k = n then all_tac else mk_primcorec_same_case_tac m
6.65 - | [excl] => mk_primcorec_different_case_tac ctxt m excl)
6.66 - (take k (nth exclsss (k - 1))))
6.67 - end;
6.68 -
6.69 -fun mk_primcorec_prelude ctxt defs thm =
6.70 - unfold_thms_tac ctxt defs THEN HEADGOAL (rtac thm) THEN
6.71 - unfold_thms_tac ctxt @{thms Let_def split};
6.72 -
6.73 -fun mk_primcorec_disc_tac ctxt defs disc_corec k m exclsss =
6.74 - mk_primcorec_prelude ctxt defs disc_corec THEN mk_primcorec_cases_tac ctxt k m exclsss;
6.75 -
6.76 -fun mk_primcorec_sel_tac ctxt defs distincts splits split_asms maps map_idents map_comps f_sel k m
6.77 - exclsss =
6.78 - mk_primcorec_prelude ctxt defs (f_sel RS trans) THEN
6.79 - mk_primcorec_cases_tac ctxt k m exclsss THEN
6.80 - HEADGOAL (REPEAT_DETERM o (rtac refl ORELSE' rtac ext ORELSE'
6.81 - eresolve_tac falseEs ORELSE'
6.82 - resolve_tac split_connectI ORELSE'
6.83 - Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE'
6.84 - Splitter.split_tac (split_if :: splits) ORELSE'
6.85 - eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac ORELSE'
6.86 - etac notE THEN' atac ORELSE'
6.87 - (CHANGED o SELECT_GOAL (unfold_thms_tac ctxt
6.88 - (@{thms id_def o_def split_def sum.cases} @ maps @ map_comps @ map_idents)))));
6.89 -
6.90 -fun mk_primcorec_ctr_of_dtr_tac ctxt m collapse maybe_disc_f sel_fs =
6.91 - HEADGOAL (rtac ((if null sel_fs then collapse else collapse RS sym) RS trans) THEN'
6.92 - (the_default (K all_tac) (Option.map rtac maybe_disc_f)) THEN' REPEAT_DETERM_N m o atac) THEN
6.93 - unfold_thms_tac ctxt (Let_def :: sel_fs) THEN HEADGOAL (rtac refl);
6.94 -
6.95 -fun inst_split_eq ctxt split =
6.96 - (case prop_of split of
6.97 - @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ (Var (_, Type (_, [T, _])) $ _) $ _) =>
6.98 - let
6.99 - val s = Name.uu;
6.100 - val eq = Abs (Name.uu, T, HOLogic.mk_eq (Free (s, T), Bound 0));
6.101 - val split' = Drule.instantiate' [] [SOME (certify ctxt eq)] split;
6.102 - in
6.103 - Thm.generalize ([], [s]) (Thm.maxidx_of split' + 1) split'
6.104 - end
6.105 - | _ => split);
6.106 -
6.107 -fun distinct_in_prems_tac distincts =
6.108 - eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac;
6.109 -
6.110 -(* TODO: reduce code duplication with selector tactic above *)
6.111 -fun mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms m f_ctr =
6.112 - let
6.113 - val splits' =
6.114 - map (fn th => th RS iffD2) (@{thm split_if_eq2} :: map (inst_split_eq ctxt) splits)
6.115 - in
6.116 - HEADGOAL (REPEAT o (resolve_tac (splits' @ split_connectI))) THEN
6.117 - mk_primcorec_prelude ctxt [] (f_ctr RS trans) THEN
6.118 - HEADGOAL ((REPEAT_DETERM_N m o mk_primcorec_assumption_tac ctxt discIs) THEN'
6.119 - SELECT_GOAL (SOLVE (HEADGOAL (REPEAT_DETERM o
6.120 - (rtac refl ORELSE' atac ORELSE'
6.121 - resolve_tac (@{thm Code.abort_def} :: split_connectI) ORELSE'
6.122 - Splitter.split_tac (split_if :: splits) ORELSE'
6.123 - Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE'
6.124 - mk_primcorec_assumption_tac ctxt discIs ORELSE'
6.125 - distinct_in_prems_tac distincts ORELSE'
6.126 - (TRY o dresolve_tac discIs) THEN' etac notE THEN' atac)))))
6.127 - end;
6.128 -
6.129 -fun mk_primcorec_raw_code_of_ctr_tac ctxt distincts discIs splits split_asms ms f_ctrs =
6.130 - EVERY (map2 (mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms) ms
6.131 - f_ctrs) THEN
6.132 - IF_UNSOLVED (unfold_thms_tac ctxt @{thms Code.abort_def} THEN
6.133 - HEADGOAL (REPEAT_DETERM o resolve_tac (refl :: split_connectI)));
6.134 -
6.135 -fun mk_primcorec_code_of_raw_code_tac ctxt distincts splits raw =
6.136 - HEADGOAL (rtac raw ORELSE' rtac (raw RS trans) THEN'
6.137 - SELECT_GOAL (unfold_thms_tac ctxt [Let_def]) THEN' REPEAT_DETERM o
6.138 - (rtac refl ORELSE' atac ORELSE'
6.139 - resolve_tac split_connectI ORELSE'
6.140 - Splitter.split_tac (split_if :: splits) ORELSE'
6.141 - distinct_in_prems_tac distincts ORELSE'
6.142 - rtac sym THEN' atac ORELSE'
6.143 - etac notE THEN' atac));
6.144 -
6.145 -end;
7.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Mon Nov 04 15:44:43 2013 +0100
7.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Mon Nov 04 16:53:43 2013 +0100
7.3 @@ -8,504 +8,26 @@
7.4
7.5 signature BNF_FP_REC_SUGAR_UTIL =
7.6 sig
7.7 - datatype rec_call =
7.8 - No_Rec of int * typ |
7.9 - Mutual_Rec of (int * typ) (*before*) * (int * typ) (*after*) |
7.10 - Nested_Rec of int * typ
7.11 + val indexed: 'a list -> int -> int list * int
7.12 + val indexedd: 'a list list -> int -> int list list * int
7.13 + val indexeddd: ''a list list list -> int -> int list list list * int
7.14 + val indexedddd: 'a list list list list -> int -> int list list list list * int
7.15 + val find_index_eq: ''a list -> ''a -> int
7.16 + val finds: ('a * 'b -> bool) -> 'a list -> 'b list -> ('a * 'b list) list * 'b list
7.17
7.18 - datatype corec_call =
7.19 - Dummy_No_Corec of int |
7.20 - No_Corec of int |
7.21 - Mutual_Corec of int (*stop?*) * int (*end*) * int (*continue*) |
7.22 - Nested_Corec of int
7.23 + val drop_All: term -> term
7.24
7.25 - type rec_ctr_spec =
7.26 - {ctr: term,
7.27 - offset: int,
7.28 - calls: rec_call list,
7.29 - rec_thm: thm}
7.30 + val mk_partial_compN: int -> typ -> term -> term
7.31 + val mk_partial_comp: typ -> typ -> term -> term
7.32 + val mk_compN: int -> typ list -> term * term -> term
7.33 + val mk_comp: typ list -> term * term -> term
7.34
7.35 - type basic_corec_ctr_spec =
7.36 - {ctr: term,
7.37 - disc: term,
7.38 - sels: term list}
7.39 -
7.40 - type corec_ctr_spec =
7.41 - {ctr: term,
7.42 - disc: term,
7.43 - sels: term list,
7.44 - pred: int option,
7.45 - calls: corec_call list,
7.46 - discI: thm,
7.47 - sel_thms: thm list,
7.48 - collapse: thm,
7.49 - corec_thm: thm,
7.50 - disc_corec: thm,
7.51 - sel_corecs: thm list}
7.52 -
7.53 - type rec_spec =
7.54 - {recx: term,
7.55 - nested_map_idents: thm list,
7.56 - nested_map_comps: thm list,
7.57 - ctr_specs: rec_ctr_spec list}
7.58 -
7.59 - type corec_spec =
7.60 - {corec: term,
7.61 - nested_maps: thm list,
7.62 - nested_map_idents: thm list,
7.63 - nested_map_comps: thm list,
7.64 - ctr_specs: corec_ctr_spec list}
7.65 -
7.66 - val s_not: term -> term
7.67 - val s_not_conj: term list -> term list
7.68 - val s_conjs: term list -> term
7.69 - val s_disjs: term list -> term
7.70 - val s_dnf: term list list -> term list
7.71 -
7.72 - val mk_partial_compN: int -> typ -> typ -> term -> term
7.73 -
7.74 - val massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
7.75 - typ list -> term -> term -> term -> term
7.76 - val massage_mutual_corec_call: Proof.context -> (term -> bool) -> (typ list -> term -> term) ->
7.77 - typ list -> term -> term
7.78 - val massage_nested_corec_call: Proof.context -> (term -> bool) ->
7.79 - (typ list -> typ -> typ -> term -> term) -> typ list -> typ -> term -> term
7.80 - val fold_rev_corec_call: Proof.context -> (term list -> term -> 'a -> 'a) -> typ list -> term ->
7.81 - 'a -> string list * 'a
7.82 - val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
7.83 - val massage_corec_code_rhs: Proof.context -> (typ list -> term -> term list -> term) ->
7.84 - typ list -> term -> term
7.85 - val fold_rev_corec_code_rhs: Proof.context -> (term list -> term -> term list -> 'a -> 'a) ->
7.86 - typ list -> term -> 'a -> 'a
7.87 - val case_thms_of_term: Proof.context -> typ list -> term ->
7.88 - thm list * thm list * thm list * thm list
7.89 -
7.90 - val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
7.91 - ((term * term list list) list) list -> local_theory ->
7.92 - (bool * rec_spec list * typ list * thm * thm list) * local_theory
7.93 - val basic_corec_specs_of: Proof.context -> typ -> basic_corec_ctr_spec list
7.94 - val corec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
7.95 - ((term * term list list) list) list -> local_theory ->
7.96 - (bool * corec_spec list * typ list * thm * thm * thm list * thm list) * local_theory
7.97 + val get_indices: ((binding * typ) * 'a) list -> term -> int list
7.98 end;
7.99
7.100 structure BNF_FP_Rec_Sugar_Util : BNF_FP_REC_SUGAR_UTIL =
7.101 struct
7.102
7.103 -open Ctr_Sugar
7.104 -open BNF_Util
7.105 -open BNF_Def
7.106 -open BNF_FP_Util
7.107 -open BNF_FP_Def_Sugar
7.108 -open BNF_FP_N2M_Sugar
7.109 -
7.110 -datatype rec_call =
7.111 - No_Rec of int * typ |
7.112 - Mutual_Rec of (int * typ) * (int * typ) |
7.113 - Nested_Rec of int * typ;
7.114 -
7.115 -datatype corec_call =
7.116 - Dummy_No_Corec of int |
7.117 - No_Corec of int |
7.118 - Mutual_Corec of int * int * int |
7.119 - Nested_Corec of int;
7.120 -
7.121 -type rec_ctr_spec =
7.122 - {ctr: term,
7.123 - offset: int,
7.124 - calls: rec_call list,
7.125 - rec_thm: thm};
7.126 -
7.127 -type basic_corec_ctr_spec =
7.128 - {ctr: term,
7.129 - disc: term,
7.130 - sels: term list};
7.131 -
7.132 -type corec_ctr_spec =
7.133 - {ctr: term,
7.134 - disc: term,
7.135 - sels: term list,
7.136 - pred: int option,
7.137 - calls: corec_call list,
7.138 - discI: thm,
7.139 - sel_thms: thm list,
7.140 - collapse: thm,
7.141 - corec_thm: thm,
7.142 - disc_corec: thm,
7.143 - sel_corecs: thm list};
7.144 -
7.145 -type rec_spec =
7.146 - {recx: term,
7.147 - nested_map_idents: thm list,
7.148 - nested_map_comps: thm list,
7.149 - ctr_specs: rec_ctr_spec list};
7.150 -
7.151 -type corec_spec =
7.152 - {corec: term,
7.153 - nested_maps: thm list,
7.154 - nested_map_idents: thm list,
7.155 - nested_map_comps: thm list,
7.156 - ctr_specs: corec_ctr_spec list};
7.157 -
7.158 -val id_def = @{thm id_def};
7.159 -
7.160 -exception AINT_NO_MAP of term;
7.161 -
7.162 -fun not_codatatype ctxt T =
7.163 - error ("Not a codatatype: " ^ Syntax.string_of_typ ctxt T);
7.164 -fun ill_formed_rec_call ctxt t =
7.165 - error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
7.166 -fun ill_formed_corec_call ctxt t =
7.167 - error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
7.168 -fun invalid_map ctxt t =
7.169 - error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
7.170 -fun unexpected_rec_call ctxt t =
7.171 - error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
7.172 -fun unexpected_corec_call ctxt t =
7.173 - error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
7.174 -
7.175 -val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
7.176 -val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
7.177 -
7.178 -val conjuncts_s = filter_out (curry (op =) @{const True}) o HOLogic.conjuncts;
7.179 -
7.180 -fun s_not @{const True} = @{const False}
7.181 - | s_not @{const False} = @{const True}
7.182 - | s_not (@{const Not} $ t) = t
7.183 - | s_not (@{const conj} $ t $ u) = @{const disj} $ s_not t $ s_not u
7.184 - | s_not (@{const disj} $ t $ u) = @{const conj} $ s_not t $ s_not u
7.185 - | s_not t = @{const Not} $ t;
7.186 -
7.187 -val s_not_conj = conjuncts_s o s_not o mk_conjs;
7.188 -
7.189 -fun s_conj c @{const True} = c
7.190 - | s_conj c d = HOLogic.mk_conj (c, d);
7.191 -
7.192 -fun propagate_unit_pos u cs = if member (op aconv) cs u then [@{const False}] else cs;
7.193 -
7.194 -fun propagate_unit_neg not_u cs = remove (op aconv) not_u cs;
7.195 -
7.196 -fun propagate_units css =
7.197 - (case List.partition (can the_single) css of
7.198 - ([], _) => css
7.199 - | ([u] :: uss, css') =>
7.200 - [u] :: propagate_units (map (propagate_unit_neg (s_not u))
7.201 - (map (propagate_unit_pos u) (uss @ css'))));
7.202 -
7.203 -fun s_conjs cs =
7.204 - if member (op aconv) cs @{const False} then @{const False}
7.205 - else mk_conjs (remove (op aconv) @{const True} cs);
7.206 -
7.207 -fun s_disjs ds =
7.208 - if member (op aconv) ds @{const True} then @{const True}
7.209 - else mk_disjs (remove (op aconv) @{const False} ds);
7.210 -
7.211 -fun s_dnf css0 =
7.212 - let val css = propagate_units css0 in
7.213 - if null css then
7.214 - [@{const False}]
7.215 - else if exists null css then
7.216 - []
7.217 - else
7.218 - map (fn c :: cs => (c, cs)) css
7.219 - |> AList.coalesce (op =)
7.220 - |> map (fn (c, css) => c :: s_dnf css)
7.221 - |> (fn [cs] => cs | css => [s_disjs (map s_conjs css)])
7.222 - end;
7.223 -
7.224 -fun mk_partial_comp gT fT g =
7.225 - let val T = domain_type fT --> range_type gT in
7.226 - Const (@{const_name Fun.comp}, gT --> fT --> T) $ g
7.227 - end;
7.228 -
7.229 -fun mk_partial_compN 0 _ _ g = g
7.230 - | mk_partial_compN n gT fT g =
7.231 - let val g' = mk_partial_compN (n - 1) gT (range_type fT) g in
7.232 - mk_partial_comp (fastype_of g') fT g'
7.233 - end;
7.234 -
7.235 -fun mk_compN n bound_Ts (g, f) =
7.236 - let val typof = curry fastype_of1 bound_Ts in
7.237 - mk_partial_compN n (typof g) (typof f) g $ f
7.238 - end;
7.239 -
7.240 -val mk_comp = mk_compN 1;
7.241 -
7.242 -fun factor_out_types ctxt massage destU U T =
7.243 - (case try destU U of
7.244 - SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt
7.245 - | NONE => invalid_map ctxt);
7.246 -
7.247 -fun map_flattened_map_args ctxt s map_args fs =
7.248 - let
7.249 - val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
7.250 - val flat_fs' = map_args flat_fs;
7.251 - in
7.252 - permute_like (op aconv) flat_fs fs flat_fs'
7.253 - end;
7.254 -
7.255 -fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
7.256 - let
7.257 - fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
7.258 -
7.259 - val typof = curry fastype_of1 bound_Ts;
7.260 - val build_map_fst = build_map ctxt (fst_const o fst);
7.261 -
7.262 - val yT = typof y;
7.263 - val yU = typof y';
7.264 -
7.265 - fun y_of_y' () = build_map_fst (yU, yT) $ y';
7.266 - val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
7.267 -
7.268 - fun massage_mutual_fun U T t =
7.269 - (case t of
7.270 - Const (@{const_name comp}, comp_T) $ t1 $ t2 =>
7.271 - mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
7.272 - | _ =>
7.273 - if has_call t then factor_out_types ctxt raw_massage_fun HOLogic.dest_prodT U T t
7.274 - else mk_comp bound_Ts (t, build_map_fst (U, T)));
7.275 -
7.276 - fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
7.277 - (case try (dest_map ctxt s) t of
7.278 - SOME (map0, fs) =>
7.279 - let
7.280 - val Type (_, ran_Ts) = range_type (typof t);
7.281 - val map' = mk_map (length fs) Us ran_Ts map0;
7.282 - val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
7.283 - in
7.284 - Term.list_comb (map', fs')
7.285 - end
7.286 - | NONE => raise AINT_NO_MAP t)
7.287 - | massage_map _ _ t = raise AINT_NO_MAP t
7.288 - and massage_map_or_map_arg U T t =
7.289 - if T = U then
7.290 - tap check_no_call t
7.291 - else
7.292 - massage_map U T t
7.293 - handle AINT_NO_MAP _ => massage_mutual_fun U T t;
7.294 -
7.295 - fun massage_call (t as t1 $ t2) =
7.296 - if has_call t then
7.297 - if t2 = y then
7.298 - massage_map yU yT (elim_y t1) $ y'
7.299 - handle AINT_NO_MAP t' => invalid_map ctxt t'
7.300 - else
7.301 - let val (g, xs) = Term.strip_comb t2 in
7.302 - if g = y then
7.303 - if exists has_call xs then unexpected_rec_call ctxt t2
7.304 - else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
7.305 - else
7.306 - ill_formed_rec_call ctxt t
7.307 - end
7.308 - else
7.309 - elim_y t
7.310 - | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
7.311 - in
7.312 - massage_call
7.313 - end;
7.314 -
7.315 -fun fold_rev_let_if_case ctxt f bound_Ts t =
7.316 - let
7.317 - val thy = Proof_Context.theory_of ctxt;
7.318 -
7.319 - fun fld conds t =
7.320 - (case Term.strip_comb t of
7.321 - (Const (@{const_name Let}, _), [_, _]) => fld conds (unfold_let t)
7.322 - | (Const (@{const_name If}, _), [cond, then_branch, else_branch]) =>
7.323 - fld (conds @ conjuncts_s cond) then_branch o fld (conds @ s_not_conj [cond]) else_branch
7.324 - | (Const (c, _), args as _ :: _ :: _) =>
7.325 - let val n = num_binder_types (Sign.the_const_type thy c) - 1 in
7.326 - if n >= 0 andalso n < length args then
7.327 - (case fastype_of1 (bound_Ts, nth args n) of
7.328 - Type (s, Ts) =>
7.329 - (case dest_case ctxt s Ts t of
7.330 - NONE => apsnd (f conds t)
7.331 - | SOME (conds', branches) =>
7.332 - apfst (cons s) o fold_rev (uncurry fld)
7.333 - (map (append conds o conjuncts_s) conds' ~~ branches))
7.334 - | _ => apsnd (f conds t))
7.335 - else
7.336 - apsnd (f conds t)
7.337 - end
7.338 - | _ => apsnd (f conds t))
7.339 - in
7.340 - fld [] t o pair []
7.341 - end;
7.342 -
7.343 -fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex);
7.344 -
7.345 -fun massage_let_if_case ctxt has_call massage_leaf =
7.346 - let
7.347 - val thy = Proof_Context.theory_of ctxt;
7.348 -
7.349 - fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else ();
7.350 -
7.351 - fun massage_abs bound_Ts 0 t = massage_rec bound_Ts t
7.352 - | massage_abs bound_Ts m (Abs (s, T, t)) = Abs (s, T, massage_abs (T :: bound_Ts) (m - 1) t)
7.353 - | massage_abs bound_Ts m t =
7.354 - let val T = domain_type (fastype_of1 (bound_Ts, t)) in
7.355 - Abs (Name.uu, T, massage_abs (T :: bound_Ts) (m - 1) (incr_boundvars 1 t $ Bound 0))
7.356 - end
7.357 - and massage_rec bound_Ts t =
7.358 - let val typof = curry fastype_of1 bound_Ts in
7.359 - (case Term.strip_comb t of
7.360 - (Const (@{const_name Let}, _), [_, _]) => massage_rec bound_Ts (unfold_let t)
7.361 - | (Const (@{const_name If}, _), obj :: (branches as [_, _])) =>
7.362 - let val branches' = map (massage_rec bound_Ts) branches in
7.363 - Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches')
7.364 - end
7.365 - | (Const (c, _), args as _ :: _ :: _) =>
7.366 - (case try strip_fun_type (Sign.the_const_type thy c) of
7.367 - SOME (gen_branch_Ts, gen_body_fun_T) =>
7.368 - let
7.369 - val gen_branch_ms = map num_binder_types gen_branch_Ts;
7.370 - val n = length gen_branch_ms;
7.371 - in
7.372 - if n < length args then
7.373 - (case gen_body_fun_T of
7.374 - Type (_, [Type (T_name, _), _]) =>
7.375 - if case_of ctxt T_name = SOME c then
7.376 - let
7.377 - val (branches, obj_leftovers) = chop n args;
7.378 - val branches' = map2 (massage_abs bound_Ts) gen_branch_ms branches;
7.379 - val branch_Ts' = map typof branches';
7.380 - val body_T' = snd (strip_typeN (hd gen_branch_ms) (hd branch_Ts'));
7.381 - val casex' = Const (c, branch_Ts' ---> map typof obj_leftovers ---> body_T');
7.382 - in
7.383 - Term.list_comb (casex', branches' @ tap (List.app check_no_call) obj_leftovers)
7.384 - end
7.385 - else
7.386 - massage_leaf bound_Ts t
7.387 - | _ => massage_leaf bound_Ts t)
7.388 - else
7.389 - massage_leaf bound_Ts t
7.390 - end
7.391 - | NONE => massage_leaf bound_Ts t)
7.392 - | _ => massage_leaf bound_Ts t)
7.393 - end
7.394 - in
7.395 - massage_rec
7.396 - end;
7.397 -
7.398 -val massage_mutual_corec_call = massage_let_if_case;
7.399 -
7.400 -fun curried_type (Type (@{type_name fun}, [Type (@{type_name prod}, Ts), T])) = Ts ---> T;
7.401 -
7.402 -fun massage_nested_corec_call ctxt has_call raw_massage_call bound_Ts U t =
7.403 - let
7.404 - fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else ();
7.405 -
7.406 - val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd);
7.407 -
7.408 - fun massage_mutual_call bound_Ts U T t =
7.409 - if has_call t then factor_out_types ctxt (raw_massage_call bound_Ts) dest_sumT U T t
7.410 - else build_map_Inl (T, U) $ t;
7.411 -
7.412 - fun massage_mutual_fun bound_Ts U T t =
7.413 - (case t of
7.414 - Const (@{const_name comp}, comp_T) $ t1 $ t2 =>
7.415 - mk_comp bound_Ts (massage_mutual_fun bound_Ts U T t1, tap check_no_call t2)
7.416 - | _ =>
7.417 - let
7.418 - val var = Var ((Name.uu, Term.maxidx_of_term t + 1),
7.419 - domain_type (fastype_of1 (bound_Ts, t)));
7.420 - in
7.421 - Term.lambda var (massage_mutual_call bound_Ts U T (t $ var))
7.422 - end);
7.423 -
7.424 - fun massage_map bound_Ts (Type (_, Us)) (Type (s, Ts)) t =
7.425 - (case try (dest_map ctxt s) t of
7.426 - SOME (map0, fs) =>
7.427 - let
7.428 - val Type (_, dom_Ts) = domain_type (fastype_of1 (bound_Ts, t));
7.429 - val map' = mk_map (length fs) dom_Ts Us map0;
7.430 - val fs' =
7.431 - map_flattened_map_args ctxt s (map3 (massage_map_or_map_arg bound_Ts) Us Ts) fs;
7.432 - in
7.433 - Term.list_comb (map', fs')
7.434 - end
7.435 - | NONE => raise AINT_NO_MAP t)
7.436 - | massage_map _ _ _ t = raise AINT_NO_MAP t
7.437 - and massage_map_or_map_arg bound_Ts U T t =
7.438 - if T = U then
7.439 - tap check_no_call t
7.440 - else
7.441 - massage_map bound_Ts U T t
7.442 - handle AINT_NO_MAP _ => massage_mutual_fun bound_Ts U T t;
7.443 -
7.444 - fun massage_call bound_Ts U T =
7.445 - massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
7.446 - if has_call t then
7.447 - (case U of
7.448 - Type (s, Us) =>
7.449 - (case try (dest_ctr ctxt s) t of
7.450 - SOME (f, args) =>
7.451 - let
7.452 - val typof = curry fastype_of1 bound_Ts;
7.453 - val f' = mk_ctr Us f
7.454 - val f'_T = typof f';
7.455 - val arg_Ts = map typof args;
7.456 - in
7.457 - Term.list_comb (f', map3 (massage_call bound_Ts) (binder_types f'_T) arg_Ts args)
7.458 - end
7.459 - | NONE =>
7.460 - (case t of
7.461 - Const (@{const_name prod_case}, _) $ t' =>
7.462 - let
7.463 - val U' = curried_type U;
7.464 - val T' = curried_type T;
7.465 - in
7.466 - Const (@{const_name prod_case}, U' --> U) $ massage_call bound_Ts U' T' t'
7.467 - end
7.468 - | t1 $ t2 =>
7.469 - (if has_call t2 then
7.470 - massage_mutual_call bound_Ts U T t
7.471 - else
7.472 - massage_map bound_Ts U T t1 $ t2
7.473 - handle AINT_NO_MAP _ => massage_mutual_call bound_Ts U T t)
7.474 - | Abs (s, T', t') =>
7.475 - Abs (s, T', massage_call (T' :: bound_Ts) (range_type U) (range_type T) t')
7.476 - | _ => massage_mutual_call bound_Ts U T t))
7.477 - | _ => ill_formed_corec_call ctxt t)
7.478 - else
7.479 - build_map_Inl (T, U) $ t) bound_Ts;
7.480 -
7.481 - val T = fastype_of1 (bound_Ts, t);
7.482 - in
7.483 - if has_call t then massage_call bound_Ts U T t else build_map_Inl (T, U) $ t
7.484 - end;
7.485 -
7.486 -val fold_rev_corec_call = fold_rev_let_if_case;
7.487 -
7.488 -fun expand_to_ctr_term ctxt s Ts t =
7.489 - (case ctr_sugar_of ctxt s of
7.490 - SOME {ctrs, casex, ...} =>
7.491 - Term.list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t
7.492 - | NONE => raise Fail "expand_to_ctr_term");
7.493 -
7.494 -fun expand_corec_code_rhs ctxt has_call bound_Ts t =
7.495 - (case fastype_of1 (bound_Ts, t) of
7.496 - Type (s, Ts) =>
7.497 - massage_let_if_case ctxt has_call (fn _ => fn t =>
7.498 - if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t
7.499 - | _ => raise Fail "expand_corec_code_rhs");
7.500 -
7.501 -fun massage_corec_code_rhs ctxt massage_ctr =
7.502 - massage_let_if_case ctxt (K false)
7.503 - (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
7.504 -
7.505 -fun fold_rev_corec_code_rhs ctxt f =
7.506 - snd ooo fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);
7.507 -
7.508 -fun case_thms_of_term ctxt bound_Ts t =
7.509 - let
7.510 - val (caseT_names, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t ();
7.511 - val ctr_sugars = map (the o ctr_sugar_of ctxt) caseT_names;
7.512 - in
7.513 - (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #sel_splits ctr_sugars,
7.514 - maps #sel_split_asms ctr_sugars)
7.515 - end;
7.516 -
7.517 fun indexed xs h = let val h' = h + length xs in (h upto h' - 1, h') end;
7.518 fun indexedd xss = fold_map indexed xss;
7.519 fun indexeddd xsss = fold_map indexedd xsss;
7.520 @@ -513,224 +35,32 @@
7.521
7.522 fun find_index_eq hs h = find_index (curry (op =) h) hs;
7.523
7.524 -(*FIXME: remove special cases for product and sum once they are registered as datatypes*)
7.525 -fun map_thms_of_typ ctxt (Type (s, _)) =
7.526 - if s = @{type_name prod} then
7.527 - @{thms map_pair_simp}
7.528 - else if s = @{type_name sum} then
7.529 - @{thms sum_map.simps}
7.530 - else
7.531 - (case fp_sugar_of ctxt s of
7.532 - SOME {index, mapss, ...} => nth mapss index
7.533 - | NONE => [])
7.534 - | map_thms_of_typ _ _ = [];
7.535 +fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
7.536
7.537 -fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
7.538 - let
7.539 - val thy = Proof_Context.theory_of lthy;
7.540 +fun drop_All t =
7.541 + subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev,
7.542 + strip_qnt_body @{const_name all} t);
7.543
7.544 - val ((missing_arg_Ts, perm0_kks,
7.545 - fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
7.546 - co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy') =
7.547 - nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy;
7.548 -
7.549 - val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
7.550 -
7.551 - val indices = map #index fp_sugars;
7.552 - val perm_indices = map #index perm_fp_sugars;
7.553 -
7.554 - val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
7.555 - val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
7.556 - val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
7.557 -
7.558 - val nn0 = length arg_Ts;
7.559 - val nn = length perm_lfpTs;
7.560 - val kks = 0 upto nn - 1;
7.561 - val perm_ns = map length perm_ctr_Tsss;
7.562 - val perm_mss = map (map length) perm_ctr_Tsss;
7.563 -
7.564 - val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
7.565 - perm_fp_sugars;
7.566 - val perm_fun_arg_Tssss =
7.567 - mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
7.568 -
7.569 - fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
7.570 - fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
7.571 -
7.572 - val induct_thms = unpermute0 (conj_dests nn induct_thm);
7.573 -
7.574 - val lfpTs = unpermute perm_lfpTs;
7.575 - val Cs = unpermute perm_Cs;
7.576 -
7.577 - val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
7.578 - val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
7.579 -
7.580 - val substA = Term.subst_TVars As_rho;
7.581 - val substAT = Term.typ_subst_TVars As_rho;
7.582 - val substCT = Term.typ_subst_TVars Cs_rho;
7.583 - val substACT = substAT o substCT;
7.584 -
7.585 - val perm_Cs' = map substCT perm_Cs;
7.586 -
7.587 - fun offset_of_ctr 0 _ = 0
7.588 - | offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) =
7.589 - length ctrs + offset_of_ctr (n - 1) ctr_sugars;
7.590 -
7.591 - fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
7.592 - | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
7.593 -
7.594 - fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
7.595 - let
7.596 - val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
7.597 - val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
7.598 - val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
7.599 - in
7.600 - {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
7.601 - rec_thm = rec_thm}
7.602 - end;
7.603 -
7.604 - fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss =
7.605 - let
7.606 - val ctrs = #ctrs (nth ctr_sugars index);
7.607 - val rec_thmss = co_rec_of (nth iter_thmsss index);
7.608 - val k = offset_of_ctr index ctr_sugars;
7.609 - val n = length ctrs;
7.610 - in
7.611 - map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss
7.612 - end;
7.613 -
7.614 - fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...}
7.615 - : fp_sugar) =
7.616 - {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
7.617 - nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
7.618 - nested_map_comps = map map_comp_of_bnf nested_bnfs,
7.619 - ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
7.620 - in
7.621 - ((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms),
7.622 - lthy')
7.623 +fun mk_partial_comp gT fT g =
7.624 + let val T = domain_type fT --> range_type gT in
7.625 + Const (@{const_name Fun.comp}, gT --> fT --> T) $ g
7.626 end;
7.627
7.628 -fun basic_corec_specs_of ctxt res_T =
7.629 - (case res_T of
7.630 - Type (T_name, _) =>
7.631 - (case Ctr_Sugar.ctr_sugar_of ctxt T_name of
7.632 - NONE => not_codatatype ctxt res_T
7.633 - | SOME {ctrs, discs, selss, ...} =>
7.634 - let
7.635 - val thy = Proof_Context.theory_of ctxt;
7.636 - val gfpT = body_type (fastype_of (hd ctrs));
7.637 - val As_rho = tvar_subst thy [gfpT] [res_T];
7.638 - val substA = Term.subst_TVars As_rho;
7.639 +fun mk_partial_compN 0 _ g = g
7.640 + | mk_partial_compN n fT g =
7.641 + let val g' = mk_partial_compN (n - 1) (range_type fT) g in
7.642 + mk_partial_comp (fastype_of g') fT g'
7.643 + end;
7.644
7.645 - fun mk_spec ctr disc sels = {ctr = substA ctr, disc = substA disc, sels = map substA sels};
7.646 - in
7.647 - map3 mk_spec ctrs discs selss
7.648 - end)
7.649 - | _ => not_codatatype ctxt res_T);
7.650 -
7.651 -fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
7.652 - let
7.653 - val thy = Proof_Context.theory_of lthy;
7.654 -
7.655 - val ((missing_res_Ts, perm0_kks,
7.656 - fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...},
7.657 - co_inducts = coinduct_thms, ...} :: _, (_, gfp_sugar_thms)), lthy') =
7.658 - nested_to_mutual_fps Greatest_FP bs res_Ts get_indices callssss0 lthy;
7.659 -
7.660 - val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
7.661 -
7.662 - val indices = map #index fp_sugars;
7.663 - val perm_indices = map #index perm_fp_sugars;
7.664 -
7.665 - val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
7.666 - val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
7.667 - val perm_gfpTs = map (body_type o fastype_of o hd) perm_ctrss;
7.668 -
7.669 - val nn0 = length res_Ts;
7.670 - val nn = length perm_gfpTs;
7.671 - val kks = 0 upto nn - 1;
7.672 - val perm_ns = map length perm_ctr_Tsss;
7.673 -
7.674 - val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o
7.675 - of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars;
7.676 - val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) =
7.677 - mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1);
7.678 -
7.679 - val (perm_p_hss, h) = indexedd perm_p_Tss 0;
7.680 - val (perm_q_hssss, h') = indexedddd perm_q_Tssss h;
7.681 - val (perm_f_hssss, _) = indexedddd perm_f_Tssss h';
7.682 -
7.683 - val fun_arg_hs =
7.684 - flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
7.685 -
7.686 - fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
7.687 - fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
7.688 -
7.689 - val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
7.690 -
7.691 - val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss);
7.692 - val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss);
7.693 - val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss);
7.694 -
7.695 - val f_Tssss = unpermute perm_f_Tssss;
7.696 - val gfpTs = unpermute perm_gfpTs;
7.697 - val Cs = unpermute perm_Cs;
7.698 -
7.699 - val As_rho = tvar_subst thy (take nn0 gfpTs) res_Ts;
7.700 - val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts;
7.701 -
7.702 - val substA = Term.subst_TVars As_rho;
7.703 - val substAT = Term.typ_subst_TVars As_rho;
7.704 - val substCT = Term.typ_subst_TVars Cs_rho;
7.705 -
7.706 - val perm_Cs' = map substCT perm_Cs;
7.707 -
7.708 - fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] =
7.709 - (if exists_subtype_in Cs T then Nested_Corec
7.710 - else if nullary then Dummy_No_Corec
7.711 - else No_Corec) g_i
7.712 - | call_of _ [q_i] [g_i, g_i'] _ = Mutual_Corec (q_i, g_i, g_i');
7.713 -
7.714 - fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss discI sel_thms collapse corec_thm
7.715 - disc_corec sel_corecs =
7.716 - let val nullary = not (can dest_funT (fastype_of ctr)) in
7.717 - {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho,
7.718 - calls = map3 (call_of nullary) q_iss f_iss f_Tss, discI = discI, sel_thms = sel_thms,
7.719 - collapse = collapse, corec_thm = corec_thm, disc_corec = disc_corec,
7.720 - sel_corecs = sel_corecs}
7.721 - end;
7.722 -
7.723 - fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) p_is q_isss f_isss f_Tsss coiter_thmsss
7.724 - disc_coitersss sel_coiterssss =
7.725 - let
7.726 - val ctrs = #ctrs (nth ctr_sugars index);
7.727 - val discs = #discs (nth ctr_sugars index);
7.728 - val selss = #selss (nth ctr_sugars index);
7.729 - val p_ios = map SOME p_is @ [NONE];
7.730 - val discIs = #discIs (nth ctr_sugars index);
7.731 - val sel_thmss = #sel_thmss (nth ctr_sugars index);
7.732 - val collapses = #collapses (nth ctr_sugars index);
7.733 - val corec_thms = co_rec_of (nth coiter_thmsss index);
7.734 - val disc_corecs = co_rec_of (nth disc_coitersss index);
7.735 - val sel_corecss = co_rec_of (nth sel_coiterssss index);
7.736 - in
7.737 - map13 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss discIs sel_thmss collapses
7.738 - corec_thms disc_corecs sel_corecss
7.739 - end;
7.740 -
7.741 - fun mk_spec ({T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss,
7.742 - disc_co_itersss = disc_coitersss, sel_co_iterssss = sel_coiterssss, ...} : fp_sugar)
7.743 - p_is q_isss f_isss f_Tsss =
7.744 - {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)),
7.745 - nested_maps = maps (map_thms_of_typ lthy o T_of_bnf) nested_bnfs,
7.746 - nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs,
7.747 - nested_map_comps = map map_comp_of_bnf nested_bnfs,
7.748 - ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss
7.749 - disc_coitersss sel_coiterssss};
7.750 - in
7.751 - ((is_some gfp_sugar_thms, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
7.752 - co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss,
7.753 - strong_co_induct_of coinduct_thmss), lthy')
7.754 +fun mk_compN n bound_Ts (g, f) =
7.755 + let val typof = curry fastype_of1 bound_Ts in
7.756 + mk_partial_compN n (typof f) g $ f
7.757 end;
7.758
7.759 +val mk_comp = mk_compN 1;
7.760 +
7.761 +fun get_indices fixes t = map (fst #>> Binding.name_of #> Free) fixes
7.762 + |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
7.763 + |> map_filter I;
7.764 +
7.765 end;
8.1 --- a/src/HOL/BNF/Tools/bnf_gfp.ML Mon Nov 04 15:44:43 2013 +0100
8.2 +++ b/src/HOL/BNF/Tools/bnf_gfp.ML Mon Nov 04 16:53:43 2013 +0100
8.3 @@ -23,7 +23,7 @@
8.4 open BNF_Comp
8.5 open BNF_FP_Util
8.6 open BNF_FP_Def_Sugar
8.7 -open BNF_FP_Rec_Sugar
8.8 +open BNF_GFP_Rec_Sugar
8.9 open BNF_GFP_Util
8.10 open BNF_GFP_Tactics
8.11
9.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
9.2 +++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML Mon Nov 04 16:53:43 2013 +0100
9.3 @@ -0,0 +1,1165 @@
9.4 +(* Title: HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
9.5 + Author: Lorenz Panny, TU Muenchen
9.6 + Author: Jasmin Blanchette, TU Muenchen
9.7 + Copyright 2013
9.8 +
9.9 +Corecursor sugar.
9.10 +*)
9.11 +
9.12 +signature BNF_GFP_REC_SUGAR =
9.13 +sig
9.14 + val add_primcorecursive_cmd: bool ->
9.15 + (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
9.16 + Proof.context -> Proof.state
9.17 + val add_primcorec_cmd: bool ->
9.18 + (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
9.19 + local_theory -> local_theory
9.20 +end;
9.21 +
9.22 +structure BNF_GFP_Rec_Sugar : BNF_GFP_REC_SUGAR =
9.23 +struct
9.24 +
9.25 +open Ctr_Sugar
9.26 +open BNF_Util
9.27 +open BNF_Def
9.28 +open BNF_FP_Util
9.29 +open BNF_FP_Def_Sugar
9.30 +open BNF_FP_N2M_Sugar
9.31 +open BNF_FP_Rec_Sugar_Util
9.32 +open BNF_GFP_Rec_Sugar_Tactics
9.33 +
9.34 +val codeN = "code"
9.35 +val ctrN = "ctr"
9.36 +val discN = "disc"
9.37 +val selN = "sel"
9.38 +
9.39 +val nitpicksimp_attrs = @{attributes [nitpick_simp]};
9.40 +val simp_attrs = @{attributes [simp]};
9.41 +val code_nitpicksimp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs;
9.42 +
9.43 +exception Primcorec_Error of string * term list;
9.44 +
9.45 +fun primcorec_error str = raise Primcorec_Error (str, []);
9.46 +fun primcorec_error_eqn str eqn = raise Primcorec_Error (str, [eqn]);
9.47 +fun primcorec_error_eqns str eqns = raise Primcorec_Error (str, eqns);
9.48 +
9.49 +datatype corec_call =
9.50 + Dummy_No_Corec of int |
9.51 + No_Corec of int |
9.52 + Mutual_Corec of int * int * int |
9.53 + Nested_Corec of int;
9.54 +
9.55 +type basic_corec_ctr_spec =
9.56 + {ctr: term,
9.57 + disc: term,
9.58 + sels: term list};
9.59 +
9.60 +type corec_ctr_spec =
9.61 + {ctr: term,
9.62 + disc: term,
9.63 + sels: term list,
9.64 + pred: int option,
9.65 + calls: corec_call list,
9.66 + discI: thm,
9.67 + sel_thms: thm list,
9.68 + collapse: thm,
9.69 + corec_thm: thm,
9.70 + disc_corec: thm,
9.71 + sel_corecs: thm list};
9.72 +
9.73 +type corec_spec =
9.74 + {corec: term,
9.75 + nested_maps: thm list,
9.76 + nested_map_idents: thm list,
9.77 + nested_map_comps: thm list,
9.78 + ctr_specs: corec_ctr_spec list};
9.79 +
9.80 +exception AINT_NO_MAP of term;
9.81 +
9.82 +fun not_codatatype ctxt T =
9.83 + error ("Not a codatatype: " ^ Syntax.string_of_typ ctxt T);
9.84 +fun ill_formed_corec_call ctxt t =
9.85 + error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
9.86 +fun invalid_map ctxt t =
9.87 + error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
9.88 +fun unexpected_corec_call ctxt t =
9.89 + error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
9.90 +
9.91 +val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
9.92 +val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
9.93 +
9.94 +val conjuncts_s = filter_out (curry (op =) @{const True}) o HOLogic.conjuncts;
9.95 +
9.96 +fun s_not @{const True} = @{const False}
9.97 + | s_not @{const False} = @{const True}
9.98 + | s_not (@{const Not} $ t) = t
9.99 + | s_not (@{const conj} $ t $ u) = @{const disj} $ s_not t $ s_not u
9.100 + | s_not (@{const disj} $ t $ u) = @{const conj} $ s_not t $ s_not u
9.101 + | s_not t = @{const Not} $ t;
9.102 +
9.103 +val s_not_conj = conjuncts_s o s_not o mk_conjs;
9.104 +
9.105 +fun propagate_unit_pos u cs = if member (op aconv) cs u then [@{const False}] else cs;
9.106 +
9.107 +fun propagate_unit_neg not_u cs = remove (op aconv) not_u cs;
9.108 +
9.109 +fun propagate_units css =
9.110 + (case List.partition (can the_single) css of
9.111 + ([], _) => css
9.112 + | ([u] :: uss, css') =>
9.113 + [u] :: propagate_units (map (propagate_unit_neg (s_not u))
9.114 + (map (propagate_unit_pos u) (uss @ css'))));
9.115 +
9.116 +fun s_conjs cs =
9.117 + if member (op aconv) cs @{const False} then @{const False}
9.118 + else mk_conjs (remove (op aconv) @{const True} cs);
9.119 +
9.120 +fun s_disjs ds =
9.121 + if member (op aconv) ds @{const True} then @{const True}
9.122 + else mk_disjs (remove (op aconv) @{const False} ds);
9.123 +
9.124 +fun s_dnf css0 =
9.125 + let val css = propagate_units css0 in
9.126 + if null css then
9.127 + [@{const False}]
9.128 + else if exists null css then
9.129 + []
9.130 + else
9.131 + map (fn c :: cs => (c, cs)) css
9.132 + |> AList.coalesce (op =)
9.133 + |> map (fn (c, css) => c :: s_dnf css)
9.134 + |> (fn [cs] => cs | css => [s_disjs (map s_conjs css)])
9.135 + end;
9.136 +
9.137 +fun fold_rev_let_if_case ctxt f bound_Ts t =
9.138 + let
9.139 + val thy = Proof_Context.theory_of ctxt;
9.140 +
9.141 + fun fld conds t =
9.142 + (case Term.strip_comb t of
9.143 + (Const (@{const_name Let}, _), [_, _]) => fld conds (unfold_let t)
9.144 + | (Const (@{const_name If}, _), [cond, then_branch, else_branch]) =>
9.145 + fld (conds @ conjuncts_s cond) then_branch o fld (conds @ s_not_conj [cond]) else_branch
9.146 + | (Const (c, _), args as _ :: _ :: _) =>
9.147 + let val n = num_binder_types (Sign.the_const_type thy c) - 1 in
9.148 + if n >= 0 andalso n < length args then
9.149 + (case fastype_of1 (bound_Ts, nth args n) of
9.150 + Type (s, Ts) =>
9.151 + (case dest_case ctxt s Ts t of
9.152 + NONE => apsnd (f conds t)
9.153 + | SOME (conds', branches) =>
9.154 + apfst (cons s) o fold_rev (uncurry fld)
9.155 + (map (append conds o conjuncts_s) conds' ~~ branches))
9.156 + | _ => apsnd (f conds t))
9.157 + else
9.158 + apsnd (f conds t)
9.159 + end
9.160 + | _ => apsnd (f conds t))
9.161 + in
9.162 + fld [] t o pair []
9.163 + end;
9.164 +
9.165 +fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex);
9.166 +
9.167 +fun massage_let_if_case ctxt has_call massage_leaf =
9.168 + let
9.169 + val thy = Proof_Context.theory_of ctxt;
9.170 +
9.171 + fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else ();
9.172 +
9.173 + fun massage_abs bound_Ts 0 t = massage_rec bound_Ts t
9.174 + | massage_abs bound_Ts m (Abs (s, T, t)) = Abs (s, T, massage_abs (T :: bound_Ts) (m - 1) t)
9.175 + | massage_abs bound_Ts m t =
9.176 + let val T = domain_type (fastype_of1 (bound_Ts, t)) in
9.177 + Abs (Name.uu, T, massage_abs (T :: bound_Ts) (m - 1) (incr_boundvars 1 t $ Bound 0))
9.178 + end
9.179 + and massage_rec bound_Ts t =
9.180 + let val typof = curry fastype_of1 bound_Ts in
9.181 + (case Term.strip_comb t of
9.182 + (Const (@{const_name Let}, _), [_, _]) => massage_rec bound_Ts (unfold_let t)
9.183 + | (Const (@{const_name If}, _), obj :: (branches as [_, _])) =>
9.184 + let val branches' = map (massage_rec bound_Ts) branches in
9.185 + Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches')
9.186 + end
9.187 + | (Const (c, _), args as _ :: _ :: _) =>
9.188 + (case try strip_fun_type (Sign.the_const_type thy c) of
9.189 + SOME (gen_branch_Ts, gen_body_fun_T) =>
9.190 + let
9.191 + val gen_branch_ms = map num_binder_types gen_branch_Ts;
9.192 + val n = length gen_branch_ms;
9.193 + in
9.194 + if n < length args then
9.195 + (case gen_body_fun_T of
9.196 + Type (_, [Type (T_name, _), _]) =>
9.197 + if case_of ctxt T_name = SOME c then
9.198 + let
9.199 + val (branches, obj_leftovers) = chop n args;
9.200 + val branches' = map2 (massage_abs bound_Ts) gen_branch_ms branches;
9.201 + val branch_Ts' = map typof branches';
9.202 + val body_T' = snd (strip_typeN (hd gen_branch_ms) (hd branch_Ts'));
9.203 + val casex' = Const (c, branch_Ts' ---> map typof obj_leftovers ---> body_T');
9.204 + in
9.205 + Term.list_comb (casex', branches' @ tap (List.app check_no_call) obj_leftovers)
9.206 + end
9.207 + else
9.208 + massage_leaf bound_Ts t
9.209 + | _ => massage_leaf bound_Ts t)
9.210 + else
9.211 + massage_leaf bound_Ts t
9.212 + end
9.213 + | NONE => massage_leaf bound_Ts t)
9.214 + | _ => massage_leaf bound_Ts t)
9.215 + end
9.216 + in
9.217 + massage_rec
9.218 + end;
9.219 +
9.220 +val massage_mutual_corec_call = massage_let_if_case;
9.221 +
9.222 +fun curried_type (Type (@{type_name fun}, [Type (@{type_name prod}, Ts), T])) = Ts ---> T;
9.223 +
9.224 +fun massage_nested_corec_call ctxt has_call raw_massage_call bound_Ts U t =
9.225 + let
9.226 + fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else ();
9.227 +
9.228 + val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd);
9.229 +
9.230 + fun massage_mutual_call bound_Ts U T t =
9.231 + if has_call t then
9.232 + (case try dest_sumT U of
9.233 + SOME (U1, U2) => if U1 = T then raw_massage_call bound_Ts T U2 t else invalid_map ctxt t
9.234 + | NONE => invalid_map ctxt t)
9.235 + else
9.236 + build_map_Inl (T, U) $ t;
9.237 +
9.238 + fun massage_mutual_fun bound_Ts U T t =
9.239 + (case t of
9.240 + Const (@{const_name comp}, _) $ t1 $ t2 =>
9.241 + mk_comp bound_Ts (massage_mutual_fun bound_Ts U T t1, tap check_no_call t2)
9.242 + | _ =>
9.243 + let
9.244 + val var = Var ((Name.uu, Term.maxidx_of_term t + 1),
9.245 + domain_type (fastype_of1 (bound_Ts, t)));
9.246 + in
9.247 + Term.lambda var (massage_mutual_call bound_Ts U T (t $ var))
9.248 + end);
9.249 +
9.250 + fun massage_map bound_Ts (Type (_, Us)) (Type (s, Ts)) t =
9.251 + (case try (dest_map ctxt s) t of
9.252 + SOME (map0, fs) =>
9.253 + let
9.254 + val Type (_, dom_Ts) = domain_type (fastype_of1 (bound_Ts, t));
9.255 + val map' = mk_map (length fs) dom_Ts Us map0;
9.256 + val fs' =
9.257 + map_flattened_map_args ctxt s (map3 (massage_map_or_map_arg bound_Ts) Us Ts) fs;
9.258 + in
9.259 + Term.list_comb (map', fs')
9.260 + end
9.261 + | NONE => raise AINT_NO_MAP t)
9.262 + | massage_map _ _ _ t = raise AINT_NO_MAP t
9.263 + and massage_map_or_map_arg bound_Ts U T t =
9.264 + if T = U then
9.265 + tap check_no_call t
9.266 + else
9.267 + massage_map bound_Ts U T t
9.268 + handle AINT_NO_MAP _ => massage_mutual_fun bound_Ts U T t;
9.269 +
9.270 + fun massage_call bound_Ts U T =
9.271 + massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
9.272 + if has_call t then
9.273 + (case U of
9.274 + Type (s, Us) =>
9.275 + (case try (dest_ctr ctxt s) t of
9.276 + SOME (f, args) =>
9.277 + let
9.278 + val typof = curry fastype_of1 bound_Ts;
9.279 + val f' = mk_ctr Us f
9.280 + val f'_T = typof f';
9.281 + val arg_Ts = map typof args;
9.282 + in
9.283 + Term.list_comb (f', map3 (massage_call bound_Ts) (binder_types f'_T) arg_Ts args)
9.284 + end
9.285 + | NONE =>
9.286 + (case t of
9.287 + Const (@{const_name prod_case}, _) $ t' =>
9.288 + let
9.289 + val U' = curried_type U;
9.290 + val T' = curried_type T;
9.291 + in
9.292 + Const (@{const_name prod_case}, U' --> U) $ massage_call bound_Ts U' T' t'
9.293 + end
9.294 + | t1 $ t2 =>
9.295 + (if has_call t2 then
9.296 + massage_mutual_call bound_Ts U T t
9.297 + else
9.298 + massage_map bound_Ts U T t1 $ t2
9.299 + handle AINT_NO_MAP _ => massage_mutual_call bound_Ts U T t)
9.300 + | Abs (s, T', t') =>
9.301 + Abs (s, T', massage_call (T' :: bound_Ts) (range_type U) (range_type T) t')
9.302 + | _ => massage_mutual_call bound_Ts U T t))
9.303 + | _ => ill_formed_corec_call ctxt t)
9.304 + else
9.305 + build_map_Inl (T, U) $ t) bound_Ts;
9.306 +
9.307 + val T = fastype_of1 (bound_Ts, t);
9.308 + in
9.309 + if has_call t then massage_call bound_Ts U T t else build_map_Inl (T, U) $ t
9.310 + end;
9.311 +
9.312 +val fold_rev_corec_call = fold_rev_let_if_case;
9.313 +
9.314 +fun expand_to_ctr_term ctxt s Ts t =
9.315 + (case ctr_sugar_of ctxt s of
9.316 + SOME {ctrs, casex, ...} =>
9.317 + Term.list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t
9.318 + | NONE => raise Fail "expand_to_ctr_term");
9.319 +
9.320 +fun expand_corec_code_rhs ctxt has_call bound_Ts t =
9.321 + (case fastype_of1 (bound_Ts, t) of
9.322 + Type (s, Ts) =>
9.323 + massage_let_if_case ctxt has_call (fn _ => fn t =>
9.324 + if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t
9.325 + | _ => raise Fail "expand_corec_code_rhs");
9.326 +
9.327 +fun massage_corec_code_rhs ctxt massage_ctr =
9.328 + massage_let_if_case ctxt (K false)
9.329 + (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
9.330 +
9.331 +fun fold_rev_corec_code_rhs ctxt f =
9.332 + snd ooo fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);
9.333 +
9.334 +fun case_thms_of_term ctxt bound_Ts t =
9.335 + let
9.336 + val (caseT_names, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t ();
9.337 + val ctr_sugars = map (the o ctr_sugar_of ctxt) caseT_names;
9.338 + in
9.339 + (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #sel_splits ctr_sugars,
9.340 + maps #sel_split_asms ctr_sugars)
9.341 + end;
9.342 +
9.343 +fun basic_corec_specs_of ctxt res_T =
9.344 + (case res_T of
9.345 + Type (T_name, _) =>
9.346 + (case Ctr_Sugar.ctr_sugar_of ctxt T_name of
9.347 + NONE => not_codatatype ctxt res_T
9.348 + | SOME {ctrs, discs, selss, ...} =>
9.349 + let
9.350 + val thy = Proof_Context.theory_of ctxt;
9.351 + val gfpT = body_type (fastype_of (hd ctrs));
9.352 + val As_rho = tvar_subst thy [gfpT] [res_T];
9.353 + val substA = Term.subst_TVars As_rho;
9.354 +
9.355 + fun mk_spec ctr disc sels = {ctr = substA ctr, disc = substA disc, sels = map substA sels};
9.356 + in
9.357 + map3 mk_spec ctrs discs selss
9.358 + end)
9.359 + | _ => not_codatatype ctxt res_T);
9.360 +
9.361 +(*FIXME: remove special cases for product and sum once they are registered as datatypes*)
9.362 +fun map_thms_of_typ ctxt (Type (s, _)) =
9.363 + if s = @{type_name prod} then
9.364 + @{thms map_pair_simp}
9.365 + else if s = @{type_name sum} then
9.366 + @{thms sum_map.simps}
9.367 + else
9.368 + (case fp_sugar_of ctxt s of
9.369 + SOME {index, mapss, ...} => nth mapss index
9.370 + | NONE => [])
9.371 + | map_thms_of_typ _ _ = [];
9.372 +
9.373 +fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
9.374 + let
9.375 + val thy = Proof_Context.theory_of lthy;
9.376 +
9.377 + val ((missing_res_Ts, perm0_kks,
9.378 + fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...},
9.379 + co_inducts = coinduct_thms, ...} :: _, (_, gfp_sugar_thms)), lthy') =
9.380 + nested_to_mutual_fps Greatest_FP bs res_Ts get_indices callssss0 lthy;
9.381 +
9.382 + val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
9.383 +
9.384 + val indices = map #index fp_sugars;
9.385 + val perm_indices = map #index perm_fp_sugars;
9.386 +
9.387 + val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
9.388 + val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
9.389 + val perm_gfpTs = map (body_type o fastype_of o hd) perm_ctrss;
9.390 +
9.391 + val nn0 = length res_Ts;
9.392 + val nn = length perm_gfpTs;
9.393 + val kks = 0 upto nn - 1;
9.394 + val perm_ns = map length perm_ctr_Tsss;
9.395 +
9.396 + val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o
9.397 + of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars;
9.398 + val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) =
9.399 + mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1);
9.400 +
9.401 + val (perm_p_hss, h) = indexedd perm_p_Tss 0;
9.402 + val (perm_q_hssss, h') = indexedddd perm_q_Tssss h;
9.403 + val (perm_f_hssss, _) = indexedddd perm_f_Tssss h';
9.404 +
9.405 + val fun_arg_hs =
9.406 + flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
9.407 +
9.408 + fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
9.409 + fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
9.410 +
9.411 + val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
9.412 +
9.413 + val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss);
9.414 + val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss);
9.415 + val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss);
9.416 +
9.417 + val f_Tssss = unpermute perm_f_Tssss;
9.418 + val gfpTs = unpermute perm_gfpTs;
9.419 + val Cs = unpermute perm_Cs;
9.420 +
9.421 + val As_rho = tvar_subst thy (take nn0 gfpTs) res_Ts;
9.422 + val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts;
9.423 +
9.424 + val substA = Term.subst_TVars As_rho;
9.425 + val substAT = Term.typ_subst_TVars As_rho;
9.426 + val substCT = Term.typ_subst_TVars Cs_rho;
9.427 +
9.428 + val perm_Cs' = map substCT perm_Cs;
9.429 +
9.430 + fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] =
9.431 + (if exists_subtype_in Cs T then Nested_Corec
9.432 + else if nullary then Dummy_No_Corec
9.433 + else No_Corec) g_i
9.434 + | call_of _ [q_i] [g_i, g_i'] _ = Mutual_Corec (q_i, g_i, g_i');
9.435 +
9.436 + fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss discI sel_thms collapse corec_thm
9.437 + disc_corec sel_corecs =
9.438 + let val nullary = not (can dest_funT (fastype_of ctr)) in
9.439 + {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho,
9.440 + calls = map3 (call_of nullary) q_iss f_iss f_Tss, discI = discI, sel_thms = sel_thms,
9.441 + collapse = collapse, corec_thm = corec_thm, disc_corec = disc_corec,
9.442 + sel_corecs = sel_corecs}
9.443 + end;
9.444 +
9.445 + fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) p_is q_isss f_isss f_Tsss coiter_thmsss
9.446 + disc_coitersss sel_coiterssss =
9.447 + let
9.448 + val ctrs = #ctrs (nth ctr_sugars index);
9.449 + val discs = #discs (nth ctr_sugars index);
9.450 + val selss = #selss (nth ctr_sugars index);
9.451 + val p_ios = map SOME p_is @ [NONE];
9.452 + val discIs = #discIs (nth ctr_sugars index);
9.453 + val sel_thmss = #sel_thmss (nth ctr_sugars index);
9.454 + val collapses = #collapses (nth ctr_sugars index);
9.455 + val corec_thms = co_rec_of (nth coiter_thmsss index);
9.456 + val disc_corecs = co_rec_of (nth disc_coitersss index);
9.457 + val sel_corecss = co_rec_of (nth sel_coiterssss index);
9.458 + in
9.459 + map13 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss discIs sel_thmss collapses
9.460 + corec_thms disc_corecs sel_corecss
9.461 + end;
9.462 +
9.463 + fun mk_spec ({T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss,
9.464 + disc_co_itersss = disc_coitersss, sel_co_iterssss = sel_coiterssss, ...} : fp_sugar)
9.465 + p_is q_isss f_isss f_Tsss =
9.466 + {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)),
9.467 + nested_maps = maps (map_thms_of_typ lthy o T_of_bnf) nested_bnfs,
9.468 + nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs,
9.469 + nested_map_comps = map map_comp_of_bnf nested_bnfs,
9.470 + ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss
9.471 + disc_coitersss sel_coiterssss};
9.472 + in
9.473 + ((is_some gfp_sugar_thms, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts,
9.474 + co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss,
9.475 + strong_co_induct_of coinduct_thmss), lthy')
9.476 + end;
9.477 +
9.478 +val const_name = try (fn Const (v, _) => v);
9.479 +val undef_const = Const (@{const_name undefined}, dummyT);
9.480 +
9.481 +val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
9.482 +fun abstract vs =
9.483 + let fun a n (t $ u) = a n t $ a n u
9.484 + | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b)
9.485 + | a n t = let val idx = find_index (equal t) vs in
9.486 + if idx < 0 then t else Bound (n + idx) end
9.487 + in a 0 end;
9.488 +fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u;
9.489 +fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts));
9.490 +
9.491 +type coeqn_data_disc = {
9.492 + fun_name: string,
9.493 + fun_T: typ,
9.494 + fun_args: term list,
9.495 + ctr: term,
9.496 + ctr_no: int, (*###*)
9.497 + disc: term,
9.498 + prems: term list,
9.499 + auto_gen: bool,
9.500 + maybe_ctr_rhs: term option,
9.501 + maybe_code_rhs: term option,
9.502 + user_eqn: term
9.503 +};
9.504 +
9.505 +type coeqn_data_sel = {
9.506 + fun_name: string,
9.507 + fun_T: typ,
9.508 + fun_args: term list,
9.509 + ctr: term,
9.510 + sel: term,
9.511 + rhs_term: term,
9.512 + user_eqn: term
9.513 +};
9.514 +
9.515 +datatype coeqn_data =
9.516 + Disc of coeqn_data_disc |
9.517 + Sel of coeqn_data_sel;
9.518 +
9.519 +fun dissect_coeqn_disc seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list)
9.520 + maybe_ctr_rhs maybe_code_rhs prems' concl matchedsss =
9.521 + let
9.522 + fun find_subterm p = let (* FIXME \<exists>? *)
9.523 + fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
9.524 + | f t = if p t then SOME t else NONE
9.525 + in f end;
9.526 +
9.527 + val applied_fun = concl
9.528 + |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
9.529 + |> the
9.530 + handle Option.Option => primcorec_error_eqn "malformed discriminator formula" concl;
9.531 + val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
9.532 + val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
9.533 +
9.534 + val discs = map #disc basic_ctr_specs;
9.535 + val ctrs = map #ctr basic_ctr_specs;
9.536 + val not_disc = head_of concl = @{term Not};
9.537 + val _ = not_disc andalso length ctrs <> 2 andalso
9.538 + primcorec_error_eqn "negated discriminator for a type with \<noteq> 2 constructors" concl;
9.539 + val disc' = find_subterm (member (op =) discs o head_of) concl;
9.540 + val eq_ctr0 = concl |> perhaps (try HOLogic.dest_not) |> try (HOLogic.dest_eq #> snd)
9.541 + |> (fn SOME t => let val n = find_index (equal t) ctrs in
9.542 + if n >= 0 then SOME n else NONE end | _ => NONE);
9.543 + val _ = is_some disc' orelse is_some eq_ctr0 orelse
9.544 + primcorec_error_eqn "no discriminator in equation" concl;
9.545 + val ctr_no' =
9.546 + if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs;
9.547 + val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
9.548 + val {ctr, disc, ...} = nth basic_ctr_specs ctr_no;
9.549 +
9.550 + val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_;
9.551 + val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
9.552 + val prems = map (abstract (List.rev fun_args)) prems';
9.553 + val real_prems =
9.554 + (if catch_all orelse seq then maps s_not_conj matchedss else []) @
9.555 + (if catch_all then [] else prems);
9.556 +
9.557 + val matchedsss' = AList.delete (op =) fun_name matchedsss
9.558 + |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]);
9.559 +
9.560 + val user_eqn =
9.561 + (real_prems, concl)
9.562 + |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop o abstract (List.rev fun_args)
9.563 + |> curry Logic.list_all (map dest_Free fun_args) o Logic.list_implies;
9.564 + in
9.565 + (Disc {
9.566 + fun_name = fun_name,
9.567 + fun_T = fun_T,
9.568 + fun_args = fun_args,
9.569 + ctr = ctr,
9.570 + ctr_no = ctr_no,
9.571 + disc = disc,
9.572 + prems = real_prems,
9.573 + auto_gen = catch_all,
9.574 + maybe_ctr_rhs = maybe_ctr_rhs,
9.575 + maybe_code_rhs = maybe_code_rhs,
9.576 + user_eqn = user_eqn
9.577 + }, matchedsss')
9.578 + end;
9.579 +
9.580 +fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
9.581 + maybe_of_spec eqn =
9.582 + let
9.583 + val (lhs, rhs) = HOLogic.dest_eq eqn
9.584 + handle TERM _ =>
9.585 + primcorec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
9.586 + val sel = head_of lhs;
9.587 + val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
9.588 + handle TERM _ =>
9.589 + primcorec_error_eqn "malformed selector argument in left-hand side" eqn;
9.590 + val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name)
9.591 + handle Option.Option => primcorec_error_eqn "malformed selector argument in left-hand side" eqn;
9.592 + val {ctr, ...} =
9.593 + (case maybe_of_spec of
9.594 + SOME of_spec => the (find_first (equal of_spec o #ctr) basic_ctr_specs)
9.595 + | NONE => filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single
9.596 + handle List.Empty => primcorec_error_eqn "ambiguous selector - use \"of\"" eqn);
9.597 + val user_eqn = drop_All eqn';
9.598 + in
9.599 + Sel {
9.600 + fun_name = fun_name,
9.601 + fun_T = fun_T,
9.602 + fun_args = fun_args,
9.603 + ctr = ctr,
9.604 + sel = sel,
9.605 + rhs_term = rhs,
9.606 + user_eqn = user_eqn
9.607 + }
9.608 + end;
9.609 +
9.610 +fun dissect_coeqn_ctr seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
9.611 + maybe_code_rhs prems concl matchedsss =
9.612 + let
9.613 + val (lhs, rhs) = HOLogic.dest_eq concl;
9.614 + val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
9.615 + val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
9.616 + val (ctr, ctr_args) = strip_comb (unfold_let rhs);
9.617 + val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs)
9.618 + handle Option.Option => primcorec_error_eqn "not a constructor" ctr;
9.619 +
9.620 + val disc_concl = betapply (disc, lhs);
9.621 + val (maybe_eqn_data_disc, matchedsss') = if length basic_ctr_specs = 1
9.622 + then (NONE, matchedsss)
9.623 + else apfst SOME (dissect_coeqn_disc seq fun_names basic_ctr_specss
9.624 + (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
9.625 +
9.626 + val sel_concls = sels ~~ ctr_args
9.627 + |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
9.628 +
9.629 +(*
9.630 +val _ = tracing ("reduced\n " ^ Syntax.string_of_term @{context} concl ^ "\nto\n \<cdot> " ^
9.631 + (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n \<cdot> ")) "" ^
9.632 + space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls) ^
9.633 + "\nfor premise(s)\n \<cdot> " ^
9.634 + space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) prems));
9.635 +*)
9.636 +
9.637 + val eqns_data_sel =
9.638 + map (dissect_coeqn_sel fun_names basic_ctr_specss eqn' (SOME ctr)) sel_concls;
9.639 + in
9.640 + (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
9.641 + end;
9.642 +
9.643 +fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss =
9.644 + let
9.645 + val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
9.646 + val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
9.647 + val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
9.648 +
9.649 + val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
9.650 + if member ((op =) o apsnd #ctr) basic_ctr_specs ctr
9.651 + then cons (ctr, cs)
9.652 + else primcorec_error_eqn "not a constructor" ctr) [] rhs' []
9.653 + |> AList.group (op =);
9.654 +
9.655 + val ctr_premss = (case cond_ctrs of [_] => [[]] | _ => map (s_dnf o snd) cond_ctrs);
9.656 + val ctr_concls = cond_ctrs |> map (fn (ctr, _) =>
9.657 + binder_types (fastype_of ctr)
9.658 + |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
9.659 + if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
9.660 + |> curry list_comb ctr
9.661 + |> curry HOLogic.mk_eq lhs);
9.662 + in
9.663 + fold_map2 (dissect_coeqn_ctr false fun_names basic_ctr_specss eqn'
9.664 + (SOME (abstract (List.rev fun_args) rhs)))
9.665 + ctr_premss ctr_concls matchedsss
9.666 + end;
9.667 +
9.668 +fun dissect_coeqn lthy seq has_call fun_names (basic_ctr_specss : basic_corec_ctr_spec list list)
9.669 + eqn' maybe_of_spec matchedsss =
9.670 + let
9.671 + val eqn = drop_All eqn'
9.672 + handle TERM _ => primcorec_error_eqn "malformed function equation" eqn';
9.673 + val (prems, concl) = Logic.strip_horn eqn
9.674 + |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
9.675 +
9.676 + val head = concl
9.677 + |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
9.678 + |> head_of;
9.679 +
9.680 + val maybe_rhs = concl |> perhaps (try HOLogic.dest_not) |> try (snd o HOLogic.dest_eq);
9.681 +
9.682 + val discs = maps (map #disc) basic_ctr_specss;
9.683 + val sels = maps (maps #sels) basic_ctr_specss;
9.684 + val ctrs = maps (map #ctr) basic_ctr_specss;
9.685 + in
9.686 + if member (op =) discs head orelse
9.687 + is_some maybe_rhs andalso
9.688 + member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
9.689 + dissect_coeqn_disc seq fun_names basic_ctr_specss NONE NONE prems concl matchedsss
9.690 + |>> single
9.691 + else if member (op =) sels head then
9.692 + ([dissect_coeqn_sel fun_names basic_ctr_specss eqn' maybe_of_spec concl], matchedsss)
9.693 + else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
9.694 + member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then
9.695 + dissect_coeqn_ctr seq fun_names basic_ctr_specss eqn' NONE prems concl matchedsss
9.696 + else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
9.697 + null prems then
9.698 + dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss
9.699 + |>> flat
9.700 + else
9.701 + primcorec_error_eqn "malformed function equation" eqn
9.702 + end;
9.703 +
9.704 +fun build_corec_arg_disc (ctr_specs : corec_ctr_spec list)
9.705 + ({fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
9.706 + if is_none (#pred (nth ctr_specs ctr_no)) then I else
9.707 + s_conjs prems
9.708 + |> curry subst_bounds (List.rev fun_args)
9.709 + |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args)
9.710 + |> K |> nth_map (the (#pred (nth ctr_specs ctr_no)));
9.711 +
9.712 +fun build_corec_arg_no_call (sel_eqns : coeqn_data_sel list) sel =
9.713 + find_first (equal sel o #sel) sel_eqns
9.714 + |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term)
9.715 + |> the_default undef_const
9.716 + |> K;
9.717 +
9.718 +fun build_corec_args_mutual_call lthy has_call (sel_eqns : coeqn_data_sel list) sel =
9.719 + (case find_first (equal sel o #sel) sel_eqns of
9.720 + NONE => (I, I, I)
9.721 + | SOME {fun_args, rhs_term, ... } =>
9.722 + let
9.723 + val bound_Ts = List.rev (map fastype_of fun_args);
9.724 + fun rewrite_stop _ t = if has_call t then @{term False} else @{term True};
9.725 + fun rewrite_end _ t = if has_call t then undef_const else t;
9.726 + fun rewrite_cont bound_Ts t =
9.727 + if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
9.728 + fun massage f _ = massage_mutual_corec_call lthy has_call f bound_Ts rhs_term
9.729 + |> abs_tuple fun_args;
9.730 + in
9.731 + (massage rewrite_stop, massage rewrite_end, massage rewrite_cont)
9.732 + end);
9.733 +
9.734 +fun build_corec_arg_nested_call lthy has_call (sel_eqns : coeqn_data_sel list) sel =
9.735 + (case find_first (equal sel o #sel) sel_eqns of
9.736 + NONE => I
9.737 + | SOME {fun_args, rhs_term, ...} =>
9.738 + let
9.739 + val bound_Ts = List.rev (map fastype_of fun_args);
9.740 + fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
9.741 + | rewrite bound_Ts U T (t as _ $ _) =
9.742 + let val (u, vs) = strip_comb t in
9.743 + if is_Free u andalso has_call u then
9.744 + Inr_const U T $ mk_tuple1 bound_Ts vs
9.745 + else if const_name u = SOME @{const_name prod_case} then
9.746 + map (rewrite bound_Ts U T) vs |> chop 1 |>> HOLogic.mk_split o the_single |> list_comb
9.747 + else
9.748 + list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs)
9.749 + end
9.750 + | rewrite _ U T t =
9.751 + if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
9.752 + fun massage t =
9.753 + rhs_term
9.754 + |> massage_nested_corec_call lthy has_call rewrite bound_Ts (range_type (fastype_of t))
9.755 + |> abs_tuple fun_args;
9.756 + in
9.757 + massage
9.758 + end);
9.759 +
9.760 +fun build_corec_args_sel lthy has_call (all_sel_eqns : coeqn_data_sel list)
9.761 + (ctr_spec : corec_ctr_spec) =
9.762 + (case filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns of
9.763 + [] => I
9.764 + | sel_eqns =>
9.765 + let
9.766 + val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
9.767 + val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
9.768 + val mutual_calls' = map_filter (try (apsnd (fn Mutual_Corec n => n))) sel_call_list;
9.769 + val nested_calls' = map_filter (try (apsnd (fn Nested_Corec n => n))) sel_call_list;
9.770 + in
9.771 + I
9.772 + #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls'
9.773 + #> fold (fn (sel, (q, g, h)) =>
9.774 + let val (fq, fg, fh) = build_corec_args_mutual_call lthy has_call sel_eqns sel in
9.775 + nth_map q fq o nth_map g fg o nth_map h fh end) mutual_calls'
9.776 + #> fold (fn (sel, n) => nth_map n
9.777 + (build_corec_arg_nested_call lthy has_call sel_eqns sel)) nested_calls'
9.778 + end);
9.779 +
9.780 +fun build_codefs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list)
9.781 + (disc_eqnss : coeqn_data_disc list list) (sel_eqnss : coeqn_data_sel list list) =
9.782 + let
9.783 + val corecs = map #corec corec_specs;
9.784 + val ctr_specss = map #ctr_specs corec_specs;
9.785 + val corec_args = hd corecs
9.786 + |> fst o split_last o binder_types o fastype_of
9.787 + |> map (Const o pair @{const_name undefined})
9.788 + |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss
9.789 + |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
9.790 + fun currys [] t = t
9.791 + | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0))
9.792 + |> fold_rev (Term.abs o pair Name.uu) Ts;
9.793 +
9.794 +(*
9.795 +val _ = tracing ("corecursor arguments:\n \<cdot> " ^
9.796 + space_implode "\n \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
9.797 +*)
9.798 +
9.799 + val exclss' =
9.800 + disc_eqnss
9.801 + |> map (map (fn x => (#fun_args x, #ctr_no x, #prems x, #auto_gen x))
9.802 + #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs [])
9.803 + #> maps (uncurry (map o pair)
9.804 + #> map (fn ((fun_args, c, x, a), (_, c', y, a')) =>
9.805 + ((c, c', a orelse a'), (x, s_not (s_conjs y)))
9.806 + ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop
9.807 + ||> Logic.list_implies
9.808 + ||> curry Logic.list_all (map dest_Free fun_args))))
9.809 + in
9.810 + map (list_comb o rpair corec_args) corecs
9.811 + |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
9.812 + |> map2 currys arg_Tss
9.813 + |> Syntax.check_terms lthy
9.814 + |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
9.815 + bs mxs
9.816 + |> rpair exclss'
9.817 + end;
9.818 +
9.819 +fun mk_real_disc_eqns fun_binding arg_Ts ({ctr_specs, ...} : corec_spec)
9.820 + (sel_eqns : coeqn_data_sel list) (disc_eqns : coeqn_data_disc list) =
9.821 + if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
9.822 + let
9.823 + val n = 0 upto length ctr_specs
9.824 + |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
9.825 + val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
9.826 + |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
9.827 + val extra_disc_eqn = {
9.828 + fun_name = Binding.name_of fun_binding,
9.829 + fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
9.830 + fun_args = fun_args,
9.831 + ctr = #ctr (nth ctr_specs n),
9.832 + ctr_no = n,
9.833 + disc = #disc (nth ctr_specs n),
9.834 + prems = maps (s_not_conj o #prems) disc_eqns,
9.835 + auto_gen = true,
9.836 + maybe_ctr_rhs = NONE,
9.837 + maybe_code_rhs = NONE,
9.838 + user_eqn = undef_const};
9.839 + in
9.840 + chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
9.841 + end;
9.842 +
9.843 +fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
9.844 + let
9.845 + val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs
9.846 + |> find_index (equal sel) o #sels o the;
9.847 + fun find t = if has_call t then snd (fold_rev_corec_call ctxt (K cons) [] t []) else [];
9.848 + in
9.849 + find rhs_term
9.850 + |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
9.851 + end;
9.852 +
9.853 +fun add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy =
9.854 + let
9.855 + val (bs, mxs) = map_split (apfst fst) fixes;
9.856 + val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
9.857 +
9.858 + val fun_names = map Binding.name_of bs;
9.859 + val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
9.860 + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
9.861 + val eqns_data =
9.862 + fold_map2 (dissect_coeqn lthy seq has_call fun_names basic_ctr_specss) (map snd specs)
9.863 + maybe_of_specs []
9.864 + |> flat o fst;
9.865 +
9.866 + val callssss =
9.867 + map_filter (try (fn Sel x => x)) eqns_data
9.868 + |> partition_eq ((op =) o pairself #fun_name)
9.869 + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
9.870 + |> map (flat o snd)
9.871 + |> map2 (fold o find_corec_calls lthy has_call) basic_ctr_specss
9.872 + |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} =>
9.873 + (ctr, map (K []) sels))) basic_ctr_specss);
9.874 +
9.875 +(*
9.876 +val _ = tracing ("callssss = " ^ @{make_string} callssss);
9.877 +*)
9.878 +
9.879 + val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms,
9.880 + strong_coinduct_thms), lthy') =
9.881 + corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
9.882 + val actual_nn = length bs;
9.883 + val corec_specs = take actual_nn corec_specs'; (*###*)
9.884 + val ctr_specss = map #ctr_specs corec_specs;
9.885 +
9.886 + val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
9.887 + |> partition_eq ((op =) o pairself #fun_name)
9.888 + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
9.889 + |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
9.890 + val _ = disc_eqnss' |> map (fn x =>
9.891 + let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
9.892 + primcorec_error_eqns "excess discriminator formula in definition"
9.893 + (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
9.894 +
9.895 + val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
9.896 + |> partition_eq ((op =) o pairself #fun_name)
9.897 + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
9.898 + |> map (flat o snd);
9.899 +
9.900 + val arg_Tss = map (binder_types o snd o fst) fixes;
9.901 + val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
9.902 + val (defs, exclss') =
9.903 + build_codefs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
9.904 +
9.905 + fun excl_tac (c, c', a) =
9.906 + if a orelse c = c' orelse seq then SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy [])))
9.907 + else maybe_tac;
9.908 +
9.909 +(*
9.910 +val _ = tracing ("exclusiveness properties:\n \<cdot> " ^
9.911 + space_implode "\n \<cdot> " (maps (map (Syntax.string_of_term lthy o snd)) exclss'));
9.912 +*)
9.913 +
9.914 + val exclss'' = exclss' |> map (map (fn (idx, t) =>
9.915 + (idx, (Option.map (Goal.prove lthy [] [] t #> Thm.close_derivation) (excl_tac idx), t))));
9.916 + val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss'';
9.917 + val (goal_idxss, goalss) = exclss''
9.918 + |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
9.919 + |> split_list o map split_list;
9.920 +
9.921 + fun prove thmss' def_thms' lthy =
9.922 + let
9.923 + val def_thms = map (snd o snd) def_thms';
9.924 +
9.925 + val exclss' = map (op ~~) (goal_idxss ~~ thmss');
9.926 + fun mk_exclsss excls n =
9.927 + (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1))
9.928 + |-> fold (fn ((c, c', _), thm) => nth_map c (nth_map c' (K [thm])));
9.929 + val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
9.930 + |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
9.931 +
9.932 + fun prove_disc ({ctr_specs, ...} : corec_spec) exclsss
9.933 + ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
9.934 + if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else
9.935 + let
9.936 + val {disc_corec, ...} = nth ctr_specs ctr_no;
9.937 + val k = 1 + ctr_no;
9.938 + val m = length prems;
9.939 + val t =
9.940 + list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
9.941 + |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
9.942 + |> HOLogic.mk_Trueprop
9.943 + |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
9.944 + |> curry Logic.list_all (map dest_Free fun_args);
9.945 + in
9.946 + if prems = [@{term False}] then [] else
9.947 + mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss
9.948 + |> K |> Goal.prove lthy [] [] t
9.949 + |> Thm.close_derivation
9.950 + |> pair (#disc (nth ctr_specs ctr_no))
9.951 + |> single
9.952 + end;
9.953 +
9.954 + fun prove_sel ({nested_maps, nested_map_idents, nested_map_comps, ctr_specs, ...}
9.955 + : corec_spec) (disc_eqns : coeqn_data_disc list) exclsss
9.956 + ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) =
9.957 + let
9.958 + val SOME ctr_spec = find_first (equal ctr o #ctr) ctr_specs;
9.959 + val ctr_no = find_index (equal ctr o #ctr) ctr_specs;
9.960 + val prems = the_default (maps (s_not_conj o #prems) disc_eqns)
9.961 + (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems);
9.962 + val sel_corec = find_index (equal sel) (#sels ctr_spec)
9.963 + |> nth (#sel_corecs ctr_spec);
9.964 + val k = 1 + ctr_no;
9.965 + val m = length prems;
9.966 + val t =
9.967 + list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
9.968 + |> curry betapply sel
9.969 + |> rpair (abstract (List.rev fun_args) rhs_term)
9.970 + |> HOLogic.mk_Trueprop o HOLogic.mk_eq
9.971 + |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
9.972 + |> curry Logic.list_all (map dest_Free fun_args);
9.973 + val (distincts, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term;
9.974 + in
9.975 + mk_primcorec_sel_tac lthy def_thms distincts sel_splits sel_split_asms nested_maps
9.976 + nested_map_idents nested_map_comps sel_corec k m exclsss
9.977 + |> K |> Goal.prove lthy [] [] t
9.978 + |> Thm.close_derivation
9.979 + |> pair sel
9.980 + end;
9.981 +
9.982 + fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list)
9.983 + (sel_eqns : coeqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) =
9.984 + (* don't try to prove theorems when some sel_eqns are missing *)
9.985 + if not (exists (equal ctr o #ctr) disc_eqns)
9.986 + andalso not (exists (equal ctr o #ctr) sel_eqns)
9.987 + orelse
9.988 + filter (equal ctr o #ctr) sel_eqns
9.989 + |> fst o finds ((op =) o apsnd #sel) sels
9.990 + |> exists (null o snd)
9.991 + then [] else
9.992 + let
9.993 + val (fun_name, fun_T, fun_args, prems, maybe_rhs) =
9.994 + (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
9.995 + |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x,
9.996 + #maybe_ctr_rhs x))
9.997 + ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], NONE))
9.998 + |> the o merge_options;
9.999 + val m = length prems;
9.1000 + val t = (if is_some maybe_rhs then the maybe_rhs else
9.1001 + filter (equal ctr o #ctr) sel_eqns
9.1002 + |> fst o finds ((op =) o apsnd #sel) sels
9.1003 + |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
9.1004 + |> curry list_comb ctr)
9.1005 + |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
9.1006 + map Bound (length fun_args - 1 downto 0)))
9.1007 + |> HOLogic.mk_Trueprop
9.1008 + |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
9.1009 + |> curry Logic.list_all (map dest_Free fun_args);
9.1010 + val maybe_disc_thm = AList.lookup (op =) disc_alist disc;
9.1011 + val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist);
9.1012 + in
9.1013 + if prems = [@{term False}] then [] else
9.1014 + mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
9.1015 + |> K |> Goal.prove lthy [] [] t
9.1016 + |> Thm.close_derivation
9.1017 + |> pair ctr
9.1018 + |> single
9.1019 + end;
9.1020 +
9.1021 + fun prove_code disc_eqns sel_eqns ctr_alist ctr_specs =
9.1022 + let
9.1023 + val (fun_name, fun_T, fun_args, maybe_rhs) =
9.1024 + (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns,
9.1025 + find_first (member (op =) (map #ctr ctr_specs) o #ctr) sel_eqns)
9.1026 + |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #maybe_code_rhs x))
9.1027 + ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, NONE))
9.1028 + |> the o merge_options;
9.1029 +
9.1030 + val bound_Ts = List.rev (map fastype_of fun_args);
9.1031 +
9.1032 + val lhs = list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0));
9.1033 + val maybe_rhs_info =
9.1034 + (case maybe_rhs of
9.1035 + SOME rhs =>
9.1036 + let
9.1037 + val raw_rhs = expand_corec_code_rhs lthy has_call bound_Ts rhs;
9.1038 + val cond_ctrs =
9.1039 + fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts raw_rhs [];
9.1040 + val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs;
9.1041 + in SOME (rhs, raw_rhs, ctr_thms) end
9.1042 + | NONE =>
9.1043 + let
9.1044 + fun prove_code_ctr {ctr, sels, ...} =
9.1045 + if not (exists (equal ctr o fst) ctr_alist) then NONE else
9.1046 + let
9.1047 + val prems = find_first (equal ctr o #ctr) disc_eqns
9.1048 + |> Option.map #prems |> the_default [];
9.1049 + val t =
9.1050 + filter (equal ctr o #ctr) sel_eqns
9.1051 + |> fst o finds ((op =) o apsnd #sel) sels
9.1052 + |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x))
9.1053 + #-> abstract)
9.1054 + |> curry list_comb ctr;
9.1055 + in
9.1056 + SOME (prems, t)
9.1057 + end;
9.1058 + val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs;
9.1059 + in
9.1060 + if exists is_none maybe_ctr_conds_argss then NONE else
9.1061 + let
9.1062 + val rhs = fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t)
9.1063 + maybe_ctr_conds_argss
9.1064 + (Const (@{const_name Code.abort}, @{typ String.literal} -->
9.1065 + (@{typ unit} --> body_type fun_T) --> body_type fun_T) $
9.1066 + HOLogic.mk_literal fun_name $
9.1067 + absdummy @{typ unit} (incr_boundvars 1 lhs));
9.1068 + in SOME (rhs, rhs, map snd ctr_alist) end
9.1069 + end);
9.1070 + in
9.1071 + (case maybe_rhs_info of
9.1072 + NONE => []
9.1073 + | SOME (rhs, raw_rhs, ctr_thms) =>
9.1074 + let
9.1075 + val ms = map (Logic.count_prems o prop_of) ctr_thms;
9.1076 + val (raw_t, t) = (raw_rhs, rhs)
9.1077 + |> pairself
9.1078 + (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
9.1079 + map Bound (length fun_args - 1 downto 0)))
9.1080 + #> HOLogic.mk_Trueprop
9.1081 + #> curry Logic.list_all (map dest_Free fun_args));
9.1082 + val (distincts, discIs, sel_splits, sel_split_asms) =
9.1083 + case_thms_of_term lthy bound_Ts raw_rhs;
9.1084 +
9.1085 + val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits
9.1086 + sel_split_asms ms ctr_thms
9.1087 + |> K |> Goal.prove lthy [] [] raw_t
9.1088 + |> Thm.close_derivation;
9.1089 + in
9.1090 + mk_primcorec_code_of_raw_code_tac lthy distincts sel_splits raw_code_thm
9.1091 + |> K |> Goal.prove lthy [] [] t
9.1092 + |> Thm.close_derivation
9.1093 + |> single
9.1094 + end)
9.1095 + end;
9.1096 +
9.1097 + val disc_alists = map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss;
9.1098 + val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss;
9.1099 + val disc_thmss = map (map snd) disc_alists;
9.1100 + val sel_thmss = map (map snd) sel_alists;
9.1101 +
9.1102 + val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
9.1103 + ctr_specss;
9.1104 + val ctr_thmss = map (map snd) ctr_alists;
9.1105 +
9.1106 + val code_thmss = map4 prove_code disc_eqnss sel_eqnss ctr_alists ctr_specss;
9.1107 +
9.1108 + val simp_thmss = map2 append disc_thmss sel_thmss
9.1109 +
9.1110 + val common_name = mk_common_name fun_names;
9.1111 +
9.1112 + val notes =
9.1113 + [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
9.1114 + (codeN, code_thmss, code_nitpicksimp_attrs),
9.1115 + (ctrN, ctr_thmss, []),
9.1116 + (discN, disc_thmss, simp_attrs),
9.1117 + (selN, sel_thmss, simp_attrs),
9.1118 + (simpsN, simp_thmss, []),
9.1119 + (strong_coinductN, map (if n2m then single else K []) strong_coinduct_thms, [])]
9.1120 + |> maps (fn (thmN, thmss, attrs) =>
9.1121 + map2 (fn fun_name => fn thms =>
9.1122 + ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]))
9.1123 + fun_names (take actual_nn thmss))
9.1124 + |> filter_out (null o fst o hd o snd);
9.1125 +
9.1126 + val common_notes =
9.1127 + [(coinductN, if n2m then [coinduct_thm] else [], []),
9.1128 + (strong_coinductN, if n2m then [strong_coinduct_thm] else [], [])]
9.1129 + |> filter_out (null o #2)
9.1130 + |> map (fn (thmN, thms, attrs) =>
9.1131 + ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
9.1132 + in
9.1133 + lthy |> Local_Theory.notes (notes @ common_notes) |> snd
9.1134 + end;
9.1135 +
9.1136 + fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss';
9.1137 + in
9.1138 + (goalss, after_qed, lthy')
9.1139 + end;
9.1140 +
9.1141 +fun add_primcorec_ursive_cmd maybe_tac seq (raw_fixes, raw_specs') lthy =
9.1142 + let
9.1143 + val (raw_specs, maybe_of_specs) =
9.1144 + split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy));
9.1145 + val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy;
9.1146 + in
9.1147 + add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy
9.1148 + handle ERROR str => primcorec_error str
9.1149 + end
9.1150 + handle Primcorec_Error (str, eqns) =>
9.1151 + if null eqns
9.1152 + then error ("primcorec error:\n " ^ str)
9.1153 + else error ("primcorec error:\n " ^ str ^ "\nin\n " ^
9.1154 + space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns));
9.1155 +
9.1156 +val add_primcorecursive_cmd = (fn (goalss, after_qed, lthy) =>
9.1157 + lthy
9.1158 + |> Proof.theorem NONE after_qed goalss
9.1159 + |> Proof.refine (Method.primitive_text I)
9.1160 + |> Seq.hd) ooo add_primcorec_ursive_cmd NONE;
9.1161 +
9.1162 +val add_primcorec_cmd = (fn (goalss, after_qed, lthy) =>
9.1163 + lthy
9.1164 + |> after_qed (map (fn [] => []
9.1165 + | _ => primcorec_error "need exclusiveness proofs - use primcorecursive instead of primcorec")
9.1166 + goalss)) ooo add_primcorec_ursive_cmd (SOME (fn {context = ctxt, ...} => auto_tac ctxt));
9.1167 +
9.1168 +end;
10.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
10.2 +++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML Mon Nov 04 16:53:43 2013 +0100
10.3 @@ -0,0 +1,135 @@
10.4 +(* Title: HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML
10.5 + Author: Jasmin Blanchette, TU Muenchen
10.6 + Copyright 2013
10.7 +
10.8 +Tactics for corecursor sugar.
10.9 +*)
10.10 +
10.11 +signature BNF_GFP_REC_SUGAR_TACTICS =
10.12 +sig
10.13 + val mk_primcorec_assumption_tac: Proof.context -> thm list -> int -> tactic
10.14 + val mk_primcorec_code_of_raw_code_tac: Proof.context -> thm list -> thm list -> thm -> tactic
10.15 + val mk_primcorec_ctr_of_dtr_tac: Proof.context -> int -> thm -> thm option -> thm list -> tactic
10.16 + val mk_primcorec_disc_tac: Proof.context -> thm list -> thm -> int -> int -> thm list list list ->
10.17 + tactic
10.18 + val mk_primcorec_raw_code_of_ctr_tac: Proof.context -> thm list -> thm list -> thm list ->
10.19 + thm list -> int list -> thm list -> tactic
10.20 + val mk_primcorec_sel_tac: Proof.context -> thm list -> thm list -> thm list -> thm list ->
10.21 + thm list -> thm list -> thm list -> thm -> int -> int -> thm list list list -> tactic
10.22 +end;
10.23 +
10.24 +structure BNF_GFP_Rec_Sugar_Tactics : BNF_GFP_REC_SUGAR_TACTICS =
10.25 +struct
10.26 +
10.27 +open BNF_Util
10.28 +open BNF_Tactics
10.29 +
10.30 +val falseEs = @{thms not_TrueE FalseE};
10.31 +val Let_def = @{thm Let_def};
10.32 +val neq_eq_eq_contradict = @{thm neq_eq_eq_contradict};
10.33 +val split_if = @{thm split_if};
10.34 +val split_if_asm = @{thm split_if_asm};
10.35 +val split_connectI = @{thms allI impI conjI};
10.36 +
10.37 +fun mk_primcorec_assumption_tac ctxt discIs =
10.38 + SELECT_GOAL (unfold_thms_tac ctxt
10.39 + @{thms not_not not_False_eq_True not_True_eq_False de_Morgan_conj de_Morgan_disj} THEN
10.40 + SOLVE (HEADGOAL (REPEAT o (rtac refl ORELSE' atac ORELSE' etac conjE ORELSE'
10.41 + eresolve_tac falseEs ORELSE'
10.42 + resolve_tac @{thms TrueI conjI disjI1 disjI2} ORELSE'
10.43 + dresolve_tac discIs THEN' atac ORELSE'
10.44 + etac notE THEN' atac ORELSE'
10.45 + etac disjE))));
10.46 +
10.47 +fun mk_primcorec_same_case_tac m =
10.48 + HEADGOAL (if m = 0 then rtac TrueI
10.49 + else REPEAT_DETERM_N (m - 1) o (rtac conjI THEN' atac) THEN' atac);
10.50 +
10.51 +fun mk_primcorec_different_case_tac ctxt m excl =
10.52 + HEADGOAL (if m = 0 then mk_primcorec_assumption_tac ctxt []
10.53 + else dtac excl THEN' (REPEAT_DETERM_N (m - 1) o atac) THEN' mk_primcorec_assumption_tac ctxt []);
10.54 +
10.55 +fun mk_primcorec_cases_tac ctxt k m exclsss =
10.56 + let val n = length exclsss in
10.57 + EVERY (map (fn [] => if k = n then all_tac else mk_primcorec_same_case_tac m
10.58 + | [excl] => mk_primcorec_different_case_tac ctxt m excl)
10.59 + (take k (nth exclsss (k - 1))))
10.60 + end;
10.61 +
10.62 +fun mk_primcorec_prelude ctxt defs thm =
10.63 + unfold_thms_tac ctxt defs THEN HEADGOAL (rtac thm) THEN
10.64 + unfold_thms_tac ctxt @{thms Let_def split};
10.65 +
10.66 +fun mk_primcorec_disc_tac ctxt defs disc_corec k m exclsss =
10.67 + mk_primcorec_prelude ctxt defs disc_corec THEN mk_primcorec_cases_tac ctxt k m exclsss;
10.68 +
10.69 +fun mk_primcorec_sel_tac ctxt defs distincts splits split_asms maps map_idents map_comps f_sel k m
10.70 + exclsss =
10.71 + mk_primcorec_prelude ctxt defs (f_sel RS trans) THEN
10.72 + mk_primcorec_cases_tac ctxt k m exclsss THEN
10.73 + HEADGOAL (REPEAT_DETERM o (rtac refl ORELSE' rtac ext ORELSE'
10.74 + eresolve_tac falseEs ORELSE'
10.75 + resolve_tac split_connectI ORELSE'
10.76 + Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE'
10.77 + Splitter.split_tac (split_if :: splits) ORELSE'
10.78 + eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac ORELSE'
10.79 + etac notE THEN' atac ORELSE'
10.80 + (CHANGED o SELECT_GOAL (unfold_thms_tac ctxt
10.81 + (@{thms id_def o_def split_def sum.cases} @ maps @ map_comps @ map_idents)))));
10.82 +
10.83 +fun mk_primcorec_ctr_of_dtr_tac ctxt m collapse maybe_disc_f sel_fs =
10.84 + HEADGOAL (rtac ((if null sel_fs then collapse else collapse RS sym) RS trans) THEN'
10.85 + (the_default (K all_tac) (Option.map rtac maybe_disc_f)) THEN' REPEAT_DETERM_N m o atac) THEN
10.86 + unfold_thms_tac ctxt (Let_def :: sel_fs) THEN HEADGOAL (rtac refl);
10.87 +
10.88 +fun inst_split_eq ctxt split =
10.89 + (case prop_of split of
10.90 + @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ (Var (_, Type (_, [T, _])) $ _) $ _) =>
10.91 + let
10.92 + val s = Name.uu;
10.93 + val eq = Abs (Name.uu, T, HOLogic.mk_eq (Free (s, T), Bound 0));
10.94 + val split' = Drule.instantiate' [] [SOME (certify ctxt eq)] split;
10.95 + in
10.96 + Thm.generalize ([], [s]) (Thm.maxidx_of split' + 1) split'
10.97 + end
10.98 + | _ => split);
10.99 +
10.100 +fun distinct_in_prems_tac distincts =
10.101 + eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac;
10.102 +
10.103 +(* TODO: reduce code duplication with selector tactic above *)
10.104 +fun mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms m f_ctr =
10.105 + let
10.106 + val splits' =
10.107 + map (fn th => th RS iffD2) (@{thm split_if_eq2} :: map (inst_split_eq ctxt) splits)
10.108 + in
10.109 + HEADGOAL (REPEAT o (resolve_tac (splits' @ split_connectI))) THEN
10.110 + mk_primcorec_prelude ctxt [] (f_ctr RS trans) THEN
10.111 + HEADGOAL ((REPEAT_DETERM_N m o mk_primcorec_assumption_tac ctxt discIs) THEN'
10.112 + SELECT_GOAL (SOLVE (HEADGOAL (REPEAT_DETERM o
10.113 + (rtac refl ORELSE' atac ORELSE'
10.114 + resolve_tac (@{thm Code.abort_def} :: split_connectI) ORELSE'
10.115 + Splitter.split_tac (split_if :: splits) ORELSE'
10.116 + Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE'
10.117 + mk_primcorec_assumption_tac ctxt discIs ORELSE'
10.118 + distinct_in_prems_tac distincts ORELSE'
10.119 + (TRY o dresolve_tac discIs) THEN' etac notE THEN' atac)))))
10.120 + end;
10.121 +
10.122 +fun mk_primcorec_raw_code_of_ctr_tac ctxt distincts discIs splits split_asms ms f_ctrs =
10.123 + EVERY (map2 (mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms) ms
10.124 + f_ctrs) THEN
10.125 + IF_UNSOLVED (unfold_thms_tac ctxt @{thms Code.abort_def} THEN
10.126 + HEADGOAL (REPEAT_DETERM o resolve_tac (refl :: split_connectI)));
10.127 +
10.128 +fun mk_primcorec_code_of_raw_code_tac ctxt distincts splits raw =
10.129 + HEADGOAL (rtac raw ORELSE' rtac (raw RS trans) THEN'
10.130 + SELECT_GOAL (unfold_thms_tac ctxt [Let_def]) THEN' REPEAT_DETERM o
10.131 + (rtac refl ORELSE' atac ORELSE'
10.132 + resolve_tac split_connectI ORELSE'
10.133 + Splitter.split_tac (split_if :: splits) ORELSE'
10.134 + distinct_in_prems_tac distincts ORELSE'
10.135 + rtac sym THEN' atac ORELSE'
10.136 + etac notE THEN' atac));
10.137 +
10.138 +end;
11.1 --- a/src/HOL/BNF/Tools/bnf_lfp.ML Mon Nov 04 15:44:43 2013 +0100
11.2 +++ b/src/HOL/BNF/Tools/bnf_lfp.ML Mon Nov 04 16:53:43 2013 +0100
11.3 @@ -22,7 +22,7 @@
11.4 open BNF_Comp
11.5 open BNF_FP_Util
11.6 open BNF_FP_Def_Sugar
11.7 -open BNF_FP_Rec_Sugar
11.8 +open BNF_LFP_Rec_Sugar
11.9 open BNF_LFP_Util
11.10 open BNF_LFP_Tactics
11.11
12.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
12.2 +++ b/src/HOL/BNF/Tools/bnf_lfp_rec_sugar.ML Mon Nov 04 16:53:43 2013 +0100
12.3 @@ -0,0 +1,598 @@
12.4 +(* Title: HOL/BNF/Tools/bnf_lfp_rec_sugar.ML
12.5 + Author: Lorenz Panny, TU Muenchen
12.6 + Author: Jasmin Blanchette, TU Muenchen
12.7 + Copyright 2013
12.8 +
12.9 +Recursor sugar.
12.10 +*)
12.11 +
12.12 +signature BNF_LFP_REC_SUGAR =
12.13 +sig
12.14 + val add_primrec: (binding * typ option * mixfix) list ->
12.15 + (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
12.16 + val add_primrec_cmd: (binding * string option * mixfix) list ->
12.17 + (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
12.18 + val add_primrec_global: (binding * typ option * mixfix) list ->
12.19 + (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
12.20 + val add_primrec_overloaded: (string * (string * typ) * bool) list ->
12.21 + (binding * typ option * mixfix) list ->
12.22 + (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
12.23 + val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
12.24 + local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
12.25 +end;
12.26 +
12.27 +structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
12.28 +struct
12.29 +
12.30 +open Ctr_Sugar
12.31 +open BNF_Util
12.32 +open BNF_Tactics
12.33 +open BNF_Def
12.34 +open BNF_FP_Util
12.35 +open BNF_FP_Def_Sugar
12.36 +open BNF_FP_N2M_Sugar
12.37 +open BNF_FP_Rec_Sugar_Util
12.38 +
12.39 +val nitpicksimp_attrs = @{attributes [nitpick_simp]};
12.40 +val simp_attrs = @{attributes [simp]};
12.41 +val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
12.42 +
12.43 +exception Primrec_Error of string * term list;
12.44 +
12.45 +fun primrec_error str = raise Primrec_Error (str, []);
12.46 +fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
12.47 +fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
12.48 +
12.49 +datatype rec_call =
12.50 + No_Rec of int * typ |
12.51 + Mutual_Rec of (int * typ) * (int * typ) |
12.52 + Nested_Rec of int * typ;
12.53 +
12.54 +type rec_ctr_spec =
12.55 + {ctr: term,
12.56 + offset: int,
12.57 + calls: rec_call list,
12.58 + rec_thm: thm};
12.59 +
12.60 +type rec_spec =
12.61 + {recx: term,
12.62 + nested_map_idents: thm list,
12.63 + nested_map_comps: thm list,
12.64 + ctr_specs: rec_ctr_spec list};
12.65 +
12.66 +exception AINT_NO_MAP of term;
12.67 +
12.68 +fun ill_formed_rec_call ctxt t =
12.69 + error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
12.70 +fun invalid_map ctxt t =
12.71 + error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
12.72 +fun unexpected_rec_call ctxt t =
12.73 + error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
12.74 +
12.75 +fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
12.76 + let
12.77 + fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
12.78 +
12.79 + val typof = curry fastype_of1 bound_Ts;
12.80 + val build_map_fst = build_map ctxt (fst_const o fst);
12.81 +
12.82 + val yT = typof y;
12.83 + val yU = typof y';
12.84 +
12.85 + fun y_of_y' () = build_map_fst (yU, yT) $ y';
12.86 + val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
12.87 +
12.88 + fun massage_mutual_fun U T t =
12.89 + (case t of
12.90 + Const (@{const_name comp}, _) $ t1 $ t2 =>
12.91 + mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
12.92 + | _ =>
12.93 + if has_call t then
12.94 + (case try HOLogic.dest_prodT U of
12.95 + SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t
12.96 + | NONE => invalid_map ctxt t)
12.97 + else
12.98 + mk_comp bound_Ts (t, build_map_fst (U, T)));
12.99 +
12.100 + fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
12.101 + (case try (dest_map ctxt s) t of
12.102 + SOME (map0, fs) =>
12.103 + let
12.104 + val Type (_, ran_Ts) = range_type (typof t);
12.105 + val map' = mk_map (length fs) Us ran_Ts map0;
12.106 + val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
12.107 + in
12.108 + Term.list_comb (map', fs')
12.109 + end
12.110 + | NONE => raise AINT_NO_MAP t)
12.111 + | massage_map _ _ t = raise AINT_NO_MAP t
12.112 + and massage_map_or_map_arg U T t =
12.113 + if T = U then
12.114 + tap check_no_call t
12.115 + else
12.116 + massage_map U T t
12.117 + handle AINT_NO_MAP _ => massage_mutual_fun U T t;
12.118 +
12.119 + fun massage_call (t as t1 $ t2) =
12.120 + if has_call t then
12.121 + if t2 = y then
12.122 + massage_map yU yT (elim_y t1) $ y'
12.123 + handle AINT_NO_MAP t' => invalid_map ctxt t'
12.124 + else
12.125 + let val (g, xs) = Term.strip_comb t2 in
12.126 + if g = y then
12.127 + if exists has_call xs then unexpected_rec_call ctxt t2
12.128 + else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
12.129 + else
12.130 + ill_formed_rec_call ctxt t
12.131 + end
12.132 + else
12.133 + elim_y t
12.134 + | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
12.135 + in
12.136 + massage_call
12.137 + end;
12.138 +
12.139 +fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
12.140 + let
12.141 + val thy = Proof_Context.theory_of lthy;
12.142 +
12.143 + val ((missing_arg_Ts, perm0_kks,
12.144 + fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
12.145 + co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy') =
12.146 + nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy;
12.147 +
12.148 + val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
12.149 +
12.150 + val indices = map #index fp_sugars;
12.151 + val perm_indices = map #index perm_fp_sugars;
12.152 +
12.153 + val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
12.154 + val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
12.155 + val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
12.156 +
12.157 + val nn0 = length arg_Ts;
12.158 + val nn = length perm_lfpTs;
12.159 + val kks = 0 upto nn - 1;
12.160 + val perm_ns = map length perm_ctr_Tsss;
12.161 + val perm_mss = map (map length) perm_ctr_Tsss;
12.162 +
12.163 + val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
12.164 + perm_fp_sugars;
12.165 + val perm_fun_arg_Tssss =
12.166 + mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
12.167 +
12.168 + fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
12.169 + fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
12.170 +
12.171 + val induct_thms = unpermute0 (conj_dests nn induct_thm);
12.172 +
12.173 + val lfpTs = unpermute perm_lfpTs;
12.174 + val Cs = unpermute perm_Cs;
12.175 +
12.176 + val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
12.177 + val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
12.178 +
12.179 + val substA = Term.subst_TVars As_rho;
12.180 + val substAT = Term.typ_subst_TVars As_rho;
12.181 + val substCT = Term.typ_subst_TVars Cs_rho;
12.182 + val substACT = substAT o substCT;
12.183 +
12.184 + val perm_Cs' = map substCT perm_Cs;
12.185 +
12.186 + fun offset_of_ctr 0 _ = 0
12.187 + | offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) =
12.188 + length ctrs + offset_of_ctr (n - 1) ctr_sugars;
12.189 +
12.190 + fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
12.191 + | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
12.192 +
12.193 + fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
12.194 + let
12.195 + val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
12.196 + val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
12.197 + val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
12.198 + in
12.199 + {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
12.200 + rec_thm = rec_thm}
12.201 + end;
12.202 +
12.203 + fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss =
12.204 + let
12.205 + val ctrs = #ctrs (nth ctr_sugars index);
12.206 + val rec_thmss = co_rec_of (nth iter_thmsss index);
12.207 + val k = offset_of_ctr index ctr_sugars;
12.208 + val n = length ctrs;
12.209 + in
12.210 + map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss
12.211 + end;
12.212 +
12.213 + fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...}
12.214 + : fp_sugar) =
12.215 + {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
12.216 + nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs,
12.217 + nested_map_comps = map map_comp_of_bnf nested_bnfs,
12.218 + ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
12.219 + in
12.220 + ((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms),
12.221 + lthy')
12.222 + end;
12.223 +
12.224 +val undef_const = Const (@{const_name undefined}, dummyT);
12.225 +
12.226 +fun permute_args n t =
12.227 + list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
12.228 +
12.229 +type eqn_data = {
12.230 + fun_name: string,
12.231 + rec_type: typ,
12.232 + ctr: term,
12.233 + ctr_args: term list,
12.234 + left_args: term list,
12.235 + right_args: term list,
12.236 + res_type: typ,
12.237 + rhs_term: term,
12.238 + user_eqn: term
12.239 +};
12.240 +
12.241 +fun dissect_eqn lthy fun_names eqn' =
12.242 + let
12.243 + val eqn = drop_All eqn' |> HOLogic.dest_Trueprop
12.244 + handle TERM _ =>
12.245 + primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
12.246 + val (lhs, rhs) = HOLogic.dest_eq eqn
12.247 + handle TERM _ =>
12.248 + primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
12.249 + val (fun_name, args) = strip_comb lhs
12.250 + |>> (fn x => if is_Free x then fst (dest_Free x)
12.251 + else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
12.252 + val (left_args, rest) = take_prefix is_Free args;
12.253 + val (nonfrees, right_args) = take_suffix is_Free rest;
12.254 + val num_nonfrees = length nonfrees;
12.255 + val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
12.256 + primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
12.257 + primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
12.258 + val _ = member (op =) fun_names fun_name orelse
12.259 + primrec_error_eqn "malformed function equation (does not start with function name)" eqn
12.260 +
12.261 + val (ctr, ctr_args) = strip_comb (the_single nonfrees);
12.262 + val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
12.263 + primrec_error_eqn "partially applied constructor in pattern" eqn;
12.264 + val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
12.265 + primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
12.266 + "\" in left-hand side") eqn end;
12.267 + val _ = forall is_Free ctr_args orelse
12.268 + primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
12.269 + val _ =
12.270 + let val b = fold_aterms (fn x as Free (v, _) =>
12.271 + if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
12.272 + not (member (op =) fun_names v) andalso
12.273 + not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
12.274 + in
12.275 + null b orelse
12.276 + primrec_error_eqn ("extra variable(s) in right-hand side: " ^
12.277 + commas (map (Syntax.string_of_term lthy) b)) eqn
12.278 + end;
12.279 + in
12.280 + {fun_name = fun_name,
12.281 + rec_type = body_type (type_of ctr),
12.282 + ctr = ctr,
12.283 + ctr_args = ctr_args,
12.284 + left_args = left_args,
12.285 + right_args = right_args,
12.286 + res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
12.287 + rhs_term = rhs,
12.288 + user_eqn = eqn'}
12.289 + end;
12.290 +
12.291 +fun rewrite_map_arg get_ctr_pos rec_type res_type =
12.292 + let
12.293 + val pT = HOLogic.mk_prodT (rec_type, res_type);
12.294 +
12.295 + val maybe_suc = Option.map (fn x => x + 1);
12.296 + fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
12.297 + | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
12.298 + | subst d t =
12.299 + let
12.300 + val (u, vs) = strip_comb t;
12.301 + val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
12.302 + in
12.303 + if ctr_pos >= 0 then
12.304 + if d = SOME ~1 andalso length vs = ctr_pos then
12.305 + list_comb (permute_args ctr_pos (snd_const pT), vs)
12.306 + else if length vs > ctr_pos andalso is_some d
12.307 + andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
12.308 + list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
12.309 + else
12.310 + primrec_error_eqn ("recursive call not directly applied to constructor argument") t
12.311 + else
12.312 + list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
12.313 + end
12.314 + in
12.315 + subst (SOME ~1)
12.316 + end;
12.317 +
12.318 +fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
12.319 + let
12.320 + fun try_nested_rec bound_Ts y t =
12.321 + AList.lookup (op =) nested_calls y
12.322 + |> Option.map (fn y' =>
12.323 + massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
12.324 +
12.325 + fun subst bound_Ts (t as g' $ y) =
12.326 + let
12.327 + fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
12.328 + val y_head = head_of y;
12.329 + in
12.330 + if not (member (op =) ctr_args y_head) then
12.331 + subst_rec ()
12.332 + else
12.333 + (case try_nested_rec bound_Ts y_head t of
12.334 + SOME t' => t'
12.335 + | NONE =>
12.336 + let val (g, g_args) = strip_comb g' in
12.337 + (case try (get_ctr_pos o fst o dest_Free) g of
12.338 + SOME ctr_pos =>
12.339 + (length g_args >= ctr_pos orelse
12.340 + primrec_error_eqn "too few arguments in recursive call" t;
12.341 + (case AList.lookup (op =) mutual_calls y of
12.342 + SOME y' => list_comb (y', g_args)
12.343 + | NONE => subst_rec ()))
12.344 + | NONE => subst_rec ())
12.345 + end)
12.346 + end
12.347 + | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
12.348 + | subst _ t = t
12.349 +
12.350 + fun subst' t =
12.351 + if has_call t then
12.352 + (* FIXME detect this case earlier? *)
12.353 + primrec_error_eqn "recursive call not directly applied to constructor argument" t
12.354 + else
12.355 + try_nested_rec [] (head_of t) t |> the_default t
12.356 + in
12.357 + subst' o subst []
12.358 + end;
12.359 +
12.360 +fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
12.361 + (maybe_eqn_data : eqn_data option) =
12.362 + (case maybe_eqn_data of
12.363 + NONE => undef_const
12.364 + | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
12.365 + let
12.366 + val calls = #calls ctr_spec;
12.367 + val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
12.368 +
12.369 + val no_calls' = tag_list 0 calls
12.370 + |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
12.371 + val mutual_calls' = tag_list 0 calls
12.372 + |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
12.373 + val nested_calls' = tag_list 0 calls
12.374 + |> map_filter (try (apsnd (fn Nested_Rec p => p)));
12.375 +
12.376 + val args = replicate n_args ("", dummyT)
12.377 + |> Term.rename_wrt_term t
12.378 + |> map Free
12.379 + |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
12.380 + nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
12.381 + no_calls'
12.382 + |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
12.383 + nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
12.384 + mutual_calls'
12.385 + |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
12.386 + nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
12.387 + nested_calls';
12.388 +
12.389 + val fun_name_ctr_pos_list =
12.390 + map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
12.391 + val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
12.392 + val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
12.393 + val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
12.394 + in
12.395 + t
12.396 + |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
12.397 + |> fold_rev lambda (args @ left_args @ right_args)
12.398 + end);
12.399 +
12.400 +fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
12.401 + let
12.402 + val n_funs = length funs_data;
12.403 +
12.404 + val ctr_spec_eqn_data_list' =
12.405 + (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
12.406 + |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
12.407 + ##> (fn x => null x orelse
12.408 + primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
12.409 + val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
12.410 + primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
12.411 +
12.412 + val ctr_spec_eqn_data_list =
12.413 + ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
12.414 +
12.415 + val recs = take n_funs rec_specs |> map #recx;
12.416 + val rec_args = ctr_spec_eqn_data_list
12.417 + |> sort ((op <) o pairself (#offset o fst) |> make_ord)
12.418 + |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
12.419 + val ctr_poss = map (fn x =>
12.420 + if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
12.421 + primrec_error ("inconstant constructor pattern position for function " ^
12.422 + quote (#fun_name (hd x)))
12.423 + else
12.424 + hd x |> #left_args |> length) funs_data;
12.425 + in
12.426 + (recs, ctr_poss)
12.427 + |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
12.428 + |> Syntax.check_terms lthy
12.429 + |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
12.430 + bs mxs
12.431 + end;
12.432 +
12.433 +fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
12.434 + let
12.435 + fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
12.436 + | find bound_Ts (t as _ $ _) ctr_arg =
12.437 + let
12.438 + val typof = curry fastype_of1 bound_Ts;
12.439 + val (f', args') = strip_comb t;
12.440 + val n = find_index (equal ctr_arg o head_of) args';
12.441 + in
12.442 + if n < 0 then
12.443 + find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
12.444 + else
12.445 + let
12.446 + val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
12.447 + val (arg_head, arg_args) = Term.strip_comb arg;
12.448 + in
12.449 + if has_call f then
12.450 + mk_partial_compN (length arg_args) (typof arg_head) f ::
12.451 + maps (fn x => find bound_Ts x ctr_arg) args
12.452 + else
12.453 + find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
12.454 + end
12.455 + end
12.456 + | find _ _ _ = [];
12.457 + in
12.458 + map (find [] rhs_term) ctr_args
12.459 + |> (fn [] => NONE | callss => SOME (ctr, callss))
12.460 + end;
12.461 +
12.462 +fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
12.463 + unfold_thms_tac ctxt fun_defs THEN
12.464 + HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
12.465 + unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN
12.466 + HEADGOAL (rtac refl);
12.467 +
12.468 +fun prepare_primrec fixes specs lthy =
12.469 + let
12.470 + val (bs, mxs) = map_split (apfst fst) fixes;
12.471 + val fun_names = map Binding.name_of bs;
12.472 + val eqns_data = map (dissect_eqn lthy fun_names) specs;
12.473 + val funs_data = eqns_data
12.474 + |> partition_eq ((op =) o pairself #fun_name)
12.475 + |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
12.476 + |> map (fn (x, y) => the_single y handle List.Empty =>
12.477 + primrec_error ("missing equations for function " ^ quote x));
12.478 +
12.479 + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
12.480 + val arg_Ts = map (#rec_type o hd) funs_data;
12.481 + val res_Ts = map (#res_type o hd) funs_data;
12.482 + val callssss = funs_data
12.483 + |> map (partition_eq ((op =) o pairself #ctr))
12.484 + |> map (maps (map_filter (find_rec_calls has_call)));
12.485 +
12.486 + val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
12.487 + rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
12.488 +
12.489 + val actual_nn = length funs_data;
12.490 +
12.491 + val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
12.492 + map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
12.493 + primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
12.494 + " is not a constructor in left-hand side") user_eqn) eqns_data end;
12.495 +
12.496 + val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
12.497 +
12.498 + fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
12.499 + (fun_data : eqn_data list) =
12.500 + let
12.501 + val def_thms = map (snd o snd) def_thms';
12.502 + val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
12.503 + |> fst
12.504 + |> map_filter (try (fn (x, [y]) =>
12.505 + (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
12.506 + |> map (fn (user_eqn, num_extra_args, rec_thm) =>
12.507 + mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
12.508 + |> K |> Goal.prove lthy [] [] user_eqn
12.509 + |> Thm.close_derivation);
12.510 + val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
12.511 + in
12.512 + (poss, simp_thmss)
12.513 + end;
12.514 +
12.515 + val notes =
12.516 + (if n2m then map2 (fn name => fn thm =>
12.517 + (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
12.518 + |> map (fn (prefix, thmN, thms, attrs) =>
12.519 + ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
12.520 +
12.521 + val common_name = mk_common_name fun_names;
12.522 +
12.523 + val common_notes =
12.524 + (if n2m then [(inductN, [induct_thm], [])] else [])
12.525 + |> map (fn (thmN, thms, attrs) =>
12.526 + ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
12.527 + in
12.528 + (((fun_names, defs),
12.529 + fn lthy => fn defs =>
12.530 + split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
12.531 + lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
12.532 + end;
12.533 +
12.534 +(* primrec definition *)
12.535 +
12.536 +fun add_primrec_simple fixes ts lthy =
12.537 + let
12.538 + val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
12.539 + handle ERROR str => primrec_error str;
12.540 + in
12.541 + lthy
12.542 + |> fold_map Local_Theory.define defs
12.543 + |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
12.544 + end
12.545 + handle Primrec_Error (str, eqns) =>
12.546 + if null eqns
12.547 + then error ("primrec_new error:\n " ^ str)
12.548 + else error ("primrec_new error:\n " ^ str ^ "\nin\n " ^
12.549 + space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns));
12.550 +
12.551 +local
12.552 +
12.553 +fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
12.554 + let
12.555 + val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
12.556 + val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
12.557 +
12.558 + val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
12.559 +
12.560 + val mk_notes =
12.561 + flat ooo map3 (fn poss => fn prefix => fn thms =>
12.562 + let
12.563 + val (bs, attrss) = map_split (fst o nth specs) poss;
12.564 + val notes =
12.565 + map3 (fn b => fn attrs => fn thm =>
12.566 + ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])]))
12.567 + bs attrss thms;
12.568 + in
12.569 + ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
12.570 + end);
12.571 + in
12.572 + lthy
12.573 + |> add_primrec_simple fixes (map snd specs)
12.574 + |-> (fn (names, (ts, (posss, simpss))) =>
12.575 + Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
12.576 + #> Local_Theory.notes (mk_notes posss names simpss)
12.577 + #>> pair ts o map snd)
12.578 + end;
12.579 +
12.580 +in
12.581 +
12.582 +val add_primrec = gen_primrec Specification.check_spec;
12.583 +val add_primrec_cmd = gen_primrec Specification.read_spec;
12.584 +
12.585 +end;
12.586 +
12.587 +fun add_primrec_global fixes specs thy =
12.588 + let
12.589 + val lthy = Named_Target.theory_init thy;
12.590 + val ((ts, simps), lthy') = add_primrec fixes specs lthy;
12.591 + val simps' = burrow (Proof_Context.export lthy' lthy) simps;
12.592 + in ((ts, simps'), Local_Theory.exit_global lthy') end;
12.593 +
12.594 +fun add_primrec_overloaded ops fixes specs thy =
12.595 + let
12.596 + val lthy = Overloading.overloading ops thy;
12.597 + val ((ts, simps), lthy') = add_primrec fixes specs lthy;
12.598 + val simps' = burrow (Proof_Context.export lthy' lthy) simps;
12.599 + in ((ts, simps'), Local_Theory.exit_global lthy') end;
12.600 +
12.601 +end;