src/HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
author blanchet
Mon, 02 May 2011 22:52:15 +0200
changeset 43509 a7a30721767a
parent 43460 9f7c48463645
child 43513 f5b4b9d4acda
permissions -rw-r--r--
have each ATP filter out dangerous facts for themselves, based on their type system
     1 (*  Title:      HOL/Mirabelle/Tools/mirabelle_sledgehammer_filter.ML
     2     Author:     Jasmin Blanchette, TU Munich
     3 *)
     4 
     5 structure Mirabelle_Sledgehammer_Filter : MIRABELLE_ACTION =
     6 struct
     7 
     8 fun get args name default_value =
     9   case AList.lookup (op =) args name of
    10     SOME value => the (Real.fromString value)
    11   | NONE => default_value
    12 
    13 fun extract_relevance_fudge args
    14       {allow_ext, local_const_multiplier, worse_irrel_freq,
    15        higher_order_irrel_weight, abs_rel_weight, abs_irrel_weight,
    16        skolem_irrel_weight, theory_const_rel_weight, theory_const_irrel_weight,
    17        intro_bonus, elim_bonus, simp_bonus, local_bonus, assum_bonus,
    18        chained_bonus, max_imperfect, max_imperfect_exp, threshold_divisor,
    19        ridiculous_threshold} =
    20   {allow_ext = allow_ext,
    21    local_const_multiplier =
    22        get args "local_const_multiplier" local_const_multiplier,
    23    worse_irrel_freq = get args "worse_irrel_freq" worse_irrel_freq,
    24    higher_order_irrel_weight =
    25        get args "higher_order_irrel_weight" higher_order_irrel_weight,
    26    abs_rel_weight = get args "abs_rel_weight" abs_rel_weight,
    27    abs_irrel_weight = get args "abs_irrel_weight" abs_irrel_weight,
    28    skolem_irrel_weight = get args "skolem_irrel_weight" skolem_irrel_weight,
    29    theory_const_rel_weight =
    30        get args "theory_const_rel_weight" theory_const_rel_weight,
    31    theory_const_irrel_weight =
    32        get args "theory_const_irrel_weight" theory_const_irrel_weight,
    33    intro_bonus = get args "intro_bonus" intro_bonus,
    34    elim_bonus = get args "elim_bonus" elim_bonus,
    35    simp_bonus = get args "simp_bonus" simp_bonus,
    36    local_bonus = get args "local_bonus" local_bonus,
    37    assum_bonus = get args "assum_bonus" assum_bonus,
    38    chained_bonus = get args "chained_bonus" chained_bonus,
    39    max_imperfect = get args "max_imperfect" max_imperfect,
    40    max_imperfect_exp = get args "max_imperfect_exp" max_imperfect_exp,
    41    threshold_divisor = get args "threshold_divisor" threshold_divisor,
    42    ridiculous_threshold = get args "ridiculous_threshold" ridiculous_threshold}
    43 
    44 structure Prooftab =
    45   Table(type key = int * int val ord = prod_ord int_ord int_ord)
    46 
    47 val proof_table = Unsynchronized.ref (Prooftab.empty: string list list Prooftab.table)
    48 
    49 val num_successes = Unsynchronized.ref ([] : (int * int) list)
    50 val num_failures = Unsynchronized.ref ([] : (int * int) list)
    51 val num_found_proofs = Unsynchronized.ref ([] : (int * int) list)
    52 val num_lost_proofs = Unsynchronized.ref ([] : (int * int) list)
    53 val num_found_facts = Unsynchronized.ref ([] : (int * int) list)
    54 val num_lost_facts = Unsynchronized.ref ([] : (int * int) list)
    55 
    56 fun get id c = the_default 0 (AList.lookup (op =) (!c) id)
    57 fun add id c n =
    58   c := (case AList.lookup (op =) (!c) id of
    59           SOME m => AList.update (op =) (id, m + n) (!c)
    60         | NONE => (id, n) :: !c)
    61 
    62 fun init proof_file _ thy =
    63   let
    64     fun do_line line =
    65       case line |> space_explode ":" of
    66         [line_num, col_num, proof] =>
    67         SOME (pairself (the o Int.fromString) (line_num, col_num),
    68               proof |> space_explode " " |> filter_out (curry (op =) ""))
    69        | _ => NONE
    70     val proofs = File.read (Path.explode proof_file)
    71     val proof_tab =
    72       proofs |> space_explode "\n"
    73              |> map_filter do_line
    74              |> AList.coalesce (op =)
    75              |> Prooftab.make
    76   in proof_table := proof_tab; thy end
    77 
    78 fun percentage a b = if b = 0 then "N/A" else string_of_int (a * 100 div b)
    79 fun percentage_alt a b = percentage a (a + b)
    80 
    81 fun done id ({log, ...} : Mirabelle.done_args) =
    82   if get id num_successes + get id num_failures > 0 then
    83     (log "";
    84      log ("Number of overall successes: " ^
    85           string_of_int (get id num_successes));
    86      log ("Number of overall failures: " ^ string_of_int (get id num_failures));
    87      log ("Overall success rate: " ^
    88           percentage_alt (get id num_successes) (get id num_failures) ^ "%");
    89      log ("Number of found proofs: " ^ string_of_int (get id num_found_proofs));
    90      log ("Number of lost proofs: " ^ string_of_int (get id num_lost_proofs));
    91      log ("Proof found rate: " ^
    92           percentage_alt (get id num_found_proofs) (get id num_lost_proofs) ^
    93           "%");
    94      log ("Number of found facts: " ^ string_of_int (get id num_found_facts));
    95      log ("Number of lost facts: " ^ string_of_int (get id num_lost_facts));
    96      log ("Fact found rate: " ^
    97           percentage_alt (get id num_found_facts) (get id num_lost_facts) ^
    98           "%"))
    99   else
   100     ()
   101 
   102 val default_prover = ATP_Systems.eN (* arbitrary ATP *)
   103 
   104 fun with_index (i, s) = s ^ "@" ^ string_of_int i
   105 
   106 fun action args id ({pre, pos, log, ...} : Mirabelle.run_args) =
   107   case (Position.line_of pos, Position.column_of pos) of
   108     (SOME line_num, SOME col_num) =>
   109     (case Prooftab.lookup (!proof_table) (line_num, col_num) of
   110        SOME proofs =>
   111        let
   112          val {context = ctxt, facts, goal} = Proof.goal pre
   113          val prover = AList.lookup (op =) args "prover"
   114                       |> the_default default_prover
   115          val {relevance_thresholds, type_sys, max_relevant, slicing, ...} =
   116            Sledgehammer_Isar.default_params ctxt args
   117          val default_max_relevant =
   118            Sledgehammer_Provers.default_max_relevant_for_prover ctxt slicing
   119                                                                 prover
   120          val is_built_in_const =
   121            Sledgehammer_Provers.is_built_in_const_for_prover ctxt default_prover
   122          val relevance_fudge =
   123            extract_relevance_fudge args
   124                (Sledgehammer_Provers.relevance_fudge_for_prover ctxt prover)
   125          val relevance_override = {add = [], del = [], only = false}
   126          val subgoal = 1
   127          val (_, hyp_ts, concl_t) = Sledgehammer_Util.strip_subgoal goal subgoal
   128          val facts =
   129            Sledgehammer_Filter.relevant_facts ctxt relevance_thresholds
   130                (the_default default_max_relevant max_relevant) is_built_in_const
   131                relevance_fudge relevance_override facts hyp_ts concl_t
   132            |> map (fst o fst)
   133          val (found_facts, lost_facts) =
   134            flat proofs |> sort_distinct string_ord
   135            |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
   136            |> List.partition (curry (op <=) 0 o fst)
   137            |>> sort (prod_ord int_ord string_ord) ||> map snd
   138          val found_proofs = filter (forall (member (op =) facts)) proofs
   139          val n = length found_proofs
   140          val _ =
   141            if n = 0 then
   142              (add id num_failures 1; log "Failure")
   143            else
   144              (add id num_successes 1;
   145               add id num_found_proofs n;
   146               log ("Success (" ^ string_of_int n ^ " of " ^
   147                    string_of_int (length proofs) ^ " proofs)"))
   148          val _ = add id num_lost_proofs (length proofs - n)
   149          val _ = add id num_found_facts (length found_facts)
   150          val _ = add id num_lost_facts (length lost_facts)
   151          val _ =
   152            if null found_facts then
   153              ()
   154            else
   155              let
   156                val found_weight =
   157                  Real.fromInt (fold (fn (n, _) =>
   158                                         Integer.add (n * n)) found_facts 0)
   159                    / Real.fromInt (length found_facts)
   160                  |> Math.sqrt |> Real.ceil
   161              in
   162                log ("Found facts (among " ^ string_of_int (length facts) ^
   163                     ", weight " ^ string_of_int found_weight ^ "): " ^
   164                     commas (map with_index found_facts))
   165              end
   166          val _ = if null lost_facts then
   167                    ()
   168                  else
   169                    log ("Lost facts (among " ^ string_of_int (length facts) ^
   170                         "): " ^ commas lost_facts)
   171        in () end
   172      | NONE => log "No known proof")
   173   | _ => ()
   174 
   175 val proof_fileK = "proof_file"
   176 
   177 fun invoke args =
   178   let
   179     val (pf_args, other_args) =
   180       args |> List.partition (curry (op =) proof_fileK o fst)
   181     val proof_file = case pf_args of
   182                        [] => error "No \"proof_file\" specified"
   183                      | (_, s) :: _ => s
   184   in Mirabelle.register (init proof_file, action other_args, done) end
   185 
   186 end;
   187 
   188 (* Workaround to keep the "mirabelle.pl" script happy *)
   189 structure Mirabelle_Sledgehammer_filter = Mirabelle_Sledgehammer_Filter;