blanchet@55698
|
1 |
(* Title: HOL/BNF/Tools/bnf_lfp_rec_sugar.ML
|
blanchet@55698
|
2 |
Author: Lorenz Panny, TU Muenchen
|
blanchet@55698
|
3 |
Author: Jasmin Blanchette, TU Muenchen
|
blanchet@55698
|
4 |
Copyright 2013
|
blanchet@55698
|
5 |
|
blanchet@55698
|
6 |
Recursor sugar.
|
blanchet@55698
|
7 |
*)
|
blanchet@55698
|
8 |
|
blanchet@55698
|
9 |
signature BNF_LFP_REC_SUGAR =
|
blanchet@55698
|
10 |
sig
|
blanchet@55698
|
11 |
val add_primrec: (binding * typ option * mixfix) list ->
|
blanchet@55698
|
12 |
(Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
|
blanchet@55698
|
13 |
val add_primrec_cmd: (binding * string option * mixfix) list ->
|
blanchet@55698
|
14 |
(Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
|
blanchet@55698
|
15 |
val add_primrec_global: (binding * typ option * mixfix) list ->
|
blanchet@55698
|
16 |
(Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
|
blanchet@55698
|
17 |
val add_primrec_overloaded: (string * (string * typ) * bool) list ->
|
blanchet@55698
|
18 |
(binding * typ option * mixfix) list ->
|
blanchet@55698
|
19 |
(Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
|
blanchet@55698
|
20 |
val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
|
blanchet@55698
|
21 |
local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
|
blanchet@55698
|
22 |
end;
|
blanchet@55698
|
23 |
|
blanchet@55698
|
24 |
structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
|
blanchet@55698
|
25 |
struct
|
blanchet@55698
|
26 |
|
blanchet@55698
|
27 |
open Ctr_Sugar
|
blanchet@55698
|
28 |
open BNF_Util
|
blanchet@55698
|
29 |
open BNF_Tactics
|
blanchet@55698
|
30 |
open BNF_Def
|
blanchet@55698
|
31 |
open BNF_FP_Util
|
blanchet@55698
|
32 |
open BNF_FP_Def_Sugar
|
blanchet@55698
|
33 |
open BNF_FP_N2M_Sugar
|
blanchet@55698
|
34 |
open BNF_FP_Rec_Sugar_Util
|
blanchet@55698
|
35 |
|
blanchet@55698
|
36 |
val nitpicksimp_attrs = @{attributes [nitpick_simp]};
|
blanchet@55698
|
37 |
val simp_attrs = @{attributes [simp]};
|
blanchet@55698
|
38 |
val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
|
blanchet@55698
|
39 |
|
blanchet@55698
|
40 |
exception Primrec_Error of string * term list;
|
blanchet@55698
|
41 |
|
blanchet@55698
|
42 |
fun primrec_error str = raise Primrec_Error (str, []);
|
blanchet@55698
|
43 |
fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
|
blanchet@55698
|
44 |
fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
|
blanchet@55698
|
45 |
|
blanchet@55698
|
46 |
datatype rec_call =
|
blanchet@55698
|
47 |
No_Rec of int * typ |
|
blanchet@55698
|
48 |
Mutual_Rec of (int * typ) * (int * typ) |
|
blanchet@55698
|
49 |
Nested_Rec of int * typ;
|
blanchet@55698
|
50 |
|
blanchet@55698
|
51 |
type rec_ctr_spec =
|
blanchet@55698
|
52 |
{ctr: term,
|
blanchet@55698
|
53 |
offset: int,
|
blanchet@55698
|
54 |
calls: rec_call list,
|
blanchet@55698
|
55 |
rec_thm: thm};
|
blanchet@55698
|
56 |
|
blanchet@55698
|
57 |
type rec_spec =
|
blanchet@55698
|
58 |
{recx: term,
|
blanchet@55698
|
59 |
nested_map_idents: thm list,
|
blanchet@55698
|
60 |
nested_map_comps: thm list,
|
blanchet@55698
|
61 |
ctr_specs: rec_ctr_spec list};
|
blanchet@55698
|
62 |
|
blanchet@55698
|
63 |
exception AINT_NO_MAP of term;
|
blanchet@55698
|
64 |
|
blanchet@55698
|
65 |
fun ill_formed_rec_call ctxt t =
|
blanchet@55698
|
66 |
error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
|
blanchet@55698
|
67 |
fun invalid_map ctxt t =
|
blanchet@55698
|
68 |
error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
|
blanchet@55698
|
69 |
fun unexpected_rec_call ctxt t =
|
blanchet@55698
|
70 |
error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
|
blanchet@55698
|
71 |
|
blanchet@55698
|
72 |
fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
|
blanchet@55698
|
73 |
let
|
blanchet@55698
|
74 |
fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
|
blanchet@55698
|
75 |
|
blanchet@55698
|
76 |
val typof = curry fastype_of1 bound_Ts;
|
blanchet@55698
|
77 |
val build_map_fst = build_map ctxt (fst_const o fst);
|
blanchet@55698
|
78 |
|
blanchet@55698
|
79 |
val yT = typof y;
|
blanchet@55698
|
80 |
val yU = typof y';
|
blanchet@55698
|
81 |
|
blanchet@55698
|
82 |
fun y_of_y' () = build_map_fst (yU, yT) $ y';
|
blanchet@55698
|
83 |
val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
|
blanchet@55698
|
84 |
|
blanchet@55698
|
85 |
fun massage_mutual_fun U T t =
|
blanchet@55698
|
86 |
(case t of
|
blanchet@55698
|
87 |
Const (@{const_name comp}, _) $ t1 $ t2 =>
|
blanchet@55698
|
88 |
mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
|
blanchet@55698
|
89 |
| _ =>
|
blanchet@55698
|
90 |
if has_call t then
|
blanchet@55698
|
91 |
(case try HOLogic.dest_prodT U of
|
blanchet@55698
|
92 |
SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t
|
blanchet@55698
|
93 |
| NONE => invalid_map ctxt t)
|
blanchet@55698
|
94 |
else
|
blanchet@55698
|
95 |
mk_comp bound_Ts (t, build_map_fst (U, T)));
|
blanchet@55698
|
96 |
|
blanchet@55698
|
97 |
fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
|
blanchet@55698
|
98 |
(case try (dest_map ctxt s) t of
|
blanchet@55698
|
99 |
SOME (map0, fs) =>
|
blanchet@55698
|
100 |
let
|
blanchet@55698
|
101 |
val Type (_, ran_Ts) = range_type (typof t);
|
blanchet@55698
|
102 |
val map' = mk_map (length fs) Us ran_Ts map0;
|
blanchet@55698
|
103 |
val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
|
blanchet@55698
|
104 |
in
|
blanchet@55698
|
105 |
Term.list_comb (map', fs')
|
blanchet@55698
|
106 |
end
|
blanchet@55698
|
107 |
| NONE => raise AINT_NO_MAP t)
|
blanchet@55698
|
108 |
| massage_map _ _ t = raise AINT_NO_MAP t
|
blanchet@55698
|
109 |
and massage_map_or_map_arg U T t =
|
blanchet@55698
|
110 |
if T = U then
|
blanchet@55698
|
111 |
tap check_no_call t
|
blanchet@55698
|
112 |
else
|
blanchet@55698
|
113 |
massage_map U T t
|
blanchet@55698
|
114 |
handle AINT_NO_MAP _ => massage_mutual_fun U T t;
|
blanchet@55698
|
115 |
|
blanchet@55698
|
116 |
fun massage_call (t as t1 $ t2) =
|
blanchet@55698
|
117 |
if has_call t then
|
blanchet@55698
|
118 |
if t2 = y then
|
blanchet@55698
|
119 |
massage_map yU yT (elim_y t1) $ y'
|
blanchet@55698
|
120 |
handle AINT_NO_MAP t' => invalid_map ctxt t'
|
blanchet@55698
|
121 |
else
|
blanchet@55698
|
122 |
let val (g, xs) = Term.strip_comb t2 in
|
blanchet@55698
|
123 |
if g = y then
|
blanchet@55698
|
124 |
if exists has_call xs then unexpected_rec_call ctxt t2
|
blanchet@55698
|
125 |
else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
|
blanchet@55698
|
126 |
else
|
blanchet@55698
|
127 |
ill_formed_rec_call ctxt t
|
blanchet@55698
|
128 |
end
|
blanchet@55698
|
129 |
else
|
blanchet@55698
|
130 |
elim_y t
|
blanchet@55698
|
131 |
| massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
|
blanchet@55698
|
132 |
in
|
blanchet@55698
|
133 |
massage_call
|
blanchet@55698
|
134 |
end;
|
blanchet@55698
|
135 |
|
blanchet@55698
|
136 |
fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
|
blanchet@55698
|
137 |
let
|
blanchet@55698
|
138 |
val thy = Proof_Context.theory_of lthy;
|
blanchet@55698
|
139 |
|
blanchet@55698
|
140 |
val ((missing_arg_Ts, perm0_kks,
|
blanchet@55698
|
141 |
fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
|
blanchet@55698
|
142 |
co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy') =
|
blanchet@55698
|
143 |
nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy;
|
blanchet@55698
|
144 |
|
blanchet@55698
|
145 |
val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
|
blanchet@55698
|
146 |
|
blanchet@55698
|
147 |
val indices = map #index fp_sugars;
|
blanchet@55698
|
148 |
val perm_indices = map #index perm_fp_sugars;
|
blanchet@55698
|
149 |
|
blanchet@55698
|
150 |
val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
|
blanchet@55698
|
151 |
val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
|
blanchet@55698
|
152 |
val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
|
blanchet@55698
|
153 |
|
blanchet@55698
|
154 |
val nn0 = length arg_Ts;
|
blanchet@55698
|
155 |
val nn = length perm_lfpTs;
|
blanchet@55698
|
156 |
val kks = 0 upto nn - 1;
|
blanchet@55698
|
157 |
val perm_ns = map length perm_ctr_Tsss;
|
blanchet@55698
|
158 |
val perm_mss = map (map length) perm_ctr_Tsss;
|
blanchet@55698
|
159 |
|
blanchet@55698
|
160 |
val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
|
blanchet@55698
|
161 |
perm_fp_sugars;
|
blanchet@55698
|
162 |
val perm_fun_arg_Tssss =
|
blanchet@55698
|
163 |
mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
|
blanchet@55698
|
164 |
|
blanchet@55698
|
165 |
fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
|
blanchet@55698
|
166 |
fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
|
blanchet@55698
|
167 |
|
blanchet@55698
|
168 |
val induct_thms = unpermute0 (conj_dests nn induct_thm);
|
blanchet@55698
|
169 |
|
blanchet@55698
|
170 |
val lfpTs = unpermute perm_lfpTs;
|
blanchet@55698
|
171 |
val Cs = unpermute perm_Cs;
|
blanchet@55698
|
172 |
|
blanchet@55698
|
173 |
val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
|
blanchet@55698
|
174 |
val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
|
blanchet@55698
|
175 |
|
blanchet@55698
|
176 |
val substA = Term.subst_TVars As_rho;
|
blanchet@55698
|
177 |
val substAT = Term.typ_subst_TVars As_rho;
|
blanchet@55698
|
178 |
val substCT = Term.typ_subst_TVars Cs_rho;
|
blanchet@55698
|
179 |
val substACT = substAT o substCT;
|
blanchet@55698
|
180 |
|
blanchet@55698
|
181 |
val perm_Cs' = map substCT perm_Cs;
|
blanchet@55698
|
182 |
|
blanchet@55698
|
183 |
fun offset_of_ctr 0 _ = 0
|
blanchet@55698
|
184 |
| offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) =
|
blanchet@55698
|
185 |
length ctrs + offset_of_ctr (n - 1) ctr_sugars;
|
blanchet@55698
|
186 |
|
blanchet@55698
|
187 |
fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
|
blanchet@55698
|
188 |
| call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
|
blanchet@55698
|
189 |
|
blanchet@55698
|
190 |
fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
|
blanchet@55698
|
191 |
let
|
blanchet@55698
|
192 |
val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
|
blanchet@55698
|
193 |
val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
|
blanchet@55698
|
194 |
val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
|
blanchet@55698
|
195 |
in
|
blanchet@55698
|
196 |
{ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
|
blanchet@55698
|
197 |
rec_thm = rec_thm}
|
blanchet@55698
|
198 |
end;
|
blanchet@55698
|
199 |
|
blanchet@55698
|
200 |
fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss =
|
blanchet@55698
|
201 |
let
|
blanchet@55698
|
202 |
val ctrs = #ctrs (nth ctr_sugars index);
|
blanchet@55698
|
203 |
val rec_thmss = co_rec_of (nth iter_thmsss index);
|
blanchet@55698
|
204 |
val k = offset_of_ctr index ctr_sugars;
|
blanchet@55698
|
205 |
val n = length ctrs;
|
blanchet@55698
|
206 |
in
|
blanchet@55698
|
207 |
map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss
|
blanchet@55698
|
208 |
end;
|
blanchet@55698
|
209 |
|
blanchet@55698
|
210 |
fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...}
|
blanchet@55698
|
211 |
: fp_sugar) =
|
blanchet@55698
|
212 |
{recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
|
blanchet@55698
|
213 |
nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs,
|
blanchet@55698
|
214 |
nested_map_comps = map map_comp_of_bnf nested_bnfs,
|
blanchet@55698
|
215 |
ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
|
blanchet@55698
|
216 |
in
|
blanchet@55698
|
217 |
((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms),
|
blanchet@55698
|
218 |
lthy')
|
blanchet@55698
|
219 |
end;
|
blanchet@55698
|
220 |
|
blanchet@55698
|
221 |
val undef_const = Const (@{const_name undefined}, dummyT);
|
blanchet@55698
|
222 |
|
blanchet@55698
|
223 |
fun permute_args n t =
|
blanchet@55698
|
224 |
list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
|
blanchet@55698
|
225 |
|
blanchet@55698
|
226 |
type eqn_data = {
|
blanchet@55698
|
227 |
fun_name: string,
|
blanchet@55698
|
228 |
rec_type: typ,
|
blanchet@55698
|
229 |
ctr: term,
|
blanchet@55698
|
230 |
ctr_args: term list,
|
blanchet@55698
|
231 |
left_args: term list,
|
blanchet@55698
|
232 |
right_args: term list,
|
blanchet@55698
|
233 |
res_type: typ,
|
blanchet@55698
|
234 |
rhs_term: term,
|
blanchet@55698
|
235 |
user_eqn: term
|
blanchet@55698
|
236 |
};
|
blanchet@55698
|
237 |
|
blanchet@55698
|
238 |
fun dissect_eqn lthy fun_names eqn' =
|
blanchet@55698
|
239 |
let
|
blanchet@55698
|
240 |
val eqn = drop_All eqn' |> HOLogic.dest_Trueprop
|
blanchet@55698
|
241 |
handle TERM _ =>
|
blanchet@55698
|
242 |
primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
|
blanchet@55698
|
243 |
val (lhs, rhs) = HOLogic.dest_eq eqn
|
blanchet@55698
|
244 |
handle TERM _ =>
|
blanchet@55698
|
245 |
primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
|
blanchet@55698
|
246 |
val (fun_name, args) = strip_comb lhs
|
blanchet@55698
|
247 |
|>> (fn x => if is_Free x then fst (dest_Free x)
|
blanchet@55698
|
248 |
else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
|
blanchet@55698
|
249 |
val (left_args, rest) = take_prefix is_Free args;
|
blanchet@55698
|
250 |
val (nonfrees, right_args) = take_suffix is_Free rest;
|
blanchet@55698
|
251 |
val num_nonfrees = length nonfrees;
|
blanchet@55698
|
252 |
val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
|
blanchet@55698
|
253 |
primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
|
blanchet@55698
|
254 |
primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
|
blanchet@55698
|
255 |
val _ = member (op =) fun_names fun_name orelse
|
blanchet@55698
|
256 |
primrec_error_eqn "malformed function equation (does not start with function name)" eqn
|
blanchet@55698
|
257 |
|
blanchet@55698
|
258 |
val (ctr, ctr_args) = strip_comb (the_single nonfrees);
|
blanchet@55698
|
259 |
val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
|
blanchet@55698
|
260 |
primrec_error_eqn "partially applied constructor in pattern" eqn;
|
blanchet@55698
|
261 |
val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
|
blanchet@55698
|
262 |
primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
|
blanchet@55698
|
263 |
"\" in left-hand side") eqn end;
|
blanchet@55698
|
264 |
val _ = forall is_Free ctr_args orelse
|
blanchet@55698
|
265 |
primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
|
blanchet@55698
|
266 |
val _ =
|
blanchet@55698
|
267 |
let val b = fold_aterms (fn x as Free (v, _) =>
|
blanchet@55698
|
268 |
if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
|
blanchet@55698
|
269 |
not (member (op =) fun_names v) andalso
|
blanchet@55698
|
270 |
not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
|
blanchet@55698
|
271 |
in
|
blanchet@55698
|
272 |
null b orelse
|
blanchet@55698
|
273 |
primrec_error_eqn ("extra variable(s) in right-hand side: " ^
|
blanchet@55698
|
274 |
commas (map (Syntax.string_of_term lthy) b)) eqn
|
blanchet@55698
|
275 |
end;
|
blanchet@55698
|
276 |
in
|
blanchet@55698
|
277 |
{fun_name = fun_name,
|
blanchet@55698
|
278 |
rec_type = body_type (type_of ctr),
|
blanchet@55698
|
279 |
ctr = ctr,
|
blanchet@55698
|
280 |
ctr_args = ctr_args,
|
blanchet@55698
|
281 |
left_args = left_args,
|
blanchet@55698
|
282 |
right_args = right_args,
|
blanchet@55698
|
283 |
res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
|
blanchet@55698
|
284 |
rhs_term = rhs,
|
blanchet@55698
|
285 |
user_eqn = eqn'}
|
blanchet@55698
|
286 |
end;
|
blanchet@55698
|
287 |
|
blanchet@55698
|
288 |
fun rewrite_map_arg get_ctr_pos rec_type res_type =
|
blanchet@55698
|
289 |
let
|
blanchet@55698
|
290 |
val pT = HOLogic.mk_prodT (rec_type, res_type);
|
blanchet@55698
|
291 |
|
blanchet@55698
|
292 |
val maybe_suc = Option.map (fn x => x + 1);
|
blanchet@55698
|
293 |
fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
|
blanchet@55698
|
294 |
| subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
|
blanchet@55698
|
295 |
| subst d t =
|
blanchet@55698
|
296 |
let
|
blanchet@55698
|
297 |
val (u, vs) = strip_comb t;
|
blanchet@55698
|
298 |
val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
|
blanchet@55698
|
299 |
in
|
blanchet@55698
|
300 |
if ctr_pos >= 0 then
|
blanchet@55698
|
301 |
if d = SOME ~1 andalso length vs = ctr_pos then
|
blanchet@55698
|
302 |
list_comb (permute_args ctr_pos (snd_const pT), vs)
|
blanchet@55698
|
303 |
else if length vs > ctr_pos andalso is_some d
|
blanchet@55698
|
304 |
andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
|
blanchet@55698
|
305 |
list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
|
blanchet@55698
|
306 |
else
|
blanchet@55698
|
307 |
primrec_error_eqn ("recursive call not directly applied to constructor argument") t
|
blanchet@55698
|
308 |
else
|
blanchet@55698
|
309 |
list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
|
blanchet@55698
|
310 |
end
|
blanchet@55698
|
311 |
in
|
blanchet@55698
|
312 |
subst (SOME ~1)
|
blanchet@55698
|
313 |
end;
|
blanchet@55698
|
314 |
|
blanchet@55698
|
315 |
fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
|
blanchet@55698
|
316 |
let
|
blanchet@55698
|
317 |
fun try_nested_rec bound_Ts y t =
|
blanchet@55698
|
318 |
AList.lookup (op =) nested_calls y
|
blanchet@55698
|
319 |
|> Option.map (fn y' =>
|
blanchet@55698
|
320 |
massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
|
blanchet@55698
|
321 |
|
blanchet@55698
|
322 |
fun subst bound_Ts (t as g' $ y) =
|
blanchet@55698
|
323 |
let
|
blanchet@55698
|
324 |
fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
|
blanchet@55698
|
325 |
val y_head = head_of y;
|
blanchet@55698
|
326 |
in
|
blanchet@55698
|
327 |
if not (member (op =) ctr_args y_head) then
|
blanchet@55698
|
328 |
subst_rec ()
|
blanchet@55698
|
329 |
else
|
blanchet@55698
|
330 |
(case try_nested_rec bound_Ts y_head t of
|
blanchet@55698
|
331 |
SOME t' => t'
|
blanchet@55698
|
332 |
| NONE =>
|
blanchet@55698
|
333 |
let val (g, g_args) = strip_comb g' in
|
blanchet@55698
|
334 |
(case try (get_ctr_pos o fst o dest_Free) g of
|
blanchet@55698
|
335 |
SOME ctr_pos =>
|
blanchet@55698
|
336 |
(length g_args >= ctr_pos orelse
|
blanchet@55698
|
337 |
primrec_error_eqn "too few arguments in recursive call" t;
|
blanchet@55698
|
338 |
(case AList.lookup (op =) mutual_calls y of
|
blanchet@55698
|
339 |
SOME y' => list_comb (y', g_args)
|
blanchet@55698
|
340 |
| NONE => subst_rec ()))
|
blanchet@55698
|
341 |
| NONE => subst_rec ())
|
blanchet@55698
|
342 |
end)
|
blanchet@55698
|
343 |
end
|
blanchet@55698
|
344 |
| subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
|
blanchet@55698
|
345 |
| subst _ t = t
|
blanchet@55698
|
346 |
|
blanchet@55698
|
347 |
fun subst' t =
|
blanchet@55698
|
348 |
if has_call t then
|
blanchet@55698
|
349 |
(* FIXME detect this case earlier? *)
|
blanchet@55698
|
350 |
primrec_error_eqn "recursive call not directly applied to constructor argument" t
|
blanchet@55698
|
351 |
else
|
blanchet@55698
|
352 |
try_nested_rec [] (head_of t) t |> the_default t
|
blanchet@55698
|
353 |
in
|
blanchet@55698
|
354 |
subst' o subst []
|
blanchet@55698
|
355 |
end;
|
blanchet@55698
|
356 |
|
blanchet@55698
|
357 |
fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
|
blanchet@55698
|
358 |
(maybe_eqn_data : eqn_data option) =
|
blanchet@55698
|
359 |
(case maybe_eqn_data of
|
blanchet@55698
|
360 |
NONE => undef_const
|
blanchet@55698
|
361 |
| SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
|
blanchet@55698
|
362 |
let
|
blanchet@55698
|
363 |
val calls = #calls ctr_spec;
|
blanchet@55698
|
364 |
val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
|
blanchet@55698
|
365 |
|
blanchet@55698
|
366 |
val no_calls' = tag_list 0 calls
|
blanchet@55698
|
367 |
|> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
|
blanchet@55698
|
368 |
val mutual_calls' = tag_list 0 calls
|
blanchet@55698
|
369 |
|> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
|
blanchet@55698
|
370 |
val nested_calls' = tag_list 0 calls
|
blanchet@55698
|
371 |
|> map_filter (try (apsnd (fn Nested_Rec p => p)));
|
blanchet@55698
|
372 |
|
blanchet@55698
|
373 |
val args = replicate n_args ("", dummyT)
|
blanchet@55698
|
374 |
|> Term.rename_wrt_term t
|
blanchet@55698
|
375 |
|> map Free
|
blanchet@55698
|
376 |
|> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
|
blanchet@55698
|
377 |
nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
|
blanchet@55698
|
378 |
no_calls'
|
blanchet@55698
|
379 |
|> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
|
blanchet@55698
|
380 |
nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
|
blanchet@55698
|
381 |
mutual_calls'
|
blanchet@55698
|
382 |
|> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
|
blanchet@55698
|
383 |
nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
|
blanchet@55698
|
384 |
nested_calls';
|
blanchet@55698
|
385 |
|
blanchet@55698
|
386 |
val fun_name_ctr_pos_list =
|
blanchet@55698
|
387 |
map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
|
blanchet@55698
|
388 |
val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
|
blanchet@55698
|
389 |
val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
|
blanchet@55698
|
390 |
val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
|
blanchet@55698
|
391 |
in
|
blanchet@55698
|
392 |
t
|
blanchet@55698
|
393 |
|> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
|
blanchet@55698
|
394 |
|> fold_rev lambda (args @ left_args @ right_args)
|
blanchet@55698
|
395 |
end);
|
blanchet@55698
|
396 |
|
blanchet@55698
|
397 |
fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
|
blanchet@55698
|
398 |
let
|
blanchet@55698
|
399 |
val n_funs = length funs_data;
|
blanchet@55698
|
400 |
|
blanchet@55698
|
401 |
val ctr_spec_eqn_data_list' =
|
blanchet@55698
|
402 |
(take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
|
blanchet@55698
|
403 |
|> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
|
blanchet@55698
|
404 |
##> (fn x => null x orelse
|
blanchet@55698
|
405 |
primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
|
blanchet@55698
|
406 |
val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
|
blanchet@55698
|
407 |
primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
|
blanchet@55698
|
408 |
|
blanchet@55698
|
409 |
val ctr_spec_eqn_data_list =
|
blanchet@55698
|
410 |
ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
|
blanchet@55698
|
411 |
|
blanchet@55698
|
412 |
val recs = take n_funs rec_specs |> map #recx;
|
blanchet@55698
|
413 |
val rec_args = ctr_spec_eqn_data_list
|
blanchet@55698
|
414 |
|> sort ((op <) o pairself (#offset o fst) |> make_ord)
|
blanchet@55698
|
415 |
|> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
|
blanchet@55698
|
416 |
val ctr_poss = map (fn x =>
|
blanchet@55698
|
417 |
if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
|
blanchet@55698
|
418 |
primrec_error ("inconstant constructor pattern position for function " ^
|
blanchet@55698
|
419 |
quote (#fun_name (hd x)))
|
blanchet@55698
|
420 |
else
|
blanchet@55698
|
421 |
hd x |> #left_args |> length) funs_data;
|
blanchet@55698
|
422 |
in
|
blanchet@55698
|
423 |
(recs, ctr_poss)
|
blanchet@55698
|
424 |
|-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
|
blanchet@55698
|
425 |
|> Syntax.check_terms lthy
|
blanchet@55698
|
426 |
|> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
|
blanchet@55698
|
427 |
bs mxs
|
blanchet@55698
|
428 |
end;
|
blanchet@55698
|
429 |
|
blanchet@55698
|
430 |
fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
|
blanchet@55698
|
431 |
let
|
blanchet@55698
|
432 |
fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
|
blanchet@55698
|
433 |
| find bound_Ts (t as _ $ _) ctr_arg =
|
blanchet@55698
|
434 |
let
|
blanchet@55698
|
435 |
val typof = curry fastype_of1 bound_Ts;
|
blanchet@55698
|
436 |
val (f', args') = strip_comb t;
|
blanchet@55698
|
437 |
val n = find_index (equal ctr_arg o head_of) args';
|
blanchet@55698
|
438 |
in
|
blanchet@55698
|
439 |
if n < 0 then
|
blanchet@55698
|
440 |
find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
|
blanchet@55698
|
441 |
else
|
blanchet@55698
|
442 |
let
|
blanchet@55698
|
443 |
val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
|
blanchet@55698
|
444 |
val (arg_head, arg_args) = Term.strip_comb arg;
|
blanchet@55698
|
445 |
in
|
blanchet@55698
|
446 |
if has_call f then
|
blanchet@55698
|
447 |
mk_partial_compN (length arg_args) (typof arg_head) f ::
|
blanchet@55698
|
448 |
maps (fn x => find bound_Ts x ctr_arg) args
|
blanchet@55698
|
449 |
else
|
blanchet@55698
|
450 |
find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
|
blanchet@55698
|
451 |
end
|
blanchet@55698
|
452 |
end
|
blanchet@55698
|
453 |
| find _ _ _ = [];
|
blanchet@55698
|
454 |
in
|
blanchet@55698
|
455 |
map (find [] rhs_term) ctr_args
|
blanchet@55698
|
456 |
|> (fn [] => NONE | callss => SOME (ctr, callss))
|
blanchet@55698
|
457 |
end;
|
blanchet@55698
|
458 |
|
blanchet@55698
|
459 |
fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
|
blanchet@55698
|
460 |
unfold_thms_tac ctxt fun_defs THEN
|
blanchet@55698
|
461 |
HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
|
blanchet@55698
|
462 |
unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN
|
blanchet@55698
|
463 |
HEADGOAL (rtac refl);
|
blanchet@55698
|
464 |
|
blanchet@55698
|
465 |
fun prepare_primrec fixes specs lthy =
|
blanchet@55698
|
466 |
let
|
blanchet@55698
|
467 |
val (bs, mxs) = map_split (apfst fst) fixes;
|
blanchet@55698
|
468 |
val fun_names = map Binding.name_of bs;
|
blanchet@55698
|
469 |
val eqns_data = map (dissect_eqn lthy fun_names) specs;
|
blanchet@55698
|
470 |
val funs_data = eqns_data
|
blanchet@55698
|
471 |
|> partition_eq ((op =) o pairself #fun_name)
|
blanchet@55698
|
472 |
|> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
|
blanchet@55698
|
473 |
|> map (fn (x, y) => the_single y handle List.Empty =>
|
blanchet@55698
|
474 |
primrec_error ("missing equations for function " ^ quote x));
|
blanchet@55698
|
475 |
|
blanchet@55698
|
476 |
val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
|
blanchet@55698
|
477 |
val arg_Ts = map (#rec_type o hd) funs_data;
|
blanchet@55698
|
478 |
val res_Ts = map (#res_type o hd) funs_data;
|
blanchet@55698
|
479 |
val callssss = funs_data
|
blanchet@55698
|
480 |
|> map (partition_eq ((op =) o pairself #ctr))
|
blanchet@55698
|
481 |
|> map (maps (map_filter (find_rec_calls has_call)));
|
blanchet@55698
|
482 |
|
blanchet@55698
|
483 |
val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
|
blanchet@55698
|
484 |
rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
|
blanchet@55698
|
485 |
|
blanchet@55698
|
486 |
val actual_nn = length funs_data;
|
blanchet@55698
|
487 |
|
blanchet@55698
|
488 |
val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
|
blanchet@55698
|
489 |
map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
|
blanchet@55698
|
490 |
primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
|
blanchet@55698
|
491 |
" is not a constructor in left-hand side") user_eqn) eqns_data end;
|
blanchet@55698
|
492 |
|
blanchet@55698
|
493 |
val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
|
blanchet@55698
|
494 |
|
blanchet@55698
|
495 |
fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
|
blanchet@55698
|
496 |
(fun_data : eqn_data list) =
|
blanchet@55698
|
497 |
let
|
blanchet@55698
|
498 |
val def_thms = map (snd o snd) def_thms';
|
blanchet@55698
|
499 |
val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
|
blanchet@55698
|
500 |
|> fst
|
blanchet@55698
|
501 |
|> map_filter (try (fn (x, [y]) =>
|
blanchet@55698
|
502 |
(#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
|
blanchet@55698
|
503 |
|> map (fn (user_eqn, num_extra_args, rec_thm) =>
|
blanchet@55698
|
504 |
mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
|
blanchet@55698
|
505 |
|> K |> Goal.prove lthy [] [] user_eqn
|
blanchet@55698
|
506 |
|> Thm.close_derivation);
|
blanchet@55698
|
507 |
val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
|
blanchet@55698
|
508 |
in
|
blanchet@55698
|
509 |
(poss, simp_thmss)
|
blanchet@55698
|
510 |
end;
|
blanchet@55698
|
511 |
|
blanchet@55698
|
512 |
val notes =
|
blanchet@55698
|
513 |
(if n2m then map2 (fn name => fn thm =>
|
blanchet@55698
|
514 |
(name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
|
blanchet@55698
|
515 |
|> map (fn (prefix, thmN, thms, attrs) =>
|
blanchet@55698
|
516 |
((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
|
blanchet@55698
|
517 |
|
blanchet@55698
|
518 |
val common_name = mk_common_name fun_names;
|
blanchet@55698
|
519 |
|
blanchet@55698
|
520 |
val common_notes =
|
blanchet@55698
|
521 |
(if n2m then [(inductN, [induct_thm], [])] else [])
|
blanchet@55698
|
522 |
|> map (fn (thmN, thms, attrs) =>
|
blanchet@55698
|
523 |
((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
|
blanchet@55698
|
524 |
in
|
blanchet@55698
|
525 |
(((fun_names, defs),
|
blanchet@55698
|
526 |
fn lthy => fn defs =>
|
blanchet@55698
|
527 |
split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
|
blanchet@55698
|
528 |
lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
|
blanchet@55698
|
529 |
end;
|
blanchet@55698
|
530 |
|
blanchet@55698
|
531 |
(* primrec definition *)
|
blanchet@55698
|
532 |
|
blanchet@55698
|
533 |
fun add_primrec_simple fixes ts lthy =
|
blanchet@55698
|
534 |
let
|
blanchet@55698
|
535 |
val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
|
blanchet@55698
|
536 |
handle ERROR str => primrec_error str;
|
blanchet@55698
|
537 |
in
|
blanchet@55698
|
538 |
lthy
|
blanchet@55698
|
539 |
|> fold_map Local_Theory.define defs
|
blanchet@55698
|
540 |
|-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
|
blanchet@55698
|
541 |
end
|
blanchet@55698
|
542 |
handle Primrec_Error (str, eqns) =>
|
blanchet@55698
|
543 |
if null eqns
|
blanchet@55698
|
544 |
then error ("primrec_new error:\n " ^ str)
|
blanchet@55698
|
545 |
else error ("primrec_new error:\n " ^ str ^ "\nin\n " ^
|
blanchet@55698
|
546 |
space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns));
|
blanchet@55698
|
547 |
|
blanchet@55698
|
548 |
local
|
blanchet@55698
|
549 |
|
blanchet@55698
|
550 |
fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
|
blanchet@55698
|
551 |
let
|
blanchet@55698
|
552 |
val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
|
blanchet@55698
|
553 |
val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
|
blanchet@55698
|
554 |
|
blanchet@55698
|
555 |
val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
|
blanchet@55698
|
556 |
|
blanchet@55698
|
557 |
val mk_notes =
|
blanchet@55698
|
558 |
flat ooo map3 (fn poss => fn prefix => fn thms =>
|
blanchet@55698
|
559 |
let
|
blanchet@55698
|
560 |
val (bs, attrss) = map_split (fst o nth specs) poss;
|
blanchet@55698
|
561 |
val notes =
|
blanchet@55698
|
562 |
map3 (fn b => fn attrs => fn thm =>
|
blanchet@55698
|
563 |
((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])]))
|
blanchet@55698
|
564 |
bs attrss thms;
|
blanchet@55698
|
565 |
in
|
blanchet@55698
|
566 |
((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
|
blanchet@55698
|
567 |
end);
|
blanchet@55698
|
568 |
in
|
blanchet@55698
|
569 |
lthy
|
blanchet@55698
|
570 |
|> add_primrec_simple fixes (map snd specs)
|
blanchet@55698
|
571 |
|-> (fn (names, (ts, (posss, simpss))) =>
|
blanchet@55698
|
572 |
Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
|
blanchet@55698
|
573 |
#> Local_Theory.notes (mk_notes posss names simpss)
|
blanchet@55698
|
574 |
#>> pair ts o map snd)
|
blanchet@55698
|
575 |
end;
|
blanchet@55698
|
576 |
|
blanchet@55698
|
577 |
in
|
blanchet@55698
|
578 |
|
blanchet@55698
|
579 |
val add_primrec = gen_primrec Specification.check_spec;
|
blanchet@55698
|
580 |
val add_primrec_cmd = gen_primrec Specification.read_spec;
|
blanchet@55698
|
581 |
|
blanchet@55698
|
582 |
end;
|
blanchet@55698
|
583 |
|
blanchet@55698
|
584 |
fun add_primrec_global fixes specs thy =
|
blanchet@55698
|
585 |
let
|
blanchet@55698
|
586 |
val lthy = Named_Target.theory_init thy;
|
blanchet@55698
|
587 |
val ((ts, simps), lthy') = add_primrec fixes specs lthy;
|
blanchet@55698
|
588 |
val simps' = burrow (Proof_Context.export lthy' lthy) simps;
|
blanchet@55698
|
589 |
in ((ts, simps'), Local_Theory.exit_global lthy') end;
|
blanchet@55698
|
590 |
|
blanchet@55698
|
591 |
fun add_primrec_overloaded ops fixes specs thy =
|
blanchet@55698
|
592 |
let
|
blanchet@55698
|
593 |
val lthy = Overloading.overloading ops thy;
|
blanchet@55698
|
594 |
val ((ts, simps), lthy') = add_primrec fixes specs lthy;
|
blanchet@55698
|
595 |
val simps' = burrow (Proof_Context.export lthy' lthy) simps;
|
blanchet@55698
|
596 |
in ((ts, simps'), Local_Theory.exit_global lthy') end;
|
blanchet@55698
|
597 |
|
blanchet@55698
|
598 |
end;
|