tuning
authorblanchet
Thu, 26 Jun 2014 13:33:27 +0200
changeset 58697a9e0f9d35125
parent 58696 ded92100ffd7
child 58698 9816f692b0ca
tuning
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:21 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:33:27 2014 +0200
     1.3 @@ -352,7 +352,7 @@
     1.4  
     1.5  exception BOTTOM of int
     1.6  
     1.7 -fun heap cmp bnd a =
     1.8 +fun heap cmp bnd al a =
     1.9    let
    1.10      fun maxson l i =
    1.11        let val i31 = i + i + i + 1 in
    1.12 @@ -394,12 +394,10 @@
    1.13            Array.update (a, i, e)
    1.14        end
    1.15  
    1.16 -    val l = Array.length a
    1.17 -
    1.18 -    fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
    1.19 +    fun for i = if i < 0 then () else (trickle al i (Array.sub (a, i)); for (i - 1))
    1.20  
    1.21      fun for2 i =
    1.22 -      if i < Integer.max 2 (l - bnd) then
    1.23 +      if i < Integer.max 2 (al - bnd) then
    1.24          ()
    1.25        else
    1.26          let val e = Array.sub (a, i) in
    1.27 @@ -408,9 +406,9 @@
    1.28            for2 (i - 1)
    1.29          end
    1.30    in
    1.31 -    for (((l + 1) div 3) - 1);
    1.32 -    for2 (l - 1);
    1.33 -    if l > 1 then
    1.34 +    for (((al + 1) div 3) - 1);
    1.35 +    for2 (al - 1);
    1.36 +    if al > 1 then
    1.37        let val e = Array.sub (a, 1) in
    1.38          Array.update (a, 1, Array.sub (a, 0));
    1.39          Array.update (a, 0, e)
    1.40 @@ -457,7 +455,7 @@
    1.41        end
    1.42  
    1.43      val _ = List.app do_feat feats
    1.44 -    val _ = heap (Real.compare o pairself snd) num_facts overlaps_sqr
    1.45 +    val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
    1.46      val no_recommends = Unsynchronized.ref 0
    1.47      val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
    1.48      val age = Unsynchronized.ref 1000000000.0
    1.49 @@ -498,39 +496,34 @@
    1.50        if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
    1.51    in
    1.52      while1 (); while2 ();
    1.53 -    heap (Real.compare o pairself snd) max_suggs recommends;
    1.54 +    heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
    1.55      ret [] (Integer.max 0 (num_visible_facts - max_suggs))
    1.56    end
    1.57  
    1.58  val nb_def_prior_weight = 21 (* FUDGE *)
    1.59  
    1.60 -fun naive_bayes_learn_fact tfreq sfreq dffreq th feats deps =
    1.61 +fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
    1.62    let
    1.63 -    fun add_th weight t =
    1.64 +    fun learn_fact th feats deps =
    1.65        let
    1.66 -        val im = Array.sub (sfreq, t)
    1.67 -        fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
    1.68 +        fun add_th weight t =
    1.69 +          let
    1.70 +            val im = Array.sub (sfreq, t)
    1.71 +            fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
    1.72 +          in
    1.73 +            Array.update (tfreq, t, weight + Array.sub (tfreq, t));
    1.74 +            Array.update (sfreq, t, fold fold_fn feats im)
    1.75 +          end
    1.76 +
    1.77 +        fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
    1.78        in
    1.79 -        Array.update (tfreq, t, weight + Array.sub (tfreq, t));
    1.80 -        Array.update (sfreq, t, fold fold_fn feats im)
    1.81 +        add_th nb_def_prior_weight th;
    1.82 +        List.app (add_th 1) deps;
    1.83 +        List.app add_sym feats
    1.84        end
    1.85  
    1.86 -    fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
    1.87 -  in
    1.88 -    add_th nb_def_prior_weight th;
    1.89 -    List.app (add_th 1) deps;
    1.90 -    List.app add_sym feats
    1.91 -  end
    1.92 -
    1.93 -fun naive_bayes_learn num_facts get_deps get_feats num_feats =
    1.94 -  let
    1.95 -    val tfreq = Array.array (num_facts, 0)
    1.96 -    val sfreq = Array.array (num_facts, Inttab.empty)
    1.97 -    val dffreq = Array.array (num_feats, 0)
    1.98 -
    1.99      fun for i =
   1.100 -      if i = num_facts then ()
   1.101 -      else (naive_bayes_learn_fact tfreq sfreq dffreq i (get_feats i) (get_deps i); for (i + 1))
   1.102 +      if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
   1.103  
   1.104      val ln_afreq = Math.ln (Real.fromInt num_facts)
   1.105    in
   1.106 @@ -539,6 +532,15 @@
   1.107       Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
   1.108    end
   1.109  
   1.110 +fun learn num_facts get_deps get_feats num_feats =
   1.111 +  let
   1.112 +    val tfreq = Array.array (num_facts, 0)
   1.113 +    val sfreq = Array.array (num_facts, Inttab.empty)
   1.114 +    val dffreq = Array.array (num_feats, 0)
   1.115 +  in
   1.116 +    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
   1.117 +  end
   1.118 +
   1.119  fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
   1.120      (tfreq, sfreq, idf) =
   1.121    let
   1.122 @@ -579,12 +581,12 @@
   1.123      fun ret acc at =
   1.124        if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
   1.125    in
   1.126 -    heap (Real.compare o pairself snd) max_suggs posterior;
   1.127 +    heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
   1.128      ret [] (Integer.max 0 (num_visible_facts - max_suggs))
   1.129    end
   1.130  
   1.131  fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
   1.132 -  naive_bayes_learn num_facts get_deps get_feats num_feats
   1.133 +  learn num_facts get_deps get_feats num_feats
   1.134    |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
   1.135  
   1.136  (* experimental *)