recompute learning data at learning time, not query time
authorblanchet
Thu, 26 Jun 2014 16:41:30 +0200
changeset 58720fe96689f393b
parent 58719 73e9b858ec8d
child 58721 dcaf04545de2
recompute learning data at learning time, not query time
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 16:41:30 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 16:41:30 2014 +0200
     1.3 @@ -616,10 +616,10 @@
     1.4       MaSh_SML_kNN_Ext => k_nearest_neighbors_ext max_suggs learns goal_feats
     1.5     | MaSh_SML_NB_Ext => naive_bayes_ext max_suggs learns goal_feats))
     1.6  
     1.7 -fun query_internal ctxt engine num_facts num_feats (facts, featss, depss) (freqs as (_, _, dffreq))
     1.8 -    visible_facts max_suggs goal_feats int_goal_feats =
     1.9 +fun query_internal ctxt engine num_facts num_feats (fact_names, featss, depss)
    1.10 +    (freqs as (_, _, dffreq)) visible_facts max_suggs goal_feats int_goal_feats =
    1.11    (trace_msg ctxt (fn () => "MaSh_SML query internal " ^ encode_strs goal_feats ^ " from {" ^
    1.12 -     elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] facts)) ^ "}");
    1.13 +     elide_string 1000 (space_implode " " (Vector.foldr (op ::) [] fact_names)) ^ "}");
    1.14     (case engine of
    1.15       MaSh_SML_kNN =>
    1.16       let
    1.17 @@ -632,7 +632,7 @@
    1.18         k_nearest_neighbors dffreq num_facts depss feat_facts max_suggs visible_facts int_goal_feats
    1.19       end
    1.20     | MaSh_SML_NB => naive_bayes freqs num_facts max_suggs visible_facts int_goal_feats)
    1.21 -   |> map (curry Vector.sub facts o fst))
    1.22 +   |> map (curry Vector.sub fact_names o fst))
    1.23  
    1.24  end;
    1.25  
    1.26 @@ -684,14 +684,47 @@
    1.27  type mash_state =
    1.28    {access_G : (proof_kind * string list * string list) Graph.T,
    1.29     xtabs : xtab * xtab,
    1.30 +   ffds : string vector * int list vector * int list vector,
    1.31 +   freqs : int vector * int Inttab.table vector * int vector,
    1.32     dirty_facts : string list option}
    1.33  
    1.34 +val empty_xtabs = (empty_xtab, empty_xtab)
    1.35 +val empty_ffds = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
    1.36 +val empty_freqs = (Vector.fromList [], Vector.fromList [], Vector.fromList [])
    1.37 +val empty_graphxx = (Graph.empty, empty_xtabs)
    1.38 +
    1.39  val empty_state =
    1.40    {access_G = Graph.empty,
    1.41 -   xtabs = (empty_xtab, empty_xtab),
    1.42 +   xtabs = empty_xtabs,
    1.43 +   ffds = empty_ffds,
    1.44 +   freqs = empty_freqs,
    1.45     dirty_facts = SOME []} : mash_state
    1.46  
    1.47 -val empty_graphxx = (Graph.empty, (empty_xtab, empty_xtab))
    1.48 +fun reorder_learns (num_facts, fact_tab) learns =
    1.49 +  let val ary = Array.array (num_facts, ("", [], [])) in
    1.50 +    List.app (fn learn as (fact, _, _) =>
    1.51 +        Array.update (ary, the (Symtab.lookup fact_tab fact), learn))
    1.52 +      learns;
    1.53 +    Array.foldr (op ::) [] ary
    1.54 +  end
    1.55 +
    1.56 +fun recompute_ffd_freqs access_G (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab)) =
    1.57 +  let
    1.58 +    val learns =
    1.59 +      Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
    1.60 +      |> reorder_learns fact_xtab
    1.61 +
    1.62 +    val fact_names = Vector.fromList (map #1 learns)
    1.63 +    val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
    1.64 +    val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
    1.65 +
    1.66 +    val tfreq = Vector.tabulate (num_facts, K 0)
    1.67 +    val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
    1.68 +    val dffreq = Vector.tabulate (num_feats, K 0)
    1.69 +  in
    1.70 +    ((fact_names, featss, depss),
    1.71 +     MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss)
    1.72 +  end
    1.73  
    1.74  local
    1.75  
    1.76 @@ -737,9 +770,11 @@
    1.77                    else wipe_out_mash_state_dir ();
    1.78                    empty_graphxx)
    1.79                 | GREATER => raise FILE_VERSION_TOO_NEW ())
    1.80 +
    1.81 +             val (ffds, freqs) = recompute_ffd_freqs access_G xtabs
    1.82             in
    1.83               trace_msg ctxt (fn () => "Loaded fact graph (" ^ graph_info access_G ^ ")");
    1.84 -             {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []}
    1.85 +             {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []}
    1.86             end
    1.87           | _ => empty_state)))
    1.88    end
    1.89 @@ -749,7 +784,7 @@
    1.90    encode_strs feats ^ "; " ^ encode_strs deps ^ "\n"
    1.91  
    1.92  fun save_state _ (time_state as (_, {dirty_facts = SOME [], ...})) = time_state
    1.93 -  | save_state ctxt (memory_time, {access_G, xtabs, dirty_facts}) =
    1.94 +  | save_state ctxt (memory_time, {access_G, xtabs, ffds, freqs, dirty_facts}) =
    1.95      let
    1.96        fun append_entry (name, ((kind, feats, deps), (parents, _))) =
    1.97          cons (kind, name, Graph.Keys.dest parents, feats, deps)
    1.98 @@ -770,7 +805,8 @@
    1.99          (case dirty_facts of
   1.100            SOME dirty_facts => "; " ^ string_of_int (length dirty_facts) ^ " dirty fact(s)"
   1.101          | _ => "") ^  ")");
   1.102 -      (Time.now (), {access_G = access_G, xtabs = xtabs, dirty_facts = SOME []})
   1.103 +      (Time.now (),
   1.104 +       {access_G = access_G, xtabs = xtabs, ffds = ffds, freqs = freqs, dirty_facts = SOME []})
   1.105      end
   1.106  
   1.107  val global_state = Synchronized.var "Sledgehammer_MaSh.global_state" (Time.zeroTime, empty_state)
   1.108 @@ -1275,16 +1311,6 @@
   1.109  fun add_const_counts t =
   1.110    fold (fn s => Symtab.map_default (s, 0) (Integer.add 1)) (Term.add_const_names t [])
   1.111  
   1.112 -fun reorder_learns (num_facts, fact_tab) learns0 =
   1.113 -  let
   1.114 -    val learns = Array.array (num_facts, ("", [], []))
   1.115 -  in
   1.116 -    List.app (fn learn as (fact, _, _) =>
   1.117 -        Array.update (learns, the (Symtab.lookup fact_tab fact), learn))
   1.118 -      learns0;
   1.119 -    Array.foldr (op ::) [] learns
   1.120 -  end
   1.121 -
   1.122  fun mash_suggested_facts ctxt ({debug, overlord, ...} : params) max_suggs hyp_ts concl_t facts =
   1.123    let
   1.124      val thy = Proof_Context.theory_of ctxt
   1.125 @@ -1333,9 +1359,9 @@
   1.126          (parents, hints, feats)
   1.127        end
   1.128  
   1.129 -    val ((access_G, (fact_xtab as (num_facts, fact_tab), (num_feats, feat_tab))), py_suggs) =
   1.130 -      peek_state ctxt overlord (fn {access_G, xtabs, ...} =>
   1.131 -        ((access_G, xtabs),
   1.132 +    val ((access_G, ((num_facts, fact_tab), (num_feats, feat_tab)), ffds, freqs), py_suggs) =
   1.133 +      peek_state ctxt overlord (fn {access_G, xtabs, ffds, freqs, ...} =>
   1.134 +        ((access_G, xtabs, ffds, freqs),
   1.135           if Graph.is_empty access_G then
   1.136             (trace_msg ctxt (K "Nothing has been learned yet"); [])
   1.137           else if engine = MaSh_Py then
   1.138 @@ -1364,25 +1390,10 @@
   1.139              end
   1.140            else
   1.141              let
   1.142 -              val learns0 =
   1.143 -                Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
   1.144 -              val learns = reorder_learns fact_xtab learns0
   1.145 -
   1.146 -              val facts = Vector.fromList (map #1 learns)
   1.147 -              val featss = Vector.fromList (map (map_filter (Symtab.lookup feat_tab) o #2) learns)
   1.148 -              val depss = Vector.fromList (map (map_filter (Symtab.lookup fact_tab) o #3) learns)
   1.149 -
   1.150 -              val tfreq = Vector.tabulate (num_facts, K 0)
   1.151 -              val sfreq = Vector.tabulate (num_facts, K Inttab.empty)
   1.152 -              val dffreq = Vector.tabulate (num_feats, K 0)
   1.153 -
   1.154 -              val freqs' =
   1.155 -                MaSh_SML.learn_facts (tfreq, sfreq, dffreq) 0 num_facts num_feats depss featss
   1.156 -
   1.157                val int_goal_feats = map_filter (Symtab.lookup feat_tab) goal_feats
   1.158              in
   1.159 -              MaSh_SML.query_internal ctxt engine num_facts num_feats (facts, featss, depss) freqs'
   1.160 -                visible_facts max_suggs goal_feats int_goal_feats
   1.161 +              MaSh_SML.query_internal ctxt engine num_facts num_feats ffds freqs visible_facts
   1.162 +                max_suggs goal_feats int_goal_feats
   1.163              end
   1.164          end
   1.165  
   1.166 @@ -1447,7 +1458,7 @@
   1.167          val feats = map fst (features_of ctxt thy 0 Symtab.empty (Local, General) [t])
   1.168        in
   1.169          map_state ctxt overlord
   1.170 -          (fn state as {access_G, xtabs, dirty_facts} =>
   1.171 +          (fn state as {access_G, xtabs, ffds, freqs, dirty_facts} =>
   1.172               let
   1.173                 val parents = maximal_wrt_access_graph access_G facts
   1.174                 val deps = used_ths
   1.175 @@ -1459,10 +1470,12 @@
   1.176                 else
   1.177                   let
   1.178                     val name = learned_proof_name ()
   1.179 -                   val (access_G, xtabs) =
   1.180 +                   val (access_G', xtabs') =
   1.181                       add_node Automatic_Proof name parents feats deps (access_G, xtabs)
   1.182 +
   1.183 +                   val (ffds', freqs') = recompute_ffd_freqs access_G' xtabs'
   1.184                   in
   1.185 -                   {access_G = access_G, xtabs = xtabs,
   1.186 +                   {access_G = access_G', xtabs = xtabs', ffds = ffds', freqs = freqs',
   1.187                      dirty_facts = Option.map (cons name) dirty_facts}
   1.188                   end
   1.189               end);
   1.190 @@ -1510,26 +1523,31 @@
   1.191              isar_dependencies_of name_tabs th
   1.192  
   1.193          fun do_commit [] [] [] state = state
   1.194 -          | do_commit learns relearns flops {access_G, xtabs, dirty_facts} =
   1.195 +          | do_commit learns relearns flops {access_G, xtabs, ffds, freqs, dirty_facts} =
   1.196              let
   1.197 +              val was_empty = Graph.is_empty access_G
   1.198 +
   1.199 +              (* TODO: use "fold_map" *)
   1.200                val (learns, (access_G, xtabs)) =
   1.201                  fold (learn_wrt_access_graph ctxt) learns ([], (access_G, xtabs))
   1.202                val (relearns, access_G) =
   1.203                  fold (relearn_wrt_access_graph ctxt) relearns ([], access_G)
   1.204  
   1.205 -              val was_empty = Graph.is_empty access_G
   1.206                val access_G = access_G |> fold flop_wrt_access_graph flops
   1.207                val dirty_facts =
   1.208                  (case (was_empty, dirty_facts) of
   1.209                    (false, SOME names) => SOME (map #1 learns @ map #1 relearns @ names)
   1.210                  | _ => NONE)
   1.211 +
   1.212 +              val (ffds', freqs') = recompute_ffd_freqs access_G xtabs
   1.213              in
   1.214                if engine = MaSh_Py then
   1.215                  (MaSh_Py.learn ctxt overlord (save andalso null relearns) (rev learns);
   1.216                   MaSh_Py.relearn ctxt overlord save relearns)
   1.217                else
   1.218                  ();
   1.219 -              {access_G = access_G, xtabs = xtabs, dirty_facts = dirty_facts}
   1.220 +              {access_G = access_G, xtabs = xtabs, ffds = ffds', freqs = freqs',
   1.221 +               dirty_facts = dirty_facts}
   1.222              end
   1.223  
   1.224          fun commit last learns relearns flops =