src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 49422 47fe0ca12fc2
parent 49421 b002cc16aa99
child 49423 5493e67982ee
equal deleted inserted replaced
49421:b002cc16aa99 49422:47fe0ca12fc2
   182     fun find_sugg (name, weight) =
   182     fun find_sugg (name, weight) =
   183       Symtab.lookup tab name |> Option.map (rpair weight)
   183       Symtab.lookup tab name |> Option.map (rpair weight)
   184   in map_filter find_sugg suggs end
   184   in map_filter find_sugg suggs end
   185 
   185 
   186 fun sum_avg [] = 0
   186 fun sum_avg [] = 0
   187   | sum_avg xs = Real.ceil (100000.0 * fold (curry (op +)) xs 0.0) div length xs
   187   | sum_avg xs =
       
   188     Real.ceil (100000000.0 * fold (curry (op +)) xs 0.0) div length xs
   188 
   189 
   189 fun normalize_scores [] = []
   190 fun normalize_scores [] = []
   190   | normalize_scores ((fact, score) :: tail) =
   191   | normalize_scores ((fact, score) :: tail) =
   191     (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail
   192     (fact, 1.0) :: map (apsnd (curry Real.* (1.0 / score))) tail
   192 
   193 
   560 end
   561 end
   561 
   562 
   562 fun mash_could_suggest_facts () = mash_home () <> ""
   563 fun mash_could_suggest_facts () = mash_home () <> ""
   563 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
   564 fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
   564 
   565 
   565 fun queue_of xs = Queue.empty |> fold Queue.enqueue xs
   566 fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   566 
   567 
   567 fun max_facts_in_graph fact_G facts =
   568 fun maximal_in_graph fact_G facts =
   568   let
   569   let
   569     val facts = [] |> fold (cons o nickname_of o snd) facts
   570     val facts = [] |> fold (cons o nickname_of o snd) facts
   570     val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
   571     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) facts
   571     fun enqueue_new seen name =
   572     fun insert_new seen name =
   572       not (member (op =) seen name) ? Queue.enqueue name
   573       not (Symtab.defined seen name) ? insert (op =) name
   573     fun find_maxes seen maxs names =
   574     fun find_maxes _ (maxs, []) = map snd maxs
   574       case try Queue.dequeue names of
   575       | find_maxes seen (maxs, new :: news) =
   575         NONE => map snd maxs
   576         find_maxes
   576       | SOME (name, names) =>
   577             (seen |> num_keys (Graph.imm_succs fact_G new) > 1
   577         if Symtab.defined tab name then
   578                      ? Symtab.default (new, ()))
   578           let
   579             (if Symtab.defined tab new then
   579             val new = (Graph.all_preds fact_G [name], name)
   580                let
   580             fun is_ancestor (_, x) (yp, _) = member (op =) yp x
   581                  val newp = Graph.all_preds fact_G [new]
   581             val maxs = maxs |> filter (fn max => not (is_ancestor max new))
   582                  fun is_ancestor x yp = member (op =) yp x
   582             val maxs =
   583                  val maxs =
   583               if exists (is_ancestor new) maxs then maxs
   584                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   584               else new :: filter_out (fn max => is_ancestor max new) maxs
   585                in
   585           in find_maxes (name :: seen) maxs names end
   586                  if exists (is_ancestor new o fst) maxs then
   586         else
   587                    (maxs, news)
   587           find_maxes (name :: seen) maxs
   588                  else
   588                      (Graph.Keys.fold (enqueue_new seen)
   589                    ((newp, new)
   589                                       (Graph.imm_preds fact_G name) names)
   590                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   590   in find_maxes [] [] (queue_of (Graph.maximals fact_G)) end
   591                     news)
       
   592                end
       
   593              else
       
   594                (maxs, Graph.Keys.fold (insert_new seen)
       
   595                                       (Graph.imm_preds fact_G new) news))
       
   596   in find_maxes Symtab.empty ([], Graph.maximals fact_G) end
   591 
   597 
   592 (* Generate more suggestions than requested, because some might be thrown out
   598 (* Generate more suggestions than requested, because some might be thrown out
   593    later for various reasons and "meshing" gives better results with some
   599    later for various reasons and "meshing" gives better results with some
   594    slack. *)
   600    slack. *)
   595 fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
   601 fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
   600 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   606 fun mash_suggested_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   601                          concl_t facts =
   607                          concl_t facts =
   602   let
   608   let
   603     val thy = Proof_Context.theory_of ctxt
   609     val thy = Proof_Context.theory_of ctxt
   604     val fact_G = #fact_G (mash_get ctxt)
   610     val fact_G = #fact_G (mash_get ctxt)
   605     val parents = max_facts_in_graph fact_G facts
   611     val parents = maximal_in_graph fact_G facts
   606     val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   612     val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   607     val suggs =
   613     val suggs =
   608       if Graph.is_empty fact_G then []
   614       if Graph.is_empty fact_G then []
   609       else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   615       else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   610     val selected = facts |> suggested_facts suggs
   616     val selected = facts |> suggested_facts suggs
   616     fun maybe_add_from from (accum as (parents, graph)) =
   622     fun maybe_add_from from (accum as (parents, graph)) =
   617       try_graph ctxt "updating graph" accum (fn () =>
   623       try_graph ctxt "updating graph" accum (fn () =>
   618           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   624           (from :: parents, Graph.add_edge_acyclic (from, name) graph))
   619     val graph = graph |> Graph.default_node (name, ())
   625     val graph = graph |> Graph.default_node (name, ())
   620     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   626     val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   621     val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   627     val (deps, _) = ([], graph) |> fold maybe_add_from deps
   622   in ((name, parents, feats, deps) :: adds, graph) end
   628   in ((name, parents, feats, deps) :: adds, graph) end
   623 
   629 
   624 val learn_timeout_slack = 2.0
   630 val learn_timeout_slack = 2.0
   625 
   631 
   626 fun launch_thread timeout task =
   632 fun launch_thread timeout task =
   645           val thy = Proof_Context.theory_of ctxt
   651           val thy = Proof_Context.theory_of ctxt
   646           val name = freshish_name ()
   652           val name = freshish_name ()
   647           val feats = features_of ctxt prover thy (Local, General) [t]
   653           val feats = features_of ctxt prover thy (Local, General) [t]
   648           val deps = used_ths |> map nickname_of
   654           val deps = used_ths |> map nickname_of
   649           val {fact_G} = mash_get ctxt
   655           val {fact_G} = mash_get ctxt
   650           val parents = max_facts_in_graph fact_G facts
   656           val parents = timeit (fn () => maximal_in_graph fact_G facts)
   651         in
   657         in
   652           mash_ADD ctxt overlord [(name, parents, feats, deps)]; (true, "")
   658           mash_ADD ctxt overlord [(name, parents, feats, deps)]; (true, "")
   653         end)
   659         end)
   654 
   660 
   655 fun sendback sub =
   661 fun sendback sub =
   741               val last_th = new_facts |> List.last |> snd
   747               val last_th = new_facts |> List.last |> snd
   742               (* crude approximation *)
   748               (* crude approximation *)
   743               val ancestors =
   749               val ancestors =
   744                 old_facts
   750                 old_facts
   745                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   751                 |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   746               val parents = max_facts_in_graph fact_G ancestors
   752               val parents = maximal_in_graph fact_G ancestors
   747               val (adds, (_, n, _, _)) =
   753               val (adds, (_, n, _, _)) =
   748                 ([], (parents, 0, next_commit_time (), false))
   754                 ([], (parents, 0, next_commit_time (), false))
   749                 |> fold learn_new_fact new_facts
   755                 |> fold learn_new_fact new_facts
   750             in commit true adds []; n end
   756             in commit true adds []; n end
   751         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   757         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   851           ()
   857           ()
   852       val fact_filter =
   858       val fact_filter =
   853         case fact_filter of
   859         case fact_filter of
   854           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
   860           SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
   855         | NONE =>
   861         | NONE =>
   856           if is_smt_prover ctxt prover then mepoN
   862           if is_smt_prover ctxt prover then
   857           else if mash_can_suggest_facts ctxt then (maybe_learn (); meshN)
   863             mepoN
   858           else if mash_could_suggest_facts () then (maybe_learn (); mepoN)
   864           else if mash_could_suggest_facts () then
   859           else mepoN
   865             (maybe_learn ();
       
   866              if mash_can_suggest_facts ctxt then meshN else mepoN)
       
   867           else
       
   868             mepoN
   860       val add_ths = Attrib.eval_thms ctxt add
   869       val add_ths = Attrib.eval_thms ctxt add
   861       fun prepend_facts ths accepts =
   870       fun prepend_facts ths accepts =
   862         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
   871         ((facts |> filter (member Thm.eq_thm_prop ths o snd)) @
   863          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
   872          (accepts |> filter_out (member Thm.eq_thm_prop ths o snd)))
   864         |> take max_facts
   873         |> take max_facts