src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
author blanchet
Fri, 20 Jul 2012 22:19:46 +0200
changeset 49407 ca998fa08cd9
parent 49405 4147f2bc4442
child 49409 82fc8c956cdc
permissions -rw-r--r--
added "learn_from_atp" command to MaSh, for patient users
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3 
     4 Sledgehammer's machine-learning-based relevance filter (MaSh).
     5 *)
     6 
     7 signature SLEDGEHAMMER_MASH =
     8 sig
     9   type stature = ATP_Problem_Generate.stature
    10   type fact = Sledgehammer_Fact.fact
    11   type fact_override = Sledgehammer_Fact.fact_override
    12   type params = Sledgehammer_Provers.params
    13   type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    14   type prover_result = Sledgehammer_Provers.prover_result
    15 
    16   val trace : bool Config.T
    17   val MaShN : string
    18   val mepoN : string
    19   val mashN : string
    20   val meshN : string
    21   val unlearnN : string
    22   val learn_isarN : string
    23   val learn_atpN : string
    24   val relearn_isarN : string
    25   val relearn_atpN : string
    26   val fact_filters : string list
    27   val escape_meta : string -> string
    28   val escape_metas : string list -> string
    29   val unescape_meta : string -> string
    30   val unescape_metas : string -> string list
    31   val extract_query : string -> string * string list
    32   val nickname_of : thm -> string
    33   val suggested_facts : string list -> ('a * thm) list -> ('a * thm) list
    34   val mesh_facts :
    35     int -> (('a * thm) list * ('a * thm) list) list -> ('a * thm) list
    36   val is_likely_tautology_or_too_meta : thm -> bool
    37   val theory_ord : theory * theory -> order
    38   val thm_ord : thm * thm -> order
    39   val goal_of_thm : theory -> thm -> thm
    40   val run_prover_for_mash :
    41     Proof.context -> params -> string -> fact list -> thm -> prover_result
    42   val features_of :
    43     Proof.context -> string -> theory -> stature -> term list -> string list
    44   val isar_dependencies_of : unit Symtab.table -> thm -> string list
    45   val atp_dependencies_of :
    46     Proof.context -> params -> string -> bool -> fact list -> unit Symtab.table
    47     -> thm -> string list
    48   val mash_CLEAR : Proof.context -> unit
    49   val mash_ADD :
    50     Proof.context -> bool
    51     -> (string * string list * string list * string list) list -> unit
    52   val mash_QUERY :
    53     Proof.context -> bool -> int -> string list * string list -> string list
    54   val mash_unlearn : Proof.context -> unit
    55   val mash_could_suggest_facts : unit -> bool
    56   val mash_can_suggest_facts : Proof.context -> bool
    57   val mash_suggest_facts :
    58     Proof.context -> params -> string -> int -> term list -> term
    59     -> ('a * thm) list -> ('a * thm) list * ('a * thm) list
    60   val mash_learn_proof :
    61     Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    62     -> unit
    63   val mash_learn_facts :
    64     Proof.context -> params -> string -> bool -> bool -> Time.time -> fact list
    65     -> string
    66   val mash_learn : Proof.context -> params -> bool -> unit
    67   val relevant_facts :
    68     Proof.context -> params -> string -> int -> fact_override -> term list
    69     -> term -> fact list -> fact list
    70   val kill_learners : unit -> unit
    71   val running_learners : unit -> unit
    72 end;
    73 
    74 structure Sledgehammer_MaSh : SLEDGEHAMMER_MASH =
    75 struct
    76 
    77 open ATP_Util
    78 open ATP_Problem_Generate
    79 open Sledgehammer_Util
    80 open Sledgehammer_Fact
    81 open Sledgehammer_Provers
    82 open Sledgehammer_Minimize
    83 open Sledgehammer_MePo
    84 
    85 val trace =
    86   Attrib.setup_config_bool @{binding sledgehammer_mash_trace} (K false)
    87 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    88 
    89 val MaShN = "MaSh"
    90 
    91 val mepoN = "mepo"
    92 val mashN = "mash"
    93 val meshN = "mesh"
    94 
    95 val fact_filters = [meshN, mepoN, mashN]
    96 
    97 val unlearnN = "unlearn"
    98 val learn_isarN = "learn_isar"
    99 val learn_atpN = "learn_atp"
   100 val relearn_isarN = "relearn_isar"
   101 val relearn_atpN = "relearn_atp"
   102 
   103 fun mash_home () = getenv "MASH_HOME"
   104 fun mash_state_dir () =
   105   getenv "ISABELLE_HOME_USER" ^ "/mash"
   106   |> tap (Isabelle_System.mkdir o Path.explode)
   107 fun mash_state_path () = mash_state_dir () ^ "/state" |> Path.explode
   108 
   109 
   110 (*** Isabelle helpers ***)
   111 
   112 fun meta_char c =
   113   if Char.isAlphaNum c orelse c = #"_" orelse c = #"." orelse c = #"(" orelse
   114      c = #")" orelse c = #"," then
   115     String.str c
   116   else
   117     (* fixed width, in case more digits follow *)
   118     "\\" ^ stringN_of_int 3 (Char.ord c)
   119 
   120 fun unmeta_chars accum [] = String.implode (rev accum)
   121   | unmeta_chars accum (#"\\" :: d1 :: d2 :: d3 :: cs) =
   122     (case Int.fromString (String.implode [d1, d2, d3]) of
   123        SOME n => unmeta_chars (Char.chr n :: accum) cs
   124      | NONE => "" (* error *))
   125   | unmeta_chars _ (#"\\" :: _) = "" (* error *)
   126   | unmeta_chars accum (c :: cs) = unmeta_chars (c :: accum) cs
   127 
   128 val escape_meta = String.translate meta_char
   129 val escape_metas = map escape_meta #> space_implode " "
   130 val unescape_meta = String.explode #> unmeta_chars []
   131 val unescape_metas =
   132   space_explode " " #> filter_out (curry (op =) "") #> map unescape_meta
   133 
   134 fun extract_query line =
   135   case space_explode ":" line of
   136     [goal_name, suggs] => (unescape_meta goal_name, unescape_metas suggs)
   137   | _ => ("", [])
   138 
   139 fun parent_of_local_thm th =
   140   let
   141     val thy = th |> Thm.theory_of_thm
   142     val facts = thy |> Global_Theory.facts_of
   143     val space = facts |> Facts.space_of
   144     fun id_of s = #id (Name_Space.the_entry space s)
   145     fun max_id (s', _) (s, id) =
   146       let val id' = id_of s' in if id > id' then (s, id) else (s', id') end
   147   in ("", ~1) |> Facts.fold_static max_id facts |> fst end
   148 
   149 val local_prefix = "local" ^ Long_Name.separator
   150 
   151 fun nickname_of th =
   152   let val hint = Thm.get_name_hint th in
   153     (* FIXME: There must be a better way to detect local facts. *)
   154     case try (unprefix local_prefix) hint of
   155       SOME suff =>
   156       parent_of_local_thm th ^ Long_Name.separator ^ Long_Name.separator ^ suff
   157     | NONE => hint
   158   end
   159 
   160 fun suggested_facts suggs facts =
   161   let
   162     fun add_fact (fact as (_, th)) = Symtab.default (nickname_of th, fact)
   163     val tab = Symtab.empty |> fold add_fact facts
   164   in map_filter (Symtab.lookup tab) suggs end
   165 
   166 (* Ad hoc score function roughly based on Blanchette's Ringberg 2011 data. *)
   167 fun score x = Math.pow (1.5, 15.5 - 0.05 * Real.fromInt x) + 15.0
   168 
   169 fun sum_sq_avg [] = 0
   170   | sum_sq_avg xs =
   171     Real.ceil (100000.0 * fold (curry (op +) o score) xs 0.0) div length xs
   172 
   173 fun mesh_facts max_facts [(selected, unknown)] =
   174     take max_facts selected @ take (max_facts - length selected) unknown
   175   | mesh_facts max_facts mess =
   176     let
   177       val mess = mess |> map (apfst (`length))
   178       val fact_eq = Thm.eq_thm o pairself snd
   179       fun score_in fact ((sel_len, sels), unks) =
   180         case find_index (curry fact_eq fact) sels of
   181           ~1 => (case find_index (curry fact_eq fact) unks of
   182                    ~1 => SOME sel_len
   183                  | _ => NONE)
   184         | j => SOME j
   185       fun score_of fact = mess |> map_filter (score_in fact) |> sum_sq_avg
   186       val facts = fold (union fact_eq o take max_facts o snd o fst) mess []
   187     in
   188       facts |> map (`score_of) |> sort (int_ord o swap o pairself fst)
   189             |> map snd |> take max_facts
   190     end
   191 
   192 val thy_feature_prefix = "y_"
   193 
   194 val thy_feature_name_of = prefix thy_feature_prefix
   195 val const_name_of = prefix const_prefix
   196 val type_name_of = prefix type_const_prefix
   197 val class_name_of = prefix class_prefix
   198 
   199 fun is_likely_tautology_or_too_meta th =
   200   let
   201     val is_boring_const = member (op =) atp_widely_irrelevant_consts
   202     fun is_boring_bool t =
   203       not (exists_Const (not o is_boring_const o fst) t) orelse
   204       exists_type (exists_subtype (curry (op =) @{typ prop})) t
   205     fun is_boring_prop (@{const Trueprop} $ t) = is_boring_bool t
   206       | is_boring_prop (@{const "==>"} $ t $ u) =
   207         is_boring_prop t andalso is_boring_prop u
   208       | is_boring_prop (Const (@{const_name all}, _) $ (Abs (_, _, t)) $ u) =
   209         is_boring_prop t andalso is_boring_prop u
   210       | is_boring_prop (Const (@{const_name "=="}, _) $ t $ u) =
   211         is_boring_bool t andalso is_boring_bool u
   212       | is_boring_prop _ = true
   213   in
   214     is_boring_prop (prop_of th) andalso not (Thm.eq_thm_prop (@{thm ext}, th))
   215   end
   216 
   217 fun theory_ord p =
   218   if Theory.eq_thy p then
   219     EQUAL
   220   else if Theory.subthy p then
   221     LESS
   222   else if Theory.subthy (swap p) then
   223     GREATER
   224   else case int_ord (pairself (length o Theory.ancestors_of) p) of
   225     EQUAL => string_ord (pairself Context.theory_name p)
   226   | order => order
   227 
   228 val thm_ord = theory_ord o pairself theory_of_thm
   229 
   230 val freezeT = Type.legacy_freeze_type
   231 
   232 fun freeze (t $ u) = freeze t $ freeze u
   233   | freeze (Abs (s, T, t)) = Abs (s, freezeT T, freeze t)
   234   | freeze (Var ((s, _), T)) = Free (s, freezeT T)
   235   | freeze (Const (s, T)) = Const (s, freezeT T)
   236   | freeze (Free (s, T)) = Free (s, freezeT T)
   237   | freeze t = t
   238 
   239 fun goal_of_thm thy = prop_of #> freeze #> cterm_of thy #> Goal.init
   240 
   241 fun run_prover_for_mash ctxt params prover facts goal =
   242   let
   243     val problem =
   244       {state = Proof.init ctxt, goal = goal, subgoal = 1, subgoal_count = 1,
   245        facts = facts |> map (apfst (apfst (fn name => name ())))
   246                      |> map Untranslated_Fact}
   247   in
   248     get_minimizing_prover ctxt MaSh (K ()) prover params (K (K (K "")))
   249                           problem
   250   end
   251 
   252 val bad_types = [@{type_name prop}, @{type_name bool}, @{type_name fun}]
   253 
   254 fun interesting_terms_types_and_classes ctxt prover term_max_depth
   255                                         type_max_depth ts =
   256   let
   257     fun is_bad_const (x as (s, _)) args =
   258       member (op =) atp_logical_consts s orelse
   259       fst (is_built_in_const_for_prover ctxt prover x args)
   260     fun add_classes @{sort type} = I
   261       | add_classes S = union (op =) (map class_name_of S)
   262     fun do_add_type (Type (s, Ts)) =
   263         (not (member (op =) bad_types s) ? insert (op =) (type_name_of s))
   264         #> fold do_add_type Ts
   265       | do_add_type (TFree (_, S)) = add_classes S
   266       | do_add_type (TVar (_, S)) = add_classes S
   267     fun add_type T = type_max_depth >= 0 ? do_add_type T
   268     fun mk_app s args =
   269       if member (op <>) args "" then s ^ "(" ^ space_implode "," args ^ ")"
   270       else s
   271     fun patternify ~1 _ = ""
   272       | patternify depth t =
   273         case strip_comb t of
   274           (Const (s, _), args) =>
   275           mk_app (const_name_of s) (map (patternify (depth - 1)) args)
   276         | _ => ""
   277     fun add_term_patterns ~1 _ = I
   278       | add_term_patterns depth t =
   279         insert (op =) (patternify depth t)
   280         #> add_term_patterns (depth - 1) t
   281     val add_term = add_term_patterns term_max_depth
   282     fun add_patterns t =
   283       let val (head, args) = strip_comb t in
   284         (case head of
   285            Const (x as (_, T)) =>
   286            not (is_bad_const x args) ? (add_term t #> add_type T)
   287          | Free (_, T) => add_type T
   288          | Var (_, T) => add_type T
   289          | Abs (_, T, body) => add_type T #> add_patterns body
   290          | _ => I)
   291         #> fold add_patterns args
   292       end
   293   in [] |> fold add_patterns ts end
   294 
   295 fun is_exists (s, _) = (s = @{const_name Ex} orelse s = @{const_name Ex1})
   296 
   297 val term_max_depth = 1
   298 val type_max_depth = 1
   299 
   300 (* TODO: Generate type classes for types? *)
   301 fun features_of ctxt prover thy (scope, status) ts =
   302   thy_feature_name_of (Context.theory_name thy) ::
   303   interesting_terms_types_and_classes ctxt prover term_max_depth type_max_depth
   304                                       ts
   305   |> forall is_lambda_free ts ? cons "no_lams"
   306   |> forall (not o exists_Const is_exists) ts ? cons "no_skos"
   307   |> scope <> Global ? cons "local"
   308   |> (case status of
   309         General => I
   310       | Induction => cons "induction"
   311       | Intro => cons "intro"
   312       | Inductive => cons "inductive"
   313       | Elim => cons "elim"
   314       | Simp => cons "simp"
   315       | Def => cons "def")
   316 
   317 fun isar_dependencies_of all_facts = thms_in_proof (SOME all_facts)
   318 
   319 val atp_dep_default_max_fact = 50
   320 
   321 fun atp_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover auto
   322                         facts all_names th =
   323   case isar_dependencies_of all_names th of
   324     [] => []
   325   | isar_deps =>
   326     let
   327       val thy = Proof_Context.theory_of ctxt
   328       val goal = goal_of_thm thy th
   329       val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
   330       val facts = facts |> filter (fn (_, th') => thm_ord (th', th) = LESS)
   331       fun fix_name ((_, stature), th) = ((fn () => nickname_of th, stature), th)
   332       fun is_dep dep (_, th) = nickname_of th = dep
   333       fun add_isar_dep facts dep accum =
   334         if exists (is_dep dep) accum then
   335           accum
   336         else case find_first (is_dep dep) facts of
   337           SOME ((name, status), th) => accum @ [((name, status), th)]
   338         | NONE => accum (* shouldn't happen *)
   339       val facts =
   340         facts |> iterative_relevant_facts ctxt params prover
   341                      (max_facts |> the_default atp_dep_default_max_fact) NONE
   342                      hyp_ts concl_t
   343               |> fold (add_isar_dep facts) isar_deps
   344               |> map fix_name
   345     in
   346       if verbose andalso not auto then
   347         let val num_facts = length facts in
   348           "MaSh: " ^ quote prover ^ " on " ^ quote (nickname_of th) ^
   349           " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
   350           "."
   351           |> Output.urgent_message
   352         end
   353       else
   354         ();
   355       case run_prover_for_mash ctxt params prover facts goal of
   356         {outcome = NONE, used_facts, ...} =>
   357         (if verbose andalso not auto then
   358            let val num_facts = length used_facts in
   359              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
   360              plural_s num_facts ^ "."
   361              |> Output.urgent_message
   362            end
   363          else
   364            ();
   365          used_facts |> map fst)
   366       | _ => isar_deps
   367     end
   368 
   369 
   370 (*** Low-level communication with MaSh ***)
   371 
   372 fun write_file (xs, f) file =
   373   let val path = Path.explode file in
   374     File.write path "";
   375     xs |> chunk_list 500
   376        |> List.app (File.append path o space_implode "" o map f)
   377   end
   378 
   379 fun mash_info overlord =
   380   if overlord then (getenv "ISABELLE_HOME_USER", "")
   381   else (getenv "ISABELLE_TMP", serial_string ())
   382 
   383 fun and_rm_files overlord flags files =
   384   if overlord then
   385     ""
   386   else
   387     " && rm -f" ^ flags ^ " -- " ^
   388     space_implode " " (map File.shell_quote files)
   389 
   390 fun run_mash ctxt overlord (temp_dir, serial) async core =
   391   let
   392     val log_file = temp_dir ^ "/mash_log" ^ serial
   393     val err_file = temp_dir ^ "/mash_err" ^ serial
   394     val command =
   395       "(" ^ mash_home () ^ "/mash --quiet --outputDir " ^ mash_state_dir () ^
   396       " --log " ^ log_file ^ " " ^ core ^ ") 2>&1 > " ^ err_file ^
   397       and_rm_files overlord "" [log_file, err_file] ^
   398       (if async then " &" else "")
   399   in
   400     trace_msg ctxt (fn () =>
   401         (if async then "Launching " else "Running ") ^ command);
   402     write_file ([], K "") log_file;
   403     write_file ([], K "") err_file;
   404     Isabelle_System.bash command;
   405     if not async then trace_msg ctxt (K "Done") else ()
   406   end
   407 
   408 fun run_mash_commands ctxt overlord save max_suggs write_cmds read_suggs =
   409   let
   410     val info as (temp_dir, serial) = mash_info overlord
   411     val sugg_file = temp_dir ^ "/mash_suggs" ^ serial
   412     val cmd_file = temp_dir ^ "/mash_commands" ^ serial
   413   in
   414     write_file ([], K "") sugg_file;
   415     write_file write_cmds cmd_file;
   416     run_mash ctxt overlord info false
   417              ("--inputFile " ^ cmd_file ^ " --predictions " ^ sugg_file ^
   418               " --numberOfPredictions " ^ string_of_int max_suggs ^
   419               (if save then " --saveModel" else "") ^
   420               (and_rm_files overlord "" [sugg_file, cmd_file]));
   421     read_suggs (fn () => File.read_lines (Path.explode sugg_file))
   422   end
   423 
   424 fun str_of_update (name, parents, feats, deps) =
   425   "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
   426   escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
   427 
   428 fun str_of_query (parents, feats) =
   429   "? " ^ escape_metas parents ^ "; " ^ escape_metas feats
   430 
   431 fun mash_CLEAR ctxt =
   432   let val path = mash_state_dir () |> Path.explode in
   433     trace_msg ctxt (K "MaSh CLEAR");
   434     File.fold_dir (fn file => fn () =>
   435                       File.rm (Path.append path (Path.basic file)))
   436                   path ()
   437   end
   438 
   439 fun mash_ADD _ _ [] = ()
   440   | mash_ADD ctxt overlord upds =
   441     (trace_msg ctxt (fn () => "MaSh ADD " ^
   442          elide_string 1000 (space_implode " " (map #1 upds)));
   443      run_mash_commands ctxt overlord true 0 (upds, str_of_update) (K ()))
   444 
   445 fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
   446   (trace_msg ctxt (fn () => "MaSh QUERY " ^ space_implode " " feats);
   447    run_mash_commands ctxt overlord false  max_suggs
   448        ([query], str_of_query)
   449        (fn suggs => snd (extract_query (List.last (suggs ()))))
   450    handle List.Empty => [])
   451 
   452 
   453 (*** High-level communication with MaSh ***)
   454 
   455 fun try_graph ctxt when def f =
   456   f ()
   457   handle Graph.CYCLES (cycle :: _) =>
   458          (trace_msg ctxt (fn () =>
   459               "Cycle involving " ^ commas cycle ^ " when " ^ when); def)
   460        | Graph.DUP name =>
   461          (trace_msg ctxt (fn () =>
   462               "Duplicate fact " ^ quote name ^ " when " ^ when); def)
   463        | Graph.UNDEF name =>
   464          (trace_msg ctxt (fn () =>
   465               "Unknown fact " ^ quote name ^ " when " ^ when); def)
   466        | exn =>
   467          if Exn.is_interrupt exn then
   468            reraise exn
   469          else
   470            (trace_msg ctxt (fn () =>
   471                 "Internal error when " ^ when ^ ":\n" ^
   472                 ML_Compiler.exn_message exn); def)
   473 
   474 type mash_state = {fact_graph : unit Graph.T}
   475 
   476 val empty_state = {fact_graph = Graph.empty}
   477 
   478 local
   479 
   480 val version = "*** MaSh 0.0 ***"
   481 
   482 fun load _ (state as (true, _)) = state
   483   | load ctxt _ =
   484     let val path = mash_state_path () in
   485       (true,
   486        case try File.read_lines path of
   487          SOME (version' :: fact_lines) =>
   488          let
   489            fun add_edge_to name parent =
   490              Graph.default_node (parent, ())
   491              #> Graph.add_edge (parent, name)
   492            fun add_fact_line line =
   493              case extract_query line of
   494                ("", _) => I (* shouldn't happen *)
   495              | (name, parents) =>
   496                Graph.default_node (name, ())
   497                #> fold (add_edge_to name) parents
   498            val fact_graph =
   499              try_graph ctxt "loading state" Graph.empty (fn () =>
   500                  Graph.empty |> version' = version
   501                                 ? fold add_fact_line fact_lines)
   502          in {fact_graph = fact_graph} end
   503        | _ => empty_state)
   504     end
   505 
   506 fun save {fact_graph} =
   507   let
   508     val path = mash_state_path ()
   509     fun fact_line_for name parents =
   510       escape_meta name ^ ": " ^ escape_metas parents
   511     val append_fact = File.append path o suffix "\n" oo fact_line_for
   512   in
   513     File.write path (version ^ "\n");
   514     Graph.fold (fn (name, ((), (parents, _))) => fn () =>
   515                    append_fact name (Graph.Keys.dest parents))
   516         fact_graph ()
   517   end
   518 
   519 val global_state =
   520   Synchronized.var "Sledgehammer_MaSh.global_state" (false, empty_state)
   521 
   522 in
   523 
   524 fun mash_map ctxt f =
   525   Synchronized.change global_state (load ctxt ##> (f #> tap save))
   526 
   527 fun mash_get ctxt =
   528   Synchronized.change_result global_state (load ctxt #> `snd)
   529 
   530 fun mash_unlearn ctxt =
   531   Synchronized.change global_state (fn _ =>
   532       (mash_CLEAR ctxt; File.write (mash_state_path ()) "";
   533        (true, empty_state)))
   534 
   535 end
   536 
   537 fun mash_could_suggest_facts () = mash_home () <> ""
   538 fun mash_can_suggest_facts ctxt =
   539   not (Graph.is_empty (#fact_graph (mash_get ctxt)))
   540 
   541 fun parents_wrt_facts facts fact_graph =
   542   let
   543     val facts = [] |> fold (cons o nickname_of o snd) facts
   544     val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
   545     fun insert_not_seen seen name =
   546       not (member (op =) seen name) ? insert (op =) name
   547     fun parents_of _ parents [] = parents
   548       | parents_of seen parents (name :: names) =
   549         if Symtab.defined tab name then
   550           parents_of (name :: seen) (name :: parents) names
   551         else
   552           parents_of (name :: seen) parents
   553                      (Graph.Keys.fold (insert_not_seen seen)
   554                                       (Graph.imm_preds fact_graph name) names)
   555   in parents_of [] [] (Graph.maximals fact_graph) end
   556 
   557 (* Generate more suggestions than requested, because some might be thrown out
   558    later for various reasons and "meshing" gives better results with some
   559    slack. *)
   560 fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
   561 
   562 fun is_fact_in_graph fact_graph (_, th) =
   563   can (Graph.get_node fact_graph) (nickname_of th)
   564 
   565 fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   566                        concl_t facts =
   567   let
   568     val thy = Proof_Context.theory_of ctxt
   569     val fact_graph = #fact_graph (mash_get ctxt)
   570     val parents = parents_wrt_facts facts fact_graph
   571     val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   572     val suggs =
   573       if Graph.is_empty fact_graph then []
   574       else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   575     val selected = facts |> suggested_facts suggs
   576     val unknown = facts |> filter_out (is_fact_in_graph fact_graph)
   577   in (selected, unknown) end
   578 
   579 fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
   580   let
   581     fun maybe_add_from from (accum as (parents, graph)) =
   582       try_graph ctxt "updating graph" accum (fn () =>
   583           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   584     val graph = graph |> Graph.default_node (name, ())
   585     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   586     val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   587   in ((name, parents, feats, deps) :: upds, graph) end
   588 
   589 val learn_timeout_slack = 2.0
   590 
   591 fun launch_thread timeout task =
   592   let
   593     val hard_timeout = time_mult learn_timeout_slack timeout
   594     val birth_time = Time.now ()
   595     val death_time = Time.+ (birth_time, hard_timeout)
   596     val desc = ("machine learner for Sledgehammer", "")
   597   in Async_Manager.launch MaShN birth_time death_time desc task end
   598 
   599 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   600                      used_ths =
   601   if is_smt_prover ctxt prover then
   602     ()
   603   else
   604     launch_thread timeout
   605         (fn () =>
   606             let
   607               val thy = Proof_Context.theory_of ctxt
   608               val name = timestamp () ^ " " ^ serial_string () (* freshish *)
   609               val feats = features_of ctxt prover thy (Local, General) [t]
   610               val deps = used_ths |> map nickname_of
   611             in
   612               mash_map ctxt (fn {fact_graph} =>
   613                   let
   614                     val parents = parents_wrt_facts facts fact_graph
   615                     val upds = [(name, parents, feats, deps)]
   616                     val (upds, fact_graph) =
   617                       ([], fact_graph) |> fold (update_fact_graph ctxt) upds
   618                   in
   619                     mash_ADD ctxt overlord upds; {fact_graph = fact_graph}
   620                   end);
   621               (true, "")
   622             end)
   623 
   624 fun sendback sub =
   625   Markup.markup Isabelle_Markup.sendback (sledgehammerN ^ " " ^ sub)
   626 
   627 (* Too many dependencies is a sign that a decision procedure is at work. There
   628    isn't much too learn from such proofs. *)
   629 val max_dependencies = 10
   630 val commit_timeout = seconds 30.0
   631 
   632 (* The timeout is understood in a very slack fashion. *)
   633 fun mash_learn_facts ctxt (params as {debug, verbose, overlord, timeout, ...})
   634                      prover auto atp learn_timeout facts =
   635   let
   636     val timer = Timer.startRealTimer ()
   637     fun next_commit_time () =
   638       Time.+ (Timer.checkRealTimer timer, commit_timeout)
   639     val {fact_graph} = mash_get ctxt
   640     val new_facts =
   641       facts |> filter_out (is_fact_in_graph fact_graph)
   642             |> sort (thm_ord o pairself snd)
   643     val num_new_facts = length new_facts
   644   in
   645     "MaShing" ^
   646     (if not auto then
   647        " " ^ string_of_int num_new_facts ^ " fact" ^
   648        plural_s num_new_facts ^
   649        (if atp then " (ATP timeout: " ^ string_from_time timeout ^ ")" else "")
   650      else
   651        "") ^ "..."
   652     |> Output.urgent_message;
   653     if null new_facts then
   654       if verbose orelse not auto then
   655         "Nothing to learn.\n\nHint: Try " ^ sendback relearn_isarN ^ " or " ^
   656         sendback relearn_atpN ^ "  to learn from scratch."
   657       else
   658         ""
   659     else
   660       let
   661         val ths = facts |> map snd
   662         val all_names =
   663           ths |> filter_out is_likely_tautology_or_too_meta
   664               |> map (rpair () o nickname_of)
   665               |> Symtab.make
   666         fun do_commit [] state = state
   667           | do_commit upds {fact_graph} =
   668             let
   669               val (upds, fact_graph) =
   670                 ([], fact_graph) |> fold (update_fact_graph ctxt) upds
   671             in mash_ADD ctxt overlord (rev upds); {fact_graph = fact_graph} end
   672         fun trim_deps deps = if length deps > max_dependencies then [] else deps
   673         fun commit last upds =
   674           (if debug andalso not auto then Output.urgent_message "Committing..."
   675            else ();
   676            mash_map ctxt (do_commit (rev upds));
   677            if not last andalso not auto then
   678              let val num_upds = length upds in
   679                "Processed " ^ string_of_int num_upds ^ " fact" ^
   680                plural_s num_upds ^ " in the last " ^
   681                string_from_time commit_timeout ^ "."
   682                |> Output.urgent_message
   683              end
   684            else
   685              ())
   686         fun do_fact _ (accum as (_, (_, _, _, true))) = accum
   687           | do_fact ((_, stature), th)
   688                     (upds, (parents, n, next_commit, false)) =
   689             let
   690               val name = nickname_of th
   691               val feats =
   692                 features_of ctxt prover (theory_of_thm th) stature [prop_of th]
   693               val deps =
   694                 (if atp then atp_dependencies_of ctxt params prover auto facts
   695                  else isar_dependencies_of) all_names th
   696                 |> trim_deps
   697               val upds = (name, parents, feats, deps) :: upds
   698               val (upds, next_commit) =
   699                 if Time.> (Timer.checkRealTimer timer, next_commit) then
   700                   (commit false upds; ([], next_commit_time ()))
   701                 else
   702                   (upds, next_commit)
   703               val timed_out =
   704                 Time.> (Timer.checkRealTimer timer, learn_timeout)
   705             in (upds, ([name], n + 1, next_commit, timed_out)) end
   706         val parents = parents_wrt_facts facts fact_graph
   707         val (upds, (_, n, _, _)) =
   708           ([], (parents, 0, next_commit_time (), false))
   709           |> fold do_fact new_facts
   710       in
   711         commit true upds;
   712         if verbose orelse not auto then
   713           "Learned " ^ string_of_int n ^ " proof" ^ plural_s n ^
   714           (if verbose then
   715              " in " ^ string_from_time (Timer.checkRealTimer timer)
   716            else
   717              "") ^ "."
   718         else
   719           ""
   720       end
   721   end
   722 
   723 fun mash_learn ctxt (params as {provers, ...}) atp =
   724   let
   725     val thy = Proof_Context.theory_of ctxt
   726     val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
   727     val facts = all_facts_of thy css_table
   728   in
   729      mash_learn_facts ctxt params (hd provers) false atp infinite_timeout facts
   730      |> Output.urgent_message
   731   end
   732 
   733 (* The threshold should be large enough so that MaSh doesn't kick in for Auto
   734    Sledgehammer and Try. *)
   735 val min_secs_for_learning = 15
   736 
   737 fun relevant_facts ctxt (params as {learn, fact_filter, timeout, ...}) prover
   738         max_facts ({add, only, ...} : fact_override) hyp_ts concl_t facts =
   739   if not (subset (op =) (the_list fact_filter, fact_filters)) then
   740     error ("Unknown fact filter: " ^ quote (the fact_filter) ^ ".")
   741   else if only then
   742     facts
   743   else if max_facts <= 0 orelse null facts then
   744     []
   745   else
   746     let
   747       fun maybe_learn () =
   748         if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
   749            Time.toSeconds timeout >= min_secs_for_learning then
   750           let val timeout = time_mult learn_timeout_slack timeout in
   751             launch_thread timeout
   752                 (fn () => (true, mash_learn_facts ctxt params prover true false
   753                                                   timeout facts))
   754           end
   755         else
   756           ()
   757       val fact_filter =
   758         case fact_filter of
   759           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
   760         | NONE =>
   761           if is_smt_prover ctxt prover then mepoN
   762           else if mash_can_suggest_facts ctxt then (maybe_learn (); meshN)
   763           else if mash_could_suggest_facts () then (maybe_learn (); mepoN)
   764           else mepoN
   765       val add_ths = Attrib.eval_thms ctxt add
   766       fun prepend_facts ths accepts =
   767         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
   768          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
   769         |> take max_facts
   770       fun iter () =
   771         iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
   772                                  concl_t facts
   773       fun mash () =
   774         mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts
   775       val mess =
   776         [] |> (if fact_filter <> mashN then cons (iter (), []) else I)
   777            |> (if fact_filter <> mepoN then cons (mash ()) else I)
   778     in
   779       mesh_facts max_facts mess
   780       |> not (null add_ths) ? prepend_facts add_ths
   781     end
   782 
   783 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
   784 fun running_learners () = Async_Manager.running_threads MaShN "learner"
   785 
   786 end;