src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 49419 0a261b4aa093
parent 49418 1f214c653c80
child 49421 b002cc16aa99
     1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Jul 20 22:19:46 2012 +0200
     1.3 @@ -41,14 +41,16 @@
     1.4      Proof.context -> params -> string -> fact list -> thm -> prover_result
     1.5    val features_of :
     1.6      Proof.context -> string -> theory -> stature -> term list -> string list
     1.7 -  val isar_dependencies_of : unit Symtab.table -> thm -> string list
     1.8 +  val isar_dependencies_of : unit Symtab.table -> thm -> string list option
     1.9    val atp_dependencies_of :
    1.10 -    Proof.context -> params -> string -> bool -> fact list -> unit Symtab.table
    1.11 -    -> thm -> string list
    1.12 +    Proof.context -> params -> string -> int -> fact list -> unit Symtab.table
    1.13 +    -> thm -> string list option
    1.14    val mash_CLEAR : Proof.context -> unit
    1.15    val mash_ADD :
    1.16      Proof.context -> bool
    1.17      -> (string * string list * string list * string list) list -> unit
    1.18 +  val mash_REPROVE :
    1.19 +    Proof.context -> bool -> (string * string list) list -> unit
    1.20    val mash_QUERY :
    1.21      Proof.context -> bool -> int -> string list * string list -> string list
    1.22    val mash_unlearn : Proof.context -> unit
    1.23 @@ -60,9 +62,6 @@
    1.24    val mash_learn_proof :
    1.25      Proof.context -> params -> string -> term -> ('a * thm) list -> thm list
    1.26      -> unit
    1.27 -  val mash_learn_facts :
    1.28 -    Proof.context -> params -> string -> bool -> bool -> Time.time -> fact list
    1.29 -    -> string
    1.30    val mash_learn :
    1.31      Proof.context -> params -> fact_override -> thm list -> bool -> unit
    1.32    val relevant_facts :
    1.33 @@ -320,14 +319,21 @@
    1.34        | Simp => cons "simp"
    1.35        | Def => cons "def")
    1.36  
    1.37 -fun isar_dependencies_of all_facts = thms_in_proof (SOME all_facts)
    1.38 +(* Too many dependencies is a sign that a decision procedure is at work. There
    1.39 +   isn't much too learn from such proofs. *)
    1.40 +val max_dependencies = 10
    1.41 +val atp_dependency_default_max_fact = 50
    1.42  
    1.43 -val atp_dep_default_max_fact = 50
    1.44 +fun trim_dependencies deps =
    1.45 +  if length deps <= max_dependencies then SOME deps else NONE
    1.46  
    1.47 -fun atp_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover auto
    1.48 -                        facts all_names th =
    1.49 +fun isar_dependencies_of all_facts =
    1.50 +  thms_in_proof (SOME all_facts) #> trim_dependencies
    1.51 +
    1.52 +fun atp_dependencies_of ctxt (params as {verbose, max_facts, ...}) prover
    1.53 +                        auto_level facts all_names th =
    1.54    case isar_dependencies_of all_names th of
    1.55 -    [] => []
    1.56 +    SOME [] => NONE
    1.57    | isar_deps =>
    1.58      let
    1.59        val thy = Proof_Context.theory_of ctxt
    1.60 @@ -344,12 +350,12 @@
    1.61          | NONE => accum (* shouldn't happen *)
    1.62        val facts =
    1.63          facts |> iterative_relevant_facts ctxt params prover
    1.64 -                     (max_facts |> the_default atp_dep_default_max_fact) NONE
    1.65 -                     hyp_ts concl_t
    1.66 -              |> fold (add_isar_dep facts) isar_deps
    1.67 +                     (max_facts |> the_default atp_dependency_default_max_fact)
    1.68 +                     NONE hyp_ts concl_t
    1.69 +              |> fold (add_isar_dep facts) (these isar_deps)
    1.70                |> map fix_name
    1.71      in
    1.72 -      if verbose andalso not auto then
    1.73 +      if verbose andalso auto_level = 0 then
    1.74          let val num_facts = length facts in
    1.75            "MaSh: " ^ quote prover ^ " on " ^ quote (nickname_of th) ^
    1.76            " with " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
    1.77 @@ -360,7 +366,7 @@
    1.78          ();
    1.79        case run_prover_for_mash ctxt params prover facts goal of
    1.80          {outcome = NONE, used_facts, ...} =>
    1.81 -        (if verbose andalso not auto then
    1.82 +        (if verbose andalso auto_level = 0 then
    1.83             let val num_facts = length used_facts in
    1.84               "Found proof with " ^ string_of_int num_facts ^ " fact" ^
    1.85               plural_s num_facts ^ "."
    1.86 @@ -368,8 +374,8 @@
    1.87             end
    1.88           else
    1.89             ();
    1.90 -         used_facts |> map fst)
    1.91 -      | _ => isar_deps
    1.92 +         used_facts |> map fst |> trim_dependencies)
    1.93 +      | _ => NONE
    1.94      end
    1.95  
    1.96  
    1.97 @@ -418,10 +424,13 @@
    1.98                                 [err_file, sugg_file, cmd_file])
    1.99    end
   1.100  
   1.101 -fun str_of_update (name, parents, feats, deps) =
   1.102 +fun str_of_add (name, parents, feats, deps) =
   1.103    "! " ^ escape_meta name ^ ": " ^ escape_metas parents ^ "; " ^
   1.104    escape_metas feats ^ "; " ^ escape_metas deps ^ "\n"
   1.105  
   1.106 +fun str_of_reprove (name, deps) =
   1.107 +  "p " ^ escape_meta name ^ ": " ^ escape_metas deps ^ "\n"
   1.108 +
   1.109  fun str_of_query (parents, feats) =
   1.110    "? " ^ escape_metas parents ^ "; " ^ escape_metas feats
   1.111  
   1.112 @@ -435,10 +444,16 @@
   1.113    end
   1.114  
   1.115  fun mash_ADD _ _ [] = ()
   1.116 -  | mash_ADD ctxt overlord upds =
   1.117 +  | mash_ADD ctxt overlord adds =
   1.118      (trace_msg ctxt (fn () => "MaSh ADD " ^
   1.119 -         elide_string 1000 (space_implode " " (map #1 upds)));
   1.120 -     run_mash_tool ctxt overlord true 0 (upds, str_of_update) (K ()))
   1.121 +         elide_string 1000 (space_implode " " (map #1 adds)));
   1.122 +     run_mash_tool ctxt overlord true 0 (adds, str_of_add) (K ()))
   1.123 +
   1.124 +fun mash_REPROVE _ _ [] = ()
   1.125 +  | mash_REPROVE ctxt overlord reps =
   1.126 +    (trace_msg ctxt (fn () => "MaSh REPROVE " ^
   1.127 +         elide_string 1000 (space_implode " " (map #1 reps)));
   1.128 +     run_mash_tool ctxt overlord true 0 (reps, str_of_reprove) (K ()))
   1.129  
   1.130  fun mash_QUERY ctxt overlord max_suggs (query as (_, feats)) =
   1.131    (trace_msg ctxt (fn () => "MaSh QUERY " ^ space_implode " " feats);
   1.132 @@ -584,7 +599,7 @@
   1.133      val unknown = facts |> filter_out (is_fact_in_graph fact_G)
   1.134    in (selected, unknown) end
   1.135  
   1.136 -fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
   1.137 +fun add_to_fact_graph ctxt (name, parents, feats, deps) (adds, graph) =
   1.138    let
   1.139      fun maybe_add_from from (accum as (parents, graph)) =
   1.140        try_graph ctxt "updating graph" accum (fn () =>
   1.141 @@ -592,7 +607,7 @@
   1.142      val graph = graph |> Graph.default_node (name, ())
   1.143      val (parents, graph) = ([], graph) |> fold maybe_add_from parents
   1.144      val (deps, graph) = ([], graph) |> fold maybe_add_from deps
   1.145 -  in ((name, parents, feats, deps) :: upds, graph) end
   1.146 +  in ((name, parents, feats, deps) :: adds, graph) end
   1.147  
   1.148  val learn_timeout_slack = 2.0
   1.149  
   1.150 @@ -628,14 +643,11 @@
   1.151  fun sendback sub =
   1.152    Markup.markup Isabelle_Markup.sendback (sledgehammerN ^ " " ^ sub)
   1.153  
   1.154 -(* Too many dependencies is a sign that a decision procedure is at work. There
   1.155 -   isn't much too learn from such proofs. *)
   1.156 -val max_dependencies = 10
   1.157  val commit_timeout = seconds 30.0
   1.158  
   1.159  (* The timeout is understood in a very slack fashion. *)
   1.160 -fun mash_learn_facts ctxt (params as {debug, verbose, overlord, timeout, ...})
   1.161 -                     prover auto atp learn_timeout facts =
   1.162 +fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
   1.163 +                     auto_level atp learn_timeout facts =
   1.164    let
   1.165      val timer = Timer.startRealTimer ()
   1.166      fun next_commit_time () =
   1.167 @@ -644,86 +656,123 @@
   1.168      val (old_facts, new_facts) =
   1.169        facts |> List.partition (is_fact_in_graph fact_G)
   1.170              ||> sort (thm_ord o pairself snd)
   1.171 -    val num_new_facts = length new_facts
   1.172    in
   1.173 -    (if not auto then
   1.174 -       "MaShing" ^
   1.175 -       (if not auto then
   1.176 -          " " ^ string_of_int num_new_facts ^ " fact" ^
   1.177 -          plural_s num_new_facts ^
   1.178 -          (if atp then " (ATP timeout: " ^ string_from_time timeout ^ ")"
   1.179 -           else "")
   1.180 -        else
   1.181 -          "") ^ "..."
   1.182 -     else
   1.183 -       "")
   1.184 -    |> Output.urgent_message;
   1.185 -    if num_new_facts = 0 then
   1.186 -      if not auto then
   1.187 -        "Nothing to learn.\n\nHint: Try " ^ sendback relearn_isarN ^ " or " ^
   1.188 -        sendback relearn_atpN ^ " to learn from scratch."
   1.189 +    if null new_facts andalso (not atp orelse null old_facts) then
   1.190 +      if auto_level < 2 then
   1.191 +        "No new " ^ (if atp then "ATP" else "Isar") ^ " proofs to learn." ^
   1.192 +        (if auto_level = 0 andalso not atp then
   1.193 +           "\n\nHint: Try " ^ sendback learn_atpN ^ " to learn from ATP proofs."
   1.194 +         else
   1.195 +           "")
   1.196        else
   1.197          ""
   1.198      else
   1.199        let
   1.200 -        val last_th = new_facts |> List.last |> snd
   1.201 -        (* crude approximation *)
   1.202 -        val ancestors =
   1.203 -          old_facts |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   1.204          val all_names =
   1.205            facts |> map snd
   1.206                  |> filter_out is_likely_tautology_or_too_meta
   1.207                  |> map (rpair () o nickname_of)
   1.208                  |> Symtab.make
   1.209 -        fun do_commit [] state = state
   1.210 -          | do_commit upds {fact_G} =
   1.211 +        val deps_of =
   1.212 +          if atp then
   1.213 +            atp_dependencies_of ctxt params prover auto_level facts all_names
   1.214 +          else
   1.215 +            isar_dependencies_of all_names
   1.216 +        fun do_commit [] [] state = state
   1.217 +          | do_commit adds reps {fact_G} =
   1.218              let
   1.219 -              val (upds, fact_G) =
   1.220 -                ([], fact_G) |> fold (update_fact_graph ctxt) upds
   1.221 -            in mash_ADD ctxt overlord (rev upds); {fact_G = fact_G} end
   1.222 -        fun trim_deps deps = if length deps > max_dependencies then [] else deps
   1.223 -        fun commit last upds =
   1.224 -          (if debug andalso not auto then Output.urgent_message "Committing..."
   1.225 -           else ();
   1.226 -           mash_map ctxt (do_commit (rev upds));
   1.227 -           if not last andalso not auto then
   1.228 -             let val num_upds = length upds in
   1.229 -               "Processed " ^ string_of_int num_upds ^ " fact" ^
   1.230 -               plural_s num_upds ^ " in the last " ^
   1.231 +              val (adds, fact_G) =
   1.232 +                ([], fact_G) |> fold (add_to_fact_graph ctxt) adds
   1.233 +            in
   1.234 +              mash_ADD ctxt overlord (rev adds);
   1.235 +              mash_REPROVE ctxt overlord reps;
   1.236 +              {fact_G = fact_G}
   1.237 +            end
   1.238 +        fun commit last adds reps =
   1.239 +          (if debug andalso auto_level = 0 then
   1.240 +             Output.urgent_message "Committing..."
   1.241 +           else
   1.242 +             ();
   1.243 +           mash_map ctxt (do_commit (rev adds) reps);
   1.244 +           if not last andalso auto_level = 0 then
   1.245 +             let val num_proofs = length adds + length reps in
   1.246 +               "Learned " ^ string_of_int num_proofs ^ " " ^
   1.247 +               (if atp then "ATP" else "Isar") ^ " proof" ^
   1.248 +               plural_s num_proofs ^ " in the last " ^
   1.249                 string_from_time commit_timeout ^ "."
   1.250                 |> Output.urgent_message
   1.251               end
   1.252             else
   1.253               ())
   1.254 -        fun do_fact _ (accum as (_, (_, _, _, true))) = accum
   1.255 -          | do_fact ((_, stature), th)
   1.256 -                    (upds, (parents, n, next_commit, false)) =
   1.257 +        fun learn_new_fact _ (accum as (_, (_, _, _, true))) = accum
   1.258 +          | learn_new_fact ((_, stature), th)
   1.259 +                           (adds, (parents, n, next_commit, _)) =
   1.260              let
   1.261                val name = nickname_of th
   1.262                val feats =
   1.263                  features_of ctxt prover (theory_of_thm th) stature [prop_of th]
   1.264 -              val deps =
   1.265 -                (if atp then atp_dependencies_of ctxt params prover auto facts
   1.266 -                 else isar_dependencies_of) all_names th
   1.267 -                |> trim_deps
   1.268 +              val deps = deps_of th |> these
   1.269                val n = n |> not (null deps) ? Integer.add 1
   1.270 -              val upds = (name, parents, feats, deps) :: upds
   1.271 -              val (upds, next_commit) =
   1.272 +              val adds = (name, parents, feats, deps) :: adds
   1.273 +              val (adds, next_commit) =
   1.274                  if Time.> (Timer.checkRealTimer timer, next_commit) then
   1.275 -                  (commit false upds; ([], next_commit_time ()))
   1.276 +                  (commit false adds []; ([], next_commit_time ()))
   1.277                  else
   1.278 -                  (upds, next_commit)
   1.279 -              val timed_out =
   1.280 -                Time.> (Timer.checkRealTimer timer, learn_timeout)
   1.281 -            in (upds, ([name], n, next_commit, timed_out)) end
   1.282 -        val parents = max_facts_in_graph fact_G ancestors
   1.283 -        val (upds, (_, n, _, _)) =
   1.284 -          ([], (parents, 0, next_commit_time (), false))
   1.285 -          |> fold do_fact new_facts
   1.286 +                  (adds, next_commit)
   1.287 +              val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   1.288 +            in (adds, ([name], n, next_commit, timed_out)) end
   1.289 +        val n =
   1.290 +          if null new_facts then
   1.291 +            0
   1.292 +          else
   1.293 +            let
   1.294 +              val last_th = new_facts |> List.last |> snd
   1.295 +              (* crude approximation *)
   1.296 +              val ancestors =
   1.297 +                old_facts
   1.298 +                |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
   1.299 +              val parents = max_facts_in_graph fact_G ancestors
   1.300 +              val (adds, (_, n, _, _)) =
   1.301 +                ([], (parents, 0, next_commit_time (), false))
   1.302 +                |> fold learn_new_fact new_facts
   1.303 +            in commit true adds []; n end
   1.304 +        fun relearn_old_fact _ (accum as (_, (_, _, true))) = accum
   1.305 +          | relearn_old_fact (_, th) (reps, (n, next_commit, _)) =
   1.306 +            let
   1.307 +              val name = nickname_of th
   1.308 +              val (n, reps) =
   1.309 +                case deps_of th of
   1.310 +                  SOME deps => (n + 1, (name, deps) :: reps)
   1.311 +                | NONE => (n, reps)
   1.312 +              val (reps, next_commit) =
   1.313 +                if Time.> (Timer.checkRealTimer timer, next_commit) then
   1.314 +                  (commit false [] reps; ([], next_commit_time ()))
   1.315 +                else
   1.316 +                  (reps, next_commit)
   1.317 +              val timed_out = Time.> (Timer.checkRealTimer timer, learn_timeout)
   1.318 +            in (reps, (n, next_commit, timed_out)) end
   1.319 +        val n =
   1.320 +          if null old_facts then
   1.321 +            n
   1.322 +          else
   1.323 +            let
   1.324 +              fun score_of (_, th) =
   1.325 +                random_range 0 (1000 * max_dependencies)
   1.326 +                - 500 * (th |> isar_dependencies_of all_names
   1.327 +                            |> Option.map length
   1.328 +                            |> the_default max_dependencies)
   1.329 +              val old_facts =
   1.330 +                old_facts |> map (`score_of)
   1.331 +                          |> sort (int_ord o pairself fst)
   1.332 +                          |> map snd
   1.333 +              val (reps, (n, _, _)) =
   1.334 +                ([], (n, next_commit_time (), false))
   1.335 +                |> fold relearn_old_fact old_facts
   1.336 +            in commit true [] reps; n end
   1.337        in
   1.338 -        commit true upds;
   1.339 -        if verbose orelse not auto then
   1.340 -          "Learned " ^ string_of_int n ^ " nontrivial proof" ^ plural_s n ^
   1.341 +        if verbose orelse auto_level < 2 then
   1.342 +          "Learned " ^ string_of_int n ^ " nontrivial " ^
   1.343 +          (if atp then "ATP" else "Isar") ^ " proof" ^ plural_s n ^
   1.344            (if verbose then
   1.345               " in " ^ string_from_time (Timer.checkRealTimer timer)
   1.346             else
   1.347 @@ -733,16 +782,35 @@
   1.348        end
   1.349    end
   1.350  
   1.351 -fun mash_learn ctxt (params as {provers, ...}) fact_override chained atp =
   1.352 +fun mash_learn ctxt (params as {provers, timeout, ...}) fact_override chained
   1.353 +               atp =
   1.354    let
   1.355      val css = Sledgehammer_Fact.clasimpset_rule_table_of ctxt
   1.356      val ctxt = ctxt |> Config.put instantiate_inducts false
   1.357      val facts =
   1.358        nearly_all_facts ctxt false fact_override Symtab.empty css chained []
   1.359                         @{prop True}
   1.360 +    val num_facts = length facts
   1.361 +    val prover = hd provers
   1.362 +    fun learn auto_level atp =
   1.363 +      mash_learn_facts ctxt params prover auto_level atp infinite_timeout facts
   1.364 +      |> Output.urgent_message
   1.365    in
   1.366 -     mash_learn_facts ctxt params (hd provers) false atp infinite_timeout facts
   1.367 -     |> Output.urgent_message
   1.368 +    (if atp then
   1.369 +       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
   1.370 +        plural_s num_facts ^ " for ATP proofs (" ^ quote prover ^ " timeout: " ^
   1.371 +        string_from_time timeout ^ ").\n\nCollecting Isar proofs first..."
   1.372 +        |> Output.urgent_message;
   1.373 +        learn 1 false;
   1.374 +        "Now collecting ATP proofs. This may take several hours. You can \
   1.375 +        \safely stop the learning process at any point."
   1.376 +        |> Output.urgent_message;
   1.377 +        learn 0 true)
   1.378 +     else
   1.379 +       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
   1.380 +        plural_s num_facts ^ " for Isar proofs..."
   1.381 +        |> Output.urgent_message;
   1.382 +        learn 0 false))
   1.383    end
   1.384  
   1.385  (* The threshold should be large enough so that MaSh doesn't kick in for Auto
   1.386 @@ -764,7 +832,7 @@
   1.387             Time.toSeconds timeout >= min_secs_for_learning then
   1.388            let val timeout = time_mult learn_timeout_slack timeout in
   1.389              launch_thread timeout
   1.390 -                (fn () => (true, mash_learn_facts ctxt params prover true false
   1.391 +                (fn () => (true, mash_learn_facts ctxt params prover 2 false
   1.392                                                    timeout facts))
   1.393            end
   1.394          else