tuning
authorblanchet
Thu, 26 Jun 2014 13:35:46 +0200
changeset 587116d422f19cefb
parent 58710 b89937ed6099
child 58712 9d420da6c7e2
tuning
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:39 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:35:46 2014 +0200
     1.3 @@ -46,30 +46,6 @@
     1.4    val is_mash_enabled : unit -> bool
     1.5    val the_mash_engine : unit -> mash_engine
     1.6  
     1.7 -  structure MaSh_Py :
     1.8 -  sig
     1.9 -    val unlearn : Proof.context -> bool -> unit
    1.10 -    val learn : Proof.context -> bool -> bool ->
    1.11 -      (string * string list * string list * string list) list -> unit
    1.12 -    val relearn : Proof.context -> bool -> bool -> (string * string list) list -> unit
    1.13 -    val query : Proof.context -> bool -> int ->
    1.14 -      (string * string list * string list * string list) list * string list * string list
    1.15 -        * (string * real) list ->
    1.16 -      (string * real) list
    1.17 -  end
    1.18 -
    1.19 -  structure MaSh_SML :
    1.20 -  sig
    1.21 -    val k_nearest_neighbors : int -> (int -> int list) -> (int -> int list) -> int -> int list ->
    1.22 -      int -> int list -> (int * real) list
    1.23 -    val naive_bayes : int -> (int -> int list) -> (int -> int list) -> int -> int -> int list ->
    1.24 -      int list -> (int * real) list
    1.25 -    val naive_bayes_py : Proof.context -> bool -> int -> (int -> int list) -> (int -> int list) ->
    1.26 -      int -> int -> int list -> (int * real) list
    1.27 -    val query : Proof.context -> bool -> mash_engine -> string list -> int ->
    1.28 -      (string * string list * string list) list * string list * string list -> string list
    1.29 -  end
    1.30 -
    1.31    val mash_unlearn : Proof.context -> params -> unit
    1.32    val nickname_of_thm : thm -> string
    1.33    val find_suggested_facts : Proof.context -> ('b * thm) list -> string list -> ('b * thm) list
    1.34 @@ -492,7 +468,7 @@
    1.35  
    1.36  val nb_def_prior_weight = 21 (* FUDGE *)
    1.37  
    1.38 -fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
    1.39 +fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats =
    1.40    let
    1.41      fun learn_fact th feats deps =
    1.42        let
    1.43 @@ -525,7 +501,7 @@
    1.44      val sfreq = Array.array (num_facts, Inttab.empty)
    1.45      val dffreq = Array.array (num_feats, 0)
    1.46    in
    1.47 -    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
    1.48 +    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats
    1.49    end
    1.50  
    1.51  fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
    1.52 @@ -574,7 +550,7 @@
    1.53    |> naive_bayes_query num_facts max_suggs visible_facts feats
    1.54  
    1.55  (* experimental *)
    1.56 -fun naive_bayes_py ctxt overlord num_facts get_deps get_feats num_feats max_suggs feats =
    1.57 +fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats =
    1.58    let
    1.59      fun name_of_fact j = "f" ^ string_of_int j
    1.60      fun fact_of_name s = the (Int.fromString (unprefix "f" s))
    1.61 @@ -631,66 +607,54 @@
    1.62    c_plus_plus_tool ("newknn/knn" ^ " " ^ string_of_int number_of_nearest_neighbors)
    1.63  val naive_bayes_cpp = c_plus_plus_tool "predict/nbayes"
    1.64  
    1.65 -val empty_xtab = (0, Symtab.empty)
    1.66 +fun query ctxt overlord engine (num_facts, fact_tab) (num_feats, feat_tab) visible_facts max_suggs
    1.67 +    learns conj_feats =
    1.68 +  if engine = MaSh_SML_kNN_Cpp then
    1.69 +    k_nearest_neighbors_cpp max_suggs learns conj_feats
    1.70 +  else if engine = MaSh_SML_NB_Cpp then
    1.71 +    naive_bayes_cpp max_suggs learns conj_feats
    1.72 +  else
    1.73 +    let
    1.74 +      val facts = map #1 learns
    1.75 +      val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
    1.76 +      val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
    1.77  
    1.78 -fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
    1.79 -fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
    1.80 +      val fact_vec = Vector.fromList facts
    1.81 +      val deps_vec = Vector.fromList depss
    1.82  
    1.83 -fun query ctxt overlord engine visible_facts max_suggs (learns0, hints, feats) =
    1.84 -  let
    1.85 -    val learns = learns0 @ (if null hints then [] else [(".hints", feats, hints)])
    1.86 -  in
    1.87 -    if engine = MaSh_SML_kNN_Cpp then
    1.88 -      k_nearest_neighbors_cpp max_suggs learns feats
    1.89 -    else if engine = MaSh_SML_NB_Cpp then
    1.90 -      naive_bayes_cpp max_suggs learns feats
    1.91 -    else
    1.92 -      let
    1.93 -        val facts = map #1 learns
    1.94 -        val fact_vec = Vector.fromList facts
    1.95 +      val get_deps = curry Vector.sub deps_vec
    1.96  
    1.97 -        val fact_xtab as (num_facts, fact_tab) = fold add_to_xtab facts empty_xtab
    1.98 -        val feat_xtab as (num_feats, feat_tab) = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
    1.99 -
   1.100 -        val featss = map (map_filter (Symtab.lookup feat_tab) o #2) learns
   1.101 -
   1.102 -        val deps_vec = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
   1.103 -
   1.104 -        val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   1.105 -
   1.106 -        val get_deps = curry Vector.sub deps_vec
   1.107 -
   1.108 -        val int_feats = map_filter (Symtab.lookup feat_tab) feats
   1.109 -      in
   1.110 -        trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs feats ^ " from {" ^
   1.111 -          elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
   1.112 -        (if engine = MaSh_SML_kNN then
   1.113 -           let
   1.114 -             val facts_ary = Array.array (num_feats, [])
   1.115 -             val _ =
   1.116 -               fold (fn feats => fn fact =>
   1.117 -                   (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
   1.118 -                 featss 0
   1.119 -             val get_facts = curry Array.sub facts_ary
   1.120 -           in
   1.121 -             k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
   1.122 -               int_feats
   1.123 -           end
   1.124 -         else
   1.125 -           let
   1.126 -             val unweighted_feats_ary = Vector.fromList featss
   1.127 -             val get_unweighted_feats = curry Vector.sub unweighted_feats_ary
   1.128 -           in
   1.129 -             (case engine of
   1.130 -               MaSh_SML_NB =>
   1.131 -               naive_bayes num_facts get_deps get_unweighted_feats num_feats max_suggs
   1.132 -                 int_visible_facts int_feats
   1.133 -             | MaSh_SML_NB_Py => naive_bayes_py ctxt overlord num_facts get_deps
   1.134 -                 get_unweighted_feats num_feats max_suggs int_feats)
   1.135 -           end)
   1.136 -        |> map (curry Vector.sub fact_vec o fst)
   1.137 -      end
   1.138 -  end
   1.139 +      val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   1.140 +      val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
   1.141 +    in
   1.142 +      trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_strs conj_feats ^ " from {" ^
   1.143 +        elide_string 1000 (space_implode " " (take num_facts facts)) ^ "}");
   1.144 +      (if engine = MaSh_SML_kNN then
   1.145 +         let
   1.146 +           val facts_ary = Array.array (num_feats, [])
   1.147 +           val _ =
   1.148 +             fold (fn feats => fn fact =>
   1.149 +                 (List.app (map_array_at facts_ary (cons fact)) feats; fact + 1))
   1.150 +               featss 0
   1.151 +           val get_facts = curry Array.sub facts_ary
   1.152 +         in
   1.153 +           k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
   1.154 +             int_conj_feats
   1.155 +         end
   1.156 +       else
   1.157 +         let
   1.158 +           val feats_ary = Vector.fromList featss
   1.159 +           val get_feats = curry Vector.sub feats_ary
   1.160 +         in
   1.161 +           (case engine of
   1.162 +             MaSh_SML_NB =>
   1.163 +             naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts
   1.164 +               int_conj_feats
   1.165 +           | MaSh_SML_NB_Py =>
   1.166 +             naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs int_conj_feats)
   1.167 +         end)
   1.168 +      |> map (curry Vector.sub fact_vec o fst)
   1.169 +    end
   1.170  
   1.171  end;
   1.172  
   1.173 @@ -1328,6 +1292,11 @@
   1.174  fun add_const_counts t =
   1.175    fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
   1.176  
   1.177 +val empty_xtab = (0, Symtab.empty)
   1.178 +
   1.179 +fun add_to_xtab key (next, tab) = (next + 1, Symtab.update_new (key, next) tab)
   1.180 +fun maybe_add_to_xtab key = perhaps (try (add_to_xtab key))
   1.181 +
   1.182  fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_facts hyp_ts concl_t facts =
   1.183    let
   1.184      val thy = Proof_Context.theory_of ctxt
   1.185 @@ -1395,12 +1364,18 @@
   1.186          []
   1.187        else
   1.188          let
   1.189 -          val (parents, hints, feats) = query_args access_G
   1.190 +          val (parents, hints, feats0) = query_args access_G
   1.191 +          val feats = map fst feats0
   1.192            val visible_facts = Graph.all_preds access_G parents
   1.193            val learns =
   1.194 -            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
   1.195 +            Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G @
   1.196 +            (if null hints then [] else [(".hints", feats, hints)])
   1.197 +
   1.198 +          val fact_xtab = fold (add_to_xtab o #1) learns empty_xtab
   1.199 +          val feat_xtab = fold (fold maybe_add_to_xtab o #2) learns empty_xtab
   1.200          in
   1.201 -          MaSh_SML.query ctxt overlord engine visible_facts max_facts (learns, hints, map fst feats)
   1.202 +          MaSh_SML.query ctxt overlord engine fact_xtab feat_xtab visible_facts max_facts learns
   1.203 +            feats
   1.204          end
   1.205  
   1.206      val unknown = filter_out (is_fact_in_graph access_G o snd) facts