use correct weights in MeSh driver
authorblanchet
Thu, 17 Jan 2013 23:29:22 +0100
changeset 519807a7d1418301e
parent 51979 2a990baa09af
child 51981 b85cb3049df9
use correct weights in MeSh driver
src/HOL/TPTP/mash_eval.ML
src/HOL/TPTP/mash_export.ML
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/TPTP/mash_eval.ML	Thu Jan 17 23:29:17 2013 +0100
     1.2 +++ b/src/HOL/TPTP/mash_eval.ML	Thu Jan 17 23:29:22 2013 +0100
     1.3 @@ -93,11 +93,12 @@
     1.4                           mesh_isar_line), mesh_prover_line)) =
     1.5        if in_range range j then
     1.6          let
     1.7 -          val (name1, mepo_suggs) = extract_suggestions mepo_line
     1.8 -          val (name2, mash_isar_suggs) = extract_suggestions mash_isar_line
     1.9 -          val (name3, mash_prover_suggs) = extract_suggestions mash_prover_line
    1.10 -          val (name4, mesh_isar_suggs) = extract_suggestions mesh_isar_line
    1.11 -          val (name5, mesh_prover_suggs) = extract_suggestions mesh_prover_line
    1.12 +          val get_suggs = extract_suggestions ##> take slack_max_facts
    1.13 +          val (name1, mepo_suggs) = get_suggs mepo_line
    1.14 +          val (name2, mash_isar_suggs) = get_suggs mash_isar_line
    1.15 +          val (name3, mash_prover_suggs) = get_suggs mash_prover_line
    1.16 +          val (name4, mesh_isar_suggs) = get_suggs mesh_isar_line
    1.17 +          val (name5, mesh_prover_suggs) = get_suggs mesh_prover_line
    1.18            val [name] =
    1.19              [name1, name2, name3, name4, name5]
    1.20              |> filter (curry (op <>) "") |> distinct (op =)
    1.21 @@ -115,12 +116,12 @@
    1.22            val mepo_facts =
    1.23              get_facts mepo_suggs (fn _ =>
    1.24                  mepo_suggested_facts ctxt params prover slack_max_facts NONE
    1.25 -                                     hyp_ts concl_t facts
    1.26 -                |> weight_mepo_facts)
    1.27 +                                     hyp_ts concl_t facts)
    1.28 +            |> weight_mepo_facts
    1.29            fun mash_of suggs =
    1.30              get_facts suggs (fn _ =>
    1.31 -                find_mash_suggestions slack_max_facts suggs facts [] []
    1.32 -                |> fst |> weight_mash_facts)
    1.33 +                find_mash_suggestions slack_max_facts suggs facts [] [] |> fst)
    1.34 +            |> weight_mash_facts
    1.35            val mash_isar_facts = mash_of mash_isar_suggs
    1.36            val mash_prover_facts = mash_of mash_prover_suggs
    1.37            fun mess_of mash_facts =
    1.38 @@ -129,12 +130,10 @@
    1.39            fun mesh_of suggs mash_facts =
    1.40              get_facts suggs (fn _ =>
    1.41                  mesh_facts (Thm.eq_thm_prop o pairself snd) slack_max_facts
    1.42 -                           (mess_of mash_facts)
    1.43 -                |> map (rpair 1.0))
    1.44 +                           (mess_of mash_facts))
    1.45            val mesh_isar_facts = mesh_of mesh_isar_suggs mash_isar_facts
    1.46            val mesh_prover_facts = mesh_of mesh_prover_suggs mash_prover_facts
    1.47 -          val isar_facts =
    1.48 -            find_suggested_facts (map (rpair 1.0) isar_deps) facts
    1.49 +          val isar_facts = find_suggested_facts isar_deps facts
    1.50            (* adapted from "mirabelle_sledgehammer.ML" *)
    1.51            fun set_file_name method (SOME dir) =
    1.52                let
    1.53 @@ -147,7 +146,7 @@
    1.54                  #> Config.put SMT_Config.debug_files (dir ^ "/" ^ prob_prefix)
    1.55                end
    1.56              | set_file_name _ NONE = I
    1.57 -          fun prove method facts =
    1.58 +          fun prove method get facts =
    1.59              if not (member (op =) methods method) orelse
    1.60                 (null facts andalso method <> IsarN) then
    1.61                (str_of_method method ^ "Skipped", 0)
    1.62 @@ -157,7 +156,7 @@
    1.63                    ((K (encode_str (nickname_of_thm th)), stature), th)
    1.64                  val facts =
    1.65                    facts
    1.66 -                  |> map (fst #> nickify)
    1.67 +                  |> map (get #> nickify)
    1.68                    |> maybe_instantiate_inducts ctxt hyp_ts concl_t
    1.69                    |> take (the max_facts)
    1.70                  val ctxt = ctxt |> set_file_name method prob_dir_name
    1.71 @@ -166,12 +165,12 @@
    1.72                  val ok = if is_none outcome then 1 else 0
    1.73                in (str_of_result method facts res, ok) end
    1.74            val ress =
    1.75 -            [fn () => prove MePoN mepo_facts,
    1.76 -             fn () => prove MaSh_IsarN mash_isar_facts,
    1.77 -             fn () => prove MaSh_ProverN mash_prover_facts,
    1.78 -             fn () => prove MeSh_IsarN mesh_isar_facts,
    1.79 -             fn () => prove MeSh_ProverN mesh_prover_facts,
    1.80 -             fn () => prove IsarN isar_facts]
    1.81 +            [fn () => prove MePoN fst mepo_facts,
    1.82 +             fn () => prove MaSh_IsarN fst mash_isar_facts,
    1.83 +             fn () => prove MaSh_ProverN fst mash_prover_facts,
    1.84 +             fn () => prove MeSh_IsarN I mesh_isar_facts,
    1.85 +             fn () => prove MeSh_ProverN I mesh_prover_facts,
    1.86 +             fn () => prove IsarN I isar_facts]
    1.87              |> (* Par_List. *) map (fn f => f ())
    1.88          in
    1.89            "Goal " ^ string_of_int j ^ ": " ^ name :: map fst ress
     2.1 --- a/src/HOL/TPTP/mash_export.ML	Thu Jan 17 23:29:17 2013 +0100
     2.2 +++ b/src/HOL/TPTP/mash_export.ML	Thu Jan 17 23:29:22 2013 +0100
     2.3 @@ -213,10 +213,10 @@
     2.4        let
     2.5          val (name, mash_suggs) =
     2.6            extract_suggestions mash_line
     2.7 -          ||> (map fst #> weight_mash_facts)
     2.8 +          ||> weight_mash_facts
     2.9          val (name', mepo_suggs) =
    2.10            extract_suggestions mepo_line
    2.11 -          ||> (map fst #> weight_mash_facts)
    2.12 +          ||> weight_mepo_facts
    2.13          val _ = if name = name' then () else error "Input files out of sync."
    2.14          val mess =
    2.15            [(mepo_weight, (mepo_suggs, [])),
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jan 17 23:29:17 2013 +0100
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Thu Jan 17 23:29:22 2013 +0100
     3.3 @@ -29,7 +29,7 @@
     3.4    val unencode_str : string -> string
     3.5    val unencode_strs : string -> string list
     3.6    val encode_features : (string * real) list -> string
     3.7 -  val extract_suggestions : string -> string * (string * real) list
     3.8 +  val extract_suggestions : string -> string * string list
     3.9  
    3.10    structure MaSh:
    3.11    sig
    3.12 @@ -41,14 +41,12 @@
    3.13        Proof.context -> bool -> (string * string list) list -> unit
    3.14      val suggest :
    3.15        Proof.context -> bool -> bool -> int
    3.16 -      -> string list * (string * real) list * string list
    3.17 -      -> (string * real) list
    3.18 +      -> string list * (string * real) list * string list -> string list
    3.19    end
    3.20  
    3.21    val mash_unlearn : Proof.context -> unit
    3.22    val nickname_of_thm : thm -> string
    3.23 -  val find_suggested_facts :
    3.24 -    (string * 'a) list -> ('b * thm) list -> (('b * thm) * 'a) list
    3.25 +  val find_suggested_facts : string list -> ('b * thm) list -> ('b * thm) list
    3.26    val mesh_facts :
    3.27      ('a * 'a -> bool) -> int -> (real * (('a * real) list * 'a list)) list
    3.28      -> 'a list
    3.29 @@ -69,8 +67,8 @@
    3.30      -> bool * string list
    3.31    val weight_mash_facts : 'a list -> ('a * real) list
    3.32    val find_mash_suggestions :
    3.33 -    int -> (Symtab.key * 'a) list -> ('b * thm) list -> ('b * thm) list
    3.34 -    -> ('b * thm) list -> ('b * thm) list * ('b * thm) list
    3.35 +    int -> string list -> ('b * thm) list -> ('b * thm) list -> ('b * thm) list
    3.36 +    -> ('b * thm) list * ('b * thm) list
    3.37    val mash_suggested_facts :
    3.38      Proof.context -> params -> string -> int -> term list -> term -> fact list
    3.39      -> fact list * fact list
    3.40 @@ -219,11 +217,13 @@
    3.41    (if learn_hints orelse null hints then "" else "; " ^ encode_strs hints) ^
    3.42    "\n"
    3.43  
    3.44 +(* The weights currently returned by "mash.py" are too spaced out to make any
    3.45 +   sense. *)
    3.46  fun extract_suggestion sugg =
    3.47    case space_explode "=" sugg of
    3.48      [name, weight] =>
    3.49 -    SOME (unencode_str name, Real.fromString weight |> the_default 1.0)
    3.50 -  | [name] => SOME (unencode_str name, 1.0)
    3.51 +    SOME (unencode_str name (* , Real.fromString weight |> the_default 1.0 *))
    3.52 +  | [name] => SOME (unencode_str name (* , 1.0 *))
    3.53    | _ => NONE
    3.54  
    3.55  fun extract_suggestions line =
    3.56 @@ -436,10 +436,8 @@
    3.57  fun find_suggested_facts suggs facts =
    3.58    let
    3.59      fun add_fact (fact as (_, th)) = Symtab.default (nickname_of_thm th, fact)
    3.60 -    val tab = Symtab.empty |> fold add_fact facts
    3.61 -    fun find_sugg (name, weight) =
    3.62 -      Symtab.lookup tab name |> Option.map (rpair weight)
    3.63 -  in map_filter find_sugg suggs end
    3.64 +    val tab = fold add_fact facts Symtab.empty
    3.65 +  in map_filter (Symtab.lookup tab) suggs end
    3.66  
    3.67  fun scaled_avg [] = 0
    3.68    | scaled_avg xs =
    3.69 @@ -776,11 +774,7 @@
    3.70  fun find_mash_suggestions _ [] _ _ raw_unknown = ([], raw_unknown)
    3.71    | find_mash_suggestions max_facts suggs facts chained raw_unknown =
    3.72      let
    3.73 -      val raw_mash =
    3.74 -        facts |> find_suggested_facts suggs
    3.75 -              (* The weights currently returned by "mash.py" are too spaced out
    3.76 -                 to make any sense. *)
    3.77 -              |> map fst
    3.78 +      val raw_mash = find_suggested_facts suggs facts
    3.79        val unknown_chained =
    3.80          inter (Thm.eq_thm_prop o pairself snd) chained raw_unknown
    3.81        val proximity =
    3.82 @@ -814,9 +808,8 @@
    3.83                  chained |> filter (is_fact_in_graph access_G snd)
    3.84                          |> map (nickname_of_thm o snd)
    3.85              in
    3.86 -              (access_G,
    3.87 -               MaSh.suggest ctxt overlord learn max_facts
    3.88 -                            (parents, feats, hints))
    3.89 +              (access_G, MaSh.suggest ctxt overlord learn max_facts
    3.90 +                                      (parents, feats, hints))
    3.91              end)
    3.92      val unknown = facts |> filter_out (is_fact_in_graph access_G snd)
    3.93    in find_mash_suggestions max_facts suggs facts chained unknown end
    3.94 @@ -1079,7 +1072,7 @@
    3.95  
    3.96  (* Generate more suggestions than requested, because some might be thrown out
    3.97     later for various reasons. *)
    3.98 -fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts div 2)
    3.99 +fun generous_max_facts max_facts = max_facts + Int.min (50, max_facts)
   3.100  
   3.101  val mepo_weight = 0.5
   3.102  val mash_weight = 0.5