1 (* Title: HOL/TPTP/mash_eval.ML
2 Author: Jasmin Blanchette, TU Muenchen
5 Evaluate proof suggestions from MaSh (Machine-learning for Sledgehammer).
10 type params = Sledgehammer_Provers.params
13 val MaSh_IsarN : string
14 val MaSh_ProverN : string
15 val MeSh_IsarN : string
16 val MeSh_ProverN : string
18 val evaluate_mash_suggestions :
19 Proof.context -> params -> int * int option -> bool -> string list
20 -> string option -> string -> string -> string -> string -> string -> string
24 structure MaSh_Eval : MASH_EVAL =
27 open Sledgehammer_Util
28 open Sledgehammer_Fact
29 open Sledgehammer_MePo
30 open Sledgehammer_MaSh
31 open Sledgehammer_Provers
32 open Sledgehammer_Isar
34 val MaSh_IsarN = MaShN ^ "-Isar"
35 val MaSh_ProverN = MaShN ^ "-Prover"
36 val MeSh_IsarN = MeShN ^ "-Isar"
37 val MeSh_ProverN = MeShN ^ "-Prover"
40 fun in_range (from, to) j =
41 j >= from andalso (to = NONE orelse j <= the to)
43 fun evaluate_mash_suggestions ctxt params range linearize methods prob_dir_name
44 mepo_file_name mash_isar_file_name mash_prover_file_name
45 mesh_isar_file_name mesh_prover_file_name report_file_name =
47 val zeros = [0, 0, 0, 0, 0, 0]
48 val report_path = report_file_name |> Path.explode
49 val _ = File.write report_path ""
50 fun print s = File.append report_path (s ^ "\n")
51 val {provers, max_facts, slice, type_enc, lam_trans, timeout, ...} =
52 default_params ctxt []
53 val prover = hd provers
54 val slack_max_facts = generous_max_facts (the max_facts)
55 val lines_of = Path.explode #> try File.read_lines #> these
57 [mepo_file_name, mash_isar_file_name, mash_prover_file_name,
58 mesh_isar_file_name, mesh_prover_file_name]
59 val lines as [mepo_lines, mash_isar_lines, mash_prover_lines,
60 mesh_isar_lines, mesh_prover_lines] =
61 map lines_of file_names
62 val num_lines = fold (Integer.max o length) lines 0
63 fun pad lines = lines @ replicate (num_lines - length lines) ""
65 pad mepo_lines ~~ pad mash_isar_lines ~~ pad mash_prover_lines ~~
66 pad mesh_isar_lines ~~ pad mesh_prover_lines
67 val css = clasimpset_rule_table_of ctxt
68 val facts = all_facts ctxt true false Symtab.empty [] [] css
69 val name_tabs = build_name_tables nickname_of_thm facts
70 fun with_index facts s = (find_index (curry (op =) s) facts + 1, s)
71 fun index_str (j, s) = s ^ "@" ^ string_of_int j
72 val str_of_method = enclose " " ": "
73 fun str_of_result method facts ({outcome, run_time, used_facts, ...}
75 let val facts = facts |> map (fst o fst) in
76 str_of_method method ^
77 (if is_none outcome then
78 "Success (" ^ ATP_Util.string_from_time run_time ^ "): " ^
79 (used_facts |> map (with_index facts o fst)
80 |> sort (int_ord o pairself fst)
82 |> space_implode " ") ^
83 (if length facts < the max_facts then
84 " (of " ^ string_of_int (length facts) ^ ")"
89 (facts |> take (the max_facts) |> tag_list 1
91 |> space_implode " "))
93 fun solve_goal (j, ((((mepo_line, mash_isar_line), mash_prover_line),
94 mesh_isar_line), mesh_prover_line)) =
95 if in_range range j then
97 val get_suggs = extract_suggestions ##> take slack_max_facts
98 val (name1, mepo_suggs) = get_suggs mepo_line
99 val (name2, mash_isar_suggs) = get_suggs mash_isar_line
100 val (name3, mash_prover_suggs) = get_suggs mash_prover_line
101 val (name4, mesh_isar_suggs) = get_suggs mesh_isar_line
102 val (name5, mesh_prover_suggs) = get_suggs mesh_prover_line
104 [name1, name2, name3, name4, name5]
105 |> filter (curry (op <>) "") |> distinct (op =)
106 handle General.Match => error "Input files out of sync."
108 case find_first (fn (_, th) => nickname_of_thm th = name) facts of
110 | NONE => error ("No fact called \"" ^ name ^ "\".")
111 val goal = goal_of_thm (Proof_Context.theory_of ctxt) th
112 val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal 1
113 val isar_deps = isar_dependencies_of name_tabs th
116 |> filter (fn (_, th') =>
117 if linearize then crude_thm_ord (th', th) = LESS
118 else thm_less (th', th))
120 find_suggested_facts ctxt facts #> map fact_of_raw_fact
121 fun get_facts [] compute = compute facts
122 | get_facts suggs _ = find_suggs suggs
124 get_facts mepo_suggs (fn _ =>
125 mepo_suggested_facts ctxt params prover slack_max_facts NONE
126 hyp_ts concl_t facts)
129 get_facts suggs (fn _ =>
130 find_mash_suggestions ctxt slack_max_facts suggs facts [] []
131 |> fst |> map fact_of_raw_fact)
133 val mash_isar_facts = mash_of mash_isar_suggs
134 val mash_prover_facts = mash_of mash_prover_suggs
135 fun mess_of mash_facts =
136 [(mepo_weight, (mepo_facts, [])),
137 (mash_weight, (mash_facts, []))]
138 fun mesh_of suggs mash_facts =
139 get_facts suggs (fn _ =>
140 mesh_facts (Thm.eq_thm_prop o pairself snd) slack_max_facts
141 (mess_of mash_facts))
142 val mesh_isar_facts = mesh_of mesh_isar_suggs mash_isar_facts
143 val mesh_prover_facts = mesh_of mesh_prover_suggs mash_prover_facts
144 val isar_facts = find_suggs isar_deps
145 (* adapted from "mirabelle_sledgehammer.ML" *)
146 fun set_file_name method (SOME dir) =
149 "goal_" ^ string_of_int j ^ "__" ^ encode_str name ^ "__" ^
152 Config.put dest_dir dir
153 #> Config.put problem_prefix (prob_prefix ^ "__")
154 #> Config.put SMT_Config.debug_files (dir ^ "/" ^ prob_prefix)
156 | set_file_name _ NONE = I
157 fun prove method get facts =
158 if not (member (op =) methods method) orelse
159 (null facts andalso method <> IsarN) then
160 (str_of_method method ^ "Skipped", 0)
163 fun nickify ((_, stature), th) =
164 ((K (encode_str (nickname_of_thm th)), stature), th)
167 |> map (get #> nickify)
168 |> maybe_instantiate_inducts ctxt hyp_ts concl_t
169 |> take (the max_facts)
170 |> map fact_of_raw_fact
171 val ctxt = ctxt |> set_file_name method prob_dir_name
172 val res as {outcome, ...} =
173 run_prover_for_mash ctxt params prover facts goal
174 val ok = if is_none outcome then 1 else 0
175 in (str_of_result method facts res, ok) end
177 [fn () => prove MePoN fst mepo_facts,
178 fn () => prove MaSh_IsarN fst mash_isar_facts,
179 fn () => prove MaSh_ProverN fst mash_prover_facts,
180 fn () => prove MeSh_IsarN I mesh_isar_facts,
181 fn () => prove MeSh_ProverN I mesh_prover_facts,
182 fn () => prove IsarN I isar_facts]
183 |> (* Par_List. *) map (fn f => f ())
185 "Goal " ^ string_of_int j ^ ": " ^ name :: map fst ress
186 |> cat_lines |> print;
191 fun total_of method ok n =
192 str_of_method method ^ string_of_int ok ^ " (" ^
193 Real.fmt (StringCvt.FIX (SOME 1))
194 (100.0 * Real.fromInt ok / Real.fromInt (Int.max (1, n))) ^ "%)"
195 val inst_inducts = Config.get ctxt instantiate_inducts
197 ["prover = " ^ prover,
198 "max_facts = " ^ string_of_int (the max_facts),
199 "slice" |> not slice ? prefix "dont_",
200 "type_enc = " ^ the_default "smart" type_enc,
201 "lam_trans = " ^ the_default "smart" lam_trans,
202 "timeout = " ^ ATP_Util.string_from_time (the_default one_year timeout),
203 "instantiate_inducts" |> not inst_inducts ? prefix "dont_"]
204 val _ = print " * * *";
205 val _ = print ("Options: " ^ commas options);
206 val oks = Par_List.map solve_goal (tag_list 1 lines)
208 val [mepo_ok, mash_isar_ok, mash_prover_ok, mesh_isar_ok, mesh_prover_ok,
210 if n = 0 then zeros else map Integer.sum (map_transpose I oks)
212 ["Successes (of " ^ string_of_int n ^ " goals)",
213 total_of MePoN mepo_ok n,
214 total_of MaSh_IsarN mash_isar_ok n,
215 total_of MaSh_ProverN mash_prover_ok n,
216 total_of MeSh_IsarN mesh_isar_ok n,
217 total_of MeSh_ProverN mesh_prover_ok n,
218 total_of IsarN isar_ok n]
219 |> cat_lines |> print