merged
authorblanchet
Fri, 17 Dec 2010 21:32:06 +0100
changeset 41503a47133170dd0
parent 41500 4ae674714876
parent 41502 0e7d45cc005f
child 41504 73401632a80c
child 41546 aad679ca38d2
merged
     1.1 --- a/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Fri Dec 17 18:38:33 2010 +0100
     1.2 +++ b/src/HOL/Mirabelle/Tools/mirabelle_sledgehammer.ML	Fri Dec 17 21:32:06 2010 +0100
     1.3 @@ -577,6 +577,7 @@
     1.4  
     1.5  fun invoke args =
     1.6    let
     1.7 +    val _ = Sledgehammer_Run.show_facts_in_proofs := true
     1.8      val _ = Sledgehammer_Isar.full_types := AList.defined (op =) args full_typesK
     1.9    in Mirabelle.register (init, sledgehammer_action args, done) end
    1.10  
     2.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML	Fri Dec 17 18:38:33 2010 +0100
     2.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_minimize.ML	Fri Dec 17 21:32:06 2010 +0100
     2.3 @@ -10,9 +10,7 @@
     2.4    type locality = Sledgehammer_Filter.locality
     2.5    type params = Sledgehammer_Provers.params
     2.6  
     2.7 -  val filter_used_facts :
     2.8 -    (string * locality) list -> ((string * locality) * thm list) list
     2.9 -    -> ((string * locality) * thm list) list
    2.10 +  val filter_used_facts : ''a list -> (''a * 'b) list -> (''a * 'b) list
    2.11    val minimize_facts :
    2.12      params -> bool -> int -> int -> Proof.state
    2.13      -> ((string * locality) * thm list) list
     3.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Fri Dec 17 18:38:33 2010 +0100
     3.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML	Fri Dec 17 21:32:06 2010 +0100
     3.3 @@ -53,6 +53,12 @@
     3.4    type prover = params -> minimize_command -> prover_problem -> prover_result
     3.5  
     3.6    (* for experimentation purposes -- do not use in production code *)
     3.7 +  val smt_weights : bool Unsynchronized.ref
     3.8 +  val smt_weight_min_facts : int Unsynchronized.ref
     3.9 +  val smt_min_weight : int Unsynchronized.ref
    3.10 +  val smt_max_weight : int Unsynchronized.ref
    3.11 +  val smt_max_weight_index : int Unsynchronized.ref
    3.12 +  val smt_weight_curve : (int -> int) Unsynchronized.ref
    3.13    val smt_max_iters : int Unsynchronized.ref
    3.14    val smt_iter_fact_frac : real Unsynchronized.ref
    3.15    val smt_iter_time_frac : real Unsynchronized.ref
    3.16 @@ -73,9 +79,13 @@
    3.17    val dest_dir : string Config.T
    3.18    val problem_prefix : string Config.T
    3.19    val measure_run_time : bool Config.T
    3.20 +  val weight_smt_fact :
    3.21 +    theory -> int -> ((string * locality) * thm) * int
    3.22 +    -> (string * locality) * (int option * thm)
    3.23    val untranslated_fact : prover_fact -> (string * locality) * thm
    3.24    val smt_weighted_fact :
    3.25 -    prover_fact -> (string * locality) * (int option * thm)
    3.26 +    theory -> int -> prover_fact * int
    3.27 +    -> (string * locality) * (int option * thm)
    3.28    val available_provers : Proof.context -> unit
    3.29    val kill_provers : unit -> unit
    3.30    val running_provers : unit -> unit
    3.31 @@ -277,7 +287,26 @@
    3.32  fun proof_banner auto =
    3.33    if auto then "Auto Sledgehammer found a proof" else "Try this command"
    3.34  
    3.35 -(* generic TPTP-based ATPs *)
    3.36 +val smt_weights = Unsynchronized.ref true
    3.37 +val smt_weight_min_facts = Unsynchronized.ref 20
    3.38 +
    3.39 +(* FUDGE *)
    3.40 +val smt_min_weight = Unsynchronized.ref 0
    3.41 +val smt_max_weight = Unsynchronized.ref 10
    3.42 +val smt_max_weight_index = Unsynchronized.ref 200
    3.43 +val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x)
    3.44 +
    3.45 +fun smt_fact_weight j num_facts =
    3.46 +  if !smt_weights andalso num_facts >= !smt_weight_min_facts then
    3.47 +    SOME (!smt_max_weight
    3.48 +          - (!smt_max_weight - !smt_min_weight + 1)
    3.49 +            * !smt_weight_curve (Int.max (0, !smt_max_weight_index - j - 1))
    3.50 +            div !smt_weight_curve (!smt_max_weight_index))
    3.51 +  else
    3.52 +    NONE
    3.53 +
    3.54 +fun weight_smt_fact thy num_facts ((info, th), j) =
    3.55 +  (info, (smt_fact_weight j num_facts, th |> Thm.transfer thy))
    3.56  
    3.57  fun untranslated_fact (Untranslated_Fact p) = p
    3.58    | untranslated_fact (ATP_Translated_Fact (_, p)) = p
    3.59 @@ -285,8 +314,11 @@
    3.60  fun atp_translated_fact _ (ATP_Translated_Fact p) = p
    3.61    | atp_translated_fact ctxt fact =
    3.62      translate_atp_fact ctxt (untranslated_fact fact)
    3.63 -fun smt_weighted_fact (SMT_Weighted_Fact p) = p
    3.64 -  | smt_weighted_fact fact = untranslated_fact fact |> apsnd (pair NONE)
    3.65 +fun smt_weighted_fact _ _ (SMT_Weighted_Fact p, _) = p
    3.66 +  | smt_weighted_fact thy num_facts (fact, j) =
    3.67 +    (untranslated_fact fact, j) |> weight_smt_fact thy num_facts
    3.68 +
    3.69 +(* generic TPTP-based ATPs *)
    3.70  
    3.71  fun int_opt_add (SOME m) (SOME n) = SOME (m + n)
    3.72    | int_opt_add _ _ = NONE
    3.73 @@ -602,7 +634,10 @@
    3.74           : prover_problem) =
    3.75    let
    3.76      val ctxt = Proof.context_of state
    3.77 -    val facts = facts |> map smt_weighted_fact
    3.78 +    val thy = Proof.theory_of state
    3.79 +    val num_facts = length facts
    3.80 +    val facts = facts ~~ (0 upto num_facts - 1)
    3.81 +                |> map (smt_weighted_fact thy num_facts)
    3.82      val {outcome, used_facts, run_time_in_msecs} =
    3.83        smt_filter_loop name params state subgoal smt_head facts
    3.84      val (chained_lemmas, other_lemmas) = split_used_facts (map fst used_facts)
     4.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML	Fri Dec 17 18:38:33 2010 +0100
     4.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_run.ML	Fri Dec 17 21:32:06 2010 +0100
     4.3 @@ -13,12 +13,7 @@
     4.4    type params = Sledgehammer_Provers.params
     4.5  
     4.6    (* for experimentation purposes -- do not use in production code *)
     4.7 -  val smt_weights : bool Unsynchronized.ref
     4.8 -  val smt_weight_min_facts : int Unsynchronized.ref
     4.9 -  val smt_min_weight : int Unsynchronized.ref
    4.10 -  val smt_max_weight : int Unsynchronized.ref
    4.11 -  val smt_max_weight_index : int Unsynchronized.ref
    4.12 -  val smt_weight_curve : (int -> int) Unsynchronized.ref
    4.13 +  val show_facts_in_proofs : bool Unsynchronized.ref
    4.14  
    4.15    val run_sledgehammer :
    4.16      params -> bool -> int -> relevance_override -> (string -> minimize_command)
    4.17 @@ -47,6 +42,8 @@
    4.18     else
    4.19       "\n" ^ Syntax.string_of_term ctxt (Thm.term_of (Thm.cprem_of goal i)))
    4.20  
    4.21 +val show_facts_in_proofs = Unsynchronized.ref false
    4.22 +
    4.23  val implicit_minimization_threshold = 50
    4.24  
    4.25  fun run_prover (params as {debug, blocking, max_relevant, timeout, expect, ...})
    4.26 @@ -66,22 +63,38 @@
    4.27        {state = state, goal = goal, subgoal = subgoal,
    4.28         subgoal_count = subgoal_count, facts = take num_facts facts,
    4.29         smt_head = smt_head}
    4.30 +    fun really_go () =
    4.31 +      prover params (minimize_command name) problem
    4.32 +      |> (fn {outcome, used_facts, message, ...} =>
    4.33 +             if is_some outcome then
    4.34 +               ("none", message)
    4.35 +             else
    4.36 +               let
    4.37 +                 val (used_facts, message) =
    4.38 +                   if length used_facts >= implicit_minimization_threshold then
    4.39 +                     minimize_facts params true subgoal subgoal_count state
    4.40 +                         (filter_used_facts used_facts
    4.41 +                              (map (apsnd single o untranslated_fact) facts))
    4.42 +                     |>> Option.map (map fst)
    4.43 +                   else
    4.44 +                     (SOME used_facts, message)
    4.45 +                 val _ =
    4.46 +                   case (debug orelse !show_facts_in_proofs, used_facts) of
    4.47 +                     (true, SOME (used_facts as _ :: _)) =>
    4.48 +                     facts ~~ (0 upto length facts - 1)
    4.49 +                     |> map (fn (fact, j) =>
    4.50 +                                fact |> untranslated_fact |> apsnd (K j))
    4.51 +                     |> filter_used_facts used_facts
    4.52 +                     |> map (fn ((name, _), j) => name ^ "@" ^ string_of_int j)
    4.53 +                     |> commas
    4.54 +                     |> enclose ("Fact" ^ plural_s num_facts ^ " in " ^
    4.55 +                                 quote name ^ " proof (of " ^
    4.56 +                                 string_of_int num_facts ^ "): ") "."
    4.57 +                     |> Output.urgent_message
    4.58 +                   | _ => ()
    4.59 +               in ("some", message) end)
    4.60      fun go () =
    4.61        let
    4.62 -        fun really_go () =
    4.63 -          prover params (minimize_command name) problem
    4.64 -          |> (fn {outcome, used_facts, message, ...} =>
    4.65 -                 if is_some outcome then
    4.66 -                   ("none", message)
    4.67 -                 else
    4.68 -                   ("some",
    4.69 -                    if length used_facts >= implicit_minimization_threshold then
    4.70 -                      minimize_facts params true subgoal subgoal_count state
    4.71 -                          (filter_used_facts used_facts
    4.72 -                               (map (apsnd single o untranslated_fact) facts))
    4.73 -                      |> snd
    4.74 -                    else
    4.75 -                      message))
    4.76          val (outcome_code, message) =
    4.77            if debug then
    4.78              really_go ()
    4.79 @@ -124,27 +137,6 @@
    4.80         (false, state))
    4.81    end
    4.82  
    4.83 -val smt_weights = Unsynchronized.ref true
    4.84 -val smt_weight_min_facts = Unsynchronized.ref 20
    4.85 -
    4.86 -(* FUDGE *)
    4.87 -val smt_min_weight = Unsynchronized.ref 0
    4.88 -val smt_max_weight = Unsynchronized.ref 10
    4.89 -val smt_max_weight_index = Unsynchronized.ref 200
    4.90 -val smt_weight_curve = Unsynchronized.ref (fn x : int => x * x)
    4.91 -
    4.92 -fun smt_fact_weight j num_facts =
    4.93 -  if !smt_weights andalso num_facts >= !smt_weight_min_facts then
    4.94 -    SOME (!smt_max_weight
    4.95 -          - (!smt_max_weight - !smt_min_weight + 1)
    4.96 -            * !smt_weight_curve (Int.max (0, !smt_max_weight_index - j - 1))
    4.97 -            div !smt_weight_curve (!smt_max_weight_index))
    4.98 -  else
    4.99 -    NONE
   4.100 -
   4.101 -fun weight_smt_fact thy num_facts (fact, j) =
   4.102 -  fact |> apsnd (pair (smt_fact_weight j num_facts) o Thm.transfer thy)
   4.103 -
   4.104  fun class_of_smt_solver ctxt name =
   4.105    ctxt |> select_smt_solver name
   4.106         |> SMT_Config.solver_class_of |> SMT_Utils.string_of_class
   4.107 @@ -154,6 +146,9 @@
   4.108    | smart_par_list_map f [x] = [f x]
   4.109    | smart_par_list_map f xs = Par_List.map f xs
   4.110  
   4.111 +fun dest_SMT_Weighted_Fact (SMT_Weighted_Fact p) = p
   4.112 +  | dest_SMT_Weighted_Fact _ = raise Fail "dest_SMT_Weighted_Fact"
   4.113 +
   4.114  (* FUDGE *)
   4.115  val auto_max_relevant_divisor = 2
   4.116  
   4.117 @@ -181,32 +176,29 @@
   4.118                | NONE => ()
   4.119        val _ = if auto then () else Output.urgent_message "Sledgehammering..."
   4.120        val (smts, atps) = provers |> List.partition (is_smt_prover ctxt)
   4.121 -      fun run_provers get_facts translate maybe_smt_head provers
   4.122 -                      (res as (success, state)) =
   4.123 -        if success orelse null provers then
   4.124 -          res
   4.125 -        else
   4.126 -          let
   4.127 -            val facts = get_facts ()
   4.128 -            val num_facts = length facts
   4.129 -            val facts = facts ~~ (0 upto num_facts - 1)
   4.130 -                        |> map (translate num_facts)
   4.131 -            val problem =
   4.132 -              {state = state, goal = goal, subgoal = i, subgoal_count = n,
   4.133 -               facts = facts,
   4.134 -               smt_head = maybe_smt_head (map smt_weighted_fact facts) i}
   4.135 -            val run_prover = run_prover params auto minimize_command only
   4.136 -          in
   4.137 -            if auto then
   4.138 -              fold (fn prover => fn (true, state) => (true, state)
   4.139 -                                  | (false, _) => run_prover problem prover)
   4.140 -                   provers (false, state)
   4.141 -            else
   4.142 -              provers
   4.143 -              |> (if blocking then smart_par_list_map else map)
   4.144 -                     (run_prover problem)
   4.145 -              |> exists fst |> rpair state
   4.146 -          end
   4.147 +      fun run_provers state get_facts translate maybe_smt_head provers =
   4.148 +        let
   4.149 +          val facts = get_facts ()
   4.150 +          val num_facts = length facts
   4.151 +          val facts = facts ~~ (0 upto num_facts - 1)
   4.152 +                      |> map (translate num_facts)
   4.153 +          val problem =
   4.154 +            {state = state, goal = goal, subgoal = i, subgoal_count = n,
   4.155 +             facts = facts,
   4.156 +             smt_head = maybe_smt_head
   4.157 +                  (fn () => map_filter (try dest_SMT_Weighted_Fact) facts) i}
   4.158 +          val run_prover = run_prover params auto minimize_command only
   4.159 +        in
   4.160 +          if auto then
   4.161 +            fold (fn prover => fn (true, state) => (true, state)
   4.162 +                                | (false, _) => run_prover problem prover)
   4.163 +                 provers (false, state)
   4.164 +          else
   4.165 +            provers
   4.166 +            |> (if blocking then smart_par_list_map else map)
   4.167 +                   (run_prover problem)
   4.168 +            |> exists fst |> rpair state
   4.169 +        end
   4.170        fun get_facts label no_dangerous_types relevance_fudge provers =
   4.171          let
   4.172            val max_max_relevant =
   4.173 @@ -235,24 +227,27 @@
   4.174                       else
   4.175                         ())
   4.176          end
   4.177 -      val run_atps =
   4.178 -        run_provers
   4.179 -            (get_facts "ATP" no_dangerous_types atp_relevance_fudge o K atps)
   4.180 -            (ATP_Translated_Fact oo K (translate_atp_fact ctxt o fst))
   4.181 -            (K (K NONE)) atps
   4.182 +      fun run_atps (accum as (success, _)) =
   4.183 +        if success orelse null atps then
   4.184 +          accum
   4.185 +        else
   4.186 +          run_provers state
   4.187 +              (get_facts "ATP" no_dangerous_types atp_relevance_fudge o K atps)
   4.188 +              (ATP_Translated_Fact oo K (translate_atp_fact ctxt o fst))
   4.189 +              (K (K NONE)) atps
   4.190        fun run_smts (accum as (success, _)) =
   4.191          if success orelse null smts then
   4.192            accum
   4.193          else
   4.194            let
   4.195              val facts = get_facts "SMT solver" true smt_relevance_fudge smts
   4.196 -            val translate = SMT_Weighted_Fact oo weight_smt_fact thy
   4.197 -            val maybe_smt_head = try o SMT_Solver.smt_filter_head state
   4.198 +            val weight = SMT_Weighted_Fact oo weight_smt_fact thy
   4.199 +            fun smt_head facts =
   4.200 +              try (SMT_Solver.smt_filter_head state (facts ()))
   4.201            in
   4.202              smts |> map (`(class_of_smt_solver ctxt))
   4.203                   |> AList.group (op =)
   4.204 -                 |> map (fn (_, smts) => run_provers (K facts) translate
   4.205 -                                                     maybe_smt_head smts accum)
   4.206 +                 |> map (run_provers state (K facts) weight smt_head o snd)
   4.207                   |> exists fst |> rpair state
   4.208            end
   4.209        fun run_atps_and_smt_solvers () =