1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:21 2014 +0200
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Thu Jun 26 13:33:27 2014 +0200
1.3 @@ -352,7 +352,7 @@
1.4
1.5 exception BOTTOM of int
1.6
1.7 -fun heap cmp bnd a =
1.8 +fun heap cmp bnd al a =
1.9 let
1.10 fun maxson l i =
1.11 let val i31 = i + i + i + 1 in
1.12 @@ -394,12 +394,10 @@
1.13 Array.update (a, i, e)
1.14 end
1.15
1.16 - val l = Array.length a
1.17 -
1.18 - fun for i = if i < 0 then () else (trickle l i (Array.sub (a, i)); for (i - 1))
1.19 + fun for i = if i < 0 then () else (trickle al i (Array.sub (a, i)); for (i - 1))
1.20
1.21 fun for2 i =
1.22 - if i < Integer.max 2 (l - bnd) then
1.23 + if i < Integer.max 2 (al - bnd) then
1.24 ()
1.25 else
1.26 let val e = Array.sub (a, i) in
1.27 @@ -408,9 +406,9 @@
1.28 for2 (i - 1)
1.29 end
1.30 in
1.31 - for (((l + 1) div 3) - 1);
1.32 - for2 (l - 1);
1.33 - if l > 1 then
1.34 + for (((al + 1) div 3) - 1);
1.35 + for2 (al - 1);
1.36 + if al > 1 then
1.37 let val e = Array.sub (a, 1) in
1.38 Array.update (a, 1, Array.sub (a, 0));
1.39 Array.update (a, 0, e)
1.40 @@ -457,7 +455,7 @@
1.41 end
1.42
1.43 val _ = List.app do_feat feats
1.44 - val _ = heap (Real.compare o pairself snd) num_facts overlaps_sqr
1.45 + val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
1.46 val no_recommends = Unsynchronized.ref 0
1.47 val recommends = Array.tabulate (num_visible_facts, rpair 0.0)
1.48 val age = Unsynchronized.ref 1000000000.0
1.49 @@ -498,39 +496,34 @@
1.50 if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
1.51 in
1.52 while1 (); while2 ();
1.53 - heap (Real.compare o pairself snd) max_suggs recommends;
1.54 + heap (Real.compare o pairself snd) max_suggs num_visible_facts recommends;
1.55 ret [] (Integer.max 0 (num_visible_facts - max_suggs))
1.56 end
1.57
1.58 val nb_def_prior_weight = 21 (* FUDGE *)
1.59
1.60 -fun naive_bayes_learn_fact tfreq sfreq dffreq th feats deps =
1.61 +fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats =
1.62 let
1.63 - fun add_th weight t =
1.64 + fun learn_fact th feats deps =
1.65 let
1.66 - val im = Array.sub (sfreq, t)
1.67 - fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
1.68 + fun add_th weight t =
1.69 + let
1.70 + val im = Array.sub (sfreq, t)
1.71 + fun fold_fn s sf = Inttab.map_default (s, 0) (Integer.add weight) sf
1.72 + in
1.73 + Array.update (tfreq, t, weight + Array.sub (tfreq, t));
1.74 + Array.update (sfreq, t, fold fold_fn feats im)
1.75 + end
1.76 +
1.77 + fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
1.78 in
1.79 - Array.update (tfreq, t, weight + Array.sub (tfreq, t));
1.80 - Array.update (sfreq, t, fold fold_fn feats im)
1.81 + add_th nb_def_prior_weight th;
1.82 + List.app (add_th 1) deps;
1.83 + List.app add_sym feats
1.84 end
1.85
1.86 - fun add_sym s = Array.update (dffreq, s, 1 + Array.sub (dffreq, s))
1.87 - in
1.88 - add_th nb_def_prior_weight th;
1.89 - List.app (add_th 1) deps;
1.90 - List.app add_sym feats
1.91 - end
1.92 -
1.93 -fun naive_bayes_learn num_facts get_deps get_feats num_feats =
1.94 - let
1.95 - val tfreq = Array.array (num_facts, 0)
1.96 - val sfreq = Array.array (num_facts, Inttab.empty)
1.97 - val dffreq = Array.array (num_feats, 0)
1.98 -
1.99 fun for i =
1.100 - if i = num_facts then ()
1.101 - else (naive_bayes_learn_fact tfreq sfreq dffreq i (get_feats i) (get_deps i); for (i + 1))
1.102 + if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
1.103
1.104 val ln_afreq = Math.ln (Real.fromInt num_facts)
1.105 in
1.106 @@ -539,6 +532,15 @@
1.107 Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq))
1.108 end
1.109
1.110 +fun learn num_facts get_deps get_feats num_feats =
1.111 + let
1.112 + val tfreq = Array.array (num_facts, 0)
1.113 + val sfreq = Array.array (num_facts, Inttab.empty)
1.114 + val dffreq = Array.array (num_feats, 0)
1.115 + in
1.116 + learn_facts tfreq sfreq dffreq num_facts get_deps get_feats num_feats
1.117 + end
1.118 +
1.119 fun naive_bayes_query (kuehlwein_log, kuehlwein_params) num_facts num_visible_facts max_suggs feats
1.120 (tfreq, sfreq, idf) =
1.121 let
1.122 @@ -579,12 +581,12 @@
1.123 fun ret acc at =
1.124 if at = num_visible_facts then acc else ret (Array.sub (posterior, at) :: acc) (at + 1)
1.125 in
1.126 - heap (Real.compare o pairself snd) max_suggs posterior;
1.127 + heap (Real.compare o pairself snd) max_suggs num_visible_facts posterior;
1.128 ret [] (Integer.max 0 (num_visible_facts - max_suggs))
1.129 end
1.130
1.131 fun naive_bayes opts num_facts num_visible_facts get_deps get_feats num_feats max_suggs feats =
1.132 - naive_bayes_learn num_facts get_deps get_feats num_feats
1.133 + learn num_facts get_deps get_feats num_feats
1.134 |> naive_bayes_query opts num_facts num_visible_facts max_suggs feats
1.135
1.136 (* experimental *)