1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Jul 20 22:19:45 2012 +0200
1.3 @@ -0,0 +1,691 @@
1.4 +(* Title: HOL/Tools/Sledgehammer/sledgehammer_mash.ML
1.5 + Author: Jasmin Blanchette, TU Muenchen
1.6 +
1.7 +Sledgehammer's machine-learning-based relevance filter (MaSh).
1.8 +*)
1.9 +
1.10 +signature SLEDGEHAMMER_FILTER_MASH =
1.11 +sig
1.12 + type status = ATP_Problem_Generate.status
1.13 + type stature = ATP_Problem_Generate.stature
1.14 + type fact = Sledgehammer_Fact.fact
1.15 + type fact_override = Sledgehammer_Fact.fact_override
1.16 + type params = Sledgehammer_Provers.params
1.17 + type relevance_fudge = Sledgehammer_Provers.relevance_fudge
1.18 + type prover_result = Sledgehammer_Provers.prover_result
1.19 +
1.20 + val trace : bool Config.T
1.21 + val MaShN : string
1.22 + val mepoN : string
1.23 + val mashN : string
1.24 + val meshN : string
1.25 + val fact_filters : string list
1.26 + val escape_meta : string -> string
1.27 + val escape_metas : string list -> string
1.28 + val unescape_meta : string -> string
1.29 + val unescape_metas : string -> string list
1.30 + val extract_query : string -> string * string list
1.31 + val nickname_of : thm -> string
1.32 + val suggested_facts : string list -> ('a * thm) list -> ('a * thm) list
1.33 + val mesh_facts :
1.34 + int -> (('a * thm) list * ('a * thm) list) list -> ('a * thm) list
1.35 + val is_likely_tautology_or_too_meta : thm -> bool
1.36 + val theory_ord : theory * theory -> order
1.37 + val thm_ord : thm * thm -> order
1.38 + val features_of :
1.39 + Proof.context -> string -> theory -> status -> term list -> string list
1.40 + val isabelle_dependencies_of : unit Symtab.table -> thm -> string list
1.41 + val goal_of_thm : theory -> thm -> thm
1.42 + val run_prover_for_mash :
1.43 + Proof.context -> params -> string -> fact list -> thm -> prover_result
1.44 + val mash_CLEAR : Proof.context -> unit
1.45 + val mash_INIT :
1.46 + Proof.context -> bool
1.47 + -> (string * string list * string list * string list) list -> unit
1.48 + val mash_ADD :
1.49 + Proof.context -> bool
1.50 + -> (string * string list * string list * string list) list -> unit
1.51 + val mash_QUERY :
1.52 + Proof.context -> bool -> int -> string list * string list -> string list
1.53 + val mash_unlearn : Proof.context -> unit
1.54 + val mash_could_suggest_facts : unit -> bool
1.55 + val mash_can_suggest_facts : Proof.context -> bool
1.56 + val mash_suggest_facts :
1.57 + Proof.context -> params -> string -> int -> term list -> term
1.58 + -> ('a * thm) list -> ('a * thm) list * ('a * thm) list
1.59 + val mash_learn_thy :
1.60 + Proof.context -> params -> theory -> Time.time -> fact list -> string
1.61 + val mash_learn_proof :
1.62 + Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
1.63 + val relevant_facts :
1.64 + Proof.context -> params -> string -> int -> fact_override -> term list
1.65 + -> term -> fact list -> fact list
1.66 + val kill_learners : unit -> unit
1.67 + val running_learners : unit -> unit
1.68 +end;
1.69 +
1.70 +structure Sledgehammer_Filter_MaSh : SLEDGEHAMMER_FILTER_MASH =
1.71 +struct
1.72 +
1.73 +open ATP_Util
1.74 +open ATP_Problem_Generate
1.75 +open Sledgehammer_Util
1.76 +open Sledgehammer_Fact
1.77 +open Sledgehammer_Filter_Iter
1.78 +open Sledgehammer_Provers
1.79 +open Sledgehammer_Minimize
1.80 +
1.81 +val trace =
1.82 + Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
1.83 +fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
1.84 +
1.85 +val MaShN = "MaSh"
1.86 +
1.87 +val mepoN = "mepo"
1.88 +val mashN = "mash"
1.89 +val meshN = "mesh"
1.90 +
1.91 +val fact_filters = [meshN, mepoN, mashN]
1.92 +
1.93 +fun mash_home () = getenv "MASH_HOME"
1.94 +fun mash_state_dir () =
1.95 + getenv "ISABELLE_HOME_USER" ^ "/mash"
1.96 + |> tap (Isabelle_System.mkdir o Path.explode)
1.97 +fun mash_state_path () = mash_state_dir () ^ "/state" |> Path.explode
1.98 +
1.99 +
1.100 +(*** Isabelle helpers ***)
1.101 +
1.102 +fun meta_char c =
1.103 + if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
1.104 + c = #")" orelse c = #"," then
1.105 + String.str c
1.106 + else
1.107 + (* fixed width, in case more digits follow *)
1.108 + "\\" ^ stringN_of_int 3 (Char.ord c)
1.109 +
1.110 +fun unmeta_chars accum [] = String.implode (rev accum)
1.111 + | unmeta_chars accum (#"\\" :: d1 :: d2 :: d3 :: cs) =
1.112 + (case Int.fromString (String.implode [d1, d2, d3]) of
1.113 + SOME n => unmeta_chars (Char.chr n :: accum) cs
1.114 + | NONE => "" (* error *))
1.115 + | unmeta_chars _ (#"\\" :: _) = "" (* error *)
1.116 + | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
1.117 +
1.118 +val escape_meta = String.translate meta_char
1.119 +val escape_metas = map escape_meta #> space_implode " "
1.120 +val unescape_meta = String.explode #> unmeta_chars []
1.121 +val unescape_metas =
1.122 + space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
1.123 +
1.124 +fun extract_query line =
1.125 + case space_explode ":" line of
1.126 + [goal_name, suggs] => (unescape_meta goal_name, unescape_metas suggs)
1.127 + | _ => ("", [])
1.128 +
1.129 +fun parent_of_local_thm th =
1.130 + let
1.131 + val thy = th |> Thm.theory_of_thm
1.132 + val facts = thy |> Global_Theory.facts_of
1.133 + val space = facts |> Facts.space_of
1.134 + fun id_of s = #id (Name_Space.the_entry space s)
1.135 + fun max_id (s', _) (s, id) =
1.136 + let val id' = id_of s' in if id > id' then (s, id) else (s', id') end
1.137 + in ("", ~1) |> Facts.fold_static max_id facts |> fst end
1.138 +
1.139 +val local_prefix = "local" ^ Long_Name.separator
1.140 +
1.141 +fun nickname_of th =
1.142 + let val hint = Thm.get_name_hint th in
1.143 + (* FIXME: There must be a better way to detect local facts. *)
1.144 + case try (unprefix local_prefix) hint of
1.145 + SOME suff =>
1.146 + parent_of_local_thm th ^ Long_Name.separator ^ Long_Name.separator ^ suff
1.147 + | NONE => hint
1.148 + end
1.149 +
1.150 +fun suggested_facts suggs facts =
1.151 + let
1.152 + fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
1.153 + val tab = Symtab.empty |> fold add_fact facts
1.154 + in map_filter (Symtab.lookup tab) suggs end
1.155 +
1.156 +(* Ad hoc score function roughly based on Blanchette's Ringberg 2011 data. *)
1.157 +fun score x = Math.pow (1.5, 15.5 - 0.05 * Real.fromInt x) + 15.0
1.158 +
1.159 +fun sum_sq_avg [] = 0
1.160 + | sum_sq_avg xs =
1.161 + Real.ceil (100000.0 * fold (curry (op +) o score) xs 0.0) div length xs
1.162 +
1.163 +fun mesh_facts max_facts [(selected, unknown)] =
1.164 + take max_facts selected @ take (max_facts - length selected) unknown
1.165 + | mesh_facts max_facts mess =
1.166 + let
1.167 + val mess = mess |> map (apfst (`length))
1.168 + val fact_eq = Thm.eq_thm o pairself snd
1.169 + fun score_in fact ((sel_len, sels), unks) =
1.170 + case find_index (curry fact_eq fact) sels of
1.171 + ~1 => (case find_index (curry fact_eq fact) unks of
1.172 + ~1 => SOME sel_len
1.173 + | _ => NONE)
1.174 + | j => SOME j
1.175 + fun score_of fact = mess |> map_filter (score_in fact) |> sum_sq_avg
1.176 + val facts = fold (union fact_eq o take max_facts o snd o fst) mess []
1.177 + in
1.178 + facts |> map (`score_of) |> sort (int_ord o swap o pairself fst)
1.179 + |> map snd |> take max_facts
1.180 + end
1.181 +
1.182 +val thy_feature_prefix = "y_"
1.183 +
1.184 +val thy_feature_name_of = prefix thy_feature_prefix
1.185 +val const_name_of = prefix const_prefix
1.186 +val type_name_of = prefix type_const_prefix
1.187 +val class_name_of = prefix class_prefix
1.188 +
1.189 +fun is_likely_tautology_or_too_meta th =
1.190 + let
1.191 + val is_boring_const = member (op =) atp_widely_irrelevant_consts
1.192 + fun is_boring_bool t =
1.193 + not (exists_Const (not o is_boring_const o fst) t) orelse
1.194 + exists_type (exists_subtype (curry (op =) @{typ prop})) t
1.195 + fun is_boring_prop (@{const Trueprop} $ t) = is_boring_bool t
1.196 + | is_boring_prop (@{const "==>"} $ t $ u) =
1.197 + is_boring_prop t andalso is_boring_prop u
1.198 + | is_boring_prop (Const (@{const_name all}, _) $ (Abs (_, _, t)) $ u) =
1.199 + is_boring_prop t andalso is_boring_prop u
1.200 + | is_boring_prop (Const (@{const_name "=="}, _) $ t $ u) =
1.201 + is_boring_bool t andalso is_boring_bool u
1.202 + | is_boring_prop _ = true
1.203 + in
1.204 + is_boring_prop (prop_of th) andalso not (Thm.eq_thm_prop (@{thm ext}, th))
1.205 + end
1.206 +
1.207 +fun theory_ord p =
1.208 + if Theory.eq_thy p then
1.209 + EQUAL
1.210 + else if Theory.subthy p then
1.211 + LESS
1.212 + else if Theory.subthy (swap p) then
1.213 + GREATER
1.214 + else case int_ord (pairself (length o Theory.ancestors_of) p) of
1.215 + EQUAL => string_ord (pairself Context.theory_name p)
1.216 + | order => order
1.217 +
1.218 +val thm_ord = theory_ord o pairself theory_of_thm
1.219 +
1.220 +val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
1.221 +
1.222 +fun interesting_terms_types_and_classes ctxt prover term_max_depth
1.223 + type_max_depth ts =
1.224 + let
1.225 + fun is_bad_const (x as (s, _)) args =
1.226 + member (op =) atp_logical_consts s orelse
1.227 + fst (is_built_in_const_for_prover ctxt prover x args)
1.228 + fun add_classes @{sort type} = I
1.229 + | add_classes S = union (op =) (map class_name_of S)
1.230 + fun do_add_type (Type (s, Ts)) =
1.231 + (not (member (op =) bad_types s) ? insert (op =) (type_name_of s))
1.232 + #> fold do_add_type Ts
1.233 + | do_add_type (TFree (_, S)) = add_classes S
1.234 + | do_add_type (TVar (_, S)) = add_classes S
1.235 + fun add_type T = type_max_depth >= 0 ? do_add_type T
1.236 + fun mk_app s args =
1.237 + if member (op <>) args "" then s ^ "(" ^ space_implode "," args ^ ")"
1.238 + else s
1.239 + fun patternify ~1 _ = ""
1.240 + | patternify depth t =
1.241 + case strip_comb t of
1.242 + (Const (s, _), args) =>
1.243 + mk_app (const_name_of s) (map (patternify (depth - 1)) args)
1.244 + | _ => ""
1.245 + fun add_term_patterns ~1 _ = I
1.246 + | add_term_patterns depth t =
1.247 + insert (op =) (patternify depth t)
1.248 + #> add_term_patterns (depth - 1) t
1.249 + val add_term = add_term_patterns term_max_depth
1.250 + fun add_patterns t =
1.251 + let val (head, args) = strip_comb t in
1.252 + (case head of
1.253 + Const (x as (_, T)) =>
1.254 + not (is_bad_const x args) ? (add_term t #> add_type T)
1.255 + | Free (_, T) => add_type T
1.256 + | Var (_, T) => add_type T
1.257 + | Abs (_, T, body) => add_type T #> add_patterns body
1.258 + | _ => I)
1.259 + #> fold add_patterns args
1.260 + end
1.261 + in [] |> fold add_patterns ts end
1.262 +
1.263 +fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
1.264 +
1.265 +val term_max_depth = 1
1.266 +val type_max_depth = 1
1.267 +
1.268 +(* TODO: Generate type classes for types? *)
1.269 +fun features_of ctxt prover thy status ts =
1.270 + thy_feature_name_of (Context.theory_name thy) ::
1.271 + interesting_terms_types_and_classes ctxt prover term_max_depth type_max_depth
1.272 + ts
1.273 + |> forall is_lambda_free ts ? cons "no_lams"
1.274 + |> forall (not o exists_Const is_exists) ts ? cons "no_skos"
1.275 + |> (case status of
1.276 + General => I
1.277 + | Induction => cons "induction"
1.278 + | Intro => cons "intro"
1.279 + | Inductive => cons "inductive"
1.280 + | Elim => cons "elim"
1.281 + | Simp => cons "simp"
1.282 + | Def => cons "def")
1.283 +
1.284 +fun isabelle_dependencies_of all_facts = thms_in_proof (SOME all_facts)
1.285 +
1.286 +val freezeT = Type.legacy_freeze_type
1.287 +
1.288 +fun freeze (t $ u) = freeze t $ freeze u
1.289 + | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
1.290 + | freeze (Var ((s, _), T)) = Free (s, freezeT T)
1.291 + | freeze (Const (s, T)) = Const (s, freezeT T)
1.292 + | freeze (Free (s, T)) = Free (s, freezeT T)
1.293 + | freeze t = t
1.294 +
1.295 +fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
1.296 +
1.297 +fun run_prover_for_mash ctxt params prover facts goal =
1.298 + let
1.299 + val problem =
1.300 + {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
1.301 + facts = facts |> map (apfst (apfst (fn name => name ())))
1.302 + |> map Untranslated_Fact}
1.303 + val prover = get_minimizing_prover ctxt Normal (K ()) prover
1.304 + in prover params (K (K (K ""))) problem end
1.305 +
1.306 +
1.307 +(*** Low-level communication with MaSh ***)
1.308 +
1.309 +fun write_file (xs, f) file =
1.310 + let val path = Path.explode file in
1.311 + File.write path "";
1.312 + xs |> chunk_list 500
1.313 + |> List.app (File.append path o space_implode "" o map f)
1.314 + end
1.315 +
1.316 +fun mash_info overlord =
1.317 + if overlord then (getenv "ISABELLE_HOME_USER", "")
1.318 + else (getenv "ISABELLE_TMP", serial_string ())
1.319 +
1.320 +fun run_mash ctxt overlord (temp_dir, serial) core =
1.321 + let
1.322 + val log_file = temp_dir ^ "/mash_log" ^ serial
1.323 + val err_file = temp_dir ^ "/mash_err" ^ serial
1.324 + val command =
1.325 + mash_home () ^ "/mash.py --quiet --outputDir " ^ mash_state_dir () ^
1.326 + " --log " ^ log_file ^ " " ^ core ^ " 2>&1 > " ^ err_file
1.327 + in
1.328 + trace_msg ctxt (fn () => "Running " ^ command);
1.329 + write_file ([], K "") log_file;
1.330 + write_file ([], K "") err_file;
1.331 + Isabelle_System.bash command;
1.332 + if overlord then ()
1.333 + else (map (File.rm o Path.explode) [log_file, err_file]; ());
1.334 + trace_msg ctxt (K "Done")
1.335 + end
1.336 +
1.337 +(* TODO: Eliminate code once "mash.py" handles sequences of ADD commands as fast
1.338 + as a single INIT. *)
1.339 +fun run_mash_init ctxt overlord write_access write_feats write_deps =
1.340 + let
1.341 + val info as (temp_dir, serial) = mash_info overlord
1.342 + val in_dir = temp_dir ^ "/mash_init" ^ serial
1.343 + val in_dir_path = in_dir |> Path.explode |> tap Isabelle_System.mkdir
1.344 + in
1.345 + write_file write_access (in_dir ^ "/mash_accessibility");
1.346 + write_file write_feats (in_dir ^ "/mash_features");
1.347 + write_file write_deps (in_dir ^ "/mash_dependencies");
1.348 + run_mash ctxt overlord info ("--init --inputDir " ^ in_dir);
1.349 + (* FIXME: temporary hack *)
1.350 + if overlord then ()
1.351 + else (Isabelle_System.bash ("rm -r -f " ^ File.shell_path in_dir_path); ())
1.352 + end
1.353 +
1.354 +fun run_mash_commands ctxt overlord save max_suggs write_cmds read_suggs =
1.355 + let
1.356 + val info as (temp_dir, serial) = mash_info overlord
1.357 + val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
1.358 + val cmd_file = temp_dir ^ "/mash_commands" ^ serial
1.359 + in
1.360 + write_file ([], K "") sugg_file;
1.361 + write_file write_cmds cmd_file;
1.362 + run_mash ctxt overlord info
1.363 + ("--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
1.364 + " --numberOfPredictions " ^ string_of_int max_suggs ^
1.365 + (if save then " --saveModel" else ""));
1.366 + read_suggs (fn () => File.read_lines (Path.explode sugg_file))
1.367 + |> tap (fn _ =>
1.368 + if overlord then ()
1.369 + else (map (File.rm o Path.explode) [sugg_file, cmd_file]; ()))
1.370 + end
1.371 +
1.372 +fun str_of_update (name, parents, feats, deps) =
1.373 + "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
1.374 + escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
1.375 +
1.376 +fun str_of_query (parents, feats) =
1.377 + "? " ^ escape_metas parents ^ "; " ^ escape_metas feats
1.378 +
1.379 +fun init_str_of_update get (upd as (name, _, _, _)) =
1.380 + escape_meta name ^ ": " ^ escape_metas (get upd) ^ "\n"
1.381 +
1.382 +fun mash_CLEAR ctxt =
1.383 + let val path = mash_state_dir () |> Path.explode in
1.384 + trace_msg ctxt (K "MaSh CLEAR");
1.385 + File.fold_dir (fn file => fn () =>
1.386 + File.rm (Path.append path (Path.basic file)))
1.387 + path ()
1.388 + end
1.389 +
1.390 +fun mash_INIT ctxt _ [] = mash_CLEAR ctxt
1.391 + | mash_INIT ctxt overlord upds =
1.392 + (trace_msg ctxt (fn () => "MaSh INIT " ^
1.393 + elide_string 1000 (space_implode " " (map #1 upds)));
1.394 + run_mash_init ctxt overlord (upds, init_str_of_update #2)
1.395 + (upds, init_str_of_update #3) (upds, init_str_of_update #4))
1.396 +
1.397 +fun mash_ADD _ _ [] = ()
1.398 + | mash_ADD ctxt overlord upds =
1.399 + (trace_msg ctxt (fn () => "MaSh ADD " ^
1.400 + elide_string 1000 (space_implode " " (map #1 upds)));
1.401 + run_mash_commands ctxt overlord true 0 (upds, str_of_update) (K ()))
1.402 +
1.403 +fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
1.404 + (trace_msg ctxt (fn () => "MaSh QUERY " ^ space_implode " " feats);
1.405 + run_mash_commands ctxt overlord false max_suggs
1.406 + ([query], str_of_query)
1.407 + (fn suggs => snd (extract_query (List.last (suggs ()))))
1.408 + handle List.Empty => [])
1.409 +
1.410 +
1.411 +(*** High-level communication with MaSh ***)
1.412 +
1.413 +fun try_graph ctxt when def f =
1.414 + f ()
1.415 + handle Graph.CYCLES (cycle :: _) =>
1.416 + (trace_msg ctxt (fn () =>
1.417 + "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
1.418 + | Graph.UNDEF name =>
1.419 + (trace_msg ctxt (fn () =>
1.420 + "Unknown fact " ^ quote name ^ " when " ^ when); def)
1.421 +
1.422 +type mash_state =
1.423 + {thys : bool Symtab.table,
1.424 + fact_graph : unit Graph.T}
1.425 +
1.426 +val empty_state = {thys = Symtab.empty, fact_graph = Graph.empty}
1.427 +
1.428 +local
1.429 +
1.430 +fun mash_load _ (state as (true, _)) = state
1.431 + | mash_load ctxt _ =
1.432 + let val path = mash_state_path () in
1.433 + (true,
1.434 + case try File.read_lines path of
1.435 + SOME (comp_thys :: incomp_thys :: fact_lines) =>
1.436 + let
1.437 + fun add_thy comp thy = Symtab.update (thy, comp)
1.438 + fun add_edge_to name parent =
1.439 + Graph.default_node (parent, ())
1.440 + #> Graph.add_edge (parent, name)
1.441 + fun add_fact_line line =
1.442 + case extract_query line of
1.443 + ("", _) => I (* shouldn't happen *)
1.444 + | (name, parents) =>
1.445 + Graph.default_node (name, ())
1.446 + #> fold (add_edge_to name) parents
1.447 + val thys =
1.448 + Symtab.empty |> fold (add_thy true) (unescape_metas comp_thys)
1.449 + |> fold (add_thy false) (unescape_metas incomp_thys)
1.450 + val fact_graph =
1.451 + try_graph ctxt "loading state" Graph.empty (fn () =>
1.452 + Graph.empty |> fold add_fact_line fact_lines)
1.453 + in {thys = thys, fact_graph = fact_graph} end
1.454 + | _ => empty_state)
1.455 + end
1.456 +
1.457 +fun mash_save ({thys, fact_graph, ...} : mash_state) =
1.458 + let
1.459 + val path = mash_state_path ()
1.460 + val thys = Symtab.dest thys
1.461 + val line_for_thys = escape_metas o AList.find (op =) thys
1.462 + fun fact_line_for name parents =
1.463 + escape_meta name ^ ": " ^ escape_metas parents
1.464 + val append_fact = File.append path o suffix "\n" oo fact_line_for
1.465 + in
1.466 + File.write path (line_for_thys true ^ "\n" ^ line_for_thys false ^ "\n");
1.467 + Graph.fold (fn (name, ((), (parents, _))) => fn () =>
1.468 + append_fact name (Graph.Keys.dest parents))
1.469 + fact_graph ()
1.470 + end
1.471 +
1.472 +val global_state =
1.473 + Synchronized.var "Sledgehammer_Filter_MaSh.global_state" (false, empty_state)
1.474 +
1.475 +in
1.476 +
1.477 +fun mash_map ctxt f =
1.478 + Synchronized.change global_state (mash_load ctxt ##> (f #> tap mash_save))
1.479 +
1.480 +fun mash_get ctxt =
1.481 + Synchronized.change_result global_state (mash_load ctxt #> `snd)
1.482 +
1.483 +fun mash_unlearn ctxt =
1.484 + Synchronized.change global_state (fn _ =>
1.485 + (mash_CLEAR ctxt; File.write (mash_state_path ()) "";
1.486 + (true, empty_state)))
1.487 +
1.488 +end
1.489 +
1.490 +fun mash_could_suggest_facts () = mash_home () <> ""
1.491 +fun mash_can_suggest_facts ctxt =
1.492 + not (Graph.is_empty (#fact_graph (mash_get ctxt)))
1.493 +
1.494 +fun parents_wrt_facts facts fact_graph =
1.495 + let
1.496 + val facts = [] |> fold (cons o nickname_of o snd) facts
1.497 + val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
1.498 + fun insert_not_seen seen name =
1.499 + not (member (op =) seen name) ? insert (op =) name
1.500 + fun parents_of _ parents [] = parents
1.501 + | parents_of seen parents (name :: names) =
1.502 + if Symtab.defined tab name then
1.503 + parents_of (name :: seen) (name :: parents) names
1.504 + else
1.505 + parents_of (name :: seen) parents
1.506 + (Graph.Keys.fold (insert_not_seen seen)
1.507 + (Graph.imm_preds fact_graph name) names)
1.508 + in parents_of [] [] (Graph.maximals fact_graph) end
1.509 +
1.510 +(* Generate more suggestions than requested, because some might be thrown out
1.511 + later for various reasons and "meshing" gives better results with some
1.512 + slack. *)
1.513 +fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
1.514 +
1.515 +fun is_fact_in_graph fact_graph (_, th) =
1.516 + can (Graph.get_node fact_graph) (nickname_of th)
1.517 +
1.518 +fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
1.519 + concl_t facts =
1.520 + let
1.521 + val thy = Proof_Context.theory_of ctxt
1.522 + val fact_graph = #fact_graph (mash_get ctxt)
1.523 + val parents = parents_wrt_facts facts fact_graph
1.524 + val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
1.525 + val suggs =
1.526 + if Graph.is_empty fact_graph then []
1.527 + else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
1.528 + val selected = facts |> suggested_facts suggs
1.529 + val unknown = facts |> filter_out (is_fact_in_graph fact_graph)
1.530 + in (selected, unknown) end
1.531 +
1.532 +fun add_thys_for thy =
1.533 + let fun add comp thy = Symtab.update (Context.theory_name thy, comp) in
1.534 + add false thy #> fold (add true) (Theory.ancestors_of thy)
1.535 + end
1.536 +
1.537 +fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
1.538 + let
1.539 + fun maybe_add_from from (accum as (parents, graph)) =
1.540 + try_graph ctxt "updating graph" accum (fn () =>
1.541 + (from :: parents, Graph.add_edge_acyclic (from, name) graph))
1.542 + val graph = graph |> Graph.default_node (name, ())
1.543 + val (parents, graph) = ([], graph) |> fold maybe_add_from parents
1.544 + val (deps, graph) = ([], graph) |> fold maybe_add_from deps
1.545 + in ((name, parents, feats, deps) :: upds, graph) end
1.546 +
1.547 +val pass1_learn_timeout_factor = 0.5
1.548 +
1.549 +(* Too many dependencies is a sign that a decision procedure is at work. There
1.550 + isn't much too learn from such proofs. *)
1.551 +val max_dependencies = 10
1.552 +
1.553 +(* The timeout is understood in a very slack fashion. *)
1.554 +fun mash_learn_thy ctxt ({provers, verbose, overlord, ...} : params) thy timeout
1.555 + facts =
1.556 + let
1.557 + val timer = Timer.startRealTimer ()
1.558 + val prover = hd provers
1.559 + fun timed_out frac =
1.560 + Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
1.561 + val {fact_graph, ...} = mash_get ctxt
1.562 + val new_facts =
1.563 + facts |> filter_out (is_fact_in_graph fact_graph)
1.564 + |> sort (thm_ord o pairself snd)
1.565 + in
1.566 + if null new_facts then
1.567 + ""
1.568 + else
1.569 + let
1.570 + val ths = facts |> map snd
1.571 + val all_names =
1.572 + ths |> filter_out is_likely_tautology_or_too_meta
1.573 + |> map (rpair () o nickname_of)
1.574 + |> Symtab.make
1.575 + fun trim_deps deps = if length deps > max_dependencies then [] else deps
1.576 + fun do_fact _ (accum as (_, true)) = accum
1.577 + | do_fact ((_, (_, status)), th) ((parents, upds), false) =
1.578 + let
1.579 + val name = nickname_of th
1.580 + val feats =
1.581 + features_of ctxt prover (theory_of_thm th) status [prop_of th]
1.582 + val deps = isabelle_dependencies_of all_names th |> trim_deps
1.583 + val upd = (name, parents, feats, deps)
1.584 + in (([name], upd :: upds), timed_out pass1_learn_timeout_factor) end
1.585 + val parents = parents_wrt_facts facts fact_graph
1.586 + val ((_, upds), _) =
1.587 + ((parents, []), false) |> fold do_fact new_facts |>> apsnd rev
1.588 + val n = length upds
1.589 + fun trans {thys, fact_graph} =
1.590 + let
1.591 + val mash_INIT_or_ADD =
1.592 + if Graph.is_empty fact_graph then mash_INIT else mash_ADD
1.593 + val (upds, fact_graph) =
1.594 + ([], fact_graph) |> fold (update_fact_graph ctxt) upds
1.595 + in
1.596 + (mash_INIT_or_ADD ctxt overlord (rev upds);
1.597 + {thys = thys |> add_thys_for thy,
1.598 + fact_graph = fact_graph})
1.599 + end
1.600 + in
1.601 + mash_map ctxt trans;
1.602 + if verbose then
1.603 + "Processed " ^ string_of_int n ^ " proof" ^ plural_s n ^
1.604 + (if verbose then
1.605 + " in " ^ string_from_time (Timer.checkRealTimer timer)
1.606 + else
1.607 + "") ^ "."
1.608 + else
1.609 + ""
1.610 + end
1.611 + end
1.612 +
1.613 +fun mash_learn_proof ctxt ({provers, overlord, ...} : params) t facts used_ths =
1.614 + let
1.615 + val thy = Proof_Context.theory_of ctxt
1.616 + val prover = hd provers
1.617 + val name = ATP_Util.timestamp () ^ " " ^ serial_string () (* fresh enough *)
1.618 + val feats = features_of ctxt prover thy General [t]
1.619 + val deps = used_ths |> map nickname_of
1.620 + in
1.621 + mash_map ctxt (fn {thys, fact_graph} =>
1.622 + let
1.623 + val parents = parents_wrt_facts facts fact_graph
1.624 + val upds = [(name, parents, feats, deps)]
1.625 + val (upds, fact_graph) =
1.626 + ([], fact_graph) |> fold (update_fact_graph ctxt) upds
1.627 + in
1.628 + mash_ADD ctxt overlord upds;
1.629 + {thys = thys, fact_graph = fact_graph}
1.630 + end)
1.631 + end
1.632 +
1.633 +(* The threshold should be large enough so that MaSh doesn't kick in for Auto
1.634 + Sledgehammer and Try. *)
1.635 +val min_secs_for_learning = 15
1.636 +val learn_timeout_factor = 2.0
1.637 +
1.638 +fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
1.639 + max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
1.640 + if not (subset (op =) (the_list fact_filter, fact_filters)) then
1.641 + error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
1.642 + else if only then
1.643 + facts
1.644 + else if max_facts <= 0 orelse null facts then
1.645 + []
1.646 + else
1.647 + let
1.648 + val thy = Proof_Context.theory_of ctxt
1.649 + fun maybe_learn () =
1.650 + if not learn orelse Async_Manager.has_running_threads MaShN then
1.651 + ()
1.652 + else if Time.toSeconds timeout >= min_secs_for_learning then
1.653 + let
1.654 + val soft_timeout = time_mult learn_timeout_factor timeout
1.655 + val hard_timeout = time_mult 4.0 soft_timeout
1.656 + val birth_time = Time.now ()
1.657 + val death_time = Time.+ (birth_time, hard_timeout)
1.658 + val desc = ("machine learner for Sledgehammer", "")
1.659 + in
1.660 + Async_Manager.launch MaShN birth_time death_time desc
1.661 + (fn () =>
1.662 + (true, mash_learn_thy ctxt params thy soft_timeout facts))
1.663 + end
1.664 + else
1.665 + ()
1.666 + val fact_filter =
1.667 + case fact_filter of
1.668 + SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
1.669 + | NONE =>
1.670 + if mash_can_suggest_facts ctxt then (maybe_learn (); meshN)
1.671 + else if mash_could_suggest_facts () then (maybe_learn (); mepoN)
1.672 + else mepoN
1.673 + val add_ths = Attrib.eval_thms ctxt add
1.674 + fun prepend_facts ths accepts =
1.675 + ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
1.676 + (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
1.677 + |> take max_facts
1.678 + fun iter () =
1.679 + iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
1.680 + concl_t facts
1.681 + fun mash () =
1.682 + mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts
1.683 + val mess =
1.684 + [] |> (if fact_filter <> mashN then cons (iter (), []) else I)
1.685 + |> (if fact_filter <> mepoN then cons (mash ()) else I)
1.686 + in
1.687 + mesh_facts max_facts mess
1.688 + |> not (null add_ths) ? prepend_facts add_ths
1.689 + end
1.690 +
1.691 +fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
1.692 +fun running_learners () = Async_Manager.running_threads MaShN "learner"
1.693 +
1.694 +end;