src/HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
author blanchet
Wed, 18 Jul 2012 08:44:03 +0200
changeset 49308 914ca0827804
parent 49307 7fcee834c7f5
child 49314 5e5c6616f0fe
permissions -rw-r--r--
renamed Sledgehammer options
wenzelm@48718
     1
(*  Title:      HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
blanchet@39118
     2
    Author:     Jasmin Blanchette, TU Munich
blanchet@39118
     3
*)
blanchet@39118
     4
blanchet@39118
     5
structure Mirabelle_Sledgehammer_Filter : MIRABELLE_ACTION =
blanchet@39118
     6
struct
blanchet@39118
     7
blanchet@40251
     8
fun get args name default_value =
blanchet@40251
     9
  case AList.lookup (op =) args name of
blanchet@40251
    10
    SOME value => the (Real.fromString value)
blanchet@40251
    11
  | NONE => default_value
blanchet@39236
    12
blanchet@40251
    13
fun extract_relevance_fudge args
blanchet@43603
    14
      {local_const_multiplier, worse_irrel_freq, higher_order_irrel_weight,
blanchet@43603
    15
       abs_rel_weight, abs_irrel_weight, skolem_irrel_weight,
blanchet@43603
    16
       theory_const_rel_weight, theory_const_irrel_weight,
blanchet@43600
    17
       chained_const_irrel_weight, intro_bonus, elim_bonus, simp_bonus,
blanchet@43603
    18
       local_bonus, assum_bonus, chained_bonus, max_imperfect, max_imperfect_exp,
blanchet@43603
    19
       threshold_divisor, ridiculous_threshold} =
blanchet@43603
    20
  {local_const_multiplier =
blanchet@42661
    21
       get args "local_const_multiplier" local_const_multiplier,
blanchet@41407
    22
   worse_irrel_freq = get args "worse_irrel_freq" worse_irrel_freq,
blanchet@40251
    23
   higher_order_irrel_weight =
blanchet@40251
    24
       get args "higher_order_irrel_weight" higher_order_irrel_weight,
blanchet@40251
    25
   abs_rel_weight = get args "abs_rel_weight" abs_rel_weight,
blanchet@40251
    26
   abs_irrel_weight = get args "abs_irrel_weight" abs_irrel_weight,
blanchet@40251
    27
   skolem_irrel_weight = get args "skolem_irrel_weight" skolem_irrel_weight,
blanchet@40251
    28
   theory_const_rel_weight =
blanchet@40251
    29
       get args "theory_const_rel_weight" theory_const_rel_weight,
blanchet@40251
    30
   theory_const_irrel_weight =
blanchet@40251
    31
       get args "theory_const_irrel_weight" theory_const_irrel_weight,
blanchet@43600
    32
   chained_const_irrel_weight =
blanchet@43600
    33
       get args "chained_const_irrel_weight" chained_const_irrel_weight,
blanchet@40251
    34
   intro_bonus = get args "intro_bonus" intro_bonus,
blanchet@40251
    35
   elim_bonus = get args "elim_bonus" elim_bonus,
blanchet@40251
    36
   simp_bonus = get args "simp_bonus" simp_bonus,
blanchet@40251
    37
   local_bonus = get args "local_bonus" local_bonus,
blanchet@40251
    38
   assum_bonus = get args "assum_bonus" assum_bonus,
blanchet@40251
    39
   chained_bonus = get args "chained_bonus" chained_bonus,
blanchet@40251
    40
   max_imperfect = get args "max_imperfect" max_imperfect,
blanchet@40251
    41
   max_imperfect_exp = get args "max_imperfect_exp" max_imperfect_exp,
blanchet@40251
    42
   threshold_divisor = get args "threshold_divisor" threshold_divisor,
blanchet@40251
    43
   ridiculous_threshold = get args "ridiculous_threshold" ridiculous_threshold}
blanchet@39128
    44
blanchet@39118
    45
structure Prooftab =
wenzelm@41655
    46
  Table(type key = int * int val ord = prod_ord int_ord int_ord)
blanchet@39118
    47
wenzelm@41655
    48
val proof_table = Unsynchronized.ref (Prooftab.empty: string list list Prooftab.table)
blanchet@39118
    49
blanchet@39122
    50
val num_successes = Unsynchronized.ref ([] : (int * int) list)
blanchet@39122
    51
val num_failures = Unsynchronized.ref ([] : (int * int) list)
blanchet@39122
    52
val num_found_proofs = Unsynchronized.ref ([] : (int * int) list)
blanchet@39122
    53
val num_lost_proofs = Unsynchronized.ref ([] : (int * int) list)
blanchet@39122
    54
val num_found_facts = Unsynchronized.ref ([] : (int * int) list)
blanchet@39122
    55
