1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:52 2014 +0200
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:56 2014 +0200
1.3 @@ -121,8 +121,17 @@
1.4 val relearn_isarN = "relearn_isar"
1.5 val relearn_proverN = "relearn_prover"
1.6
1.7 +val hintsN = ".hints"
1.8 +
1.9 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
1.10
1.11 +type xtab = int * int Symtab.table
1.12 +
1.13 +val empty_xtab = (0, Symtab.empty)
1.14 +
1.15 +fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
1.16 +fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
1.17 +
1.18 fun mash_state_dir () = Path.explode "$ISABELLE_HOME_USER/mash" |> tap Isabelle_System.mkdir
1.19 fun mash_state_file () = Path.append (mash_state_dir ()) (Path.explode "state")
1.20
1.21 @@ -384,10 +393,10 @@
1.22
1.23 val number_of_nearest_neighbors = 10 (* FUDGE *)
1.24
1.25 -fun select_visible_facts recommends =
1.26 +fun select_visible_facts big_number recommends =
1.27 List.app (fn at =>
1.28 let val (j, ov) = Array.sub (recommends, at) in
1.29 - Array.update (recommends, at, (j, 1000000000.0 + ov))
1.30 + Array.update (recommends, at, (j, big_number + ov))
1.31 end)
1.32
1.33 exception EXIT of unit
1.34 @@ -461,7 +470,7 @@
1.35 in
1.36 while1 ();
1.37 while2 ();
1.38 - select_visible_facts recommends visible_facts;
1.39 + select_visible_facts 1000000000.0 recommends visible_facts;
1.40 heap (Real.compare o pairself snd) max_suggs num_facts recommends;
1.41 ret [] (Integer.max 0 (num_facts - max_suggs))
1.42 end
1.43 @@ -540,7 +549,7 @@
1.44 fun ret at acc =
1.45 if at = num_facts then acc else ret (at + 1) (Array.sub (posterior, at) :: acc)
1.46 in
1.47 - select_visible_facts posterior visible_facts;
1.48 + select_visible_facts 100000.0 posterior visible_facts;
1.49 heap (Real.compare o pairself snd) max_suggs num_facts posterior;
1.50 ret (Integer.max 0 (num_facts - max_suggs)) []
1.51 end
1.52 @@ -608,13 +617,20 @@
1.53 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
1.54
1.55 fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
1.56 - learns conj_feats =
1.57 + learns0 conj_feats =
1.58 if engine = MaSh_SML_kNN_Cpp then
1.59 - k_nearest_neighbors_cpp max_suggs learns conj_feats
1.60 + k_nearest_neighbors_cpp max_suggs learns0 conj_feats
1.61 else if engine = MaSh_SML_NB_Cpp then
1.62 - naive_bayes_cpp max_suggs learns conj_feats
1.63 + naive_bayes_cpp max_suggs learns0 conj_feats
1.64 else
1.65 let
1.66 + val learn_ary = Array.array (num_facts, ("", [], []))
1.67 + val _ =
1.68 + List.app (fn entry as (fact, _, _) =>
1.69 + Array.update (learn_ary, the (Symtab.lookup fact_tab fact), entry))
1.70 + learns0
1.71 + val learns = Array.foldr (op ::) [] learn_ary
1.72 +
1.73 val facts = map #1 learns
1.74 val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
1.75 val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
1.76 @@ -675,10 +691,12 @@
1.77 Graph.default_node (parent, (Isar_Proof, [], []))
1.78 #> Graph.add_edge (parent, name)
1.79
1.80 -fun add_node kind name parents feats deps G =
1.81 - (Graph.new_node (name, (kind, feats, deps)) G
1.82 - handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) G)
1.83 - |> fold (add_edge_to name) parents
1.84 +fun add_node kind name parents feats deps (access_G, fact_xtab, feat_xtab) =
1.85 + ((Graph.new_node (name, (kind, feats, deps)) access_G
1.86 + handle Graph.DUP _ => Graph.map_node name (K (kind, feats, deps)) access_G)
1.87 + |> fold (add_edge_to name) parents,
1.88 + maybe_add_to_xtab name fact_xtab,
1.89 + fold maybe_add_to_xtab feats feat_xtab)
1.90
1.91 fun try_graph ctxt when def f =
1.92 f ()
1.93 @@ -703,10 +721,19 @@
1.94
1.95 type mash_state =
1.96 {access_G : (proof_kind * string list * string list) Graph.T,
1.97 - num_known_facts : int,
1.98 + fact_xtab : xtab,
1.99 + feat_xtab : xtab,
1.100 + num_known_facts : int, (* ### FIXME: kill *)
1.101 dirty_facts : string list option}
1.102
1.103 -val empty_state = {access_G = Graph.empty, num_known_facts = 0, dirty_facts = SOME []} : mash_state
1.104 +val empty_state =
1.105 + {access_G = Graph.empty,
1.106 + fact_xtab = empty_xtab,
1.107 + feat_xtab = empty_xtab,
1.108 + num_known_facts = 0,
1.109 + dirty_facts = SOME []} : mash_state
1.110 +
1.111 +val empty_graphxx = (Graph.empty, empty_xtab, empty_xtab)
1.112
1.113 local
1.114
1.115 @@ -741,21 +768,22 @@
1.116 NONE => I (* should not happen *)
1.117 | SOME (kind, name, parents, feats, deps) => add_node kind name parents feats deps)
1.118
1.119 - val (access_G, num_known_facts) =
1.120 + val ((access_G, fact_xtab, feat_xtab), num_known_facts) =
1.121 (case string_ord (version', version) of
1.122 EQUAL =>
1.123 - (try_graph ctxt "loading state" Graph.empty (fn () =>
1.124 - fold extract_line_and_add_node node_lines Graph.empty),
1.125 + (try_graph ctxt "loading state" empty_graphxx (fn () =>
1.126 + fold extract_line_and_add_node node_lines empty_graphxx),
1.127 length node_lines)
1.128 | LESS =>
1.129 (* cannot parse old file *)
1.130 (if the_mash_engine () = MaSh_Py then MaSh_Py.unlearn ctxt overlord
1.131 else wipe_out_mash_state_dir ();
1.132 - (Graph.empty, 0))
1.133 + (empty_graphxx, 0))
1.134 | GREATER => raise FILE_VERSION_TOO_NEW ())
1.135 in
1.136 trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
1.137 - {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []}
1.138 + {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
1.139 + num_known_facts = num_known_facts, dirty_facts = SOME []}
1.140 end
1.141 | _ => empty_state)))
1.142 end
1.143 @@ -765,7 +793,7 @@
1.144 encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
1.145
1.146 fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
1.147 - | save_state ctxt (memory_time, {access_G, num_known_facts, dirty_facts}) =
1.148 + | save_state ctxt (memory_time, {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts}) =
1.149 let
1.150 fun append_entry (name, ((kind, feats, deps), (parents, _))) =
1.151 cons (kind, name, Graph.Keys.dest parents, feats, deps)
1.152 @@ -786,7 +814,9 @@
1.153 (case dirty_facts of
1.154 SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
1.155 | _ => "") ^ ")");
1.156 - (Time.now (), {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = SOME []})
1.157 + (Time.now (),
1.158 + {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
1.159 + num_known_facts = num_known_facts, dirty_facts = SOME []})
1.160 end
1.161
1.162 val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
1.163 @@ -1291,11 +1321,6 @@
1.164 fun add_const_counts t =
1.165 fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
1.166
1.167 -val empty_xtab = (0, Symtab.empty)
1.168 -
1.169 -fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
1.170 -fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
1.171 -
1.172 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
1.173 let
1.174 val thy = Proof_Context.theory_of ctxt
1.175 @@ -1344,19 +1369,18 @@
1.176 (parents, hints, feats)
1.177 end
1.178
1.179 - val (access_G, py_suggs) =
1.180 - peek_state ctxt overlord (fn {access_G, ...} =>
1.181 - if Graph.is_empty access_G then
1.182 - (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
1.183 - else
1.184 - (access_G,
1.185 - if engine = MaSh_Py then
1.186 - let val (parents, hints, feats) = query_args access_G in
1.187 - MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
1.188 - |> map fst
1.189 - end
1.190 - else
1.191 - []))
1.192 + val ((access_G, fact_xtab, feat_xtab), py_suggs) =
1.193 + peek_state ctxt overlord (fn {access_G, fact_xtab, feat_xtab, ...} =>
1.194 + ((access_G, fact_xtab, feat_xtab),
1.195 + if Graph.is_empty access_G then
1.196 + (trace_msg ctxt (K "Nothing has been learned yet"); [])
1.197 + else if engine = MaSh_Py then
1.198 + let val (parents, hints, feats) = query_args access_G in
1.199 + MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
1.200 + |> map fst
1.201 + end
1.202 + else
1.203 + []))
1.204
1.205 val sml_suggs =
1.206 if engine = MaSh_Py then
1.207 @@ -1367,11 +1391,8 @@
1.208 val feats = map fst feats0
1.209 val visible_facts = Graph.all_preds access_G parents
1.210 val learns =
1.211 - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
1.212 - (if null hints then [] else [(".hints", feats, hints)])
1.213 -
1.214 - val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
1.215 - val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
1.216 + (if null hints then [] else [(hintsN, feats, hints)]) @ (* ### FIXME *)
1.217 + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
1.218 in
1.219 MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
1.220 feats
1.221 @@ -1383,27 +1404,33 @@
1.222 |> pairself (map fact_of_raw_fact)
1.223 end
1.224
1.225 -fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
1.226 +fun learn_wrt_access_graph ctxt (name, parents, feats, deps)
1.227 + (learns, (access_G, fact_xtab, feat_xtab)) =
1.228 let
1.229 - fun maybe_learn_from from (accum as (parents, G)) =
1.230 + fun maybe_learn_from from (accum as (parents, access_G)) =
1.231 try_graph ctxt "updating graph" accum (fn () =>
1.232 - (from :: parents, Graph.add_edge_acyclic (from, name) G))
1.233 - val G = G |> Graph.default_node (name, (Isar_Proof, feats, deps))
1.234 - val (parents, G) = ([], G) |> fold maybe_learn_from parents
1.235 - val (deps, _) = ([], G) |> fold maybe_learn_from deps
1.236 + (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
1.237 +
1.238 + val access_G = access_G |> Graph.default_node (name, (Isar_Proof, feats, deps))
1.239 + val (parents, access_G) = ([], access_G) |> fold maybe_learn_from parents
1.240 + val (deps, _) = ([], access_G) |> fold maybe_learn_from deps
1.241 +
1.242 + val fact_xtab = maybe_add_to_xtab name fact_xtab
1.243 + val feat_xtab = fold maybe_add_to_xtab feats feat_xtab
1.244 in
1.245 - ((name, parents, feats, deps) :: learns, G)
1.246 + ((name, parents, feats, deps) :: learns, (access_G, fact_xtab, feat_xtab))
1.247 end
1.248
1.249 -fun relearn_wrt_access_graph ctxt (name, deps) (relearns, G) =
1.250 +fun relearn_wrt_access_graph ctxt (name, deps) (relearns, access_G) =
1.251 let
1.252 - fun maybe_relearn_from from (accum as (parents, G)) =
1.253 + fun maybe_relearn_from from (accum as (parents, access_G)) =
1.254 try_graph ctxt "updating graph" accum (fn () =>
1.255 - (from :: parents, Graph.add_edge_acyclic (from, name) G))
1.256 - val G = G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
1.257 - val (deps, _) = ([], G) |> fold maybe_relearn_from deps
1.258 + (from :: parents, Graph.add_edge_acyclic (from, name) access_G))
1.259 + val access_G =
1.260 + access_G |> Graph.map_node name (fn (_, feats, _) => (Automatic_Proof, feats, deps))
1.261 + val (deps, _) = ([], access_G) |> fold maybe_relearn_from deps
1.262 in
1.263 - ((name, deps) :: relearns, G)
1.264 + ((name, deps) :: relearns, access_G)
1.265 end
1.266
1.267 fun flop_wrt_access_graph name =
1.268 @@ -1431,24 +1458,28 @@
1.269 val thy = Proof_Context.theory_of ctxt
1.270 val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
1.271 in
1.272 - map_state ctxt overlord (fn state as {access_G, num_known_facts, dirty_facts} =>
1.273 - let
1.274 - val parents = maximal_wrt_access_graph access_G facts
1.275 - val deps = used_ths
1.276 - |> filter (is_fact_in_graph access_G)
1.277 - |> map nickname_of_thm
1.278 - in
1.279 - if the_mash_engine () = MaSh_Py then
1.280 - (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
1.281 - else
1.282 - let
1.283 - val name = learned_proof_name ()
1.284 - val access_G = access_G |> add_node Automatic_Proof name parents feats deps
1.285 - in
1.286 - {access_G = access_G, num_known_facts = num_known_facts + 1,
1.287 - dirty_facts = Option.map (cons name) dirty_facts}
1.288 - end
1.289 - end);
1.290 + map_state ctxt overlord
1.291 + (fn state as {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =>
1.292 + let
1.293 + val parents = maximal_wrt_access_graph access_G facts
1.294 + val deps = used_ths
1.295 + |> filter (is_fact_in_graph access_G)
1.296 + |> map nickname_of_thm
1.297 + in
1.298 + if the_mash_engine () = MaSh_Py then
1.299 + (MaSh_Py.learn ctxt overlord true [("", parents, feats, deps)]; state)
1.300 + else
1.301 + let
1.302 + val name = learned_proof_name ()
1.303 + val (access_G, fact_xtab, feat_xtab) =
1.304 + add_node Automatic_Proof name parents feats deps
1.305 + (access_G, fact_xtab, feat_xtab)
1.306 + in
1.307 + {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
1.308 + num_known_facts = num_known_facts + 1,
1.309 + dirty_facts = Option.map (cons name) dirty_facts}
1.310 + end
1.311 + end);
1.312 (true, "")
1.313 end)
1.314 else
1.315 @@ -1466,7 +1497,7 @@
1.316 fun next_commit_time () = Time.+ (Timer.checkRealTimer timer, commit_timeout)
1.317
1.318 val engine = the_mash_engine ()
1.319 - val {access_G, ...} = peek_state ctxt overlord I
1.320 + val {access_G, fact_xtab, feat_xtab, ...} = peek_state ctxt overlord I
1.321 val is_in_access_G = is_fact_in_graph access_G o snd
1.322 val no_new_facts = forall is_in_access_G facts
1.323 in
1.324 @@ -1493,12 +1524,15 @@
1.325 isar_dependencies_of name_tabs th
1.326
1.327 fun do_commit [] [] [] state = state
1.328 - | do_commit learns relearns flops {access_G, num_known_facts, dirty_facts} =
1.329 + | do_commit learns relearns flops
1.330 + {access_G, fact_xtab, feat_xtab, num_known_facts, dirty_facts} =
1.331 let
1.332 + val (learns, (access_G, fact_xtab, feat_xtab)) =
1.333 + fold (learn_wrt_access_graph ctxt) learns ([], (access_G, fact_xtab, feat_xtab))
1.334 + val (relearns, access_G) =
1.335 + fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
1.336 +
1.337 val was_empty = Graph.is_empty access_G
1.338 - val (learns, access_G) = ([], access_G) |> fold (learn_wrt_access_graph ctxt) learns
1.339 - val (relearns, access_G) =
1.340 - ([], access_G) |> fold (relearn_wrt_access_graph ctxt) relearns
1.341 val access_G = access_G |> fold flop_wrt_access_graph flops
1.342 val num_known_facts = num_known_facts + length learns
1.343 val dirty_facts =
1.344 @@ -1511,7 +1545,8 @@
1.345 MaSh_Py.relearn ctxt overlord save relearns)
1.346 else
1.347 ();
1.348 - {access_G = access_G, num_known_facts = num_known_facts, dirty_facts = dirty_facts}
1.349 + {access_G = access_G, fact_xtab = fact_xtab, feat_xtab = feat_xtab,
1.350 + num_known_facts = num_known_facts, dirty_facts = dirty_facts}
1.351 end
1.352
1.353 fun commit last learns relearns flops =