1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
1.2 +++ b/src/HOL/SMT/Tools/smt_normalize.ML Fri Sep 18 18:13:19 2009 +0200
1.3 @@ -0,0 +1,408 @@
1.4 +(* Title: HOL/SMT/Tools/smt_normalize.ML
1.5 + Author: Sascha Boehme, TU Muenchen
1.6 +
1.7 +Normalization steps on theorems required by SMT solvers:
1.8 + * unfold trivial let expressions,
1.9 + * replace negative numerals by negated positive numerals,
1.10 + * embed natural numbers into integers,
1.11 + * add extra rules specifying types and constants which occur frequently,
1.12 + * lift lambda terms,
1.13 + * make applications explicit for functions with varying number of arguments,
1.14 + * fully translate into object logic, add universal closure.
1.15 +*)
1.16 +
1.17 +signature SMT_NORMALIZE =
1.18 +sig
1.19 + val normalize_rule: Proof.context -> thm -> thm
1.20 + val instantiate_free: Thm.cterm * Thm.cterm -> thm -> thm
1.21 + val discharge_definition: Thm.cterm -> thm -> thm
1.22 +
1.23 + val trivial_let: Proof.context -> thm list -> thm list
1.24 + val positive_numerals: Proof.context -> thm list -> thm list
1.25 + val nat_as_int: Proof.context -> thm list -> thm list
1.26 + val unfold_defs: bool Config.T
1.27 + val add_pair_rules: Proof.context -> thm list -> thm list
1.28 + val add_fun_upd_rules: Proof.context -> thm list -> thm list
1.29 + val add_abs_min_max_rules: Proof.context -> thm list -> thm list
1.30 +
1.31 + datatype config =
1.32 + RewriteTrivialLets |
1.33 + RewriteNegativeNumerals |
1.34 + RewriteNaturalNumbers |
1.35 + AddPairRules |
1.36 + AddFunUpdRules |
1.37 + AddAbsMinMaxRules
1.38 +
1.39 + val normalize: config list -> Proof.context -> thm list ->
1.40 + Thm.cterm list * thm list
1.41 +
1.42 + val setup: theory -> theory
1.43 +end
1.44 +
1.45 +structure SMT_Normalize: SMT_NORMALIZE =
1.46 +struct
1.47 +
1.48 +val norm_binder_conv = Conv.try_conv (More_Conv.rewrs_conv [
1.49 + @{lemma "All P == ALL x. P x" by (rule reflexive)},
1.50 + @{lemma "Ex P == EX x. P x" by (rule reflexive)},
1.51 + @{lemma "Let c P == let x = c in P x" by (rule reflexive)}])
1.52 +
1.53 +fun cert ctxt = Thm.cterm_of (ProofContext.theory_of ctxt)
1.54 +
1.55 +fun norm_meta_def cv thm =
1.56 + let val thm' = Thm.combination thm (Thm.reflexive cv)
1.57 + in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end
1.58 +
1.59 +fun norm_def ctxt thm =
1.60 + (case Thm.prop_of thm of
1.61 + Const (@{const_name "=="}, _) $ _ $ Abs (_, T, _) =>
1.62 + let val v = Var ((Name.uu, #maxidx (Thm.rep_thm thm) + 1), T)
1.63 + in norm_def ctxt (norm_meta_def (cert ctxt v) thm) end
1.64 + | @{term Trueprop} $ (Const (@{const_name "op ="}, _) $ _ $ Abs _) =>
1.65 + norm_def ctxt (thm RS @{thm fun_cong})
1.66 + | _ => thm)
1.67 +
1.68 +fun normalize_rule ctxt =
1.69 + Conv.fconv_rule (
1.70 + Thm.beta_conversion true then_conv
1.71 + Thm.eta_conversion then_conv
1.72 + More_Conv.bottom_conv (K norm_binder_conv) ctxt) #>
1.73 + norm_def ctxt #>
1.74 + Drule.forall_intr_vars #>
1.75 + Conv.fconv_rule ObjectLogic.atomize
1.76 +
1.77 +fun instantiate_free (cv, ct) thm =
1.78 + if Term.exists_subterm (equal (Thm.term_of cv)) (Thm.prop_of thm)
1.79 + then Thm.forall_elim ct (Thm.forall_intr cv thm)
1.80 + else thm
1.81 +
1.82 +fun discharge_definition ct thm =
1.83 + let val (cv, cu) = Thm.dest_equals ct
1.84 + in
1.85 + Thm.implies_intr ct thm
1.86 + |> instantiate_free (cv, cu)
1.87 + |> (fn thm => Thm.implies_elim thm (Thm.reflexive cu))
1.88 + end
1.89 +
1.90 +fun if_conv c cv1 cv2 ct = (if c (Thm.term_of ct) then cv1 else cv2) ct
1.91 +fun if_true_conv c cv = if_conv c cv Conv.all_conv
1.92 +
1.93 +
1.94 +(* simplification of trivial let expressions (whose bound variables occur at
1.95 + most once) *)
1.96 +
1.97 +local
1.98 + fun count i (Bound j) = if j = i then 1 else 0
1.99 + | count i (t $ u) = count i t + count i u
1.100 + | count i (Abs (_, _, t)) = count (i + 1) t
1.101 + | count _ _ = 0
1.102 +
1.103 + fun is_trivial_let (Const (@{const_name Let}, _) $ _ $ Abs (_, _, t)) =
1.104 + (count 0 t <= 1)
1.105 + | is_trivial_let _ = false
1.106 +
1.107 + fun let_conv _ = if_true_conv is_trivial_let (Conv.rewr_conv @{thm Let_def})
1.108 +
1.109 + fun cond_let_conv ctxt = if_true_conv (Term.exists_subterm is_trivial_let)
1.110 + (More_Conv.top_conv let_conv ctxt)
1.111 +in
1.112 +fun trivial_let ctxt = map (Conv.fconv_rule (cond_let_conv ctxt))
1.113 +end
1.114 +
1.115 +
1.116 +(* rewriting of negative integer numerals into positive numerals *)
1.117 +
1.118 +local
1.119 + fun neg_numeral @{term Int.Min} = true
1.120 + | neg_numeral _ = false
1.121 + fun is_number_sort thy T = Sign.of_sort thy (T, @{sort number_ring})
1.122 + fun is_neg_number ctxt (Const (@{const_name number_of}, T) $ t) =
1.123 + Term.exists_subterm neg_numeral t andalso
1.124 + is_number_sort (ProofContext.theory_of ctxt) (Term.body_type T)
1.125 + | is_neg_number _ _ = false
1.126 + fun has_neg_number ctxt = Term.exists_subterm (is_neg_number ctxt)
1.127 +
1.128 + val pos_numeral_ss = HOL_ss
1.129 + addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}]
1.130 + addsimps [@{thm Int.numeral_1_eq_1}]
1.131 + addsimps @{thms Int.pred_bin_simps}
1.132 + addsimps @{thms Int.normalize_bin_simps}
1.133 + addsimps @{lemma
1.134 + "Int.Min = - Int.Bit1 Int.Pls"
1.135 + "Int.Bit0 (- Int.Pls) = - Int.Pls"
1.136 + "Int.Bit0 (- k) = - Int.Bit0 k"
1.137 + "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)"
1.138 + by simp_all (simp add: pred_def)}
1.139 +
1.140 + fun pos_conv ctxt = if_conv (is_neg_number ctxt)
1.141 + (Simplifier.rewrite (Simplifier.context ctxt pos_numeral_ss))
1.142 + Conv.no_conv
1.143 +
1.144 + fun cond_pos_conv ctxt = if_true_conv (has_neg_number ctxt)
1.145 + (More_Conv.top_sweep_conv pos_conv ctxt)
1.146 +in
1.147 +fun positive_numerals ctxt = map (Conv.fconv_rule (cond_pos_conv ctxt))
1.148 +end
1.149 +
1.150 +
1.151 +(* embedding of standard natural number operations into integer operations *)
1.152 +
1.153 +local
1.154 + val nat_embedding = @{lemma
1.155 + "nat (int n) = n"
1.156 + "i >= 0 --> int (nat i) = i"
1.157 + "i < 0 --> int (nat i) = 0"
1.158 + by simp_all}
1.159 +
1.160 + val nat_rewriting = @{lemma
1.161 + "0 = nat 0"
1.162 + "1 = nat 1"
1.163 + "number_of i = nat (number_of i)"
1.164 + "int (nat 0) = 0"
1.165 + "int (nat 1) = 1"
1.166 + "a < b = (int a < int b)"
1.167 + "a <= b = (int a <= int b)"
1.168 + "Suc a = nat (int a + 1)"
1.169 + "a + b = nat (int a + int b)"
1.170 + "a - b = nat (int a - int b)"
1.171 + "a * b = nat (int a * int b)"
1.172 + "a div b = nat (int a div int b)"
1.173 + "a mod b = nat (int a mod int b)"
1.174 + "int (nat (int a + int b)) = int a + int b"
1.175 + "int (nat (int a * int b)) = int a * int b"
1.176 + "int (nat (int a div int b)) = int a div int b"
1.177 + "int (nat (int a mod int b)) = int a mod int b"
1.178 + by (simp add: nat_mult_distrib nat_div_distrib nat_mod_distrib
1.179 + int_mult[symmetric] zdiv_int[symmetric] zmod_int[symmetric])+}
1.180 +
1.181 + fun on_positive num f x =
1.182 + (case try HOLogic.dest_number (Thm.term_of num) of
1.183 + SOME (_, i) => if i >= 0 then SOME (f x) else NONE
1.184 + | NONE => NONE)
1.185 +
1.186 + val cancel_int_nat_ss = HOL_ss
1.187 + addsimps [@{thm Nat_Numeral.nat_number_of}]
1.188 + addsimps [@{thm Nat_Numeral.int_nat_number_of}]
1.189 + addsimps @{thms neg_simps}
1.190 +
1.191 + fun cancel_int_nat_simproc _ ss ct =
1.192 + let
1.193 + val num = Thm.dest_arg (Thm.dest_arg ct)
1.194 + val goal = Thm.mk_binop @{cterm "op == :: int => _"} ct num
1.195 + val simpset = Simplifier.inherit_context ss cancel_int_nat_ss
1.196 + fun tac _ = Simplifier.simp_tac simpset 1
1.197 + in on_positive num (Goal.prove_internal [] goal) tac end
1.198 +
1.199 + val nat_ss = HOL_ss
1.200 + addsimps nat_rewriting
1.201 + addsimprocs [Simplifier.make_simproc {
1.202 + name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}],
1.203 + proc = cancel_int_nat_simproc, identifier = [] }]
1.204 +
1.205 + fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss)
1.206 +
1.207 + val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat}))
1.208 +in
1.209 +fun nat_as_int ctxt thms =
1.210 + let
1.211 + fun norm thm uses_nat =
1.212 + if not (uses_nat_type (Thm.prop_of thm)) then (thm, uses_nat)
1.213 + else (Conv.fconv_rule (conv ctxt) thm, true)
1.214 + val (thms', uses_nat) = fold_map norm thms false
1.215 + in if uses_nat then nat_embedding @ thms' else thms' end
1.216 +end
1.217 +
1.218 +
1.219 +(* include additional rules *)
1.220 +
1.221 +val (unfold_defs, unfold_defs_setup) =
1.222 + Attrib.config_bool "smt_unfold_defs" true
1.223 +
1.224 +local
1.225 + val pair_rules = [@{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}]
1.226 +
1.227 + val pair_type = (fn Type (@{type_name "*"}, _) => true | _ => false)
1.228 + val exists_pair_type = Term.exists_type (Term.exists_subtype pair_type)
1.229 +
1.230 + val fun_upd_rules = [@{thm fun_upd_same}, @{thm fun_upd_apply}]
1.231 + val is_fun_upd = (fn Const (@{const_name fun_upd}, _) => true | _ => false)
1.232 + val exists_fun_upd = Term.exists_subterm is_fun_upd
1.233 +in
1.234 +fun add_pair_rules _ thms =
1.235 + thms
1.236 + |> exists (exists_pair_type o Thm.prop_of) thms ? append pair_rules
1.237 +
1.238 +fun add_fun_upd_rules _ thms =
1.239 + thms
1.240 + |> exists (exists_fun_upd o Thm.prop_of) thms ? append fun_upd_rules
1.241 +end
1.242 +
1.243 +
1.244 +local
1.245 + fun mk_entry t thm = (Term.head_of t, (thm, thm RS @{thm eq_reflection}))
1.246 + fun prepare_def thm =
1.247 + (case HOLogic.dest_Trueprop (Thm.prop_of thm) of
1.248 + Const (@{const_name "op ="}, _) $ t $ _ => mk_entry t thm
1.249 + | t => raise TERM ("prepare_def", [t]))
1.250 +
1.251 + val defs = map prepare_def [
1.252 + @{thm abs_if[where 'a = int]}, @{thm abs_if[where 'a = real]},
1.253 + @{thm min_def[where 'a = int]}, @{thm min_def[where 'a = real]},
1.254 + @{thm max_def[where 'a = int]}, @{thm max_def[where 'a = real]}]
1.255 +
1.256 + fun add_sym t = if AList.defined (op =) defs t then insert (op =) t else I
1.257 + fun add_syms thms = fold (Term.fold_aterms add_sym o Thm.prop_of) thms []
1.258 +
1.259 + fun unfold_conv ctxt ct =
1.260 + (case AList.lookup (op =) defs (Term.head_of (Thm.term_of ct)) of
1.261 + SOME (_, eq) => Conv.rewr_conv eq
1.262 + | NONE => Conv.all_conv) ct
1.263 +in
1.264 +fun add_abs_min_max_rules ctxt thms =
1.265 + if Config.get ctxt unfold_defs
1.266 + then map (Conv.fconv_rule (More_Conv.bottom_conv unfold_conv ctxt)) thms
1.267 + else map fst (map_filter (AList.lookup (op =) defs) (add_syms thms)) @ thms
1.268 +end
1.269 +
1.270 +
1.271 +(* lift lambda terms into additional rules *)
1.272 +
1.273 +local
1.274 + val meta_eq = @{cpat "op =="}
1.275 + val meta_eqT = hd (Thm.dest_ctyp (Thm.ctyp_of_term meta_eq))
1.276 + fun inst_meta cT = Thm.instantiate_cterm ([(meta_eqT, cT)], []) meta_eq
1.277 + fun mk_meta_eq ct cu = Thm.mk_binop (inst_meta (Thm.ctyp_of_term ct)) ct cu
1.278 +
1.279 + fun lambda_conv conv =
1.280 + let
1.281 + fun sub_conv cvs ctxt ct =
1.282 + (case Thm.term_of ct of
1.283 + Const (@{const_name All}, _) $ Abs _ => quant_conv cvs ctxt
1.284 + | Const (@{const_name Ex}, _) $ Abs _ => quant_conv cvs ctxt
1.285 + | Const _ $ Abs _ => Conv.arg_conv (at_lambda_conv cvs ctxt)
1.286 + | Const (@{const_name Let}, _) $ _ $ Abs _ => Conv.combination_conv
1.287 + (Conv.arg_conv (sub_conv cvs ctxt)) (abs_conv cvs ctxt)
1.288 + | Abs _ => at_lambda_conv cvs ctxt
1.289 + | _ $ _ => Conv.comb_conv (sub_conv cvs ctxt)
1.290 + | _ => Conv.all_conv) ct
1.291 + and abs_conv cvs = Conv.abs_conv (fn (cv, cx) => sub_conv (cv::cvs) cx)
1.292 + and quant_conv cvs ctxt = Conv.arg_conv (abs_conv cvs ctxt)
1.293 + and at_lambda_conv cvs ctxt = abs_conv cvs ctxt then_conv conv cvs ctxt
1.294 + in sub_conv [] end
1.295 +
1.296 + fun used_vars cvs ct =
1.297 + let
1.298 + val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
1.299 + val add = (fn (SOME ct) => insert (op aconvc) ct | _ => I)
1.300 + in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
1.301 +
1.302 + val rev_int_fst_ord = rev_order o int_ord o pairself fst
1.303 + fun ordered_values tab =
1.304 + Termtab.fold (fn (_, x) => OrdList.insert rev_int_fst_ord x) tab []
1.305 + |> map snd
1.306 +in
1.307 +fun lift_lambdas ctxt thms =
1.308 + let
1.309 + val declare_frees = fold (Thm.fold_terms Term.declare_term_frees)
1.310 + val names = ref (declare_frees thms (Name.make_context []))
1.311 + val fresh_name = change_result names o yield_singleton Name.variants
1.312 +
1.313 + val defs = ref (Termtab.empty : (int * thm) Termtab.table)
1.314 + fun add_def t thm = change defs (Termtab.update (t, (serial (), thm)))
1.315 + fun make_def cvs eq = Thm.symmetric (fold norm_meta_def cvs eq)
1.316 + fun def_conv cvs ctxt ct =
1.317 + let
1.318 + val cvs' = used_vars cvs ct
1.319 + val ct' = fold Thm.cabs cvs' ct
1.320 + in
1.321 + (case Termtab.lookup (!defs) (Thm.term_of ct') of
1.322 + SOME (_, eq) => make_def cvs' eq
1.323 + | NONE =>
1.324 + let
1.325 + val {t, T, ...} = Thm.rep_cterm ct'
1.326 + val eq = mk_meta_eq (cert ctxt (Free (fresh_name "", T))) ct'
1.327 + val thm = Thm.assume eq
1.328 + in (add_def t thm; make_def cvs' thm) end)
1.329 + end
1.330 + val thms' = map (Conv.fconv_rule (lambda_conv def_conv ctxt)) thms
1.331 + val eqs = ordered_values (!defs)
1.332 + in
1.333 + (maps (#hyps o Thm.crep_thm) eqs, map (normalize_rule ctxt) eqs @ thms')
1.334 + end
1.335 +end
1.336 +
1.337 +
1.338 +(* make application explicit for functions with varying number of arguments *)
1.339 +
1.340 +local
1.341 + val const = prefix "c" and free = prefix "f"
1.342 + fun min i (e as (_, j)) = if i <> j then (true, Int.min (i, j)) else e
1.343 + fun add t i = Symtab.map_default (t, (false, i)) (min i)
1.344 + fun traverse t =
1.345 + (case Term.strip_comb t of
1.346 + (Const (n, _), ts) => add (const n) (length ts) #> fold traverse ts
1.347 + | (Free (n, _), ts) => add (free n) (length ts) #> fold traverse ts
1.348 + | (Abs (_, _, u), ts) => fold traverse (u :: ts)
1.349 + | (_, ts) => fold traverse ts)
1.350 + val prune = (fn (n, (true, i)) => Symtab.update (n, i) | _ => I)
1.351 + fun prune_tab tab = Symtab.fold prune tab Symtab.empty
1.352 +
1.353 + fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
1.354 + fun nary_conv conv1 conv2 ct =
1.355 + (Conv.combination_conv (nary_conv conv1 conv2) conv2 else_conv conv1) ct
1.356 + fun abs_conv conv tb = Conv.abs_conv (fn (cv, cx) =>
1.357 + let val n = fst (Term.dest_Free (Thm.term_of cv))
1.358 + in conv (Symtab.update (free n, 0) tb) cx end)
1.359 + val apply_rule = @{lemma "f x == apply f x" by (simp add: apply_def)}
1.360 +in
1.361 +fun explicit_application ctxt thms =
1.362 + let
1.363 + fun sub_conv tb ctxt ct =
1.364 + (case Term.strip_comb (Thm.term_of ct) of
1.365 + (Const (n, _), ts) => app_conv tb (const n) (length ts) ctxt
1.366 + | (Free (n, _), ts) => app_conv tb (free n) (length ts) ctxt
1.367 + | (Abs _, ts) => nary_conv (abs_conv sub_conv tb ctxt) (sub_conv tb ctxt)
1.368 + | (_, ts) => nary_conv Conv.all_conv (sub_conv tb ctxt)) ct
1.369 + and app_conv tb n i ctxt =
1.370 + (case Symtab.lookup tb n of
1.371 + NONE => nary_conv Conv.all_conv (sub_conv tb ctxt)
1.372 + | SOME j => apply_conv tb ctxt (i - j))
1.373 + and apply_conv tb ctxt i ct = (
1.374 + if i = 0 then nary_conv Conv.all_conv (sub_conv tb ctxt)
1.375 + else
1.376 + Conv.rewr_conv apply_rule then_conv
1.377 + binop_conv (apply_conv tb ctxt (i-1)) (sub_conv tb ctxt)) ct
1.378 +
1.379 + val tab = prune_tab (fold (traverse o Thm.prop_of) thms Symtab.empty)
1.380 + in map (Conv.fconv_rule (sub_conv tab ctxt)) thms end
1.381 +end
1.382 +
1.383 +
1.384 +(* combined normalization *)
1.385 +
1.386 +datatype config =
1.387 + RewriteTrivialLets |
1.388 + RewriteNegativeNumerals |
1.389 + RewriteNaturalNumbers |
1.390 + AddPairRules |
1.391 + AddFunUpdRules |
1.392 + AddAbsMinMaxRules
1.393 +
1.394 +fun normalize config ctxt thms =
1.395 + let fun if_enabled c f = member (op =) config c ? f ctxt
1.396 + in
1.397 + thms
1.398 + |> if_enabled RewriteTrivialLets trivial_let
1.399 + |> if_enabled RewriteNegativeNumerals positive_numerals
1.400 + |> if_enabled RewriteNaturalNumbers nat_as_int
1.401 + |> if_enabled AddPairRules add_pair_rules
1.402 + |> if_enabled AddFunUpdRules add_fun_upd_rules
1.403 + |> if_enabled AddAbsMinMaxRules add_abs_min_max_rules
1.404 + |> map (normalize_rule ctxt)
1.405 + |> lift_lambdas ctxt
1.406 + |> apsnd (explicit_application ctxt)
1.407 + end
1.408 +
1.409 +val setup = unfold_defs_setup
1.410 +
1.411 +end