src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 49395 d4b7c7be3116
parent 49394 2b5ad61e2ccc
child 49396 1b7d798460bb
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:45 2012 +0200
     1.3 @@ -0,0 +1,691 @@
     1.4 +(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.5 +    Author:     Jasmin Blanchette, TU Muenchen
     1.6 +
     1.7 +Sledgehammer's machine-learning-based relevance filter (MaSh).
     1.8 +*)
     1.9 +
    1.10 +signature SLEDGEHAMMER_FILTER_MASH =
    1.11 +sig
    1.12 +  type status = ATP_Problem_Generate.status
    1.13 +  type stature = ATP_Problem_Generate.stature
    1.14 +  type fact = Sledgehammer_Fact.fact
    1.15 +  type fact_override = Sledgehammer_Fact.fact_override
    1.16 +  type params = Sledgehammer_Provers.params
    1.17 +  type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    1.18 +  type prover_result = Sledgehammer_Provers.prover_result
    1.19 +
    1.20 +  val trace : bool Config.T
    1.21 +  val MaShN : string
    1.22 +  val mepoN : string
    1.23 +  val mashN : string
    1.24 +  val meshN : string
    1.25 +  val fact_filters : string list
    1.26 +  val escape_meta : string -> string
    1.27 +  val escape_metas : string list -> string
    1.28 +  val unescape_meta : string -> string
    1.29 +  val unescape_metas : string -> string list
    1.30 +  val extract_query : string -> string * string list
    1.31 +  val nickname_of : thm -> string
    1.32 +  val suggested_facts : string list -> ('a * thm) list -> ('a * thm) list
    1.33 +  val mesh_facts :
    1.34 +    int -> (('a * thm) list * ('a * thm) list) list -> ('a * thm) list
    1.35 +  val is_likely_tautology_or_too_meta : thm -> bool
    1.36 +  val theory_ord : theory * theory -> order
    1.37 +  val thm_ord : thm * thm -> order
    1.38 +  val features_of :
    1.39 +    Proof.context -> string -> theory -> status -> term list -> string list
    1.40 +  val isabelle_dependencies_of : unit Symtab.table -> thm -> string list
    1.41 +  val goal_of_thm : theory -> thm -> thm
    1.42 +  val run_prover_for_mash :
    1.43 +    Proof.context -> params -> string -> fact list -> thm -> prover_result
    1.44 +  val mash_CLEAR : Proof.context -> unit
    1.45 +  val mash_INIT :
    1.46 +    Proof.context -> bool
    1.47 +    -> (string * string list * string list * string list) list -> unit
    1.48 +  val mash_ADD :
    1.49 +    Proof.context -> bool
    1.50 +    -> (string * string list * string list * string list) list -> unit
    1.51 +  val mash_QUERY :
    1.52 +    Proof.context -> bool -> int -> string list * string list -> string list
    1.53 +  val mash_unlearn : Proof.context -> unit
    1.54 +  val mash_could_suggest_facts : unit -> bool
    1.55 +  val mash_can_suggest_facts : Proof.context -> bool
    1.56 +  val mash_suggest_facts :
    1.57 +    Proof.context -> params -> string -> int -> term list -> term
    1.58 +    -> ('a * thm) list -> ('a * thm) list * ('a * thm) list
    1.59 +  val mash_learn_thy :
    1.60 +    Proof.context -> params -> theory -> Time.time -> fact list -> string
    1.61 +  val mash_learn_proof :
    1.62 +    Proof.context -> params -> term -> ('a * thm) list -> thm list -> unit
    1.63 +  val relevant_facts :
    1.64 +    Proof.context -> params -> string -> int -> fact_override -> term list
    1.65 +    -> term -> fact list -> fact list
    1.66 +  val kill_learners : unit -> unit
    1.67 +  val running_learners : unit -> unit
    1.68 +end;
    1.69 +
    1.70 +structure Sledgehammer_Filter_MaSh : SLEDGEHAMMER_FILTER_MASH =
    1.71 +struct
    1.72 +
    1.73 +open ATP_Util
    1.74 +open ATP_Problem_Generate
    1.75 +open Sledgehammer_Util
    1.76 +open Sledgehammer_Fact
    1.77 +open Sledgehammer_Filter_Iter
    1.78 +open Sledgehammer_Provers
    1.79 +open Sledgehammer_Minimize
    1.80 +
    1.81 +val trace =
    1.82 +  Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
    1.83 +fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    1.84 +
    1.85 +val MaShN = "MaSh"
    1.86 +
    1.87 +val mepoN = "mepo"
    1.88 +val mashN = "mash"
    1.89 +val meshN = "mesh"
    1.90 +
    1.91 +val fact_filters = [meshN, mepoN, mashN]
    1.92 +
    1.93 +fun mash_home () = getenv "MASH_HOME"
    1.94 +fun mash_state_dir () =
    1.95 +  getenv "ISABELLE_HOME_USER" ^ "/mash"
    1.96 +  |> tap (Isabelle_System.mkdir o Path.explode)
    1.97 +fun mash_state_path () = mash_state_dir () ^ "/state" |> Path.explode
    1.98 +
    1.99 +
   1.100 +(*** Isabelle helpers ***)
   1.101 +
   1.102 +fun meta_char c =
   1.103 +  if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   1.104 +     c = #")" orelse c = #"," then
   1.105 +    String.str c
   1.106 +  else
   1.107 +    (* fixed width, in case more digits follow *)
   1.108 +    "\\" ^ stringN_of_int 3 (Char.ord c)
   1.109 +
   1.110 +fun unmeta_chars accum [] = String.implode (rev accum)
   1.111 +  | unmeta_chars accum (#"\\" :: d1 :: d2 :: d3 :: cs) =
   1.112 +    (case Int.fromString (String.implode [d1, d2, d3]) of
   1.113 +       SOME n => unmeta_chars (Char.chr n :: accum) cs
   1.114 +     | NONE => "" (* error *))
   1.115 +  | unmeta_chars _ (#"\\" :: _) = "" (* error *)
   1.116 +  | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   1.117 +
   1.118 +val escape_meta = String.translate meta_char
   1.119 +val escape_metas = map escape_meta #> space_implode " "
   1.120 +val unescape_meta = String.explode #> unmeta_chars []
   1.121 +val unescape_metas =
   1.122 +  space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
   1.123 +
   1.124 +fun extract_query line =
   1.125 +  case space_explode ":" line of
   1.126 +    [goal_name, suggs] => (unescape_meta goal_name, unescape_metas suggs)
   1.127 +  | _ => ("", [])
   1.128 +
   1.129 +fun parent_of_local_thm th =
   1.130 +  let
   1.131 +    val thy = th |> Thm.theory_of_thm
   1.132 +    val facts = thy |> Global_Theory.facts_of
   1.133 +    val space = facts |> Facts.space_of
   1.134 +    fun id_of s = #id (Name_Space.the_entry space s)
   1.135 +    fun max_id (s', _) (s, id) =
   1.136 +      let val id' = id_of s' in if id > id' then (s, id) else (s', id') end
   1.137 +  in ("", ~1) |> Facts.fold_static max_id facts |> fst end
   1.138 +
   1.139 +val local_prefix = "local" ^ Long_Name.separator
   1.140 +
   1.141 +fun nickname_of th =
   1.142 +  let val hint = Thm.get_name_hint th in
   1.143 +    (* FIXME: There must be a better way to detect local facts. *)
   1.144 +    case try (unprefix local_prefix) hint of
   1.145 +      SOME suff =>
   1.146 +      parent_of_local_thm th ^ Long_Name.separator ^ Long_Name.separator ^ suff
   1.147 +    | NONE => hint
   1.148 +  end
   1.149 +
   1.150 +fun suggested_facts suggs facts =
   1.151 +  let
   1.152 +    fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
   1.153 +    val tab = Symtab.empty |> fold add_fact facts
   1.154 +  in map_filter (Symtab.lookup tab) suggs end
   1.155 +
   1.156 +(* Ad hoc score function roughly based on Blanchette's Ringberg 2011 data. *)
   1.157 +fun score x = Math.pow (1.5, 15.5 - 0.05 * Real.fromInt x) + 15.0
   1.158 +
   1.159 +fun sum_sq_avg [] = 0
   1.160 +  | sum_sq_avg xs =
   1.161 +    Real.ceil (100000.0 * fold (curry (op +) o score) xs 0.0) div length xs
   1.162 +
   1.163 +fun mesh_facts max_facts [(selected, unknown)] =
   1.164 +    take max_facts selected @ take (max_facts - length selected) unknown
   1.165 +  | mesh_facts max_facts mess =
   1.166 +    let
   1.167 +      val mess = mess |> map (apfst (`length))
   1.168 +      val fact_eq = Thm.eq_thm o pairself snd
   1.169 +      fun score_in fact ((sel_len, sels), unks) =
   1.170 +        case find_index (curry fact_eq fact) sels of
   1.171 +          ~1 => (case find_index (curry fact_eq fact) unks of
   1.172 +                   ~1 => SOME sel_len
   1.173 +                 | _ => NONE)
   1.174 +        | j => SOME j
   1.175 +      fun score_of fact = mess |> map_filter (score_in fact) |> sum_sq_avg
   1.176 +      val facts = fold (union fact_eq o take max_facts o snd o fst) mess []
   1.177 +    in
   1.178 +      facts |> map (`score_of) |> sort (int_ord o swap o pairself fst)
   1.179 +            |> map snd |> take max_facts
   1.180 +    end
   1.181 +
   1.182 +val thy_feature_prefix = "y_"
   1.183 +
   1.184 +val thy_feature_name_of = prefix thy_feature_prefix
   1.185 +val const_name_of = prefix const_prefix
   1.186 +val type_name_of = prefix type_const_prefix
   1.187 +val class_name_of = prefix class_prefix
   1.188 +
   1.189 +fun is_likely_tautology_or_too_meta th =
   1.190 +  let
   1.191 +    val is_boring_const = member (op =) atp_widely_irrelevant_consts
   1.192 +    fun is_boring_bool t =
   1.193 +      not (exists_Const (not o is_boring_const o fst) t) orelse
   1.194 +      exists_type (exists_subtype (curry (op =) @{typ prop})) t
   1.195 +    fun is_boring_prop (@{const Trueprop} $ t) = is_boring_bool t
   1.196 +      | is_boring_prop (@{const "==>"} $ t $ u) =
   1.197 +        is_boring_prop t andalso is_boring_prop u
   1.198 +      | is_boring_prop (Const (@{const_name all}, _) $ (Abs (_, _, t)) $ u) =
   1.199 +        is_boring_prop t andalso is_boring_prop u
   1.200 +      | is_boring_prop (Const (@{const_name "=="}, _) $ t $ u) =
   1.201 +        is_boring_bool t andalso is_boring_bool u
   1.202 +      | is_boring_prop _ = true
   1.203 +  in
   1.204 +    is_boring_prop (prop_of th) andalso not (Thm.eq_thm_prop (@{thm ext}, th))
   1.205 +  end
   1.206 +
   1.207 +fun theory_ord p =
   1.208 +  if Theory.eq_thy p then
   1.209 +    EQUAL
   1.210 +  else if Theory.subthy p then
   1.211 +    LESS
   1.212 +  else if Theory.subthy (swap p) then
   1.213 +    GREATER
   1.214 +  else case int_ord (pairself (length o Theory.ancestors_of) p) of
   1.215 +    EQUAL => string_ord (pairself Context.theory_name p)
   1.216 +  | order => order
   1.217 +
   1.218 +val thm_ord = theory_ord o pairself theory_of_thm
   1.219 +
   1.220 +val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   1.221 +
   1.222 +fun interesting_terms_types_and_classes ctxt prover term_max_depth
   1.223 +                                        type_max_depth ts =
   1.224 +  let
   1.225 +    fun is_bad_const (x as (s, _)) args =
   1.226 +      member (op =) atp_logical_consts s orelse
   1.227 +      fst (is_built_in_const_for_prover ctxt prover x args)
   1.228 +    fun add_classes @{sort type} = I
   1.229 +      | add_classes S = union (op =) (map class_name_of S)
   1.230 +    fun do_add_type (Type (s, Ts)) =
   1.231 +        (not (member (op =) bad_types s) ? insert (op =) (type_name_of s))
   1.232 +        #> fold do_add_type Ts
   1.233 +      | do_add_type (TFree (_, S)) = add_classes S
   1.234 +      | do_add_type (TVar (_, S)) = add_classes S
   1.235 +    fun add_type T = type_max_depth >= 0 ? do_add_type T
   1.236 +    fun mk_app s args =
   1.237 +      if member (op <>) args "" then s ^ "(" ^ space_implode "," args ^ ")"
   1.238 +      else s
   1.239 +    fun patternify ~1 _ = ""
   1.240 +      | patternify depth t =
   1.241 +        case strip_comb t of
   1.242 +          (Const (s, _), args) =>
   1.243 +          mk_app (const_name_of s) (map (patternify (depth - 1)) args)
   1.244 +        | _ => ""
   1.245 +    fun add_term_patterns ~1 _ = I
   1.246 +      | add_term_patterns depth t =
   1.247 +        insert (op =) (patternify depth t)
   1.248 +        #> add_term_patterns (depth - 1) t
   1.249 +    val add_term = add_term_patterns term_max_depth
   1.250 +    fun add_patterns t =
   1.251 +      let val (head, args) = strip_comb t in
   1.252 +        (case head of
   1.253 +           Const (x as (_, T)) =>
   1.254 +           not (is_bad_const x args) ? (add_term t #> add_type T)
   1.255 +         | Free (_, T) => add_type T
   1.256 +         | Var (_, T) => add_type T
   1.257 +         | Abs (_, T, body) => add_type T #> add_patterns body
   1.258 +         | _ => I)
   1.259 +        #> fold add_patterns args
   1.260 +      end
   1.261 +  in [] |> fold add_patterns ts end
   1.262 +
   1.263 +fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
   1.264 +
   1.265 +val term_max_depth = 1
   1.266 +val type_max_depth = 1
   1.267 +
   1.268 +(* TODO: Generate type classes for types? *)
   1.269 +fun features_of ctxt prover thy status ts =
   1.270 +  thy_feature_name_of (Context.theory_name thy) ::
   1.271 +  interesting_terms_types_and_classes ctxt prover term_max_depth type_max_depth
   1.272 +                                      ts
   1.273 +  |> forall is_lambda_free ts ? cons "no_lams"
   1.274 +  |> forall (not o exists_Const is_exists) ts ? cons "no_skos"
   1.275 +  |> (case status of
   1.276 +        General => I
   1.277 +      | Induction => cons "induction"
   1.278 +      | Intro => cons "intro"
   1.279 +      | Inductive => cons "inductive"
   1.280 +      | Elim => cons "elim"
   1.281 +      | Simp => cons "simp"
   1.282 +      | Def => cons "def")
   1.283 +
   1.284 +fun isabelle_dependencies_of all_facts = thms_in_proof (SOME all_facts)
   1.285 +
   1.286 +val freezeT = Type.legacy_freeze_type
   1.287 +
   1.288 +fun freeze (t $ u) = freeze t $ freeze u
   1.289 +  | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   1.290 +  | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   1.291 +  | freeze (Const (s, T)) = Const (s, freezeT T)
   1.292 +  | freeze (Free (s, T)) = Free (s, freezeT T)
   1.293 +  | freeze t = t
   1.294 +
   1.295 +fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   1.296 +
   1.297 +fun run_prover_for_mash ctxt params prover facts goal =
   1.298 +  let
   1.299 +    val problem =
   1.300 +      {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
   1.301 +       facts = facts |> map (apfst (apfst (fn name => name ())))
   1.302 +                     |> map Untranslated_Fact}
   1.303 +    val prover = get_minimizing_prover ctxt Normal (K ()) prover
   1.304 +  in prover params (K (K (K ""))) problem end
   1.305 +
   1.306 +
   1.307 +(*** Low-level communication with MaSh ***)
   1.308 +
   1.309 +fun write_file (xs, f) file =
   1.310 +  let val path = Path.explode file in
   1.311 +    File.write path "";
   1.312 +    xs |> chunk_list 500
   1.313 +       |> List.app (File.append path o space_implode "" o map f)
   1.314 +  end
   1.315 +
   1.316 +fun mash_info overlord =
   1.317 +  if overlord then (getenv "ISABELLE_HOME_USER", "")
   1.318 +  else (getenv "ISABELLE_TMP", serial_string ())
   1.319 +
   1.320 +fun run_mash ctxt overlord (temp_dir, serial) core =
   1.321 +  let
   1.322 +    val log_file = temp_dir ^ "/mash_log" ^ serial
   1.323 +    val err_file = temp_dir ^ "/mash_err" ^ serial
   1.324 +    val command =
   1.325 +      mash_home () ^ "/mash.py --quiet --outputDir " ^ mash_state_dir () ^
   1.326 +      " --log " ^ log_file ^ " " ^ core ^ " 2>&1 > " ^ err_file
   1.327 +  in
   1.328 +    trace_msg ctxt (fn () => "Running " ^ command);
   1.329 +    write_file ([], K "") log_file;
   1.330 +    write_file ([], K "") err_file;
   1.331 +    Isabelle_System.bash command;
   1.332 +    if overlord then ()
   1.333 +    else (map (File.rm o Path.explode) [log_file, err_file]; ());
   1.334 +    trace_msg ctxt (K "Done")
   1.335 +  end
   1.336 +
   1.337 +(* TODO: Eliminate code once "mash.py" handles sequences of ADD commands as fast
   1.338 +   as a single INIT. *)
   1.339 +fun run_mash_init ctxt overlord write_access write_feats write_deps =
   1.340 +  let
   1.341 +    val info as (temp_dir, serial) = mash_info overlord
   1.342 +    val in_dir = temp_dir ^ "/mash_init" ^ serial
   1.343 +    val in_dir_path = in_dir |> Path.explode |> tap Isabelle_System.mkdir
   1.344 +  in
   1.345 +    write_file write_access (in_dir ^ "/mash_accessibility");
   1.346 +    write_file write_feats (in_dir ^ "/mash_features");
   1.347 +    write_file write_deps (in_dir ^ "/mash_dependencies");
   1.348 +    run_mash ctxt overlord info ("--init --inputDir " ^ in_dir);
   1.349 +    (* FIXME: temporary hack *)
   1.350 +    if overlord then ()
   1.351 +    else (Isabelle_System.bash ("rm -r -f " ^ File.shell_path in_dir_path); ())
   1.352 +  end
   1.353 +
   1.354 +fun run_mash_commands ctxt overlord save max_suggs write_cmds read_suggs =
   1.355 +  let
   1.356 +    val info as (temp_dir, serial) = mash_info overlord
   1.357 +    val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   1.358 +    val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   1.359 +  in
   1.360 +    write_file ([], K "") sugg_file;
   1.361 +    write_file write_cmds cmd_file;
   1.362 +    run_mash ctxt overlord info
   1.363 +             ("--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   1.364 +              " --numberOfPredictions " ^ string_of_int max_suggs ^
   1.365 +              (if save then " --saveModel" else ""));
   1.366 +    read_suggs (fn () => File.read_lines (Path.explode sugg_file))
   1.367 +    |> tap (fn _ =>
   1.368 +               if overlord then ()
   1.369 +               else (map (File.rm o Path.explode) [sugg_file, cmd_file]; ()))
   1.370 +  end
   1.371 +
   1.372 +fun str_of_update (name, parents, feats, deps) =
   1.373 +  "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
   1.374 +  escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
   1.375 +
   1.376 +fun str_of_query (parents, feats) =
   1.377 +  "? " ^ escape_metas parents ^ "; " ^ escape_metas feats
   1.378 +
   1.379 +fun init_str_of_update get (upd as (name, _, _, _)) =
   1.380 +  escape_meta name ^ ": " ^ escape_metas (get upd) ^ "\n"
   1.381 +
   1.382 +fun mash_CLEAR ctxt =
   1.383 +  let val path = mash_state_dir () |> Path.explode in
   1.384 +    trace_msg ctxt (K "MaSh CLEAR");
   1.385 +    File.fold_dir (fn file => fn () =>
   1.386 +                      File.rm (Path.append path (Path.basic file)))
   1.387 +                  path ()
   1.388 +  end
   1.389 +
   1.390 +fun mash_INIT ctxt _ [] = mash_CLEAR ctxt
   1.391 +  | mash_INIT ctxt overlord upds =
   1.392 +    (trace_msg ctxt (fn () => "MaSh INIT " ^
   1.393 +         elide_string 1000 (space_implode " " (map #1 upds)));
   1.394 +     run_mash_init ctxt overlord (upds, init_str_of_update #2)
   1.395 +                   (upds, init_str_of_update #3) (upds, init_str_of_update #4))
   1.396 +
   1.397 +fun mash_ADD _ _ [] = ()
   1.398 +  | mash_ADD ctxt overlord upds =
   1.399 +    (trace_msg ctxt (fn () => "MaSh ADD " ^
   1.400 +         elide_string 1000 (space_implode " " (map #1 upds)));
   1.401 +     run_mash_commands ctxt overlord true 0 (upds, str_of_update) (K ()))
   1.402 +
   1.403 +fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
   1.404 +  (trace_msg ctxt (fn () => "MaSh QUERY " ^ space_implode " " feats);
   1.405 +   run_mash_commands ctxt overlord false max_suggs
   1.406 +       ([query], str_of_query)
   1.407 +       (fn suggs => snd (extract_query (List.last (suggs ()))))
   1.408 +   handle List.Empty => [])
   1.409 +
   1.410 +
   1.411 +(*** High-level communication with MaSh ***)
   1.412 +
   1.413 +fun try_graph ctxt when def f =
   1.414 +  f ()
   1.415 +  handle Graph.CYCLES (cycle :: _) =>
   1.416 +         (trace_msg ctxt (fn () =>
   1.417 +              "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   1.418 +       | Graph.UNDEF name =>
   1.419 +         (trace_msg ctxt (fn () =>
   1.420 +              "Unknown fact " ^ quote name ^ " when " ^ when); def)
   1.421 +
   1.422 +type mash_state =
   1.423 +  {thys : bool Symtab.table,
   1.424 +   fact_graph : unit Graph.T}
   1.425 +
   1.426 +val empty_state = {thys = Symtab.empty, fact_graph = Graph.empty}
   1.427 +
   1.428 +local
   1.429 +
   1.430 +fun mash_load _ (state as (true, _)) = state
   1.431 +  | mash_load ctxt _ =
   1.432 +    let val path = mash_state_path () in
   1.433 +      (true,
   1.434 +       case try File.read_lines path of
   1.435 +         SOME (comp_thys :: incomp_thys :: fact_lines) =>
   1.436 +         let
   1.437 +           fun add_thy comp thy = Symtab.update (thy, comp)
   1.438 +           fun add_edge_to name parent =
   1.439 +             Graph.default_node (parent, ())
   1.440 +             #> Graph.add_edge (parent, name)
   1.441 +           fun add_fact_line line =
   1.442 +             case extract_query line of
   1.443 +               ("", _) => I (* shouldn't happen *)
   1.444 +             | (name, parents) =>
   1.445 +               Graph.default_node (name, ())
   1.446 +               #> fold (add_edge_to name) parents
   1.447 +           val thys =
   1.448 +             Symtab.empty |> fold (add_thy true) (unescape_metas comp_thys)
   1.449 +                          |> fold (add_thy false) (unescape_metas incomp_thys)
   1.450 +           val fact_graph =
   1.451 +             try_graph ctxt "loading state" Graph.empty (fn () =>
   1.452 +                 Graph.empty |> fold add_fact_line fact_lines)
   1.453 +         in {thys = thys, fact_graph = fact_graph} end
   1.454 +       | _ => empty_state)
   1.455 +    end
   1.456 +
   1.457 +fun mash_save ({thys, fact_graph, ...} : mash_state) =
   1.458 +  let
   1.459 +    val path = mash_state_path ()
   1.460 +    val thys = Symtab.dest thys
   1.461 +    val line_for_thys = escape_metas o AList.find (op =) thys
   1.462 +    fun fact_line_for name parents =
   1.463 +      escape_meta name ^ ": " ^ escape_metas parents
   1.464 +    val append_fact = File.append path o suffix "\n" oo fact_line_for
   1.465 +  in
   1.466 +    File.write path (line_for_thys true ^ "\n" ^ line_for_thys false ^ "\n");
   1.467 +    Graph.fold (fn (name, ((), (parents, _))) => fn () =>
   1.468 +                   append_fact name (Graph.Keys.dest parents))
   1.469 +        fact_graph ()
   1.470 +  end
   1.471 +
   1.472 +val global_state =
   1.473 +  Synchronized.var "Sledgehammer_Filter_MaSh.global_state" (false, empty_state)
   1.474 +
   1.475 +in
   1.476 +
   1.477 +fun mash_map ctxt f =
   1.478 +  Synchronized.change global_state (mash_load ctxt ##> (f #> tap mash_save))
   1.479 +
   1.480 +fun mash_get ctxt =
   1.481 +  Synchronized.change_result global_state (mash_load ctxt #> `snd)
   1.482 +
   1.483 +fun mash_unlearn ctxt =
   1.484 +  Synchronized.change global_state (fn _ =>
   1.485 +      (mash_CLEAR ctxt; File.write (mash_state_path ()) "";
   1.486 +       (true, empty_state)))
   1.487 +
   1.488 +end
   1.489 +
   1.490 +fun mash_could_suggest_facts () = mash_home () <> ""
   1.491 +fun mash_can_suggest_facts ctxt =
   1.492 +  not (Graph.is_empty (#fact_graph (mash_get ctxt)))
   1.493 +
   1.494 +fun parents_wrt_facts facts fact_graph =
   1.495 +  let
   1.496 +    val facts = [] |> fold (cons o nickname_of o snd) facts
   1.497 +    val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
   1.498 +    fun insert_not_seen seen name =
   1.499 +      not (member (op =) seen name) ? insert (op =) name
   1.500 +    fun parents_of _ parents [] = parents
   1.501 +      | parents_of seen parents (name :: names) =
   1.502 +        if Symtab.defined tab name then
   1.503 +          parents_of (name :: seen) (name :: parents) names
   1.504 +        else
   1.505 +          parents_of (name :: seen) parents
   1.506 +                     (Graph.Keys.fold (insert_not_seen seen)
   1.507 +                                      (Graph.imm_preds fact_graph name) names)
   1.508 +  in parents_of [] [] (Graph.maximals fact_graph) end
   1.509 +
   1.510 +(* Generate more suggestions than requested, because some might be thrown out
   1.511 +   later for various reasons and "meshing" gives better results with some
   1.512 +   slack. *)
   1.513 +fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
   1.514 +
   1.515 +fun is_fact_in_graph fact_graph (_, th) =
   1.516 +  can (Graph.get_node fact_graph) (nickname_of th)
   1.517 +
   1.518 +fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   1.519 +                       concl_t facts =
   1.520 +  let
   1.521 +    val thy = Proof_Context.theory_of ctxt
   1.522 +    val fact_graph = #fact_graph (mash_get ctxt)
   1.523 +    val parents = parents_wrt_facts facts fact_graph
   1.524 +    val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
   1.525 +    val suggs =
   1.526 +      if Graph.is_empty fact_graph then []
   1.527 +      else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   1.528 +    val selected = facts |> suggested_facts suggs
   1.529 +    val unknown = facts |> filter_out (is_fact_in_graph fact_graph)
   1.530 +  in (selected, unknown) end
   1.531 +
   1.532 +fun add_thys_for thy =
   1.533 +  let fun add comp thy = Symtab.update (Context.theory_name thy, comp) in
   1.534 +    add false thy #> fold (add true) (Theory.ancestors_of thy)
   1.535 +  end
   1.536 +
   1.537 +fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
   1.538 +  let
   1.539 +    fun maybe_add_from from (accum as (parents, graph)) =
   1.540 +      try_graph ctxt "updating graph" accum (fn () =>
   1.541 +          (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   1.542 +    val graph = graph |> Graph.default_node (name, ())
   1.543 +    val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   1.544 +    val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   1.545 +  in ((name, parents, feats, deps) :: upds, graph) end
   1.546 +
   1.547 +val pass1_learn_timeout_factor = 0.5
   1.548 +
   1.549 +(* Too many dependencies is a sign that a decision procedure is at work. There
   1.550 +   isn't much too learn from such proofs. *)
   1.551 +val max_dependencies = 10
   1.552 +
   1.553 +(* The timeout is understood in a very slack fashion. *)
   1.554 +fun mash_learn_thy ctxt ({provers, verbose, overlord, ...} : params) thy timeout
   1.555 +                   facts =
   1.556 +  let
   1.557 +    val timer = Timer.startRealTimer ()
   1.558 +    val prover = hd provers
   1.559 +    fun timed_out frac =
   1.560 +      Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
   1.561 +    val {fact_graph, ...} = mash_get ctxt
   1.562 +    val new_facts =
   1.563 +      facts |> filter_out (is_fact_in_graph fact_graph)
   1.564 +            |> sort (thm_ord o pairself snd)
   1.565 +  in
   1.566 +    if null new_facts then
   1.567 +      ""
   1.568 +    else
   1.569 +      let
   1.570 +        val ths = facts |> map snd
   1.571 +        val all_names =
   1.572 +          ths |> filter_out is_likely_tautology_or_too_meta
   1.573 +              |> map (rpair () o nickname_of)
   1.574 +              |> Symtab.make
   1.575 +        fun trim_deps deps = if length deps > max_dependencies then [] else deps
   1.576 +        fun do_fact _ (accum as (_, true)) = accum
   1.577 +          | do_fact ((_, (_, status)), th) ((parents, upds), false) =
   1.578 +            let
   1.579 +              val name = nickname_of th
   1.580 +              val feats =
   1.581 +                features_of ctxt prover (theory_of_thm th) status [prop_of th]
   1.582 +              val deps = isabelle_dependencies_of all_names th |> trim_deps
   1.583 +              val upd = (name, parents, feats, deps)
   1.584 +            in (([name], upd :: upds), timed_out pass1_learn_timeout_factor) end
   1.585 +        val parents = parents_wrt_facts facts fact_graph
   1.586 +        val ((_, upds), _) =
   1.587 +          ((parents, []), false) |> fold do_fact new_facts |>> apsnd rev
   1.588 +        val n = length upds
   1.589 +        fun trans {thys, fact_graph} =
   1.590 +          let
   1.591 +            val mash_INIT_or_ADD =
   1.592 +              if Graph.is_empty fact_graph then mash_INIT else mash_ADD
   1.593 +            val (upds, fact_graph) =
   1.594 +              ([], fact_graph) |> fold (update_fact_graph ctxt) upds
   1.595 +          in
   1.596 +            (mash_INIT_or_ADD ctxt overlord (rev upds);
   1.597 +             {thys = thys |> add_thys_for thy,
   1.598 +              fact_graph = fact_graph})
   1.599 +          end
   1.600 +      in
   1.601 +        mash_map ctxt trans;
   1.602 +        if verbose then
   1.603 +          "Processed " ^ string_of_int n ^ " proof" ^ plural_s n ^
   1.604 +          (if verbose then
   1.605 +             " in " ^ string_from_time (Timer.checkRealTimer timer)
   1.606 +           else
   1.607 +             "") ^ "."
   1.608 +        else
   1.609 +          ""
   1.610 +      end
   1.611 +  end
   1.612 +
   1.613 +fun mash_learn_proof ctxt ({provers, overlord, ...} : params) t facts used_ths =
   1.614 +  let
   1.615 +    val thy = Proof_Context.theory_of ctxt
   1.616 +    val prover = hd provers
   1.617 +    val name = ATP_Util.timestamp () ^ " " ^ serial_string () (* fresh enough *)
   1.618 +    val feats = features_of ctxt prover thy General [t]
   1.619 +    val deps = used_ths |> map nickname_of
   1.620 +  in
   1.621 +    mash_map ctxt (fn {thys, fact_graph} =>
   1.622 +        let
   1.623 +          val parents = parents_wrt_facts facts fact_graph
   1.624 +          val upds = [(name, parents, feats, deps)]
   1.625 +          val (upds, fact_graph) =
   1.626 +            ([], fact_graph) |> fold (update_fact_graph ctxt) upds
   1.627 +        in
   1.628 +          mash_ADD ctxt overlord upds;
   1.629 +          {thys = thys, fact_graph = fact_graph}
   1.630 +        end)
   1.631 +  end
   1.632 +
   1.633 +(* The threshold should be large enough so that MaSh doesn't kick in for Auto
   1.634 +   Sledgehammer and Try. *)
   1.635 +val min_secs_for_learning = 15
   1.636 +val learn_timeout_factor = 2.0
   1.637 +
   1.638 +fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
   1.639 +        max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
   1.640 +  if not (subset (op =) (the_list fact_filter, fact_filters)) then
   1.641 +    error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
   1.642 +  else if only then
   1.643 +    facts
   1.644 +  else if max_facts <= 0 orelse null facts then
   1.645 +    []
   1.646 +  else
   1.647 +    let
   1.648 +      val thy = Proof_Context.theory_of ctxt
   1.649 +      fun maybe_learn () =
   1.650 +        if not learn orelse Async_Manager.has_running_threads MaShN then
   1.651 +          ()
   1.652 +        else if Time.toSeconds timeout >= min_secs_for_learning then
   1.653 +          let
   1.654 +            val soft_timeout = time_mult learn_timeout_factor timeout
   1.655 +            val hard_timeout = time_mult 4.0 soft_timeout
   1.656 +            val birth_time = Time.now ()
   1.657 +            val death_time = Time.+ (birth_time, hard_timeout)
   1.658 +            val desc = ("machine learner for Sledgehammer", "")
   1.659 +          in
   1.660 +            Async_Manager.launch MaShN birth_time death_time desc
   1.661 +                (fn () =>
   1.662 +                    (true, mash_learn_thy ctxt params thy soft_timeout facts))
   1.663 +          end
   1.664 +        else
   1.665 +          ()
   1.666 +      val fact_filter =
   1.667 +        case fact_filter of
   1.668 +          SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
   1.669 +        | NONE =>
   1.670 +          if mash_can_suggest_facts ctxt then (maybe_learn (); meshN)
   1.671 +          else if mash_could_suggest_facts () then (maybe_learn (); mepoN)
   1.672 +          else mepoN
   1.673 +      val add_ths = Attrib.eval_thms ctxt add
   1.674 +      fun prepend_facts ths accepts =
   1.675 +        ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
   1.676 +         (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
   1.677 +        |> take max_facts
   1.678 +      fun iter () =
   1.679 +        iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
   1.680 +                                 concl_t facts
   1.681 +      fun mash () =
   1.682 +        mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts
   1.683 +      val mess =
   1.684 +        [] |> (if fact_filter <> mashN then cons (iter (), []) else I)
   1.685 +           |> (if fact_filter <> mepoN then cons (mash ()) else I)
   1.686 +    in
   1.687 +      mesh_facts max_facts mess
   1.688 +      |> not (null add_ths) ? prepend_facts add_ths
   1.689 +    end
   1.690 +
   1.691 +fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
   1.692 +fun running_learners () = Async_Manager.running_threads MaShN "learner"
   1.693 +
   1.694 +end;