src/HOL/Tools/Sledgehammer/sledgehammer_filter_mash.ML
changeset 49335 891a24a48155
parent 49334 340187063d84
child 49336 c552d7f1720b
equal deleted inserted replaced
49334:340187063d84 49335:891a24a48155
    24   val escape_metas : string list -> string
    24   val escape_metas : string list -> string
    25   val unescape_meta : string -> string
    25   val unescape_meta : string -> string
    26   val unescape_metas : string -> string list
    26   val unescape_metas : string -> string list
    27   val extract_query : string -> string * string list
    27   val extract_query : string -> string * string list
    28   val suggested_facts : string list -> fact list -> fact list
    28   val suggested_facts : string list -> fact list -> fact list
    29   val mesh_facts : int -> (fact list * int option) list -> fact list
    29   val mesh_facts : int -> (fact list * fact list) list -> fact list
    30   val is_likely_tautology : Proof.context -> string -> thm -> bool
    30   val is_likely_tautology : Proof.context -> string -> thm -> bool
    31   val is_too_meta : thm -> bool
    31   val is_too_meta : thm -> bool
    32   val theory_ord : theory * theory -> order
    32   val theory_ord : theory * theory -> order
    33   val thm_ord : thm * thm -> order
    33   val thm_ord : thm * thm -> order
    34   val features_of :
    34   val features_of :
    49   val mash_reset : Proof.context -> unit
    49   val mash_reset : Proof.context -> unit
    50   val mash_could_suggest_facts : unit -> bool
    50   val mash_could_suggest_facts : unit -> bool
    51   val mash_can_suggest_facts : unit -> bool
    51   val mash_can_suggest_facts : unit -> bool
    52   val mash_suggest_facts :
    52   val mash_suggest_facts :
    53     Proof.context -> params -> string -> int -> term list -> term -> fact list
    53     Proof.context -> params -> string -> int -> term list -> term -> fact list
    54     -> fact list
    54     -> fact list * fact list
    55   val mash_learn_thy :
    55   val mash_learn_thy :
    56     Proof.context -> params -> theory -> Time.time -> fact list -> string
    56     Proof.context -> params -> theory -> Time.time -> fact list -> string
    57   val mash_learn_proof :
    57   val mash_learn_proof :
    58     Proof.context -> params -> term -> thm list -> fact list -> unit
    58     Proof.context -> params -> term -> thm list -> fact list -> unit
    59   val relevant_facts :
    59   val relevant_facts :
   123 
   123 
   124 fun find_suggested facts sugg =
   124 fun find_suggested facts sugg =
   125   find_first (fn (_, th) => Thm.get_name_hint th = sugg) facts
   125   find_first (fn (_, th) => Thm.get_name_hint th = sugg) facts
   126 fun suggested_facts suggs facts = map_filter (find_suggested facts) suggs
   126 fun suggested_facts suggs facts = map_filter (find_suggested facts) suggs
   127 
   127 
   128 fun sum_avg n xs =
   128 fun sum_avg _ [] = 1000000000 (* big number *)
   129   fold (Integer.add o Integer.mult n) xs 0 div (length xs)
   129   | sum_avg n xs = fold (Integer.add o Integer.mult n) xs 0 div (length xs)
   130 
   130 
   131 fun mesh_facts max_facts [(facts, _)] = facts |> take max_facts
   131 fun mesh_facts max_facts [(selected, unknown)] =
       
   132     take max_facts selected @ take (max_facts - length selected) unknown
   132   | mesh_facts max_facts mess =
   133   | mesh_facts max_facts mess =
   133     let
   134     let
       
   135       val mess = mess |> map (apfst (`length))
   134       val n = length mess
   136       val n = length mess
   135       val fact_eq = Thm.eq_thm o pairself snd
   137       val fact_eq = Thm.eq_thm o pairself snd
   136       fun score_in fact (facts, def) =
   138       fun score_in fact ((sel_len, sels), unks) =
   137         case find_index (curry fact_eq fact) facts of
   139         case find_index (curry fact_eq fact) sels of
   138           ~1 => def
   140           ~1 => (case find_index (curry fact_eq fact) unks of
       
   141                    ~1 => SOME sel_len
       
   142                  | _ => NONE)
   139         | j => SOME j
   143         | j => SOME j
   140       fun score_of fact = mess |> map_filter (score_in fact) |> sum_avg n
   144       fun score_of fact = mess |> map_filter (score_in fact) |> sum_avg n
   141       val facts = fold (union fact_eq o take max_facts o fst) mess []
   145       val facts = fold (union fact_eq o take max_facts o snd o fst) mess []
   142     in
   146     in
   143       facts |> map (`score_of) |> sort (int_ord o pairself fst) |> map snd
   147       facts |> map (`score_of) |> sort (int_ord o pairself fst) |> map snd
   144             |> take max_facts
   148             |> take max_facts
   145     end
   149     end
   146 
   150 
   480 (* Generate more suggestions than requested, because some might be thrown out
   484 (* Generate more suggestions than requested, because some might be thrown out
   481    later for various reasons and "meshing" gives better results with some
   485    later for various reasons and "meshing" gives better results with some
   482    slack. *)
   486    slack. *)
   483 fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
   487 fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
   484 
   488 
       
   489 fun is_fact_in_graph fact_graph (_, th) =
       
   490   can (Graph.get_node fact_graph) (Thm.get_name_hint th)
       
   491 
   485 fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   492 fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   486                        concl_t facts =
   493                        concl_t facts =
   487   let
   494   let
   488     val thy = Proof_Context.theory_of ctxt
   495     val thy = Proof_Context.theory_of ctxt
   489     val fact_graph = #fact_graph (mash_get ())
   496     val fact_graph = #fact_graph (mash_get ())
   490     val parents = parents_wrt_facts facts fact_graph
   497     val parents = parents_wrt_facts facts fact_graph
   491     val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
   498     val feats = features_of ctxt prover thy General (concl_t :: hyp_ts)
   492     val suggs =
   499     val suggs =
   493       mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   500       if Graph.is_empty fact_graph then []
   494   in suggested_facts suggs facts end
   501       else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
       
   502     val selected = facts |> suggested_facts suggs
       
   503     val unknown = facts |> filter_out (is_fact_in_graph fact_graph)
       
   504   in (selected, unknown) end
   495 
   505 
   496 fun add_thys_for thy =
   506 fun add_thys_for thy =
   497   let fun add comp thy = Symtab.update (Context.theory_name thy, comp) in
   507   let fun add comp thy = Symtab.update (Context.theory_name thy, comp) in
   498     add false thy #> fold (add true) (Theory.ancestors_of thy)
   508     add false thy #> fold (add true) (Theory.ancestors_of thy)
   499   end
   509   end
   521     val timer = Timer.startRealTimer ()
   531     val timer = Timer.startRealTimer ()
   522     val prover = hd provers
   532     val prover = hd provers
   523     fun timed_out frac =
   533     fun timed_out frac =
   524       Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
   534       Time.> (Timer.checkRealTimer timer, time_mult frac timeout)
   525     val {fact_graph, ...} = mash_get ()
   535     val {fact_graph, ...} = mash_get ()
   526     fun is_old (_, th) = can (Graph.get_node fact_graph) (Thm.get_name_hint th)
   536     val new_facts =
   527     val new_facts = facts |> filter_out is_old |> sort (thm_ord o pairself snd)
   537       facts |> filter_out (is_fact_in_graph fact_graph)
       
   538             |> sort (thm_ord o pairself snd)
   528   in
   539   in
   529     if null new_facts then
   540     if null new_facts then
   530       ""
   541       ""
   531     else
   542     else
   532       let
   543       let
   643          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
   654          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
   644         |> take max_facts
   655         |> take max_facts
   645       fun iter () =
   656       fun iter () =
   646         iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
   657         iterative_relevant_facts ctxt params prover max_facts NONE hyp_ts
   647                                  concl_t facts
   658                                  concl_t facts
   648         |> (fn facts => (facts, SOME (length facts)))
       
   649       fun mash () =
   659       fun mash () =
   650         (mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts,
   660         mash_suggest_facts ctxt params prover max_facts hyp_ts concl_t facts
   651          NONE)
       
   652       val mess =
   661       val mess =
   653         [] |> (if fact_filter <> mashN then cons (iter ()) else I)
   662         [] |> (if fact_filter <> mashN then cons (iter (), []) else I)
   654            |> (if fact_filter <> iterN then cons (mash ()) else I)
   663            |> (if fact_filter <> iterN then cons (mash ()) else I)
   655     in
   664     in
   656       mesh_facts max_facts mess
   665       mesh_facts max_facts mess
   657       |> not (null add_ths) ? prepend_facts add_ths
   666       |> not (null add_ths) ? prepend_facts add_ths
   658     end
   667     end