src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 58715 e9d47cd3239b
parent 58714 24738b4f8c6b
child 58716 cb6667e7cbc1
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:00 2014 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jun 26 13:36:06 2014 +0200
     1.3 @@ -398,16 +398,8 @@
     1.4  
     1.5  exception EXIT of unit
     1.6  
     1.7 -fun k_nearest_neighbors num_facts get_deps get_sym_ths max_suggs visible_facts num_feats feats =
     1.8 +fun k_nearest_neighbors dffreq num_facts deps_vec get_sym_ths max_suggs visible_facts conj_feats =
     1.9    let
    1.10 -    val dffreq = Array.array (num_feats, 0)
    1.11 -
    1.12 -    val add_sym = map_array_at dffreq (Integer.add 1)
    1.13 -    fun for1 i =
    1.14 -      if i = num_feats then () else
    1.15 -      (List.app (fn _ => add_sym i) (get_sym_ths i); for1 (i + 1))
    1.16 -    val _ = for1 0
    1.17 -
    1.18      val ln_afreq = Math.ln (Real.fromInt num_facts)
    1.19      fun tfidf feat = ln_afreq - Math.ln (Real.fromInt (Array.sub (dffreq, feat)))
    1.20  
    1.21 @@ -427,7 +419,7 @@
    1.22          List.app do_th (get_sym_ths s)
    1.23        end
    1.24  
    1.25 -    val _ = List.app do_feat feats
    1.26 +    val _ = List.app do_feat conj_feats
    1.27      val _ = heap (Real.compare o pairself snd) num_facts num_facts overlaps_sqr
    1.28      val no_recommends = Unsynchronized.ref 0
    1.29      val recommends = Array.tabulate (num_facts, rpair 0.0)
    1.30 @@ -447,7 +439,7 @@
    1.31            val (j, o2) = Array.sub (overlaps_sqr, num_facts - k - 1)
    1.32            val o1 = Math.sqrt o2
    1.33            val _ = inc_recommend j o1
    1.34 -          val ds = get_deps j
    1.35 +          val ds = Vector.sub (deps_vec, j)
    1.36            val l = Real.fromInt (length ds)
    1.37          in
    1.38            List.app (fn d => inc_recommend d (o1 / l)) ds
    1.39 @@ -474,7 +466,7 @@
    1.40  
    1.41  val nb_def_prior_weight = 21 (* FUDGE *)
    1.42  
    1.43 -fun learn_facts tfreq sfreq dffreq num_facts get_deps get_feats =
    1.44 +fun learn_facts tfreq sfreq dffreq num_facts depss featss =
    1.45    let
    1.46      fun learn_fact th feats deps =
    1.47        let
    1.48 @@ -495,35 +487,26 @@
    1.49        end
    1.50  
    1.51      fun for i =
    1.52 -      if i = num_facts then () else (learn_fact i (get_feats i) (get_deps i); for (i + 1))
    1.53 +      if i = num_facts then ()
    1.54 +      else (learn_fact i (Vector.sub (featss, i)) (Vector.sub (depss, i)); for (i + 1))
    1.55    in
    1.56 -    for 0;
    1.57 -    (Array.vector tfreq, Array.vector sfreq, Array.vector dffreq)
    1.58 +    for 0
    1.59    end
    1.60  
    1.61 -fun learn num_facts get_deps get_feats num_feats =
    1.62 -  let
    1.63 -    val tfreq = Array.array (num_facts, 0)
    1.64 -    val sfreq = Array.array (num_facts, Inttab.empty)
    1.65 -    val dffreq = Array.array (num_feats, 0)
    1.66 -  in
    1.67 -    learn_facts tfreq sfreq dffreq num_facts get_deps get_feats
    1.68 -  end
    1.69 -
    1.70 -fun naive_bayes_query num_facts max_suggs visible_facts feats (tfreq, sfreq, dffreq) =
    1.71 +fun naive_bayes_query tfreq sfreq dffreq num_facts max_suggs visible_facts conj_feats =
    1.72    let
    1.73      val tau = 0.05 (* FUDGE *)
    1.74      val pos_weight = 10.0 (* FUDGE *)
    1.75      val def_val = ~15.0 (* FUDGE *)
    1.76  
    1.77      val ln_afreq = Math.ln (Real.fromInt num_facts)
    1.78 -    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) dffreq
    1.79 +    val idf = Vector.map (fn i => ln_afreq - Math.ln (Real.fromInt i)) (Array.vector dffreq)
    1.80  
    1.81      fun tfidf feat = Vector.sub (idf, feat)
    1.82  
    1.83      fun log_posterior i =
    1.84        let
    1.85 -        val tfreq = Real.fromInt (Vector.sub (tfreq, i))
    1.86 +        val tfreq = Real.fromInt (Array.sub (tfreq, i))
    1.87  
    1.88          fun fold_feats f (res, sfh) =
    1.89            (case Inttab.lookup sfh f of
    1.90 @@ -532,7 +515,7 @@
    1.91               Inttab.delete f sfh)
    1.92            | NONE => (res + tfidf f * def_val, sfh))
    1.93  
    1.94 -        val (res, sfh) = fold fold_feats feats (Math.ln tfreq, Vector.sub (sfreq, i))
    1.95 +        val (res, sfh) = fold fold_feats conj_feats (Math.ln tfreq, Array.sub (sfreq, i))
    1.96  
    1.97          fun fold_sfh (f, sf) sow = sow + tfidf f * Math.ln (1.0 + (1.0 - Real.fromInt sf) / tfreq)
    1.98  
    1.99 @@ -551,26 +534,23 @@
   1.100      ret (Integer.max 0 (num_facts - max_suggs)) []
   1.101    end
   1.102  
   1.103 -fun naive_bayes num_facts get_deps get_feats num_feats max_suggs visible_facts feats =
   1.104 -  learn num_facts get_deps get_feats num_feats
   1.105 -  |> naive_bayes_query num_facts max_suggs visible_facts feats
   1.106 -
   1.107  (* experimental *)
   1.108 -fun naive_bayes_py ctxt overlord num_facts get_deps get_feats max_suggs feats =
   1.109 +fun naive_bayes_py ctxt overlord num_facts depss featss max_suggs conj_feats =
   1.110    let
   1.111      fun name_of_fact j = "f" ^ string_of_int j
   1.112      fun fact_of_name s = the (Int.fromString (unprefix "f" s))
   1.113      fun name_of_feature j = "F" ^ string_of_int j
   1.114      fun parents_of j = if j = 0 then [] else [name_of_fact (j - 1)]
   1.115  
   1.116 -    val learns = map (fn j => (name_of_fact j, parents_of j, map name_of_feature (get_feats j),
   1.117 -      map name_of_fact (get_deps j))) (0 upto num_facts - 1)
   1.118 +    val learns = map (fn j => (name_of_fact j, parents_of j,
   1.119 +      map name_of_feature (Vector.sub (featss, j)),
   1.120 +      map name_of_fact (Vector.sub (depss, j)))) (0 upto num_facts - 1)
   1.121      val parents' = parents_of num_facts
   1.122 -    val feats' = map (rpair 1.0 o name_of_feature) feats
   1.123 +    val conj_feats' = map (rpair 1.0 o name_of_feature) conj_feats
   1.124    in
   1.125      MaSh_Py.unlearn ctxt overlord;
   1.126      OS.Process.sleep (seconds 2.0); (* hack *)
   1.127 -    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', feats')
   1.128 +    MaSh_Py.query ctxt overlord max_suggs (learns, [], parents', conj_feats')
   1.129      |> map (apfst fact_of_name)
   1.130    end
   1.131  
   1.132 @@ -633,9 +613,14 @@
   1.133        val depss = map (map_filter (Symtab.lookup fact_tab) o #3) learns
   1.134  
   1.135        val fact_vec = Vector.fromList facts
   1.136 +      val feats_vec = Vector.fromList featss
   1.137        val deps_vec = Vector.fromList depss
   1.138  
   1.139 -      val get_deps = curry Vector.sub deps_vec
   1.140 +      val tfreq = Array.array (num_facts, 0)
   1.141 +      val sfreq = Array.array (num_facts, Inttab.empty)
   1.142 +      val dffreq = Array.array (num_feats, 0)
   1.143 +
   1.144 +      val _ = learn_facts tfreq sfreq dffreq num_facts deps_vec feats_vec
   1.145  
   1.146        val int_visible_facts = map_filter (Symtab.lookup fact_tab) visible_facts
   1.147        val int_conj_feats = map_filter (Symtab.lookup feat_tab) conj_feats
   1.148 @@ -652,17 +637,11 @@
   1.149                featss 0
   1.150            val get_facts = curry Array.sub facts_ary
   1.151          in
   1.152 -          k_nearest_neighbors num_facts get_deps get_facts max_suggs int_visible_facts num_feats
   1.153 +          k_nearest_neighbors dffreq num_facts deps_vec get_facts max_suggs int_visible_facts
   1.154              int_conj_feats
   1.155          end
   1.156        | MaSh_SML_NB =>
   1.157 -        let
   1.158 -          val feats_ary = Vector.fromList featss
   1.159 -          val get_feats = curry Vector.sub feats_ary
   1.160 -        in
   1.161 -          naive_bayes num_facts get_deps get_feats num_feats max_suggs int_visible_facts
   1.162 -            int_conj_feats
   1.163 -        end)
   1.164 +        naive_bayes_query tfreq sfreq dffreq num_facts max_suggs int_visible_facts int_conj_feats)
   1.165        |> map (curry Vector.sub fact_vec o fst)
   1.166      end
   1.167