1 (* Title: Provers/Arith/fast_lin_arith.ML
2 Author: Tobias Nipkow and Tjark Weber and Sascha Boehme
4 A generic linear arithmetic package. It provides two tactics
5 (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
8 Only take premises and conclusions into account that are already
9 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
13 (*** Data needed for setting up the linear arithmetic package ***)
15 signature LIN_ARITH_LOGIC =
17 val conjI : thm (* P ==> Q ==> P & Q *)
18 val ccontr : thm (* (~ P ==> False) ==> P *)
19 val notI : thm (* (P ==> False) ==> ~ P *)
20 val not_lessD : thm (* ~(m < n) ==> n <= m *)
21 val not_leD : thm (* ~(m <= n) ==> n < m *)
22 val sym : thm (* x = y ==> y = x *)
23 val trueI : thm (* True *)
24 val mk_Eq : thm -> thm
25 val atomize : thm -> thm list
26 val mk_Trueprop : term -> term
27 val neg_prop : term -> term
28 val is_False : thm -> bool
29 val is_nat : typ list * term -> bool
30 val mk_nat_thm : theory -> term -> thm
33 mk_Eq(~in) = `in == False'
34 mk_Eq(in) = `in == True'
35 where `in' is an (in)equality.
37 neg_prop(t) = neg if t is wrapped up in Trueprop and neg is the
38 (logically) negated version of t (again wrapped up in Trueprop),
39 where the negation of a negative term is the term itself (no
40 double negation!); raises TERM ("neg_prop", [t]) if t is not of
41 the form 'Trueprop $ _'
43 is_nat(parameter-types,t) = t:nat
44 mk_nat_thm(t) = "0 <= t"
47 signature LIN_ARITH_DATA =
49 (*internal representation of linear (in-)equations:*)
50 type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
51 val decomp: Proof.context -> term -> decomp option
52 val domain_is_nat: term -> bool
54 (*preprocessing, performed on a representation of subgoals as list of premises:*)
55 val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
57 (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
58 val pre_tac: simpset -> int -> tactic
60 (*the limit on the number of ~= allowed; because each ~= is split
61 into two cases, this can lead to an explosion*)
62 val fast_arith_neq_limit: int Config.T
65 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
66 where Rel is one of "<", "~<", "<=", "~<=" and "=" and
67 p (q, respectively) is the decomposition of the sum term x
68 (y, respectively) into a list of summand * multiplicity
69 pairs and a constant summand and d indicates if the domain
72 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
74 The relationship between pre_decomp and pre_tac is somewhat tricky. The
75 internal representation of a subgoal and the corresponding theorem must
76 be modified by pre_decomp (pre_tac, resp.) in a corresponding way. See
77 the comment for split_items below. (This is even necessary for eta- and
78 beta-equivalent modifications, as some of the lin. arith. code is not
81 ss must reduce contradictory <= to False.
82 It should also cancel common summands to keep <= reduced;
83 otherwise <= can grow to massive proportions.
86 signature FAST_LIN_ARITH =
88 val cut_lin_arith_tac: simpset -> int -> tactic
89 val lin_arith_tac: Proof.context -> bool -> int -> tactic
90 val lin_arith_simproc: simpset -> term -> thm option
91 val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
92 lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
93 number_of : serial * (theory -> typ -> int -> cterm)}
94 -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
95 lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
96 number_of : serial * (theory -> typ -> int -> cterm)})
97 -> Context.generic -> Context.generic
98 val add_inj_thms: thm list -> Context.generic -> Context.generic
99 val add_lessD: thm -> Context.generic -> Context.generic
100 val add_simps: thm list -> Context.generic -> Context.generic
101 val add_simprocs: simproc list -> Context.generic -> Context.generic
102 val set_number_of: (theory -> typ -> int -> cterm) -> Context.generic -> Context.generic
103 val trace: bool Unsynchronized.ref
106 functor Fast_Lin_Arith
107 (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
113 fun no_number_of _ _ _ = raise CTERM ("number_of", [])
115 structure Data = Generic_Data
118 {add_mono_thms: thm list,
119 mult_mono_thms: thm list,
123 simpset: Simplifier.simpset,
124 number_of : serial * (theory -> typ -> int -> cterm)};
126 val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
127 lessD = [], neqE = [], simpset = Simplifier.empty_ss,
128 number_of = (serial (), no_number_of) } : T;
131 ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
132 lessD = lessD1, neqE=neqE1, simpset = simpset1,
133 number_of = (number_of1 as (s1, _))},
134 {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
135 lessD = lessD2, neqE=neqE2, simpset = simpset2,
136 number_of = (number_of2 as (s2, _))}) =
137 {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
138 mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
139 inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
140 lessD = Thm.merge_thms (lessD1, lessD2),
141 neqE = Thm.merge_thms (neqE1, neqE2),
142 simpset = Simplifier.merge_ss (simpset1, simpset2),
143 (* FIXME depends on accidental load order !?! *) (* FIXME *)
144 number_of = if s1 > s2 then number_of1 else number_of2};
147 val map_data = Data.map;
148 val get_data = Data.get o Context.Proof;
150 fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
151 {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
152 lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of};
154 fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
155 {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
156 lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of};
158 fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
159 {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
160 lessD = lessD, neqE = neqE, simpset = f simpset, number_of = number_of};
162 fun map_number_of f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
163 {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
164 lessD = lessD, neqE = neqE, simpset = simpset, number_of = f number_of};
166 fun add_inj_thms thms = map_data (map_inj_thms (append thms));
167 fun add_lessD thm = map_data (map_lessD (fn thms => thms @ [thm]));
168 fun add_simps thms = map_data (map_simpset (fn simpset => simpset addsimps thms));
169 fun add_simprocs procs = map_data (map_simpset (fn simpset => simpset addsimprocs procs));
171 fun set_number_of f = map_data (map_number_of (K (serial (), f)));
174 (*** A fast decision procedure ***)
175 (*** Code ported from HOL Light ***)
176 (* possible optimizations:
177 use (var,coeff) rep or vector rep tp save space;
178 treat non-negative atoms separately rather than adding 0 <= atom
181 val trace = Unsynchronized.ref false;
183 datatype lineq_type = Eq | Le | Lt;
185 datatype injust = Asm of int
186 | Nat of int (* index of atom *)
191 | Multiplied of int * injust
192 | Added of injust * injust;
194 datatype lineq = Lineq of int * lineq_type * int list * injust;
196 (* ------------------------------------------------------------------------- *)
197 (* Finding a (counter) example from the trace of a failed elimination *)
198 (* ------------------------------------------------------------------------- *)
199 (* Examples are represented as rational numbers, *)
200 (* Dont blame John Harrison for this code - it is entirely mine. TN *)
204 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
205 In general, true means the bound is included, false means it is excluded.
206 Need to know if it is a lower or upper bound for unambiguous interpretation!
209 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
210 | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
211 | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
213 (* PRE: ex[v] must be 0! *)
214 fun eval ex v (a, le, cs) =
216 val rs = map Rat.rat_of_int cs;
217 val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
218 in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
219 (* If nth rs v < 0, le should be negated.
220 Instead this swap is taken into account in ratrelmin2.
223 fun ratrelmin2 (x as (r, ler), y as (s, les)) =
225 of EQUAL => (r, (not ler) andalso (not les))
229 fun ratrelmax2 (x as (r, ler), y as (s, les)) =
231 of EQUAL => (r, ler andalso les)
235 val ratrelmin = foldr1 ratrelmin2;
236 val ratrelmax = foldr1 ratrelmax2;
238 fun ratexact up (r, exact) =
241 val (_, q) = Rat.quotient_of_rat r;
242 val nth = Rat.inv (Rat.rat_of_int q);
243 in Rat.add r (if up then nth else Rat.neg nth) end;
245 fun ratmiddle (r, s) = Rat.mult (Rat.add r s) (Rat.inv Rat.two);
247 fun choose2 d ((lb, exactl), (ub, exactu)) =
248 let val ord = Rat.sign lb in
249 if (ord = LESS orelse exactl) andalso (ord = GREATER orelse exactu)
253 then if exactl then lb else ratmiddle (lb, ub)
254 else if exactu then ub else ratmiddle (lb, ub)
255 else (*discrete domain, both bounds must be exact*)
257 then let val lb' = Rat.roundup lb in
258 if Rat.le lb' ub then lb' else raise NoEx end
259 else let val ub' = Rat.rounddown ub in
260 if Rat.le lb ub' then ub' else raise NoEx end
263 fun findex1 discr (v, lineqs) ex =
265 val nz = filter (fn (Lineq (_, _, cs, _)) => nth cs v <> 0) lineqs;
266 val ineqs = maps elim_eqns nz;
267 val (ge, le) = List.partition (fn (_,_,cs) => nth cs v > 0) ineqs
268 val lb = ratrelmax (map (eval ex v) ge)
269 val ub = ratrelmin (map (eval ex v) le)
270 in nth_map v (K (choose2 (nth discr v) (lb, ub))) ex end;
273 map (fn (a,le,bs) => (Rat.add a (Rat.neg (Rat.mult (nth bs v) x)), le,
274 nth_map v (K Rat.zero) bs));
276 fun single_var v (_, _, cs) = case filter_out (curry (op =) EQUAL o Rat.sign) cs
277 of [x] => x =/ nth cs v
281 all variables occur only with positive or only with negative coefficients *)
282 fun pick_vars discr (ineqs,ex) =
283 let val nz = filter_out (fn (_,_,cs) => forall (curry (op =) EQUAL o Rat.sign) cs) ineqs
284 in case nz of [] => ex
286 let val v = find_index (not o curry (op =) EQUAL o Rat.sign) cs
288 val pos = not (Rat.sign (nth cs v) = LESS);
289 val sv = filter (single_var v) nz;
291 if pos then if d then Rat.roundup o fst o ratrelmax
292 else ratexact true o ratrelmax
293 else if d then Rat.rounddown o fst o ratrelmin
294 else ratexact false o ratrelmin
295 val bnds = map (fn (a,le,bs) => (Rat.mult a (Rat.inv (nth bs v)), le)) sv
296 val x = minmax((Rat.zero,if pos then true else false)::bnds)
297 val ineqs' = elim1 v x nz
298 val ex' = nth_map v (K x) ex
299 in pick_vars discr (ineqs',ex') end
302 fun findex0 discr n lineqs =
303 let val ineqs = maps elim_eqns lineqs
304 val rineqs = map (fn (a,le,cs) => (Rat.rat_of_int a, le, map Rat.rat_of_int cs))
306 in pick_vars discr (rineqs,replicate n Rat.zero) end;
308 (* ------------------------------------------------------------------------- *)
309 (* End of counterexample finder. The actual decision procedure starts here. *)
310 (* ------------------------------------------------------------------------- *)
312 (* ------------------------------------------------------------------------- *)
313 (* Calculate new (in)equality type after addition. *)
314 (* ------------------------------------------------------------------------- *)
316 fun find_add_type(Eq,x) = x
317 | find_add_type(x,Eq) = x
318 | find_add_type(_,Lt) = Lt
319 | find_add_type(Lt,_) = Lt
320 | find_add_type(Le,Le) = Le;
322 (* ------------------------------------------------------------------------- *)
323 (* Multiply out an (in)equation. *)
324 (* ------------------------------------------------------------------------- *)
326 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
328 else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
329 else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
330 else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
332 (* ------------------------------------------------------------------------- *)
333 (* Add together (in)equations. *)
334 (* ------------------------------------------------------------------------- *)
336 fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
337 let val l = map2 Integer.add l1 l2
338 in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
340 (* ------------------------------------------------------------------------- *)
341 (* Elimination of variable between a single pair of (in)equations. *)
342 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve. *)
343 (* ------------------------------------------------------------------------- *)
345 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
346 let val c1 = nth l1 v and c2 = nth l2 v
347 val m = Integer.lcm (abs c1) (abs c2)
348 val m1 = m div (abs c1) and m2 = m div (abs c2)
350 if (c1 >= 0) = (c2 >= 0)
351 then if ty1 = Eq then (~m1,m2)
352 else if ty2 = Eq then (m1,~m2)
353 else sys_error "elim_var"
355 val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
356 then (~n1,~n2) else (n1,n2)
357 in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
359 (* ------------------------------------------------------------------------- *)
360 (* The main refutation-finding code. *)
361 (* ------------------------------------------------------------------------- *)
363 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
365 fun is_contradictory (Lineq(k,ty,_,_)) =
366 case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
369 let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
370 in length p * length n end;
372 (* ------------------------------------------------------------------------- *)
373 (* Main elimination code: *)
375 (* (1) Looks for immediate solutions (false assertions with no variables). *)
377 (* (2) If there are any equations, picks a variable with the lowest absolute *)
378 (* coefficient in any of them, and uses it to eliminate. *)
380 (* (3) Otherwise, chooses a variable in the inequality to minimize the *)
381 (* blowup (number of consequences generated) and eliminates it. *)
382 (* ------------------------------------------------------------------------- *)
384 fun extract_first p =
386 fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
387 | extract xs [] = raise Empty
390 fun print_ineqs ineqs =
392 tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
394 (case t of Eq => " = " | Lt=> " < " | Le => " <= ") ^
395 commas(map string_of_int l)) ineqs))
398 type history = (int * lineq list) list;
399 datatype result = Success of injust | Failure of history;
401 fun elim (ineqs, hist) =
402 let val _ = print_ineqs ineqs
403 val (triv, nontriv) = List.partition is_trivial ineqs in
405 then case Library.find_first is_contradictory triv of
406 NONE => elim (nontriv, hist)
407 | SOME(Lineq(_,_,_,j)) => Success j
409 if null nontriv then Failure hist
411 let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
412 if not (null eqs) then
414 fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
415 |> filter (fn i => i <> 0)
416 |> sort (int_ord o pairself abs)
418 val (eq as Lineq(_,_,ceq,_),othereqs) =
419 extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
420 val v = find_index (fn v => v = c) ceq
421 val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
423 val others = map (elim_var v eq) roth @ ioth
424 in elim(others,(v,nontriv)::hist) end
426 let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
427 val numlist = 0 upto (length (hd lists) - 1)
428 val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
429 val blows = map calc_blowup coeffs
430 val iblows = blows ~~ numlist
431 val nziblows = filter_out (fn (i, _) => i = 0) iblows
432 in if null nziblows then Failure((~1,nontriv)::hist)
434 let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
435 val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
436 val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
437 in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
442 (* ------------------------------------------------------------------------- *)
443 (* Translate back a proof. *)
444 (* ------------------------------------------------------------------------- *)
446 fun trace_thm ctxt msg th =
447 (if !trace then (tracing msg; tracing (Display.string_of_thm ctxt th)) else (); th);
449 fun trace_term ctxt msg t =
450 (if !trace then tracing (cat_lines [msg, Syntax.string_of_term ctxt t]) else (); t)
453 if !trace then tracing msg else ();
455 val union_term = union Pattern.aeconv;
456 val union_bterm = union (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t'));
458 fun add_atoms (lhs, _, _, rhs, _, _) =
459 union_term (map fst lhs) o union_term (map fst rhs);
461 fun atoms_of ds = fold add_atoms ds [];
464 Simplification may detect a contradiction 'prematurely' due to type
465 information: n+1 <= 0 is simplified to False and does not need to be crossed
469 exception FalseE of thm
472 fun mkthm ss asms (just: injust) =
474 val ctxt = Simplifier.the_context ss;
475 val thy = ProofContext.theory_of ctxt;
476 val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset,
477 number_of = (_, num_of), ...} = get_data ctxt;
478 val simpset' = Simplifier.inherit_context ss simpset;
479 fun only_concl f thm =
480 if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
481 val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
483 fun use_first rules thm =
484 get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
487 use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
488 fun try_add thms thm = get_first (fn th => add2 th thm) thms;
490 fun add_thms thm1 thm2 =
491 (case add2 thm1 thm2 of
493 (case try_add ([thm1] RL inj_thms) thm2 of
495 (the (try_add ([thm2] RL inj_thms) thm1)
497 (trace_thm ctxt "" thm1; trace_thm ctxt "" thm2;
498 sys_error "Linear arithmetic: failed to add thms"))
502 fun mult_by_add n thm =
503 let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
506 val rewr = Simplifier.rewrite simpset';
507 val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
508 (Conv.binop_conv rewr)));
509 fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
510 let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
511 in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
514 (case use_first mult_mono_thms thm of
515 NONE => mult_by_add n thm
518 val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
519 |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
520 val T = #T (Thm.rep_cterm cv)
523 |> Thm.instantiate ([], [(cv, num_of thy T n)])
526 handle CTERM _ => mult_by_add n thm
527 | THM _ => mult_by_add n thm
530 fun mult_thm (n, thm) =
531 if n = ~1 then thm RS LA_Logic.sym
532 else if n < 0 then mult (~n) thm RS LA_Logic.sym
536 let val thm' = trace_thm ctxt "Simplified:" (full_simplify simpset' thm)
537 in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
539 fun mk (Asm i) = trace_thm ctxt ("Asm " ^ string_of_int i) (nth asms i)
540 | mk (Nat i) = trace_thm ctxt ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
541 | mk (LessD j) = trace_thm ctxt "L" (hd ([mk j] RL lessD))
542 | mk (NotLeD j) = trace_thm ctxt "NLe" (mk j RS LA_Logic.not_leD)
543 | mk (NotLeDD j) = trace_thm ctxt "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
544 | mk (NotLessD j) = trace_thm ctxt "NL" (mk j RS LA_Logic.not_lessD)
545 | mk (Added (j1, j2)) = simp (trace_thm ctxt "+" (add_thms (mk j1) (mk j2)))
546 | mk (Multiplied (n, j)) =
547 (trace_msg ("*" ^ string_of_int n); trace_thm ctxt "*" (mult_thm (n, mk j)))
551 val _ = trace_msg "mkthm";
552 val thm = trace_thm ctxt "Final thm:" (mk just);
553 val fls = simplify simpset' thm;
554 val _ = trace_thm ctxt "After simplification:" fls;
556 if LA_Logic.is_False fls then ()
559 (["Assumptions:"] @ map (Display.string_of_thm ctxt) asms @ [""] @
560 ["Proved:", Display.string_of_thm ctxt fls, ""]));
561 warning "Linear arithmetic should have refuted the assumptions.\n\
562 \Please inform Tobias Nipkow.")
564 handle FalseE thm => trace_thm ctxt "False reached early:" thm
569 fun coeff poly atom =
570 AList.lookup Pattern.aeconv poly atom |> the_default 0;
572 fun integ(rlhs,r,rel,rrhs,s,d) =
573 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
574 val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
576 let val (i,j) = Rat.quotient_of_rat r
577 in (t,i * (m div j)) end
578 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
582 let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
583 val lhsa = map (coeff lhs) atoms
584 and rhsa = map (coeff rhs) atoms
585 val diff = map2 (curry (op -)) rhsa lhsa
588 fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
590 "<=" => lineq(c,Le,diff,just)
591 | "~<=" => if discrete
592 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
593 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
595 then lineq(c+1,Le,diff,LessD(just))
596 else lineq(c,Lt,diff,just)
597 | "~<" => lineq(~c,Le,map (op~) diff,NotLessD(just))
598 | "=" => lineq(c,Eq,diff,just)
599 | _ => sys_error("mklineq" ^ rel)
602 (* ------------------------------------------------------------------------- *)
603 (* Print (counter) example *)
604 (* ------------------------------------------------------------------------- *)
606 fun print_atom((a,d),r) =
607 let val (p,q) = Rat.quotient_of_rat r
608 val s = if d then string_of_int p else
610 else string_of_int p ^ "/" ^ string_of_int q
611 in a ^ " = " ^ s end;
617 #> curry (op ^) "Counterexample (possibly spurious):\n";
619 fun trace_ex ctxt params atoms discr n (hist: history) =
622 | (v, lineqs) :: hist' =>
624 val frees = map Free params
625 fun show_term t = Syntax.string_of_term ctxt (subst_bounds (frees, t))
627 if v = ~1 then (hist', findex0 discr n lineqs)
628 else (hist, replicate n Rat.zero)
629 val ex = SOME (produce_ex (map show_term atoms ~~ discr)
630 (uncurry (fold (findex1 discr)) start))
634 SOME s => (warning "Linear arithmetic failed - see trace for a counterexample."; tracing s)
635 | NONE => warning "Linear arithmetic failed"
638 (* ------------------------------------------------------------------------- *)
640 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
641 if LA_Logic.is_nat (pTs, atom)
642 then let val l = map (fn j => if j=i then 1 else 0) ixs
643 in SOME (Lineq (0, Le, l, Nat i)) end
646 (* This code is tricky. It takes a list of premises in the order they occur
647 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
648 ones as NONE. Going through the premises, each numeric one is converted into
649 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
650 >. Thus split_items returns a list of equation systems. This may blow up if
651 there are many ~=, but in practice it does not seem to happen. The really
652 tricky bit is to arrange the order of the cases such that they coincide with
653 the order in which the cases are in the end generated by the tactic that
654 applies the generated refutation thms (see function 'refute_tac').
656 For variables n of type nat, a constraint 0 <= n is added.
659 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
660 (* could be intertwined: separate the first (fully split) case, *)
661 (* refute it, continue with splitting and refuting. Terminate with *)
662 (* failure as soon as a case could not be refuted; i.e. delay further *)
663 (* splitting until after a refutation for other cases has been found. *)
665 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
667 (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
668 (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic *)
670 (* FIXME: this is currently sensitive to the order of theorems in *)
671 (* neqE: The theorem for type "nat" must come first. A *)
672 (* better (i.e. less likely to break when neqE changes) *)
673 (* implementation should *test* which theorem from neqE *)
674 (* can be applied, and split the premise accordingly. *)
675 fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
676 (LA_Data.decomp option * bool) list list =
678 fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
679 (LA_Data.decomp option * bool) list list =
681 | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
682 map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
683 | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
684 if rel = "~=" andalso (not nat_only orelse is_nat) then
685 (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
686 elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
687 elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
689 map (cons ineq) (elim_neq' nat_only ineqs)
691 ineqs |> elim_neq' true
692 |> maps (elim_neq' false)
695 fun ignore_neq (NONE, bool) = (NONE, bool)
696 | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
697 if rel = "~=" then (NONE, bool) else (ineq, bool)
699 fun number_hyps _ [] = []
700 | number_hyps n (NONE::xs) = number_hyps (n+1) xs
701 | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
703 val result = (Ts, terms)
704 |> (* user-defined preprocessing of the subgoal *)
705 (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
706 |> tap (fn subgoals => trace_msg ("Preprocessing yields " ^
707 string_of_int (length subgoals) ^ " subgoal(s) total."))
708 |> (* produce the internal encoding of (in-)equalities *)
709 map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
710 |> (* splitting of inequalities *)
711 map (apsnd (if split_neq then elim_neq else
712 Library.single o map ignore_neq))
713 |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
714 |> (* numbering of hypotheses, ignoring irrelevant ones *)
715 map (apsnd (number_hyps 0))
717 trace_msg ("Splitting of inequalities yields " ^
718 string_of_int (length result) ^ " subgoal(s) total.");
722 fun add_datoms ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _) (dats : (bool * term) list) =
723 union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
725 fun discr (initems : (LA_Data.decomp * int) list) : bool list =
726 map fst (fold add_datoms initems []);
728 fun refutes ctxt params show_ex :
729 (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
731 fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
733 val atoms = atoms_of (map fst initems)
735 val mkleq = mklineq atoms
736 val ixs = 0 upto (n - 1)
737 val iatoms = atoms ~~ ixs
738 val natlineqs = map_filter (mknat Ts ixs) iatoms
739 val ineqs = map mkleq initems @ natlineqs
740 in case elim (ineqs, []) of
742 (trace_msg ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
743 refute initemss (js @ [j]))
745 (if not show_ex then ()
748 val (param_names, ctxt') = ctxt |> Variable.variant_fixes (map fst params)
749 val (more_names, ctxt'') = ctxt' |> Variable.variant_fixes
750 (Name.invents (Variable.names_of ctxt') Name.uu (length Ts - length params))
751 val params' = (more_names @ param_names) ~~ Ts
753 trace_ex ctxt'' params' atoms (discr initems) n hist
756 | refute [] js = SOME js
759 fun refute ctxt params show_ex do_pre split_neq terms : injust list option =
760 refutes ctxt params show_ex (split_items ctxt do_pre split_neq
761 (map snd params, terms)) [];
763 fun count P xs = length (filter P xs);
765 fun prove ctxt params show_ex do_pre Hs concl : bool * injust list option =
767 val _ = trace_msg "prove:"
768 (* append the negated conclusion to 'Hs' -- this corresponds to *)
769 (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
770 (* theorem/tactic level *)
771 val Hs' = Hs @ [LA_Logic.neg_prop concl]
772 fun is_neq NONE = false
773 | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
774 val neq_limit = Config.get ctxt LA_Data.fast_arith_neq_limit
775 val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
779 trace_msg ("fast_arith_neq_limit exceeded (current value is " ^
780 string_of_int neq_limit ^ "), ignoring all inequalities");
781 (split_neq, refute ctxt params show_ex do_pre split_neq Hs')
782 end handle TERM ("neg_prop", _) =>
783 (* since no meta-logic negation is available, we can only fail if *)
784 (* the conclusion is not of the form 'Trueprop $ _' (simply *)
785 (* dropping the conclusion doesn't work either, because even *)
786 (* 'False' does not imply arbitrary 'concl::prop') *)
787 (trace_msg "prove failed (cannot negate conclusion).";
790 fun refute_tac ss (i, split_neq, justs) =
793 val ctxt = Simplifier.the_context ss;
794 val _ = trace_thm ctxt
795 ("refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
796 string_of_int (length justs) ^ " justification(s)):") state
797 val {neqE, ...} = get_data ctxt;
799 (* eliminate inequalities *)
801 REPEAT_DETERM (eresolve_tac neqE i)
804 PRIMITIVE (trace_thm ctxt "State after neqE:") THEN
805 (* use theorems generated from the actual justifications *)
806 Subgoal.FOCUS (fn {prems, ...} => rtac (mkthm ss prems j) 1) ctxt i
808 (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
809 DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i) THEN
810 (* user-defined preprocessing of the subgoal *)
811 DETERM (LA_Data.pre_tac ss i) THEN
812 PRIMITIVE (trace_thm ctxt "State after pre_tac:") THEN
813 (* prove every resulting subgoal, using its justification *)
814 EVERY (map just1 justs)
818 Fast but very incomplete decider. Only premises and conclusions
819 that are already (negated) (in)equations are taken into account.
821 fun simpset_lin_arith_tac ss show_ex = SUBGOAL (fn (A, i) =>
823 val ctxt = Simplifier.the_context ss
824 val params = rev (Logic.strip_params A)
825 val Hs = Logic.strip_assums_hyp A
826 val concl = Logic.strip_assums_concl A
827 val _ = trace_term ctxt ("Trying to refute subgoal " ^ string_of_int i) A
829 case prove ctxt params show_ex true Hs concl of
830 (_, NONE) => (trace_msg "Refutation failed."; no_tac)
831 | (split_neq, SOME js) => (trace_msg "Refutation succeeded.";
832 refute_tac ss (i, split_neq, js))
835 fun cut_lin_arith_tac ss =
836 cut_facts_tac (Simplifier.prems_of_ss ss) THEN'
837 simpset_lin_arith_tac ss false;
839 fun lin_arith_tac ctxt =
840 simpset_lin_arith_tac (Simplifier.context ctxt Simplifier.empty_ss);
844 (** Forward proof from theorems **)
846 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
847 to splits of ~= premises) such that it coincides with the order of the cases
848 generated by function split_items. *)
850 datatype splittree = Tip of thm list
851 | Spl of thm * cterm * splittree * cterm * splittree;
853 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
855 fun extract (imp : cterm) : cterm * cterm =
856 let val (Il, r) = Thm.dest_comb imp
857 val (_, imp1) = Thm.dest_comb Il
858 val (Ict1, _) = Thm.dest_comb imp1
859 val (_, ct1) = Thm.dest_comb Ict1
860 val (Ir, _) = Thm.dest_comb r
861 val (_, Ict2r) = Thm.dest_comb Ir
862 val (Ict2, _) = Thm.dest_comb Ict2r
863 val (_, ct2) = Thm.dest_comb Ict2
866 fun splitasms ctxt (asms : thm list) : splittree =
867 let val {neqE, ...} = get_data ctxt
868 fun elim_neq [] (asms', []) = Tip (rev asms')
869 | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
870 | elim_neq (neq :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
871 | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
872 (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
874 let val (ct1, ct2) = extract (cprop_of spl)
875 val thm1 = Thm.assume ct1
876 val thm2 = Thm.assume ct2
877 in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
878 ct2, elim_neq neqs (asms', asms@[thm2]))
880 | NONE => elim_neq neqs (asm::asms', asms))
881 in elim_neq neqE ([], asms) end;
883 fun fwdproof ss (Tip asms : splittree) (j::js : injust list) = (mkthm ss asms j, js)
884 | fwdproof ss (Spl (thm, ct1, tree1, ct2, tree2)) js =
886 val (thm1, js1) = fwdproof ss tree1 js
887 val (thm2, js2) = fwdproof ss tree2 js1
888 val thm1' = Thm.implies_intr ct1 thm1
889 val thm2' = Thm.implies_intr ct2 thm2
890 in (thm2' COMP (thm1' COMP thm), js2) end;
891 (* FIXME needs handle THM _ => NONE ? *)
893 fun prover ss thms Tconcl (js : injust list) split_neq pos : thm option =
895 val ctxt = Simplifier.the_context ss
896 val thy = ProofContext.theory_of ctxt
897 val nTconcl = LA_Logic.neg_prop Tconcl
898 val cnTconcl = cterm_of thy nTconcl
899 val nTconclthm = Thm.assume cnTconcl
900 val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
901 val (Falsethm, _) = fwdproof ss tree js
902 val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
903 val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
904 in SOME (trace_thm ctxt "Proved by lin. arith. prover:" (LA_Logic.mk_Eq concl)) end
905 (*in case concl contains ?-var, which makes assume fail:*) (* FIXME Variable.import_terms *)
906 handle THM _ => NONE;
908 (* PRE: concl is not negated!
909 This assumption is OK because
910 1. lin_arith_simproc tries both to prove and disprove concl and
911 2. lin_arith_simproc is applied by the Simplifier which
912 dives into terms and will thus try the non-negated concl anyway.
914 fun lin_arith_simproc ss concl =
916 val ctxt = Simplifier.the_context ss
917 val thms = maps LA_Logic.atomize (Simplifier.prems_of_ss ss)
918 val Hs = map Thm.prop_of thms
919 val Tconcl = LA_Logic.mk_Trueprop concl
921 case prove ctxt [] false false Hs Tconcl of (* concl provable? *)
922 (split_neq, SOME js) => prover ss thms Tconcl js split_neq true
924 let val nTconcl = LA_Logic.neg_prop Tconcl in
925 case prove ctxt [] false false Hs nTconcl of (* ~concl provable? *)
926 (split_neq, SOME js) => prover ss thms nTconcl js split_neq false