prefer recursive calls before others in the mode inference
authorbulwahn
Mon, 29 Mar 2010 17:30:46 +0200
changeset 360223837493fe4ab
parent 36021 29a15da9c63d
child 36023 a790b94e090c
prefer recursive calls before others in the mode inference
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
     1.1 --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 29 17:30:45 2010 +0200
     1.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 29 17:30:46 2010 +0200
     1.3 @@ -1563,50 +1563,60 @@
     1.4      EQUAL => ord2 (x, x')
     1.5    | ord => ord
     1.6  
     1.7 -fun deriv_ord2' thy modes t1 t2 ((deriv1, mvars1), (deriv2, mvars2)) =
     1.8 +fun deriv_ord2' thy pred modes t1 t2 ((deriv1, mvars1), (deriv2, mvars2)) =
     1.9    let
    1.10 +    (* prefer modes without requirement for generating random values *)
    1.11      fun mvars_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
    1.12        int_ord (length mvars1, length mvars2)
    1.13 +    (* prefer non-random modes *)
    1.14      fun random_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
    1.15        int_ord (if random_mode_in_deriv modes t1 deriv1 then 1 else 0,
    1.16          if random_mode_in_deriv modes t1 deriv1 then 1 else 0)
    1.17 +    (* prefer modes with more input and less output *)
    1.18      fun output_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
    1.19        int_ord (number_of_output_positions (head_mode_of deriv1),
    1.20          number_of_output_positions (head_mode_of deriv2))
    1.21 +    (* prefer recursive calls *)
    1.22 +    fun is_rec_premise t =
    1.23 +      case fst (strip_comb t) of Const (c, T) => c = pred | _ => false
    1.24 +    fun recursive_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
    1.25 +      int_ord (if is_rec_premise t1 then 0 else 1,
    1.26 +        if is_rec_premise t2 then 0 else 1)
    1.27 +    val ord = lex_ord mvars_ord (lex_ord random_mode_ord (lex_ord output_mode_ord recursive_ord))
    1.28    in
    1.29 -    lex_ord mvars_ord (lex_ord random_mode_ord output_mode_ord)
    1.30 -      ((t1, deriv1, mvars1), (t2, deriv2, mvars2))
    1.31 +    ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2))
    1.32    end
    1.33  
    1.34 -fun deriv_ord2 thy modes t = deriv_ord2' thy modes t t
    1.35 +fun deriv_ord2 thy pred modes t = deriv_ord2' thy pred modes t t
    1.36  
    1.37  fun deriv_ord ((deriv1, mvars1), (deriv2, mvars2)) =
    1.38    int_ord (length mvars1, length mvars2)
    1.39  
    1.40 -fun premise_ord thy modes ((prem1, a1), (prem2, a2)) =
    1.41 -  rev_option_ord (deriv_ord2' thy modes (term_of_prem prem1) (term_of_prem prem2)) (a1, a2)
    1.42 +fun premise_ord thy pred modes ((prem1, a1), (prem2, a2)) =
    1.43 +  rev_option_ord (deriv_ord2' thy pred modes (term_of_prem prem1) (term_of_prem prem2)) (a1, a2)
    1.44  
    1.45  fun print_mode_list modes =
    1.46    tracing ("modes: " ^ (commas (map (fn (s, ms) => s ^ ": " ^
    1.47      commas (map (fn (m, r) => string_of_mode m ^ (if r then " random " else " not ")) ms)) modes)))
    1.48  
    1.49 -fun select_mode_prem (mode_analysis_options : mode_analysis_options) (thy : theory) pol (modes, (pos_modes, neg_modes)) vs ps =
    1.50 +fun select_mode_prem (mode_analysis_options : mode_analysis_options) (thy : theory) pred
    1.51 +  pol (modes, (pos_modes, neg_modes)) vs ps =
    1.52    let
    1.53      fun choose_mode_of_prem (Prem t) = partial_hd
    1.54 -        (sort (deriv_ord2 thy modes t) (all_derivations_of thy pos_modes vs t))
    1.55 +        (sort (deriv_ord2 thy pred modes t) (all_derivations_of thy pos_modes vs t))
    1.56        | choose_mode_of_prem (Sidecond t) = SOME (Context Bool, missing_vars vs t)
    1.57        | choose_mode_of_prem (Negprem t) = partial_hd
    1.58 -          (sort (deriv_ord2 thy modes t) (filter (fn (d, missing_vars) => is_all_input (head_mode_of d))
    1.59 +          (sort (deriv_ord2 thy pred modes t) (filter (fn (d, missing_vars) => is_all_input (head_mode_of d))
    1.60               (all_derivations_of thy neg_modes vs t)))
    1.61        | choose_mode_of_prem p = raise Fail ("choose_mode_of_prem: " ^ string_of_prem thy p)
    1.62    in
    1.63      if #reorder_premises mode_analysis_options then
    1.64 -      partial_hd (sort (premise_ord thy modes) (ps ~~ map choose_mode_of_prem ps))
    1.65 +      partial_hd (sort (premise_ord thy pred modes) (ps ~~ map choose_mode_of_prem ps))
    1.66      else
    1.67        SOME (hd ps, choose_mode_of_prem (hd ps))
    1.68    end
    1.69  
    1.70 -fun check_mode_clause' (mode_analysis_options : mode_analysis_options) thy param_vs (modes :
    1.71 +fun check_mode_clause' (mode_analysis_options : mode_analysis_options) thy pred param_vs (modes :
    1.72    (string * ((bool * mode) * bool) list) list) ((pol, mode) : bool * mode) (ts, ps) =
    1.73    let
    1.74      val vTs = distinct (op =) (fold Term.add_frees (map term_of_prem ps) (fold Term.add_frees ts []))
    1.75 @@ -1631,7 +1641,7 @@
    1.76      fun check_mode_prems acc_ps rnd vs [] = SOME (acc_ps, vs, rnd)
    1.77        | check_mode_prems acc_ps rnd vs ps =
    1.78          (case
    1.79 -          (select_mode_prem mode_analysis_options thy pol (modes', (pos_modes', neg_modes')) vs ps) of
    1.80 +          (select_mode_prem mode_analysis_options thy pred pol (modes', (pos_modes', neg_modes')) vs ps) of
    1.81            SOME (p, SOME (deriv, [])) => check_mode_prems ((p, deriv) :: acc_ps) rnd
    1.82              (known_vs_after p vs) (filter_out (equal p) ps)
    1.83          | SOME (p, SOME (deriv, missing_vars)) =>
    1.84 @@ -1678,7 +1688,7 @@
    1.85      fun check_mode m =
    1.86        let
    1.87          val res = Output.cond_timeit false "work part of check_mode for one mode" (fn _ => 
    1.88 -          map (check_mode_clause' mode_analysis_options thy param_vs modes m) rs)
    1.89 +          map (check_mode_clause' mode_analysis_options thy p param_vs modes m) rs)
    1.90        in
    1.91          Output.cond_timeit false "aux part of check_mode for one mode" (fn _ => 
    1.92          case find_indices is_none res of
    1.93 @@ -1701,7 +1711,7 @@
    1.94      (p, map (fn (m, rnd) =>
    1.95        (m, map
    1.96          ((fn (ts, ps, rnd) => (ts, ps)) o the o
    1.97 -          check_mode_clause' mode_analysis_options thy param_vs modes m) rs)) ms)
    1.98 +          check_mode_clause' mode_analysis_options thy p param_vs modes m) rs)) ms)
    1.99    end;
   1.100  
   1.101  fun fixp f (x : (string * ((bool * mode) * bool) list) list) =
   1.102 @@ -1759,8 +1769,8 @@
   1.103        else
   1.104          map (fn (s, ms) => (s, map (fn m => ((true, m), false)) ms)) all_modes
   1.105      fun iteration modes = map
   1.106 -      (check_modes_pred' mode_analysis_options options thy param_vs clauses (modes @ extra_modes))
   1.107 -        modes
   1.108 +      (check_modes_pred' mode_analysis_options options thy param_vs clauses
   1.109 +        (modes @ extra_modes)) modes
   1.110      val ((modes : (string * ((bool * mode) * bool) list) list), errors) =
   1.111        Output.cond_timeit false "Fixpount computation of mode analysis" (fn () =>
   1.112        if collect_errors then