1 (* Title: HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
2 Author: Jasmin Blanchette, TU Munich
5 structure Mirabelle_Sledgehammer_Filter : MIRABELLE_ACTION =
8 fun get args name default_value =
9 case AList.lookup (op =) args name of
10 SOME value => the (Real.fromString value)
11 | NONE => default_value
13 fun extract_relevance_fudge args
14 {local_const_multiplier, worse_irrel_freq, higher_order_irrel_weight,
15 abs_rel_weight, abs_irrel_weight, skolem_irrel_weight,
16 theory_const_rel_weight, theory_const_irrel_weight,
17 chained_const_irrel_weight, intro_bonus, elim_bonus, simp_bonus,
18 local_bonus, assum_bonus, chained_bonus, max_imperfect, max_imperfect_exp,
19 threshold_divisor, ridiculous_threshold} =
20 {local_const_multiplier =
21 get args "local_const_multiplier" local_const_multiplier,
22 worse_irrel_freq = get args "worse_irrel_freq" worse_irrel_freq,
23 higher_order_irrel_weight =
24 get args "higher_order_irrel_weight" higher_order_irrel_weight,
25 abs_rel_weight = get args "abs_rel_weight" abs_rel_weight,
26 abs_irrel_weight = get args "abs_irrel_weight" abs_irrel_weight,
27 skolem_irrel_weight = get args "skolem_irrel_weight" skolem_irrel_weight,
28 theory_const_rel_weight =
29 get args "theory_const_rel_weight" theory_const_rel_weight,
30 theory_const_irrel_weight =
31 get args "theory_const_irrel_weight" theory_const_irrel_weight,
32 chained_const_irrel_weight =
33 get args "chained_const_irrel_weight" chained_const_irrel_weight,
34 intro_bonus = get args "intro_bonus" intro_bonus,
35 elim_bonus = get args "elim_bonus" elim_bonus,
36 simp_bonus = get args "simp_bonus" simp_bonus,
37 local_bonus = get args "local_bonus" local_bonus,
38 assum_bonus = get args "assum_bonus" assum_bonus,
39 chained_bonus = get args "chained_bonus" chained_bonus,
40 max_imperfect = get args "max_imperfect" max_imperfect,
41 max_imperfect_exp = get args "max_imperfect_exp" max_imperfect_exp,
42 threshold_divisor = get args "threshold_divisor" threshold_divisor,
43 ridiculous_threshold = get args "ridiculous_threshold" ridiculous_threshold}
46 Table(type key = int * int val ord = prod_ord int_ord int_ord)
48 val proof_table = Unsynchronized.ref (Prooftab.empty: string list list Prooftab.table)
50 val num_successes = Unsynchronized.ref ([] : (int * int) list)
51 val num_failures = Unsynchronized.ref ([] : (int * int) list)
52 val num_found_proofs = Unsynchronized.ref ([] : (int * int) list)
53 val num_lost_proofs = Unsynchronized.ref ([] : (int * int) list)
54 val num_found_facts = Unsynchronized.ref ([] : (int * int) list)
55 val num_lost_facts = Unsynchronized.ref ([] : (int * int) list)
57 fun get id c = the_default 0 (AList.lookup (op =) (!c) id)
59 c := (case AList.lookup (op =) (!c) id of
60 SOME m => AList.update (op =) (id, m + n) (!c)
61 | NONE => (id, n) :: !c)
63 fun init proof_file _ thy =
66 case line |> space_explode ":" of
67 [line_num, offset, proof] =>
68 SOME (pairself (the o Int.fromString) (line_num, offset),
69 proof |> space_explode " " |> filter_out (curry (op =) ""))
71 val proofs = File.read (Path.explode proof_file)
73 proofs |> space_explode "\n"
75 |> AList.coalesce (op =)
77 in proof_table := proof_tab; thy end
79 fun percentage a b = if b = 0 then "N/A" else string_of_int (a * 100 div b)
80 fun percentage_alt a b = percentage a (a + b)
82 fun done id ({log, ...} : Mirabelle.done_args) =
83 if get id num_successes + get id num_failures > 0 then
85 log ("Number of overall successes: " ^
86 string_of_int (get id num_successes));
87 log ("Number of overall failures: " ^ string_of_int (get id num_failures));
88 log ("Overall success rate: " ^
89 percentage_alt (get id num_successes) (get id num_failures) ^ "%");
90 log ("Number of found proofs: " ^ string_of_int (get id num_found_proofs));
91 log ("Number of lost proofs: " ^ string_of_int (get id num_lost_proofs));
92 log ("Proof found rate: " ^
93 percentage_alt (get id num_found_proofs) (get id num_lost_proofs) ^
95 log ("Number of found facts: " ^ string_of_int (get id num_found_facts));
96 log ("Number of lost facts: " ^ string_of_int (get id num_lost_facts));
97 log ("Fact found rate: " ^
98 percentage_alt (get id num_found_facts) (get id num_lost_facts) ^
103 val default_prover = ATP_Systems.eN (* arbitrary ATP *)
105 fun with_index (i, s) = s ^ "@" ^ string_of_int i
107 fun action args id ({pre, pos, log, ...} : Mirabelle.run_args) =
108 case (Position.line_of pos, Position.offset_of pos) of
109 (SOME line_num, SOME offset) =>
110 (case Prooftab.lookup (!proof_table) (line_num, offset) of
113 val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
114 val prover = AList.lookup (op =) args "prover"
115 |> the_default default_prover
116 val params as {max_facts, slice, ...} =
117 Sledgehammer_Isar.default_params ctxt args
118 val default_max_facts =
119 Sledgehammer_Provers.default_max_facts_for_prover ctxt slice prover
120 val is_appropriate_prop =
121 Sledgehammer_Provers.is_appropriate_prop_for_prover ctxt
123 val relevance_fudge =
124 extract_relevance_fudge args
125 (Sledgehammer_Provers.relevance_fudge_for_prover ctxt prover)
127 val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal subgoal
128 val ho_atp = Sledgehammer_Provers.is_ho_atp ctxt prover
129 val reserved = Sledgehammer_Util.reserved_isar_keyword_table ()
130 val css_table = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
132 Sledgehammer_Fact.nearly_all_facts ctxt ho_atp
133 Sledgehammer_Fact.no_fact_override reserved css_table chained_ths
135 |> filter (is_appropriate_prop o prop_of o snd)
136 |> Sledgehammer_MePo.iterative_relevant_facts ctxt params
137 default_prover (the_default default_max_facts max_facts)
138 (SOME relevance_fudge) hyp_ts concl_t
139 |> map ((fn name => name ()) o fst o fst)
140 val (found_facts, lost_facts) =
141 flat proofs |> sort_distinct string_ord
142 |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
143 |> List.partition (curry (op <=) 0 o fst)
144 |>> sort (prod_ord int_ord string_ord) ||> map snd
145 val found_proofs = filter (forall (member (op =) facts)) proofs
146 val n = length found_proofs
149 (add id num_failures 1; log "Failure")
151 (add id num_successes 1;
152 add id num_found_proofs n;
153 log ("Success (" ^ string_of_int n ^ " of " ^
154 string_of_int (length proofs) ^ " proofs)"))
155 val _ = add id num_lost_proofs (length proofs - n)
156 val _ = add id num_found_facts (length found_facts)
157 val _ = add id num_lost_facts (length lost_facts)
159 if null found_facts then
164 Real.fromInt (fold (fn (n, _) =>
165 Integer.add (n * n)) found_facts 0)
166 / Real.fromInt (length found_facts)
167 |> Math.sqrt |> Real.ceil
169 log ("Found facts (among " ^ string_of_int (length facts) ^
170 ", weight " ^ string_of_int found_weight ^ "): " ^
171 commas (map with_index found_facts))
173 val _ = if null lost_facts then
176 log ("Lost facts (among " ^ string_of_int (length facts) ^
177 "): " ^ commas lost_facts)
179 | NONE => log "No known proof")
182 val proof_fileK = "proof_file"
186 val (pf_args, other_args) =
187 args |> List.partition (curry (op =) proof_fileK o fst)
188 val proof_file = case pf_args of
189 [] => error "No \"proof_file\" specified"
191 in Mirabelle.register (init proof_file, action other_args, done) end
195 (* Workaround to keep the "mirabelle.pl" script happy *)
196 structure Mirabelle_Sledgehammer_filter = Mirabelle_Sledgehammer_Filter;