renamed old-style Drule.standard to Drule.export_without_context, to emphasize that this is in no way a standard operation;
1 (* Title: HOL/Tools/Predicate_Compile/predicate_compile_core.ML
2 Author: Lukas Bulwahn, TU Muenchen
4 A compiler from predicates specified by intro/elim rules to equations.
7 signature PREDICATE_COMPILE_CORE =
9 val setup : theory -> theory
10 val code_pred : Predicate_Compile_Aux.options -> string -> Proof.context -> Proof.state
11 val code_pred_cmd : Predicate_Compile_Aux.options -> string -> Proof.context -> Proof.state
12 val values_cmd : string list -> Predicate_Compile_Aux.mode option list option
13 -> (string option * (Predicate_Compile_Aux.compilation * int list))
14 -> int -> string -> Toplevel.state -> unit
15 val register_predicate : (string * thm list * thm) -> theory -> theory
16 val register_intros : string * thm list -> theory -> theory
17 val is_registered : theory -> string -> bool
18 val function_name_of : Predicate_Compile_Aux.compilation -> theory
19 -> string -> Predicate_Compile_Aux.mode -> string
20 val predfun_intro_of: theory -> string -> Predicate_Compile_Aux.mode -> thm
21 val predfun_elim_of: theory -> string -> Predicate_Compile_Aux.mode -> thm
22 val all_preds_of : theory -> string list
23 val modes_of: Predicate_Compile_Aux.compilation
24 -> theory -> string -> Predicate_Compile_Aux.mode list
25 val all_modes_of : Predicate_Compile_Aux.compilation
26 -> theory -> (string * Predicate_Compile_Aux.mode list) list
27 val all_random_modes_of : theory -> (string * Predicate_Compile_Aux.mode list) list
28 val intros_of : theory -> string -> thm list
29 val add_intro : thm -> theory -> theory
30 val set_elim : thm -> theory -> theory
31 val preprocess_intro : theory -> thm -> thm
32 val print_stored_rules : theory -> unit
33 val print_all_modes : Predicate_Compile_Aux.compilation -> theory -> unit
34 val mk_casesrule : Proof.context -> term -> thm list -> term
36 val eval_ref : (unit -> term Predicate.pred) option Unsynchronized.ref
37 val random_eval_ref : (unit -> int * int -> term Predicate.pred * (int * int))
38 option Unsynchronized.ref
39 val dseq_eval_ref : (unit -> term DSequence.dseq) option Unsynchronized.ref
40 val random_dseq_eval_ref : (unit -> int -> int -> int * int -> term DSequence.dseq * (int * int))
41 option Unsynchronized.ref
42 val code_pred_intro_attrib : attribute
44 (* used by Quickcheck_Generator *)
45 (* temporary for testing of the compilation *)
47 datatype compilation_funs = CompilationFuns of {
48 mk_predT : typ -> typ,
49 dest_predT : typ -> typ,
51 mk_single : term -> term,
52 mk_bind : term * term -> term,
53 mk_sup : term * term -> term,
55 mk_not : term -> term,
56 mk_map : typ -> typ -> term -> term -> term
59 val pred_compfuns : compilation_funs
60 val randompred_compfuns : compilation_funs
61 val add_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
62 val add_random_dseq_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
63 val mk_tracing : string -> term -> term
66 structure Predicate_Compile_Core : PREDICATE_COMPILE_CORE =
69 open Predicate_Compile_Aux;
75 fun print_tac s = Seq.single;
77 fun print_tac' options s =
78 if show_proof_trace options then Tactical.print_tac s else Seq.single;
80 fun debug_tac msg = Seq.single; (* (fn st => (Output.tracing msg; Seq.single st)); *)
82 fun assert b = if not b then error "Assertion failed" else warning "Assertion holds"
84 datatype assertion = Max_number_of_subgoals of int
85 fun assert_tac (Max_number_of_subgoals i) st =
86 if (nprems_of st <= i) then Seq.single st
87 else error ("assert_tac: Numbers of subgoals mismatch at goal state :"
88 ^ "\n" ^ Pretty.string_of (Pretty.chunks
89 (Goal_Display.pretty_goals_without_context (! Goal_Display.goals_limit) st)));
93 (* syntactic operations *)
96 let fun mk_eqs _ [] = []
98 HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
101 fun mk_scomp (t, u) =
105 val [A] = binder_types T
108 Const (@{const_name "scomp"}, T --> U --> A --> D) $ t $ u
111 fun dest_funT (Type ("fun",[S, T])) = (S, T)
112 | dest_funT T = raise TYPE ("dest_funT", [T], [])
114 fun mk_fun_comp (t, u) =
116 val (_, B) = dest_funT (fastype_of t)
117 val (C, A) = dest_funT (fastype_of u)
119 Const(@{const_name "Fun.comp"}, (A --> B) --> (C --> A) --> C --> B) $ t $ u
122 fun dest_randomT (Type ("fun", [@{typ Random.seed},
123 Type ("*", [Type ("*", [T, @{typ "unit => Code_Evaluation.term"}]) ,@{typ Random.seed}])])) = T
124 | dest_randomT T = raise TYPE ("dest_randomT", [T], [])
127 Const(@{const_name Code_Evaluation.tracing},
128 @{typ String.literal} --> (fastype_of t) --> (fastype_of t)) $ (HOLogic.mk_literal s) $ t
130 val strip_intro_concl = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of)
132 (* derivation trees for modes of premises *)
134 datatype mode_derivation = Mode_App of mode_derivation * mode_derivation | Context of mode
135 | Mode_Pair of mode_derivation * mode_derivation | Term of mode
137 fun string_of_derivation (Mode_App (m1, m2)) =
138 "App (" ^ string_of_derivation m1 ^ ", " ^ string_of_derivation m2 ^ ")"
139 | string_of_derivation (Mode_Pair (m1, m2)) =
140 "Pair (" ^ string_of_derivation m1 ^ ", " ^ string_of_derivation m2 ^ ")"
141 | string_of_derivation (Term m) = "Term (" ^ string_of_mode m ^ ")"
142 | string_of_derivation (Context m) = "Context (" ^ string_of_mode m ^ ")"
144 fun strip_mode_derivation deriv =
146 fun strip (Mode_App (deriv1, deriv2)) ds = strip deriv1 (deriv2 :: ds)
147 | strip deriv ds = (deriv, ds)
152 fun mode_of (Context m) = m
153 | mode_of (Term m) = m
154 | mode_of (Mode_App (d1, d2)) =
155 (case mode_of d1 of Fun (m, m') =>
156 (if m = mode_of d2 then m' else error "mode_of")
157 | _ => error "mode_of2")
158 | mode_of (Mode_Pair (d1, d2)) =
159 Pair (mode_of d1, mode_of d2)
161 fun head_mode_of deriv = mode_of (fst (strip_mode_derivation deriv))
163 fun param_derivations_of deriv =
165 val (_, argument_derivs) = strip_mode_derivation deriv
166 fun param_derivation (Mode_Pair (m1, m2)) =
167 param_derivation m1 @ param_derivation m2
168 | param_derivation (Term _) = []
169 | param_derivation m = [m]
171 maps param_derivation argument_derivs
174 fun collect_context_modes (Mode_App (m1, m2)) =
175 collect_context_modes m1 @ collect_context_modes m2
176 | collect_context_modes (Mode_Pair (m1, m2)) =
177 collect_context_modes m1 @ collect_context_modes m2
178 | collect_context_modes (Context m) = [m]
179 | collect_context_modes (Term _) = []
181 (* representation of inferred clauses with modes *)
183 type moded_clause = term list * (indprem * mode_derivation) list
185 type 'a pred_mode_table = (string * (mode * 'a) list) list
189 datatype predfun_data = PredfunData of {
195 fun rep_predfun_data (PredfunData data) = data;
197 fun mk_predfun_data (definition, intro, elim) =
198 PredfunData {definition = definition, intro = intro, elim = elim}
200 datatype pred_data = PredData of {
203 function_names : (compilation * (mode * string) list) list,
204 predfun_data : (mode * predfun_data) list,
205 needs_random : mode list
208 fun rep_pred_data (PredData data) = data;
210 fun mk_pred_data ((intros, elim), (function_names, predfun_data, needs_random)) =
211 PredData {intros = intros, elim = elim,
212 function_names = function_names, predfun_data = predfun_data, needs_random = needs_random}
214 fun map_pred_data f (PredData {intros, elim, function_names, predfun_data, needs_random}) =
215 mk_pred_data (f ((intros, elim), (function_names, predfun_data, needs_random)))
217 fun eq_option eq (NONE, NONE) = true
218 | eq_option eq (SOME x, SOME y) = eq (x, y)
219 | eq_option eq _ = false
221 fun eq_pred_data (PredData d1, PredData d2) =
222 eq_list (Thm.eq_thm) (#intros d1, #intros d2) andalso
223 eq_option (Thm.eq_thm) (#elim d1, #elim d2)
225 structure PredData = Theory_Data
227 type T = pred_data Graph.T;
228 val empty = Graph.empty;
230 val merge = Graph.merge eq_pred_data;
235 fun lookup_pred_data thy name =
236 Option.map rep_pred_data (try (Graph.get_node (PredData.get thy)) name)
238 fun the_pred_data thy name = case lookup_pred_data thy name
239 of NONE => error ("No such predicate " ^ quote name)
242 val is_registered = is_some oo lookup_pred_data
244 val all_preds_of = Graph.keys o PredData.get
246 fun intros_of thy = map (Thm.transfer thy) o #intros o the_pred_data thy
248 fun the_elim_of thy name = case #elim (the_pred_data thy name)
249 of NONE => error ("No elimination rule for predicate " ^ quote name)
250 | SOME thm => Thm.transfer thy thm
252 val has_elim = is_some o #elim oo the_pred_data;
254 fun function_names_of compilation thy name =
255 case AList.lookup (op =) (#function_names (the_pred_data thy name)) compilation of
256 NONE => error ("No " ^ string_of_compilation compilation
257 ^ "functions defined for predicate " ^ quote name)
258 | SOME fun_names => fun_names
260 fun function_name_of compilation thy name mode =
261 case AList.lookup (op =) (function_names_of compilation thy name) mode of
262 NONE => error ("No " ^ string_of_compilation compilation
263 ^ "function defined for mode " ^ string_of_mode mode ^ " of predicate " ^ quote name)
264 | SOME function_name => function_name
266 fun modes_of compilation thy name = map fst (function_names_of compilation thy name)
268 fun all_modes_of compilation thy =
269 map_filter (fn name => Option.map (pair name) (try (modes_of compilation thy) name))
272 val all_random_modes_of = all_modes_of Random
274 fun defined_functions compilation thy name =
275 AList.defined (op =) (#function_names (the_pred_data thy name)) compilation
277 fun lookup_predfun_data thy name mode =
278 Option.map rep_predfun_data
279 (AList.lookup (op =) (#predfun_data (the_pred_data thy name)) mode)
281 fun the_predfun_data thy name mode =
282 case lookup_predfun_data thy name mode of
283 NONE => error ("No function defined for mode " ^ string_of_mode mode ^
284 " of predicate " ^ name)
287 val predfun_definition_of = #definition ooo the_predfun_data
289 val predfun_intro_of = #intro ooo the_predfun_data
291 val predfun_elim_of = #elim ooo the_predfun_data
293 (* diagnostic display functions *)
295 fun print_modes options thy modes =
296 if show_modes options then
297 tracing ("Inferred modes:\n" ^
298 cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
299 string_of_mode ms)) modes))
302 fun print_pred_mode_table string_of_entry thy pred_mode_table =
304 fun print_mode pred (mode, entry) = "mode : " ^ string_of_mode mode
305 ^ string_of_entry pred mode entry
306 fun print_pred (pred, modes) =
307 "predicate " ^ pred ^ ": " ^ cat_lines (map (print_mode pred) modes)
308 val _ = tracing (cat_lines (map print_pred pred_mode_table))
311 fun string_of_prem thy (Prem t) =
312 (Syntax.string_of_term_global thy t) ^ "(premise)"
313 | string_of_prem thy (Negprem t) =
314 (Syntax.string_of_term_global thy (HOLogic.mk_not t)) ^ "(negative premise)"
315 | string_of_prem thy (Sidecond t) =
316 (Syntax.string_of_term_global thy t) ^ "(sidecondition)"
317 | string_of_prem thy _ = error "string_of_prem: unexpected input"
319 fun string_of_clause thy pred (ts, prems) =
320 (space_implode " --> "
321 (map (string_of_prem thy) prems)) ^ " --> " ^ pred ^ " "
322 ^ (space_implode " " (map (Syntax.string_of_term_global thy) ts))
324 fun print_compiled_terms options thy =
325 if show_compilation options then
326 print_pred_mode_table (fn _ => fn _ => Syntax.string_of_term_global thy) thy
329 fun print_stored_rules thy =
331 val preds = (Graph.keys o PredData.get) thy
332 fun print pred () = let
333 val _ = writeln ("predicate: " ^ pred)
334 val _ = writeln ("introrules: ")
335 val _ = fold (fn thm => fn u => writeln (Display.string_of_thm_global thy thm))
336 (rev (intros_of thy pred)) ()
338 if (has_elim thy pred) then
339 writeln ("elimrule: " ^ Display.string_of_thm_global thy (the_elim_of thy pred))
341 writeln ("no elimrule defined")
347 fun print_all_modes compilation thy =
349 val _ = writeln ("Inferred modes:")
350 fun print (pred, modes) u =
352 val _ = writeln ("predicate: " ^ pred)
353 val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
356 fold print (all_modes_of compilation thy) ()
359 (* validity checks *)
360 (* EXPECTED MODE and PROPOSED_MODE are largely the same; define a clear semantics for those! *)
362 fun check_expected_modes preds options modes =
363 case expected_modes options of
364 SOME (s, ms) => (case AList.lookup (op =) modes s of
369 if not (eq_set eq_mode (ms, modes')) then
370 error ("expected modes were not inferred:\n"
371 ^ " inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode modes') ^ "\n"
372 ^ " expected modes for " ^ s ^ ": " ^ commas (map string_of_mode ms))
378 fun check_proposed_modes preds options modes extra_modes errors =
379 case proposed_modes options of
380 SOME (s, ms) => (case AList.lookup (op =) modes s of
383 val preds_without_modes = map fst (filter (null o snd) (modes @ extra_modes))
384 val modes' = inferred_ms
386 if not (eq_set eq_mode (ms, modes')) then
387 error ("expected modes were not inferred:\n"
388 ^ " inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode modes') ^ "\n"
389 ^ " expected modes for " ^ s ^ ": " ^ commas (map string_of_mode ms) ^ "\n"
390 ^ "For the following clauses, the following modes could not be inferred: " ^ "\n"
392 (if not (null preds_without_modes) then
393 "\n" ^ "No mode inferred for the predicates " ^ commas preds_without_modes
400 (* importing introduction rules *)
402 fun unify_consts thy cs intr_ts =
404 val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
405 fun varify (t, (i, ts)) =
406 let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify [] t))
407 in (maxidx_of_term t', t'::ts) end;
408 val (i, cs') = List.foldr varify (~1, []) cs;
409 val (i', intr_ts') = List.foldr varify (i, []) intr_ts;
410 val rec_consts = fold add_term_consts_2 cs' [];
411 val intr_consts = fold add_term_consts_2 intr_ts' [];
412 fun unify (cname, cT) =
413 let val consts = map snd (filter (fn c => fst c = cname) intr_consts)
414 in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
415 val (env, _) = fold unify rec_consts (Vartab.empty, i');
416 val subst = map_types (Envir.norm_type env)
417 in (map subst cs', map subst intr_ts')
418 end) handle Type.TUNIFY =>
419 (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
421 fun import_intros inp_pred [] ctxt =
423 val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
424 val T = fastype_of outp_pred
425 (* TODO: put in a function for this next line! *)
426 val paramTs = ho_argsT_of (hd (all_modes_of_typ T)) (binder_types T)
427 val (param_names, ctxt'') = Variable.variant_fixes
428 (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
429 val params = map2 (curry Free) param_names paramTs
431 (((outp_pred, params), []), ctxt')
433 | import_intros inp_pred (th :: ths) ctxt =
435 val ((_, [th']), ctxt') = Variable.import true [th] ctxt
436 val thy = ProofContext.theory_of ctxt'
437 val (pred, args) = strip_intro_concl th'
438 val T = fastype_of pred
439 val ho_args = ho_args_of (hd (all_modes_of_typ T)) args
440 fun subst_of (pred', pred) =
442 val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
443 in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
444 fun instantiate_typ th =
446 val (pred', _) = strip_intro_concl th
447 val _ = if not (fst (dest_Const pred) = fst (dest_Const pred')) then
448 error "Trying to instantiate another predicate" else ()
449 in Thm.certify_instantiate (subst_of (pred', pred), []) th end;
450 fun instantiate_ho_args th =
452 val (_, args') = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) th
453 val ho_args' = map dest_Var (ho_args_of (hd (all_modes_of_typ T)) args')
454 in Thm.certify_instantiate ([], ho_args' ~~ ho_args) th end
456 Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
457 val ((_, ths'), ctxt1) =
458 Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
460 (((outp_pred, ho_args), th' :: ths'), ctxt1)
463 (* generation of case rules from user-given introduction rules *)
465 fun mk_args2 (Type ("*", [T1, T2])) st =
467 val (t1, st') = mk_args2 T1 st
468 val (t2, st'') = mk_args2 T2 st'
470 (HOLogic.mk_prod (t1, t2), st'')
472 | mk_args2 (T as Type ("fun", _)) (params, ctxt) =
474 val (S, U) = strip_type T
476 if U = HOLogic.boolT then
477 (hd params, (tl params, ctxt))
480 val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
482 (Free (x, T), (params, ctxt'))
485 | mk_args2 T (params, ctxt) =
487 val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
489 (Free (x, T), (params, ctxt'))
492 fun mk_casesrule ctxt pred introrules =
494 val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
495 val intros = map prop_of intros_th
496 val ([propname], ctxt2) = Variable.variant_fixes ["thesis"] ctxt1
497 val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
498 val argsT = binder_types (fastype_of pred)
499 val (argvs, _) = fold_map mk_args2 argsT (params, ctxt2)
502 val (_, args) = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl) intro
503 val prems = Logic.strip_imp_prems intro
504 val eqprems = map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) argvs args
505 val frees = (fold o fold_aterms)
507 if member (op aconv) params t then I else insert (op aconv) t
508 | _ => I) (args @ prems) []
509 in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
510 val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
511 val cases = map mk_case intros
512 in Logic.list_implies (assm :: cases, prop) end;
514 (** preprocessing rules **)
516 fun imp_prems_conv cv ct =
517 case Thm.term_of ct of
518 Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
519 | _ => Conv.all_conv ct
521 fun Trueprop_conv cv ct =
522 case Thm.term_of ct of
523 Const ("Trueprop", _) $ _ => Conv.arg_conv cv ct
524 | _ => error "Trueprop_conv"
526 fun preprocess_intro thy rule =
529 (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
530 (Thm.transfer thy rule)
532 fun preprocess_elim thy elimrule =
534 fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
535 HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
537 val ctxt = ProofContext.init thy
538 val ((_, [elimrule]), ctxt') = Variable.import false [elimrule] ctxt
539 val prems = Thm.prems_of elimrule
540 val nargs = length (snd (strip_comb (HOLogic.dest_Trueprop (hd prems))))
541 fun preprocess_case t =
543 val params = Logic.strip_params t
544 val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
545 val assums_hyp' = assums1 @ (map replace_eqs assums2)
547 list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t))
549 val cases' = map preprocess_case (tl prems)
550 val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
551 val bigeq = (Thm.symmetric (Conv.implies_concl_conv
552 (MetaSimplifier.rewrite true [@{thm Predicate.eq_is_eq}])
553 (cterm_of thy elimrule')))
554 val tac = (fn _ => Skip_Proof.cheat_tac thy)
555 val eq = Goal.prove ctxt' [] [] (Logic.mk_equals ((Thm.prop_of elimrule), elimrule')) tac
557 Thm.equal_elim eq elimrule |> singleton (Variable.export ctxt' ctxt)
560 fun expand_tuples_elim th = th
562 val no_compilation = ([], [], [])
564 fun fetch_pred_data thy name =
565 case try (Inductive.the_inductive (ProofContext.init thy)) name of
566 SOME (info as (_, result)) =>
568 fun is_intro_of intro =
570 val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
571 in (fst (dest_Const const) = name) end;
573 (map (expand_tuples thy #> preprocess_intro thy) (filter is_intro_of (#intrs result)))
574 val index = find_index (fn s => s = name) (#names (fst info))
575 val pre_elim = nth (#elims result) index
576 val pred = nth (#preds result) index
577 (*val elim = singleton (Inductive_Set.codegen_preproc thy) (preprocess_elim thy nparams
578 (expand_tuples_elim pre_elim))*)
580 (Drule.export_without_context o Skip_Proof.make_thm thy)
581 (mk_casesrule (ProofContext.init thy) pred intros)
583 mk_pred_data ((intros, SOME elim), no_compilation)
585 | NONE => error ("No such predicate: " ^ quote name)
587 fun add_predfun_data name mode data =
589 val add = (apsnd o apsnd3) (cons (mode, mk_predfun_data data))
590 in PredData.map (Graph.map_node name (map_pred_data add)) end
592 fun is_inductive_predicate thy name =
593 is_some (try (Inductive.the_inductive (ProofContext.init thy)) name)
595 fun depending_preds_of thy (key, value) =
597 val intros = (#intros o rep_pred_data) value
599 fold Term.add_const_names (map Thm.prop_of intros) []
600 |> filter (fn c => (not (c = key)) andalso
601 (is_inductive_predicate thy c orelse is_registered thy c))
604 fun add_intro thm thy =
606 val (name, T) = dest_Const (fst (strip_intro_concl thm))
608 case try (Graph.get_node gr) name of
609 SOME pred_data => Graph.map_node name (map_pred_data
610 (apfst (fn (intros, elim) => (intros @ [thm], elim)))) gr
611 | NONE => Graph.new_node (name, mk_pred_data (([thm], NONE), no_compilation)) gr
612 in PredData.map cons_intro thy end
616 val (name, _) = dest_Const (fst
617 (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
618 fun set (intros, _) = (intros, SOME thm)
619 in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
621 fun register_predicate (constname, pre_intros, pre_elim) thy =
623 val intros = map (preprocess_intro thy) pre_intros
624 val elim = preprocess_elim thy pre_elim
626 if not (member (op =) (Graph.keys (PredData.get thy)) constname) then
628 (Graph.new_node (constname,
629 mk_pred_data ((intros, SOME elim), no_compilation))) thy
633 fun register_intros (constname, pre_intros) thy =
635 val T = Sign.the_const_type thy constname
636 fun constname_of_intro intr = fst (dest_Const (fst (strip_intro_concl intr)))
637 val _ = if not (forall (fn intr => constname_of_intro intr = constname) pre_intros) then
638 error ("register_intros: Introduction rules of different constants are used\n" ^
639 "expected rules for " ^ constname ^ ", but received rules for " ^
640 commas (map constname_of_intro pre_intros))
642 val pred = Const (constname, T)
644 (Drule.export_without_context o Skip_Proof.make_thm thy)
645 (mk_casesrule (ProofContext.init thy) pred pre_intros)
646 in register_predicate (constname, pre_intros, pre_elim) thy end
648 fun defined_function_of compilation pred =
650 val set = (apsnd o apfst3) (cons (compilation, []))
652 PredData.map (Graph.map_node pred (map_pred_data set))
655 fun set_function_name compilation pred mode name =
657 val set = (apsnd o apfst3)
658 (AList.map_default (op =) (compilation, [(mode, name)]) (cons (mode, name)))
660 PredData.map (Graph.map_node pred (map_pred_data set))
663 fun set_needs_random name modes =
665 val set = (apsnd o aptrd3) (K modes)
667 PredData.map (Graph.map_node name (map_pred_data set))
670 (* datastructures and setup for generic compilation *)
672 datatype compilation_funs = CompilationFuns of {
673 mk_predT : typ -> typ,
674 dest_predT : typ -> typ,
675 mk_bot : typ -> term,
676 mk_single : term -> term,
677 mk_bind : term * term -> term,
678 mk_sup : term * term -> term,
679 mk_if : term -> term,
680 mk_not : term -> term,
681 mk_map : typ -> typ -> term -> term -> term
684 fun mk_predT (CompilationFuns funs) = #mk_predT funs
685 fun dest_predT (CompilationFuns funs) = #dest_predT funs
686 fun mk_bot (CompilationFuns funs) = #mk_bot funs
687 fun mk_single (CompilationFuns funs) = #mk_single funs
688 fun mk_bind (CompilationFuns funs) = #mk_bind funs
689 fun mk_sup (CompilationFuns funs) = #mk_sup funs
690 fun mk_if (CompilationFuns funs) = #mk_if funs
691 fun mk_not (CompilationFuns funs) = #mk_not funs
692 fun mk_map (CompilationFuns funs) = #mk_map funs
694 structure PredicateCompFuns =
697 fun mk_predT T = Type (@{type_name Predicate.pred}, [T])
699 fun dest_predT (Type (@{type_name Predicate.pred}, [T])) = T
700 | dest_predT T = raise TYPE ("dest_predT", [T], []);
702 fun mk_bot T = Const (@{const_name Orderings.bot}, mk_predT T);
705 let val T = fastype_of t
706 in Const(@{const_name Predicate.single}, T --> mk_predT T) $ t end;
709 let val T as Type ("fun", [_, U]) = fastype_of f
711 Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
714 val mk_sup = HOLogic.mk_binop @{const_name sup};
716 fun mk_if cond = Const (@{const_name Predicate.if_pred},
717 HOLogic.boolT --> mk_predT HOLogic.unitT) $ cond;
719 fun mk_not t = let val T = mk_predT HOLogic.unitT
720 in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
723 let val T as Type ("fun", [T', _]) = fastype_of f
725 Const (@{const_name Predicate.Pred}, T --> mk_predT T') $ f
732 Const (@{const_name Predicate.eval}, mk_predT T --> T --> HOLogic.boolT) $ f $ x
735 fun dest_Eval (Const (@{const_name Predicate.eval}, _) $ f $ x) = (f, x)
737 fun mk_map T1 T2 tf tp = Const (@{const_name Predicate.map},
738 (T1 --> T2) --> mk_predT T1 --> mk_predT T2) $ tf $ tp;
740 val compfuns = CompilationFuns {mk_predT = mk_predT, dest_predT = dest_predT, mk_bot = mk_bot,
741 mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if, mk_not = mk_not,
746 structure RandomPredCompFuns =
749 fun mk_randompredT T =
750 @{typ Random.seed} --> HOLogic.mk_prodT (PredicateCompFuns.mk_predT T, @{typ Random.seed})
752 fun dest_randompredT (Type ("fun", [@{typ Random.seed}, Type (@{type_name "*"},
753 [Type (@{type_name "Predicate.pred"}, [T]), @{typ Random.seed}])])) = T
754 | dest_randompredT T = raise TYPE ("dest_randompredT", [T], []);
756 fun mk_bot T = Const(@{const_name Quickcheck.empty}, mk_randompredT T)
762 Const (@{const_name Quickcheck.single}, T --> mk_randompredT T) $ t
767 val T as (Type ("fun", [_, U])) = fastype_of f
769 Const (@{const_name Quickcheck.bind}, fastype_of x --> T --> U) $ x $ f
772 val mk_sup = HOLogic.mk_binop @{const_name Quickcheck.union}
774 fun mk_if cond = Const (@{const_name Quickcheck.if_randompred},
775 HOLogic.boolT --> mk_randompredT HOLogic.unitT) $ cond;
777 fun mk_not t = let val T = mk_randompredT HOLogic.unitT
778 in Const (@{const_name Quickcheck.not_randompred}, T --> T) $ t end
780 fun mk_map T1 T2 tf tp = Const (@{const_name Quickcheck.map},
781 (T1 --> T2) --> mk_randompredT T1 --> mk_randompredT T2) $ tf $ tp
783 val compfuns = CompilationFuns {mk_predT = mk_randompredT, dest_predT = dest_randompredT,
784 mk_bot = mk_bot, mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if,
785 mk_not = mk_not, mk_map = mk_map};
789 structure DSequence_CompFuns =
792 fun mk_dseqT T = Type ("fun", [@{typ code_numeral}, Type ("fun", [@{typ bool},
793 Type (@{type_name Option.option}, [Type ("Lazy_Sequence.lazy_sequence", [T])])])])
795 fun dest_dseqT (Type ("fun", [@{typ code_numeral}, Type ("fun", [@{typ bool},
796 Type (@{type_name Option.option}, [Type ("Lazy_Sequence.lazy_sequence", [T])])])])) = T
797 | dest_dseqT T = raise TYPE ("dest_dseqT", [T], []);
799 fun mk_bot T = Const ("DSequence.empty", mk_dseqT T);
802 let val T = fastype_of t
803 in Const("DSequence.single", T --> mk_dseqT T) $ t end;
806 let val T as Type ("fun", [_, U]) = fastype_of f
808 Const ("DSequence.bind", fastype_of x --> T --> U) $ x $ f
811 val mk_sup = HOLogic.mk_binop "DSequence.union";
813 fun mk_if cond = Const ("DSequence.if_seq",
814 HOLogic.boolT --> mk_dseqT HOLogic.unitT) $ cond;
816 fun mk_not t = let val T = mk_dseqT HOLogic.unitT
817 in Const ("DSequence.not_seq", T --> T) $ t end
819 fun mk_map T1 T2 tf tp = Const ("DSequence.map",
820 (T1 --> T2) --> mk_dseqT T1 --> mk_dseqT T2) $ tf $ tp
822 val compfuns = CompilationFuns {mk_predT = mk_dseqT, dest_predT = dest_dseqT,
823 mk_bot = mk_bot, mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if,
824 mk_not = mk_not, mk_map = mk_map}
828 structure Random_Sequence_CompFuns =
831 fun mk_random_dseqT T =
832 @{typ code_numeral} --> @{typ code_numeral} --> @{typ Random.seed} -->
833 HOLogic.mk_prodT (DSequence_CompFuns.mk_dseqT T, @{typ Random.seed})
835 fun dest_random_dseqT (Type ("fun", [@{typ code_numeral}, Type ("fun", [@{typ code_numeral},
836 Type ("fun", [@{typ Random.seed},
837 Type (@{type_name "*"}, [T, @{typ Random.seed}])])])])) = DSequence_CompFuns.dest_dseqT T
838 | dest_random_dseqT T = raise TYPE ("dest_random_dseqT", [T], []);
840 fun mk_bot T = Const ("Random_Sequence.empty", mk_random_dseqT T);
843 let val T = fastype_of t
844 in Const("Random_Sequence.single", T --> mk_random_dseqT T) $ t end;
848 val T as Type ("fun", [_, U]) = fastype_of f
850 Const ("Random_Sequence.bind", fastype_of x --> T --> U) $ x $ f
853 val mk_sup = HOLogic.mk_binop "Random_Sequence.union";
855 fun mk_if cond = Const ("Random_Sequence.if_random_dseq",
856 HOLogic.boolT --> mk_random_dseqT HOLogic.unitT) $ cond;
858 fun mk_not t = let val T = mk_random_dseqT HOLogic.unitT
859 in Const ("Random_Sequence.not_random_dseq", T --> T) $ t end
861 fun mk_map T1 T2 tf tp = Const ("Random_Sequence.map",
862 (T1 --> T2) --> mk_random_dseqT T1 --> mk_random_dseqT T2) $ tf $ tp
864 val compfuns = CompilationFuns {mk_predT = mk_random_dseqT, dest_predT = dest_random_dseqT,
865 mk_bot = mk_bot, mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if,
866 mk_not = mk_not, mk_map = mk_map}
874 val random = Const ("Quickcheck.random_class.random",
875 @{typ code_numeral} --> @{typ Random.seed} -->
876 HOLogic.mk_prodT (HOLogic.mk_prodT (T, @{typ "unit => term"}), @{typ Random.seed}))
878 Const ("Random_Sequence.Random", (@{typ code_numeral} --> @{typ Random.seed} -->
879 HOLogic.mk_prodT (HOLogic.mk_prodT (T, @{typ "unit => term"}), @{typ Random.seed})) -->
880 Random_Sequence_CompFuns.mk_random_dseqT T) $ random
885 (* for external use with interactive mode *)
886 val pred_compfuns = PredicateCompFuns.compfuns
887 val randompred_compfuns = Random_Sequence_CompFuns.compfuns;
889 (* function types and names of different compilations *)
891 fun funT_of compfuns mode T =
893 val Ts = binder_types T
894 val (inTs, outTs) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode Ts
896 inTs ---> (mk_predT compfuns (HOLogic.mk_tupleT outTs))
899 (** mode analysis **)
903 val cnstrs = flat (maps
904 (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
905 (Symtab.dest (Datatype.get_all thy)));
906 fun check t = (case strip_comb t of
908 | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
909 (SOME (i, Tname), Type (Tname', _)) =>
910 length ts = i andalso Tname = Tname' andalso forall check ts
915 (*** check if a type is an equality type (i.e. doesn't contain fun)
916 FIXME this is only an approximation ***)
917 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
920 fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
921 val terms_vs = distinct (op =) o maps term_vs;
923 (** collect all Frees in a term (with duplicates!) **)
925 fold_aterms (fn Free xT => cons xT | _ => I) tm [];
932 | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
933 else y::merge (x::xs) ys;
934 val is = subsets (i+1) j
935 in merge (map (fn ks => i::ks) is) is end
938 fun print_failed_mode options thy modes p m rs is =
939 if show_mode_inference options then
941 val _ = tracing ("Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^
942 p ^ " violates mode " ^ string_of_mode m)
946 fun error_of p m is =
947 (" Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^
948 p ^ " violates mode " ^ string_of_mode m)
950 fun is_all_input mode =
952 fun is_all_input' (Fun _) = true
953 | is_all_input' (Pair (m1, m2)) = is_all_input' m1 andalso is_all_input' m2
954 | is_all_input' Input = true
955 | is_all_input' Output = false
957 forall is_all_input' (strip_fun_mode mode)
962 val (Ts, U) = strip_type T
963 fun input_of (Type ("*", [T1, T2])) = Pair (input_of T1, input_of T2)
966 if U = HOLogic.boolT then
967 fold_rev (curry Fun) (map input_of Ts) Bool
969 error "all_input_of: not a predicate"
972 fun partial_hd [] = NONE
973 | partial_hd (x :: xs) = SOME x
975 fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
976 val terms_vs = distinct (op =) o maps term_vs;
980 val (Ts, U) = strip_type T
982 fold_rev (curry Fun) (map (K Input) Ts) Input
987 val (Ts, U) = strip_type T
989 fold_rev (curry Fun) (map (K Output) Ts) Output
992 fun is_invertible_function thy (Const (f, _)) = is_constr thy f
993 | is_invertible_function thy _ = false
995 fun non_invertible_subterms thy (Free _) = []
996 | non_invertible_subterms thy t =
997 case (strip_comb t) of (f, args) =>
998 if is_invertible_function thy f then
999 maps (non_invertible_subterms thy) args
1003 fun collect_non_invertible_subterms thy (f as Free _) (names, eqs) = (f, (names, eqs))
1004 | collect_non_invertible_subterms thy t (names, eqs) =
1005 case (strip_comb t) of (f, args) =>
1006 if is_invertible_function thy f then
1008 val (args', (names', eqs')) =
1009 fold_map (collect_non_invertible_subterms thy) args (names, eqs)
1011 (list_comb (f, args'), (names', eqs'))
1015 val s = Name.variant names "x"
1016 val v = Free (s, fastype_of t)
1018 (v, (s :: names, HOLogic.mk_eq (v, t) :: eqs))
1021 if is_constrt thy t then (t, (names, eqs)) else
1023 val s = Name.variant names "x"
1024 val v = Free (s, fastype_of t)
1025 in (v, (s::names, HOLogic.mk_eq (v, t)::eqs)) end;
1028 fun is_possible_output thy vs t =
1030 (fn t => is_eqT (fastype_of t) andalso forall (member (op =) vs) (term_vs t))
1031 (non_invertible_subterms thy t)
1033 fun vars_of_destructable_term thy (Free (x, _)) = [x]
1034 | vars_of_destructable_term thy t =
1035 case (strip_comb t) of (f, args) =>
1036 if is_invertible_function thy f then
1037 maps (vars_of_destructable_term thy) args
1041 fun is_constructable thy vs t = forall (member (op =) vs) (term_vs t)
1043 fun missing_vars vs t = subtract (op =) vs (term_vs t)
1045 fun derivations_of thy modes vs t Input =
1046 [(Term Input, missing_vars vs t)]
1047 | derivations_of thy modes vs t Output =
1048 if is_possible_output thy vs t then [(Term Output, [])] else []
1049 | derivations_of thy modes vs (Const ("Pair", _) $ t1 $ t2) (Pair (m1, m2)) =
1051 (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2))
1052 (derivations_of thy modes vs t1 m1) (derivations_of thy modes vs t2 m2)
1053 | derivations_of thy modes vs t m =
1054 (case try (all_derivations_of thy modes vs) t of
1055 SOME derivs => filter (fn (d, mvars) => mode_of d = m) derivs
1056 | NONE => (if is_all_input m then [(Context m, [])] else []))
1057 and all_derivations_of thy modes vs (Const ("Pair", _) $ t1 $ t2) =
1059 val derivs1 = all_derivations_of thy modes vs t1
1060 val derivs2 = all_derivations_of thy modes vs t2
1063 (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2))
1066 | all_derivations_of thy modes vs (t1 $ t2) =
1068 val derivs1 = all_derivations_of thy modes vs t1
1070 maps (fn (d1, mvars1) =>
1072 Fun (m', _) => map (fn (d2, mvars2) =>
1073 (Mode_App (d1, d2), union (op =) mvars1 mvars2)) (derivations_of thy modes vs t2 m')
1074 | _ => error "Something went wrong") derivs1
1076 | all_derivations_of thy modes vs (Const (s, T)) =
1077 (case (AList.lookup (op =) modes s) of
1078 SOME ms => map (fn m => (Context m, [])) ms
1079 | NONE => error ("No mode for constant " ^ s))
1080 | all_derivations_of _ modes vs (Free (x, _)) =
1081 (case (AList.lookup (op =) modes x) of
1082 SOME ms => map (fn m => (Context m , [])) ms
1083 | NONE => error ("No mode for parameter variable " ^ x))
1084 | all_derivations_of _ modes vs _ = error "all_derivations_of"
1086 fun rev_option_ord ord (NONE, NONE) = EQUAL
1087 | rev_option_ord ord (NONE, SOME _) = GREATER
1088 | rev_option_ord ord (SOME _, NONE) = LESS
1089 | rev_option_ord ord (SOME x, SOME y) = ord (x, y)
1091 fun term_of_prem (Prem t) = t
1092 | term_of_prem (Negprem t) = t
1093 | term_of_prem (Sidecond t) = t
1095 fun random_mode_in_deriv modes t deriv =
1096 case try dest_Const (fst (strip_comb t)) of
1098 (case AList.lookup (op =) modes s of
1100 (case AList.lookup (op =) ms (head_mode_of deriv) of
1106 fun number_of_output_positions mode =
1108 val args = strip_fun_mode mode
1109 fun contains_output (Fun _) = false
1110 | contains_output Input = false
1111 | contains_output Output = true
1112 | contains_output (Pair (m1, m2)) = contains_output m1 orelse contains_output m2
1114 length (filter contains_output args)
1117 fun lex_ord ord1 ord2 (x, x') =
1118 case ord1 (x, x') of
1119 EQUAL => ord2 (x, x')
1122 fun deriv_ord2' thy modes t1 t2 ((deriv1, mvars1), (deriv2, mvars2)) =
1124 fun mvars_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
1125 int_ord (length mvars1, length mvars2)
1126 fun random_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
1127 int_ord (if random_mode_in_deriv modes t1 deriv1 then 1 else 0,
1128 if random_mode_in_deriv modes t1 deriv1 then 1 else 0)
1129 fun output_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
1130 int_ord (number_of_output_positions (head_mode_of deriv1),
1131 number_of_output_positions (head_mode_of deriv2))
1133 lex_ord mvars_ord (lex_ord random_mode_ord output_mode_ord)
1134 ((t1, deriv1, mvars1), (t2, deriv2, mvars2))
1137 fun deriv_ord2 thy modes t = deriv_ord2' thy modes t t
1139 fun deriv_ord ((deriv1, mvars1), (deriv2, mvars2)) =
1140 int_ord (length mvars1, length mvars2)
1142 fun premise_ord thy modes ((prem1, a1), (prem2, a2)) =
1143 rev_option_ord (deriv_ord2' thy modes (term_of_prem prem1) (term_of_prem prem2)) (a1, a2)
1145 fun print_mode_list modes =
1146 tracing ("modes: " ^ (commas (map (fn (s, ms) => s ^ ": " ^
1147 commas (map (fn (m, r) => string_of_mode m ^ (if r then " random " else " not ")) ms)) modes)))
1149 fun select_mode_prem' thy modes vs ps =
1151 val modes' = map (fn (s, ms) => (s, map fst ms)) modes
1153 partial_hd (sort (premise_ord thy modes) (ps ~~ map
1156 (sort (deriv_ord2 thy modes t) (all_derivations_of thy modes' vs t))
1157 | Sidecond t => SOME (Context Bool, missing_vars vs t)
1160 (sort (deriv_ord2 thy modes t) (filter (fn (d, missing_vars) => is_all_input (head_mode_of d))
1161 (all_derivations_of thy modes' vs t)))
1162 | p => error (string_of_prem thy p))
1166 fun check_mode_clause' use_random thy param_vs modes mode (ts, ps) =
1168 val vTs = distinct (op =) (fold Term.add_frees (map term_of_prem ps) (fold Term.add_frees ts []))
1169 val modes' = modes @ (param_vs ~~ map (fn x => [(x, false)]) (ho_arg_modes_of mode))
1170 val (in_ts, out_ts) = split_mode mode ts
1171 val in_vs = maps (vars_of_destructable_term thy) in_ts
1172 val out_vs = terms_vs out_ts
1173 fun check_mode_prems acc_ps rnd vs [] = SOME (acc_ps, vs, rnd)
1174 | check_mode_prems acc_ps rnd vs ps =
1175 (case select_mode_prem' thy modes' vs ps of
1176 SOME (p, SOME (deriv, [])) => check_mode_prems ((p, deriv) :: acc_ps) rnd (*TODO: uses random? *)
1178 Prem t => union (op =) vs (term_vs t)
1180 | Negprem t => union (op =) vs (term_vs t)
1181 | _ => error "I do not know")
1182 (filter_out (equal p) ps)
1183 | SOME (p, SOME (deriv, missing_vars)) =>
1185 check_mode_prems ((p, deriv) :: (map
1186 (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output)) missing_vars)
1189 Prem t => union (op =) vs (term_vs t)
1190 | Sidecond t => union (op =) vs (term_vs t)
1191 | Negprem t => union (op =) vs (term_vs t)
1192 | _ => error "I do not know")
1193 (filter_out (equal p) ps)
1195 | SOME (p, NONE) => NONE
1198 case check_mode_prems [] false in_vs ps of
1200 | SOME (acc_ps, vs, rnd) =>
1201 if forall (is_constructable thy vs) (in_ts @ out_ts) then
1202 SOME (ts, rev acc_ps, rnd)
1206 val generators = map
1207 (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output))
1208 (subtract (op =) vs (terms_vs out_ts))
1210 SOME (ts, rev (generators @ acc_ps), true)
1216 datatype result = Success of bool | Error of string
1218 fun check_modes_pred' use_random options thy param_vs clauses modes (p, ms) =
1222 fun split' [] (ys, zs) = (rev ys, rev zs)
1223 | split' ((m, Error z) :: xs) (ys, zs) = split' xs (ys, z :: zs)
1224 | split' ((m, Success rnd) :: xs) (ys, zs) = split' xs ((m, rnd) :: ys, zs)
1228 val rs = these (AList.lookup (op =) clauses p)
1231 val res = map (check_mode_clause' use_random thy param_vs modes m) rs
1233 case find_indices is_none res of
1234 [] => Success (exists (fn SOME (_, _, true) => true | _ => false) res)
1235 | is => (print_failed_mode options thy modes p m rs is; Error (error_of p m is))
1237 val res = map (fn (m, _) => (m, check_mode m)) ms
1238 val (ms', errors) = split res
1243 fun get_modes_pred' use_random thy param_vs clauses modes (p, ms) =
1245 val rs = these (AList.lookup (op =) clauses p)
1247 (p, map (fn (m, rnd) =>
1248 (m, map ((fn (ts, ps, rnd) => (ts, ps)) o the o check_mode_clause' use_random thy param_vs modes m) rs)) ms)
1253 in if x = y then x else fixp f y end;
1255 fun fixp_with_state f (x, state) =
1257 val (y, state') = f (x, state)
1259 if x = y then (y, state') else fixp_with_state f (y, state')
1262 fun infer_modes use_random options preds extra_modes param_vs clauses thy =
1264 val all_modes = map (fn (s, T) => (s, map (rpair false) (all_modes_of_typ T))) preds
1265 fun needs_random s m = (m, member (op =) (#needs_random (the_pred_data thy s)) m)
1266 val extra_modes = map (fn (s, ms) => (s, map (needs_random s) ms)) extra_modes
1267 val (modes, errors) =
1268 fixp_with_state (fn (modes, errors) =>
1271 (check_modes_pred' use_random options thy param_vs clauses (modes @ extra_modes)) modes
1272 in (map fst res, errors @ maps snd res) end)
1274 val thy' = fold (fn (s, ms) => if member (op =) (map fst preds) s then
1275 set_needs_random s (map fst (filter (fn (_, rnd) => rnd = true) ms)) else I) modes thy
1277 ((map (get_modes_pred' use_random thy param_vs clauses (modes @ extra_modes)) modes, errors), thy')
1280 (* term construction *)
1282 fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of
1283 NONE => (Free (s, T), (names, (s, [])::vs))
1286 val s' = Name.variant names s;
1287 val v = Free (s', T)
1289 (v, (s'::names, AList.update (op =) (s, v::xs) vs))
1292 fun distinct_v (Free (s, T)) nvs = mk_v nvs s T
1293 | distinct_v (t $ u) nvs =
1295 val (t', nvs') = distinct_v t nvs;
1296 val (u', nvs'') = distinct_v u nvs';
1297 in (t' $ u', nvs'') end
1298 | distinct_v x nvs = (x, nvs);
1300 (** specific rpred functions -- move them to the correct place in this file *)
1302 fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names)
1303 | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
1305 val Ts = binder_types T
1306 fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
1307 | mk_split_lambda [x] t = lambda x t
1308 | mk_split_lambda xs t =
1310 fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
1311 | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
1313 mk_split_lambda' xs t
1317 val vname = Name.variant names ("x" ^ string_of_int i)
1318 val default = Free (vname, T)
1320 case AList.lookup (op =) mode i of
1321 NONE => (([], [default]), [default])
1322 | SOME NONE => (([default], []), [default])
1323 | SOME (SOME pis) =>
1324 case HOLogic.strip_tupleT T of
1325 [] => error "pair mode but unit tuple" (*(([default], []), [default])*)
1326 | [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
1329 val vnames = Name.variant_list names
1330 (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
1332 val args = map2 (curry Free) vnames Ts
1333 fun split_args (i, arg) (ins, outs) =
1334 if member (op =) pis i then
1338 val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
1339 fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
1340 in ((tuple inargs, tuple outargs), args) end
1342 val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
1343 val (inargs, outargs) = pairself flat (split_list inoutargs)
1344 val r = PredicateCompFuns.mk_Eval
1345 (list_comb (x, inargs @ additional_arguments), HOLogic.mk_tuple outargs)
1346 val t = fold_rev mk_split_lambda args r
1351 structure Comp_Mod =
1354 datatype comp_modifiers = Comp_Modifiers of
1356 compilation : compilation,
1357 function_name_prefix : string,
1358 compfuns : compilation_funs,
1359 additional_arguments : string list -> term list,
1360 wrap_compilation : compilation_funs -> string -> typ -> mode -> term list -> term -> term,
1361 transform_additional_arguments : indprem -> term list -> term list
1364 fun dest_comp_modifiers (Comp_Modifiers c) = c
1366 val compilation = #compilation o dest_comp_modifiers
1367 val function_name_prefix = #function_name_prefix o dest_comp_modifiers
1368 val compfuns = #compfuns o dest_comp_modifiers
1369 val funT_of = funT_of o compfuns
1370 val additional_arguments = #additional_arguments o dest_comp_modifiers
1371 val wrap_compilation = #wrap_compilation o dest_comp_modifiers
1372 val transform_additional_arguments = #transform_additional_arguments o dest_comp_modifiers
1376 (* TODO: uses param_vs -- change necessary for compilation with new modes *)
1377 fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg =
1379 fun map_params (t as Free (f, T)) =
1380 if member (op =) param_vs f then
1381 case (AList.lookup (op =) (param_vs ~~ iss) f) of
1384 val _ = error "compile_arg: A parameter in a input position -- do we have a test case?"
1385 val T' = Comp_Mod.funT_of compilation_modifiers is T
1386 in t(*fst (mk_Eval_of additional_arguments ((Free (f, T'), T), is) [])*) end
1390 in map_aterms map_params arg end
1392 fun compile_match compilation_modifiers compfuns additional_arguments
1393 param_vs iss thy eqs eqs' out_ts success_t =
1395 val eqs'' = maps mk_eq eqs @ eqs'
1397 map (compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss) eqs''
1398 val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
1399 val name = Name.variant names "x";
1400 val name' = Name.variant (name :: names) "y";
1401 val T = HOLogic.mk_tupleT (map fastype_of out_ts);
1402 val U = fastype_of success_t;
1403 val U' = dest_predT compfuns U;
1404 val v = Free (name, T);
1405 val v' = Free (name', T);
1407 lambda v (fst (Datatype.make_case
1408 (ProofContext.init thy) Datatype_Case.Quiet [] v
1409 [(HOLogic.mk_tuple out_ts,
1410 if null eqs'' then success_t
1411 else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
1412 foldr1 HOLogic.mk_conj eqs'' $ success_t $
1413 mk_bot compfuns U'),
1414 (v', mk_bot compfuns U')]))
1417 fun compile_expr compilation_modifiers compfuns thy (t, deriv) additional_arguments =
1419 fun expr_of (t, deriv) =
1421 (t, Term Input) => SOME t
1422 | (t, Term Output) => NONE
1423 | (Const (name, T), Context mode) =>
1424 SOME (Const (function_name_of (Comp_Mod.compilation compilation_modifiers) thy name mode,
1425 Comp_Mod.funT_of compilation_modifiers mode T))
1426 | (Free (s, T), Context m) =>
1427 SOME (Free (s, Comp_Mod.funT_of compilation_modifiers m T))
1430 val bs = map (pair "x") (binder_types (fastype_of t))
1431 val bounds = map Bound (rev (0 upto (length bs) - 1))
1432 in SOME (list_abs (bs, mk_if compfuns (list_comb (t, bounds)))) end
1433 | (Const ("Pair", _) $ t1 $ t2, Mode_Pair (d1, d2)) =>
1434 (case (expr_of (t1, d1), expr_of (t2, d2)) of
1435 (NONE, NONE) => NONE
1436 | (NONE, SOME t) => SOME t
1437 | (SOME t, NONE) => SOME t
1438 | (SOME t1, SOME t2) => SOME (HOLogic.mk_prod (t1, t2)))
1439 | (t1 $ t2, Mode_App (deriv1, deriv2)) =>
1440 (case (expr_of (t1, deriv1), expr_of (t2, deriv2)) of
1441 (SOME t, NONE) => SOME t
1442 | (SOME t, SOME u) => SOME (t $ u)
1443 | _ => error "something went wrong here!"))
1445 the (expr_of (t, deriv))
1448 fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments
1449 mode inp (ts, moded_ps) =
1451 val iss = ho_arg_modes_of mode
1452 val compile_match = compile_match compilation_modifiers compfuns
1453 additional_arguments param_vs iss thy
1454 val (in_ts, out_ts) = split_mode mode ts;
1455 val (in_ts', (all_vs', eqs)) =
1456 fold_map (collect_non_invertible_subterms thy) in_ts (all_vs, []);
1457 fun compile_prems out_ts' vs names [] =
1459 val (out_ts'', (names', eqs')) =
1460 fold_map (collect_non_invertible_subterms thy) out_ts' (names, []);
1461 val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
1462 out_ts'' (names', map (rpair []) vs);
1464 compile_match constr_vs (eqs @ eqs') out_ts'''
1465 (mk_single compfuns (HOLogic.mk_tuple out_ts))
1467 | compile_prems out_ts vs names ((p, deriv) :: ps) =
1469 val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
1470 val (out_ts', (names', eqs)) =
1471 fold_map (collect_non_invertible_subterms thy) out_ts (names, [])
1472 val (out_ts'', (names'', constr_vs')) = fold_map distinct_v
1473 out_ts' ((names', map (rpair []) vs))
1474 val mode = head_mode_of deriv
1475 val additional_arguments' =
1476 Comp_Mod.transform_additional_arguments compilation_modifiers p additional_arguments
1477 val (compiled_clause, rest) = case p of
1481 compile_expr compilation_modifiers compfuns thy
1482 (t, deriv) additional_arguments'
1483 val (_, out_ts''') = split_mode mode (snd (strip_comb t))
1484 val rest = compile_prems out_ts''' vs' names'' ps
1490 val u = mk_not compfuns
1491 (compile_expr compilation_modifiers compfuns thy
1492 (t, deriv) additional_arguments')
1493 val (_, out_ts''') = split_mode mode (snd (strip_comb t))
1494 val rest = compile_prems out_ts''' vs' names'' ps
1500 val t = compile_arg compilation_modifiers compfuns additional_arguments
1502 val rest = compile_prems [] vs' names'' ps;
1504 (mk_if compfuns t, rest)
1506 | Generator (v, T) =>
1509 val rest = compile_prems [Free (v, T)] vs' names'' ps;
1514 compile_match constr_vs' eqs out_ts''
1515 (mk_bind compfuns (compiled_clause, rest))
1517 val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
1519 mk_bind compfuns (mk_single compfuns inp, prem_t)
1522 fun compile_pred compilation_modifiers thy all_vs param_vs s T mode moded_cls =
1524 (* TODO: val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
1527 val compfuns = Comp_Mod.compfuns compilation_modifiers
1528 fun is_param_type (T as Type ("fun",[_ , T'])) =
1529 is_some (try (dest_predT compfuns) T) orelse is_param_type T'
1530 | is_param_type T = is_some (try (dest_predT compfuns) T)
1531 val additional_arguments = []
1532 val (inpTs, outTs) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode
1534 val predT = mk_predT compfuns (HOLogic.mk_tupleT outTs)
1535 val funT = Comp_Mod.funT_of compilation_modifiers mode T
1537 val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
1538 (fn T => fn (param_vs, names) =>
1539 if is_param_type T then
1540 (Free (hd param_vs, T), (tl param_vs, names))
1543 val new = Name.variant names "x"
1544 in (Free (new, T), (param_vs, new :: names)) end)) inpTs
1545 (param_vs, (all_vs @ param_vs))
1546 val in_ts' = map_filter (map_filter_prod
1547 (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
1549 map (compile_clause compilation_modifiers compfuns
1550 thy all_vs param_vs additional_arguments mode (HOLogic.mk_tuple in_ts')) moded_cls;
1551 val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns
1552 s T mode additional_arguments
1554 mk_bot compfuns (HOLogic.mk_tupleT outTs)
1555 else foldr1 (mk_sup compfuns) cl_ts)
1557 Const (function_name_of (Comp_Mod.compilation compilation_modifiers) thy s mode, funT)
1560 (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation))
1563 (* special setup for simpset *)
1564 val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}])
1565 setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
1566 setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))
1568 (* Definition of executable functions and their intro and elim rules *)
1570 fun print_arities arities = tracing ("Arities:\n" ^
1571 cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
1572 space_implode " -> " (map
1573 (fn NONE => "X" | SOME k' => string_of_int k')
1574 (ks @ [SOME k]))) arities));
1576 fun split_lambda (x as Free _) t = lambda x t
1577 | split_lambda (Const ("Pair", _) $ t1 $ t2) t =
1578 HOLogic.mk_split (split_lambda t1 (split_lambda t2 t))
1579 | split_lambda (Const ("Product_Type.Unity", _)) t = Abs ("x", HOLogic.unitT, t)
1580 | split_lambda t _ = raise (TERM ("split_lambda", [t]))
1582 fun strip_split_abs (Const ("split", _) $ t) = strip_split_abs t
1583 | strip_split_abs (Abs (_, _, t)) = strip_split_abs t
1584 | strip_split_abs t = t
1586 fun mk_args is_eval (Pair (m1, m2), Type ("*", [T1, T2])) names =
1588 val (t1, names') = mk_args is_eval (m1, T1) names
1589 val (t2, names'') = mk_args is_eval (m2, T2) names'
1591 (HOLogic.mk_prod (t1, t2), names'')
1593 | mk_args is_eval ((m as Fun _), T) names =
1595 val funT = funT_of PredicateCompFuns.compfuns m T
1596 val x = Name.variant names "x"
1597 val (args, _) = fold_map (mk_args is_eval) (strip_fun_mode m ~~ binder_types T) (x :: names)
1598 val (inargs, outargs) = split_map_mode (fn _ => fn t => (SOME t, NONE)) m args
1599 val t = fold_rev split_lambda args (PredicateCompFuns.mk_Eval
1600 (list_comb (Free (x, funT), inargs), HOLogic.mk_tuple outargs))
1602 (if is_eval then t else Free (x, funT), x :: names)
1604 | mk_args is_eval (_, T) names =
1606 val x = Name.variant names "x"
1608 (Free (x, T), x :: names)
1611 fun create_intro_elim_rule mode defthm mode_id funT pred thy =
1613 val funtrm = Const (mode_id, funT)
1614 val Ts = binder_types (fastype_of pred)
1615 val (args, argnames) = fold_map (mk_args true) (strip_fun_mode mode ~~ Ts) []
1616 fun strip_eval _ t =
1618 val t' = strip_split_abs t
1619 val (r, _) = PredicateCompFuns.dest_Eval t'
1620 in (SOME (fst (strip_comb r)), NONE) end
1621 val (inargs, outargs) = split_map_mode strip_eval mode args
1622 val eval_hoargs = ho_args_of mode args
1623 val hoargTs = ho_argsT_of mode Ts
1625 Name.variant_list argnames ((map (fn i => "x" ^ string_of_int i)) (1 upto (length hoargTs)))
1626 val hoargs' = map2 (curry Free) hoarg_names' hoargTs
1627 val args' = replace_ho_args mode hoargs' args
1628 val predpropI = HOLogic.mk_Trueprop (list_comb (pred, args'))
1629 val predpropE = HOLogic.mk_Trueprop (list_comb (pred, args))
1630 val param_eqs = map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) eval_hoargs hoargs'
1631 val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, inargs),
1632 if null outargs then Free("y", HOLogic.unitT) else HOLogic.mk_tuple outargs))
1633 val funpropI = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, inargs),
1634 HOLogic.mk_tuple outargs))
1635 val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
1636 val simprules = [defthm, @{thm eval_pred},
1637 @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}, @{thm pair_collapse}]
1638 val unfolddef_tac = Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1
1639 val introthm = Goal.prove (ProofContext.init thy)
1640 (argnames @ hoarg_names' @ ["y"]) [] introtrm (fn _ => unfolddef_tac)
1641 val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
1642 val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predpropE, P)], P)
1643 val elimthm = Goal.prove (ProofContext.init thy)
1644 (argnames @ ["y", "P"]) [] elimtrm (fn _ => unfolddef_tac)
1649 fun create_constname_of_mode options thy prefix name T mode =
1651 val system_proposal = prefix ^ (Long_Name.base_name name)
1652 ^ "_" ^ ascii_string_of_mode mode
1653 val name = the_default system_proposal (proposed_names options name mode)
1655 Sign.full_bname thy name
1658 fun create_definitions options preds (name, modes) thy =
1660 val compfuns = PredicateCompFuns.compfuns
1661 val T = AList.lookup (op =) preds name |> the
1662 fun create_definition mode thy =
1664 val mode_cname = create_constname_of_mode options thy "" name T mode
1665 val mode_cbasename = Long_Name.base_name mode_cname
1666 val funT = funT_of compfuns mode T
1667 val (args, _) = fold_map (mk_args true) ((strip_fun_mode mode) ~~ (binder_types T)) []
1668 fun strip_eval m t =
1670 val t' = strip_split_abs t
1671 val (r, _) = PredicateCompFuns.dest_Eval t'
1672 in (SOME (fst (strip_comb r)), NONE) end
1673 val (inargs, outargs) = split_map_mode strip_eval mode args
1674 val predterm = fold_rev split_lambda inargs
1675 (PredicateCompFuns.mk_Enum (split_lambda (HOLogic.mk_tuple outargs)
1676 (list_comb (Const (name, T), args))))
1677 val lhs = Const (mode_cname, funT)
1678 val def = Logic.mk_equals (lhs, predterm)
1679 val ([definition], thy') = thy |>
1680 Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |>
1681 PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])]
1683 create_intro_elim_rule mode definition mode_cname funT (Const (name, T)) thy'
1685 |> set_function_name Pred name mode mode_cname
1686 |> add_predfun_data name mode (definition, intro, elim)
1687 |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd
1688 |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim) |> snd
1689 |> Theory.checkpoint
1692 thy |> defined_function_of Pred name |> fold create_definition modes
1695 fun define_functions comp_modifiers compfuns options preds (name, modes) thy =
1697 val T = AList.lookup (op =) preds name |> the
1698 fun create_definition mode thy =
1700 val function_name_prefix = Comp_Mod.function_name_prefix comp_modifiers
1701 val mode_cname = create_constname_of_mode options thy function_name_prefix name T mode
1702 val funT = Comp_Mod.funT_of comp_modifiers mode T
1704 thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
1705 |> set_function_name (Comp_Mod.compilation comp_modifiers) name mode mode_cname
1709 |> defined_function_of (Comp_Mod.compilation comp_modifiers) name
1710 |> fold create_definition modes
1713 (* Proving equivalence of term *)
1715 fun is_Type (Type _) = true
1718 (* returns true if t is an application of an datatype constructor *)
1719 (* which then consequently would be splitted *)
1721 fun is_constructor thy t =
1722 if (is_Type (fastype_of t)) then
1723 (case Datatype.get_info thy ((fst o dest_Type o fastype_of) t) of
1726 val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
1727 val (c, _) = strip_comb t
1729 Const (name, _) => name mem_string constr_consts
1733 (* MAJOR FIXME: prove_params should be simple
1734 - different form of introrule for parameters ? *)
1736 fun prove_param options thy t deriv =
1738 val (f, args) = strip_comb (Envir.eta_contract t)
1739 val mode = head_mode_of deriv
1740 val param_derivations = param_derivations_of deriv
1741 val ho_args = ho_args_of mode args
1742 val f_tac = case f of
1743 Const (name, T) => simp_tac (HOL_basic_ss addsimps
1744 ([@{thm eval_pred}, (predfun_definition_of thy name mode),
1745 @{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
1746 @{thm "snd_conv"}, @{thm pair_collapse}, @{thm "Product_Type.split_conv"}])) 1
1747 | Free _ => TRY (rtac @{thm refl} 1)
1748 | Abs _ => error "prove_param: No valid parameter term"
1750 REPEAT_DETERM (rtac @{thm ext} 1)
1751 THEN print_tac' options "prove_param"
1753 THEN print_tac' options "after simplification in prove_args"
1754 THEN (REPEAT_DETERM (atac 1))
1755 THEN (EVERY (map2 (prove_param options thy) ho_args param_derivations))
1758 fun prove_expr options thy (premposition : int) (t, deriv) =
1759 case strip_comb t of
1760 (Const (name, T), args) =>
1762 val mode = head_mode_of deriv
1763 val introrule = predfun_intro_of thy name mode
1764 val param_derivations = param_derivations_of deriv
1765 val ho_args = ho_args_of mode args
1767 print_tac' options "before intro rule:"
1768 (* for the right assumption in first position *)
1769 THEN rotate_tac premposition 1
1770 THEN debug_tac (Display.string_of_thm (ProofContext.init thy) introrule)
1771 THEN rtac introrule 1
1772 THEN print_tac' options "after intro rule"
1773 (* work with parameter arguments *)
1775 THEN print_tac' options "parameter goal"
1776 THEN (EVERY (map2 (prove_param options thy) ho_args param_derivations))
1777 THEN (REPEAT_DETERM (atac 1))
1781 (HOL_basic_ss' addsimps [@{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
1782 @{thm "snd_conv"}, @{thm pair_collapse}]) 1
1784 THEN print_tac' options "after prove parameter call"
1787 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st;
1789 fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
1791 fun check_format thy st =
1793 val concl' = Logic.strip_assums_concl (hd (prems_of st))
1794 val concl = HOLogic.dest_Trueprop concl'
1795 val expr = fst (strip_comb (fst (PredicateCompFuns.dest_Eval concl)))
1796 fun valid_expr (Const (@{const_name Predicate.bind}, _)) = true
1797 | valid_expr (Const (@{const_name Predicate.single}, _)) = true
1798 | valid_expr _ = false
1800 if valid_expr expr then
1801 ((*tracing "expression is valid";*) Seq.single st)
1803 ((*tracing "expression is not valid";*) Seq.empty) (*error "check_format: wrong format"*)
1806 fun prove_match options thy (out_ts : term list) =
1808 fun get_case_rewrite t =
1809 if (is_constructor thy t) then let
1810 val case_rewrites = (#case_rewrites (Datatype.the_info thy
1811 ((fst o dest_Type o fastype_of) t)))
1812 in case_rewrites @ maps get_case_rewrite (snd (strip_comb t)) end
1814 val simprules = @{thm "unit.cases"} :: @{thm "prod.cases"} :: maps get_case_rewrite out_ts
1815 (* replace TRY by determining if it necessary - are there equations when calling compile match? *)
1817 (* make this simpset better! *)
1818 asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1
1819 THEN print_tac' options "after prove_match:"
1820 THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm HOL.if_P}] 1
1821 THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1))))
1822 THEN print_tac' options "if condition to be solved:"
1823 THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1 THEN print_tac' options "after if simp; in SOLVED:"))
1824 THEN check_format thy
1825 THEN print_tac' options "after if simplification - a TRY block")))
1826 THEN print_tac' options "after if simplification"
1829 (* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
1831 fun prove_sidecond thy modes t =
1833 fun preds_of t nameTs = case strip_comb t of
1834 (f as Const (name, T), args) =>
1835 if AList.defined (op =) modes name then (name, T) :: nameTs
1836 else fold preds_of args nameTs
1838 val preds = preds_of t []
1840 (fn (pred, T) => predfun_definition_of thy pred
1844 (* remove not_False_eq_True when simpset in prove_match is better *)
1845 simp_tac (HOL_basic_ss addsimps
1846 (@{thms HOL.simp_thms} @ (@{thm not_False_eq_True} :: @{thm eval_pred} :: defs))) 1
1847 (* need better control here! *)
1850 fun prove_clause options thy nargs modes mode (_, clauses) (ts, moded_ps) =
1852 val (in_ts, clause_out_ts) = split_mode mode ts;
1853 fun prove_prems out_ts [] =
1854 (prove_match options thy out_ts)
1855 THEN print_tac' options "before simplifying assumptions"
1856 THEN asm_full_simp_tac HOL_basic_ss' 1
1857 THEN print_tac' options "before single intro rule"
1858 THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
1859 | prove_prems out_ts ((p, deriv) :: ps) =
1861 val premposition = (find_index (equal p) clauses) + nargs
1862 val mode = head_mode_of deriv
1865 THEN (case p of Prem t =>
1867 val (_, us) = strip_comb t
1868 val (_, out_ts''') = split_mode mode us
1869 val rec_tac = prove_prems out_ts''' ps
1871 print_tac' options "before clause:"
1872 (*THEN asm_simp_tac HOL_basic_ss 1*)
1873 THEN print_tac' options "before prove_expr:"
1874 THEN prove_expr options thy premposition (t, deriv)
1875 THEN print_tac' options "after prove_expr:"
1880 val (t, args) = strip_comb t
1881 val (_, out_ts''') = split_mode mode args
1882 val rec_tac = prove_prems out_ts''' ps
1883 val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
1884 val param_derivations = param_derivations_of deriv
1885 val params = ho_args_of mode args
1887 print_tac' options "before prove_neg_expr:"
1888 THEN full_simp_tac (HOL_basic_ss addsimps
1889 [@{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
1890 @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
1891 THEN (if (is_some name) then
1892 print_tac' options ("before unfolding definition " ^
1893 (Display.string_of_thm_global thy
1894 (predfun_definition_of thy (the name) mode)))
1896 THEN simp_tac (HOL_basic_ss addsimps
1897 [predfun_definition_of thy (the name) mode]) 1
1898 THEN rtac @{thm not_predI} 1
1899 THEN print_tac' options "after applying rule not_predI"
1900 THEN full_simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True},
1901 @{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
1902 @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
1903 THEN (REPEAT_DETERM (atac 1))
1904 THEN (EVERY (map2 (prove_param options thy) params param_derivations))
1905 THEN (REPEAT_DETERM (atac 1))
1907 rtac @{thm not_predI'} 1)
1908 THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
1912 rtac @{thm if_predI} 1
1913 THEN print_tac' options "before sidecond:"
1914 THEN prove_sidecond thy modes t
1915 THEN print_tac' options "after sidecond:"
1916 THEN prove_prems [] ps)
1917 in (prove_match options thy out_ts)
1920 val prems_tac = prove_prems in_ts moded_ps
1922 print_tac' options "Proving clause..."
1923 THEN rtac @{thm bindI} 1
1924 THEN rtac @{thm singleI} 1
1928 fun select_sup 1 1 = []
1929 | select_sup _ 1 = [rtac @{thm supI1}]
1930 | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
1932 fun prove_one_direction options thy clauses preds modes pred mode moded_clauses =
1934 val T = the (AList.lookup (op =) preds pred)
1935 val nargs = length (binder_types T)
1936 val pred_case_rule = the_elim_of thy pred
1938 REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
1939 THEN print_tac' options "before applying elim rule"
1940 THEN etac (predfun_elim_of thy pred mode) 1
1941 THEN etac pred_case_rule 1
1943 (fn i => EVERY' (select_sup (length moded_clauses) i) i)
1944 (1 upto (length moded_clauses))))
1945 THEN (EVERY (map2 (prove_clause options thy nargs modes mode) clauses moded_clauses))
1946 THEN print_tac' options "proved one direction"
1949 (** Proof in the other direction **)
1951 fun prove_match2 thy out_ts = let
1952 fun split_term_tac (Free _) = all_tac
1953 | split_term_tac t =
1954 if (is_constructor thy t) then let
1955 val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t)
1956 val num_of_constrs = length (#case_rewrites info)
1957 (* special treatment of pairs -- because of fishing *)
1958 val split_rules = case (fst o dest_Type o fastype_of) t of
1959 "*" => [@{thm prod.split_asm}]
1960 | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
1961 val (_, ts) = strip_comb t
1963 (print_tac ("Term " ^ (Syntax.string_of_term_global thy t) ^
1964 "splitting with rules \n" ^
1965 commas (map (Display.string_of_thm_global thy) split_rules)))
1966 THEN TRY ((Splitter.split_asm_tac split_rules 1)
1967 THEN (print_tac "after splitting with split_asm rules")
1968 (* THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
1969 THEN (DETERM (TRY (etac @{thm Pair_inject} 1)))*)
1970 THEN (REPEAT_DETERM_N (num_of_constrs - 1)
1971 (etac @{thm botE} 1 ORELSE etac @{thm botE} 2)))
1972 THEN (assert_tac (Max_number_of_subgoals 2))
1973 THEN (EVERY (map split_term_tac ts))
1977 split_term_tac (HOLogic.mk_tuple out_ts)
1978 THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1)
1979 THEN (etac @{thm botE} 2))))
1982 (* VERY LARGE SIMILIRATIY to function prove_param
1983 -- join both functions
1985 (* TODO: remove function *)
1987 fun prove_param2 thy t deriv =
1989 val (f, args) = strip_comb (Envir.eta_contract t)
1990 val mode = head_mode_of deriv
1991 val param_derivations = param_derivations_of deriv
1992 val ho_args = ho_args_of mode args
1993 val f_tac = case f of
1994 Const (name, T) => full_simp_tac (HOL_basic_ss addsimps
1995 (@{thm eval_pred}::(predfun_definition_of thy name mode)
1996 :: @{thm "Product_Type.split_conv"}::[])) 1
1998 | _ => error "prove_param2: illegal parameter term"
2000 print_tac "before simplification in prove_args:"
2002 THEN print_tac "after simplification in prove_args"
2003 THEN EVERY (map2 (prove_param2 thy) ho_args param_derivations)
2006 fun prove_expr2 thy (t, deriv) =
2007 (case strip_comb t of
2008 (Const (name, T), args) =>
2010 val mode = head_mode_of deriv
2011 val param_derivations = param_derivations_of deriv
2012 val ho_args = ho_args_of mode args
2015 THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
2016 THEN print_tac "prove_expr2-before"
2017 THEN (debug_tac (Syntax.string_of_term_global thy
2018 (prop_of (predfun_elim_of thy name mode))))
2019 THEN (etac (predfun_elim_of thy name mode) 1)
2020 THEN print_tac "prove_expr2"
2021 THEN (EVERY (map2 (prove_param2 thy) ho_args param_derivations))
2022 THEN print_tac "finished prove_expr2"
2024 | _ => etac @{thm bindE} 1)
2026 (* FIXME: what is this for? *)
2027 (* replace defined by has_mode thy pred *)
2028 (* TODO: rewrite function *)
2029 fun prove_sidecond2 thy modes t = let
2030 fun preds_of t nameTs = case strip_comb t of
2031 (f as Const (name, T), args) =>
2032 if AList.defined (op =) modes name then (name, T) :: nameTs
2033 else fold preds_of args nameTs
2035 val preds = preds_of t []
2037 (fn (pred, T) => predfun_definition_of thy pred
2041 (* only simplify the one assumption *)
2042 full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1
2043 (* need better control here! *)
2044 THEN print_tac "after sidecond2 simplification"
2047 fun prove_clause2 thy modes pred mode (ts, ps) i =
2049 val pred_intro_rule = nth (intros_of thy pred) (i - 1)
2050 val (in_ts, clause_out_ts) = split_mode mode ts;
2051 fun prove_prems2 out_ts [] =
2052 print_tac "before prove_match2 - last call:"
2053 THEN prove_match2 thy out_ts
2054 THEN print_tac "after prove_match2 - last call:"
2055 THEN (etac @{thm singleE} 1)
2056 THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
2057 THEN (asm_full_simp_tac HOL_basic_ss' 1)
2058 THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
2059 THEN (asm_full_simp_tac HOL_basic_ss' 1)
2060 THEN SOLVED (print_tac "state before applying intro rule:"
2061 THEN (rtac pred_intro_rule 1)
2062 (* How to handle equality correctly? *)
2063 THEN (print_tac "state before assumption matching")
2064 THEN (REPEAT (atac 1 ORELSE
2065 (CHANGED (asm_full_simp_tac (HOL_basic_ss' addsimps
2066 [@{thm split_eta}, @{thm "split_beta"}, @{thm "fst_conv"},
2067 @{thm "snd_conv"}, @{thm pair_collapse}]) 1)
2068 THEN print_tac "state after simp_tac:"))))
2069 | prove_prems2 out_ts ((p, deriv) :: ps) =
2071 val mode = head_mode_of deriv
2072 val rest_tac = (case p of
2075 val (_, us) = strip_comb t
2076 val (_, out_ts''') = split_mode mode us
2077 val rec_tac = prove_prems2 out_ts''' ps
2079 (prove_expr2 thy (t, deriv)) THEN rec_tac
2083 val (_, args) = strip_comb t
2084 val (_, out_ts''') = split_mode mode args
2085 val rec_tac = prove_prems2 out_ts''' ps
2086 val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
2087 val param_derivations = param_derivations_of deriv
2088 val ho_args = ho_args_of mode args
2090 print_tac "before neg prem 2"
2091 THEN etac @{thm bindE} 1
2092 THEN (if is_some name then
2093 full_simp_tac (HOL_basic_ss addsimps
2094 [predfun_definition_of thy (the name) mode]) 1
2095 THEN etac @{thm not_predE} 1
2096 THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
2097 THEN (EVERY (map2 (prove_param2 thy) ho_args param_derivations))
2099 etac @{thm not_predE'} 1)
2104 THEN etac @{thm if_predE} 1
2105 THEN prove_sidecond2 thy modes t
2106 THEN prove_prems2 [] ps)
2107 in print_tac "before prove_match2:"
2108 THEN prove_match2 thy out_ts
2109 THEN print_tac "after prove_match2:"
2112 val prems_tac = prove_prems2 in_ts ps
2114 print_tac "starting prove_clause2"
2115 THEN etac @{thm bindE} 1
2116 THEN (etac @{thm singleE'} 1)
2117 THEN (TRY (etac @{thm Pair_inject} 1))
2118 THEN print_tac "after singleE':"
2122 fun prove_other_direction options thy modes pred mode moded_clauses =
2124 fun prove_clause clause i =
2125 (if i < length moded_clauses then etac @{thm supE} 1 else all_tac)
2126 THEN (prove_clause2 thy modes pred mode clause i)
2128 (DETERM (TRY (rtac @{thm unit.induct} 1)))
2129 THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
2130 THEN (rtac (predfun_intro_of thy pred mode) 1)
2131 THEN (REPEAT_DETERM (rtac @{thm refl} 2))
2132 THEN (if null moded_clauses then
2134 else EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses))))
2137 (** proof procedure **)
2139 fun prove_pred options thy clauses preds modes pred mode (moded_clauses, compiled_term) =
2141 val ctxt = ProofContext.init thy
2142 val clauses = case AList.lookup (op =) clauses pred of SOME rs => rs | NONE => []
2144 Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term
2145 (if not (skip_proof options) then
2147 rtac @{thm pred_iffI} 1
2148 THEN print_tac' options "after pred_iffI"
2149 THEN prove_one_direction options thy clauses preds modes pred mode moded_clauses
2150 THEN print_tac' options "proved one direction"
2151 THEN prove_other_direction options thy modes pred mode moded_clauses
2152 THEN print_tac' options "proved other direction")
2153 else (fn _ => Skip_Proof.cheat_tac thy))
2156 (* composition of mode inference, definition, compilation and proof *)
2158 (** auxillary combinators for table of preds and modes **)
2160 fun map_preds_modes f preds_modes_table =
2161 map (fn (pred, modes) =>
2162 (pred, map (fn (mode, value) => (mode, f pred mode value)) modes)) preds_modes_table
2164 fun join_preds_modes table1 table2 =
2165 map_preds_modes (fn pred => fn mode => fn value =>
2166 (value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1
2168 fun maps_modes preds_modes_table =
2169 map (fn (pred, modes) =>
2170 (pred, map (fn (mode, value) => value) modes)) preds_modes_table
2172 fun compile_preds comp_modifiers thy all_vs param_vs preds moded_clauses =
2173 map_preds_modes (fn pred => compile_pred comp_modifiers thy all_vs param_vs pred
2174 (the (AList.lookup (op =) preds pred))) moded_clauses
2176 fun prove options thy clauses preds modes moded_clauses compiled_terms =
2177 map_preds_modes (prove_pred options thy clauses preds modes)
2178 (join_preds_modes moded_clauses compiled_terms)
2180 fun prove_by_skip options thy _ _ _ _ compiled_terms =
2182 (fn pred => fn mode => fn t => Drule.export_without_context (Skip_Proof.make_thm thy t))
2185 (* preparation of introduction rules into special datastructures *)
2187 fun dest_prem thy params t =
2188 (case strip_comb t of
2189 (v as Free _, ts) => if member (op =) params v then Prem t else Sidecond t
2190 | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem thy params t of
2192 | Negprem _ => error ("Double negation not allowed in premise: " ^
2193 Syntax.string_of_term_global thy (c $ t))
2194 | Sidecond t => Sidecond (c $ t))
2195 | (c as Const (s, _), ts) =>
2196 if is_registered thy s then Prem t else Sidecond t
2199 fun prepare_intrs options compilation thy prednames intros =
2201 val intrs = map prop_of intros
2202 val preds = map (fn c => Const (c, Sign.the_const_type thy c)) prednames
2203 val (preds, intrs) = unify_consts thy preds intrs
2204 val ([preds, intrs], _) = fold_burrow (Variable.import_terms false) [preds, intrs]
2205 (ProofContext.init thy)
2206 val preds = map dest_Const preds
2208 all_modes_of compilation thy |> filter_out (fn (name, _) => member (op =) prednames name)
2209 val all_vs = terms_vs intrs
2214 val T = snd (hd preds)
2216 ho_argsT_of (hd (all_modes_of_typ T)) (binder_types T)
2217 val param_names = Name.variant_list [] (map (fn i => "p" ^ string_of_int i)
2218 (1 upto length paramTs))
2220 map2 (curry Free) param_names paramTs
2222 | (intr :: _) => maps extract_params
2223 (snd (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl intr))))
2224 val param_vs = map (fst o dest_Free) params
2225 fun add_clause intr clauses =
2227 val (Const (name, T), ts) = strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl intr))
2228 val prems = map (dest_prem thy params o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr)
2230 AList.update op = (name, these (AList.lookup op = clauses name) @
2231 [(ts, prems)]) clauses
2233 val clauses = fold add_clause intrs []
2235 (preds, all_vs, param_vs, extra_modes, clauses)
2238 (* sanity check of introduction rules *)
2239 (* TODO: rethink check with new modes *)
2241 fun check_format_of_intro_rule thy intro =
2243 val concl = Logic.strip_imp_concl (prop_of intro)
2244 val (p, args) = strip_comb (HOLogic.dest_Trueprop concl)
2245 val params = fst (chop (nparams_of thy (fst (dest_Const p))) args)
2246 fun check_arg arg = case HOLogic.strip_tupleT (fastype_of arg) of
2247 (Ts as _ :: _ :: _) =>
2248 if length (HOLogic.strip_tuple arg) = length Ts then
2251 error ("Format of introduction rule is invalid: tuples must be expanded:"
2252 ^ (Syntax.string_of_term_global thy arg) ^ " in " ^
2253 (Display.string_of_thm_global thy intro))
2255 val prems = Logic.strip_imp_prems (prop_of intro)
2256 fun check_prem (Prem t) = forall check_arg args
2257 | check_prem (Negprem t) = forall check_arg args
2258 | check_prem _ = true
2260 forall check_arg args andalso
2261 forall (check_prem o dest_prem thy params o HOLogic.dest_Trueprop) prems
2265 fun check_intros_elim_match thy prednames =
2267 fun check predname =
2269 val intros = intros_of thy predname
2270 val elim = the_elim_of thy predname
2271 val nparams = nparams_of thy predname
2273 (Drule.export_without_context o Skip_Proof.make_thm thy)
2274 (mk_casesrule (ProofContext.init thy) nparams intros)
2276 if not (Thm.equiv_thm (elim, elim')) then
2277 error "Introduction and elimination rules do not match!"
2280 in forall check prednames end
2283 (* create code equation *)
2285 fun add_code_equations thy preds result_thmss =
2287 fun add_code_equation (predname, T) (pred, result_thms) =
2289 val full_mode = fold_rev (curry Fun) (map (K Input) (binder_types T)) Bool
2291 if member (op =) (modes_of Pred thy predname) full_mode then
2293 val Ts = binder_types T
2294 val arg_names = Name.variant_list []
2295 (map (fn i => "x" ^ string_of_int i) (1 upto length Ts))
2296 val args = map2 (curry Free) arg_names Ts
2297 val predfun = Const (function_name_of Pred thy predname full_mode,
2298 Ts ---> PredicateCompFuns.mk_predT @{typ unit})
2299 val rhs = @{term Predicate.holds} $ (list_comb (predfun, args))
2300 val eq_term = HOLogic.mk_Trueprop
2301 (HOLogic.mk_eq (list_comb (Const (predname, T), args), rhs))
2302 val def = predfun_definition_of thy predname full_mode
2303 val tac = fn _ => Simplifier.simp_tac
2304 (HOL_basic_ss addsimps [def, @{thm holds_eq}, @{thm eval_pred}]) 1
2305 val eq = Goal.prove (ProofContext.init thy) arg_names [] eq_term tac
2307 (pred, result_thms @ [eq])
2313 map2 add_code_equation preds result_thmss
2316 (** main function of predicate compiler **)
2318 datatype steps = Steps of
2320 define_functions : options -> (string * typ) list -> string * mode list -> theory -> theory,
2321 infer_modes : options -> (string * typ) list -> (string * mode list) list
2322 -> string list -> (string * (term list * indprem list) list) list
2323 -> theory -> ((moded_clause list pred_mode_table * string list) * theory),
2324 prove : options -> theory -> (string * (term list * indprem list) list) list
2325 -> (string * typ) list -> (string * mode list) list
2326 -> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table,
2327 add_code_equations : theory -> (string * typ) list
2328 -> (string * thm list) list -> (string * thm list) list,
2329 comp_modifiers : Comp_Mod.comp_modifiers,
2333 fun add_equations_of steps options prednames thy =
2335 fun dest_steps (Steps s) = s
2336 val _ = print_step options
2337 ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
2338 (*val _ = check_intros_elim_match thy prednames*)
2339 (*val _ = map (check_format_of_intro_rule thy) (maps (intros_of thy) prednames)*)
2340 val compilation = Comp_Mod.compilation (#comp_modifiers (dest_steps steps))
2341 val (preds, all_vs, param_vs, extra_modes, clauses) =
2342 prepare_intrs options compilation thy prednames (maps (intros_of thy) prednames)
2343 val _ = print_step options "Infering modes..."
2344 val ((moded_clauses, errors), thy') =
2345 #infer_modes (dest_steps steps) options preds extra_modes param_vs clauses thy
2346 val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
2347 val _ = check_expected_modes preds options modes
2348 val _ = check_proposed_modes preds options modes extra_modes errors
2349 val _ = print_modes options thy' modes
2350 val _ = print_step options "Defining executable functions..."
2351 val thy'' = fold (#define_functions (dest_steps steps) options preds) modes thy'
2352 |> Theory.checkpoint
2353 val _ = print_step options "Compiling equations..."
2354 val compiled_terms =
2355 compile_preds (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses
2356 val _ = print_compiled_terms options thy'' compiled_terms
2357 val _ = print_step options "Proving equations..."
2358 val result_thms = #prove (dest_steps steps) options thy'' clauses preds (extra_modes @ modes)
2359 moded_clauses compiled_terms
2360 val result_thms' = #add_code_equations (dest_steps steps) thy'' preds
2361 (maps_modes result_thms)
2362 val qname = #qname (dest_steps steps)
2363 val attrib = fn thy => Attrib.attribute_i thy (Attrib.internal (K (Thm.declaration_attribute
2364 (fn thm => Context.mapping (Code.add_eqn thm) I))))
2365 val thy''' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss
2366 [((Binding.qualify true (Long_Name.base_name name) (Binding.name qname), result_thms),
2367 [attrib thy ])] thy))
2368 result_thms' thy'' |> Theory.checkpoint
2373 fun extend' value_of edges_of key (G, visited) =
2375 val (G', v) = case try (Graph.get_node G) key of
2377 | NONE => (Graph.new_node (key, value_of key) G, value_of key)
2378 val (G'', visited') = fold (extend' value_of edges_of)
2379 (subtract (op =) visited (edges_of (key, v)))
2380 (G', key :: visited)
2382 (fold (Graph.add_edge o (pair key)) (edges_of (key, v)) G'', visited')
2385 fun extend value_of edges_of key G = fst (extend' value_of edges_of key (G, []))
2387 fun gen_add_equations steps options names thy =
2389 fun dest_steps (Steps s) = s
2390 val defined = defined_functions (Comp_Mod.compilation (#comp_modifiers (dest_steps steps)))
2392 |> PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names)
2393 |> Theory.checkpoint;
2394 fun strong_conn_of gr keys =
2395 Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
2396 val scc = strong_conn_of (PredData.get thy') names
2398 val thy'' = fold_rev
2399 (fn preds => fn thy =>
2400 if not (forall (defined thy) preds) then
2401 add_equations_of steps options preds thy
2403 scc thy' |> Theory.checkpoint
2406 val depth_limited_comp_modifiers = Comp_Mod.Comp_Modifiers
2408 compilation = Depth_Limited,
2409 function_name_of = function_name_of Depth_Limited,
2410 set_function_name = set_function_name Depth_Limited,
2411 funT_of = depth_limited_funT_of : (compilation_funs -> mode -> typ -> typ),
2412 function_name_prefix = "depth_limited_",
2413 additional_arguments = fn names =>
2415 val [depth_name, polarity_name] = Name.variant_list names ["depth", "polarity"]
2416 in [Free (polarity_name, @{typ "bool"}), Free (depth_name, @{typ "code_numeral"})] end,
2418 fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
2420 val [polarity, depth] = additional_arguments
2421 val (_, Ts2) = chop (length (fst mode)) (binder_types T)
2422 val (_, Us2) = split_smodeT (snd mode) Ts2
2423 val T' = mk_predT compfuns (HOLogic.mk_tupleT Us2)
2424 val if_const = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
2425 val full_mode = null Us2
2427 if_const $ HOLogic.mk_eq (depth, @{term "0 :: code_numeral"})
2428 $ (if_const $ polarity $ mk_bot compfuns (dest_predT compfuns T')
2429 $ (if full_mode then mk_single compfuns HOLogic.unit else
2430 Const (@{const_name undefined}, T')))
2433 transform_additional_arguments =
2434 fn prem => fn additional_arguments =>
2436 val [polarity, depth] = additional_arguments
2437 val polarity' = (case prem of Prem _ => I | Negprem _ => HOLogic.mk_not | _ => I) polarity
2439 Const ("Algebras.minus", @{typ "code_numeral => code_numeral => code_numeral"})
2440 $ depth $ Const ("Algebras.one", @{typ "Code_Numeral.code_numeral"})
2441 in [polarity', depth'] end
2444 val random_comp_modifiers = Comp_Mod.Comp_Modifiers
2446 compilation = Random,
2447 function_name_of = function_name_of Random,
2448 set_function_name = set_function_name Random,
2449 function_name_prefix = "random_",
2450 funT_of = K random_function_funT_of : (compilation_funs -> mode -> typ -> typ),
2451 additional_arguments = fn names => [Free (Name.variant names "size", @{typ code_numeral})],
2452 wrap_compilation = K (K (K (K (K I))))
2453 : (compilation_funs -> string -> typ -> mode -> term list -> term -> term),
2454 transform_additional_arguments = K I : (indprem -> term list -> term list)
2457 (* different instantiantions of the predicate compiler *)
2459 val predicate_comp_modifiers = Comp_Mod.Comp_Modifiers
2462 function_name_prefix = "",
2463 compfuns = PredicateCompFuns.compfuns,
2464 additional_arguments = K [],
2465 wrap_compilation = K (K (K (K (K I))))
2466 : (compilation_funs -> string -> typ -> mode -> term list -> term -> term),
2467 transform_additional_arguments = K I : (indprem -> term list -> term list)
2470 val add_equations = gen_add_equations
2471 (Steps {infer_modes = infer_modes false,
2472 define_functions = create_definitions,
2474 add_code_equations = add_code_equations,
2475 comp_modifiers = predicate_comp_modifiers,
2476 qname = "equation"})
2478 val annotated_comp_modifiers = Comp_Mod.Comp_Modifiers
2480 compilation = Annotated,
2481 function_name_prefix = "annotated_",
2482 compfuns = PredicateCompFuns.compfuns,
2483 additional_arguments = K [],
2485 fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
2486 mk_tracing ("calling predicate " ^ s ^
2487 " with mode " ^ string_of_mode mode) compilation,
2488 transform_additional_arguments = K I : (indprem -> term list -> term list)
2491 val dseq_comp_modifiers = Comp_Mod.Comp_Modifiers
2494 function_name_prefix = "dseq_",
2495 compfuns = DSequence_CompFuns.compfuns,
2496 additional_arguments = K [],
2497 wrap_compilation = K (K (K (K (K I))))
2498 : (compilation_funs -> string -> typ -> mode -> term list -> term -> term),
2499 transform_additional_arguments = K I : (indprem -> term list -> term list)
2502 val random_dseq_comp_modifiers = Comp_Mod.Comp_Modifiers
2504 compilation = Random_DSeq,
2505 function_name_prefix = "random_dseq_",
2506 compfuns = Random_Sequence_CompFuns.compfuns,
2507 additional_arguments = K [],
2508 wrap_compilation = K (K (K (K (K I))))
2509 : (compilation_funs -> string -> typ -> mode -> term list -> term -> term),
2510 transform_additional_arguments = K I : (indprem -> term list -> term list)
2514 val add_depth_limited_equations = gen_add_equations
2515 (Steps {infer_modes = infer_modes,
2516 define_functions = define_functions depth_limited_comp_modifiers PredicateCompFuns.compfuns,
2517 compile_preds = compile_preds depth_limited_comp_modifiers PredicateCompFuns.compfuns,
2518 prove = prove_by_skip,
2519 add_code_equations = K (K I),
2520 defined = defined_functions Depth_Limited,
2521 qname = "depth_limited_equation"})
2523 val add_annotated_equations = gen_add_equations
2524 (Steps {infer_modes = infer_modes false,
2525 define_functions = define_functions annotated_comp_modifiers PredicateCompFuns.compfuns,
2526 prove = prove_by_skip,
2527 add_code_equations = K (K I),
2528 comp_modifiers = annotated_comp_modifiers,
2529 qname = "annotated_equation"})
2531 val add_quickcheck_equations = gen_add_equations
2532 (Steps {infer_modes = infer_modes_with_generator,
2533 define_functions = define_functions random_comp_modifiers RandomPredCompFuns.compfuns,
2534 compile_preds = compile_preds random_comp_modifiers RandomPredCompFuns.compfuns,
2535 prove = prove_by_skip,
2536 add_code_equations = K (K I),
2537 defined = defined_functions Random,
2538 qname = "random_equation"})
2540 val add_dseq_equations = gen_add_equations
2541 (Steps {infer_modes = infer_modes false,
2542 define_functions = define_functions dseq_comp_modifiers DSequence_CompFuns.compfuns,
2543 prove = prove_by_skip,
2544 add_code_equations = K (K I),
2545 comp_modifiers = dseq_comp_modifiers,
2546 qname = "dseq_equation"})
2548 val add_random_dseq_equations = gen_add_equations
2549 (Steps {infer_modes = infer_modes true,
2550 define_functions = define_functions random_dseq_comp_modifiers Random_Sequence_CompFuns.compfuns,
2551 prove = prove_by_skip,
2552 add_code_equations = K (K I),
2553 comp_modifiers = random_dseq_comp_modifiers,
2554 qname = "random_dseq_equation"})
2557 (** user interface **)
2559 (* code_pred_intro attribute *)
2561 fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I);
2563 val code_pred_intro_attrib = attrib add_intro;
2567 - Naming of auxiliary rules necessary?
2570 val setup = PredData.put (Graph.empty) #>
2571 Attrib.setup @{binding code_pred_intro} (Scan.succeed (attrib add_intro))
2572 "adding alternative introduction rules for code generation of inductive predicates"
2574 (* TODO: make Theory_Data to Generic_Data & remove duplication of local theory and theory *)
2575 fun generic_code_pred prep_const options raw_const lthy =
2577 val thy = ProofContext.theory_of lthy
2578 val const = prep_const thy raw_const
2579 val lthy' = Local_Theory.theory (PredData.map
2580 (extend (fetch_pred_data thy) (depending_preds_of thy) const)) lthy
2581 |> Local_Theory.checkpoint
2582 val thy' = ProofContext.theory_of lthy'
2583 val preds = Graph.all_succs (PredData.get thy') [const] |> filter_out (has_elim thy')
2584 fun mk_cases const =
2586 val T = Sign.the_const_type thy const
2587 val pred = Const (const, T)
2588 val intros = intros_of thy' const
2589 in mk_casesrule lthy' pred intros end
2590 val cases_rules = map mk_cases preds
2592 map (fn case_rule => Rule_Cases.Case {fixes = [],
2593 assumes = [("", Logic.strip_imp_prems case_rule)],
2594 binds = [], cases = []}) cases_rules
2595 val case_env = map2 (fn p => fn c => (Long_Name.base_name p, SOME c)) preds cases
2597 |> fold Variable.auto_fixes cases_rules
2598 |> ProofContext.add_cases true case_env
2599 fun after_qed thms goal_ctxt =
2601 val global_thms = ProofContext.export goal_ctxt
2602 (ProofContext.init (ProofContext.theory_of goal_ctxt)) (map the_single thms)
2604 goal_ctxt |> Local_Theory.theory (fold set_elim global_thms #>
2605 ((case compilation options of
2606 Pred => add_equations
2607 | DSeq => add_dseq_equations
2608 | Random_DSeq => add_random_dseq_equations
2609 | compilation => error ("Compilation not supported")
2610 (*| Random => (fn opt => fn cs => add_equations opt cs #> add_quickcheck_equations opt cs)
2611 | Depth_Limited => add_depth_limited_equations
2612 | Annotated => add_annotated_equations*)
2616 Proof.theorem_i NONE after_qed (map (single o (rpair [])) cases_rules) lthy''
2619 val code_pred = generic_code_pred (K I);
2620 val code_pred_cmd = generic_code_pred Code.read_const
2622 (* transformation for code generation *)
2624 val eval_ref = Unsynchronized.ref (NONE : (unit -> term Predicate.pred) option);
2625 val random_eval_ref =
2626 Unsynchronized.ref (NONE : (unit -> int * int -> term Predicate.pred * (int * int)) option);
2627 val dseq_eval_ref = Unsynchronized.ref (NONE : (unit -> term DSequence.dseq) option);
2628 val random_dseq_eval_ref =
2629 Unsynchronized.ref (NONE : (unit -> int -> int -> int * int -> term DSequence.dseq * (int * int)) option);
2631 (*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
2632 fun analyze_compr thy compfuns param_user_modes (compilation, arguments) t_compr =
2634 val all_modes_of = all_modes_of compilation
2635 val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
2636 | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
2637 val (body, Ts, fp) = HOLogic.strip_psplits split;
2638 val output_names = Name.variant_list (Term.add_free_names body [])
2639 (map (fn i => "x" ^ string_of_int i) (1 upto length Ts))
2640 val output_frees = map2 (curry Free) output_names (rev Ts)
2641 val body = subst_bounds (output_frees, body)
2642 val T_compr = HOLogic.mk_ptupleT fp Ts
2643 val output_tuple = HOLogic.mk_ptuple fp T_compr (rev output_frees)
2644 val (pred as Const (name, T), all_args) = strip_comb body
2646 if defined_functions compilation thy name then
2648 fun extract_mode (Const ("Pair", _) $ t1 $ t2) = Pair (extract_mode t1, extract_mode t2)
2649 | extract_mode (Free (x, _)) = if member (op =) output_names x then Output else Input
2650 | extract_mode _ = Input
2651 val user_mode = fold_rev (curry Fun) (map extract_mode all_args) Bool
2652 fun valid modes1 modes2 =
2653 case int_ord (length modes1, length modes2) of
2654 GREATER => error "Not enough mode annotations"
2655 | LESS => error "Too many mode annotations"
2656 | EQUAL => forall (fn (m, NONE) => true | (m, SOME m2) => eq_mode (m, m2))
2658 fun mode_instance_of (m1, m2) =
2660 fun instance_of (Fun _, Input) = true
2661 | instance_of (Input, Input) = true
2662 | instance_of (Output, Output) = true
2663 | instance_of (Pair (m1, m2), Pair (m1', m2')) =
2664 instance_of (m1, m1') andalso instance_of (m2, m2')
2665 | instance_of (Pair (m1, m2), Input) =
2666 instance_of (m1, Input) andalso instance_of (m2, Input)
2667 | instance_of (Pair (m1, m2), Output) =
2668 instance_of (m1, Output) andalso instance_of (m2, Output)
2669 | instance_of _ = false
2670 in forall instance_of (strip_fun_mode m1 ~~ strip_fun_mode m2) end
2671 val derivs = all_derivations_of thy (all_modes_of thy) [] body
2672 |> filter (fn (d, missing_vars) =>
2674 val (p_mode :: modes) = collect_context_modes d
2676 null missing_vars andalso
2677 mode_instance_of (p_mode, user_mode) andalso
2678 the_default true (Option.map (valid modes) param_user_modes)
2681 val deriv = case derivs of
2682 [] => error ("No mode possible for comprehension "
2683 ^ Syntax.string_of_term_global thy t_compr)
2685 | d :: _ :: _ => (warning ("Multiple modes possible for comprehension "
2686 ^ Syntax.string_of_term_global thy t_compr); d);
2687 val (_, outargs) = split_mode (head_mode_of deriv) all_args
2688 val additional_arguments =
2691 | Random => [@{term "5 :: code_numeral"}]
2693 | Depth_Limited => [@{term "True"}, HOLogic.mk_number @{typ "code_numeral"} (hd arguments)]
2696 val comp_modifiers =
2698 Pred => predicate_comp_modifiers
2699 (*| Random => random_comp_modifiers
2700 | Depth_Limited => depth_limited_comp_modifiers
2701 | Annotated => annotated_comp_modifiers*)
2702 | DSeq => dseq_comp_modifiers
2703 | Random_DSeq => random_dseq_comp_modifiers
2704 val t_pred = compile_expr comp_modifiers compfuns thy (body, deriv) additional_arguments;
2705 val T_pred = dest_predT compfuns (fastype_of t_pred)
2706 val arrange = split_lambda (HOLogic.mk_tuple outargs) output_tuple
2708 if null outargs then t_pred else mk_map compfuns T_pred T_compr arrange t_pred
2711 error "Evaluation with values is not possible because compilation with code_pred was not invoked"
2714 fun eval thy param_user_modes (options as (compilation, arguments)) k t_compr =
2718 Random => RandomPredCompFuns.compfuns
2719 | DSeq => DSequence_CompFuns.compfuns
2720 | Random_DSeq => Random_Sequence_CompFuns.compfuns
2721 | _ => PredicateCompFuns.compfuns
2722 val t = analyze_compr thy compfuns param_user_modes options t_compr;
2723 val T = dest_predT compfuns (fastype_of t);
2724 val t' = mk_map compfuns T HOLogic.termT (HOLogic.term_of_const T) t;
2728 fst (Predicate.yieldn k
2729 (Code_Eval.eval NONE ("Predicate_Compile_Core.random_eval_ref", random_eval_ref)
2730 (fn proc => fn g => fn s => g s |>> Predicate.map proc) thy t' []
2731 |> Random_Engine.run))
2734 val [nrandom, size, depth] = arguments
2736 fst (DSequence.yieldn k
2737 (Code_Eval.eval NONE ("Predicate_Compile_Core.random_dseq_eval_ref", random_dseq_eval_ref)
2738 (fn proc => fn g => fn nrandom => fn size => fn s => g nrandom size s |>> DSequence.map proc)
2739 thy t' [] nrandom size
2740 |> Random_Engine.run)
2744 fst (DSequence.yieldn k
2745 (Code_Eval.eval NONE ("Predicate_Compile_Core.dseq_eval_ref", dseq_eval_ref)
2746 DSequence.map thy t' []) (the_single arguments) true)
2748 fst (Predicate.yieldn k
2749 (Code_Eval.eval NONE ("Predicate_Compile_Core.eval_ref", eval_ref)
2750 Predicate.map thy t' []))
2753 fun values ctxt param_user_modes (raw_expected, comp_options) k t_compr =
2755 val thy = ProofContext.theory_of ctxt
2756 val (T, ts) = eval thy param_user_modes comp_options k t_compr
2757 val setT = HOLogic.mk_setT T
2758 val elems = HOLogic.mk_set T ts
2759 val cont = Free ("...", setT)
2760 (* check expected values *)
2762 case raw_expected of
2765 if eq_set (op =) (HOLogic.dest_set (Syntax.read_term ctxt s), ts) then ()
2767 error ("expected and computed values do not match:\n" ^
2768 "expected values: " ^ Syntax.string_of_term ctxt (Syntax.read_term ctxt s) ^ "\n" ^
2769 "computed values: " ^ Syntax.string_of_term ctxt elems ^ "\n")
2771 if k = ~1 orelse length ts < k then elems
2772 else Const (@{const_name Set.union}, setT --> setT --> setT) $ elems $ cont
2775 fun values_cmd print_modes param_user_modes options k raw_t state =
2777 val ctxt = Toplevel.context_of state
2778 val t = Syntax.read_term ctxt raw_t
2779 val t' = values ctxt param_user_modes options k t
2780 val ty' = Term.type_of t'
2781 val ctxt' = Variable.auto_fixes t' ctxt
2782 val p = PrintMode.with_modes print_modes (fn () =>
2783 Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
2784 Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
2785 in Pretty.writeln p end;