1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:39 2014 +0200
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:35:46 2014 +0200
1.3 @@ -46,30 +46,6 @@
1.4 val is_mash_enabled : unit -> bool
1.5 val the_mash_engine : unit -> mash_engine
1.6
1.7 - structure MaSh_Py :
1.8 - sig
1.9 - val unlearn : Proof.context -> bool -> unit
1.10 - val learn : Proof.context -> bool -> bool ->
1.11 - (string * string list * string list * string list) list -> unit
1.12 - val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
1.13 - val query : Proof.context -> bool -> int ->
1.14 - (string * string list * string list * string list) list * string list * string list
1.15 - * (string * real) list ->
1.16 - (string * real) list
1.17 - end
1.18 -
1.19 - structure MaSh_SML :
1.20 - sig
1.21 - val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
1.22 - int -> int list -> (int * real) list
1.23 - val naive_bayes : int -> (int -> int list) -> (int -> int list) -> int -> int -> int list ->
1.24 - int list -> (int * real) list
1.25 - val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
1.26 - int -> int -> int list -> (int * real) list
1.27 - val query : Proof.context -> bool -> mash_engine -> string list -> int ->
1.28 - (string * string list * string list) list * string list * string list -> string list
1.29 - end
1.30 -
1.31 val mash_unlearn : Proof.context -> params -> unit
1.32 val nickname_of_thm : thm -> string
1.33 val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
1.34 @@ -492,7 +468,7 @@
1.35
1.36 val nb_def_prior_weight = 21 (* FUDGE *)
1.37
1.38 -fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
1.39 +fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats =
1.40 let
1.41 fun learn_fact th feats deps =
1.42 let
1.43 @@ -525,7 +501,7 @@
1.44 val sfreq = Array.array (num_facts, Inttab.empty)
1.45 val dffreq = Array.array (num_feats, 0)
1.46 in
1.47 - learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
1.48 + learn_facts tfreq sfreq dffreq num_facts get_deps get_feats
1.49 end
1.50
1.51 fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
1.52 @@ -574,7 +550,7 @@
1.53 |> naive_bayes_query num_facts max_suggs visible_facts feats
1.54
1.55 (* experimental *)
1.56 -fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
1.57 +fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats =
1.58 let
1.59 fun name_of_fact j = "f" ^ string_of_int j
1.60 fun fact_of_name s = the (Int.fromString (unprefix "f" s))
1.61 @@ -631,66 +607,54 @@
1.62 c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
1.63 val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
1.64
1.65 -val empty_xtab = (0, Symtab.empty)
1.66 +fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
1.67 + learns conj_feats =
1.68 + if engine = MaSh_SML_kNN_Cpp then
1.69 + k_nearest_neighbors_cpp max_suggs learns conj_feats
1.70 + else if engine = MaSh_SML_NB_Cpp then
1.71 + naive_bayes_cpp max_suggs learns conj_feats
1.72 + else
1.73 + let
1.74 + val facts = map #1 learns
1.75 + val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
1.76 + val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
1.77
1.78 -fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
1.79 -fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
1.80 + val fact_vec = Vector.fromList facts
1.81 + val deps_vec = Vector.fromList depss
1.82
1.83 -fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
1.84 - let
1.85 - val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
1.86 - in
1.87 - if engine = MaSh_SML_kNN_Cpp then
1.88 - k_nearest_neighbors_cpp max_suggs learns feats
1.89 - else if engine = MaSh_SML_NB_Cpp then
1.90 - naive_bayes_cpp max_suggs learns feats
1.91 - else
1.92 - let
1.93 - val facts = map #1 learns
1.94 - val fact_vec = Vector.fromList facts
1.95 + val get_deps = curry Vector.sub deps_vec
1.96
1.97 - val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab
1.98 - val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
1.99 -
1.100 - val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
1.101 -
1.102 - val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
1.103 -
1.104 - val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
1.105 -
1.106 - val get_deps = curry Vector.sub deps_vec
1.107 -
1.108 - val int_feats = map_filter (Symtab.lookup feat_tab) feats
1.109 - in
1.110 - trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^
1.111 - elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
1.112 - (if engine = MaSh_SML_kNN then
1.113 - let
1.114 - val facts_ary = Array.array (num_feats, [])
1.115 - val _ =
1.116 - fold (fn feats => fn fact =>
1.117 - (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
1.118 - featss 0
1.119 - val get_facts = curry Array.sub facts_ary
1.120 - in
1.121 - k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
1.122 - int_feats
1.123 - end
1.124 - else
1.125 - let
1.126 - val unweighted_feats_ary = Vector.fromList featss
1.127 - val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
1.128 - in
1.129 - (case engine of
1.130 - MaSh_SML_NB =>
1.131 - naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs
1.132 - int_visible_facts int_feats
1.133 - | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
1.134 - get_unweighted_feats num_feats max_suggs int_feats)
1.135 - end)
1.136 - |> map (curry Vector.sub fact_vec o fst)
1.137 - end
1.138 - end
1.139 + val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
1.140 + val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
1.141 + in
1.142 + trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
1.143 + elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
1.144 + (if engine = MaSh_SML_kNN then
1.145 + let
1.146 + val facts_ary = Array.array (num_feats, [])
1.147 + val _ =
1.148 + fold (fn feats => fn fact =>
1.149 + (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
1.150 + featss 0
1.151 + val get_facts = curry Array.sub facts_ary
1.152 + in
1.153 + k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
1.154 + int_conj_feats
1.155 + end
1.156 + else
1.157 + let
1.158 + val feats_ary = Vector.fromList featss
1.159 + val get_feats = curry Vector.sub feats_ary
1.160 + in
1.161 + (case engine of
1.162 + MaSh_SML_NB =>
1.163 + naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts
1.164 + int_conj_feats
1.165 + | MaSh_SML_NB_Py =>
1.166 + naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs int_conj_feats)
1.167 + end)
1.168 + |> map (curry Vector.sub fact_vec o fst)
1.169 + end
1.170
1.171 end;
1.172
1.173 @@ -1328,6 +1292,11 @@
1.174 fun add_const_counts t =
1.175 fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
1.176
1.177 +val empty_xtab = (0, Symtab.empty)
1.178 +
1.179 +fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
1.180 +fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
1.181 +
1.182 fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
1.183 let
1.184 val thy = Proof_Context.theory_of ctxt
1.185 @@ -1395,12 +1364,18 @@
1.186 []
1.187 else
1.188 let
1.189 - val (parents, hints, feats) = query_args access_G
1.190 + val (parents, hints, feats0) = query_args access_G
1.191 + val feats = map fst feats0
1.192 val visible_facts = Graph.all_preds access_G parents
1.193 val learns =
1.194 - Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
1.195 + Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
1.196 + (if null hints then [] else [(".hints", feats, hints)])
1.197 +
1.198 + val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
1.199 + val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
1.200 in
1.201 - MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats)
1.202 + MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
1.203 + feats
1.204 end
1.205
1.206 val unknown = filter_out (is_fact_in_graph access_G o snd) facts