src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
changeset 49334 340187063d84
parent 49333 325c8fd0d762
child 49335 891a24a48155
equal deleted inserted replaced
49333:325c8fd0d762 49334:340187063d84
    13   type params = Sledgehammer_Provers.params
    13   type params = Sledgehammer_Provers.params
    14   type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    14   type relevance_fudge = Sledgehammer_Provers.relevance_fudge
    15   type prover_result = Sledgehammer_Provers.prover_result
    15   type prover_result = Sledgehammer_Provers.prover_result
    16 
    16 
    17   val trace : bool Config.T
    17   val trace : bool Config.T
       
    18   val MaShN : string
    18   val meshN : string
    19   val meshN : string
    19   val iterN : string
    20   val iterN : string
    20   val mashN : string
    21   val mashN : string
    21   val fact_filters : string list
    22   val fact_filters : string list
    22   val escape_meta : string -> string
    23   val escape_meta : string -> string
    49   val mash_could_suggest_facts : unit -> bool
    50   val mash_could_suggest_facts : unit -> bool
    50   val mash_can_suggest_facts : unit -> bool
    51   val mash_can_suggest_facts : unit -> bool
    51   val mash_suggest_facts :
    52   val mash_suggest_facts :
    52     Proof.context -> params -> string -> int -> term list -> term -> fact list
    53     Proof.context -> params -> string -> int -> term list -> term -> fact list
    53     -> fact list
    54     -> fact list
    54   val mash_learn_thy : Proof.context -> params -> theory -> Time.time -> unit
    55   val mash_learn_thy :
       
    56     Proof.context -> params -> theory -> Time.time -> fact list -> string
    55   val mash_learn_proof :
    57   val mash_learn_proof :
    56     Proof.context -> params -> term -> thm list -> fact list -> unit
    58     Proof.context -> params -> term -> thm list -> fact list -> unit
    57   val relevant_facts :
    59   val relevant_facts :
    58     Proof.context -> params -> string -> int -> fact_override -> term list
    60     Proof.context -> params -> string -> int -> fact_override -> term list
    59     -> term -> fact list -> fact list
    61     -> term -> fact list -> fact list
       
    62   val kill_learners : unit -> unit
       
    63   val running_learners : unit -> unit
    60 end;
    64 end;
    61 
    65 
    62 structure Sledgehammer_Filter_MaSh : SLEDGEHAMMER_FILTER_MASH =
    66 structure Sledgehammer_Filter_MaSh : SLEDGEHAMMER_FILTER_MASH =
    63 struct
    67 struct
    64 
    68 
    71 open Sledgehammer_Minimize
    75 open Sledgehammer_Minimize
    72 
    76 
    73 val trace =
    77 val trace =
    74   Attrib.setup_config_bool @{binding sledgehammer_filter_mash_trace} (K false)
    78   Attrib.setup_config_bool @{binding sledgehammer_filter_mash_trace} (K false)
    75 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
    79 fun trace_msg ctxt msg = if Config.get ctxt trace then tracing (msg ()) else ()
       
    80 
       
    81 val MaShN = "MaSh"
    76 
    82 
    77 val meshN = "mesh"
    83 val meshN = "mesh"
    78 val iterN = "iter"
    84 val iterN = "iter"
    79 val mashN = "mash"
    85 val mashN = "mash"
    80 
    86 
   479 fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   485 fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   480                        concl_t facts =
   486                        concl_t facts =
   481   let
   487   let
   482     val thy = Proof_Context.theory_of ctxt
   488     val thy = Proof_Context.theory_of ctxt
   483     val fact_graph = #fact_graph (mash_get ())
   489     val fact_graph = #fact_graph (mash_get ())
   484 val _ = warning (PolyML.makestring (length (fact_graph |> Graph.keys), length (fact_graph |> Graph.maximals),
       
   485 length (fact_graph |> Graph.minimals))) (*###*)
       
   486     val parents = parents_wrt_facts facts fact_graph
   490     val parents = parents_wrt_facts facts fact_graph
   487     val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
   491     val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
   488     val suggs =
   492     val suggs =
   489       mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   493       mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   490   in suggested_facts suggs facts end
   494   in suggested_facts suggs facts end
   507     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   511     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   508     val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   512     val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   509   in ((name, parents, feats, deps) :: upds, graph) end
   513   in ((name, parents, feats, deps) :: upds, graph) end
   510 
   514 
   511 val pass1_learn_timeout_factor = 0.5
   515 val pass1_learn_timeout_factor = 0.5
   512 val pass2_learn_timeout_factor = 10.0
       
   513 
   516 
   514 (* The timeout is understood in a very slack fashion. *)
   517 (* The timeout is understood in a very slack fashion. *)
   515 fun mash_learn_thy ctxt ({provers, verbose, overlord, ...} : params) thy
   518 fun mash_learn_thy ctxt ({provers, verbose, overlord, ...} : params) thy timeout
   516                    timeout =
   519                    facts =
   517   let
   520   let
   518     val timer = Timer.startRealTimer ()
   521     val timer = Timer.startRealTimer ()
   519     val prover = hd provers
   522     val prover = hd provers
   520     fun timed_out frac =
   523     fun timed_out frac =
   521       Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
   524       Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
   522     val css_table = clasimpset_rule_table_of ctxt
       
   523     val facts = all_facts_of thy css_table
       
   524     val {fact_graph, ...} = mash_get ()
   525     val {fact_graph, ...} = mash_get ()
   525     fun is_old (_, th) = can (Graph.get_node fact_graph) (Thm.get_name_hint th)
   526     fun is_old (_, th) = can (Graph.get_node fact_graph) (Thm.get_name_hint th)
   526     val new_facts = facts |> filter_out is_old |> sort (thm_ord o pairself snd)
   527     val new_facts = facts |> filter_out is_old |> sort (thm_ord o pairself snd)
   527   in
   528   in
   528     if null new_facts then
   529     if null new_facts then
   529       ()
   530       ""
   530     else
   531     else
   531       let
   532       let
   532         val n = length new_facts
       
   533         val _ =
       
   534           if verbose then
       
   535             "MaShing " ^ string_of_int n ^ " fact" ^ plural_s n ^
       
   536             " (advisory timeout: " ^ string_from_time timeout ^ ")..."
       
   537             |> Output.urgent_message
       
   538           else
       
   539             ()
       
   540         val ths = facts |> map snd
   533         val ths = facts |> map snd
   541         val all_names =
   534         val all_names =
   542           ths |> filter_out (is_likely_tautology ctxt prover orf is_too_meta)
   535           ths |> filter_out (is_likely_tautology ctxt prover orf is_too_meta)
   543               |> map (rpair () o Thm.get_name_hint)
   536               |> map (rpair () o Thm.get_name_hint)
   544               |> Symtab.make
   537               |> Symtab.make
   564             (mash_INIT_or_ADD ctxt overlord (rev upds);
   557             (mash_INIT_or_ADD ctxt overlord (rev upds);
   565              {thys = thys |> add_thys_for thy,
   558              {thys = thys |> add_thys_for thy,
   566               fact_graph = fact_graph})
   559               fact_graph = fact_graph})
   567           end
   560           end
   568       in
   561       in
   569         TimeLimit.timeLimit (time_mult pass2_learn_timeout_factor timeout)
   562         mash_map trans;
   570                             mash_map trans
   563         if verbose then
   571         handle TimeLimit.TimeOut =>
   564           "Processed " ^ string_of_int n ^ " proof" ^ plural_s n ^
   572                (if verbose then
   565           (if verbose then
   573                   "MaSh timed out trying to learn " ^ string_of_int n ^
   566              " in " ^ string_from_time (Timer.checkRealTimer timer)
   574                   " fact" ^ plural_s n ^ " in " ^
   567            else
   575                   string_from_time (Timer.checkRealTimer timer) ^ "."
   568              "") ^ "."
   576                   |> Output.urgent_message
   569         else
   577                 else
   570           ""
   578                   ());
       
   579         (if verbose then
       
   580            "MaSh learned " ^ string_of_int n ^ " fact" ^ plural_s n ^ " in " ^
       
   581            string_from_time (Timer.checkRealTimer timer) ^ "."
       
   582            |> Output.urgent_message
       
   583          else
       
   584            ())
       
   585       end
   571       end
   586   end
   572   end
   587 
   573 
   588 fun mash_learn_proof ctxt ({provers, overlord, ...} : params) t used_ths facts =
   574 fun mash_learn_proof ctxt ({provers, overlord, ...} : params) t used_ths facts =
   589   let
   575   let
   621     []
   607     []
   622   else
   608   else
   623     let
   609     let
   624       val thy = Proof_Context.theory_of ctxt
   610       val thy = Proof_Context.theory_of ctxt
   625       fun maybe_learn can_suggest =
   611       fun maybe_learn can_suggest =
   626         if Time.toSeconds timeout >= min_secs_for_learning then
   612         if Async_Manager.has_running_threads MaShN orelse null facts then
   627           if Multithreading.enabled () then
   613           ()
   628             let
   614         else if Time.toSeconds timeout >= min_secs_for_learning then
   629               val factor =
   615           let
   630                 if can_suggest then short_learn_timeout_factor
   616             val factor =
   631                 else long_learn_timeout_factor
   617               if can_suggest then short_learn_timeout_factor
   632             in
   618               else long_learn_timeout_factor
   633               Future.fork (fn () => mash_learn_thy ctxt params thy
   619             val soft_timeout = time_mult factor timeout
   634                                         (time_mult factor timeout)); ()
   620             val hard_timeout = time_mult 2.0 soft_timeout
   635             end
   621             val birth_time = Time.now ()
   636           else
   622             val death_time = Time.+ (birth_time, hard_timeout)
   637             mash_learn_thy ctxt params thy
   623             val desc = ("machine learner for Sledgehammer", "")
   638                            (time_mult short_learn_timeout_factor timeout)
   624           in
       
   625             Async_Manager.launch MaShN birth_time death_time desc
       
   626                 (fn () =>
       
   627                     (true, mash_learn_thy ctxt params thy soft_timeout facts))
       
   628           end
   639         else
   629         else
   640           ()
   630           ()
   641       val fact_filter =
   631       val fact_filter =
   642         case fact_filter of
   632         case fact_filter of
   643           SOME ff =>
   633           SOME ff =>
   665     in
   655     in
   666       mesh_facts max_facts mess
   656       mesh_facts max_facts mess
   667       |> not (null add_ths) ? prepend_facts add_ths
   657       |> not (null add_ths) ? prepend_facts add_ths
   668     end
   658     end
   669 
   659 
       
   660 fun kill_learners () = Async_Manager.kill_threads MaShN "learner"
       
   661 fun running_learners () = Async_Manager.running_threads MaShN "learner"
       
   662 
   670 end;
   663 end;