src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 52313 407b0258464b
parent 52273 fdcc06013f2d
child 52314 e8c9755fd14e
equal deleted inserted replaced
52312:9f472d5f112c 52313:407b0258464b
   746 
   746 
   747 (*** High-level communication with MaSh ***)
   747 (*** High-level communication with MaSh ***)
   748 
   748 
   749 fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   749 fun num_keys keys = Graph.Keys.fold (K (Integer.add 1)) keys 0
   750 
   750 
   751 fun maximal_in_graph access_G facts =
   751 fun maximal_wrt_graph G keys =
   752   let
   752   let
   753     val facts = [] |> fold (cons o nickname_of_thm o snd) facts
   753     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) keys
   754     val tab = Symtab.empty |> fold (fn name => Symtab.default (name, ())) facts
       
   755     fun insert_new seen name =
   754     fun insert_new seen name =
   756       not (Symtab.defined seen name) ? insert (op =) name
   755       not (Symtab.defined seen name) ? insert (op =) name
   757     fun find_maxes _ (maxs, []) = map snd maxs
   756     fun find_maxes _ (maxs, []) = map snd maxs
   758       | find_maxes seen (maxs, new :: news) =
   757       | find_maxes seen (maxs, new :: news) =
   759         find_maxes
   758         find_maxes
   760             (seen |> num_keys (Graph.imm_succs access_G new) > 1
   759             (seen |> num_keys (Graph.imm_succs G new) > 1
   761                      ? Symtab.default (new, ()))
   760                      ? Symtab.default (new, ()))
   762             (if Symtab.defined tab new then
   761             (if Symtab.defined tab new then
   763                let
   762                let
   764                  val newp = Graph.all_preds access_G [new]
   763                  val newp = Graph.all_preds G [new]
   765                  fun is_ancestor x yp = member (op =) yp x
   764                  fun is_ancestor x yp = member (op =) yp x
   766                  val maxs =
   765                  val maxs =
   767                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   766                    maxs |> filter (fn (_, max) => not (is_ancestor max newp))
   768                in
   767                in
   769                  if exists (is_ancestor new o fst) maxs then
   768                  if exists (is_ancestor new o fst) maxs then
   773                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   772                     :: filter_out (fn (_, max) => is_ancestor max newp) maxs,
   774                     news)
   773                     news)
   775                end
   774                end
   776              else
   775              else
   777                (maxs, Graph.Keys.fold (insert_new seen)
   776                (maxs, Graph.Keys.fold (insert_new seen)
   778                                       (Graph.imm_preds access_G new) news))
   777                                       (Graph.imm_preds G new) news))
   779   in find_maxes Symtab.empty ([], Graph.maximals access_G) end
   778   in find_maxes Symtab.empty ([], Graph.maximals G) end
       
   779 
       
   780 fun maximal_wrt_access_graph access_G =
       
   781   map (nickname_of_thm o snd)
       
   782   #> maximal_wrt_graph access_G
   780 
   783 
   781 fun is_fact_in_graph access_G get_th fact =
   784 fun is_fact_in_graph access_G get_th fact =
   782   can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
   785   can (Graph.get_node access_G) (nickname_of_thm (get_th fact))
   783 
   786 
   784 (* FUDGE *)
   787 (* FUDGE *)
   828       peek_state ctxt (fn {access_G, ...} =>
   831       peek_state ctxt (fn {access_G, ...} =>
   829           if Graph.is_empty access_G then
   832           if Graph.is_empty access_G then
   830             (access_G, [])
   833             (access_G, [])
   831           else
   834           else
   832             let
   835             let
   833               val parents = maximal_in_graph access_G facts
   836               val parents = maximal_wrt_access_graph access_G facts
   834               val feats =
   837               val feats =
   835                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   838                 features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   836               val hints =
   839               val hints =
   837                 chained |> filter (is_fact_in_graph access_G snd)
   840                 chained |> filter (is_fact_in_graph access_G snd)
   838                         |> map (nickname_of_thm o snd)
   841                         |> map (nickname_of_thm o snd)
   886         val name = freshish_name ()
   889         val name = freshish_name ()
   887         val feats = features_of ctxt prover thy (Local, General) [t]
   890         val feats = features_of ctxt prover thy (Local, General) [t]
   888       in
   891       in
   889         peek_state ctxt (fn {access_G, ...} =>
   892         peek_state ctxt (fn {access_G, ...} =>
   890             let
   893             let
   891               val parents = maximal_in_graph access_G facts
   894               val parents = maximal_wrt_access_graph access_G facts
   892               val deps =
   895               val deps =
   893                 used_ths |> filter (is_fact_in_graph access_G I)
   896                 used_ths |> filter (is_fact_in_graph access_G I)
   894                          |> map nickname_of_thm
   897                          |> map nickname_of_thm
   895             in
   898             in
   896               MaSh.learn ctxt overlord [(name, parents, feats, deps)]
   899               MaSh.learn ctxt overlord [(name, parents, feats, deps)]
  1001               val last_th = new_facts |> List.last |> snd
  1004               val last_th = new_facts |> List.last |> snd
  1002               (* crude approximation *)
  1005               (* crude approximation *)
  1003               val ancestors =
  1006               val ancestors =
  1004                 old_facts
  1007                 old_facts
  1005                 |> filter (fn (_, th) => crude_thm_ord (th, last_th) <> GREATER)
  1008                 |> filter (fn (_, th) => crude_thm_ord (th, last_th) <> GREATER)
  1006               val parents = maximal_in_graph access_G ancestors
  1009               val parents = maximal_wrt_access_graph access_G ancestors
  1007               val (learns, (_, n, _, _)) =
  1010               val (learns, (_, n, _, _)) =
  1008                 ([], (parents, 0, next_commit_time (), false))
  1011                 ([], (parents, 0, next_commit_time (), false))
  1009                 |> fold learn_new_fact new_facts
  1012                 |> fold learn_new_fact new_facts
  1010             in commit true learns [] []; n end
  1013             in commit true learns [] []; n end
  1011         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
  1014         fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum