src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 49415 f08425165cca
parent 49414 4bacc8983b3d
child 49416 e740216ca28d
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
     1.3 @@ -472,9 +472,9 @@
     1.4                  "Internal error when " ^ when ^ ":\n" ^
     1.5                  ML_Compiler.exn_message exn); def)
     1.6  
     1.7 -type mash_state = {fact_graph : unit Graph.T}
     1.8 +type mash_state = {fact_G : unit Graph.T}
     1.9  
    1.10 -val empty_state = {fact_graph = Graph.empty}
    1.11 +val empty_state = {fact_G = Graph.empty}
    1.12  
    1.13  local
    1.14  
    1.15 @@ -496,25 +496,25 @@
    1.16               | (name, parents) =>
    1.17                 Graph.default_node (name, ())
    1.18                 #> fold (add_edge_to name) parents
    1.19 -           val fact_graph =
    1.20 +           val fact_G =
    1.21               try_graph ctxt "loading state" Graph.empty (fn () =>
    1.22                   Graph.empty |> version' = version
    1.23                                  ? fold add_fact_line fact_lines)
    1.24 -         in {fact_graph = fact_graph} end
    1.25 +         in {fact_G = fact_G} end
    1.26         | _ => empty_state)
    1.27      end
    1.28  
    1.29 -fun save {fact_graph} =
    1.30 +fun save {fact_G} =
    1.31    let
    1.32      val path = mash_state_path ()
    1.33      fun fact_line_for name parents =
    1.34        escape_meta name ^ ": " ^ escape_metas parents
    1.35      val append_fact = File.append path o suffix "\n" oo fact_line_for
    1.36 +    fun append_entry (name, ((), (parents, _))) () =
    1.37 +      append_fact name (Graph.Keys.dest parents)
    1.38    in
    1.39      File.write path (version ^ "\n");
    1.40 -    Graph.fold (fn (name, ((), (parents, _))) => fn () =>
    1.41 -                   append_fact name (Graph.Keys.dest parents))
    1.42 -        fact_graph ()
    1.43 +    Graph.fold append_entry fact_G ()
    1.44    end
    1.45  
    1.46  val global_state =
    1.47 @@ -535,45 +535,52 @@
    1.48  end
    1.49  
    1.50  fun mash_could_suggest_facts () = mash_home () <> ""
    1.51 -fun mash_can_suggest_facts ctxt =
    1.52 -  not (Graph.is_empty (#fact_graph (mash_get ctxt)))
    1.53 +fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
    1.54  
    1.55 -fun parents_wrt_facts facts fact_graph =
    1.56 +fun queue_of xs = Queue.empty |> fold Queue.enqueue xs
    1.57 +
    1.58 +fun max_facts_in_graph fact_G facts =
    1.59    let
    1.60      val facts = [] |> fold (cons o nickname_of o snd) facts
    1.61      val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
    1.62 -    fun insert_not_seen seen name =
    1.63 -      not (member (op =) seen name) ? insert (op =) name
    1.64 -    fun parents_of _ parents [] = parents
    1.65 -      | parents_of seen parents (name :: names) =
    1.66 -        if Symtab.defined tab name then
    1.67 -          parents_of (name :: seen) (name :: parents) names
    1.68 -        else
    1.69 -          parents_of (name :: seen) parents
    1.70 -                     (Graph.Keys.fold (insert_not_seen seen)
    1.71 -                                      (Graph.imm_preds fact_graph name) names)
    1.72 -  in parents_of [] [] (Graph.maximals fact_graph) end
    1.73 +    fun enqueue_new seen name =
    1.74 +      not (member (op =) seen name) ? Queue.enqueue name
    1.75 +    fun find_maxes seen maxs names =
    1.76 +        case try Queue.dequeue names of
    1.77 +          NONE => maxs
    1.78 +        | SOME (name, names) =>
    1.79 +          if Symtab.defined tab name then
    1.80 +            let
    1.81 +              fun no_path x y = not (member (op =) (Graph.all_preds fact_G [y]) x)
    1.82 +              val maxs = maxs |> filter (fn max => no_path max name)
    1.83 +              val maxs = maxs |> forall (no_path name) maxs ? cons name
    1.84 +            in find_maxes (name :: seen) maxs names end
    1.85 +          else
    1.86 +            find_maxes (name :: seen) maxs
    1.87 +                       (Graph.Keys.fold (enqueue_new seen)
    1.88 +                                        (Graph.imm_preds fact_G name) names)
    1.89 +  in find_maxes [] [] (queue_of (Graph.maximals fact_G)) end
    1.90  
    1.91  (* Generate more suggestions than requested, because some might be thrown out
    1.92     later for various reasons and "meshing" gives better results with some
    1.93     slack. *)
    1.94  fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
    1.95  
    1.96 -fun is_fact_in_graph fact_graph (_, th) =
    1.97 -  can (Graph.get_node fact_graph) (nickname_of th)
    1.98 +fun is_fact_in_graph fact_G (_, th) =
    1.99 +  can (Graph.get_node fact_G) (nickname_of th)
   1.100  
   1.101  fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
   1.102                         concl_t facts =
   1.103    let
   1.104      val thy = Proof_Context.theory_of ctxt
   1.105 -    val fact_graph = #fact_graph (mash_get ctxt)
   1.106 -    val parents = parents_wrt_facts facts fact_graph
   1.107 +    val fact_G = #fact_G (mash_get ctxt)
   1.108 +    val parents = max_facts_in_graph fact_G facts
   1.109      val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
   1.110      val suggs =
   1.111 -      if Graph.is_empty fact_graph then []
   1.112 +      if Graph.is_empty fact_G then []
   1.113        else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
   1.114      val selected = facts |> suggested_facts suggs
   1.115 -    val unknown = facts |> filter_out (is_fact_in_graph fact_graph)
   1.116 +    val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   1.117    in (selected, unknown) end
   1.118  
   1.119  fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
   1.120 @@ -596,6 +603,10 @@
   1.121      val desc = ("machine learner for Sledgehammer", "")
   1.122    in Async_Manager.launch MaShN birth_time death_time desc task end
   1.123  
   1.124 +fun freshish_name () =
   1.125 +  Date.fmt ".%Y_%m_%d_%H_%M_%S__" (Date.fromTimeLocal (Time.now ())) ^
   1.126 +  serial_string ()
   1.127 +
   1.128  fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
   1.129                       used_ths =
   1.130    if is_smt_prover ctxt prover then
   1.131 @@ -605,18 +616,17 @@
   1.132          (fn () =>
   1.133              let
   1.134                val thy = Proof_Context.theory_of ctxt
   1.135 -              val name = timestamp () ^ " " ^ serial_string () (* freshish *)
   1.136 +              val name = freshish_name ()
   1.137                val feats = features_of ctxt prover thy (Local, General) [t]
   1.138                val deps = used_ths |> map nickname_of
   1.139              in
   1.140 -              mash_map ctxt (fn {fact_graph} =>
   1.141 +              mash_map ctxt (fn {fact_G} =>
   1.142                    let
   1.143 -                    val parents = parents_wrt_facts facts fact_graph
   1.144 +                    val parents = max_facts_in_graph fact_G facts
   1.145                      val upds = [(name, parents, feats, deps)]
   1.146 -                    val (upds, fact_graph) =
   1.147 -                      ([], fact_graph) |> fold (update_fact_graph ctxt) upds
   1.148 -                  in
   1.149 -                    mash_ADD ctxt overlord upds; {fact_graph = fact_graph}
   1.150 +                    val (upds, fact_G) =
   1.151 +                      ([], fact_G) |> fold (update_fact_graph ctxt) upds
   1.152 +                  in mash_ADD ctxt overlord upds; {fact_G = fact_G}
   1.153                    end);
   1.154                (true, "")
   1.155              end)
   1.156 @@ -636,10 +646,10 @@
   1.157      val timer = Timer.startRealTimer ()
   1.158      fun next_commit_time () =
   1.159        Time.+ (Timer.checkRealTimer timer, commit_timeout)
   1.160 -    val {fact_graph} = mash_get ctxt
   1.161 -    val new_facts =
   1.162 -      facts |> filter_out (is_fact_in_graph fact_graph)
   1.163 -            |> sort (thm_ord o pairself snd)
   1.164 +    val {fact_G} = mash_get ctxt
   1.165 +    val (old_facts, new_facts) =
   1.166 +      facts |> List.partition (is_fact_in_graph fact_G)
   1.167 +            ||> sort (thm_ord o pairself snd)
   1.168      val num_new_facts = length new_facts
   1.169    in
   1.170      (if not auto then
   1.171 @@ -654,7 +664,7 @@
   1.172       else
   1.173         "")
   1.174      |> Output.urgent_message;
   1.175 -    if null new_facts then
   1.176 +    if num_new_facts = 0 then
   1.177        if not auto then
   1.178          "Nothing to learn.\n\nHint: Try " ^ sendback relearn_isarN ^ " or " ^
   1.179          sendback relearn_atpN ^ " to learn from scratch."
   1.180 @@ -662,17 +672,21 @@
   1.181          ""
   1.182      else
   1.183        let
   1.184 -        val ths = facts |> map snd
   1.185 +        val last_th = new_facts |> List.last |> snd
   1.186 +        (* crude approximation *)
   1.187 +        val ancestors =
   1.188 +          old_facts |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   1.189          val all_names =
   1.190 -          ths |> filter_out is_likely_tautology_or_too_meta
   1.191 -              |> map (rpair () o nickname_of)
   1.192 -              |> Symtab.make
   1.193 +          facts |> map snd
   1.194 +                |> filter_out is_likely_tautology_or_too_meta
   1.195 +                |> map (rpair () o nickname_of)
   1.196 +                |> Symtab.make
   1.197          fun do_commit [] state = state
   1.198 -          | do_commit upds {fact_graph} =
   1.199 +          | do_commit upds {fact_G} =
   1.200              let
   1.201 -              val (upds, fact_graph) =
   1.202 -                ([], fact_graph) |> fold (update_fact_graph ctxt) upds
   1.203 -            in mash_ADD ctxt overlord (rev upds); {fact_graph = fact_graph} end
   1.204 +              val (upds, fact_G) =
   1.205 +                ([], fact_G) |> fold (update_fact_graph ctxt) upds
   1.206 +            in mash_ADD ctxt overlord (rev upds); {fact_G = fact_G} end
   1.207          fun trim_deps deps = if length deps > max_dependencies then [] else deps
   1.208          fun commit last upds =
   1.209            (if debug andalso not auto then Output.urgent_message "Committing..."
   1.210 @@ -708,7 +722,7 @@
   1.211                val timed_out =
   1.212                  Time.> (Timer.checkRealTimer timer, learn_timeout)
   1.213              in (upds, ([name], n, next_commit, timed_out)) end
   1.214 -        val parents = parents_wrt_facts facts fact_graph
   1.215 +        val parents = max_facts_in_graph fact_G ancestors
   1.216          val (upds, (_, n, _, _)) =
   1.217            ([], (parents, 0, next_commit_time (), false))
   1.218            |> fold do_fact new_facts