src/HOL/Tools/Sledgehammer/sledgehammer_fact_filter.ML
changeset 38991 6628adcae4a7
parent 38990 01c4d14b2a61
child 39019 e46e7a9cb622
child 39051 21a6f261595e
equal deleted inserted replaced
38990:01c4d14b2a61 38991:6628adcae4a7
     3     Author:     Jasmin Blanchette, TU Muenchen
     3     Author:     Jasmin Blanchette, TU Muenchen
     4 *)
     4 *)
     5 
     5 
     6 signature SLEDGEHAMMER_FACT_FILTER =
     6 signature SLEDGEHAMMER_FACT_FILTER =
     7 sig
     7 sig
       
     8   datatype locality = General | Theory | Local | Chained
       
     9 
     8   type relevance_override =
    10   type relevance_override =
     9     {add: Facts.ref list,
    11     {add: Facts.ref list,
    10      del: Facts.ref list,
    12      del: Facts.ref list,
    11      only: bool}
    13      only: bool}
    12 
    14 
    13   val trace : bool Unsynchronized.ref
    15   val trace : bool Unsynchronized.ref
    14   val name_thm_pairs_from_ref :
    16   val name_thm_pairs_from_ref :
    15     Proof.context -> unit Symtab.table -> thm list -> Facts.ref
    17     Proof.context -> unit Symtab.table -> thm list -> Facts.ref
    16     -> ((unit -> string * bool) * (bool * thm)) list
    18     -> ((string * locality) * thm) list
    17   val relevant_facts :
    19   val relevant_facts :
    18     bool -> real * real -> int -> bool -> relevance_override
    20     bool -> real * real -> int -> bool -> relevance_override
    19     -> Proof.context * (thm list * 'a) -> term list -> term
    21     -> Proof.context * (thm list * 'a) -> term list -> term
    20     -> ((string * bool) * thm) list
    22     -> ((string * locality) * thm) list
    21 end;
    23 end;
    22 
    24 
    23 structure Sledgehammer_Fact_Filter : SLEDGEHAMMER_FACT_FILTER =
    25 structure Sledgehammer_Fact_Filter : SLEDGEHAMMER_FACT_FILTER =
    24 struct
    26 struct
    25 
    27 
    27 
    29 
    28 val trace = Unsynchronized.ref false
    30 val trace = Unsynchronized.ref false
    29 fun trace_msg msg = if !trace then tracing (msg ()) else ()
    31 fun trace_msg msg = if !trace then tracing (msg ()) else ()
    30 
    32 
    31 val respect_no_atp = true
    33 val respect_no_atp = true
       
    34 
       
    35 datatype locality = General | Theory | Local | Chained
    32 
    36 
    33 type relevance_override =
    37 type relevance_override =
    34   {add: Facts.ref list,
    38   {add: Facts.ref list,
    35    del: Facts.ref list,
    39    del: Facts.ref list,
    36    only: bool}
    40    only: bool}
    45   let
    49   let
    46     val ths = ProofContext.get_fact ctxt xref
    50     val ths = ProofContext.get_fact ctxt xref
    47     val name = Facts.string_of_ref xref
    51     val name = Facts.string_of_ref xref
    48     val multi = length ths > 1
    52     val multi = length ths > 1
    49   in
    53   in
    50     fold (fn th => fn (j, rest) =>
    54     (ths, (1, []))
    51              (j + 1, (fn () => (repair_name reserved multi j name,
    55     |-> fold (fn th => fn (j, rest) =>
    52                                 member Thm.eq_thm chained_ths th),
    56                  (j + 1, ((repair_name reserved multi j name,
    53                       (multi, th)) :: rest))
    57                           if member Thm.eq_thm chained_ths th then Chained
    54          ths (1, [])
    58                           else General), th) :: rest))
    55     |> snd
    59     |> snd
    56   end
    60   end
    57 
    61 
    58 (***************************************************************)
    62 (***************************************************************)
    59 (* Relevance Filtering                                         *)
    63 (* Relevance Filtering                                         *)
   243 (* TODO: experiment
   247 (* TODO: experiment
   244 fun irrel_log n = 0.5 + 1.0 / Math.ln (Real.fromInt n + 1.0)
   248 fun irrel_log n = 0.5 + 1.0 / Math.ln (Real.fromInt n + 1.0)
   245 *)
   249 *)
   246 fun irrel_log n = Math.ln (Real.fromInt n + 19.0) / 6.4
   250 fun irrel_log n = Math.ln (Real.fromInt n + 19.0) / 6.4
   247 
   251 
       
   252 (* FUDGE *)
       
   253 val skolem_weight = 1.0
       
   254 val abs_weight = 2.0
       
   255 
   248 (* Computes a constant's weight, as determined by its frequency. *)
   256 (* Computes a constant's weight, as determined by its frequency. *)
   249 val rel_weight = rel_log oo pseudoconst_freq match_pseudotypes
   257 val rel_weight = rel_log oo pseudoconst_freq match_pseudotypes
   250 fun irrel_weight const_tab (c as (s, _)) =
   258 fun irrel_weight const_tab (c as (s, _)) =
   251   if String.isPrefix skolem_prefix s then 1.0
   259   if String.isPrefix skolem_prefix s then skolem_weight
   252   else if String.isPrefix abs_prefix s then 2.0
   260   else if String.isPrefix abs_prefix s then abs_weight
   253   else irrel_log (pseudoconst_freq (match_pseudotypes o swap) const_tab c)
   261   else irrel_log (pseudoconst_freq (match_pseudotypes o swap) const_tab c)
   254 (* TODO: experiment
   262 (* TODO: experiment
   255 fun irrel_weight _ _ = 1.0
   263 fun irrel_weight _ _ = 1.0
   256 *)
   264 *)
   257 
   265 
   258 val chained_bonus_factor = 2.0
   266 (* FUDGE *)
   259 
   267 fun locality_multiplier General = 1.0
   260 fun axiom_weight chained const_tab relevant_consts axiom_consts =
   268   | locality_multiplier Theory = 1.1
       
   269   | locality_multiplier Local = 1.3
       
   270   | locality_multiplier Chained = 2.0
       
   271 
       
   272 fun axiom_weight loc const_tab relevant_consts axiom_consts =
   261   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
   273   case axiom_consts |> List.partition (pseudoconst_mem I relevant_consts)
   262                     ||> filter_out (pseudoconst_mem swap relevant_consts) of
   274                     ||> filter_out (pseudoconst_mem swap relevant_consts) of
   263     ([], []) => 0.0
   275     ([], []) => 0.0
   264   | (_, []) => 1.0
   276   | (_, []) => 1.0
   265   | (rel, irrel) =>
   277   | (rel, irrel) =>
   266     let
   278     let
   267       val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
   279       val rel_weight = fold (curry Real.+ o rel_weight const_tab) rel 0.0
   268                        |> chained ? curry Real.* chained_bonus_factor
   280                        |> curry Real.* (locality_multiplier loc)
   269       val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
   281       val irrel_weight = fold (curry Real.+ o irrel_weight const_tab) irrel 0.0
   270       val res = rel_weight / (rel_weight + irrel_weight)
   282       val res = rel_weight / (rel_weight + irrel_weight)
   271     in if Real.isFinite res then res else 0.0 end
   283     in if Real.isFinite res then res else 0.0 end
   272 
   284 
   273 (* TODO: experiment
   285 (* TODO: experiment
   292 fun pair_consts_axiom theory_relevant thy axiom =
   304 fun pair_consts_axiom theory_relevant thy axiom =
   293   (axiom, axiom |> snd |> theory_const_prop_of theory_relevant
   305   (axiom, axiom |> snd |> theory_const_prop_of theory_relevant
   294                 |> pseudoconsts_of_term thy)
   306                 |> pseudoconsts_of_term thy)
   295 
   307 
   296 type annotated_thm =
   308 type annotated_thm =
   297   ((unit -> string * bool) * thm) * (string * pseudotype list) list
   309   (((unit -> string) * locality) * thm) * (string * pseudotype list) list
   298 
   310 
   299 fun take_most_relevant max_max_imperfect max_relevant remaining_max
   311 fun take_most_relevant max_max_imperfect max_relevant remaining_max
   300                        (candidates : (annotated_thm * real) list) =
   312                        (candidates : (annotated_thm * real) list) =
   301   let
   313   let
   302     val max_imperfect =
   314     val max_imperfect =
   313                         string_of_int (length candidates));
   325                         string_of_int (length candidates));
   314     trace_msg (fn () => "Effective threshold: " ^
   326     trace_msg (fn () => "Effective threshold: " ^
   315                         Real.toString (#2 (hd accepts)));
   327                         Real.toString (#2 (hd accepts)));
   316     trace_msg (fn () => "Actually passed (" ^ Int.toString (length accepts) ^
   328     trace_msg (fn () => "Actually passed (" ^ Int.toString (length accepts) ^
   317         "): " ^ (accepts
   329         "): " ^ (accepts
   318                  |> map (fn (((name, _), _), weight) =>
   330                  |> map (fn ((((name, _), _), _), weight) =>
   319                             fst (name ()) ^ " [" ^ Real.toString weight ^ "]")
   331                             name () ^ " [" ^ Real.toString weight ^ "]")
   320                  |> commas));
   332                  |> commas));
   321     (accepts, more_rejects @ rejects)
   333     (accepts, more_rejects @ rejects)
   322   end
   334   end
   323 
   335 
       
   336 (* FUDGE *)
   324 val threshold_divisor = 2.0
   337 val threshold_divisor = 2.0
   325 val ridiculous_threshold = 0.1
   338 val ridiculous_threshold = 0.1
   326 val max_max_imperfect_fudge_factor = 0.66
   339 val max_max_imperfect_fudge_factor = 0.66
   327 
   340 
   328 fun relevance_filter ctxt threshold0 decay max_relevant theory_relevant
   341 fun relevance_filter ctxt threshold0 decay max_relevant theory_relevant
   390                else
   403                else
   391                  iter (j + 1) remaining_max threshold rel_const_tab'
   404                  iter (j + 1) remaining_max threshold rel_const_tab'
   392                       hopeless_rejects hopeful_rejects)
   405                       hopeless_rejects hopeful_rejects)
   393             end
   406             end
   394           | relevant candidates rejects hopeless
   407           | relevant candidates rejects hopeless
   395                      (((ax as ((name, th), axiom_consts)), cached_weight)
   408                      (((ax as (((_, loc), th), axiom_consts)), cached_weight)
   396                       :: hopeful) =
   409                       :: hopeful) =
   397             let
   410             let
   398               val weight =
   411               val weight =
   399                 case cached_weight of
   412                 case cached_weight of
   400                   SOME w => w
   413                   SOME w => w
   401                 | NONE => axiom_weight (snd (name ())) const_tab rel_const_tab
   414                 | NONE => axiom_weight loc const_tab rel_const_tab axiom_consts
   402                                        axiom_consts
       
   403 (* TODO: experiment
   415 (* TODO: experiment
   404 val _ = if String.isPrefix "lift.simps(3" (fst (name ())) then
   416 val name = fst (fst (fst ax)) ()
   405 tracing ("*** " ^ (fst (name ())) ^ PolyML.makestring (debug_axiom_weight const_tab rel_const_tab axiom_consts))
   417 val _ = if String.isPrefix "lift.simps(3" name then
       
   418 tracing ("*** " ^ name ^ PolyML.makestring (debug_axiom_weight const_tab rel_const_tab axiom_consts))
   406 else
   419 else
   407 ()
   420 ()
   408 *)
   421 *)
   409             in
   422             in
   410               if weight >= threshold then
   423               if weight >= threshold then
   568     is_strange_theorem thm
   581     is_strange_theorem thm
   569   end
   582   end
   570 
   583 
   571 fun all_name_thms_pairs ctxt reserved full_types add_thms chained_ths =
   584 fun all_name_thms_pairs ctxt reserved full_types add_thms chained_ths =
   572   let
   585   let
   573     val is_chained = member Thm.eq_thm chained_ths
   586     val thy = ProofContext.theory_of ctxt
   574     val global_facts = PureThy.facts_of (ProofContext.theory_of ctxt)
   587     val thy_prefix = Context.theory_name thy ^ Long_Name.separator
       
   588     val global_facts = PureThy.facts_of thy
   575     val local_facts = ProofContext.facts_of ctxt
   589     val local_facts = ProofContext.facts_of ctxt
   576     val named_locals = local_facts |> Facts.dest_static []
   590     val named_locals = local_facts |> Facts.dest_static []
       
   591     val is_chained = member Thm.eq_thm chained_ths
   577     (* Unnamed, not chained formulas with schematic variables are omitted,
   592     (* Unnamed, not chained formulas with schematic variables are omitted,
   578        because they are rejected by the backticks (`...`) parser for some
   593        because they are rejected by the backticks (`...`) parser for some
   579        reason. *)
   594        reason. *)
   580     fun is_good_unnamed_local th =
   595     fun is_good_unnamed_local th =
   581       forall (fn (_, ths) => not (member Thm.eq_thm ths th)) named_locals
   596       forall (fn (_, ths) => not (member Thm.eq_thm ths th)) named_locals
   583     val unnamed_locals =
   598     val unnamed_locals =
   584       local_facts |> Facts.props |> filter is_good_unnamed_local
   599       local_facts |> Facts.props |> filter is_good_unnamed_local
   585                   |> map (pair "" o single)
   600                   |> map (pair "" o single)
   586     val full_space =
   601     val full_space =
   587       Name_Space.merge (Facts.space_of global_facts, Facts.space_of local_facts)
   602       Name_Space.merge (Facts.space_of global_facts, Facts.space_of local_facts)
   588     fun add_valid_facts foldx facts =
   603     fun add_facts global foldx facts =
   589       foldx (fn (name0, ths) =>
   604       foldx (fn (name0, ths) =>
   590         if name0 <> "" andalso
   605         if name0 <> "" andalso
   591            forall (not o member Thm.eq_thm add_thms) ths andalso
   606            forall (not o member Thm.eq_thm add_thms) ths andalso
   592            (Facts.is_concealed facts name0 orelse
   607            (Facts.is_concealed facts name0 orelse
   593             (respect_no_atp andalso is_package_def name0) orelse
   608             (respect_no_atp andalso is_package_def name0) orelse
   594             exists (fn s => String.isSuffix s name0) multi_base_blacklist orelse
   609             exists (fn s => String.isSuffix s name0) multi_base_blacklist orelse
   595             String.isSuffix "_def_raw" (* FIXME: crude hack *) name0) then
   610             String.isSuffix "_def_raw" (* FIXME: crude hack *) name0) then
   596           I
   611           I
   597         else
   612         else
   598           let
   613           let
       
   614             val base_loc =
       
   615               if not global then Local
       
   616               else if String.isPrefix thy_prefix name0 then Theory
       
   617               else General
   599             val multi = length ths > 1
   618             val multi = length ths > 1
   600             fun backquotify th =
   619             fun backquotify th =
   601               "`" ^ Print_Mode.setmp [Print_Mode.input]
   620               "`" ^ Print_Mode.setmp [Print_Mode.input]
   602                                  (Syntax.string_of_term ctxt) (prop_of th) ^ "`"
   621                                  (Syntax.string_of_term ctxt) (prop_of th) ^ "`"
   603               |> String.translate (fn c => if Char.isPrint c then str c else "")
   622               |> String.translate (fn c => if Char.isPrint c then str c else "")
   612                  (j + 1,
   631                  (j + 1,
   613                   if is_theorem_bad_for_atps full_types th andalso
   632                   if is_theorem_bad_for_atps full_types th andalso
   614                      not (member Thm.eq_thm add_thms th) then
   633                      not (member Thm.eq_thm add_thms th) then
   615                     rest
   634                     rest
   616                   else
   635                   else
   617                     (fn () =>
   636                     (((fn () =>
   618                         (if name0 = "" then
   637                           if name0 = "" then
   619                            th |> backquotify
   638                             th |> backquotify
   620                          else
   639                           else
   621                            let
   640                             let
   622                              val name1 = Facts.extern facts name0
   641                               val name1 = Facts.extern facts name0
   623                              val name2 = Name_Space.extern full_space name0
   642                               val name2 = Name_Space.extern full_space name0
   624                            in
   643                             in
   625                              case find_first check_thms [name1, name2, name0] of
   644                               case find_first check_thms [name1, name2, name0] of
   626                                SOME name => repair_name reserved multi j name
   645                                 SOME name => repair_name reserved multi j name
   627                              | NONE => ""
   646                               | NONE => ""
   628                            end, is_chained th), (multi, th)) :: rest)) ths
   647                             end), if is_chained th then Chained else base_loc),
       
   648                       (multi, th)) :: rest)) ths
   629             #> snd
   649             #> snd
   630           end)
   650           end)
   631   in
   651   in
   632     [] |> add_valid_facts fold local_facts (unnamed_locals @ named_locals)
   652     [] |> add_facts false fold local_facts (unnamed_locals @ named_locals)
   633        |> add_valid_facts Facts.fold_static global_facts global_facts
   653        |> add_facts true Facts.fold_static global_facts global_facts
   634   end
   654   end
   635 
   655 
   636 (* The single-name theorems go after the multiple-name ones, so that single
   656 (* The single-name theorems go after the multiple-name ones, so that single
   637    names are preferred when both are available. *)
   657    names are preferred when both are available. *)
   638 fun name_thm_pairs ctxt respect_no_atp =
   658 fun name_thm_pairs ctxt respect_no_atp =
   651                                 1.0 / Real.fromInt (max_relevant + 1))
   671                                 1.0 / Real.fromInt (max_relevant + 1))
   652     val add_thms = maps (ProofContext.get_fact ctxt) add
   672     val add_thms = maps (ProofContext.get_fact ctxt) add
   653     val reserved = reserved_isar_keyword_table ()
   673     val reserved = reserved_isar_keyword_table ()
   654     val axioms =
   674     val axioms =
   655       (if only then
   675       (if only then
   656          maps (name_thm_pairs_from_ref ctxt reserved chained_ths) add
   676          maps (map (fn ((name, loc), th) => ((K name, loc), (true, th)))
       
   677                o name_thm_pairs_from_ref ctxt reserved chained_ths) add
   657        else
   678        else
   658          all_name_thms_pairs ctxt reserved full_types add_thms chained_ths)
   679          all_name_thms_pairs ctxt reserved full_types add_thms chained_ths)
   659       |> name_thm_pairs ctxt (respect_no_atp andalso not only)
   680       |> name_thm_pairs ctxt (respect_no_atp andalso not only)
   660       |> make_unique
   681       |> make_unique
   661   in
   682   in
   666      else if threshold0 < 0.0 then
   687      else if threshold0 < 0.0 then
   667        axioms
   688        axioms
   668      else
   689      else
   669        relevance_filter ctxt threshold0 decay max_relevant theory_relevant
   690        relevance_filter ctxt threshold0 decay max_relevant theory_relevant
   670                         relevance_override axioms (concl_t :: hyp_ts))
   691                         relevance_override axioms (concl_t :: hyp_ts))
   671     |> map (apfst (fn f => f ())) |> sort_wrt (fst o fst)
   692     |> map (apfst (apfst (fn f => f ()))) |> sort_wrt (fst o fst)
   672   end
   693   end
   673 
   694 
   674 end;
   695 end;