store string-to-index tables in memory
authorblanchet
Thu, 26 Jun 2014 13:35:56 +0200
changeset 587130b2bce982afd
parent 58712 9d420da6c7e2
child 58714 24738b4f8c6b
store string-to-index tables in memory
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:52 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:56 2014 +0200
     1.3 @@ -121,8 +121,17 @@
     1.4  val relearn_isarN = "relearn_isar"
     1.5  val relearn_proverN = "relearn_prover"
     1.6  
     1.7 +val hintsN = ".hints"
     1.8 +
     1.9  fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
    1.10  
    1.11 +type xtab = int * int Symtab.table
    1.12 +
    1.13 +val empty_xtab = (0, Symtab.empty)
    1.14 +
    1.15 +fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
    1.16 +fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
    1.17 +
    1.18  fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
    1.19  fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
    1.20  
    1.21 @@ -384,10 +393,10 @@
    1.22  
    1.23  val number_of_nearest_neighbors = 10 (* FUDGE *)
    1.24  
    1.25 -fun select_visible_facts recommends =
    1.26 +fun select_visible_facts big_number recommends =
    1.27    List.app (fn at =>
    1.28      let val (j, ov) = Array.sub (recommends, at) in
    1.29 -      Array.update (recommends, at, (j, 1000000000.0 + ov))
    1.30 +      Array.update (recommends, at, (j, big_number + ov))
    1.31      end)
    1.32  
    1.33  exception EXIT of unit
    1.34 @@ -461,7 +470,7 @@
    1.35    in
    1.36      while1 ();
    1.37      while2 ();
    1.38 -    select_visible_facts recommends visible_facts;
    1.39 +    select_visible_facts 1000000000.0 recommends visible_facts;
    1.40      heap (Real.compare o pairself snd) max_suggs num_facts recommends;
    1.41      ret [] (Integer.max 0 (num_facts - max_suggs))
    1.42    end
    1.43 @@ -540,7 +549,7 @@
    1.44      fun ret at acc =
    1.45        if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
    1.46    in
    1.47 -    select_visible_facts posterior visible_facts;
    1.48 +    select_visible_facts 100000.0 posterior visible_facts;
    1.49      heap (Real.compare o pairself snd) max_suggs num_facts posterior;
    1.50      ret (Integer.max 0 (num_facts - max_suggs)) []
    1.51    end
    1.52 @@ -608,13 +617,20 @@
    1.53  val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
    1.54  
    1.55  fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
    1.56 -    learns conj_feats =
    1.57 +    learns0 conj_feats =
    1.58    if engine = MaSh_SML_kNN_Cpp then
    1.59 -    k_nearest_neighbors_cpp max_suggs learns conj_feats
    1.60 +    k_nearest_neighbors_cpp max_suggs learns0 conj_feats
    1.61    else if engine = MaSh_SML_NB_Cpp then
    1.62 -    naive_bayes_cpp max_suggs learns conj_feats
    1.63 +    naive_bayes_cpp max_suggs learns0 conj_feats
    1.64    else
    1.65      let
    1.66 +      val learn_ary = Array.array (num_facts, ("", [], []))
    1.67 +      val _ =
    1.68 +        List.app (fn entry as (fact, _, _) =>
    1.69 +            Array.update (learn_ary, the (Symtab.lookup fact_tab fact), entry))
    1.70 +          learns0
    1.71 +      val learns = Array.foldr (op ::) [] learn_ary
    1.72 +
    1.73        val facts = map #1 learns
    1.74        val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
    1.75        val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
    1.76 @@ -675,10 +691,12 @@
    1.77    Graph.default_node (parent, (Isar_Proof, [], []))
    1.78    #> Graph.add_edge (parent, name)
    1.79  
    1.80 -fun add_node kind name parents feats deps G =
    1.81 -  (Graph.new_node (name, (kind, feats, deps)) G
    1.82 -   handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) G)
    1.83 -  |> fold (add_edge_to name) parents
    1.84 +fun add_node kind name parents feats deps (access_G, fact_xtab, feat_xtab) =
    1.85 +  ((Graph.new_node (name, (kind, feats, deps)) access_G
    1.86 +    handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
    1.87 +   |> fold (add_edge_to name) parents,
    1.88 +  maybe_add_to_xtab name fact_xtab,
    1.89 +  fold maybe_add_to_xtab feats feat_xtab)
    1.90  
    1.91  fun try_graph ctxt when def f =
    1.92    f ()
    1.93 @@ -703,10 +721,19 @@
    1.94  
    1.95  type mash_state =
    1.96    {access_G : (proof_kind * string list * string list) Graph.T,
    1.97 -   num_known_facts : int,
    1.98 +   fact_xtab : xtab,
    1.99 +   feat_xtab : xtab,
   1.100 +   num_known_facts : int, (* ### FIXME: kill *)
   1.101     dirty_facts : string list option}
   1.102  
   1.103 -val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty_facts = SOME []} : mash_state
   1.104 +val empty_state =
   1.105 +  {access_G = Graph.empty,
   1.106 +   fact_xtab = empty_xtab,
   1.107 +   feat_xtab = empty_xtab,
   1.108 +   num_known_facts = 0,
   1.109 +   dirty_facts = SOME []} : mash_state
   1.110 +
   1.111 +val empty_graphxx = (Graph.empty, empty_xtab, empty_xtab)
   1.112  
   1.113  local
   1.114  
   1.115 @@ -741,21 +768,22 @@
   1.116                   NONE => I (* should not happen *)
   1.117                 | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
   1.118  
   1.119 -             val (access_G, num_known_facts) =
   1.120 +             val ((access_G, fact_xtab, feat_xtab), num_known_facts) =
   1.121                 (case string_ord (version', version) of
   1.122                   EQUAL =>
   1.123 -                 (try_graph ctxt "loading state" Graph.empty (fn () =>
   1.124 -                    fold extract_line_and_add_node node_lines Graph.empty),
   1.125 +                 (try_graph ctxt "loading state" empty_graphxx (fn () =>
   1.126 +                    fold extract_line_and_add_node node_lines empty_graphxx),
   1.127                    length node_lines)
   1.128                 | LESS =>
   1.129                   (* cannot parse old file *)
   1.130                   (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
   1.131                    else wipe_out_mash_state_dir ();
   1.132 -                  (Graph.empty, 0))
   1.133 +                  (empty_graphxx, 0))
   1.134                 | GREATER => raise FILE_VERSION_TOO_NEW ())
   1.135             in
   1.136               trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
   1.137 -             {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []}
   1.138 +             {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
   1.139 +              num_known_facts = num_known_facts, dirty_facts = SOME []}
   1.140             end
   1.141           | _ => empty_state)))
   1.142    end
   1.143 @@ -765,7 +793,7 @@
   1.144    encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
   1.145  
   1.146  fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
   1.147 -  | save_state ctxt (memory_time, {access_G, num_known_facts, dirty_facts}) =
   1.148 +  | save_state ctxt (memory_time, {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts}) =
   1.149      let
   1.150        fun append_entry (name, ((kind, feats, deps), (parents, _))) =
   1.151          cons (kind, name, Graph.Keys.dest parents, feats, deps)
   1.152 @@ -786,7 +814,9 @@
   1.153          (case dirty_facts of
   1.154            SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
   1.155          | _ => "") ^  ")");
   1.156 -      (Time.now (), {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []})
   1.157 +      (Time.now (),
   1.158 +       {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
   1.159 +        num_known_facts = num_known_facts, dirty_facts = SOME []})
   1.160      end
   1.161  
   1.162  val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
   1.163 @@ -1291,11 +1321,6 @@
   1.164  fun add_const_counts t =
   1.165    fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
   1.166  
   1.167 -val empty_xtab = (0, Symtab.empty)
   1.168 -
   1.169 -fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
   1.170 -fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
   1.171 -
   1.172  fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   1.173    let
   1.174      val thy = Proof_Context.theory_of ctxt
   1.175 @@ -1344,19 +1369,18 @@
   1.176          (parents, hints, feats)
   1.177        end
   1.178  
   1.179 -    val (access_G, py_suggs) =
   1.180 -      peek_state ctxt overlord (fn {access_G, ...} =>
   1.181 -        if Graph.is_empty access_G then
   1.182 -          (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
   1.183 -        else
   1.184 -          (access_G,
   1.185 -           if engine = MaSh_Py then
   1.186 -             let val (parents, hints, feats) = query_args access_G in
   1.187 -               MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
   1.188 -               |> map fst
   1.189 -             end
   1.190 -           else
   1.191 -             []))
   1.192 +    val ((access_G, fact_xtab, feat_xtab), py_suggs) =
   1.193 +      peek_state ctxt overlord (fn {access_G, fact_xtab, feat_xtab, ...} =>
   1.194 +        ((access_G, fact_xtab, feat_xtab),
   1.195 +         if Graph.is_empty access_G then
   1.196 +           (trace_msg ctxt (K "Nothing has been learned yet"); [])
   1.197 +         else if engine = MaSh_Py then
   1.198 +           let val (parents, hints, feats) = query_args access_G in
   1.199 +             MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
   1.200 +             |> map fst
   1.201 +           end
   1.202 +         else
   1.203 +           []))
   1.204  
   1.205      val sml_suggs =
   1.206        if engine = MaSh_Py then
   1.207 @@ -1367,11 +1391,8 @@
   1.208            val feats = map fst feats0
   1.209            val visible_facts = Graph.all_preds access_G parents
   1.210            val learns =
   1.211 -            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
   1.212 -            (if null hints then [] else [(".hints", feats, hints)])
   1.213 -
   1.214 -          val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
   1.215 -          val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
   1.216 +            (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *)
   1.217 +            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
   1.218          in
   1.219            MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
   1.220              feats
   1.221 @@ -1383,27 +1404,33 @@
   1.222      |> pairself (map fact_of_raw_fact)
   1.223    end
   1.224  
   1.225 -fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
   1.226 +fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
   1.227 +    (learns, (access_G, fact_xtab, feat_xtab)) =
   1.228    let
   1.229 -    fun maybe_learn_from from (accum as (parents, G)) =
   1.230 +    fun maybe_learn_from from (accum as (parents, access_G)) =
   1.231        try_graph ctxt "updating graph" accum (fn () =>
   1.232 -        (from :: parents, Graph.add_edge_acyclic (from, name) G))
   1.233 -    val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
   1.234 -    val (parents, G) = ([], G) |> fold maybe_learn_from parents
   1.235 -    val (deps, _) = ([], G) |> fold maybe_learn_from deps
   1.236 +        (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
   1.237 +
   1.238 +    val access_G = access_G |> Graph.default_node (name, (Isar_Proof, feats, deps))
   1.239 +    val (parents, access_G) = ([], access_G) |> fold maybe_learn_from parents
   1.240 +    val (deps, _) = ([], access_G) |> fold maybe_learn_from deps
   1.241 +
   1.242 +    val fact_xtab = maybe_add_to_xtab name fact_xtab
   1.243 +    val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
   1.244    in
   1.245 -    ((name, parents, feats, deps) :: learns, G)
   1.246 +    ((name, parents, feats, deps) :: learns, (access_G, fact_xtab, feat_xtab))
   1.247    end
   1.248  
   1.249 -fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
   1.250 +fun relearn_wrt_access_graph ctxt (name, deps) (relearns, access_G) =
   1.251    let
   1.252 -    fun maybe_relearn_from from (accum as (parents, G)) =
   1.253 +    fun maybe_relearn_from from (accum as (parents, access_G)) =
   1.254        try_graph ctxt "updating graph" accum (fn () =>
   1.255 -        (from :: parents, Graph.add_edge_acyclic (from, name) G))
   1.256 -    val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
   1.257 -    val (deps, _) = ([], G) |> fold maybe_relearn_from deps
   1.258 +        (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
   1.259 +    val access_G =
   1.260 +      access_G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
   1.261 +    val (deps, _) = ([], access_G) |> fold maybe_relearn_from deps
   1.262    in
   1.263 -    ((name, deps) :: relearns, G)
   1.264 +    ((name, deps) :: relearns, access_G)
   1.265    end
   1.266  
   1.267  fun flop_wrt_access_graph name =
   1.268 @@ -1431,24 +1458,28 @@
   1.269          val thy = Proof_Context.theory_of ctxt
   1.270          val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
   1.271        in
   1.272 -        map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty_facts} =>
   1.273 -          let
   1.274 -            val parents = maximal_wrt_access_graph access_G facts
   1.275 -            val deps = used_ths
   1.276 -              |> filter (is_fact_in_graph access_G)
   1.277 -              |> map nickname_of_thm
   1.278 -          in
   1.279 -            if the_mash_engine () = MaSh_Py then
   1.280 -              (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
   1.281 -            else
   1.282 -              let
   1.283 -                val name = learned_proof_name ()
   1.284 -                val access_G = access_G |> add_node Automatic_Proof name parents feats deps
   1.285 -              in
   1.286 -                {access_G = access_G, num_known_facts = num_known_facts + 1,
   1.287 -                 dirty_facts = Option.map (cons name) dirty_facts}
   1.288 -              end
   1.289 -          end);
   1.290 +        map_state ctxt overlord
   1.291 +          (fn state as {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =>
   1.292 +             let
   1.293 +               val parents = maximal_wrt_access_graph access_G facts
   1.294 +               val deps = used_ths
   1.295 +                 |> filter (is_fact_in_graph access_G)
   1.296 +                 |> map nickname_of_thm
   1.297 +             in
   1.298 +               if the_mash_engine () = MaSh_Py then
   1.299 +                 (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
   1.300 +               else
   1.301 +                 let
   1.302 +                   val name = learned_proof_name ()
   1.303 +                   val (access_G, fact_xtab, feat_xtab) =
   1.304 +                     add_node Automatic_Proof name parents feats deps
   1.305 +                       (access_G, fact_xtab, feat_xtab)
   1.306 +                 in
   1.307 +                   {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
   1.308 +                    num_known_facts = num_known_facts + 1,
   1.309 +                    dirty_facts = Option.map (cons name) dirty_facts}
   1.310 +                 end
   1.311 +             end);
   1.312          (true, "")
   1.313        end)
   1.314    else
   1.315 @@ -1466,7 +1497,7 @@
   1.316      fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
   1.317  
   1.318      val engine = the_mash_engine ()
   1.319 -    val {access_G, ...} = peek_state ctxt overlord I
   1.320 +    val {access_G, fact_xtab, feat_xtab, ...} = peek_state ctxt overlord I
   1.321      val is_in_access_G = is_fact_in_graph access_G o snd
   1.322      val no_new_facts = forall is_in_access_G facts
   1.323    in
   1.324 @@ -1493,12 +1524,15 @@
   1.325              isar_dependencies_of name_tabs th
   1.326  
   1.327          fun do_commit [] [] [] state = state
   1.328 -          | do_commit learns relearns flops {access_G, num_known_facts, dirty_facts} =
   1.329 +          | do_commit learns relearns flops
   1.330 +              {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =
   1.331              let
   1.332 +              val (learns, (access_G, fact_xtab, feat_xtab)) =
   1.333 +                fold (learn_wrt_access_graph ctxt) learns ([], (access_G, fact_xtab, feat_xtab))
   1.334 +              val (relearns, access_G) =
   1.335 +                fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
   1.336 +
   1.337                val was_empty = Graph.is_empty access_G
   1.338 -              val (learns, access_G) = ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
   1.339 -              val (relearns, access_G) =
   1.340 -                ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
   1.341                val access_G = access_G |> fold flop_wrt_access_graph flops
   1.342                val num_known_facts = num_known_facts + length learns
   1.343                val dirty_facts =
   1.344 @@ -1511,7 +1545,8 @@
   1.345                   MaSh_Py.relearn ctxt overlord save relearns)
   1.346                else
   1.347                  ();
   1.348 -              {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = dirty_facts}
   1.349 +              {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
   1.350 +               num_known_facts = num_known_facts, dirty_facts = dirty_facts}
   1.351              end
   1.352  
   1.353          fun commit last learns relearns flops =