1 (* Title: HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
2 Author: Lukas Bulwahn, TU Muenchen
4 Deriving specialised predicates and their intro rules
7 signature PREDICATE_COMPILE_SPECIALISATION =
9 val find_specialisations : string list -> (string * thm list) list -> theory -> (string * thm list) list * theory
12 structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION =
15 open Predicate_Compile_Aux;
17 (* table of specialisations *)
18 structure Specialisations = Theory_Data
20 type T = (term * term) Item_Net.T;
21 val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst);
23 val merge = Item_Net.merge;
26 fun specialisation_of thy atom =
27 Item_Net.retrieve (Specialisations.get thy) atom
29 fun print_specialisations thy =
30 tracing (cat_lines (map (fn (t, spec_t) =>
31 Syntax.string_of_term_global thy t ^ " ~~~> " ^ Syntax.string_of_term_global thy spec_t)
32 (Item_Net.content (Specialisations.get thy))))
34 fun import (pred, intros) args ctxt =
36 val thy = Proof_Context.theory_of ctxt
37 val ((Tinst, intros'), ctxt') = Variable.importT intros ctxt
38 val pred' = fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros')))))
39 val Ts = binder_types (fastype_of pred')
40 val argTs = map fastype_of args
41 val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty
42 val args' = map (Envir.subst_term_types Tsubst) args
44 (((pred', intros'), args'), ctxt')
47 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*)
48 fun is_nontrivial_constrt thy t =
50 val cnstrs = flat (maps
51 (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
52 (Symtab.dest (Datatype.get_all thy)));
53 fun check t = (case strip_comb t of
54 (Var _, []) => (true, true)
55 | (Free _, []) => (true, true)
56 | (Const (@{const_name Pair}, _), ts) =>
57 pairself (forall I) (split_list (map check ts))
58 | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
59 (SOME (i, Tname), Type (Tname', _)) => (false,
60 length ts = i andalso Tname = Tname' andalso forall (snd o check) ts)
61 | _ => (false, false))
62 | _ => (false, false))
63 in check t = (false, true) end;
65 fun specialise_intros black_list (pred, intros) pats thy =
67 val ctxt = Proof_Context.init_global thy
68 val maxidx = fold (Term.maxidx_term o prop_of) intros ~1
69 val pats = map (Logic.incr_indexes ([], maxidx + 1)) pats
70 val (((pred, intros), pats), ctxt') = import (pred, intros) pats ctxt
71 val intros_t = map prop_of intros
72 val result_pats = map Var (fold_rev Term.add_vars pats [])
73 fun mk_fresh_name names =
76 singleton (Name.variant_list names)
77 ("specialised_" ^ Long_Name.base_name (fst (dest_Const pred)))
78 val bname = Sign.full_bname thy name
80 if Sign.declared_const thy bname then
81 mk_fresh_name (name :: names)
85 val constname = mk_fresh_name []
86 val constT = map fastype_of result_pats ---> @{typ bool}
87 val specialised_const = Const (constname, constT)
89 [(HOLogic.mk_Trueprop (list_comb (pred, pats)),
90 HOLogic.mk_Trueprop (list_comb (specialised_const, result_pats)))]
91 fun specialise_intro intro =
93 val (prems, concl) = Logic.strip_horn (prop_of intro)
94 val env = Pattern.unify thy
95 (HOLogic.mk_Trueprop (list_comb (pred, pats)), concl) (Envir.empty 0)
96 val prems = map (Envir.norm_term env) prems
97 val args = map (Envir.norm_term env) result_pats
98 val concl = HOLogic.mk_Trueprop (list_comb (specialised_const, args))
99 val intro = Logic.list_implies (prems, concl)
102 end handle Pattern.Unif => NONE)
103 val specialised_intros_t = map_filter I (map specialise_intro intros)
104 val thy' = Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy
105 val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t
106 val exported_intros = Variable.exportT ctxt' ctxt specialised_intros
107 val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt
108 [list_comb (pred, pats), list_comb (specialised_const, result_pats)]
109 val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy'
110 val optimised_intros =
111 map_filter (Predicate_Compile_Aux.peephole_optimisation thy'') exported_intros
112 val ([spec], thy''') = find_specialisations black_list [(constname, optimised_intros)] thy''
113 val thy'''' = Core_Data.register_intros spec thy'''
118 and find_specialisations black_list specs thy =
120 val add_vars = fold_aterms (fn Var v => cons v | _ => I);
121 fun fresh_free T free_names =
123 val free_name = singleton (Name.variant_list free_names) "x"
125 (Free (free_name, T), free_name :: free_names)
127 fun replace_term_and_restrict thy T t Tts free_names =
129 val (free, free_names') = fresh_free T free_names
130 val Tts' = map (apsnd (Pattern.rewrite_term thy [(t, free)] [])) Tts
131 val (ts', free_names'') = restrict_pattern' thy Tts' free_names'
133 (free :: ts', free_names'')
135 and restrict_pattern' thy [] free_names = ([], free_names)
136 | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names =
138 val (ts', free_names') = restrict_pattern' thy Tts free_names
140 (Free (x, T) :: ts', free_names')
142 | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
143 replace_term_and_restrict thy T t Tts free_names
144 | restrict_pattern' thy ((T as Type (Tcon, Ts), t) :: Tts) free_names =
145 case Datatype_Data.get_constrs thy Tcon of
146 NONE => replace_term_and_restrict thy T t Tts free_names
147 | SOME constrs => (case strip_comb t of
148 (Const (s, _), ats) => (case AList.lookup (op =) constrs s of
151 val (Ts', T') = strip_type constr_T
152 val Tsubst = Type.raw_match (T', T) Vartab.empty
153 val Ts = map (Envir.subst_type Tsubst) Ts'
154 val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names
155 val (ats', ts') = chop (length ats) bts'
157 (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names')
159 | NONE => replace_term_and_restrict thy T t Tts free_names))
160 fun restrict_pattern thy Ts args =
162 val args = map Logic.unvarify_global args
163 val Ts = map Logic.unvarifyT_global Ts
164 val free_names = fold Term.add_free_names args []
165 val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names
166 in map Logic.varify_global pat end
167 fun detect' atom thy =
168 case strip_comb atom of
169 (pred as Const (pred_name, _), args) =>
171 val Ts = binder_types (Sign.the_const_type thy pred_name)
172 val vnames = map fst (fold Term.add_var_names args [])
173 val pats = restrict_pattern thy Ts args
175 if (exists (is_nontrivial_constrt thy) pats)
176 orelse (has_duplicates (op =) (fold add_vars pats [])) then
179 case specialisation_of thy atom of
181 if member (op =) ((map fst specs) @ black_list) pred_name then
184 (case try (Core_Data.intros_of (Proof_Context.init_global thy)) pred_name of
188 specialise_intros ((map fst specs) @ (pred_name :: black_list))
189 (pred, intros) pats thy)
190 | (t, specialised_t) :: _ => thy
192 case specialisation_of thy' atom of
194 | (t, specialised_t) :: _ =>
196 val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty)
197 in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom
198 (*FIXME: this exception could be caught earlier in specialisation_of *)
205 fun specialise' (constname, intros) thy =
207 (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *)
208 val intros = Drule.zero_var_indexes_list intros
209 val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy
211 ((constname, map (Skip_Proof.make_thm thy') intros_t'), thy')
214 fold_map specialise' specs thy