1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML Wed Aug 25 09:42:28 2010 +0200
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML Wed Aug 25 17:49:52 2010 +0200
1.3 @@ -11,11 +11,11 @@
1.4 only: bool}
1.5
1.6 val trace : bool Unsynchronized.ref
1.7 - val name_thms_pair_from_ref :
1.8 + val name_thm_pairs_from_ref :
1.9 Proof.context -> unit Symtab.table -> thm list -> Facts.ref
1.10 - -> (unit -> string * bool) * thm list
1.11 + -> ((unit -> string * bool) * (bool * thm)) list
1.12 val relevant_facts :
1.13 - bool -> real -> real -> int -> bool -> relevance_override
1.14 + bool -> real -> real option -> int -> bool -> relevance_override
1.15 -> Proof.context * (thm list * 'a) -> term list -> term
1.16 -> ((string * bool) * thm) list
1.17 end;
1.18 @@ -37,13 +37,22 @@
1.19
1.20 val sledgehammer_prefix = "Sledgehammer" ^ Long_Name.separator
1.21
1.22 -fun name_thms_pair_from_ref ctxt reserved chained_ths xref =
1.23 - let val ths = ProofContext.get_fact ctxt xref in
1.24 - (fn () => let
1.25 - val name = Facts.string_of_ref xref
1.26 - val name = name |> Symtab.defined reserved name ? quote
1.27 - val chained = forall (member Thm.eq_thm chained_ths) ths
1.28 - in (name, chained) end, ths)
1.29 +fun repair_name reserved multi j name =
1.30 + (name |> Symtab.defined reserved name ? quote) ^
1.31 + (if multi then "(" ^ Int.toString j ^ ")" else "")
1.32 +
1.33 +fun name_thm_pairs_from_ref ctxt reserved chained_ths xref =
1.34 + let
1.35 + val ths = ProofContext.get_fact ctxt xref
1.36 + val name = Facts.string_of_ref xref
1.37 + val multi = length ths > 1
1.38 + in
1.39 + fold (fn th => fn (j, rest) =>
1.40 + (j + 1, (fn () => (repair_name reserved multi j name,
1.41 + member Thm.eq_thm chained_ths th),
1.42 + (multi, th)) :: rest))
1.43 + ths (1, [])
1.44 + |> snd
1.45 end
1.46
1.47 (***************************************************************)
1.48 @@ -53,30 +62,44 @@
1.49 (*** constants with types ***)
1.50
1.51 (*An abstraction of Isabelle types*)
1.52 -datatype const_typ = CTVar | CType of string * const_typ list
1.53 +datatype pseudotype = PVar | PType of string * pseudotype list
1.54 +
1.55 +fun string_for_pseudotype PVar = "?"
1.56 + | string_for_pseudotype (PType (s, Ts)) =
1.57 + (case Ts of
1.58 + [] => ""
1.59 + | [T] => string_for_pseudotype T
1.60 + | Ts => string_for_pseudotypes Ts ^ " ") ^ s
1.61 +and string_for_pseudotypes Ts =
1.62 + "(" ^ commas (map string_for_pseudotype Ts) ^ ")"
1.63
1.64 (*Is the second type an instance of the first one?*)
1.65 -fun match_type (CType(con1,args1)) (CType(con2,args2)) =
1.66 - con1=con2 andalso match_types args1 args2
1.67 - | match_type CTVar _ = true
1.68 - | match_type _ CTVar = false
1.69 -and match_types [] [] = true
1.70 - | match_types (a1::as1) (a2::as2) = match_type a1 a2 andalso match_types as1 as2;
1.71 +fun match_pseudotype (PType (a, T), PType (b, U)) =
1.72 + a = b andalso match_pseudotypes (T, U)
1.73 + | match_pseudotype (PVar, _) = true
1.74 + | match_pseudotype (_, PVar) = false
1.75 +and match_pseudotypes ([], []) = true
1.76 + | match_pseudotypes (T :: Ts, U :: Us) =
1.77 + match_pseudotype (T, U) andalso match_pseudotypes (Ts, Us)
1.78
1.79 (*Is there a unifiable constant?*)
1.80 -fun const_mem const_tab (c, c_typ) =
1.81 - exists (match_types c_typ) (these (Symtab.lookup const_tab c))
1.82 +fun pseudoconst_mem f const_tab (c, c_typ) =
1.83 + exists (curry (match_pseudotypes o f) c_typ)
1.84 + (these (Symtab.lookup const_tab c))
1.85
1.86 -(*Maps a "real" type to a const_typ*)
1.87 -fun const_typ_of (Type (c,typs)) = CType (c, map const_typ_of typs)
1.88 - | const_typ_of (TFree _) = CTVar
1.89 - | const_typ_of (TVar _) = CTVar
1.90 +fun pseudotype_for (Type (c,typs)) = PType (c, map pseudotype_for typs)
1.91 + | pseudotype_for (TFree _) = PVar
1.92 + | pseudotype_for (TVar _) = PVar
1.93 +(* Pairs a constant with the list of its type instantiations. *)
1.94 +fun pseudoconst_for thy (c, T) =
1.95 + (c, map pseudotype_for (Sign.const_typargs thy (c, T)))
1.96 + handle TYPE _ => (c, []) (* Variable (locale constant): monomorphic *)
1.97
1.98 -(*Pairs a constant with the list of its type instantiations (using const_typ)*)
1.99 -fun const_with_typ thy (c,typ) =
1.100 - let val tvars = Sign.const_typargs thy (c,typ) in
1.101 - (c, map const_typ_of tvars) end
1.102 - handle TYPE _ => (c, []) (*Variable (locale constant): monomorphic*)
1.103 +fun string_for_pseudoconst (s, []) = s
1.104 + | string_for_pseudoconst (s, Ts) = s ^ string_for_pseudotypes Ts
1.105 +fun string_for_super_pseudoconst (s, [[]]) = s
1.106 + | string_for_super_pseudoconst (s, Tss) =
1.107 + s ^ "{" ^ commas (map string_for_pseudotypes Tss) ^ "}"
1.108
1.109 (*Add a const/type pair to the table, but a [] entry means a standard connective,
1.110 which we ignore.*)
1.111 @@ -86,7 +109,7 @@
1.112
1.113 fun is_formula_type T = (T = HOLogic.boolT orelse T = propT)
1.114
1.115 -val fresh_prefix = "Sledgehammer.FRESH."
1.116 +val fresh_prefix = "Sledgehammer.skolem."
1.117 val flip = Option.map not
1.118 (* These are typically simplified away by "Meson.presimplify". *)
1.119 val boring_consts =
1.120 @@ -99,7 +122,7 @@
1.121 introduce a fresh constant to simulate the effect of Skolemization. *)
1.122 fun do_term t =
1.123 case t of
1.124 - Const x => add_const_to_table (const_with_typ thy x)
1.125 + Const x => add_const_to_table (pseudoconst_for thy x)
1.126 | Free (s, _) => add_const_to_table (s, [])
1.127 | t1 $ t2 => fold do_term [t1, t2]
1.128 | Abs (_, _, t') => do_term t'
1.129 @@ -166,23 +189,23 @@
1.130
1.131 (* A two-dimensional symbol table counts frequencies of constants. It's keyed
1.132 first by constant name and second by its list of type instantiations. For the
1.133 - latter, we need a linear ordering on "const_typ list". *)
1.134 + latter, we need a linear ordering on "pseudotype list". *)
1.135
1.136 -fun const_typ_ord p =
1.137 +fun pseudotype_ord p =
1.138 case p of
1.139 - (CTVar, CTVar) => EQUAL
1.140 - | (CTVar, CType _) => LESS
1.141 - | (CType _, CTVar) => GREATER
1.142 - | (CType q1, CType q2) =>
1.143 - prod_ord fast_string_ord (dict_ord const_typ_ord) (q1, q2)
1.144 + (PVar, PVar) => EQUAL
1.145 + | (PVar, PType _) => LESS
1.146 + | (PType _, PVar) => GREATER
1.147 + | (PType q1, PType q2) =>
1.148 + prod_ord fast_string_ord (dict_ord pseudotype_ord) (q1, q2)
1.149
1.150 structure CTtab =
1.151 - Table(type key = const_typ list val ord = dict_ord const_typ_ord)
1.152 + Table(type key = pseudotype list val ord = dict_ord pseudotype_ord)
1.153
1.154 fun count_axiom_consts theory_relevant thy (_, th) =
1.155 let
1.156 fun do_const (a, T) =
1.157 - let val (c, cts) = const_with_typ thy (a, T) in
1.158 + let val (c, cts) = pseudoconst_for thy (a, T) in
1.159 (* Two-dimensional table update. Constant maps to types maps to
1.160 count. *)
1.161 CTtab.map_default (cts, 0) (Integer.add 1)
1.162 @@ -199,8 +222,8 @@
1.163 (**** Actual Filtering Code ****)
1.164
1.165 (*The frequency of a constant is the sum of those of all instances of its type.*)
1.166 -fun const_frequency const_tab (c, cts) =
1.167 - CTtab.fold (fn (cts', m) => match_types cts cts' ? Integer.add m)
1.168 +fun pseudoconst_freq match const_tab (c, cts) =
1.169 + CTtab.fold (fn (cts', m) => match (cts, cts') ? Integer.add m)
1.170 (the (Symtab.lookup const_tab c)) 0
1.171 handle Option.Option => 0
1.172
1.173 @@ -214,29 +237,22 @@
1.174 fun irrel_log (x : real) = Math.ln (x + 19.0) / 6.4
1.175
1.176 (* Computes a constant's weight, as determined by its frequency. *)
1.177 -val rel_const_weight = rel_log o real oo const_frequency
1.178 -val irrel_const_weight = irrel_log o real oo const_frequency
1.179 -(* fun irrel_const_weight _ _ = 1.0 FIXME: OLD CODE *)
1.180 +val rel_weight = rel_log o real oo pseudoconst_freq match_pseudotypes
1.181 +val irrel_weight =
1.182 + irrel_log o real oo pseudoconst_freq (match_pseudotypes o swap)
1.183 +(* fun irrel_weight _ _ = 1.0 FIXME: OLD CODE *)
1.184
1.185 fun axiom_weight const_tab relevant_consts axiom_consts =
1.186 - let
1.187 - val (rel, irrel) = List.partition (const_mem relevant_consts) axiom_consts
1.188 - val rel_weight = fold (curry Real.+ o rel_const_weight const_tab) rel 0.0
1.189 - val irrel_weight = fold (curry Real.+ o irrel_const_weight const_tab) irrel 0.0
1.190 - val res = rel_weight / (rel_weight + irrel_weight)
1.191 - in if Real.isFinite res then res else 0.0 end
1.192 -
1.193 -(* OLD CODE:
1.194 -(*Relevant constants are weighted according to frequency,
1.195 - but irrelevant constants are simply counted. Otherwise, Skolem functions,
1.196 - which are rare, would harm a formula's chances of being picked.*)
1.197 -fun axiom_weight const_tab relevant_consts axiom_consts =
1.198 - let
1.199 - val rel = filter (const_mem relevant_consts) axiom_consts
1.200 - val rel_weight = fold (curry Real.+ o rel_const_weight const_tab) rel 0.0
1.201 - val res = rel_weight / (rel_weight + real (length axiom_consts - length rel))
1.202 - in if Real.isFinite res then res else 0.0 end
1.203 -*)
1.204 + case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
1.205 + ||> filter_out (pseudoconst_mem swap relevant_consts) of
1.206 + ([], []) => 0.0
1.207 + | (_, []) => 1.0
1.208 + | (rel, irrel) =>
1.209 + let
1.210 + val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
1.211 + val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
1.212 + val res = rel_weight / (rel_weight + irrel_weight)
1.213 + in if Real.isFinite res then res else 0.0 end
1.214
1.215 fun consts_of_term thy t =
1.216 Symtab.fold (fn (x, ys) => fold (fn y => cons (x, y)) ys)
1.217 @@ -247,83 +263,82 @@
1.218 |> consts_of_term thy)
1.219
1.220 type annotated_thm =
1.221 - ((unit -> string * bool) * thm) * (string * const_typ list) list
1.222 + ((unit -> string * bool) * thm) * (string * pseudotype list) list
1.223
1.224 -(*For a reverse sort, putting the largest values first.*)
1.225 -fun compare_pairs ((_, w1), (_, w2)) = Real.compare (w2, w1)
1.226 +fun rev_compare_pairs ((_, w1), (_, w2)) = Real.compare (w2, w1)
1.227
1.228 -(* Limit the number of new facts, to prevent runaway acceptance. *)
1.229 -fun take_best max_relevant_per_iter (new_pairs : (annotated_thm * real) list) =
1.230 - let val nnew = length new_pairs in
1.231 - if nnew <= max_relevant_per_iter then
1.232 - (map #1 new_pairs, [])
1.233 - else
1.234 - let
1.235 - val new_pairs = sort compare_pairs new_pairs
1.236 - val accepted = List.take (new_pairs, max_relevant_per_iter)
1.237 - in
1.238 - trace_msg (fn () => ("Number of candidates, " ^ Int.toString nnew ^
1.239 - ", exceeds the limit of " ^ Int.toString max_relevant_per_iter));
1.240 - trace_msg (fn () => ("Effective pass mark: " ^ Real.toString (#2 (List.last accepted))));
1.241 - trace_msg (fn () => "Actually passed: " ^
1.242 - space_implode ", " (map (fst o (fn f => f ()) o fst o fst o fst) accepted));
1.243 - (map #1 accepted, List.drop (new_pairs, max_relevant_per_iter))
1.244 - end
1.245 - end;
1.246 +fun take_best max (new_pairs : (annotated_thm * real) list) =
1.247 + let
1.248 + val ((perfect, more_perfect), imperfect) =
1.249 + new_pairs |> List.partition (fn (_, w) => w > 0.99999)
1.250 + |>> chop (max - 1) ||> sort rev_compare_pairs
1.251 + val (accepted, rejected) =
1.252 + case more_perfect @ imperfect of
1.253 + [] => (perfect, [])
1.254 + | (q :: qs) => (q :: perfect, qs)
1.255 + in
1.256 + trace_msg (fn () => "Number of candidates: " ^
1.257 + string_of_int (length new_pairs));
1.258 + trace_msg (fn () => "Effective threshold: " ^
1.259 + Real.toString (#2 (hd accepted)));
1.260 + trace_msg (fn () => "Actually passed: " ^
1.261 + (accepted |> map (fn (((name, _), _), weight) =>
1.262 + fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
1.263 + |> commas));
1.264 + (map #1 accepted, rejected)
1.265 + end
1.266
1.267 val threshold_divisor = 2.0
1.268 val ridiculous_threshold = 0.1
1.269
1.270 -fun relevance_filter ctxt relevance_threshold relevance_decay
1.271 - max_relevant_per_iter theory_relevant
1.272 - ({add, del, ...} : relevance_override) axioms goal_ts =
1.273 +fun relevance_filter ctxt relevance_threshold relevance_decay max_relevant
1.274 + theory_relevant ({add, del, ...} : relevance_override)
1.275 + axioms goal_ts =
1.276 let
1.277 val thy = ProofContext.theory_of ctxt
1.278 val const_tab = fold (count_axiom_consts theory_relevant thy) axioms
1.279 Symtab.empty
1.280 - val goal_const_tab = get_consts thy (SOME false) goal_ts
1.281 - val _ =
1.282 - trace_msg (fn () => "Initial constants: " ^
1.283 - commas (goal_const_tab |> Symtab.dest
1.284 - |> filter (curry (op <>) [] o snd)
1.285 - |> map fst))
1.286 val add_thms = maps (ProofContext.get_fact ctxt) add
1.287 val del_thms = maps (ProofContext.get_fact ctxt) del
1.288 - fun iter j threshold rel_const_tab =
1.289 + fun iter j max threshold rel_const_tab rest =
1.290 let
1.291 + fun game_over rejects =
1.292 + if j = 0 andalso threshold >= ridiculous_threshold then
1.293 + (* First iteration? Try again. *)
1.294 + iter 0 max (threshold / threshold_divisor) rel_const_tab rejects
1.295 + else
1.296 + (* Add "add:" facts. *)
1.297 + if null add_thms then
1.298 + []
1.299 + else
1.300 + map_filter (fn ((p as (_, th), _), _) =>
1.301 + if member Thm.eq_thm add_thms th then SOME p
1.302 + else NONE) rejects
1.303 fun relevant ([], rejects) [] =
1.304 - (* Nothing was added this iteration. *)
1.305 - if j = 0 andalso threshold >= ridiculous_threshold then
1.306 - (* First iteration? Try again. *)
1.307 - iter 0 (threshold / threshold_divisor) rel_const_tab
1.308 - (map (apsnd SOME) rejects)
1.309 - else
1.310 - (* Add "add:" facts. *)
1.311 - if null add_thms then
1.312 - []
1.313 - else
1.314 - map_filter (fn ((p as (_, th), _), _) =>
1.315 - if member Thm.eq_thm add_thms th then SOME p
1.316 - else NONE) rejects
1.317 + (* Nothing has been added this iteration. *)
1.318 + game_over (map (apsnd SOME) rejects)
1.319 | relevant (new_pairs, rejects) [] =
1.320 let
1.321 - val (new_rels, more_rejects) =
1.322 - take_best max_relevant_per_iter new_pairs
1.323 + val (new_rels, more_rejects) = take_best max new_pairs
1.324 val rel_const_tab' =
1.325 rel_const_tab |> fold add_const_to_table (maps snd new_rels)
1.326 - fun is_dirty c =
1.327 - const_mem rel_const_tab' c andalso
1.328 - not (const_mem rel_const_tab c)
1.329 + fun is_dirty (c, _) =
1.330 + Symtab.lookup rel_const_tab' c <> Symtab.lookup rel_const_tab c
1.331 val rejects =
1.332 more_rejects @ rejects
1.333 |> map (fn (ax as (_, consts), old_weight) =>
1.334 (ax, if exists is_dirty consts then NONE
1.335 else SOME old_weight))
1.336 val threshold = threshold + (1.0 - threshold) * relevance_decay
1.337 + val max = max - length new_rels
1.338 in
1.339 - trace_msg (fn () => "relevant this iteration: " ^
1.340 - Int.toString (length new_rels));
1.341 - map #1 new_rels @ iter (j + 1) threshold rel_const_tab' rejects
1.342 + trace_msg (fn () => "New or updated constants: " ^
1.343 + commas (rel_const_tab' |> Symtab.dest
1.344 + |> subtract (op =) (Symtab.dest rel_const_tab)
1.345 + |> map string_for_super_pseudoconst));
1.346 + map #1 new_rels @
1.347 + (if max = 0 then game_over rejects
1.348 + else iter (j + 1) max threshold rel_const_tab' rejects)
1.349 end
1.350 | relevant (new_rels, rejects)
1.351 (((ax as ((name, th), axiom_consts)), cached_weight)
1.352 @@ -335,26 +350,29 @@
1.353 | NONE => axiom_weight const_tab rel_const_tab axiom_consts
1.354 in
1.355 if weight >= threshold then
1.356 - (trace_msg (fn () =>
1.357 - fst (name ()) ^ " passes: " ^ Real.toString weight
1.358 - ^ " consts: " ^ commas (map fst axiom_consts));
1.359 - relevant ((ax, weight) :: new_rels, rejects) rest)
1.360 + relevant ((ax, weight) :: new_rels, rejects) rest
1.361 else
1.362 relevant (new_rels, (ax, weight) :: rejects) rest
1.363 end
1.364 in
1.365 - trace_msg (fn () => "relevant_facts, current threshold: " ^
1.366 - Real.toString threshold);
1.367 - relevant ([], [])
1.368 + trace_msg (fn () =>
1.369 + "ITERATION " ^ string_of_int j ^ ": current threshold: " ^
1.370 + Real.toString threshold ^ ", constants: " ^
1.371 + commas (rel_const_tab |> Symtab.dest
1.372 + |> filter (curry (op <>) [] o snd)
1.373 + |> map string_for_super_pseudoconst));
1.374 + relevant ([], []) rest
1.375 end
1.376 in
1.377 axioms |> filter_out (member Thm.eq_thm del_thms o snd)
1.378 |> map (rpair NONE o pair_consts_axiom theory_relevant thy)
1.379 - |> iter 0 relevance_threshold goal_const_tab
1.380 + |> iter 0 max_relevant relevance_threshold
1.381 + (get_consts thy (SOME false) goal_ts)
1.382 |> tap (fn res => trace_msg (fn () =>
1.383 "Total relevant: " ^ Int.toString (length res)))
1.384 end
1.385
1.386 +
1.387 (***************************************************************)
1.388 (* Retrieving and filtering lemmas *)
1.389 (***************************************************************)
1.390 @@ -547,14 +565,7 @@
1.391 val name2 = Name_Space.extern full_space name0
1.392 in
1.393 case find_first check_thms [name1, name2, name0] of
1.394 - SOME name =>
1.395 - let
1.396 - val name =
1.397 - name |> Symtab.defined reserved name ? quote
1.398 - in
1.399 - if multi then name ^ "(" ^ Int.toString j ^ ")"
1.400 - else name
1.401 - end
1.402 + SOME name => repair_name reserved multi j name
1.403 | NONE => ""
1.404 end, is_chained th), (multi, th)) :: rest)) ths
1.405 #> snd
1.406 @@ -567,25 +578,26 @@
1.407 (* The single-name theorems go after the multiple-name ones, so that single
1.408 names are preferred when both are available. *)
1.409 fun name_thm_pairs ctxt respect_no_atp =
1.410 - List.partition (fst o snd) #> op @
1.411 - #> map (apsnd snd)
1.412 + List.partition (fst o snd) #> op @ #> map (apsnd snd)
1.413 #> respect_no_atp ? filter_out (No_ATPs.member ctxt o snd)
1.414
1.415 (***************************************************************)
1.416 (* ATP invocation methods setup *)
1.417 (***************************************************************)
1.418
1.419 -fun relevant_facts full_types relevance_threshold relevance_decay
1.420 - max_relevant_per_iter theory_relevant
1.421 - (relevance_override as {add, del, only})
1.422 +fun relevant_facts full_types relevance_threshold relevance_decay max_relevant
1.423 + theory_relevant (relevance_override as {add, del, only})
1.424 (ctxt, (chained_ths, _)) hyp_ts concl_t =
1.425 let
1.426 + val relevance_decay =
1.427 + case relevance_decay of
1.428 + SOME x => x
1.429 + | NONE => 0.35 / Math.ln (Real.fromInt (max_relevant + 1))
1.430 val add_thms = maps (ProofContext.get_fact ctxt) add
1.431 val reserved = reserved_isar_keyword_table ()
1.432 val axioms =
1.433 (if only then
1.434 - maps ((fn (n, ths) => map (pair n o pair false) ths)
1.435 - o name_thms_pair_from_ref ctxt reserved chained_ths) add
1.436 + maps (name_thm_pairs_from_ref ctxt reserved chained_ths) add
1.437 else
1.438 all_name_thms_pairs ctxt reserved full_types add_thms chained_ths)
1.439 |> name_thm_pairs ctxt (respect_no_atp andalso not only)
1.440 @@ -598,11 +610,10 @@
1.441 else if relevance_threshold < 0.0 then
1.442 axioms
1.443 else
1.444 - relevance_filter ctxt relevance_threshold relevance_decay
1.445 - max_relevant_per_iter theory_relevant relevance_override
1.446 - axioms (concl_t :: hyp_ts))
1.447 - |> map (apfst (fn f => f ()))
1.448 - |> sort_wrt (fst o fst)
1.449 + relevance_filter ctxt relevance_threshold relevance_decay max_relevant
1.450 + theory_relevant relevance_override axioms
1.451 + (concl_t :: hyp_ts))
1.452 + |> map (apfst (fn f => f ())) |> sort_wrt (fst o fst)
1.453 end
1.454
1.455 end;