1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 16:41:30 2014 +0200
1.3 @@ -616,10 +616,10 @@
1.4 MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
1.5 | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
1.6
1.7 -fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq))
1.8 - visible_facts max_suggs goal_feats int_goal_feats =
1.9 +fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
1.10 + (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
1.11 (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
1.12 - elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
1.13 + elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
1.14 (case engine of
1.15 MaSh_SML_kNN =>
1.16 let
1.17 @@ -632,7 +632,7 @@
1.18 k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
1.19 end
1.20 | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
1.21 - |> map (curry Vector.sub facts o fst))
1.22 + |> map (curry Vector.sub fact_names o fst))
1.23
1.24 end;
1.25
1.26 @@ -684,14 +684,47 @@
1.27 type mash_state =
1.28 {access_G : (proof_kind * string list * string list) Graph.T,
1.29 xtabs : xtab * xtab,
1.30 + ffds : string vector * int list vector * int list vector,
1.31 + freqs : int vector * int Inttab.table vector * int vector,
1.32 dirty_facts : string list option}
1.33
1.34 +val empty_xtabs = (empty_xtab, empty_xtab)
1.35 +val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
1.36 +val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
1.37 +val empty_graphxx = (Graph.empty, empty_xtabs)
1.38 +
1.39 val empty_state =
1.40 {access_G = Graph.empty,
1.41 - xtabs = (empty_xtab, empty_xtab),
1.42 + xtabs = empty_xtabs,
1.43 + ffds = empty_ffds,
1.44 + freqs = empty_freqs,
1.45 dirty_facts = SOME []} : mash_state
1.46
1.47 -val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab))
1.48 +fun reorder_learns (num_facts, fact_tab) learns =
1.49 + let val ary = Array.array (num_facts, ("", [], [])) in
1.50 + List.app (fn learn as (fact, _, _) =>
1.51 + Array.update (ary, the (Symtab.lookup fact_tab fact), learn))
1.52 + learns;
1.53 + Array.foldr (op ::) [] ary
1.54 + end
1.55 +
1.56 +fun recompute_ffd_freqs access_G (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab)) =
1.57 + let
1.58 + val learns =
1.59 + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
1.60 + |> reorder_learns fact_xtab
1.61 +
1.62 + val fact_names = Vector.fromList (map #1 learns)
1.63 + val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
1.64 + val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
1.65 +
1.66 + val tfreq = Vector.tabulate (num_facts, K 0)
1.67 + val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
1.68 + val dffreq = Vector.tabulate (num_feats, K 0)
1.69 + in
1.70 + ((fact_names, featss, depss),
1.71 + MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss)
1.72 + end
1.73
1.74 local
1.75
1.76 @@ -737,9 +770,11 @@
1.77 else wipe_out_mash_state_dir ();
1.78 empty_graphxx)
1.79 | GREATER => raise FILE_VERSION_TOO_NEW ())
1.80 +
1.81 + val (ffds, freqs) = recompute_ffd_freqs access_G xtabs
1.82 in
1.83 trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
1.84 - {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []}
1.85 + {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}
1.86 end
1.87 | _ => empty_state)))
1.88 end
1.89 @@ -749,7 +784,7 @@
1.90 encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
1.91
1.92 fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
1.93 - | save_state ctxt (memory_time, {access_G, xtabs, dirty_facts}) =
1.94 + | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) =
1.95 let
1.96 fun append_entry (name, ((kind, feats, deps), (parents, _))) =
1.97 cons (kind, name, Graph.Keys.dest parents, feats, deps)
1.98 @@ -770,7 +805,8 @@
1.99 (case dirty_facts of
1.100 SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
1.101 | _ => "") ^ ")");
1.102 - (Time.now (), {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []})
1.103 + (Time.now (),
1.104 + {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []})
1.105 end
1.106
1.107 val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
1.108 @@ -1275,16 +1311,6 @@
1.109 fun add_const_counts t =
1.110 fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
1.111
1.112 -fun reorder_learns (num_facts, fact_tab) learns0 =
1.113 - let
1.114 - val learns = Array.array (num_facts, ("", [], []))
1.115 - in
1.116 - List.app (fn learn as (fact, _, _) =>
1.117 - Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
1.118 - learns0;
1.119 - Array.foldr (op ::) [] learns
1.120 - end
1.121 -
1.122 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
1.123 let
1.124 val thy = Proof_Context.theory_of ctxt
1.125 @@ -1333,9 +1359,9 @@
1.126 (parents, hints, feats)
1.127 end
1.128
1.129 - val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) =
1.130 - peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
1.131 - ((access_G, xtabs),
1.132 + val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) =
1.133 + peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} =>
1.134 + ((access_G, xtabs, ffds, freqs),
1.135 if Graph.is_empty access_G then
1.136 (trace_msg ctxt (K "Nothing has been learned yet"); [])
1.137 else if engine = MaSh_Py then
1.138 @@ -1364,25 +1390,10 @@
1.139 end
1.140 else
1.141 let
1.142 - val learns0 =
1.143 - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
1.144 - val learns = reorder_learns fact_xtab learns0
1.145 -
1.146 - val facts = Vector.fromList (map #1 learns)
1.147 - val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
1.148 - val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
1.149 -
1.150 - val tfreq = Vector.tabulate (num_facts, K 0)
1.151 - val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
1.152 - val dffreq = Vector.tabulate (num_feats, K 0)
1.153 -
1.154 - val freqs' =
1.155 - MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
1.156 -
1.157 val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
1.158 in
1.159 - MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs'
1.160 - visible_facts max_suggs goal_feats int_goal_feats
1.161 + MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts
1.162 + max_suggs goal_feats int_goal_feats
1.163 end
1.164 end
1.165
1.166 @@ -1447,7 +1458,7 @@
1.167 val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
1.168 in
1.169 map_state ctxt overlord
1.170 - (fn state as {access_G, xtabs, dirty_facts} =>
1.171 + (fn state as {access_G, xtabs, ffds, freqs, dirty_facts} =>
1.172 let
1.173 val parents = maximal_wrt_access_graph access_G facts
1.174 val deps = used_ths
1.175 @@ -1459,10 +1470,12 @@
1.176 else
1.177 let
1.178 val name = learned_proof_name ()
1.179 - val (access_G, xtabs) =
1.180 + val (access_G', xtabs') =
1.181 add_node Automatic_Proof name parents feats deps (access_G, xtabs)
1.182 +
1.183 + val (ffds', freqs') = recompute_ffd_freqs access_G' xtabs'
1.184 in
1.185 - {access_G = access_G, xtabs = xtabs,
1.186 + {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
1.187 dirty_facts = Option.map (cons name) dirty_facts}
1.188 end
1.189 end);
1.190 @@ -1510,26 +1523,31 @@
1.191 isar_dependencies_of name_tabs th
1.192
1.193 fun do_commit [] [] [] state = state
1.194 - | do_commit learns relearns flops {access_G, xtabs, dirty_facts} =
1.195 + | do_commit learns relearns flops {access_G, xtabs, ffds, freqs, dirty_facts} =
1.196 let
1.197 + val was_empty = Graph.is_empty access_G
1.198 +
1.199 + (* TODO: use "fold_map" *)
1.200 val (learns, (access_G, xtabs)) =
1.201 fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs))
1.202 val (relearns, access_G) =
1.203 fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
1.204
1.205 - val was_empty = Graph.is_empty access_G
1.206 val access_G = access_G |> fold flop_wrt_access_graph flops
1.207 val dirty_facts =
1.208 (case (was_empty, dirty_facts) of
1.209 (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
1.210 | _ => NONE)
1.211 +
1.212 + val (ffds', freqs') = recompute_ffd_freqs access_G xtabs
1.213 in
1.214 if engine = MaSh_Py then
1.215 (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
1.216 MaSh_Py.relearn ctxt overlord save relearns)
1.217 else
1.218 ();
1.219 - {access_G = access_G, xtabs = xtabs, dirty_facts = dirty_facts}
1.220 + {access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs',
1.221 + dirty_facts = dirty_facts}
1.222 end
1.223
1.224 fun commit last learns relearns flops =