val num_lost_facts = Unsynchronized.ref ([] : (int * int) list)
blanchet@39122
    56
blanchet@39122
    57
fun get id c = the_default 0 (AList.lookup (op =) (!c) id)
blanchet@39122
    58
fun add id c n =
blanchet@39122
    59
  c := (case AList.lookup (op =) (!c) id of
blanchet@39122
    60
          SOME m => AList.update (op =) (id, m + n) (!c)
blanchet@39122
    61
        | NONE => (id, n) :: !c)
blanchet@39120
    62
blanchet@39120
    63
fun init proof_file _ thy =
blanchet@39118
    64
  let
blanchet@39118
    65
    fun do_line line =
blanchet@39118
    66
      case line |> space_explode ":" of
wenzelm@44581
    67
        [line_num, offset, proof] =>
wenzelm@44581
    68
        SOME (pairself (the o Int.fromString) (line_num, offset),
blanchet@39118
    69
              proof |> space_explode " " |> filter_out (curry (op =) ""))
blanchet@39118
    70
       | _ => NONE
blanchet@39120
    71
    val proofs = File.read (Path.explode proof_file)
blanchet@39118
    72
    val proof_tab =
blanchet@39118
    73
      proofs |> space_explode "\n"
blanchet@39118
    74
             |> map_filter do_line
blanchet@39118
    75
             |> AList.coalesce (op =)
blanchet@39118
    76
             |> Prooftab.make
blanchet@39118
    77
  in proof_table := proof_tab; thy end
blanchet@39118
    78
blanchet@39120
    79
fun percentage a b = if b = 0 then "N/A" else string_of_int (a * 100 div b)
blanchet@39120
    80
fun percentage_alt a b = percentage a (a + b)
blanchet@39120
    81
blanchet@39122
    82
fun done id ({log, ...} : Mirabelle.done_args) =
blanchet@39122
    83
  if get id num_successes + get id num_failures > 0 then
blanchet@39123
    84
    (log "";
blanchet@39123
    85
     log ("Number of overall successes: " ^
blanchet@39123
    86
          string_of_int (get id num_successes));
blanchet@39123
    87
     log ("Number of overall failures: " ^ string_of_int (get id num_failures));
blanchet@39120
    88
     log ("Overall success rate: " ^
blanchet@39122
    89
          percentage_alt (get id num_successes) (get id num_failures) ^ "%");
blanchet@39123
    90
     log ("Number of found proofs: " ^ string_of_int (get id num_found_proofs));
blanchet@39123
    91
     log ("Number of lost proofs: " ^ string_of_int (get id num_lost_proofs));
blanchet@39120
    92
     log ("Proof found rate: " ^
blanchet@39122
    93
          percentage_alt (get id num_found_proofs) (get id num_lost_proofs) ^
blanchet@39122
    94
          "%");
blanchet@39123
    95
     log ("Number of found facts: " ^ string_of_int (get id num_found_facts));
blanchet@39123
    96
     log ("Number of lost facts: " ^ string_of_int (get id num_lost_facts));
blanchet@39120
    97
     log ("Fact found rate: " ^
blanchet@39122
    98
          percentage_alt (get id num_found_facts) (get id num_lost_facts) ^
blanchet@39122
    99
          "%"))
blanchet@39120
   100
  else
blanchet@39120
   101
    ()
blanchet@39118
   102
blanchet@40461
   103
val default_prover = ATP_Systems.eN (* arbitrary ATP *)
blanchet@39118
   104
blanchet@39123
   105
fun with_index (i, s) = s ^ "@" ^ string_of_int i
blanchet@39123
   106
blanchet@39122
   107
fun action args id ({pre, pos, log, ...} : Mirabelle.run_args) =
wenzelm@44581
   108
  case (Position.line_of pos, Position.offset_of pos) of
wenzelm@44581
   109
    (SOME line_num, SOME offset) =>
