1.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Jul 20 22:19:46 2012 +0200
1.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML Fri Jul 20 22:19:46 2012 +0200
1.3 @@ -472,9 +472,9 @@
1.4 "Internal error when " ^ when ^ ":\n" ^
1.5 ML_Compiler.exn_message exn); def)
1.6
1.7 -type mash_state = {fact_graph : unit Graph.T}
1.8 +type mash_state = {fact_G : unit Graph.T}
1.9
1.10 -val empty_state = {fact_graph = Graph.empty}
1.11 +val empty_state = {fact_G = Graph.empty}
1.12
1.13 local
1.14
1.15 @@ -496,25 +496,25 @@
1.16 | (name, parents) =>
1.17 Graph.default_node (name, ())
1.18 #> fold (add_edge_to name) parents
1.19 - val fact_graph =
1.20 + val fact_G =
1.21 try_graph ctxt "loading state" Graph.empty (fn () =>
1.22 Graph.empty |> version' = version
1.23 ? fold add_fact_line fact_lines)
1.24 - in {fact_graph = fact_graph} end
1.25 + in {fact_G = fact_G} end
1.26 | _ => empty_state)
1.27 end
1.28
1.29 -fun save {fact_graph} =
1.30 +fun save {fact_G} =
1.31 let
1.32 val path = mash_state_path ()
1.33 fun fact_line_for name parents =
1.34 escape_meta name ^ ": " ^ escape_metas parents
1.35 val append_fact = File.append path o suffix "\n" oo fact_line_for
1.36 + fun append_entry (name, ((), (parents, _))) () =
1.37 + append_fact name (Graph.Keys.dest parents)
1.38 in
1.39 File.write path (version ^ "\n");
1.40 - Graph.fold (fn (name, ((), (parents, _))) => fn () =>
1.41 - append_fact name (Graph.Keys.dest parents))
1.42 - fact_graph ()
1.43 + Graph.fold append_entry fact_G ()
1.44 end
1.45
1.46 val global_state =
1.47 @@ -535,45 +535,52 @@
1.48 end
1.49
1.50 fun mash_could_suggest_facts () = mash_home () <> ""
1.51 -fun mash_can_suggest_facts ctxt =
1.52 - not (Graph.is_empty (#fact_graph (mash_get ctxt)))
1.53 +fun mash_can_suggest_facts ctxt = not (Graph.is_empty (#fact_G (mash_get ctxt)))
1.54
1.55 -fun parents_wrt_facts facts fact_graph =
1.56 +fun queue_of xs = Queue.empty |> fold Queue.enqueue xs
1.57 +
1.58 +fun max_facts_in_graph fact_G facts =
1.59 let
1.60 val facts = [] |> fold (cons o nickname_of o snd) facts
1.61 val tab = Symtab.empty |> fold (fn name => Symtab.update (name, ())) facts
1.62 - fun insert_not_seen seen name =
1.63 - not (member (op =) seen name) ? insert (op =) name
1.64 - fun parents_of _ parents [] = parents
1.65 - | parents_of seen parents (name :: names) =
1.66 - if Symtab.defined tab name then
1.67 - parents_of (name :: seen) (name :: parents) names
1.68 - else
1.69 - parents_of (name :: seen) parents
1.70 - (Graph.Keys.fold (insert_not_seen seen)
1.71 - (Graph.imm_preds fact_graph name) names)
1.72 - in parents_of [] [] (Graph.maximals fact_graph) end
1.73 + fun enqueue_new seen name =
1.74 + not (member (op =) seen name) ? Queue.enqueue name
1.75 + fun find_maxes seen maxs names =
1.76 + case try Queue.dequeue names of
1.77 + NONE => maxs
1.78 + | SOME (name, names) =>
1.79 + if Symtab.defined tab name then
1.80 + let
1.81 + fun no_path x y = not (member (op =) (Graph.all_preds fact_G [y]) x)
1.82 + val maxs = maxs |> filter (fn max => no_path max name)
1.83 + val maxs = maxs |> forall (no_path name) maxs ? cons name
1.84 + in find_maxes (name :: seen) maxs names end
1.85 + else
1.86 + find_maxes (name :: seen) maxs
1.87 + (Graph.Keys.fold (enqueue_new seen)
1.88 + (Graph.imm_preds fact_G name) names)
1.89 + in find_maxes [] [] (queue_of (Graph.maximals fact_G)) end
1.90
1.91 (* Generate more suggestions than requested, because some might be thrown out
1.92 later for various reasons and "meshing" gives better results with some
1.93 slack. *)
1.94 fun max_suggs_of max_facts = max_facts + Int.min (200, max_facts)
1.95
1.96 -fun is_fact_in_graph fact_graph (_, th) =
1.97 - can (Graph.get_node fact_graph) (nickname_of th)
1.98 +fun is_fact_in_graph fact_G (_, th) =
1.99 + can (Graph.get_node fact_G) (nickname_of th)
1.100
1.101 fun mash_suggest_facts ctxt ({overlord, ...} : params) prover max_facts hyp_ts
1.102 concl_t facts =
1.103 let
1.104 val thy = Proof_Context.theory_of ctxt
1.105 - val fact_graph = #fact_graph (mash_get ctxt)
1.106 - val parents = parents_wrt_facts facts fact_graph
1.107 + val fact_G = #fact_G (mash_get ctxt)
1.108 + val parents = max_facts_in_graph fact_G facts
1.109 val feats = features_of ctxt prover thy (Local, General) (concl_t :: hyp_ts)
1.110 val suggs =
1.111 - if Graph.is_empty fact_graph then []
1.112 + if Graph.is_empty fact_G then []
1.113 else mash_QUERY ctxt overlord (max_suggs_of max_facts) (parents, feats)
1.114 val selected = facts |> suggested_facts suggs
1.115 - val unknown = facts |> filter_out (is_fact_in_graph fact_graph)
1.116 + val unknown = facts |> filter_out (is_fact_in_graph fact_G)
1.117 in (selected, unknown) end
1.118
1.119 fun update_fact_graph ctxt (name, parents, feats, deps) (upds, graph) =
1.120 @@ -596,6 +603,10 @@
1.121 val desc = ("machine learner for Sledgehammer", "")
1.122 in Async_Manager.launch MaShN birth_time death_time desc task end
1.123
1.124 +fun freshish_name () =
1.125 + Date.fmt ".%Y_%m_%d_%H_%M_%S__" (Date.fromTimeLocal (Time.now ())) ^
1.126 + serial_string ()
1.127 +
1.128 fun mash_learn_proof ctxt ({overlord, timeout, ...} : params) prover t facts
1.129 used_ths =
1.130 if is_smt_prover ctxt prover then
1.131 @@ -605,18 +616,17 @@
1.132 (fn () =>
1.133 let
1.134 val thy = Proof_Context.theory_of ctxt
1.135 - val name = timestamp () ^ " " ^ serial_string () (* freshish *)
1.136 + val name = freshish_name ()
1.137 val feats = features_of ctxt prover thy (Local, General) [t]
1.138 val deps = used_ths |> map nickname_of
1.139 in
1.140 - mash_map ctxt (fn {fact_graph} =>
1.141 + mash_map ctxt (fn {fact_G} =>
1.142 let
1.143 - val parents = parents_wrt_facts facts fact_graph
1.144 + val parents = max_facts_in_graph fact_G facts
1.145 val upds = [(name, parents, feats, deps)]
1.146 - val (upds, fact_graph) =
1.147 - ([], fact_graph) |> fold (update_fact_graph ctxt) upds
1.148 - in
1.149 - mash_ADD ctxt overlord upds; {fact_graph = fact_graph}
1.150 + val (upds, fact_G) =
1.151 + ([], fact_G) |> fold (update_fact_graph ctxt) upds
1.152 + in mash_ADD ctxt overlord upds; {fact_G = fact_G}
1.153 end);
1.154 (true, "")
1.155 end)
1.156 @@ -636,10 +646,10 @@
1.157 val timer = Timer.startRealTimer ()
1.158 fun next_commit_time () =
1.159 Time.+ (Timer.checkRealTimer timer, commit_timeout)
1.160 - val {fact_graph} = mash_get ctxt
1.161 - val new_facts =
1.162 - facts |> filter_out (is_fact_in_graph fact_graph)
1.163 - |> sort (thm_ord o pairself snd)
1.164 + val {fact_G} = mash_get ctxt
1.165 + val (old_facts, new_facts) =
1.166 + facts |> List.partition (is_fact_in_graph fact_G)
1.167 + ||> sort (thm_ord o pairself snd)
1.168 val num_new_facts = length new_facts
1.169 in
1.170 (if not auto then
1.171 @@ -654,7 +664,7 @@
1.172 else
1.173 "")
1.174 |> Output.urgent_message;
1.175 - if null new_facts then
1.176 + if num_new_facts = 0 then
1.177 if not auto then
1.178 "Nothing to learn.\n\nHint: Try " ^ sendback relearn_isarN ^ " or " ^
1.179 sendback relearn_atpN ^ " to learn from scratch."
1.180 @@ -662,17 +672,21 @@
1.181 ""
1.182 else
1.183 let
1.184 - val ths = facts |> map snd
1.185 + val last_th = new_facts |> List.last |> snd
1.186 + (* crude approximation *)
1.187 + val ancestors =
1.188 + old_facts |> filter (fn (_, th) => thm_ord (th, last_th) <> GREATER)
1.189 val all_names =
1.190 - ths |> filter_out is_likely_tautology_or_too_meta
1.191 - |> map (rpair () o nickname_of)
1.192 - |> Symtab.make
1.193 + facts |> map snd
1.194 + |> filter_out is_likely_tautology_or_too_meta
1.195 + |> map (rpair () o nickname_of)
1.196 + |> Symtab.make
1.197 fun do_commit [] state = state
1.198 - | do_commit upds {fact_graph} =
1.199 + | do_commit upds {fact_G} =
1.200 let
1.201 - val (upds, fact_graph) =
1.202 - ([], fact_graph) |> fold (update_fact_graph ctxt) upds
1.203 - in mash_ADD ctxt overlord (rev upds); {fact_graph = fact_graph} end
1.204 + val (upds, fact_G) =
1.205 + ([], fact_G) |> fold (update_fact_graph ctxt) upds
1.206 + in mash_ADD ctxt overlord (rev upds); {fact_G = fact_G} end
1.207 fun trim_deps deps = if length deps > max_dependencies then [] else deps
1.208 fun commit last upds =
1.209 (if debug andalso not auto then Output.urgent_message "Committing..."
1.210 @@ -708,7 +722,7 @@
1.211 val timed_out =
1.212 Time.> (Timer.checkRealTimer timer, learn_timeout)
1.213 in (upds, ([name], n, next_commit, timed_out)) end
1.214 - val parents = parents_wrt_facts facts fact_graph
1.215 + val parents = max_facts_in_graph fact_G ancestors
1.216 val (upds, (_, n, _, _)) =
1.217 ([], (parents, 0, next_commit_time (), false))
1.218 |> fold do_fact new_facts