1 (* Title: HOL/arith_data.ML
3 Author: Markus Wenzel, Stefan Berghofer and Tobias Nipkow
5 Various arithmetic proof procedures.
8 (*---------------------------------------------------------------------------*)
9 (* 1. Cancellation of common terms *)
10 (*---------------------------------------------------------------------------*)
12 structure NatArithUtils =
15 (** abstract syntax of structure nat: 0, Suc, + **)
17 (* mk_sum, mk_norm_sum *)
19 val one = HOLogic.mk_nat 1;
20 val mk_plus = HOLogic.mk_binop "op +";
22 fun mk_sum [] = HOLogic.zero
24 | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
26 (*normal form of sums: Suc (... (Suc (a + (b + ...))))*)
28 let val (ones, sums) = partition (equal one) ts in
29 funpow (length ones) HOLogic.mk_Suc (mk_sum sums)
35 val dest_plus = HOLogic.dest_bin "op +" HOLogic.natT;
38 if HOLogic.is_zero tm then []
40 (case try HOLogic.dest_Suc tm of
41 Some t => one :: dest_sum t
43 (case try dest_plus tm of
44 Some (t, u) => dest_sum t @ dest_sum u
48 (** generic proof tools **)
50 (* prove conversions *)
52 val mk_eqv = HOLogic.mk_Trueprop o HOLogic.mk_eq;
54 fun prove_conv expand_tac norm_tac sg tu =
55 mk_meta_eq (prove_goalw_cterm_nocheck [] (cterm_of sg (mk_eqv tu))
56 (K [expand_tac, norm_tac]))
57 handle ERROR => error ("The error(s) above occurred while trying to prove " ^
58 (string_of_cterm (cterm_of sg (mk_eqv tu))));
60 val subst_equals = prove_goal HOL.thy "[| t = s; u = t |] ==> u = s"
61 (fn prems => [cut_facts_tac prems 1, SIMPSET' asm_simp_tac 1]);
66 fun simp_all rules = ALLGOALS (simp_tac (HOL_ss addsimps rules));
68 val add_rules = [add_Suc, add_Suc_right, add_0, add_0_right];
69 val mult_rules = [mult_Suc, mult_Suc_right, mult_0, mult_0_right];
71 fun prep_simproc (name, pats, proc) =
72 Simplifier.simproc (Theory.sign_of (the_context ())) name pats proc;
76 signature ARITH_DATA =
78 val nat_cancel_sums_add: simproc list
79 val nat_cancel_sums: simproc list
82 structure ArithData: ARITH_DATA =
88 (** cancel common summands **)
92 val mk_sum = mk_norm_sum;
93 val dest_sum = dest_sum;
94 val prove_conv = prove_conv;
95 val norm_tac = simp_all add_rules THEN simp_all add_ac;
98 fun gen_uncancel_tac rule ct =
99 rtac (instantiate' [] [None, Some ct] (rule RS subst_equals)) 1;
104 structure EqCancelSums = CancelSumsFun
107 val mk_bal = HOLogic.mk_eq;
108 val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
109 val uncancel_tac = gen_uncancel_tac nat_add_left_cancel;
115 structure LessCancelSums = CancelSumsFun
118 val mk_bal = HOLogic.mk_binrel "op <";
119 val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
120 val uncancel_tac = gen_uncancel_tac nat_add_left_cancel_less;
126 structure LeCancelSums = CancelSumsFun
129 val mk_bal = HOLogic.mk_binrel "op <=";
130 val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
131 val uncancel_tac = gen_uncancel_tac nat_add_left_cancel_le;
137 structure DiffCancelSums = CancelSumsFun
140 val mk_bal = HOLogic.mk_binop "op -";
141 val dest_bal = HOLogic.dest_bin "op -" HOLogic.natT;
142 val uncancel_tac = gen_uncancel_tac diff_cancel;
147 (** prepare nat_cancel simprocs **)
149 val nat_cancel_sums_add = map prep_simproc
150 [("nateq_cancel_sums",
151 ["(l::nat) + m = n", "(l::nat) = m + n", "Suc m = n", "m = Suc n"], EqCancelSums.proc),
152 ("natless_cancel_sums",
153 ["(l::nat) + m < n", "(l::nat) < m + n", "Suc m < n", "m < Suc n"], LessCancelSums.proc),
154 ("natle_cancel_sums",
155 ["(l::nat) + m <= n", "(l::nat) <= m + n", "Suc m <= n", "m <= Suc n"], LeCancelSums.proc)];
157 val nat_cancel_sums = nat_cancel_sums_add @
158 [prep_simproc ("natdiff_cancel_sums",
159 ["((l::nat) + m) - n", "(l::nat) - (m + n)", "Suc m - n", "m - Suc n"], DiffCancelSums.proc)];
166 (*---------------------------------------------------------------------------*)
167 (* 2. Linear arithmetic *)
168 (*---------------------------------------------------------------------------*)
170 (* Parameters data for general linear arithmetic functor *)
172 structure LA_Logic: LIN_ARITH_LOGIC =
176 val neqE = linorder_neqE;
179 val not_lessD = linorder_not_less RS iffD1;
180 val not_leD = linorder_not_le RS iffD1;
183 fun mk_Eq thm = (thm RS Eq_FalseI) handle THM _ => (thm RS Eq_TrueI);
185 val mk_Trueprop = HOLogic.mk_Trueprop;
187 fun neg_prop(TP$(Const("Not",_)$t)) = TP$t
188 | neg_prop(TP$t) = TP $ (Const("Not",HOLogic.boolT-->HOLogic.boolT)$t);
191 let val _ $ t = #prop(rep_thm thm)
192 in t = Const("False",HOLogic.boolT) end;
194 fun is_nat(t) = fastype_of1 t = HOLogic.natT;
196 fun mk_nat_thm sg t =
197 let val ct = cterm_of sg t and cn = cterm_of sg (Var(("n",0),HOLogic.natT))
198 in instantiate ([],[(cn,ct)]) le0 end;
203 (* arith theory data *)
205 structure ArithTheoryDataArgs =
207 val name = "HOL/arith";
208 type T = {splits: thm list, inj_consts: (string * typ)list, discrete: (string * bool) list, presburger: (int -> tactic) option};
210 val empty = {splits = [], inj_consts = [], discrete = [], presburger = None};
213 fun merge ({splits= splits1, inj_consts= inj_consts1, discrete= discrete1, presburger= presburger1},
214 {splits= splits2, inj_consts= inj_consts2, discrete= discrete2, presburger= presburger2}) =
215 {splits = Drule.merge_rules (splits1, splits2),
216 inj_consts = merge_lists inj_consts1 inj_consts2,
217 discrete = merge_alists discrete1 discrete2,
218 presburger = (case presburger1 of None => presburger2 | p => p)};
222 structure ArithTheoryData = TheoryDataFun(ArithTheoryDataArgs);
224 fun arith_split_add (thy, thm) = (ArithTheoryData.map (fn {splits,inj_consts,discrete,presburger} =>
225 {splits= thm::splits, inj_consts= inj_consts, discrete= discrete, presburger= presburger}) thy, thm);
227 fun arith_discrete d = ArithTheoryData.map (fn {splits,inj_consts,discrete,presburger} =>
228 {splits = splits, inj_consts = inj_consts, discrete = d :: discrete, presburger= presburger});
230 fun arith_inj_const c = ArithTheoryData.map (fn {splits,inj_consts,discrete,presburger} =>
231 {splits = splits, inj_consts = c :: inj_consts, discrete = discrete, presburger = presburger});
234 structure LA_Data_Ref: LIN_ARITH_DATA =
237 (* Decomposition of terms *)
239 fun nT (Type("fun",[N,_])) = N = HOLogic.natT
242 fun add_atom(t,m,(p,i)) = (case assoc(p,t) of None => ((t,m)::p,i)
243 | Some n => (overwrite(p,(t,ratadd(n,m))), i));
247 fun rat_of_term(numt,dent) =
248 let val num = HOLogic.dest_binum numt and den = HOLogic.dest_binum dent
249 in if den = 0 then raise Zero else int_ratdiv(num,den) end;
251 (* Warning: in rare cases number_of encloses a non-numeral,
252 in which case dest_binum raises TERM; hence all the handles below.
253 Same for Suc-terms that turn out not to be numerals -
254 although the simplifier should eliminate those anyway...
257 fun number_of_Sucs (Const("Suc",_) $ n) = number_of_Sucs n + 1
258 | number_of_Sucs t = if HOLogic.is_zero t then 0
259 else raise TERM("number_of_Sucs",[])
261 (* decompose nested multiplications, bracketing them to the right and combining all
265 fun demult inj_consts =
267 fun demult((mC as Const("op *",_)) $ s $ t,m) = ((case s of
268 Const("Numeral.number_of",_)$n
269 => demult(t,ratmul(m,rat_of_int(HOLogic.dest_binum n)))
270 | Const("uminus",_)$(Const("Numeral.number_of",_)$n)
271 => demult(t,ratmul(m,rat_of_int(~(HOLogic.dest_binum n))))
273 => demult(t,ratmul(m,rat_of_int(number_of_Sucs s)))
274 | Const("op *",_) $ s1 $ s2 => demult(mC $ s1 $ (mC $ s2 $ t),m)
275 | Const("HOL.divide",_) $ numt $ (Const("Numeral.number_of",_)$dent) =>
276 let val den = HOLogic.dest_binum dent
277 in if den = 0 then raise Zero
278 else demult(mC $ numt $ t,ratmul(m, ratinv(rat_of_int den)))
280 | _ => atomult(mC,s,t,m)
281 ) handle TERM _ => atomult(mC,s,t,m))
282 | demult(atom as Const("HOL.divide",_) $ t $ (Const("Numeral.number_of",_)$dent), m) =
283 (let val den = HOLogic.dest_binum dent
284 in if den = 0 then raise Zero else demult(t,ratmul(m, ratinv(rat_of_int den))) end
285 handle TERM _ => (Some atom,m))
286 | demult(Const("0",_),m) = (None, rat_of_int 0)
287 | demult(Const("1",_),m) = (None, m)
288 | demult(t as Const("Numeral.number_of",_)$n,m) =
289 ((None,ratmul(m,rat_of_int(HOLogic.dest_binum n)))
290 handle TERM _ => (Some t,m))
291 | demult(Const("uminus",_)$t, m) = demult(t,ratmul(m,rat_of_int(~1)))
292 | demult(t as Const f $ x, m) =
293 (if f mem inj_consts then Some x else Some t,m)
294 | demult(atom,m) = (Some atom,m)
296 and atomult(mC,atom,t,m) = (case demult(t,m) of (None,m') => (Some atom,m')
297 | (Some t',m') => (Some(mC $ atom $ t'),m'))
300 fun decomp2 inj_consts (rel,lhs,rhs) =
302 (* Turn term into list of summand * multiplicity plus a constant *)
303 fun poly(Const("op +",_) $ s $ t, m, pi) = poly(s,m,poly(t,m,pi))
304 | poly(all as Const("op -",T) $ s $ t, m, pi) =
305 if nT T then add_atom(all,m,pi)
306 else poly(s,m,poly(t,ratneg m,pi))
307 | poly(Const("uminus",_) $ t, m, pi) = poly(t,ratneg m,pi)
308 | poly(Const("0",_), _, pi) = pi
309 | poly(Const("1",_), m, (p,i)) = (p,ratadd(i,m))
310 | poly(Const("Suc",_)$t, m, (p,i)) = poly(t, m, (p,ratadd(i,m)))
311 | poly(t as Const("op *",_) $ _ $ _, m, pi as (p,i)) =
312 (case demult inj_consts (t,m) of
313 (None,m') => (p,ratadd(i,m))
314 | (Some u,m') => add_atom(u,m',pi))
315 | poly(t as Const("HOL.divide",_) $ _ $ _, m, pi as (p,i)) =
316 (case demult inj_consts (t,m) of
317 (None,m') => (p,ratadd(i,m'))
318 | (Some u,m') => add_atom(u,m',pi))
319 | poly(all as (Const("Numeral.number_of",_)$t,m,(p,i))) =
320 ((p,ratadd(i,ratmul(m,rat_of_int(HOLogic.dest_binum t))))
321 handle TERM _ => add_atom all)
322 | poly(all as Const f $ x, m, pi) =
323 if f mem inj_consts then poly(x,m,pi) else add_atom(all,m,pi)
324 | poly x = add_atom x;
326 val (p,i) = poly(lhs,rat_of_int 1,([],rat_of_int 0))
327 and (q,j) = poly(rhs,rat_of_int 1,([],rat_of_int 0))
330 "op <" => Some(p,i,"<",q,j)
331 | "op <=" => Some(p,i,"<=",q,j)
332 | "op =" => Some(p,i,"=",q,j)
334 end handle Zero => None;
336 fun negate(Some(x,i,rel,y,j,d)) = Some(x,i,"~"^rel,y,j,d)
337 | negate None = None;
339 fun decomp1 (discrete,inj_consts) (T,xxx) =
341 Type("fun",[Type(D,[]),_]) =>
342 (case assoc(discrete,D) of
344 | Some d => (case decomp2 inj_consts xxx of
346 | Some(p,i,rel,q,j) => Some(p,i,rel,q,j,d)))
349 fun decomp2 data (_$(Const(rel,T)$lhs$rhs)) = decomp1 data (T,(rel,lhs,rhs))
350 | decomp2 data (_$(Const("Not",_)$(Const(rel,T)$lhs$rhs))) =
351 negate(decomp1 data (T,(rel,lhs,rhs)))
352 | decomp2 data _ = None
355 let val {discrete, inj_consts, ...} = ArithTheoryData.get_sg sg
356 in decomp2 (discrete,inj_consts) end
358 fun number_of(n,T) = HOLogic.number_of_const T $ (HOLogic.mk_bin n)
363 structure Fast_Arith =
364 Fast_Lin_Arith(structure LA_Logic=LA_Logic and LA_Data=LA_Data_Ref);
366 val fast_arith_tac = Fast_Arith.lin_arith_tac false
367 and fast_ex_arith_tac = Fast_Arith.lin_arith_tac
368 and trace_arith = Fast_Arith.trace
369 and fast_arith_neq_limit = Fast_Arith.fast_arith_neq_limit;
374 let val thy = theory "Nat"
375 in prove_goal thy "Suc(i+j) = i+j + Suc 0"
376 (fn _ => [simp_tac (simpset_of thy) 1])
379 (* reduce contradictory <= to False.
380 Most of the work is done by the cancel tactics.
383 [add_zero_left,add_zero_right,Zero_not_Suc,Suc_not_Zero,le_0_eq,
384 One_nat_def,isolateSuc];
386 val add_mono_thms_ordered_semiring = map (fn s => prove_goal (the_context ()) s
387 (fn prems => [cut_facts_tac prems 1,
388 blast_tac (claset() addIs [add_mono]) 1]))
389 ["(i <= j) & (k <= l) ==> i + k <= j + (l::'a::ordered_semidom)",
390 "(i = j) & (k <= l) ==> i + k <= j + (l::'a::ordered_semidom)",
391 "(i <= j) & (k = l) ==> i + k <= j + (l::'a::ordered_semidom)",
392 "(i = j) & (k = l) ==> i + k = j + (l::'a::ordered_semidom)"
397 val init_lin_arith_data =
399 [Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset = _} =>
400 {add_mono_thms = add_mono_thms @ add_mono_thms_ordered_semiring,
401 mult_mono_thms = mult_mono_thms,
403 lessD = lessD @ [Suc_leI],
404 simpset = HOL_basic_ss addsimps add_rules addsimprocs nat_cancel_sums_add}),
405 ArithTheoryData.init, arith_discrete ("nat", true)];
409 val fast_nat_arith_simproc =
410 Simplifier.simproc (Theory.sign_of (the_context ())) "fast_nat_arith"
411 ["(m::nat) < n","(m::nat) <= n", "(m::nat) = n"] Fast_Arith.lin_arith_prover;
414 (* Because of fast_nat_arith_simproc, the arithmetic solver is really only
415 useful to detect inconsistencies among the premises for subgoals which are
416 *not* themselves (in)equalities, because the latter activate
417 fast_nat_arith_simproc anyway. However, it seems cheaper to activate the
418 solver all the time rather than add the additional check. *)
421 (* arith proof method *)
423 (* FIXME: K true should be replaced by a sensible test to speed things up
424 in case there are lots of irrelevant terms involved;
425 elimination of min/max can be optimized:
426 (max m n + k <= r) = (m+k <= r & n+k <= r)
427 (l <= min m n + k) = (l <= m+k & l <= n+k)
431 fun raw_arith_tac ex i st =
433 (REPEAT o split_tac (#splits (ArithTheoryData.get_sg (Thm.sign_of_thm st))))
434 ((REPEAT_DETERM o etac linorder_neqE) THEN' fast_ex_arith_tac ex)
437 fun presburger_tac i st =
438 (case ArithTheoryData.get_sg (sign_of_thm st) of
439 {presburger = Some tac, ...} =>
440 (tracing "Simple arithmetic decision procedure failed.\nNow trying full Presburger arithmetic..."; tac i st)
445 val simple_arith_tac = FIRST' [fast_arith_tac,
446 ObjectLogic.atomize_tac THEN' raw_arith_tac true];
448 val arith_tac = FIRST' [fast_arith_tac,
449 ObjectLogic.atomize_tac THEN' raw_arith_tac true,
452 val silent_arith_tac = FIRST' [fast_arith_tac,
453 ObjectLogic.atomize_tac THEN' raw_arith_tac false,
456 fun arith_method prems =
457 Method.METHOD (fn facts => HEADGOAL (Method.insert_tac (prems @ facts) THEN' arith_tac));
465 [Simplifier.change_simpset_of (op addsimprocs) nat_cancel_sums] @
466 init_lin_arith_data @
467 [Simplifier.change_simpset_of (op addSolver)
468 (mk_solver "lin. arith." Fast_Arith.cut_lin_arith_tac),
469 Simplifier.change_simpset_of (op addsimprocs) [fast_nat_arith_simproc],
470 Method.add_methods [("arith", (arith_method o #2) oo Method.syntax Args.bang_facts,
471 "decide linear arithmethic")],
472 Attrib.add_attributes [("arith_split",
473 (Attrib.no_args arith_split_add, Attrib.no_args Attrib.undef_local_attribute),
474 "declaration of split rules for arithmetic procedure")]];