src/HOL/SMT/Tools/smt_normalize.ML
changeset 32618 42865636d006
child 32740 9dd0a2f83429
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/SMT/Tools/smt_normalize.ML	Fri Sep 18 18:13:19 2009 +0200
     1.3 @@ -0,0 +1,408 @@
     1.4 +(*  Title:      HOL/SMT/Tools/smt_normalize.ML
     1.5 +    Author:     Sascha Boehme, TU Muenchen
     1.6 +
     1.7 +Normalization steps on theorems required by SMT solvers:
     1.8 +  * unfold trivial let expressions,
     1.9 +  * replace negative numerals by negated positive numerals,
    1.10 +  * embed natural numbers into integers,
    1.11 +  * add extra rules specifying types and constants which occur frequently,
    1.12 +  * lift lambda terms,
    1.13 +  * make applications explicit for functions with varying number of arguments,
    1.14 +  * fully translate into object logic, add universal closure. 
    1.15 +*)
    1.16 +
    1.17 +signature SMT_NORMALIZE =
    1.18 +sig
    1.19 +  val normalize_rule: Proof.context -> thm -> thm
    1.20 +  val instantiate_free: Thm.cterm * Thm.cterm -> thm -> thm
    1.21 +  val discharge_definition: Thm.cterm -> thm -> thm
    1.22 +
    1.23 +  val trivial_let: Proof.context -> thm list -> thm list
    1.24 +  val positive_numerals: Proof.context -> thm list -> thm list
    1.25 +  val nat_as_int: Proof.context -> thm list -> thm list
    1.26 +  val unfold_defs: bool Config.T
    1.27 +  val add_pair_rules: Proof.context -> thm list -> thm list
    1.28 +  val add_fun_upd_rules: Proof.context -> thm list -> thm list
    1.29 +  val add_abs_min_max_rules: Proof.context -> thm list -> thm list
    1.30 +
    1.31 +  datatype config =
    1.32 +    RewriteTrivialLets |
    1.33 +    RewriteNegativeNumerals |
    1.34 +    RewriteNaturalNumbers |
    1.35 +    AddPairRules |
    1.36 +    AddFunUpdRules |
    1.37 +    AddAbsMinMaxRules
    1.38 +
    1.39 +  val normalize: config list -> Proof.context -> thm list ->
    1.40 +    Thm.cterm list * thm list
    1.41 +
    1.42 +  val setup: theory -> theory
    1.43 +end
    1.44 +
    1.45 +structure SMT_Normalize: SMT_NORMALIZE =
    1.46 +struct
    1.47 +
    1.48 +val norm_binder_conv = Conv.try_conv (More_Conv.rewrs_conv [
    1.49 +  @{lemma "All P == ALL x. P x" by (rule reflexive)},
    1.50 +  @{lemma "Ex P == EX x. P x" by (rule reflexive)},
    1.51 +  @{lemma "Let c P == let x = c in P x" by (rule reflexive)}])
    1.52 +
    1.53 +fun cert ctxt = Thm.cterm_of (ProofContext.theory_of ctxt)
    1.54 +
    1.55 +fun norm_meta_def cv thm = 
    1.56 +  let val thm' = Thm.combination thm (Thm.reflexive cv)
    1.57 +  in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
    1.58 +
    1.59 +fun norm_def ctxt thm =
    1.60 +  (case Thm.prop_of thm of
    1.61 +    Const (@{const_name "=="}, _) $ _ $ Abs (_, T, _) =>
    1.62 +      let val v = Var ((Name.uu, #maxidx (Thm.rep_thm thm) + 1), T)
    1.63 +      in norm_def ctxt (norm_meta_def (cert ctxt v) thm) end
    1.64 +  | @{term Trueprop} $ (Const (@{const_name "op ="}, _) $ _ $ Abs _) =>
    1.65 +      norm_def ctxt (thm RS @{thm fun_cong})
    1.66 +  | _ => thm)
    1.67 +
    1.68 +fun normalize_rule ctxt =
    1.69 +  Conv.fconv_rule (
    1.70 +    Thm.beta_conversion true then_conv
    1.71 +    Thm.eta_conversion then_conv
    1.72 +    More_Conv.bottom_conv (K norm_binder_conv) ctxt) #>
    1.73 +  norm_def ctxt #>
    1.74 +  Drule.forall_intr_vars #>
    1.75 +  Conv.fconv_rule ObjectLogic.atomize
    1.76 +
    1.77 +fun instantiate_free (cv, ct) thm =
    1.78 +  if Term.exists_subterm (equal (Thm.term_of cv)) (Thm.prop_of thm)
    1.79 +  then Thm.forall_elim ct (Thm.forall_intr cv thm)
    1.80 +  else thm
    1.81 +
    1.82 +fun discharge_definition ct thm =
    1.83 +  let val (cv, cu) = Thm.dest_equals ct
    1.84 +  in
    1.85 +    Thm.implies_intr ct thm
    1.86 +    |> instantiate_free (cv, cu)
    1.87 +    |> (fn thm => Thm.implies_elim thm (Thm.reflexive cu))
    1.88 +  end
    1.89 +
    1.90 +fun if_conv c cv1 cv2 ct = (if c (Thm.term_of ct) then cv1 else cv2) ct
    1.91 +fun if_true_conv c cv = if_conv c cv Conv.all_conv
    1.92 +
    1.93 +
    1.94 +(* simplification of trivial let expressions (whose bound variables occur at
    1.95 +   most once) *)
    1.96 +
    1.97 +local
    1.98 +  fun count i (Bound j) = if j = i then 1 else 0
    1.99 +    | count i (t $ u) = count i t + count i u
   1.100 +    | count i (Abs (_, _, t)) = count (i + 1) t
   1.101 +    | count _ _ = 0
   1.102 +
   1.103 +  fun is_trivial_let (Const (@{const_name Let}, _) $ _ $ Abs (_, _, t)) =
   1.104 +        (count 0 t <= 1)
   1.105 +    | is_trivial_let _ = false
   1.106 +
   1.107 +  fun let_conv _ = if_true_conv is_trivial_let (Conv.rewr_conv @{thm Let_def})
   1.108 +
   1.109 +  fun cond_let_conv ctxt = if_true_conv (Term.exists_subterm is_trivial_let)
   1.110 +    (More_Conv.top_conv let_conv ctxt)
   1.111 +in
   1.112 +fun trivial_let ctxt = map (Conv.fconv_rule (cond_let_conv ctxt))
   1.113 +end
   1.114 +
   1.115 +
   1.116 +(* rewriting of negative integer numerals into positive numerals *)
   1.117 +
   1.118 +local
   1.119 +  fun neg_numeral @{term Int.Min} = true
   1.120 +    | neg_numeral _ = false
   1.121 +  fun is_number_sort thy T = Sign.of_sort thy (T, @{sort number_ring})
   1.122 +  fun is_neg_number ctxt (Const (@{const_name number_of}, T) $ t) =
   1.123 +        Term.exists_subterm neg_numeral t andalso
   1.124 +        is_number_sort (ProofContext.theory_of ctxt) (Term.body_type T)
   1.125 +    | is_neg_number _ _ = false
   1.126 +  fun has_neg_number ctxt = Term.exists_subterm (is_neg_number ctxt)
   1.127 +
   1.128 +  val pos_numeral_ss = HOL_ss
   1.129 +    addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
   1.130 +    addsimps [@{thm Int.numeral_1_eq_1}]
   1.131 +    addsimps @{thms Int.pred_bin_simps}
   1.132 +    addsimps @{thms Int.normalize_bin_simps}
   1.133 +    addsimps @{lemma
   1.134 +      "Int.Min = - Int.Bit1 Int.Pls"
   1.135 +      "Int.Bit0 (- Int.Pls) = - Int.Pls"
   1.136 +      "Int.Bit0 (- k) = - Int.Bit0 k"
   1.137 +      "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
   1.138 +      by simp_all (simp add: pred_def)}
   1.139 +
   1.140 +  fun pos_conv ctxt = if_conv (is_neg_number ctxt)
   1.141 +    (Simplifier.rewrite (Simplifier.context ctxt pos_numeral_ss))
   1.142 +    Conv.no_conv
   1.143 +
   1.144 +  fun cond_pos_conv ctxt = if_true_conv (has_neg_number ctxt)
   1.145 +    (More_Conv.top_sweep_conv pos_conv ctxt)
   1.146 +in
   1.147 +fun positive_numerals ctxt = map (Conv.fconv_rule (cond_pos_conv ctxt))
   1.148 +end
   1.149 +
   1.150 +
   1.151 +(* embedding of standard natural number operations into integer operations *)
   1.152 +
   1.153 +local
   1.154 +  val nat_embedding = @{lemma
   1.155 +    "nat (int n) = n"
   1.156 +    "i >= 0 --> int (nat i) = i"
   1.157 +    "i < 0 --> int (nat i) = 0"
   1.158 +    by simp_all}
   1.159 +
   1.160 +  val nat_rewriting = @{lemma
   1.161 +    "0 = nat 0"
   1.162 +    "1 = nat 1"
   1.163 +    "number_of i = nat (number_of i)"
   1.164 +    "int (nat 0) = 0"
   1.165 +    "int (nat 1) = 1"
   1.166 +    "a < b = (int a < int b)"
   1.167 +    "a <= b = (int a <= int b)"
   1.168 +    "Suc a = nat (int a + 1)"
   1.169 +    "a + b = nat (int a + int b)"
   1.170 +    "a - b = nat (int a - int b)"
   1.171 +    "a * b = nat (int a * int b)"
   1.172 +    "a div b = nat (int a div int b)"
   1.173 +    "a mod b = nat (int a mod int b)"
   1.174 +    "int (nat (int a + int b)) = int a + int b"
   1.175 +    "int (nat (int a * int b)) = int a * int b"
   1.176 +    "int (nat (int a div int b)) = int a div int b"
   1.177 +    "int (nat (int a mod int b)) = int a mod int b"
   1.178 +    by (simp add: nat_mult_distrib nat_div_distrib nat_mod_distrib
   1.179 +      int_mult[symmetric] zdiv_int[symmetric] zmod_int[symmetric])+}
   1.180 +
   1.181 +  fun on_positive num f x = 
   1.182 +    (case try HOLogic.dest_number (Thm.term_of num) of
   1.183 +      SOME (_, i) => if i >= 0 then SOME (f x) else NONE
   1.184 +    | NONE => NONE)
   1.185 +
   1.186 +  val cancel_int_nat_ss = HOL_ss
   1.187 +    addsimps [@{thm Nat_Numeral.nat_number_of}]
   1.188 +    addsimps [@{thm Nat_Numeral.int_nat_number_of}]
   1.189 +    addsimps @{thms neg_simps}
   1.190 +
   1.191 +  fun cancel_int_nat_simproc _ ss ct = 
   1.192 +    let
   1.193 +      val num = Thm.dest_arg (Thm.dest_arg ct)
   1.194 +      val goal = Thm.mk_binop @{cterm "op == :: int => _"} ct num
   1.195 +      val simpset = Simplifier.inherit_context ss cancel_int_nat_ss
   1.196 +      fun tac _ = Simplifier.simp_tac simpset 1
   1.197 +    in on_positive num (Goal.prove_internal [] goal) tac end
   1.198 +
   1.199 +  val nat_ss = HOL_ss
   1.200 +    addsimps nat_rewriting
   1.201 +    addsimprocs [Simplifier.make_simproc {
   1.202 +      name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
   1.203 +      proc = cancel_int_nat_simproc, identifier = [] }]
   1.204 +
   1.205 +  fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
   1.206 +
   1.207 +  val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat}))
   1.208 +in
   1.209 +fun nat_as_int ctxt thms =
   1.210 +  let
   1.211 +    fun norm thm uses_nat =
   1.212 +      if not (uses_nat_type (Thm.prop_of thm)) then (thm, uses_nat)
   1.213 +      else (Conv.fconv_rule (conv ctxt) thm, true)
   1.214 +    val (thms', uses_nat) = fold_map norm thms false
   1.215 +  in if uses_nat then nat_embedding @ thms' else thms' end
   1.216 +end
   1.217 +
   1.218 +
   1.219 +(* include additional rules *)
   1.220 +
   1.221 +val (unfold_defs, unfold_defs_setup) =
   1.222 +  Attrib.config_bool "smt_unfold_defs" true
   1.223 +
   1.224 +local
   1.225 +  val pair_rules = [@{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}]
   1.226 +
   1.227 +  val pair_type = (fn Type (@{type_name "*"}, _) => true | _ => false)
   1.228 +  val exists_pair_type = Term.exists_type (Term.exists_subtype pair_type)
   1.229 +
   1.230 +  val fun_upd_rules = [@{thm fun_upd_same}, @{thm fun_upd_apply}]
   1.231 +  val is_fun_upd = (fn Const (@{const_name fun_upd}, _) => true | _ => false)
   1.232 +  val exists_fun_upd = Term.exists_subterm is_fun_upd
   1.233 +in
   1.234 +fun add_pair_rules _ thms =
   1.235 +  thms
   1.236 +  |> exists (exists_pair_type o Thm.prop_of) thms ? append pair_rules
   1.237 +
   1.238 +fun add_fun_upd_rules _ thms =
   1.239 +  thms
   1.240 +  |> exists (exists_fun_upd o Thm.prop_of) thms ? append fun_upd_rules
   1.241 +end
   1.242 +
   1.243 +
   1.244 +local
   1.245 +  fun mk_entry t thm = (Term.head_of t, (thm, thm RS @{thm eq_reflection}))
   1.246 +  fun prepare_def thm =
   1.247 +    (case HOLogic.dest_Trueprop (Thm.prop_of thm) of
   1.248 +      Const (@{const_name "op ="}, _) $ t $ _ => mk_entry t thm
   1.249 +    | t => raise TERM ("prepare_def", [t]))
   1.250 +
   1.251 +  val defs = map prepare_def [
   1.252 +    @{thm abs_if[where 'a = int]}, @{thm abs_if[where 'a = real]},
   1.253 +    @{thm min_def[where 'a = int]}, @{thm min_def[where 'a = real]},
   1.254 +    @{thm max_def[where 'a = int]}, @{thm max_def[where 'a = real]}]
   1.255 +
   1.256 +  fun add_sym t = if AList.defined (op =) defs t then insert (op =) t else I
   1.257 +  fun add_syms thms = fold (Term.fold_aterms add_sym o Thm.prop_of) thms []
   1.258 +
   1.259 +  fun unfold_conv ctxt ct =
   1.260 +    (case AList.lookup (op =) defs (Term.head_of (Thm.term_of ct)) of
   1.261 +      SOME (_, eq) => Conv.rewr_conv eq
   1.262 +    | NONE => Conv.all_conv) ct
   1.263 +in
   1.264 +fun add_abs_min_max_rules ctxt thms =
   1.265 +  if Config.get ctxt unfold_defs
   1.266 +  then map (Conv.fconv_rule (More_Conv.bottom_conv unfold_conv ctxt)) thms
   1.267 +  else map fst (map_filter (AList.lookup (op =) defs) (add_syms thms)) @ thms
   1.268 +end
   1.269 +
   1.270 +
   1.271 +(* lift lambda terms into additional rules *)
   1.272 +
   1.273 +local
   1.274 +  val meta_eq = @{cpat "op =="}
   1.275 +  val meta_eqT = hd (Thm.dest_ctyp (Thm.ctyp_of_term meta_eq))
   1.276 +  fun inst_meta cT = Thm.instantiate_cterm ([(meta_eqT, cT)], []) meta_eq
   1.277 +  fun mk_meta_eq ct cu = Thm.mk_binop (inst_meta (Thm.ctyp_of_term ct)) ct cu
   1.278 +
   1.279 +  fun lambda_conv conv =
   1.280 +    let
   1.281 +      fun sub_conv cvs ctxt ct =
   1.282 +        (case Thm.term_of ct of
   1.283 +          Const (@{const_name All}, _) $ Abs _ => quant_conv cvs ctxt
   1.284 +        | Const (@{const_name Ex}, _) $ Abs _ => quant_conv cvs ctxt
   1.285 +        | Const _ $ Abs _ => Conv.arg_conv (at_lambda_conv cvs ctxt)
   1.286 +        | Const (@{const_name Let}, _) $ _ $ Abs _ => Conv.combination_conv
   1.287 +            (Conv.arg_conv (sub_conv cvs ctxt)) (abs_conv cvs ctxt)
   1.288 +        | Abs _ => at_lambda_conv cvs ctxt
   1.289 +        | _ $ _ => Conv.comb_conv (sub_conv cvs ctxt)
   1.290 +        | _ => Conv.all_conv) ct
   1.291 +      and abs_conv cvs = Conv.abs_conv (fn (cv, cx) => sub_conv (cv::cvs) cx)
   1.292 +      and quant_conv cvs ctxt = Conv.arg_conv (abs_conv cvs ctxt)
   1.293 +      and at_lambda_conv cvs ctxt = abs_conv cvs ctxt then_conv conv cvs ctxt
   1.294 +    in sub_conv [] end
   1.295 +
   1.296 +  fun used_vars cvs ct =
   1.297 +    let
   1.298 +      val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
   1.299 +      val add = (fn (SOME ct) => insert (op aconvc) ct | _ => I)
   1.300 +    in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
   1.301 +
   1.302 +  val rev_int_fst_ord = rev_order o int_ord o pairself fst
   1.303 +  fun ordered_values tab =
   1.304 +    Termtab.fold (fn (_, x) => OrdList.insert rev_int_fst_ord x) tab []
   1.305 +    |> map snd
   1.306 +in
   1.307 +fun lift_lambdas ctxt thms =
   1.308 +  let
   1.309 +    val declare_frees = fold (Thm.fold_terms Term.declare_term_frees)
   1.310 +    val names = ref (declare_frees thms (Name.make_context []))
   1.311 +    val fresh_name = change_result names o yield_singleton Name.variants
   1.312 +
   1.313 +    val defs = ref (Termtab.empty : (int * thm) Termtab.table)
   1.314 +    fun add_def t thm = change defs (Termtab.update (t, (serial (), thm)))
   1.315 +    fun make_def cvs eq = Thm.symmetric (fold norm_meta_def cvs eq)
   1.316 +    fun def_conv cvs ctxt ct =
   1.317 +      let
   1.318 +        val cvs' = used_vars cvs ct
   1.319 +        val ct' = fold Thm.cabs cvs' ct
   1.320 +      in
   1.321 +        (case Termtab.lookup (!defs) (Thm.term_of ct') of
   1.322 +          SOME (_, eq) => make_def cvs' eq
   1.323 +        | NONE =>
   1.324 +            let
   1.325 +              val {t, T, ...} = Thm.rep_cterm ct'
   1.326 +              val eq = mk_meta_eq (cert ctxt (Free (fresh_name "", T))) ct'
   1.327 +              val thm = Thm.assume eq
   1.328 +            in (add_def t thm; make_def cvs' thm) end)
   1.329 +      end
   1.330 +    val thms' = map (Conv.fconv_rule (lambda_conv def_conv ctxt)) thms
   1.331 +    val eqs = ordered_values (!defs)
   1.332 +  in
   1.333 +    (maps (#hyps o Thm.crep_thm) eqs, map (normalize_rule ctxt) eqs @ thms')
   1.334 +  end
   1.335 +end
   1.336 +
   1.337 +
   1.338 +(* make application explicit for functions with varying number of arguments *)
   1.339 +
   1.340 +local
   1.341 +  val const = prefix "c" and free = prefix "f"
   1.342 +  fun min i (e as (_, j)) = if i <> j then (true, Int.min (i, j)) else e
   1.343 +  fun add t i = Symtab.map_default (t, (false, i)) (min i)
   1.344 +  fun traverse t =
   1.345 +    (case Term.strip_comb t of
   1.346 +      (Const (n, _), ts) => add (const n) (length ts) #> fold traverse ts 
   1.347 +    | (Free (n, _), ts) => add (free n) (length ts) #> fold traverse ts
   1.348 +    | (Abs (_, _, u), ts) => fold traverse (u :: ts)
   1.349 +    | (_, ts) => fold traverse ts)
   1.350 +  val prune = (fn (n, (true, i)) => Symtab.update (n, i) | _ => I)
   1.351 +  fun prune_tab tab = Symtab.fold prune tab Symtab.empty
   1.352 +
   1.353 +  fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
   1.354 +  fun nary_conv conv1 conv2 ct =
   1.355 +    (Conv.combination_conv (nary_conv conv1 conv2) conv2 else_conv conv1) ct
   1.356 +  fun abs_conv conv tb = Conv.abs_conv (fn (cv, cx) =>
   1.357 +    let val n = fst (Term.dest_Free (Thm.term_of cv))
   1.358 +    in conv (Symtab.update (free n, 0) tb) cx end)
   1.359 +  val apply_rule = @{lemma "f x == apply f x" by (simp add: apply_def)}
   1.360 +in
   1.361 +fun explicit_application ctxt thms =
   1.362 +  let
   1.363 +    fun sub_conv tb ctxt ct =
   1.364 +      (case Term.strip_comb (Thm.term_of ct) of
   1.365 +        (Const (n, _), ts) => app_conv tb (const n) (length ts) ctxt
   1.366 +      | (Free (n, _), ts) => app_conv tb (free n) (length ts) ctxt
   1.367 +      | (Abs _, ts) => nary_conv (abs_conv sub_conv tb ctxt) (sub_conv tb ctxt)
   1.368 +      | (_, ts) => nary_conv Conv.all_conv (sub_conv tb ctxt)) ct
   1.369 +    and app_conv tb n i ctxt =
   1.370 +      (case Symtab.lookup tb n of
   1.371 +        NONE => nary_conv Conv.all_conv (sub_conv tb ctxt)
   1.372 +      | SOME j => apply_conv tb ctxt (i - j))
   1.373 +    and apply_conv tb ctxt i ct = (
   1.374 +      if i = 0 then nary_conv Conv.all_conv (sub_conv tb ctxt)
   1.375 +      else
   1.376 +        Conv.rewr_conv apply_rule then_conv
   1.377 +        binop_conv (apply_conv tb ctxt (i-1)) (sub_conv tb ctxt)) ct
   1.378 +
   1.379 +    val tab = prune_tab (fold (traverse o Thm.prop_of) thms Symtab.empty)
   1.380 +  in map (Conv.fconv_rule (sub_conv tab ctxt)) thms end
   1.381 +end
   1.382 +
   1.383 +
   1.384 +(* combined normalization *)
   1.385 +
   1.386 +datatype config =
   1.387 +  RewriteTrivialLets |
   1.388 +  RewriteNegativeNumerals |
   1.389 +  RewriteNaturalNumbers |
   1.390 +  AddPairRules |
   1.391 +  AddFunUpdRules |
   1.392 +  AddAbsMinMaxRules
   1.393 +
   1.394 +fun normalize config ctxt thms =
   1.395 +  let fun if_enabled c f = member (op =) config c ? f ctxt
   1.396 +  in
   1.397 +    thms
   1.398 +    |> if_enabled RewriteTrivialLets trivial_let
   1.399 +    |> if_enabled RewriteNegativeNumerals positive_numerals
   1.400 +    |> if_enabled RewriteNaturalNumbers nat_as_int
   1.401 +    |> if_enabled AddPairRules add_pair_rules
   1.402 +    |> if_enabled AddFunUpdRules add_fun_upd_rules
   1.403 +    |> if_enabled AddAbsMinMaxRules add_abs_min_max_rules
   1.404 +    |> map (normalize_rule ctxt)
   1.405 +    |> lift_lambdas ctxt
   1.406 +    |> apsnd (explicit_application ctxt)
   1.407 +  end
   1.408 +
   1.409 +val setup = unfold_defs_setup
   1.410 +
   1.411 +end