src/HOL/Tools/Nitpick/nitpick_preproc.ML
author blanchet
Thu, 29 Apr 2010 01:17:14 +0200
changeset 36553 8ff45c2076da
parent 36389 8228b3a4a2ba
child 37252 e01c1fe245cd
permissions -rw-r--r--
expand combinators in Isar proofs constructed by Sledgehammer;
this requires shuffling around a couple of functions previously defined in Refute
     1 (*  Title:      HOL/Tools/Nitpick/nitpick_preproc.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2008, 2009, 2010
     4 
     5 Nitpick's HOL preprocessor.
     6 *)
     7 
     8 signature NITPICK_PREPROC =
     9 sig
    10   type hol_context = Nitpick_HOL.hol_context
    11   val preprocess_term :
    12     hol_context -> (typ option * bool option) list
    13     -> (typ option * bool option) list -> term
    14     -> term list * term list * bool * bool * bool
    15 end;
    16 
    17 structure Nitpick_Preproc : NITPICK_PREPROC =
    18 struct
    19 
    20 open Nitpick_Util
    21 open Nitpick_HOL
    22 open Nitpick_Mono
    23 
    24 fun is_positive_existential polar quant_s =
    25   (polar = Pos andalso quant_s = @{const_name Ex}) orelse
    26   (polar = Neg andalso quant_s <> @{const_name Ex})
    27 
    28 (** Binary coding of integers **)
    29 
    30 (* If a formula contains a numeral whose absolute value is more than this
    31    threshold, the unary coding is likely not to work well and we prefer the
    32    binary coding. *)
    33 val binary_int_threshold = 3
    34 
    35 val may_use_binary_ints =
    36   let
    37     fun aux def (Const (@{const_name "=="}, _) $ t1 $ t2) =
    38         aux def t1 andalso aux false t2
    39       | aux def (@{const "==>"} $ t1 $ t2) = aux false t1 andalso aux def t2
    40       | aux def (Const (@{const_name "op ="}, _) $ t1 $ t2) =
    41         aux def t1 andalso aux false t2
    42       | aux def (@{const "op -->"} $ t1 $ t2) = aux false t1 andalso aux def t2
    43       | aux def (t1 $ t2) = aux def t1 andalso aux def t2
    44       | aux def (t as Const (s, _)) =
    45         (not def orelse t <> @{const Suc}) andalso
    46         not (member (op =) [@{const_name Abs_Frac}, @{const_name Rep_Frac},
    47                             @{const_name nat_gcd}, @{const_name nat_lcm},
    48                             @{const_name Frac}, @{const_name norm_frac}] s)
    49       | aux def (Abs (_, _, t')) = aux def t'
    50       | aux _ _ = true
    51   in aux end
    52 val should_use_binary_ints =
    53   let
    54     fun aux (t1 $ t2) = aux t1 orelse aux t2
    55       | aux (Const (s, T)) =
    56         ((s = @{const_name times} orelse s = @{const_name div}) andalso
    57          is_integer_type (body_type T)) orelse
    58         (String.isPrefix numeral_prefix s andalso
    59          let val n = the (Int.fromString (unprefix numeral_prefix s)) in
    60            n < ~ binary_int_threshold orelse n > binary_int_threshold
    61          end)
    62       | aux (Abs (_, _, t')) = aux t'
    63       | aux _ = false
    64   in aux end
    65 
    66 (** Uncurrying **)
    67 
    68 fun add_to_uncurry_table thy t =
    69   let
    70     fun aux (t1 $ t2) args table =
    71         let val table = aux t2 [] table in aux t1 (t2 :: args) table end
    72       | aux (Abs (_, _, t')) _ table = aux t' [] table
    73       | aux (t as Const (x as (s, _))) args table =
    74         if is_built_in_const thy [(NONE, true)] true x orelse
    75            is_constr_like thy x orelse
    76            is_sel s orelse s = @{const_name Sigma} then
    77           table
    78         else
    79           Termtab.map_default (t, 65536) (Integer.min (length args)) table
    80       | aux _ _ table = table
    81   in aux t [] end
    82 
    83 fun uncurry_prefix_for k j =
    84   uncurry_prefix ^ string_of_int k ^ "@" ^ string_of_int j ^ name_sep
    85 
    86 fun uncurry_term table t =
    87   let
    88     fun aux (t1 $ t2) args = aux t1 (aux t2 [] :: args)
    89       | aux (Abs (s, T, t')) args = betapplys (Abs (s, T, aux t' []), args)
    90       | aux (t as Const (s, T)) args =
    91         (case Termtab.lookup table t of
    92            SOME n =>
    93            if n >= 2 then
    94              let
    95                val arg_Ts = strip_n_binders n T |> fst
    96                val j =
    97                  if is_iterator_type (hd arg_Ts) then
    98                    1
    99                  else case find_index (not_equal bool_T) arg_Ts of
   100                    ~1 => n
   101                  | j => j
   102                val ((before_args, tuple_args), after_args) =
   103                  args |> chop n |>> chop j
   104                val ((before_arg_Ts, tuple_arg_Ts), rest_T) =
   105                  T |> strip_n_binders n |>> chop j
   106                val tuple_T = HOLogic.mk_tupleT tuple_arg_Ts
   107              in
   108                if n - j < 2 then
   109                  betapplys (t, args)
   110                else
   111                  betapplys (Const (uncurry_prefix_for (n - j) j ^ s,
   112                                    before_arg_Ts ---> tuple_T --> rest_T),
   113                             before_args @ [mk_flat_tuple tuple_T tuple_args] @
   114                             after_args)
   115              end
   116            else
   117              betapplys (t, args)
   118          | NONE => betapplys (t, args))
   119       | aux t args = betapplys (t, args)
   120   in aux t [] end
   121 
   122 (** Boxing **)
   123 
   124 fun box_fun_and_pair_in_term (hol_ctxt as {thy, stds, fast_descrs, ...}) def
   125                              orig_t =
   126   let
   127     fun box_relational_operator_type (Type (@{type_name fun}, Ts)) =
   128         Type (@{type_name fun}, map box_relational_operator_type Ts)
   129       | box_relational_operator_type (Type (@{type_name "*"}, Ts)) =
   130         Type (@{type_name "*"}, map (box_type hol_ctxt InPair) Ts)
   131       | box_relational_operator_type T = T
   132     fun add_boxed_types_for_var (z as (_, T)) (T', t') =
   133       case t' of
   134         Var z' => z' = z ? insert (op =) T'
   135       | Const (@{const_name Pair}, _) $ t1 $ t2 =>
   136         (case T' of
   137            Type (_, [T1, T2]) =>
   138            fold (add_boxed_types_for_var z) [(T1, t1), (T2, t2)]
   139          | _ => raise TYPE ("Nitpick_Preproc.box_fun_and_pair_in_term.\
   140                             \add_boxed_types_for_var", [T'], []))
   141       | _ => exists_subterm (curry (op =) (Var z)) t' ? insert (op =) T
   142     fun box_var_in_def new_Ts old_Ts t (z as (_, T)) =
   143       case t of
   144         @{const Trueprop} $ t1 => box_var_in_def new_Ts old_Ts t1 z
   145       | Const (s0, _) $ t1 $ _ =>
   146         if s0 = @{const_name "=="} orelse s0 = @{const_name "op ="} then
   147           let
   148             val (t', args) = strip_comb t1
   149             val T' = fastype_of1 (new_Ts, do_term new_Ts old_Ts Neut t')
   150           in
   151             case fold (add_boxed_types_for_var z)
   152                       (fst (strip_n_binders (length args) T') ~~ args) [] of
   153               [T''] => T''
   154             | _ => T
   155           end
   156         else
   157           T
   158       | _ => T
   159     and do_quantifier new_Ts old_Ts polar quant_s quant_T abs_s abs_T t =
   160       let
   161         val abs_T' =
   162           if polar = Neut orelse is_positive_existential polar quant_s then
   163             box_type hol_ctxt InFunLHS abs_T
   164           else
   165             abs_T
   166         val body_T = body_type quant_T
   167       in
   168         Const (quant_s, (abs_T' --> body_T) --> body_T)
   169         $ Abs (abs_s, abs_T',
   170                t |> do_term (abs_T' :: new_Ts) (abs_T :: old_Ts) polar)
   171       end
   172     and do_equals new_Ts old_Ts s0 T0 t1 t2 =
   173       let
   174         val (t1, t2) = pairself (do_term new_Ts old_Ts Neut) (t1, t2)
   175         val (T1, T2) = pairself (curry fastype_of1 new_Ts) (t1, t2)
   176         val T = [T1, T2] |> sort Term_Ord.typ_ord |> List.last
   177       in
   178         list_comb (Const (s0, T --> T --> body_type T0),
   179                    map2 (coerce_term hol_ctxt new_Ts T) [T1, T2] [t1, t2])
   180       end
   181     and do_description_operator s T =
   182       let val T1 = box_type hol_ctxt InFunLHS (range_type T) in
   183         Const (s, (T1 --> bool_T) --> T1)
   184       end
   185     and do_term new_Ts old_Ts polar t =
   186       case t of
   187         Const (s0 as @{const_name all}, T0) $ Abs (s1, T1, t1) =>
   188         do_quantifier new_Ts old_Ts polar s0 T0 s1 T1 t1
   189       | Const (s0 as @{const_name "=="}, T0) $ t1 $ t2 =>
   190         do_equals new_Ts old_Ts s0 T0 t1 t2
   191       | @{const "==>"} $ t1 $ t2 =>
   192         @{const "==>"} $ do_term new_Ts old_Ts (flip_polarity polar) t1
   193         $ do_term new_Ts old_Ts polar t2
   194       | @{const Pure.conjunction} $ t1 $ t2 =>
   195         @{const Pure.conjunction} $ do_term new_Ts old_Ts polar t1
   196         $ do_term new_Ts old_Ts polar t2
   197       | @{const Trueprop} $ t1 =>
   198         @{const Trueprop} $ do_term new_Ts old_Ts polar t1
   199       | @{const Not} $ t1 =>
   200         @{const Not} $ do_term new_Ts old_Ts (flip_polarity polar) t1
   201       | Const (s0 as @{const_name All}, T0) $ Abs (s1, T1, t1) =>
   202         do_quantifier new_Ts old_Ts polar s0 T0 s1 T1 t1
   203       | Const (s0 as @{const_name Ex}, T0) $ Abs (s1, T1, t1) =>
   204         do_quantifier new_Ts old_Ts polar s0 T0 s1 T1 t1
   205       | Const (s0 as @{const_name "op ="}, T0) $ t1 $ t2 =>
   206         do_equals new_Ts old_Ts s0 T0 t1 t2
   207       | @{const "op &"} $ t1 $ t2 =>
   208         @{const "op &"} $ do_term new_Ts old_Ts polar t1
   209         $ do_term new_Ts old_Ts polar t2
   210       | @{const "op |"} $ t1 $ t2 =>
   211         @{const "op |"} $ do_term new_Ts old_Ts polar t1
   212         $ do_term new_Ts old_Ts polar t2
   213       | @{const "op -->"} $ t1 $ t2 =>
   214         @{const "op -->"} $ do_term new_Ts old_Ts (flip_polarity polar) t1
   215         $ do_term new_Ts old_Ts polar t2
   216       | Const (s as @{const_name The}, T) => do_description_operator s T
   217       | Const (s as @{const_name Eps}, T) => do_description_operator s T
   218       | Const (s as @{const_name safe_The}, T) => do_description_operator s T
   219       | Const (s as @{const_name safe_Eps}, T) => do_description_operator s T
   220       | Const (x as (s, T)) =>
   221         Const (s, if s = @{const_name converse} orelse
   222                      s = @{const_name trancl} then
   223                     box_relational_operator_type T
   224                   else if String.isPrefix quot_normal_prefix s then
   225                     let val T' = box_type hol_ctxt InFunLHS (domain_type T) in
   226                       T' --> T'
   227                     end
   228                   else if is_built_in_const thy stds fast_descrs x orelse
   229                           s = @{const_name Sigma} then
   230                     T
   231                   else if is_constr_like thy x then
   232                     box_type hol_ctxt InConstr T
   233                   else if is_sel s
   234                        orelse is_rep_fun thy x then
   235                     box_type hol_ctxt InSel T
   236                   else
   237                     box_type hol_ctxt InExpr T)
   238       | t1 $ Abs (s, T, t2') =>
   239         let
   240           val t1 = do_term new_Ts old_Ts Neut t1
   241           val T1 = fastype_of1 (new_Ts, t1)
   242           val (s1, Ts1) = dest_Type T1
   243           val T' = hd (snd (dest_Type (hd Ts1)))
   244           val t2 = Abs (s, T', do_term (T' :: new_Ts) (T :: old_Ts) Neut t2')
   245           val T2 = fastype_of1 (new_Ts, t2)
   246           val t2 = coerce_term hol_ctxt new_Ts (hd Ts1) T2 t2
   247         in
   248           betapply (if s1 = @{type_name fun} then
   249                       t1
   250                     else
   251                       select_nth_constr_arg thy stds
   252                           (@{const_name FunBox},
   253                            Type (@{type_name fun}, Ts1) --> T1) t1 0
   254                           (Type (@{type_name fun}, Ts1)), t2)
   255         end
   256       | t1 $ t2 =>
   257         let
   258           val t1 = do_term new_Ts old_Ts Neut t1
   259           val T1 = fastype_of1 (new_Ts, t1)
   260           val (s1, Ts1) = dest_Type T1
   261           val t2 = do_term new_Ts old_Ts Neut t2
   262           val T2 = fastype_of1 (new_Ts, t2)
   263           val t2 = coerce_term hol_ctxt new_Ts (hd Ts1) T2 t2
   264         in
   265           betapply (if s1 = @{type_name fun} then
   266                       t1
   267                     else
   268                       select_nth_constr_arg thy stds
   269                           (@{const_name FunBox},
   270                            Type (@{type_name fun}, Ts1) --> T1) t1 0
   271                           (Type (@{type_name fun}, Ts1)), t2)
   272         end
   273       | Free (s, T) => Free (s, box_type hol_ctxt InExpr T)
   274       | Var (z as (x, T)) =>
   275         Var (x, if def then box_var_in_def new_Ts old_Ts orig_t z
   276                 else box_type hol_ctxt InExpr T)
   277       | Bound _ => t
   278       | Abs (s, T, t') =>
   279         Abs (s, T, do_term (T :: new_Ts) (T :: old_Ts) Neut t')
   280   in do_term [] [] Pos orig_t end
   281 
   282 (** Destruction of constructors **)
   283 
   284 val val_var_prefix = nitpick_prefix ^ "v"
   285 
   286 fun fresh_value_var Ts k n j t =
   287   Var ((val_var_prefix ^ nat_subscript (n - j), k), fastype_of1 (Ts, t))
   288 
   289 fun has_heavy_bounds_or_vars Ts t =
   290   let
   291     fun aux [] = false
   292       | aux [T] = is_fun_type T orelse is_pair_type T
   293       | aux _ = true
   294   in aux (map snd (Term.add_vars t []) @ map (nth Ts) (loose_bnos t)) end
   295 
   296 fun pull_out_constr_comb ({thy, stds, ...} : hol_context) Ts relax k level t
   297                          args seen =
   298   let val t_comb = list_comb (t, args) in
   299     case t of
   300       Const x =>
   301       if not relax andalso is_constr thy stds x andalso
   302          not (is_fun_type (fastype_of1 (Ts, t_comb))) andalso
   303          has_heavy_bounds_or_vars Ts t_comb andalso
   304          not (loose_bvar (t_comb, level)) then
   305         let
   306           val (j, seen) = case find_index (curry (op =) t_comb) seen of
   307                             ~1 => (0, t_comb :: seen)
   308                           | j => (j, seen)
   309         in (fresh_value_var Ts k (length seen) j t_comb, seen) end
   310       else
   311         (t_comb, seen)
   312     | _ => (t_comb, seen)
   313   end
   314 
   315 fun equations_for_pulled_out_constrs mk_eq Ts k seen =
   316   let val n = length seen in
   317     map2 (fn j => fn t => mk_eq (fresh_value_var Ts k n j t, t))
   318          (index_seq 0 n) seen
   319   end
   320 
   321 fun pull_out_universal_constrs hol_ctxt def t =
   322   let
   323     val k = maxidx_of_term t + 1
   324     fun do_term Ts def t args seen =
   325       case t of
   326         (t0 as Const (@{const_name "=="}, _)) $ t1 $ t2 =>
   327         do_eq_or_imp Ts true def t0 t1 t2 seen
   328       | (t0 as @{const "==>"}) $ t1 $ t2 =>
   329         if def then (t, []) else do_eq_or_imp Ts false def t0 t1 t2 seen
   330       | (t0 as Const (@{const_name "op ="}, _)) $ t1 $ t2 =>
   331         do_eq_or_imp Ts true def t0 t1 t2 seen
   332       | (t0 as @{const "op -->"}) $ t1 $ t2 =>
   333         do_eq_or_imp Ts false def t0 t1 t2 seen
   334       | Abs (s, T, t') =>
   335         let val (t', seen) = do_term (T :: Ts) def t' [] seen in
   336           (list_comb (Abs (s, T, t'), args), seen)
   337         end
   338       | t1 $ t2 =>
   339         let val (t2, seen) = do_term Ts def t2 [] seen in
   340           do_term Ts def t1 (t2 :: args) seen
   341         end
   342       | _ => pull_out_constr_comb hol_ctxt Ts def k 0 t args seen
   343     and do_eq_or_imp Ts eq def t0 t1 t2 seen =
   344       let
   345         val (t2, seen) = if eq andalso def then (t2, seen)
   346                          else do_term Ts false t2 [] seen
   347         val (t1, seen) = do_term Ts false t1 [] seen
   348       in (t0 $ t1 $ t2, seen) end
   349     val (concl, seen) = do_term [] def t [] []
   350   in
   351     Logic.list_implies (equations_for_pulled_out_constrs Logic.mk_equals [] k
   352                                                          seen, concl)
   353   end
   354 
   355 fun mk_exists v t =
   356   HOLogic.exists_const (fastype_of v) $ lambda v (incr_boundvars 1 t)
   357 
   358 fun pull_out_existential_constrs hol_ctxt t =
   359   let
   360     val k = maxidx_of_term t + 1
   361     fun aux Ts num_exists t args seen =
   362       case t of
   363         (t0 as Const (@{const_name Ex}, _)) $ Abs (s1, T1, t1) =>
   364         let
   365           val (t1, seen') = aux (T1 :: Ts) (num_exists + 1) t1 [] []
   366           val n = length seen'
   367           fun vars () = map2 (fresh_value_var Ts k n) (index_seq 0 n) seen'
   368         in
   369           (equations_for_pulled_out_constrs HOLogic.mk_eq Ts k seen'
   370            |> List.foldl s_conj t1 |> fold mk_exists (vars ())
   371            |> curry3 Abs s1 T1 |> curry (op $) t0, seen)
   372         end
   373       | t1 $ t2 =>
   374         let val (t2, seen) = aux Ts num_exists t2 [] seen in
   375           aux Ts num_exists t1 (t2 :: args) seen
   376         end
   377       | Abs (s, T, t') =>
   378         let
   379           val (t', seen) = aux (T :: Ts) 0 t' [] (map (incr_boundvars 1) seen)
   380         in (list_comb (Abs (s, T, t'), args), map (incr_boundvars ~1) seen) end
   381       | _ =>
   382         if num_exists > 0 then
   383           pull_out_constr_comb hol_ctxt Ts false k num_exists t args seen
   384         else
   385           (list_comb (t, args), seen)
   386   in aux [] 0 t [] [] |> fst end
   387 
   388 val let_var_prefix = nitpick_prefix ^ "l"
   389 val let_inline_threshold = 32
   390 
   391 fun hol_let n abs_T body_T f t =
   392   if n * size_of_term t <= let_inline_threshold then
   393     f t
   394   else
   395     let val z = ((let_var_prefix, 0), abs_T) in
   396       Const (@{const_name Let}, abs_T --> (abs_T --> body_T) --> body_T)
   397       $ t $ abs_var z (incr_boundvars 1 (f (Var z)))
   398     end
   399 
   400 fun destroy_pulled_out_constrs (hol_ctxt as {thy, stds, ...}) axiom t =
   401   let
   402     val num_occs_of_var =
   403       fold_aterms (fn Var z => (fn f => fn z' => f z' |> z = z' ? Integer.add 1)
   404                     | _ => I) t (K 0)
   405     fun aux careful ((t0 as Const (@{const_name "=="}, _)) $ t1 $ t2) =
   406         aux_eq careful true t0 t1 t2
   407       | aux careful ((t0 as @{const "==>"}) $ t1 $ t2) =
   408         t0 $ aux false t1 $ aux careful t2
   409       | aux careful ((t0 as Const (@{const_name "op ="}, _)) $ t1 $ t2) =
   410         aux_eq careful true t0 t1 t2
   411       | aux careful ((t0 as @{const "op -->"}) $ t1 $ t2) =
   412         t0 $ aux false t1 $ aux careful t2
   413       | aux careful (Abs (s, T, t')) = Abs (s, T, aux careful t')
   414       | aux careful (t1 $ t2) = aux careful t1 $ aux careful t2
   415       | aux _ t = t
   416     and aux_eq careful pass1 t0 t1 t2 =
   417       ((if careful then
   418           raise SAME ()
   419         else if axiom andalso is_Var t2 andalso
   420                 num_occs_of_var (dest_Var t2) = 1 then
   421           @{const True}
   422         else case strip_comb t2 of
   423           (* The first case is not as general as it could be. *)
   424           (Const (@{const_name PairBox}, _),
   425                   [Const (@{const_name fst}, _) $ Var z1,
   426                    Const (@{const_name snd}, _) $ Var z2]) =>
   427           if z1 = z2 andalso num_occs_of_var z1 = 2 then @{const True}
   428           else raise SAME ()
   429         | (Const (x as (s, T)), args) =>
   430           let
   431             val (arg_Ts, dataT) = strip_type T
   432             val n = length arg_Ts
   433           in
   434             if length args = n andalso
   435                (is_constr thy stds x orelse s = @{const_name Pair} orelse
   436                 x = (@{const_name Suc}, nat_T --> nat_T)) andalso
   437                (not careful orelse not (is_Var t1) orelse
   438                 String.isPrefix val_var_prefix (fst (fst (dest_Var t1)))) then
   439                 hol_let (n + 1) dataT bool_T
   440                     (fn t1 => discriminate_value hol_ctxt x t1 ::
   441                               map3 (sel_eq x t1) (index_seq 0 n) arg_Ts args
   442                               |> foldr1 s_conj) t1
   443             else
   444               raise SAME ()
   445           end
   446         | _ => raise SAME ())
   447        |> body_type (type_of t0) = prop_T ? HOLogic.mk_Trueprop)
   448       handle SAME () => if pass1 then aux_eq careful false t0 t2 t1
   449                         else t0 $ aux false t2 $ aux false t1
   450     and sel_eq x t n nth_T nth_t =
   451       HOLogic.eq_const nth_T $ nth_t
   452                              $ select_nth_constr_arg thy stds x t n nth_T
   453       |> aux false
   454   in aux axiom t end
   455 
   456 (** Destruction of universal and existential equalities **)
   457 
   458 fun curry_assms (@{const "==>"} $ (@{const Trueprop}
   459                                    $ (@{const "op &"} $ t1 $ t2)) $ t3) =
   460     curry_assms (Logic.list_implies ([t1, t2] |> map HOLogic.mk_Trueprop, t3))
   461   | curry_assms (@{const "==>"} $ t1 $ t2) =
   462     @{const "==>"} $ curry_assms t1 $ curry_assms t2
   463   | curry_assms t = t
   464 
   465 val destroy_universal_equalities =
   466   let
   467     fun aux prems zs t =
   468       case t of
   469         @{const "==>"} $ t1 $ t2 => aux_implies prems zs t1 t2
   470       | _ => Logic.list_implies (rev prems, t)
   471     and aux_implies prems zs t1 t2 =
   472       case t1 of
   473         Const (@{const_name "=="}, _) $ Var z $ t' => aux_eq prems zs z t' t1 t2
   474       | @{const Trueprop} $ (Const (@{const_name "op ="}, _) $ Var z $ t') =>
   475         aux_eq prems zs z t' t1 t2
   476       | @{const Trueprop} $ (Const (@{const_name "op ="}, _) $ t' $ Var z) =>
   477         aux_eq prems zs z t' t1 t2
   478       | _ => aux (t1 :: prems) (Term.add_vars t1 zs) t2
   479     and aux_eq prems zs z t' t1 t2 =
   480       if not (member (op =) zs z) andalso
   481          not (exists_subterm (curry (op =) (Var z)) t') then
   482         aux prems zs (subst_free [(Var z, t')] t2)
   483       else
   484         aux (t1 :: prems) (Term.add_vars t1 zs) t2
   485   in aux [] [] end
   486 
   487 fun find_bound_assign thy stds j =
   488   let
   489     fun do_term _ [] = NONE
   490       | do_term seen (t :: ts) =
   491         let
   492           fun do_eq pass1 t1 t2 =
   493             (if loose_bvar1 (t2, j) then
   494                if pass1 then do_eq false t2 t1 else raise SAME ()
   495              else case t1 of
   496                Bound j' => if j' = j then SOME (t2, ts @ seen) else raise SAME ()
   497              | Const (s, Type (@{type_name fun}, [T1, T2])) $ Bound j' =>
   498                if j' = j andalso
   499                   s = nth_sel_name_for_constr_name @{const_name FunBox} 0 then
   500                  SOME (construct_value thy stds (@{const_name FunBox}, T2 --> T1)
   501                                        [t2], ts @ seen)
   502                else
   503                  raise SAME ()
   504              | _ => raise SAME ())
   505             handle SAME () => do_term (t :: seen) ts
   506         in
   507           case t of
   508             Const (@{const_name "op ="}, _) $ t1 $ t2 => do_eq true t1 t2
   509           | _ => do_term (t :: seen) ts
   510         end
   511   in do_term end
   512 
   513 fun subst_one_bound j arg t =
   514   let
   515     fun aux (Bound i, lev) =
   516         if i < lev then raise SAME ()
   517         else if i = lev then incr_boundvars (lev - j) arg
   518         else Bound (i - 1)
   519       | aux (Abs (a, T, body), lev) = Abs (a, T, aux (body, lev + 1))
   520       | aux (f $ t, lev) =
   521         (aux (f, lev) $ (aux (t, lev) handle SAME () => t)
   522          handle SAME () => f $ aux (t, lev))
   523       | aux _ = raise SAME ()
   524   in aux (t, j) handle SAME () => t end
   525 
   526 fun destroy_existential_equalities ({thy, stds, ...} : hol_context) =
   527   let
   528     fun kill [] [] ts = foldr1 s_conj ts
   529       | kill (s :: ss) (T :: Ts) ts =
   530         (case find_bound_assign thy stds (length ss) [] ts of
   531            SOME (_, []) => @{const True}
   532          | SOME (arg_t, ts) =>
   533            kill ss Ts (map (subst_one_bound (length ss)
   534                                 (incr_bv (~1, length ss + 1, arg_t))) ts)
   535          | NONE =>
   536            Const (@{const_name Ex}, (T --> bool_T) --> bool_T)
   537            $ Abs (s, T, kill ss Ts ts))
   538       | kill _ _ _ = raise UnequalLengths
   539     fun gather ss Ts (Const (@{const_name Ex}, _) $ Abs (s1, T1, t1)) =
   540         gather (ss @ [s1]) (Ts @ [T1]) t1
   541       | gather [] [] (Abs (s, T, t1)) = Abs (s, T, gather [] [] t1)
   542       | gather [] [] (t1 $ t2) = gather [] [] t1 $ gather [] [] t2
   543       | gather [] [] t = t
   544       | gather ss Ts t = kill ss Ts (conjuncts_of (gather [] [] t))
   545   in gather [] [] end
   546 
   547 (** Skolemization **)
   548 
   549 fun skolem_prefix_for k j =
   550   skolem_prefix ^ string_of_int k ^ "@" ^ string_of_int j ^ name_sep
   551 
   552 fun skolemize_term_and_more (hol_ctxt as {thy, def_table, skolems, ...})
   553                             skolem_depth =
   554   let
   555     val incrs = map (Integer.add 1)
   556     fun aux ss Ts js depth polar t =
   557       let
   558         fun do_quantifier quant_s quant_T abs_s abs_T t =
   559           if not (loose_bvar1 (t, 0)) then
   560             aux ss Ts js depth polar (incr_boundvars ~1 t)
   561           else if depth <= skolem_depth andalso
   562                   is_positive_existential polar quant_s then
   563             let
   564               val j = length (!skolems) + 1
   565               val sko_s = skolem_prefix_for (length js) j ^ abs_s
   566               val _ = Unsynchronized.change skolems (cons (sko_s, ss))
   567               val sko_t = list_comb (Const (sko_s, rev Ts ---> abs_T),
   568                                      map Bound (rev js))
   569               val abs_t = Abs (abs_s, abs_T, aux ss Ts (incrs js) depth polar t)
   570             in
   571               if null js then betapply (abs_t, sko_t)
   572               else Const (@{const_name Let}, abs_T --> quant_T) $ sko_t $ abs_t
   573             end
   574           else
   575             Const (quant_s, quant_T)
   576             $ Abs (abs_s, abs_T,
   577                    if is_higher_order_type abs_T then
   578                      t
   579                    else
   580                      aux (abs_s :: ss) (abs_T :: Ts) (0 :: incrs js)
   581                          (depth + 1) polar t)
   582       in
   583         case t of
   584           Const (s0 as @{const_name all}, T0) $ Abs (s1, T1, t1) =>
   585           do_quantifier s0 T0 s1 T1 t1
   586         | @{const "==>"} $ t1 $ t2 =>
   587           @{const "==>"} $ aux ss Ts js depth (flip_polarity polar) t1
   588           $ aux ss Ts js depth polar t2
   589         | @{const Pure.conjunction} $ t1 $ t2 =>
   590           @{const Pure.conjunction} $ aux ss Ts js depth polar t1
   591           $ aux ss Ts js depth polar t2
   592         | @{const Trueprop} $ t1 =>
   593           @{const Trueprop} $ aux ss Ts js depth polar t1
   594         | @{const Not} $ t1 =>
   595           @{const Not} $ aux ss Ts js depth (flip_polarity polar) t1
   596         | Const (s0 as @{const_name All}, T0) $ Abs (s1, T1, t1) =>
   597           do_quantifier s0 T0 s1 T1 t1
   598         | Const (s0 as @{const_name Ex}, T0) $ Abs (s1, T1, t1) =>
   599           do_quantifier s0 T0 s1 T1 t1
   600         | @{const "op &"} $ t1 $ t2 =>
   601           @{const "op &"} $ aux ss Ts js depth polar t1
   602           $ aux ss Ts js depth polar t2
   603         | @{const "op |"} $ t1 $ t2 =>
   604           @{const "op |"} $ aux ss Ts js depth polar t1
   605           $ aux ss Ts js depth polar t2
   606         | @{const "op -->"} $ t1 $ t2 =>
   607           @{const "op -->"} $ aux ss Ts js depth (flip_polarity polar) t1
   608           $ aux ss Ts js depth polar t2
   609         | (t0 as Const (@{const_name Let}, _)) $ t1 $ t2 =>
   610           t0 $ t1 $ aux ss Ts js depth polar t2
   611         | Const (x as (s, T)) =>
   612           if is_inductive_pred hol_ctxt x andalso
   613              not (is_well_founded_inductive_pred hol_ctxt x) then
   614             let
   615               val gfp = (fixpoint_kind_of_const thy def_table x = Gfp)
   616               val (pref, connective, set_oper) =
   617                 if gfp then
   618                   (lbfp_prefix, @{const "op |"},
   619                    @{const_name semilattice_sup_class.sup})
   620                 else
   621                   (ubfp_prefix, @{const "op &"},
   622                    @{const_name semilattice_inf_class.inf})
   623               fun pos () = unrolled_inductive_pred_const hol_ctxt gfp x
   624                            |> aux ss Ts js depth polar
   625               fun neg () = Const (pref ^ s, T)
   626             in
   627               (case polar |> gfp ? flip_polarity of
   628                  Pos => pos ()
   629                | Neg => neg ()
   630                | Neut =>
   631                  if is_fun_type T then
   632                    let
   633                      val ((trunk_arg_Ts, rump_arg_T), body_T) =
   634                        T |> strip_type |>> split_last
   635                      val set_T = rump_arg_T --> body_T
   636                      fun app f =
   637                        list_comb (f (),
   638                                   map Bound (length trunk_arg_Ts - 1 downto 0))
   639                    in
   640                      List.foldr absdummy
   641                                 (Const (set_oper, set_T --> set_T --> set_T)
   642                                         $ app pos $ app neg) trunk_arg_Ts
   643                    end
   644                  else
   645                    connective $ pos () $ neg ())
   646             end
   647           else
   648             Const x
   649         | t1 $ t2 =>
   650           betapply (aux ss Ts [] (skolem_depth + 1) polar t1,
   651                     aux ss Ts [] depth Neut t2)
   652         | Abs (s, T, t1) => Abs (s, T, aux ss Ts (incrs js) depth polar t1)
   653         | _ => t
   654       end
   655   in aux [] [] [] 0 Pos end
   656 
   657 (** Function specialization **)
   658 
   659 fun params_in_equation (@{const "==>"} $ _ $ t2) = params_in_equation t2
   660   | params_in_equation (@{const Trueprop} $ t1) = params_in_equation t1
   661   | params_in_equation (Const (@{const_name "op ="}, _) $ t1 $ _) =
   662     snd (strip_comb t1)
   663   | params_in_equation _ = []
   664 
   665 fun specialize_fun_axiom x x' fixed_js fixed_args extra_args t =
   666   let
   667     val k = fold Integer.max (map maxidx_of_term (fixed_args @ extra_args)) 0
   668             + 1
   669     val t = map_aterms (fn Var ((s, i), T) => Var ((s, k + i), T) | t' => t') t
   670     val fixed_params = filter_indices fixed_js (params_in_equation t)
   671     fun aux args (Abs (s, T, t)) = list_comb (Abs (s, T, aux [] t), args)
   672       | aux args (t1 $ t2) = aux (aux [] t2 :: args) t1
   673       | aux args t =
   674         if t = Const x then
   675           list_comb (Const x', extra_args @ filter_out_indices fixed_js args)
   676         else
   677           let val j = find_index (curry (op =) t) fixed_params in
   678             list_comb (if j >= 0 then nth fixed_args j else t, args)
   679           end
   680   in aux [] t end
   681 
   682 fun static_args_in_term ({ersatz_table, ...} : hol_context) x t =
   683   let
   684     fun fun_calls (Abs (_, _, t)) _ = fun_calls t []
   685       | fun_calls (t1 $ t2) args = fun_calls t2 [] #> fun_calls t1 (t2 :: args)
   686       | fun_calls t args =
   687         (case t of
   688            Const (x' as (s', T')) =>
   689            x = x' orelse (case AList.lookup (op =) ersatz_table s' of
   690                             SOME s'' => x = (s'', T')
   691                           | NONE => false)
   692          | _ => false) ? cons args
   693     fun call_sets [] [] vs = [vs]
   694       | call_sets [] uss vs = vs :: call_sets uss [] []
   695       | call_sets ([] :: _) _ _ = []
   696       | call_sets ((t :: ts) :: tss) uss vs =
   697         OrdList.insert Term_Ord.term_ord t vs |> call_sets tss (ts :: uss)
   698     val sets = call_sets (fun_calls t [] []) [] []
   699     val indexed_sets = sets ~~ (index_seq 0 (length sets))
   700   in
   701     fold_rev (fn (set, j) =>
   702                  case set of
   703                    [Var _] => AList.lookup (op =) indexed_sets set = SOME j
   704                               ? cons (j, NONE)
   705                  | [t as Const _] => cons (j, SOME t)
   706                  | [t as Free _] => cons (j, SOME t)
   707                  | _ => I) indexed_sets []
   708   end
   709 fun static_args_in_terms hol_ctxt x =
   710   map (static_args_in_term hol_ctxt x)
   711   #> fold1 (OrdList.inter (prod_ord int_ord (option_ord Term_Ord.term_ord)))
   712 
   713 fun overlapping_indices [] _ = []
   714   | overlapping_indices _ [] = []
   715   | overlapping_indices (ps1 as (j1, t1) :: ps1') (ps2 as (j2, t2) :: ps2') =
   716     if j1 < j2 then overlapping_indices ps1' ps2
   717     else if j1 > j2 then overlapping_indices ps1 ps2'
   718     else overlapping_indices ps1' ps2' |> the_default t2 t1 = t2 ? cons j1
   719 
   720 fun is_eligible_arg Ts t =
   721   let val bad_Ts = map snd (Term.add_vars t []) @ map (nth Ts) (loose_bnos t) in
   722     null bad_Ts orelse
   723     (is_higher_order_type (fastype_of1 (Ts, t)) andalso
   724      forall (not o is_higher_order_type) bad_Ts)
   725   end
   726 
   727 fun special_prefix_for j = special_prefix ^ string_of_int j ^ name_sep
   728 
   729 (* If a constant's definition is picked up deeper than this threshold, we
   730    prevent excessive specialization by not specializing it. *)
   731 val special_max_depth = 20
   732 
   733 val bound_var_prefix = "b"
   734 
   735 fun specialize_consts_in_term (hol_ctxt as {specialize, simp_table,
   736                                             special_funs, ...}) depth t =
   737   if not specialize orelse depth > special_max_depth then
   738     t
   739   else
   740     let
   741       val blacklist = if depth = 0 then []
   742                       else case term_under_def t of Const x => [x] | _ => []
   743       fun aux args Ts (Const (x as (s, T))) =
   744           ((if not (member (op =) blacklist x) andalso not (null args) andalso
   745                not (String.isPrefix special_prefix s) andalso
   746                is_equational_fun hol_ctxt x then
   747               let
   748                 val eligible_args = filter (is_eligible_arg Ts o snd)
   749                                            (index_seq 0 (length args) ~~ args)
   750                 val _ = not (null eligible_args) orelse raise SAME ()
   751                 val old_axs = equational_fun_axioms hol_ctxt x
   752                               |> map (destroy_existential_equalities hol_ctxt)
   753                 val static_params = static_args_in_terms hol_ctxt x old_axs
   754                 val fixed_js = overlapping_indices static_params eligible_args
   755                 val _ = not (null fixed_js) orelse raise SAME ()
   756                 val fixed_args = filter_indices fixed_js args
   757                 val vars = fold Term.add_vars fixed_args []
   758                            |> sort (Term_Ord.fast_indexname_ord o pairself fst)
   759                 val bound_js = fold (fn t => fn js => add_loose_bnos (t, 0, js))
   760                                     fixed_args []
   761                                |> sort int_ord
   762                 val live_args = filter_out_indices fixed_js args
   763                 val extra_args = map Var vars @ map Bound bound_js @ live_args
   764                 val extra_Ts = map snd vars @ filter_indices bound_js Ts
   765                 val k = maxidx_of_term t + 1
   766                 fun var_for_bound_no j =
   767                   Var ((bound_var_prefix ^
   768                         nat_subscript (find_index (curry (op =) j) bound_js
   769                                        + 1), k),
   770                        nth Ts j)
   771                 val fixed_args_in_axiom =
   772                   map (curry subst_bounds
   773                              (map var_for_bound_no (index_seq 0 (length Ts))))
   774                       fixed_args
   775               in
   776                 case AList.lookup (op =) (!special_funs)
   777                                   (x, fixed_js, fixed_args_in_axiom) of
   778                   SOME x' => list_comb (Const x', extra_args)
   779                 | NONE =>
   780                   let
   781                     val extra_args_in_axiom =
   782                       map Var vars @ map var_for_bound_no bound_js
   783                     val x' as (s', _) =
   784                       (special_prefix_for (length (!special_funs) + 1) ^ s,
   785                        extra_Ts @ filter_out_indices fixed_js (binder_types T)
   786                        ---> body_type T)
   787                     val new_axs =
   788                       map (specialize_fun_axiom x x' fixed_js
   789                                fixed_args_in_axiom extra_args_in_axiom) old_axs
   790                     val _ =
   791                       Unsynchronized.change special_funs
   792                           (cons ((x, fixed_js, fixed_args_in_axiom), x'))
   793                     val _ = add_simps simp_table s' new_axs
   794                   in list_comb (Const x', extra_args) end
   795               end
   796             else
   797               raise SAME ())
   798            handle SAME () => list_comb (Const x, args))
   799         | aux args Ts (Abs (s, T, t)) =
   800           list_comb (Abs (s, T, aux [] (T :: Ts) t), args)
   801         | aux args Ts (t1 $ t2) = aux (aux [] Ts t2 :: args) Ts t1
   802         | aux args _ t = list_comb (t, args)
   803     in aux [] [] t end
   804 
   805 type special_triple = int list * term list * styp
   806 
   807 val cong_var_prefix = "c"
   808 
   809 fun special_congruence_axiom T (js1, ts1, x1) (js2, ts2, x2) =
   810   let
   811     val (bounds1, bounds2) = pairself (map Var o special_bounds) (ts1, ts2)
   812     val Ts = binder_types T
   813     val max_j = fold (fold Integer.max) [js1, js2] ~1
   814     val (eqs, (args1, args2)) =
   815       fold (fn j => case pairself (fn ps => AList.lookup (op =) ps j)
   816                                   (js1 ~~ ts1, js2 ~~ ts2) of
   817                       (SOME t1, SOME t2) => apfst (cons (t1, t2))
   818                     | (SOME t1, NONE) => apsnd (apsnd (cons t1))
   819                     | (NONE, SOME t2) => apsnd (apfst (cons t2))
   820                     | (NONE, NONE) =>
   821                       let val v = Var ((cong_var_prefix ^ nat_subscript j, 0),
   822                                        nth Ts j) in
   823                         apsnd (pairself (cons v))
   824                       end) (max_j downto 0) ([], ([], []))
   825   in
   826     Logic.list_implies (eqs |> filter_out (op =) |> distinct (op =)
   827                             |> map Logic.mk_equals,
   828                         Logic.mk_equals (list_comb (Const x1, bounds1 @ args1),
   829                                          list_comb (Const x2, bounds2 @ args2)))
   830     |> close_form (* TODO: needed? *)
   831   end
   832 
   833 fun special_congruence_axioms (hol_ctxt as {special_funs, ...}) xs =
   834   let
   835     val groups =
   836       !special_funs
   837       |> map (fn ((x, js, ts), x') => (x, (js, ts, x')))
   838       |> AList.group (op =)
   839       |> filter_out (is_equational_fun_surely_complete hol_ctxt o fst)
   840       |> map (fn (x, zs) => (x, zs |> member (op =) xs x ? cons ([], [], x)))
   841     fun generality (js, _, _) = ~(length js)
   842     fun is_more_specific (j1, t1, x1) (j2, t2, x2) =
   843       x1 <> x2 andalso OrdList.subset (prod_ord int_ord Term_Ord.term_ord)
   844                                       (j2 ~~ t2, j1 ~~ t1)
   845     fun do_pass_1 _ [] [_] [_] = I
   846       | do_pass_1 T skipped _ [] = do_pass_2 T skipped
   847       | do_pass_1 T skipped all (z :: zs) =
   848         case filter (is_more_specific z) all
   849              |> sort (int_ord o pairself generality) of
   850           [] => do_pass_1 T (z :: skipped) all zs
   851         | (z' :: _) => cons (special_congruence_axiom T z z')
   852                        #> do_pass_1 T skipped all zs
   853     and do_pass_2 _ [] = I
   854       | do_pass_2 T (z :: zs) =
   855         fold (cons o special_congruence_axiom T z) zs #> do_pass_2 T zs
   856   in fold (fn ((_, T), zs) => do_pass_1 T [] zs zs) groups [] end
   857 
   858 (** Axiom selection **)
   859 
   860 fun all_table_entries table = Symtab.fold (append o snd) table []
   861 fun extra_table table s = Symtab.make [(s, all_table_entries table)]
   862 
   863 fun eval_axiom_for_term j t =
   864   Logic.mk_equals (Const (eval_prefix ^ string_of_int j, fastype_of t), t)
   865 
   866 val is_trivial_equation = the_default false o try (op aconv o Logic.dest_equals)
   867 
   868 (* Prevents divergence in case of cyclic or infinite axiom dependencies. *)
   869 val axioms_max_depth = 255
   870 
   871 fun axioms_for_term
   872         (hol_ctxt as {thy, ctxt, max_bisim_depth, stds, user_axioms,
   873                       fast_descrs, evals, def_table, nondef_table,
   874                       choice_spec_table, user_nondefs, ...}) t =
   875   let
   876     type accumulator = styp list * (term list * term list)
   877     fun add_axiom get app depth t (accum as (xs, axs)) =
   878       let
   879         val t = t |> unfold_defs_in_term hol_ctxt
   880                   |> skolemize_term_and_more hol_ctxt ~1
   881       in
   882         if is_trivial_equation t then
   883           accum
   884         else
   885           let val t' = t |> specialize_consts_in_term hol_ctxt depth in
   886             if exists (member (op aconv) (get axs)) [t, t'] then accum
   887             else add_axioms_for_term (depth + 1) t' (xs, app (cons t') axs)
   888           end
   889       end
   890     and add_def_axiom depth = add_axiom fst apfst depth
   891     and add_nondef_axiom depth = add_axiom snd apsnd depth
   892     and add_maybe_def_axiom depth t =
   893       (if head_of t <> @{const "==>"} then add_def_axiom
   894        else add_nondef_axiom) depth t
   895     and add_eq_axiom depth t =
   896       (if is_constr_pattern_formula thy t then add_def_axiom
   897        else add_nondef_axiom) depth t
   898     and add_axioms_for_term depth t (accum as (xs, axs)) =
   899       case t of
   900         t1 $ t2 => accum |> fold (add_axioms_for_term depth) [t1, t2]
   901       | Const (x as (s, T)) =>
   902         (if member (op =) xs x orelse
   903             is_built_in_const thy stds fast_descrs x then
   904            accum
   905          else
   906            let val accum = (x :: xs, axs) in
   907              if depth > axioms_max_depth then
   908                raise TOO_LARGE ("Nitpick_Preproc.axioms_for_term.\
   909                                 \add_axioms_for_term",
   910                                 "too many nested axioms (" ^
   911                                 string_of_int depth ^ ")")
   912              else if Refute.is_const_of_class thy x then
   913                let
   914                  val class = Logic.class_of_const s
   915                  val of_class = Logic.mk_of_class (TVar (("'a", 0), [class]),
   916                                                    class)
   917                  val ax1 = try (specialize_type thy x) of_class
   918                  val ax2 = Option.map (specialize_type thy x o snd)
   919                                       (Refute.get_classdef thy class)
   920                in
   921                  fold (add_maybe_def_axiom depth) (map_filter I [ax1, ax2])
   922                       accum
   923                end
   924              else if is_constr thy stds x then
   925                accum
   926              else if is_equational_fun hol_ctxt x then
   927                fold (add_eq_axiom depth) (equational_fun_axioms hol_ctxt x)
   928                     accum
   929              else if is_choice_spec_fun hol_ctxt x then
   930                fold (add_nondef_axiom depth)
   931                     (nondef_props_for_const thy true choice_spec_table x) accum
   932              else if is_abs_fun thy x then
   933                if is_quot_type thy (range_type T) then
   934                  raise NOT_SUPPORTED "\"Abs_\" function of quotient type"
   935                else
   936                  accum |> fold (add_nondef_axiom depth)
   937                                (nondef_props_for_const thy false nondef_table x)
   938                        |> (is_funky_typedef thy (range_type T) orelse
   939                            range_type T = nat_T)
   940                           ? fold (add_maybe_def_axiom depth)
   941                                  (nondef_props_for_const thy true
   942                                                     (extra_table def_table s) x)
   943              else if is_rep_fun thy x then
   944                if is_quot_type thy (domain_type T) then
   945                  raise NOT_SUPPORTED "\"Rep_\" function of quotient type"
   946                else
   947                  accum |> fold (add_nondef_axiom depth)
   948                                (nondef_props_for_const thy false nondef_table x)
   949                        |> (is_funky_typedef thy (range_type T) orelse
   950                            range_type T = nat_T)
   951                           ? fold (add_maybe_def_axiom depth)
   952                                  (nondef_props_for_const thy true
   953                                                     (extra_table def_table s) x)
   954                        |> add_axioms_for_term depth
   955                                               (Const (mate_of_rep_fun thy x))
   956                        |> fold (add_def_axiom depth)
   957                                (inverse_axioms_for_rep_fun thy x)
   958              else
   959                accum |> user_axioms <> SOME false
   960                         ? fold (add_nondef_axiom depth)
   961                                (nondef_props_for_const thy false nondef_table x)
   962            end)
   963         |> add_axioms_for_type depth T
   964       | Free (_, T) => add_axioms_for_type depth T accum
   965       | Var (_, T) => add_axioms_for_type depth T accum
   966       | Bound _ => accum
   967       | Abs (_, T, t) => accum |> add_axioms_for_term depth t
   968                                |> add_axioms_for_type depth T
   969     and add_axioms_for_type depth T =
   970       case T of
   971         Type (@{type_name fun}, Ts) => fold (add_axioms_for_type depth) Ts
   972       | Type (@{type_name "*"}, Ts) => fold (add_axioms_for_type depth) Ts
   973       | @{typ prop} => I
   974       | @{typ bool} => I
   975       | @{typ unit} => I
   976       | TFree (_, S) => add_axioms_for_sort depth T S
   977       | TVar (_, S) => add_axioms_for_sort depth T S
   978       | Type (z as (_, Ts)) =>
   979         fold (add_axioms_for_type depth) Ts
   980         #> (if is_pure_typedef thy T then
   981               fold (add_maybe_def_axiom depth) (optimized_typedef_axioms thy z)
   982             else if is_quot_type thy T then
   983               fold (add_def_axiom depth)
   984                    (optimized_quot_type_axioms ctxt stds z)
   985             else if max_bisim_depth >= 0 andalso is_codatatype thy T then
   986               fold (add_maybe_def_axiom depth)
   987                    (codatatype_bisim_axioms hol_ctxt T)
   988             else
   989               I)
   990     and add_axioms_for_sort depth T S =
   991       let
   992         val supers = Sign.complete_sort thy S
   993         val class_axioms =
   994           maps (fn class => map prop_of (AxClass.get_info thy class |> #axioms
   995                                          handle ERROR _ => [])) supers
   996         val monomorphic_class_axioms =
   997           map (fn t => case Term.add_tvars t [] of
   998                          [] => t
   999                        | [(x, S)] =>
  1000                          monomorphic_term (Vartab.make [(x, (S, T))]) t
  1001                        | _ => raise TERM ("Nitpick_Preproc.axioms_for_term.\
  1002                                           \add_axioms_for_sort", [t]))
  1003               class_axioms
  1004       in fold (add_nondef_axiom depth) monomorphic_class_axioms end
  1005     val (mono_user_nondefs, poly_user_nondefs) =
  1006       List.partition (null o Term.hidden_polymorphism) user_nondefs
  1007     val eval_axioms = map2 eval_axiom_for_term (index_seq 0 (length evals))
  1008                            evals
  1009     val (xs, (defs, nondefs)) =
  1010       ([], ([], [])) |> add_axioms_for_term 1 t 
  1011                      |> fold_rev (add_def_axiom 1) eval_axioms
  1012                      |> user_axioms = SOME true
  1013                         ? fold (add_nondef_axiom 1) mono_user_nondefs
  1014     val defs = defs @ special_congruence_axioms hol_ctxt xs
  1015     val got_all_mono_user_axioms =
  1016       (user_axioms = SOME true orelse null mono_user_nondefs)
  1017   in (t :: nondefs, defs, got_all_mono_user_axioms, null poly_user_nondefs) end
  1018 
  1019 (** Simplification of constructor/selector terms **)
  1020 
  1021 fun simplify_constrs_and_sels thy t =
  1022   let
  1023     fun is_nth_sel_on t' n (Const (s, _) $ t) =
  1024         (t = t' andalso is_sel_like_and_no_discr s andalso
  1025          sel_no_from_name s = n)
  1026       | is_nth_sel_on _ _ _ = false
  1027     fun do_term (Const (@{const_name Rep_Frac}, _)
  1028                  $ (Const (@{const_name Abs_Frac}, _) $ t1)) [] = do_term t1 []
  1029       | do_term (Const (@{const_name Abs_Frac}, _)
  1030                  $ (Const (@{const_name Rep_Frac}, _) $ t1)) [] = do_term t1 []
  1031       | do_term (t1 $ t2) args = do_term t1 (do_term t2 [] :: args)
  1032       | do_term (t as Const (x as (s, T))) (args as _ :: _) =
  1033         ((if is_constr_like thy x then
  1034             if length args = num_binder_types T then
  1035               case hd args of
  1036                 Const (_, T') $ t' =>
  1037                 if domain_type T' = body_type T andalso
  1038                    forall (uncurry (is_nth_sel_on t'))
  1039                           (index_seq 0 (length args) ~~ args) then
  1040                   t'
  1041                 else
  1042                   raise SAME ()
  1043               | _ => raise SAME ()
  1044             else
  1045               raise SAME ()
  1046           else if is_sel_like_and_no_discr s then
  1047             case strip_comb (hd args) of
  1048               (Const (x' as (s', T')), ts') =>
  1049               if is_constr_like thy x' andalso
  1050                  constr_name_for_sel_like s = s' andalso
  1051                  not (exists is_pair_type (binder_types T')) then
  1052                 list_comb (nth ts' (sel_no_from_name s), tl args)
  1053               else
  1054                 raise SAME ()
  1055             | _ => raise SAME ()
  1056           else
  1057             raise SAME ())
  1058          handle SAME () => betapplys (t, args))
  1059       | do_term (Abs (s, T, t')) args =
  1060         betapplys (Abs (s, T, do_term t' []), args)
  1061       | do_term t args = betapplys (t, args)
  1062   in do_term t [] end
  1063 
  1064 (** Quantifier massaging: Distributing quantifiers **)
  1065 
  1066 fun distribute_quantifiers t =
  1067   case t of
  1068     (t0 as Const (@{const_name All}, T0)) $ Abs (s, T1, t1) =>
  1069     (case t1 of
  1070        (t10 as @{const "op &"}) $ t11 $ t12 =>
  1071        t10 $ distribute_quantifiers (t0 $ Abs (s, T1, t11))
  1072            $ distribute_quantifiers (t0 $ Abs (s, T1, t12))
  1073      | (t10 as @{const Not}) $ t11 =>
  1074        t10 $ distribute_quantifiers (Const (@{const_name Ex}, T0)
  1075                                      $ Abs (s, T1, t11))
  1076      | t1 =>
  1077        if not (loose_bvar1 (t1, 0)) then
  1078          distribute_quantifiers (incr_boundvars ~1 t1)
  1079        else
  1080          t0 $ Abs (s, T1, distribute_quantifiers t1))
  1081   | (t0 as Const (@{const_name Ex}, T0)) $ Abs (s, T1, t1) =>
  1082     (case distribute_quantifiers t1 of
  1083        (t10 as @{const "op |"}) $ t11 $ t12 =>
  1084        t10 $ distribute_quantifiers (t0 $ Abs (s, T1, t11))
  1085            $ distribute_quantifiers (t0 $ Abs (s, T1, t12))
  1086      | (t10 as @{const "op -->"}) $ t11 $ t12 =>
  1087        t10 $ distribute_quantifiers (Const (@{const_name All}, T0)
  1088                                      $ Abs (s, T1, t11))
  1089            $ distribute_quantifiers (t0 $ Abs (s, T1, t12))
  1090      | (t10 as @{const Not}) $ t11 =>
  1091        t10 $ distribute_quantifiers (Const (@{const_name All}, T0)
  1092                                      $ Abs (s, T1, t11))
  1093      | t1 =>
  1094        if not (loose_bvar1 (t1, 0)) then
  1095          distribute_quantifiers (incr_boundvars ~1 t1)
  1096        else
  1097          t0 $ Abs (s, T1, distribute_quantifiers t1))
  1098   | t1 $ t2 => distribute_quantifiers t1 $ distribute_quantifiers t2
  1099   | Abs (s, T, t') => Abs (s, T, distribute_quantifiers t')
  1100   | _ => t
  1101 
  1102 (** Quantifier massaging: Pushing quantifiers inward **)
  1103 
  1104 fun renumber_bounds j n f t =
  1105   case t of
  1106     t1 $ t2 => renumber_bounds j n f t1 $ renumber_bounds j n f t2
  1107   | Abs (s, T, t') => Abs (s, T, renumber_bounds (j + 1) n f t')
  1108   | Bound j' =>
  1109     Bound (if j' >= j andalso j' < j + n then f (j' - j) + j else j')
  1110   | _ => t
  1111 
  1112 (* Maximum number of quantifiers in a cluster for which the exponential
  1113    algorithm is used. Larger clusters use a heuristic inspired by Claessen &
  1114    Soerensson's polynomial binary splitting procedure (p. 5 of their MODEL 2003
  1115    paper). *)
  1116 val quantifier_cluster_threshold = 7
  1117 
  1118 val push_quantifiers_inward =
  1119   let
  1120     fun aux quant_s ss Ts t =
  1121       (case t of
  1122          Const (s0, _) $ Abs (s1, T1, t1 as _ $ _) =>
  1123          if s0 = quant_s then
  1124            aux s0 (s1 :: ss) (T1 :: Ts) t1
  1125          else if quant_s = "" andalso
  1126                  (s0 = @{const_name All} orelse s0 = @{const_name Ex}) then
  1127            aux s0 [s1] [T1] t1
  1128          else
  1129            raise SAME ()
  1130        | _ => raise SAME ())
  1131       handle SAME () =>
  1132              case t of
  1133                t1 $ t2 =>
  1134                if quant_s = "" then
  1135                  aux "" [] [] t1 $ aux "" [] [] t2
  1136                else
  1137                  let
  1138                    val typical_card = 4
  1139                    fun big_union proj ps =
  1140                      fold (fold (insert (op =)) o proj) ps []
  1141                    val (ts, connective) = strip_any_connective t
  1142                    val T_costs =
  1143                      map (bounded_card_of_type 65536 typical_card []) Ts
  1144                    val t_costs = map size_of_term ts
  1145                    val num_Ts = length Ts
  1146                    val flip = curry (op -) (num_Ts - 1)
  1147                    val t_boundss = map (map flip o loose_bnos) ts
  1148                    fun merge costly_boundss [] = costly_boundss
  1149                      | merge costly_boundss (j :: js) =
  1150                        let
  1151                          val (yeas, nays) =
  1152                            List.partition (fn (bounds, _) =>
  1153                                               member (op =) bounds j)
  1154                                           costly_boundss
  1155                          val yeas_bounds = big_union fst yeas
  1156                          val yeas_cost = Integer.sum (map snd yeas)
  1157                                          * nth T_costs j
  1158                        in merge ((yeas_bounds, yeas_cost) :: nays) js end
  1159                    val cost = Integer.sum o map snd oo merge
  1160                    fun heuristically_best_permutation _ [] = []
  1161                      | heuristically_best_permutation costly_boundss js =
  1162                        let
  1163                          val (costly_boundss, (j, js)) =
  1164                            js |> map (`(merge costly_boundss o single))
  1165                               |> sort (int_ord
  1166                                        o pairself (Integer.sum o map snd o fst))
  1167                               |> split_list |>> hd ||> pairf hd tl
  1168                        in
  1169                          j :: heuristically_best_permutation costly_boundss js
  1170                        end
  1171                    val js =
  1172                      if length Ts <= quantifier_cluster_threshold then
  1173                        all_permutations (index_seq 0 num_Ts)
  1174                        |> map (`(cost (t_boundss ~~ t_costs)))
  1175                        |> sort (int_ord o pairself fst) |> hd |> snd
  1176                      else
  1177                        heuristically_best_permutation (t_boundss ~~ t_costs)
  1178                                                       (index_seq 0 num_Ts)
  1179                    val back_js = map (fn j => find_index (curry (op =) j) js)
  1180                                      (index_seq 0 num_Ts)
  1181                    val ts = map (renumber_bounds 0 num_Ts (nth back_js o flip))
  1182                                 ts
  1183                    fun mk_connection [] =
  1184                        raise ARG ("Nitpick_Preproc.push_quantifiers_inward.aux.\
  1185                                   \mk_connection", "")
  1186                      | mk_connection ts_cum_bounds =
  1187                        ts_cum_bounds |> map fst
  1188                        |> foldr1 (fn (t1, t2) => connective $ t1 $ t2)
  1189                    fun build ts_cum_bounds [] = ts_cum_bounds |> mk_connection
  1190                      | build ts_cum_bounds (j :: js) =
  1191                        let
  1192                          val (yeas, nays) =
  1193                            List.partition (fn (_, bounds) =>
  1194                                               member (op =) bounds j)
  1195                                           ts_cum_bounds
  1196                            ||> map (apfst (incr_boundvars ~1))
  1197                        in
  1198                          if null yeas then
  1199                            build nays js
  1200                          else
  1201                            let val T = nth Ts (flip j) in
  1202                              build ((Const (quant_s, (T --> bool_T) --> bool_T)
  1203                                      $ Abs (nth ss (flip j), T,
  1204                                             mk_connection yeas),
  1205                                       big_union snd yeas) :: nays) js
  1206                            end
  1207                        end
  1208                  in build (ts ~~ t_boundss) js end
  1209              | Abs (s, T, t') => Abs (s, T, aux "" [] [] t')
  1210              | _ => t
  1211   in aux "" [] [] end
  1212 
  1213 (** Inference of finite functions **)
  1214 
  1215 fun finitize_all_types_of_funs (hol_ctxt as {thy, ...}) binarize finitizes monos
  1216                                (nondef_ts, def_ts) =
  1217   let
  1218     val Ts = ground_types_in_terms hol_ctxt binarize (nondef_ts @ def_ts)
  1219              |> filter_out (fn Type (@{type_name fun_box}, _) => true
  1220                              | @{typ signed_bit} => true
  1221                              | @{typ unsigned_bit} => true
  1222                              | T => is_small_finite_type hol_ctxt T orelse
  1223                                     triple_lookup (type_match thy) monos T
  1224                                     = SOME (SOME false))
  1225   in fold (finitize_funs hol_ctxt binarize finitizes) Ts (nondef_ts, def_ts) end
  1226 
  1227 (** Preprocessor entry point **)
  1228 
  1229 val max_skolem_depth = 4
  1230 
  1231 fun preprocess_term (hol_ctxt as {thy, stds, binary_ints, destroy_constrs,
  1232                                   boxes, ...}) finitizes monos t =
  1233   let
  1234     val (nondef_ts, def_ts, got_all_mono_user_axioms, no_poly_user_axioms) =
  1235       t |> unfold_defs_in_term hol_ctxt
  1236         |> close_form
  1237         |> skolemize_term_and_more hol_ctxt max_skolem_depth
  1238         |> specialize_consts_in_term hol_ctxt 0
  1239         |> axioms_for_term hol_ctxt
  1240     val binarize =
  1241       is_standard_datatype thy stds nat_T andalso
  1242       case binary_ints of
  1243         SOME false => false
  1244       | _ => forall (may_use_binary_ints false) nondef_ts andalso
  1245              forall (may_use_binary_ints true) def_ts andalso
  1246              (binary_ints = SOME true orelse
  1247               exists should_use_binary_ints (nondef_ts @ def_ts))
  1248     val box = exists (not_equal (SOME false) o snd) boxes
  1249     val table =
  1250       Termtab.empty
  1251       |> box ? fold (add_to_uncurry_table thy) (nondef_ts @ def_ts)
  1252     fun do_rest def =
  1253       binarize ? binarize_nat_and_int_in_term
  1254       #> box ? uncurry_term table
  1255       #> box ? box_fun_and_pair_in_term hol_ctxt def
  1256       #> destroy_constrs ? (pull_out_universal_constrs hol_ctxt def
  1257                             #> pull_out_existential_constrs hol_ctxt
  1258                             #> destroy_pulled_out_constrs hol_ctxt def)
  1259       #> curry_assms
  1260       #> destroy_universal_equalities
  1261       #> destroy_existential_equalities hol_ctxt
  1262       #> simplify_constrs_and_sels thy
  1263       #> distribute_quantifiers
  1264       #> push_quantifiers_inward
  1265       #> close_form
  1266       #> Term.map_abs_vars shortest_name
  1267     val nondef_ts = map (do_rest false) nondef_ts
  1268     val def_ts = map (do_rest true) def_ts
  1269     val (nondef_ts, def_ts) =
  1270       finitize_all_types_of_funs hol_ctxt binarize finitizes monos
  1271                                  (nondef_ts, def_ts)
  1272   in
  1273     (nondef_ts, def_ts, got_all_mono_user_axioms, no_poly_user_axioms, binarize)
  1274   end
  1275 
  1276 end;