wenzelm@44581
   110
    (case Prooftab.lookup (!proof_table) (line_num, offset) of
blanchet@39118
   111
       SOME proofs =>
blanchet@39118
   112
       let
blanchet@44217
   113
         val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
blanchet@40461
   114
         val prover = AList.lookup (op =) args "prover"
blanchet@40461
   115
                      |> the_default default_prover
blanchet@49308
   116
         val params as {max_facts, slice, ...} =
blanchet@43314
   117
           Sledgehammer_Isar.default_params ctxt args
blanchet@49308
   118
         val default_max_facts =
blanchet@49308
   119
           Sledgehammer_Provers.default_max_facts_for_prover ctxt slice prover
blanchet@43793
   120
         val is_appropriate_prop =
blanchet@43793
   121
           Sledgehammer_Provers.is_appropriate_prop_for_prover ctxt
blanchet@43793
   122
               default_prover
blanchet@40251
   123
         val relevance_fudge =
blanchet@40251
   124
           extract_relevance_fudge args
blanchet@41335
   125
               (Sledgehammer_Provers.relevance_fudge_for_prover ctxt prover)
blanchet@39118
   126
         val subgoal = 1
blanchet@43929
   127
         val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal subgoal
blanchet@45483
   128
         val ho_atp = Sledgehammer_Provers.is_ho_atp ctxt prover
blanchet@39118
   129
         val facts =
blanchet@49304
   130
           Sledgehammer_Fact.nearly_all_facts ctxt ho_atp
blanchet@49307
   131
               Sledgehammer_Fact.no_fact_override chained_ths hyp_ts concl_t
blanchet@49304
   132
           |> filter (is_appropriate_prop o prop_of o snd)
blanchet@49304
   133
           |> Sledgehammer_Filter_Iter.iterative_relevant_facts ctxt params
blanchet@49308
   134
                  default_prover (the_default default_max_facts max_facts)
blanchet@49307
   135
                  (SOME relevance_fudge) hyp_ts concl_t
blanchet@49304
   136
            |> map ((fn name => name ()) o fst o fst)
blanchet@39120
   137
         val (found_facts, lost_facts) =
wenzelm@41711
   138
           flat proofs |> sort_distinct string_ord
blanchet@39123
   139
           |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
blanchet@39123
   140
           |> List.partition (curry (op <=) 0 o fst)
blanchet@39123
   141
           |>> sort (prod_ord int_ord string_ord) ||> map snd
blanchet@39118
   142
         val found_proofs = filter (forall (member (op =) facts)) proofs
blanchet@39120
   143
         val n = length found_proofs
blanchet@39118
   144
         val _ =
blanchet@39120
   145
           if n = 0 then
blanchet@39122
   146
             (add id num_failures 1; log "Failure")
blanchet@39120
   147
           else
blanchet@39122
   148
             (add id num_successes 1;
blanchet@39122
   149
              add id num_found_proofs n;
blanchet@39123
   150
              log ("Success (" ^ string_of_int n ^ " of " ^
blanchet@39123
   151
                   string_of_int (length proofs) ^ " proofs)"))
blanchet@39122
   152
         val _ = add id num_lost_proofs (length proofs - n)
blanchet@39122
   153
         val _ = add id num_found_facts (length found_facts)
blanchet@39122
   154
         val _ = add id num_lost_facts (length lost_facts)
blanchet@39128
   155
         val _ =
blanchet@39128
   156
           if null found_facts then
blanchet@39128
   157
             ()
blanchet@39128
   158
           else
blanchet@39128
   159
             let
blanchet@39128
   160
               val found_weight =
blanchet@39128
   161
                 Real.fromInt (fold (fn (n, _) =>
blanchet@39128
   162
                                        Integer.add (n * n)) found_facts 0)
blanchet@39128
   163
                   / Real.fromInt (length found_facts)
blanchet@39128
   164
                 |> Math.sqrt |> Real.ceil
blanchet@39128
   165
             in
blanchet@39128
   166
               log ("Found facts (among " ^ string_of_int (length facts) ^
blanchet@39128
   167
                    ", weight " ^ string_of_int found_weight ^ "): " ^
blanchet@39128
   168
                    commas (map with_index found_facts))
blanchet@39128
   169
             end
blanchet@39128
   170
         val _ = if null lost_facts then
blanchet@39123
   171
                   ()
blanchet@39123
   172
                 else
blanchet@39128
   173
                   log ("Lost facts (among " ^ string_of_int (length facts) ^
blanchet@39128
   174
                        "): " ^ commas lost_facts)
blanchet@39118
   175
       in () end
blanchet@39120
   176
     | NONE => log "No known proof")
blanchet@39118
   177
  | _ => ()
blanchet@39118
   178
blanchet@39120
   179
val proof_fileK = "proof_file"
blanchet@39120
   180
blanchet@39120
   181
fun invoke args =
blanchet@39120
   182
  let
blanchet@39120
   183
    val (pf_args, other_args) =
blanchet@39120
   184
      args |> List.partition (curry (op =) proof_fileK o fst)
blanchet@39120
   185
    val proof_file = case pf_args of
blanchet@39120
   186
                       [] => error "No \"proof_file\" specified"
blanchet@39120
   187
                     | (_, s) :: _ => s
blanchet@39120
   188
  in Mirabelle.register (init proof_file, action other_args, done) end
blanchet@39118
   189
blanchet@39118
   190
end;
blanchet@39120
   191
blanchet@39120
   192
(* Workaround to keep the "mirabelle.pl" script happy *)
blanchet@39120
   193
structure Mirabelle_Sledgehammer_filter = Mirabelle_Sledgehammer_Filter;