1 (* Title: Provers/Arith/fast_lin_arith.ML
2 Author: Tobias Nipkow and Tjark Weber and Sascha Boehme
4 A generic linear arithmetic package. It provides two tactics
5 (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
8 Only take premises and conclusions into account that are already
9 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
13 (*** Data needed for setting up the linear arithmetic package ***)
15 signature LIN_ARITH_LOGIC =
17 val conjI : thm (* P ==> Q ==> P & Q *)
18 val ccontr : thm (* (~ P ==> False) ==> P *)
19 val notI : thm (* (P ==> False) ==> ~ P *)
20 val not_lessD : thm (* ~(m < n) ==> n <= m *)
21 val not_leD : thm (* ~(m <= n) ==> n < m *)
22 val sym : thm (* x = y ==> y = x *)
23 val trueI : thm (* True *)
24 val mk_Eq : thm -> thm
25 val atomize : thm -> thm list
26 val mk_Trueprop : term -> term
27 val neg_prop : term -> term
28 val is_False : thm -> bool
29 val is_nat : typ list * term -> bool
30 val mk_nat_thm : theory -> term -> thm
33 mk_Eq(~in) = `in == False'
34 mk_Eq(in) = `in == True'
35 where `in' is an (in)equality.
37 neg_prop(t) = neg if t is wrapped up in Trueprop and neg is the
38 (logically) negated version of t (again wrapped up in Trueprop),
39 where the negation of a negative term is the term itself (no
40 double negation!); raises TERM ("neg_prop", [t]) if t is not of
41 the form 'Trueprop $ _'
43 is_nat(parameter-types,t) = t:nat
44 mk_nat_thm(t) = "0 <= t"
47 signature LIN_ARITH_DATA =
49 (*internal representation of linear (in-)equations:*)
50 type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
51 val decomp: Proof.context -> term -> decomp option
52 val domain_is_nat: term -> bool
54 (*preprocessing, performed on a representation of subgoals as list of premises:*)
55 val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
57 (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
58 val pre_tac: simpset -> int -> tactic
60 (*the limit on the number of ~= allowed; because each ~= is split
61 into two cases, this can lead to an explosion*)
62 val fast_arith_neq_limit: int Config.T
65 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
66 where Rel is one of "<", "~<", "<=", "~<=" and "=" and
67 p (q, respectively) is the decomposition of the sum term x
68 (y, respectively) into a list of summand * multiplicity
69 pairs and a constant summand and d indicates if the domain
72 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
74 The relationship between pre_decomp and pre_tac is somewhat tricky. The
75 internal representation of a subgoal and the corresponding theorem must
76 be modified by pre_decomp (pre_tac, resp.) in a corresponding way. See
77 the comment for split_items below. (This is even necessary for eta- and
78 beta-equivalent modifications, as some of the lin. arith. code is not
81 ss must reduce contradictory <= to False.
82 It should also cancel common summands to keep <= reduced;
83 otherwise <= can grow to massive proportions.
86 signature FAST_LIN_ARITH =
88 val cut_lin_arith_tac: simpset -> int -> tactic
89 val lin_arith_tac: Proof.context -> bool -> int -> tactic
90 val lin_arith_simproc: simpset -> term -> thm option
91 val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
92 lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
93 number_of : serial * (theory -> typ -> int -> cterm)}
94 -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
95 lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
96 number_of : serial * (theory -> typ -> int -> cterm)})
97 -> Context.generic -> Context.generic
98 val trace: bool Unsynchronized.ref
101 functor Fast_Lin_Arith
102 (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
108 fun no_number_of _ _ _ = raise CTERM ("number_of", [])
110 structure Data = Generic_Data
113 {add_mono_thms: thm list,
114 mult_mono_thms: thm list,
118 simpset: Simplifier.simpset,
119 number_of : serial * (theory -> typ -> int -> cterm)};
121 val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
122 lessD = [], neqE = [], simpset = Simplifier.empty_ss,
123 number_of = (serial (), no_number_of) } : T;
126 ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
127 lessD = lessD1, neqE=neqE1, simpset = simpset1,
128 number_of = (number_of1 as (s1, _))},
129 {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
130 lessD = lessD2, neqE=neqE2, simpset = simpset2,
131 number_of = (number_of2 as (s2, _))}) =
132 {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
133 mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
134 inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
135 lessD = Thm.merge_thms (lessD1, lessD2),
136 neqE = Thm.merge_thms (neqE1, neqE2),
137 simpset = Simplifier.merge_ss (simpset1, simpset2),
138 (* FIXME depends on accidental load order !?! *)
139 number_of = if s1 > s2 then number_of1 else number_of2};
142 val map_data = Data.map;
143 val get_data = Data.get o Context.Proof;
147 (*** A fast decision procedure ***)
148 (*** Code ported from HOL Light ***)
149 (* possible optimizations:
150 use (var,coeff) rep or vector rep tp save space;
151 treat non-negative atoms separately rather than adding 0 <= atom
154 val trace = Unsynchronized.ref false;
156 datatype lineq_type = Eq | Le | Lt;
158 datatype injust = Asm of int
159 | Nat of int (* index of atom *)
164 | Multiplied of int * injust
165 | Added of injust * injust;
167 datatype lineq = Lineq of int * lineq_type * int list * injust;
169 (* ------------------------------------------------------------------------- *)
170 (* Finding a (counter) example from the trace of a failed elimination *)
171 (* ------------------------------------------------------------------------- *)
172 (* Examples are represented as rational numbers, *)
173 (* Dont blame John Harrison for this code - it is entirely mine. TN *)
177 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
178 In general, true means the bound is included, false means it is excluded.
179 Need to know if it is a lower or upper bound for unambiguous interpretation!
182 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
183 | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
184 | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
186 (* PRE: ex[v] must be 0! *)
187 fun eval ex v (a, le, cs) =
189 val rs = map Rat.rat_of_int cs;
190 val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
191 in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
192 (* If nth rs v < 0, le should be negated.
193 Instead this swap is taken into account in ratrelmin2.
196 fun ratrelmin2 (x as (r, ler), y as (s, les)) =
198 of EQUAL => (r, (not ler) andalso (not les))
202 fun ratrelmax2 (x as (r, ler), y as (s, les)) =
204 of EQUAL => (r, ler andalso les)
208 val ratrelmin = foldr1 ratrelmin2;
209 val ratrelmax = foldr1 ratrelmax2;
211 fun ratexact up (r, exact) =
214 val (_, q) = Rat.quotient_of_rat r;
215 val nth = Rat.inv (Rat.rat_of_int q);
216 in Rat.add r (if up then nth else Rat.neg nth) end;
218 fun ratmiddle (r, s) = Rat.mult (Rat.add r s) (Rat.inv Rat.two);
220 fun choose2 d ((lb, exactl), (ub, exactu)) =
221 let val ord = Rat.sign lb in
222 if (ord = LESS orelse exactl) andalso (ord = GREATER orelse exactu)
226 then if exactl then lb else ratmiddle (lb, ub)
227 else if exactu then ub else ratmiddle (lb, ub)
228 else (*discrete domain, both bounds must be exact*)
230 then let val lb' = Rat.roundup lb in
231 if Rat.le lb' ub then lb' else raise NoEx end
232 else let val ub' = Rat.rounddown ub in
233 if Rat.le lb ub' then ub' else raise NoEx end
236 fun findex1 discr (v, lineqs) ex =
238 val nz = filter (fn (Lineq (_, _, cs, _)) => nth cs v <> 0) lineqs;
239 val ineqs = maps elim_eqns nz;
240 val (ge, le) = List.partition (fn (_,_,cs) => nth cs v > 0) ineqs
241 val lb = ratrelmax (map (eval ex v) ge)
242 val ub = ratrelmin (map (eval ex v) le)
243 in nth_map v (K (choose2 (nth discr v) (lb, ub))) ex end;
246 map (fn (a,le,bs) => (Rat.add a (Rat.neg (Rat.mult (nth bs v) x)), le,
247 nth_map v (K Rat.zero) bs));
249 fun single_var v (_, _, cs) = case filter_out (curry (op =) EQUAL o Rat.sign) cs
250 of [x] => x =/ nth cs v
254 all variables occur only with positive or only with negative coefficients *)
255 fun pick_vars discr (ineqs,ex) =
256 let val nz = filter_out (fn (_,_,cs) => forall (curry (op =) EQUAL o Rat.sign) cs) ineqs
257 in case nz of [] => ex
259 let val v = find_index (not o curry (op =) EQUAL o Rat.sign) cs
261 val pos = not (Rat.sign (nth cs v) = LESS);
262 val sv = filter (single_var v) nz;
264 if pos then if d then Rat.roundup o fst o ratrelmax
265 else ratexact true o ratrelmax
266 else if d then Rat.rounddown o fst o ratrelmin
267 else ratexact false o ratrelmin
268 val bnds = map (fn (a,le,bs) => (Rat.mult a (Rat.inv (nth bs v)), le)) sv
269 val x = minmax((Rat.zero,if pos then true else false)::bnds)
270 val ineqs' = elim1 v x nz
271 val ex' = nth_map v (K x) ex
272 in pick_vars discr (ineqs',ex') end
275 fun findex0 discr n lineqs =
276 let val ineqs = maps elim_eqns lineqs
277 val rineqs = map (fn (a,le,cs) => (Rat.rat_of_int a, le, map Rat.rat_of_int cs))
279 in pick_vars discr (rineqs,replicate n Rat.zero) end;
281 (* ------------------------------------------------------------------------- *)
282 (* End of counterexample finder. The actual decision procedure starts here. *)
283 (* ------------------------------------------------------------------------- *)
285 (* ------------------------------------------------------------------------- *)
286 (* Calculate new (in)equality type after addition. *)
287 (* ------------------------------------------------------------------------- *)
289 fun find_add_type(Eq,x) = x
290 | find_add_type(x,Eq) = x
291 | find_add_type(_,Lt) = Lt
292 | find_add_type(Lt,_) = Lt
293 | find_add_type(Le,Le) = Le;
295 (* ------------------------------------------------------------------------- *)
296 (* Multiply out an (in)equation. *)
297 (* ------------------------------------------------------------------------- *)
299 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
301 else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
302 else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
303 else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
305 (* ------------------------------------------------------------------------- *)
306 (* Add together (in)equations. *)
307 (* ------------------------------------------------------------------------- *)
309 fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
310 let val l = map2 Integer.add l1 l2
311 in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
313 (* ------------------------------------------------------------------------- *)
314 (* Elimination of variable between a single pair of (in)equations. *)
315 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve. *)
316 (* ------------------------------------------------------------------------- *)
318 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
319 let val c1 = nth l1 v and c2 = nth l2 v
320 val m = Integer.lcm (abs c1) (abs c2)
321 val m1 = m div (abs c1) and m2 = m div (abs c2)
323 if (c1 >= 0) = (c2 >= 0)
324 then if ty1 = Eq then (~m1,m2)
325 else if ty2 = Eq then (m1,~m2)
326 else sys_error "elim_var"
328 val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
329 then (~n1,~n2) else (n1,n2)
330 in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
332 (* ------------------------------------------------------------------------- *)
333 (* The main refutation-finding code. *)
334 (* ------------------------------------------------------------------------- *)
336 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
338 fun is_contradictory (Lineq(k,ty,_,_)) =
339 case ty of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
342 let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
343 in length p * length n end;
345 (* ------------------------------------------------------------------------- *)
346 (* Main elimination code: *)
348 (* (1) Looks for immediate solutions (false assertions with no variables). *)
350 (* (2) If there are any equations, picks a variable with the lowest absolute *)
351 (* coefficient in any of them, and uses it to eliminate. *)
353 (* (3) Otherwise, chooses a variable in the inequality to minimize the *)
354 (* blowup (number of consequences generated) and eliminates it. *)
355 (* ------------------------------------------------------------------------- *)
357 fun extract_first p =
359 fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
360 | extract xs [] = raise Empty
363 fun print_ineqs ineqs =
365 tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
367 (case t of Eq => " = " | Lt=> " < " | Le => " <= ") ^
368 commas(map string_of_int l)) ineqs))
371 type history = (int * lineq list) list;
372 datatype result = Success of injust | Failure of history;
374 fun elim (ineqs, hist) =
375 let val _ = print_ineqs ineqs
376 val (triv, nontriv) = List.partition is_trivial ineqs in
378 then case Library.find_first is_contradictory triv of
379 NONE => elim (nontriv, hist)
380 | SOME(Lineq(_,_,_,j)) => Success j
382 if null nontriv then Failure hist
384 let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
385 if not (null eqs) then
387 fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
388 |> filter (fn i => i <> 0)
389 |> sort (int_ord o pairself abs)
391 val (eq as Lineq(_,_,ceq,_),othereqs) =
392 extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
393 val v = find_index (fn v => v = c) ceq
394 val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
396 val others = map (elim_var v eq) roth @ ioth
397 in elim(others,(v,nontriv)::hist) end
399 let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
400 val numlist = 0 upto (length (hd lists) - 1)
401 val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
402 val blows = map calc_blowup coeffs
403 val iblows = blows ~~ numlist
404 val nziblows = filter_out (fn (i, _) => i = 0) iblows
405 in if null nziblows then Failure((~1,nontriv)::hist)
407 let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
408 val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
409 val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
410 in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
415 (* ------------------------------------------------------------------------- *)
416 (* Translate back a proof. *)
417 (* ------------------------------------------------------------------------- *)
419 fun trace_thm ctxt msg th =
420 (if !trace then (tracing msg; tracing (Display.string_of_thm ctxt th)) else (); th);
422 fun trace_term ctxt msg t =
423 (if !trace then tracing (cat_lines [msg, Syntax.string_of_term ctxt t]) else (); t)
426 if !trace then tracing msg else ();
428 val union_term = union Pattern.aeconv;
429 val union_bterm = union (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t'));
431 fun add_atoms (lhs, _, _, rhs, _, _) =
432 union_term (map fst lhs) o union_term (map fst rhs);
434 fun atoms_of ds = fold add_atoms ds [];
437 Simplification may detect a contradiction 'prematurely' due to type
438 information: n+1 <= 0 is simplified to False and does not need to be crossed
442 exception FalseE of thm
445 fun mkthm ss asms (just: injust) =
447 val ctxt = Simplifier.the_context ss;
448 val thy = ProofContext.theory_of ctxt;
449 val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset,
450 number_of = (_, num_of), ...} = get_data ctxt;
451 val simpset' = Simplifier.inherit_context ss simpset;
452 fun only_concl f thm =
453 if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
454 val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
456 fun use_first rules thm =
457 get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
460 use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
461 fun try_add thms thm = get_first (fn th => add2 th thm) thms;
463 fun add_thms thm1 thm2 =
464 (case add2 thm1 thm2 of
466 (case try_add ([thm1] RL inj_thms) thm2 of
468 (the (try_add ([thm2] RL inj_thms) thm1)
470 (trace_thm ctxt "" thm1; trace_thm ctxt "" thm2;
471 sys_error "Linear arithmetic: failed to add thms"))
475 fun mult_by_add n thm =
476 let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
479 val rewr = Simplifier.rewrite simpset';
480 val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
481 (Conv.binop_conv rewr)));
482 fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
483 let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
484 in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
487 (case use_first mult_mono_thms thm of
488 NONE => mult_by_add n thm
491 val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
492 |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
493 val T = #T (Thm.rep_cterm cv)
496 |> Thm.instantiate ([], [(cv, num_of thy T n)])
499 handle CTERM _ => mult_by_add n thm
500 | THM _ => mult_by_add n thm
503 fun mult_thm (n, thm) =
504 if n = ~1 then thm RS LA_Logic.sym
505 else if n < 0 then mult (~n) thm RS LA_Logic.sym
509 let val thm' = trace_thm ctxt "Simplified:" (full_simplify simpset' thm)
510 in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
512 fun mk (Asm i) = trace_thm ctxt ("Asm " ^ string_of_int i) (nth asms i)
513 | mk (Nat i) = trace_thm ctxt ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
514 | mk (LessD j) = trace_thm ctxt "L" (hd ([mk j] RL lessD))
515 | mk (NotLeD j) = trace_thm ctxt "NLe" (mk j RS LA_Logic.not_leD)
516 | mk (NotLeDD j) = trace_thm ctxt "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
517 | mk (NotLessD j) = trace_thm ctxt "NL" (mk j RS LA_Logic.not_lessD)
518 | mk (Added (j1, j2)) = simp (trace_thm ctxt "+" (add_thms (mk j1) (mk j2)))
519 | mk (Multiplied (n, j)) =
520 (trace_msg ("*" ^ string_of_int n); trace_thm ctxt "*" (mult_thm (n, mk j)))
524 val _ = trace_msg "mkthm";
525 val thm = trace_thm ctxt "Final thm:" (mk just);
526 val fls = simplify simpset' thm;
527 val _ = trace_thm ctxt "After simplification:" fls;
529 if LA_Logic.is_False fls then ()
532 (["Assumptions:"] @ map (Display.string_of_thm ctxt) asms @ [""] @
533 ["Proved:", Display.string_of_thm ctxt fls, ""]));
534 warning "Linear arithmetic should have refuted the assumptions.\n\
535 \Please inform Tobias Nipkow.")
537 handle FalseE thm => trace_thm ctxt "False reached early:" thm
542 fun coeff poly atom =
543 AList.lookup Pattern.aeconv poly atom |> the_default 0;
545 fun integ(rlhs,r,rel,rrhs,s,d) =
546 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
547 val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
549 let val (i,j) = Rat.quotient_of_rat r
550 in (t,i * (m div j)) end
551 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
555 let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
556 val lhsa = map (coeff lhs) atoms
557 and rhsa = map (coeff rhs) atoms
558 val diff = map2 (curry (op -)) rhsa lhsa
561 fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
563 "<=" => lineq(c,Le,diff,just)
564 | "~<=" => if discrete
565 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
566 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
568 then lineq(c+1,Le,diff,LessD(just))
569 else lineq(c,Lt,diff,just)
570 | "~<" => lineq(~c,Le,map (op~) diff,NotLessD(just))
571 | "=" => lineq(c,Eq,diff,just)
572 | _ => sys_error("mklineq" ^ rel)
575 (* ------------------------------------------------------------------------- *)
576 (* Print (counter) example *)
577 (* ------------------------------------------------------------------------- *)
579 fun print_atom((a,d),r) =
580 let val (p,q) = Rat.quotient_of_rat r
581 val s = if d then string_of_int p else
583 else string_of_int p ^ "/" ^ string_of_int q
584 in a ^ " = " ^ s end;
590 #> curry (op ^) "Counterexample (possibly spurious):\n";
592 fun trace_ex ctxt params atoms discr n (hist: history) =
595 | (v, lineqs) :: hist' =>
597 val frees = map Free params
598 fun show_term t = Syntax.string_of_term ctxt (subst_bounds (frees, t))
600 if v = ~1 then (hist', findex0 discr n lineqs)
601 else (hist, replicate n Rat.zero)
602 val ex = SOME (produce_ex (map show_term atoms ~~ discr)
603 (uncurry (fold (findex1 discr)) start))
607 SOME s => (warning "Linear arithmetic failed - see trace for a counterexample."; tracing s)
608 | NONE => warning "Linear arithmetic failed"
611 (* ------------------------------------------------------------------------- *)
613 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
614 if LA_Logic.is_nat (pTs, atom)
615 then let val l = map (fn j => if j=i then 1 else 0) ixs
616 in SOME (Lineq (0, Le, l, Nat i)) end
619 (* This code is tricky. It takes a list of premises in the order they occur
620 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
621 ones as NONE. Going through the premises, each numeric one is converted into
622 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
623 >. Thus split_items returns a list of equation systems. This may blow up if
624 there are many ~=, but in practice it does not seem to happen. The really
625 tricky bit is to arrange the order of the cases such that they coincide with
626 the order in which the cases are in the end generated by the tactic that
627 applies the generated refutation thms (see function 'refute_tac').
629 For variables n of type nat, a constraint 0 <= n is added.
632 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
633 (* could be intertwined: separate the first (fully split) case, *)
634 (* refute it, continue with splitting and refuting. Terminate with *)
635 (* failure as soon as a case could not be refuted; i.e. delay further *)
636 (* splitting until after a refutation for other cases has been found. *)
638 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
640 (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
641 (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic *)
643 (* FIXME: this is currently sensitive to the order of theorems in *)
644 (* neqE: The theorem for type "nat" must come first. A *)
645 (* better (i.e. less likely to break when neqE changes) *)
646 (* implementation should *test* which theorem from neqE *)
647 (* can be applied, and split the premise accordingly. *)
648 fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
649 (LA_Data.decomp option * bool) list list =
651 fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
652 (LA_Data.decomp option * bool) list list =
654 | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
655 map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
656 | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
657 if rel = "~=" andalso (not nat_only orelse is_nat) then
658 (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
659 elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
660 elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
662 map (cons ineq) (elim_neq' nat_only ineqs)
664 ineqs |> elim_neq' true
665 |> maps (elim_neq' false)
668 fun ignore_neq (NONE, bool) = (NONE, bool)
669 | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
670 if rel = "~=" then (NONE, bool) else (ineq, bool)
672 fun number_hyps _ [] = []
673 | number_hyps n (NONE::xs) = number_hyps (n+1) xs
674 | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
676 val result = (Ts, terms)
677 |> (* user-defined preprocessing of the subgoal *)
678 (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
679 |> tap (fn subgoals => trace_msg ("Preprocessing yields " ^
680 string_of_int (length subgoals) ^ " subgoal(s) total."))
681 |> (* produce the internal encoding of (in-)equalities *)
682 map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
683 |> (* splitting of inequalities *)
684 map (apsnd (if split_neq then elim_neq else
685 Library.single o map ignore_neq))
686 |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
687 |> (* numbering of hypotheses, ignoring irrelevant ones *)
688 map (apsnd (number_hyps 0))
690 trace_msg ("Splitting of inequalities yields " ^
691 string_of_int (length result) ^ " subgoal(s) total.");
695 fun add_datoms ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _) (dats : (bool * term) list) =
696 union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
698 fun discr (initems : (LA_Data.decomp * int) list) : bool list =
699 map fst (fold add_datoms initems []);
701 fun refutes ctxt params show_ex :
702 (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
704 fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
706 val atoms = atoms_of (map fst initems)
708 val mkleq = mklineq atoms
709 val ixs = 0 upto (n - 1)
710 val iatoms = atoms ~~ ixs
711 val natlineqs = map_filter (mknat Ts ixs) iatoms
712 val ineqs = map mkleq initems @ natlineqs
713 in case elim (ineqs, []) of
715 (trace_msg ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
716 refute initemss (js @ [j]))
718 (if not show_ex then ()
721 val (param_names, ctxt') = ctxt |> Variable.variant_fixes (map fst params)
722 val (more_names, ctxt'') = ctxt' |> Variable.variant_fixes
723 (Name.invents (Variable.names_of ctxt') Name.uu (length Ts - length params))
724 val params' = (more_names @ param_names) ~~ Ts
726 trace_ex ctxt'' params' atoms (discr initems) n hist
729 | refute [] js = SOME js
732 fun refute ctxt params show_ex do_pre split_neq terms : injust list option =
733 refutes ctxt params show_ex (split_items ctxt do_pre split_neq
734 (map snd params, terms)) [];
736 fun count P xs = length (filter P xs);
738 fun prove ctxt params show_ex do_pre Hs concl : bool * injust list option =
740 val _ = trace_msg "prove:"
741 (* append the negated conclusion to 'Hs' -- this corresponds to *)
742 (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
743 (* theorem/tactic level *)
744 val Hs' = Hs @ [LA_Logic.neg_prop concl]
745 fun is_neq NONE = false
746 | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
747 val neq_limit = Config.get ctxt LA_Data.fast_arith_neq_limit
748 val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
752 trace_msg ("fast_arith_neq_limit exceeded (current value is " ^
753 string_of_int neq_limit ^ "), ignoring all inequalities");
754 (split_neq, refute ctxt params show_ex do_pre split_neq Hs')
755 end handle TERM ("neg_prop", _) =>
756 (* since no meta-logic negation is available, we can only fail if *)
757 (* the conclusion is not of the form 'Trueprop $ _' (simply *)
758 (* dropping the conclusion doesn't work either, because even *)
759 (* 'False' does not imply arbitrary 'concl::prop') *)
760 (trace_msg "prove failed (cannot negate conclusion).";
763 fun refute_tac ss (i, split_neq, justs) =
766 val ctxt = Simplifier.the_context ss;
767 val _ = trace_thm ctxt
768 ("refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
769 string_of_int (length justs) ^ " justification(s)):") state
770 val {neqE, ...} = get_data ctxt;
772 (* eliminate inequalities *)
774 REPEAT_DETERM (eresolve_tac neqE i)
777 PRIMITIVE (trace_thm ctxt "State after neqE:") THEN
778 (* use theorems generated from the actual justifications *)
779 Subgoal.FOCUS (fn {prems, ...} => rtac (mkthm ss prems j) 1) ctxt i
781 (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
782 DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i) THEN
783 (* user-defined preprocessing of the subgoal *)
784 DETERM (LA_Data.pre_tac ss i) THEN
785 PRIMITIVE (trace_thm ctxt "State after pre_tac:") THEN
786 (* prove every resulting subgoal, using its justification *)
787 EVERY (map just1 justs)
791 Fast but very incomplete decider. Only premises and conclusions
792 that are already (negated) (in)equations are taken into account.
794 fun simpset_lin_arith_tac ss show_ex = SUBGOAL (fn (A, i) =>
796 val ctxt = Simplifier.the_context ss
797 val params = rev (Logic.strip_params A)
798 val Hs = Logic.strip_assums_hyp A
799 val concl = Logic.strip_assums_concl A
800 val _ = trace_term ctxt ("Trying to refute subgoal " ^ string_of_int i) A
802 case prove ctxt params show_ex true Hs concl of
803 (_, NONE) => (trace_msg "Refutation failed."; no_tac)
804 | (split_neq, SOME js) => (trace_msg "Refutation succeeded.";
805 refute_tac ss (i, split_neq, js))
808 fun cut_lin_arith_tac ss =
809 cut_facts_tac (Simplifier.prems_of_ss ss) THEN'
810 simpset_lin_arith_tac ss false;
812 fun lin_arith_tac ctxt =
813 simpset_lin_arith_tac (Simplifier.context ctxt Simplifier.empty_ss);
817 (** Forward proof from theorems **)
819 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
820 to splits of ~= premises) such that it coincides with the order of the cases
821 generated by function split_items. *)
823 datatype splittree = Tip of thm list
824 | Spl of thm * cterm * splittree * cterm * splittree;
826 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
828 fun extract (imp : cterm) : cterm * cterm =
829 let val (Il, r) = Thm.dest_comb imp
830 val (_, imp1) = Thm.dest_comb Il
831 val (Ict1, _) = Thm.dest_comb imp1
832 val (_, ct1) = Thm.dest_comb Ict1
833 val (Ir, _) = Thm.dest_comb r
834 val (_, Ict2r) = Thm.dest_comb Ir
835 val (Ict2, _) = Thm.dest_comb Ict2r
836 val (_, ct2) = Thm.dest_comb Ict2
839 fun splitasms ctxt (asms : thm list) : splittree =
840 let val {neqE, ...} = get_data ctxt
841 fun elim_neq [] (asms', []) = Tip (rev asms')
842 | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
843 | elim_neq (neq :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
844 | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
845 (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
847 let val (ct1, ct2) = extract (cprop_of spl)
848 val thm1 = Thm.assume ct1
849 val thm2 = Thm.assume ct2
850 in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
851 ct2, elim_neq neqs (asms', asms@[thm2]))
853 | NONE => elim_neq neqs (asm::asms', asms))
854 in elim_neq neqE ([], asms) end;
856 fun fwdproof ss (Tip asms : splittree) (j::js : injust list) = (mkthm ss asms j, js)
857 | fwdproof ss (Spl (thm, ct1, tree1, ct2, tree2)) js =
859 val (thm1, js1) = fwdproof ss tree1 js
860 val (thm2, js2) = fwdproof ss tree2 js1
861 val thm1' = Thm.implies_intr ct1 thm1
862 val thm2' = Thm.implies_intr ct2 thm2
863 in (thm2' COMP (thm1' COMP thm), js2) end;
864 (* FIXME needs handle THM _ => NONE ? *)
866 fun prover ss thms Tconcl (js : injust list) split_neq pos : thm option =
868 val ctxt = Simplifier.the_context ss
869 val thy = ProofContext.theory_of ctxt
870 val nTconcl = LA_Logic.neg_prop Tconcl
871 val cnTconcl = cterm_of thy nTconcl
872 val nTconclthm = Thm.assume cnTconcl
873 val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
874 val (Falsethm, _) = fwdproof ss tree js
875 val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
876 val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
877 in SOME (trace_thm ctxt "Proved by lin. arith. prover:" (LA_Logic.mk_Eq concl)) end
878 (*in case concl contains ?-var, which makes assume fail:*) (* FIXME Variable.import_terms *)
879 handle THM _ => NONE;
881 (* PRE: concl is not negated!
882 This assumption is OK because
883 1. lin_arith_simproc tries both to prove and disprove concl and
884 2. lin_arith_simproc is applied by the Simplifier which
885 dives into terms and will thus try the non-negated concl anyway.
887 fun lin_arith_simproc ss concl =
889 val ctxt = Simplifier.the_context ss
890 val thms = maps LA_Logic.atomize (Simplifier.prems_of_ss ss)
891 val Hs = map Thm.prop_of thms
892 val Tconcl = LA_Logic.mk_Trueprop concl
894 case prove ctxt [] false false Hs Tconcl of (* concl provable? *)
895 (split_neq, SOME js) => prover ss thms Tconcl js split_neq true
897 let val nTconcl = LA_Logic.neg_prop Tconcl in
898 case prove ctxt [] false false Hs nTconcl of (* ~concl provable? *)
899 (split_neq, SOME js) => prover ss thms nTconcl js split_neq false