learn new facts on query if there aren't too many of them in MaSh
authorblanchet
Fri, 23 Aug 2013 13:30:25 +0200
changeset 54289cbd3c7c48d2c
parent 54288 fbf4d50dec91
child 54290 1e9735cd27aa
learn new facts on query if there aren't too many of them in MaSh
src/HOL/Tools/Sledgehammer/MaSh/src/mash.py
src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
     1.1 --- a/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Fri Aug 23 00:12:20 2013 +0200
     1.2 +++ b/src/HOL/Tools/Sledgehammer/MaSh/src/mash.py	Fri Aug 23 13:30:25 2013 +0200
     1.3 @@ -103,30 +103,27 @@
     1.4          received = communicate(data,args.host,args.port)
     1.5          logger.info(received)     
     1.6      
     1.7 -    if args.inputFile == None:
     1.8 -        return
     1.9 -    logger.debug('Using the following settings: %s',args)
    1.10 -    # IO Streams
    1.11 -    OS = open(args.predictions,'w')
    1.12 -    IS = open(args.inputFile,'r')
    1.13 -    lineCount = 0
    1.14 -    for line in IS:
    1.15 -        lineCount += 1
    1.16 -        if lineCount % 100 == 0:
    1.17 -            logger.info('On line %s', lineCount)
    1.18 -        #if lineCount == 50: ###
    1.19 -        #    break       
    1.20 -        received = communicate(line,args.host,args.port)
    1.21 -        if not received == '':
    1.22 -            OS.write('%s\n' % received)
    1.23 -    OS.close()
    1.24 -    IS.close()
    1.25 +    if not args.inputFile == None:
    1.26 +        logger.debug('Using the following settings: %s',args)
    1.27 +        # IO Streams
    1.28 +        OS = open(args.predictions,'w')
    1.29 +        IS = open(args.inputFile,'r')
    1.30 +        lineCount = 0
    1.31 +        for line in IS:
    1.32 +            lineCount += 1
    1.33 +            if lineCount % 100 == 0:
    1.34 +                logger.info('On line %s', lineCount)
    1.35 +            received = communicate(line,args.host,args.port)
    1.36 +            if not received == '':
    1.37 +                OS.write('%s\n' % received)
    1.38 +        OS.close()
    1.39 +        IS.close()
    1.40  
    1.41      # Statistics
    1.42      if args.statistics:
    1.43          received = communicate('avgStats',args.host,args.port)
    1.44          logger.info(received)
    1.45 -    elif args.saveModels:
    1.46 +    if args.saveModels:
    1.47          communicate('save',args.host,args.port)
    1.48  
    1.49  
     2.1 --- a/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Aug 23 00:12:20 2013 +0200
     2.2 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML	Fri Aug 23 13:30:25 2013 +0200
     2.3 @@ -151,7 +151,7 @@
     2.4     xs |> chunk_list 500 |> List.app (File.append path o implode o map f))
     2.5    handle IO.Io _ => ()
     2.6  
     2.7 -fun run_mash_tool ctxt overlord extra_args write_cmds read_suggs =
     2.8 +fun run_mash_tool ctxt overlord extra_args background write_cmds read_suggs =
     2.9    let
    2.10      val (temp_dir, serial) =
    2.11        if overlord then (getenv "ISABELLE_HOME_USER", "")
    2.12 @@ -172,7 +172,8 @@
    2.13        " --dictsFile=" ^ model_dir ^ "/dict.pickle" ^
    2.14        " --log " ^ log_file ^ " " ^ core ^
    2.15        (if extra_args = [] then "" else " " ^ space_implode " " extra_args) ^
    2.16 -      " >& " ^ err_file
    2.17 +      " >& " ^ err_file ^
    2.18 +      (if background then " &" else "")
    2.19      fun run_on () =
    2.20        (Isabelle_System.bash command
    2.21         |> tap (fn _ => trace_msg ctxt (fn () =>
    2.22 @@ -254,7 +255,10 @@
    2.23  struct
    2.24  
    2.25  fun shutdown ctxt overlord =
    2.26 -  run_mash_tool ctxt overlord [shutdown_server_arg] ([], K "") (K ())
    2.27 +  run_mash_tool ctxt overlord [shutdown_server_arg] true ([], K "") (K ())
    2.28 +
    2.29 +fun save ctxt overlord =
    2.30 +  run_mash_tool ctxt overlord [save_models_arg] true ([], K "") (K ())
    2.31  
    2.32  fun unlearn ctxt overlord =
    2.33    let val path = mash_model_dir () in
    2.34 @@ -270,19 +274,19 @@
    2.35    | learn ctxt overlord learns =
    2.36      (trace_msg ctxt (fn () => "MaSh learn " ^
    2.37           elide_string 1000 (space_implode " " (map #1 learns)));
    2.38 -     run_mash_tool ctxt overlord [save_models_arg] (learns, str_of_learn)
    2.39 +     run_mash_tool ctxt overlord [] false (learns, str_of_learn)
    2.40                     (K ()))
    2.41  
    2.42  fun relearn _ _ [] = ()
    2.43    | relearn ctxt overlord relearns =
    2.44      (trace_msg ctxt (fn () => "MaSh relearn " ^
    2.45           elide_string 1000 (space_implode " " (map #1 relearns)));
    2.46 -     run_mash_tool ctxt overlord [save_models_arg] (relearns, str_of_relearn)
    2.47 +     run_mash_tool ctxt overlord [] false (relearns, str_of_relearn)
    2.48                     (K ()))
    2.49  
    2.50  fun query ctxt overlord max_suggs (query as (_, _, _, feats)) =
    2.51    (trace_msg ctxt (fn () => "MaSh query " ^ encode_features feats);
    2.52 -   run_mash_tool ctxt overlord [] ([query], str_of_query max_suggs)
    2.53 +   run_mash_tool ctxt overlord [] false ([query], str_of_query max_suggs)
    2.54         (fn suggs =>
    2.55             case suggs () of
    2.56               [] => []
    2.57 @@ -359,8 +363,8 @@
    2.58       | _ => NONE)
    2.59    | _ => NONE
    2.60  
    2.61 -fun load _ _ (state as (true, _)) = state
    2.62 -  | load ctxt overlord _ =
    2.63 +fun load_state _ _ (state as (true, _)) = state
    2.64 +  | load_state ctxt overlord _ =
    2.65      let val path = mash_state_file () in
    2.66        (true,
    2.67         case try File.read_lines path of
    2.68 @@ -394,8 +398,8 @@
    2.69         | _ => empty_state)
    2.70      end
    2.71  
    2.72 -fun save _ (state as {dirty = SOME [], ...}) = state
    2.73 -  | save ctxt {access_G, num_known_facts, dirty} =
    2.74 +fun save_state _ (state as {dirty = SOME [], ...}) = state
    2.75 +  | save_state ctxt {access_G, num_known_facts, dirty} =
    2.76      let
    2.77        fun str_of_entry (name, parents, kind) =
    2.78          str_of_proof_kind kind ^ " " ^ encode_str name ^ ": " ^
    2.79 @@ -424,12 +428,13 @@
    2.80  in
    2.81  
    2.82  fun map_state ctxt overlord f =
    2.83 -  Synchronized.change global_state (load ctxt overlord ##> (f #> save ctxt))
    2.84 +  Synchronized.change global_state
    2.85 +                      (load_state ctxt overlord ##> (f #> save_state ctxt))
    2.86    handle FILE_VERSION_TOO_NEW () => ()
    2.87  
    2.88  fun peek_state ctxt overlord f =
    2.89    Synchronized.change_result global_state
    2.90 -      (perhaps (try (load ctxt overlord)) #> `snd #>> f)
    2.91 +      (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
    2.92  
    2.93  fun clear_state ctxt overlord =
    2.94    Synchronized.change global_state (fn _ =>
    2.95 @@ -1044,7 +1049,8 @@
    2.96                  used_ths |> filter (is_fact_in_graph access_G)
    2.97                           |> map nickname_of_thm
    2.98              in
    2.99 -              MaSh.learn ctxt overlord [(name, parents, feats, deps)]
   2.100 +              MaSh.learn ctxt overlord [(name, parents, feats, deps)];
   2.101 +              MaSh.save ctxt overlord
   2.102              end);
   2.103          (true, "")
   2.104        end)
   2.105 @@ -1056,7 +1062,7 @@
   2.106  
   2.107  (* The timeout is understood in a very relaxed fashion. *)
   2.108  fun mash_learn_facts ctxt (params as {debug, verbose, overlord, ...}) prover
   2.109 -                     auto_level run_prover learn_timeout facts =
   2.110 +                     save auto_level run_prover learn_timeout facts =
   2.111    let
   2.112      val timer = Timer.startRealTimer ()
   2.113      fun next_commit_time () =
   2.114 @@ -1107,6 +1113,7 @@
   2.115              in
   2.116                MaSh.learn ctxt overlord (rev learns);
   2.117                MaSh.relearn ctxt overlord relearns;
   2.118 +              if save then MaSh.save ctxt overlord else ();
   2.119                {access_G = access_G, num_known_facts = num_known_facts,
   2.120                 dirty = dirty}
   2.121              end
   2.122 @@ -1228,7 +1235,7 @@
   2.123      val num_facts = length facts
   2.124      val prover = hd provers
   2.125      fun learn auto_level run_prover =
   2.126 -      mash_learn_facts ctxt params prover auto_level run_prover NONE facts
   2.127 +      mash_learn_facts ctxt params prover true auto_level run_prover NONE facts
   2.128        |> Output.urgent_message
   2.129    in
   2.130      if run_prover then
   2.131 @@ -1261,6 +1268,8 @@
   2.132  val mepo_weight = 0.5
   2.133  val mash_weight = 0.5
   2.134  
   2.135 +val max_facts_to_learn_before_query = 100
   2.136 +
   2.137  (* The threshold should be large enough so that MaSh doesn't kick in for Auto
   2.138     Sledgehammer and Try. *)
   2.139  val min_secs_for_learning = 15
   2.140 @@ -1278,28 +1287,45 @@
   2.141      [("", [])]
   2.142    else
   2.143      let
   2.144 -      fun maybe_learn () =
   2.145 -        if learn andalso not (Async_Manager.has_running_threads MaShN) andalso
   2.146 +      fun maybe_launch_thread () =
   2.147 +        if not (Async_Manager.has_running_threads MaShN) andalso
   2.148             (timeout = NONE orelse
   2.149              Time.toSeconds (the timeout) >= min_secs_for_learning) then
   2.150            let
   2.151              val timeout = Option.map (time_mult learn_timeout_slack) timeout
   2.152            in
   2.153              launch_thread (timeout |> the_default one_day)
   2.154 -                (fn () => (true, mash_learn_facts ctxt params prover 2 false
   2.155 -                                                  timeout facts))
   2.156 +                (fn () => (true, mash_learn_facts ctxt params prover true 2
   2.157 +                                                  false timeout facts))
   2.158            end
   2.159          else
   2.160            ()
   2.161 -      val effective_fact_filter =
   2.162 +      fun maybe_learn () =
   2.163 +        if learn then
   2.164 +          let
   2.165 +            val {access_G, num_known_facts, ...} = peek_state ctxt overlord I
   2.166 +            val is_in_access_G = is_fact_in_graph access_G o snd
   2.167 +          in
   2.168 +            if length facts - num_known_facts <= max_facts_to_learn_before_query
   2.169 +               andalso length (filter_out is_in_access_G facts)
   2.170 +                       <= max_facts_to_learn_before_query then
   2.171 +              (mash_learn_facts ctxt params prover false 2 false timeout facts
   2.172 +               |> (fn "" => () | s => Output.urgent_message s);
   2.173 +               true)
   2.174 +            else
   2.175 +              (maybe_launch_thread (); false)
   2.176 +          end
   2.177 +        else
   2.178 +          false
   2.179 +      val (save, effective_fact_filter) =
   2.180          case fact_filter of
   2.181 -          SOME ff => (() |> ff <> mepoN ? maybe_learn; ff)
   2.182 +          SOME ff => (ff <> mepoN andalso maybe_learn (), ff)
   2.183          | NONE =>
   2.184            if is_mash_enabled () then
   2.185 -            (maybe_learn ();
   2.186 +            (maybe_learn (),
   2.187               if mash_can_suggest_facts ctxt overlord then meshN else mepoN)
   2.188            else
   2.189 -            mepoN
   2.190 +            (false, mepoN)
   2.191        val add_ths = Attrib.eval_thms ctxt add
   2.192        fun in_add (_, th) = member Thm.eq_thm_prop add_ths th
   2.193        fun add_and_take accepts =
   2.194 @@ -1330,6 +1356,7 @@
   2.195          mesh_facts (Thm.eq_thm_prop o pairself snd) max_facts mess
   2.196          |> add_and_take
   2.197      in
   2.198 +      if save then MaSh.save ctxt overlord else ();
   2.199        case (fact_filter, mess) of
   2.200          (NONE, [(_, (mepo, _)), (_, (mash, _))]) =>
   2.201          [(meshN, mesh), (mepoN, mepo |> map fst |> add_and_take),