src/Provers/Arith/fast_lin_arith.ML
author wenzelm
Thu, 26 Aug 2010 17:37:26 +0200
changeset 39037 996afaa9254a
parent 38298 04a8de29e8f7
child 39038 283f1f9969ba
permissions -rw-r--r--
slightly more abstract data handling in Fast_Lin_Arith;
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     Author:     Tobias Nipkow and Tjark Weber and Sascha Boehme
     3 
     4 A generic linear arithmetic package.  It provides two tactics
     5 (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
     6 (lin_arith_simproc).
     7 
     8 Only take premises and conclusions into account that are already
     9 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
    10 the term.
    11 *)
    12 
    13 (*** Data needed for setting up the linear arithmetic package ***)
    14 
    15 signature LIN_ARITH_LOGIC =
    16 sig
    17   val conjI       : thm  (* P ==> Q ==> P & Q *)
    18   val ccontr      : thm  (* (~ P ==> False) ==> P *)
    19   val notI        : thm  (* (P ==> False) ==> ~ P *)
    20   val not_lessD   : thm  (* ~(m < n) ==> n <= m *)
    21   val not_leD     : thm  (* ~(m <= n) ==> n < m *)
    22   val sym         : thm  (* x = y ==> y = x *)
    23   val trueI       : thm  (* True *)
    24   val mk_Eq       : thm -> thm
    25   val atomize     : thm -> thm list
    26   val mk_Trueprop : term -> term
    27   val neg_prop    : term -> term
    28   val is_False    : thm -> bool
    29   val is_nat      : typ list * term -> bool
    30   val mk_nat_thm  : theory -> term -> thm
    31 end;
    32 (*
    33 mk_Eq(~in) = `in == False'
    34 mk_Eq(in) = `in == True'
    35 where `in' is an (in)equality.
    36 
    37 neg_prop(t) = neg  if t is wrapped up in Trueprop and neg is the
    38   (logically) negated version of t (again wrapped up in Trueprop),
    39   where the negation of a negative term is the term itself (no
    40   double negation!); raises TERM ("neg_prop", [t]) if t is not of
    41   the form 'Trueprop $ _'
    42 
    43 is_nat(parameter-types,t) =  t:nat
    44 mk_nat_thm(t) = "0 <= t"
    45 *)
    46 
    47 signature LIN_ARITH_DATA =
    48 sig
    49   (*internal representation of linear (in-)equations:*)
    50   type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
    51   val decomp: Proof.context -> term -> decomp option
    52   val domain_is_nat: term -> bool
    53 
    54   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    55   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    56 
    57   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    58   val pre_tac: simpset -> int -> tactic
    59 
    60   (*the limit on the number of ~= allowed; because each ~= is split
    61     into two cases, this can lead to an explosion*)
    62   val fast_arith_neq_limit: int Config.T
    63 end;
    64 (*
    65 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    66    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    67          p (q, respectively) is the decomposition of the sum term x
    68          (y, respectively) into a list of summand * multiplicity
    69          pairs and a constant summand and d indicates if the domain
    70          is discrete.
    71 
    72 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
    73 
    74 The relationship between pre_decomp and pre_tac is somewhat tricky.  The
    75 internal representation of a subgoal and the corresponding theorem must
    76 be modified by pre_decomp (pre_tac, resp.) in a corresponding way.  See
    77 the comment for split_items below.  (This is even necessary for eta- and
    78 beta-equivalent modifications, as some of the lin. arith. code is not
    79 insensitive to them.)
    80 
    81 ss must reduce contradictory <= to False.
    82    It should also cancel common summands to keep <= reduced;
    83    otherwise <= can grow to massive proportions.
    84 *)
    85 
    86 signature FAST_LIN_ARITH =
    87 sig
    88   val cut_lin_arith_tac: simpset -> int -> tactic
    89   val lin_arith_tac: Proof.context -> bool -> int -> tactic
    90   val lin_arith_simproc: simpset -> term -> thm option
    91   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    92                  lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    93                  number_of : serial * (theory -> typ -> int -> cterm)}
    94                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    95                      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    96                      number_of : serial * (theory -> typ -> int -> cterm)})
    97                 -> Context.generic -> Context.generic
    98   val add_inj_thms: thm list -> Context.generic -> Context.generic
    99   val add_lessD: thm -> Context.generic -> Context.generic
   100   val add_simps: thm list -> Context.generic -> Context.generic
   101   val add_simprocs: simproc list -> Context.generic -> Context.generic
   102   val set_number_of: (theory -> typ -> int -> cterm) -> Context.generic -> Context.generic
   103   val trace: bool Unsynchronized.ref
   104 end;
   105 
   106 functor Fast_Lin_Arith
   107   (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
   108 struct
   109 
   110 
   111 (** theory data **)
   112 
   113 fun no_number_of _ _ _ = raise CTERM ("number_of", [])
   114 
   115 structure Data = Generic_Data
   116 (
   117   type T =
   118    {add_mono_thms: thm list,
   119     mult_mono_thms: thm list,
   120     inj_thms: thm list,
   121     lessD: thm list,
   122     neqE: thm list,
   123     simpset: Simplifier.simpset,
   124     number_of : serial * (theory -> typ -> int -> cterm)};
   125 
   126   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   127                lessD = [], neqE = [], simpset = Simplifier.empty_ss,
   128                number_of = (serial (), no_number_of) } : T;
   129   val extend = I;
   130   fun merge
   131     ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
   132       lessD = lessD1, neqE=neqE1, simpset = simpset1,
   133       number_of = (number_of1 as (s1, _))},
   134      {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
   135       lessD = lessD2, neqE=neqE2, simpset = simpset2,
   136       number_of = (number_of2 as (s2, _))}) =
   137     {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
   138      mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
   139      inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
   140      lessD = Thm.merge_thms (lessD1, lessD2),
   141      neqE = Thm.merge_thms (neqE1, neqE2),
   142      simpset = Simplifier.merge_ss (simpset1, simpset2),
   143      (* FIXME depends on accidental load order !?! *)  (* FIXME *)
   144      number_of = if s1 > s2 then number_of1 else number_of2};
   145 );
   146 
   147 val map_data = Data.map;
   148 val get_data = Data.get o Context.Proof;
   149 
   150 fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   151   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
   152     lessD = lessD, neqE = neqE, simpset = simpset, number_of = number_of};
   153 
   154 fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   155   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   156     lessD = f lessD, neqE = neqE, simpset = simpset, number_of = number_of};
   157 
   158 fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   159   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   160     lessD = lessD, neqE = neqE, simpset = f simpset, number_of = number_of};
   161 
   162 fun map_number_of f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset, number_of} =
   163   {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
   164     lessD = lessD, neqE = neqE, simpset = simpset, number_of = f number_of};
   165 
   166 fun add_inj_thms thms = map_data (map_inj_thms (append thms));
   167 fun add_lessD thm = map_data (map_lessD (fn thms => thms @ [thm]));
   168 fun add_simps thms = map_data (map_simpset (fn simpset => simpset addsimps thms));
   169 fun add_simprocs procs = map_data (map_simpset (fn simpset => simpset addsimprocs procs));
   170 
   171 fun set_number_of f = map_data (map_number_of (K (serial (), f)));
   172 
   173 
   174 (*** A fast decision procedure ***)
   175 (*** Code ported from HOL Light ***)
   176 (* possible optimizations:
   177    use (var,coeff) rep or vector rep  tp save space;
   178    treat non-negative atoms separately rather than adding 0 <= atom
   179 *)
   180 
   181 val trace = Unsynchronized.ref false;
   182 
   183 datatype lineq_type = Eq | Le | Lt;
   184 
   185 datatype injust = Asm of int
   186                 | Nat of int (* index of atom *)
   187                 | LessD of injust
   188                 | NotLessD of injust
   189                 | NotLeD of injust
   190                 | NotLeDD of injust
   191                 | Multiplied of int * injust
   192                 | Added of injust * injust;
   193 
   194 datatype lineq = Lineq of int * lineq_type * int list * injust;
   195 
   196 (* ------------------------------------------------------------------------- *)
   197 (* Finding a (counter) example from the trace of a failed elimination        *)
   198 (* ------------------------------------------------------------------------- *)
   199 (* Examples are represented as rational numbers,                             *)
   200 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   201 
   202 exception NoEx;
   203 
   204 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   205    In general, true means the bound is included, false means it is excluded.
   206    Need to know if it is a lower or upper bound for unambiguous interpretation!
   207 *)
   208 
   209 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
   210   | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
   211   | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
   212 
   213 (* PRE: ex[v] must be 0! *)
   214 fun eval ex v (a, le, cs) =
   215   let
   216     val rs = map Rat.rat_of_int cs;
   217     val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
   218   in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
   219 (* If nth rs v < 0, le should be negated.
   220    Instead this swap is taken into account in ratrelmin2.
   221 *)
   222 
   223 fun ratrelmin2 (x as (r, ler), y as (s, les)) =
   224   case Rat.ord (r, s)
   225    of EQUAL => (r, (not ler) andalso (not les))
   226     | LESS => x
   227     | GREATER => y;
   228 
   229 fun ratrelmax2 (x as (r, ler), y as (s, les)) =
   230   case Rat.ord (r, s)
   231    of EQUAL => (r, ler andalso les)
   232     | LESS => y
   233     | GREATER => x;
   234 
   235 val ratrelmin = foldr1 ratrelmin2;
   236 val ratrelmax = foldr1 ratrelmax2;
   237 
   238 fun ratexact up (r, exact) =
   239   if exact then r else
   240   let
   241     val (_, q) = Rat.quotient_of_rat r;
   242     val nth = Rat.inv (Rat.rat_of_int q);
   243   in Rat.add r (if up then nth else Rat.neg nth) end;
   244 
   245 fun ratmiddle (r, s) = Rat.mult (Rat.add r s) (Rat.inv Rat.two);
   246 
   247 fun choose2 d ((lb, exactl), (ub, exactu)) =
   248   let val ord = Rat.sign lb in
   249   if (ord = LESS orelse exactl) andalso (ord = GREATER orelse exactu)
   250     then Rat.zero
   251     else if not d then
   252       if ord = GREATER
   253         then if exactl then lb else ratmiddle (lb, ub)
   254         else if exactu then ub else ratmiddle (lb, ub)
   255       else (*discrete domain, both bounds must be exact*)
   256       if ord = GREATER
   257         then let val lb' = Rat.roundup lb in
   258           if Rat.le lb' ub then lb' else raise NoEx end
   259         else let val ub' = Rat.rounddown ub in
   260           if Rat.le lb ub' then ub' else raise NoEx end
   261   end;
   262 
   263 fun findex1 discr (v, lineqs) ex =
   264   let
   265     val nz = filter (fn (Lineq (_, _, cs, _)) => nth cs v <> 0) lineqs;
   266     val ineqs = maps elim_eqns nz;
   267     val (ge, le) = List.partition (fn (_,_,cs) => nth cs v > 0) ineqs
   268     val lb = ratrelmax (map (eval ex v) ge)
   269     val ub = ratrelmin (map (eval ex v) le)
   270   in nth_map v (K (choose2 (nth discr v) (lb, ub))) ex end;
   271 
   272 fun elim1 v x =
   273   map (fn (a,le,bs) => (Rat.add a (Rat.neg (Rat.mult (nth bs v) x)), le,
   274                         nth_map v (K Rat.zero) bs));
   275 
   276 fun single_var v (_, _, cs) = case filter_out (curry (op =) EQUAL o Rat.sign) cs
   277  of [x] => x =/ nth cs v
   278   | _ => false;
   279 
   280 (* The base case:
   281    all variables occur only with positive or only with negative coefficients *)
   282 fun pick_vars discr (ineqs,ex) =
   283   let val nz = filter_out (fn (_,_,cs) => forall (curry (op =) EQUAL o Rat.sign) cs) ineqs
   284   in case nz of [] => ex
   285      | (_,_,cs) :: _ =>
   286        let val v = find_index (not o curry (op =) EQUAL o Rat.sign) cs
   287            val d = nth discr v;
   288            val pos = not (Rat.sign (nth cs v) = LESS);
   289            val sv = filter (single_var v) nz;
   290            val minmax =
   291              if pos then if d then Rat.roundup o fst o ratrelmax
   292                          else ratexact true o ratrelmax
   293                     else if d then Rat.rounddown o fst o ratrelmin
   294                          else ratexact false o ratrelmin
   295            val bnds = map (fn (a,le,bs) => (Rat.mult a (Rat.inv (nth bs v)), le)) sv
   296            val x = minmax((Rat.zero,if pos then true else false)::bnds)
   297            val ineqs' = elim1 v x nz
   298            val ex' = nth_map v (K x) ex
   299        in pick_vars discr (ineqs',ex') end
   300   end;
   301 
   302 fun findex0 discr n lineqs =
   303   let val ineqs = maps elim_eqns lineqs
   304       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_int a, le, map Rat.rat_of_int cs))
   305                        ineqs
   306   in pick_vars discr (rineqs,replicate n Rat.zero) end;
   307 
   308 (* ------------------------------------------------------------------------- *)
   309 (* End of counterexample finder. The actual decision procedure starts here.  *)
   310 (* ------------------------------------------------------------------------- *)
   311 
   312 (* ------------------------------------------------------------------------- *)
   313 (* Calculate new (in)equality type after addition.                           *)
   314 (* ------------------------------------------------------------------------- *)
   315 
   316 fun find_add_type(Eq,x) = x
   317   | find_add_type(x,Eq) = x
   318   | find_add_type(_,Lt) = Lt
   319   | find_add_type(Lt,_) = Lt
   320   | find_add_type(Le,Le) = Le;
   321 
   322 (* ------------------------------------------------------------------------- *)
   323 (* Multiply out an (in)equation.                                             *)
   324 (* ------------------------------------------------------------------------- *)
   325 
   326 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   327   if n = 1 then i
   328   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
   329   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
   330   else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
   331 
   332 (* ------------------------------------------------------------------------- *)
   333 (* Add together (in)equations.                                               *)
   334 (* ------------------------------------------------------------------------- *)
   335 
   336 fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
   337   let val l = map2 Integer.add l1 l2
   338   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   339 
   340 (* ------------------------------------------------------------------------- *)
   341 (* Elimination of variable between a single pair of (in)equations.           *)
   342 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   343 (* ------------------------------------------------------------------------- *)
   344 
   345 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   346   let val c1 = nth l1 v and c2 = nth l2 v
   347       val m = Integer.lcm (abs c1) (abs c2)
   348       val m1 = m div (abs c1) and m2 = m div (abs c2)
   349       val (n1,n2) =
   350         if (c1 >= 0) = (c2 >= 0)
   351         then if ty1 = Eq then (~m1,m2)
   352              else if ty2 = Eq then (m1,~m2)
   353                   else sys_error "elim_var"
   354         else (m1,m2)
   355       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   356                     then (~n1,~n2) else (n1,n2)
   357   in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
   358 
   359 (* ------------------------------------------------------------------------- *)
   360 (* The main refutation-finding code.                                         *)
   361 (* ------------------------------------------------------------------------- *)
   362 
   363 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   364 
   365 fun is_contradictory (Lineq(k,ty,_,_)) =
   366   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   367 
   368 fun calc_blowup l =
   369   let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
   370   in length p * length n end;
   371 
   372 (* ------------------------------------------------------------------------- *)
   373 (* Main elimination code:                                                    *)
   374 (*                                                                           *)
   375 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   376 (*                                                                           *)
   377 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   378 (* coefficient in any of them, and uses it to eliminate.                     *)
   379 (*                                                                           *)
   380 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   381 (* blowup (number of consequences generated) and eliminates it.              *)
   382 (* ------------------------------------------------------------------------- *)
   383 
   384 fun extract_first p =
   385   let
   386     fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
   387       | extract xs [] = raise Empty
   388   in extract [] end;
   389 
   390 fun print_ineqs ineqs =
   391   if !trace then
   392      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   393        string_of_int c ^
   394        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   395        commas(map string_of_int l)) ineqs))
   396   else ();
   397 
   398 type history = (int * lineq list) list;
   399 datatype result = Success of injust | Failure of history;
   400 
   401 fun elim (ineqs, hist) =
   402   let val _ = print_ineqs ineqs
   403       val (triv, nontriv) = List.partition is_trivial ineqs in
   404   if not (null triv)
   405   then case Library.find_first is_contradictory triv of
   406          NONE => elim (nontriv, hist)
   407        | SOME(Lineq(_,_,_,j)) => Success j
   408   else
   409   if null nontriv then Failure hist
   410   else
   411   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   412   if not (null eqs) then
   413      let val c =
   414            fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
   415            |> filter (fn i => i <> 0)
   416            |> sort (int_ord o pairself abs)
   417            |> hd
   418          val (eq as Lineq(_,_,ceq,_),othereqs) =
   419                extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
   420          val v = find_index (fn v => v = c) ceq
   421          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
   422                                      (othereqs @ noneqs)
   423          val others = map (elim_var v eq) roth @ ioth
   424      in elim(others,(v,nontriv)::hist) end
   425   else
   426   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   427       val numlist = 0 upto (length (hd lists) - 1)
   428       val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
   429       val blows = map calc_blowup coeffs
   430       val iblows = blows ~~ numlist
   431       val nziblows = filter_out (fn (i, _) => i = 0) iblows
   432   in if null nziblows then Failure((~1,nontriv)::hist)
   433      else
   434      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   435          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
   436          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
   437      in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
   438   end
   439   end
   440   end;
   441 
   442 (* ------------------------------------------------------------------------- *)
   443 (* Translate back a proof.                                                   *)
   444 (* ------------------------------------------------------------------------- *)
   445 
   446 fun trace_thm ctxt msg th =
   447   (if !trace then (tracing msg; tracing (Display.string_of_thm ctxt th)) else (); th);
   448 
   449 fun trace_term ctxt msg t =
   450   (if !trace then tracing (cat_lines [msg, Syntax.string_of_term ctxt t]) else (); t)
   451 
   452 fun trace_msg msg =
   453   if !trace then tracing msg else ();
   454 
   455 val union_term = union Pattern.aeconv;
   456 val union_bterm = union (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t'));
   457 
   458 fun add_atoms (lhs, _, _, rhs, _, _) =
   459   union_term (map fst lhs) o union_term (map fst rhs);
   460 
   461 fun atoms_of ds = fold add_atoms ds [];
   462 
   463 (*
   464 Simplification may detect a contradiction 'prematurely' due to type
   465 information: n+1 <= 0 is simplified to False and does not need to be crossed
   466 with 0 <= n.
   467 *)
   468 local
   469   exception FalseE of thm
   470 in
   471 
   472 fun mkthm ss asms (just: injust) =
   473   let
   474     val ctxt = Simplifier.the_context ss;
   475     val thy = ProofContext.theory_of ctxt;
   476     val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset,
   477       number_of = (_, num_of), ...} = get_data ctxt;
   478     val simpset' = Simplifier.inherit_context ss simpset;
   479     fun only_concl f thm =
   480       if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
   481     val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
   482 
   483     fun use_first rules thm =
   484       get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
   485 
   486     fun add2 thm1 thm2 =
   487       use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
   488     fun try_add thms thm = get_first (fn th => add2 th thm) thms;
   489 
   490     fun add_thms thm1 thm2 =
   491       (case add2 thm1 thm2 of
   492         NONE =>
   493           (case try_add ([thm1] RL inj_thms) thm2 of
   494             NONE =>
   495               (the (try_add ([thm2] RL inj_thms) thm1)
   496                 handle Option =>
   497                   (trace_thm ctxt "" thm1; trace_thm ctxt "" thm2;
   498                    sys_error "Linear arithmetic: failed to add thms"))
   499           | SOME thm => thm)
   500       | SOME thm => thm);
   501 
   502     fun mult_by_add n thm =
   503       let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
   504       in mul n thm end;
   505 
   506     val rewr = Simplifier.rewrite simpset';
   507     val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
   508       (Conv.binop_conv rewr)));
   509     fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
   510       let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
   511       in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
   512 
   513     fun mult n thm =
   514       (case use_first mult_mono_thms thm of
   515         NONE => mult_by_add n thm
   516       | SOME mth =>
   517           let
   518             val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
   519               |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
   520             val T = #T (Thm.rep_cterm cv)
   521           in
   522             mth
   523             |> Thm.instantiate ([], [(cv, num_of thy T n)])
   524             |> rewrite_concl
   525             |> discharge_prem
   526             handle CTERM _ => mult_by_add n thm
   527                  | THM _ => mult_by_add n thm
   528           end);
   529 
   530     fun mult_thm (n, thm) =
   531       if n = ~1 then thm RS LA_Logic.sym
   532       else if n < 0 then mult (~n) thm RS LA_Logic.sym
   533       else mult n thm;
   534 
   535     fun simp thm =
   536       let val thm' = trace_thm ctxt "Simplified:" (full_simplify simpset' thm)
   537       in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
   538 
   539     fun mk (Asm i) = trace_thm ctxt ("Asm " ^ string_of_int i) (nth asms i)
   540       | mk (Nat i) = trace_thm ctxt ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
   541       | mk (LessD j) = trace_thm ctxt "L" (hd ([mk j] RL lessD))
   542       | mk (NotLeD j) = trace_thm ctxt "NLe" (mk j RS LA_Logic.not_leD)
   543       | mk (NotLeDD j) = trace_thm ctxt "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
   544       | mk (NotLessD j) = trace_thm ctxt "NL" (mk j RS LA_Logic.not_lessD)
   545       | mk (Added (j1, j2)) = simp (trace_thm ctxt "+" (add_thms (mk j1) (mk j2)))
   546       | mk (Multiplied (n, j)) =
   547           (trace_msg ("*" ^ string_of_int n); trace_thm ctxt "*" (mult_thm (n, mk j)))
   548 
   549   in
   550     let
   551       val _ = trace_msg "mkthm";
   552       val thm = trace_thm ctxt "Final thm:" (mk just);
   553       val fls = simplify simpset' thm;
   554       val _ = trace_thm ctxt "After simplification:" fls;
   555       val _ =
   556         if LA_Logic.is_False fls then ()
   557         else
   558          (tracing (cat_lines
   559            (["Assumptions:"] @ map (Display.string_of_thm ctxt) asms @ [""] @
   560             ["Proved:", Display.string_of_thm ctxt fls, ""]));
   561           warning "Linear arithmetic should have refuted the assumptions.\n\
   562             \Please inform Tobias Nipkow.")
   563     in fls end
   564     handle FalseE thm => trace_thm ctxt "False reached early:" thm
   565   end;
   566 
   567 end;
   568 
   569 fun coeff poly atom =
   570   AList.lookup Pattern.aeconv poly atom |> the_default 0;
   571 
   572 fun integ(rlhs,r,rel,rrhs,s,d) =
   573 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   574     val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   575     fun mult(t,r) =
   576         let val (i,j) = Rat.quotient_of_rat r
   577         in (t,i * (m div j)) end
   578 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   579 
   580 fun mklineq atoms =
   581   fn (item, k) =>
   582   let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
   583       val lhsa = map (coeff lhs) atoms
   584       and rhsa = map (coeff rhs) atoms
   585       val diff = map2 (curry (op -)) rhsa lhsa
   586       val c = i-j
   587       val just = Asm k
   588       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
   589   in case rel of
   590       "<="   => lineq(c,Le,diff,just)
   591      | "~<=" => if discrete
   592                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   593                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   594      | "<"   => if discrete
   595                 then lineq(c+1,Le,diff,LessD(just))
   596                 else lineq(c,Lt,diff,just)
   597      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   598      | "="   => lineq(c,Eq,diff,just)
   599      | _     => sys_error("mklineq" ^ rel)
   600   end;
   601 
   602 (* ------------------------------------------------------------------------- *)
   603 (* Print (counter) example                                                   *)
   604 (* ------------------------------------------------------------------------- *)
   605 
   606 fun print_atom((a,d),r) =
   607   let val (p,q) = Rat.quotient_of_rat r
   608       val s = if d then string_of_int p else
   609               if p = 0 then "0"
   610               else string_of_int p ^ "/" ^ string_of_int q
   611   in a ^ " = " ^ s end;
   612 
   613 fun produce_ex sds =
   614   curry (op ~~) sds
   615   #> map print_atom
   616   #> commas
   617   #> curry (op ^) "Counterexample (possibly spurious):\n";
   618 
   619 fun trace_ex ctxt params atoms discr n (hist: history) =
   620   case hist of
   621     [] => ()
   622   | (v, lineqs) :: hist' =>
   623       let
   624         val frees = map Free params
   625         fun show_term t = Syntax.string_of_term ctxt (subst_bounds (frees, t))
   626         val start =
   627           if v = ~1 then (hist', findex0 discr n lineqs)
   628           else (hist, replicate n Rat.zero)
   629         val ex = SOME (produce_ex (map show_term atoms ~~ discr)
   630             (uncurry (fold (findex1 discr)) start))
   631           handle NoEx => NONE
   632       in
   633         case ex of
   634           SOME s => (warning "Linear arithmetic failed - see trace for a counterexample."; tracing s)
   635         | NONE => warning "Linear arithmetic failed"
   636       end;
   637 
   638 (* ------------------------------------------------------------------------- *)
   639 
   640 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
   641   if LA_Logic.is_nat (pTs, atom)
   642   then let val l = map (fn j => if j=i then 1 else 0) ixs
   643        in SOME (Lineq (0, Le, l, Nat i)) end
   644   else NONE;
   645 
   646 (* This code is tricky. It takes a list of premises in the order they occur
   647 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   648 ones as NONE. Going through the premises, each numeric one is converted into
   649 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   650 >. Thus split_items returns a list of equation systems. This may blow up if
   651 there are many ~=, but in practice it does not seem to happen. The really
   652 tricky bit is to arrange the order of the cases such that they coincide with
   653 the order in which the cases are in the end generated by the tactic that
   654 applies the generated refutation thms (see function 'refute_tac').
   655 
   656 For variables n of type nat, a constraint 0 <= n is added.
   657 *)
   658 
   659 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
   660 (*        could be intertwined: separate the first (fully split) case,       *)
   661 (*        refute it, continue with splitting and refuting.  Terminate with   *)
   662 (*        failure as soon as a case could not be refuted; i.e. delay further *)
   663 (*        splitting until after a refutation for other cases has been found. *)
   664 
   665 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
   666 let
   667   (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
   668   (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic    *)
   669   (* level                                                          *)
   670   (* FIXME: this is currently sensitive to the order of theorems in *)
   671   (*        neqE:  The theorem for type "nat" must come first.  A   *)
   672   (*        better (i.e. less likely to break when neqE changes)    *)
   673   (*        implementation should *test* which theorem from neqE    *)
   674   (*        can be applied, and split the premise accordingly.      *)
   675   fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
   676                (LA_Data.decomp option * bool) list list =
   677   let
   678     fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
   679                   (LA_Data.decomp option * bool) list list =
   680           [[]]
   681       | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
   682           map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
   683       | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
   684           if rel = "~=" andalso (not nat_only orelse is_nat) then
   685             (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
   686             elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
   687             elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
   688           else
   689             map (cons ineq) (elim_neq' nat_only ineqs)
   690   in
   691     ineqs |> elim_neq' true
   692           |> maps (elim_neq' false)
   693   end
   694 
   695   fun ignore_neq (NONE, bool) = (NONE, bool)
   696     | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
   697       if rel = "~=" then (NONE, bool) else (ineq, bool)
   698 
   699   fun number_hyps _ []             = []
   700     | number_hyps n (NONE::xs)     = number_hyps (n+1) xs
   701     | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
   702 
   703   val result = (Ts, terms)
   704     |> (* user-defined preprocessing of the subgoal *)
   705        (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
   706     |> tap (fn subgoals => trace_msg ("Preprocessing yields " ^
   707          string_of_int (length subgoals) ^ " subgoal(s) total."))
   708     |> (* produce the internal encoding of (in-)equalities *)
   709        map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
   710     |> (* splitting of inequalities *)
   711        map (apsnd (if split_neq then elim_neq else
   712                      Library.single o map ignore_neq))
   713     |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
   714     |> (* numbering of hypotheses, ignoring irrelevant ones *)
   715        map (apsnd (number_hyps 0))
   716 in
   717   trace_msg ("Splitting of inequalities yields " ^
   718     string_of_int (length result) ^ " subgoal(s) total.");
   719   result
   720 end;
   721 
   722 fun add_datoms ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _) (dats : (bool * term) list) =
   723   union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
   724 
   725 fun discr (initems : (LA_Data.decomp * int) list) : bool list =
   726   map fst (fold add_datoms initems []);
   727 
   728 fun refutes ctxt params show_ex :
   729     (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
   730   let
   731     fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
   732           let
   733             val atoms = atoms_of (map fst initems)
   734             val n = length atoms
   735             val mkleq = mklineq atoms
   736             val ixs = 0 upto (n - 1)
   737             val iatoms = atoms ~~ ixs
   738             val natlineqs = map_filter (mknat Ts ixs) iatoms
   739             val ineqs = map mkleq initems @ natlineqs
   740           in case elim (ineqs, []) of
   741                Success j =>
   742                  (trace_msg ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
   743                   refute initemss (js @ [j]))
   744              | Failure hist =>
   745                  (if not show_ex then ()
   746                   else
   747                     let
   748                       val (param_names, ctxt') = ctxt |> Variable.variant_fixes (map fst params)
   749                       val (more_names, ctxt'') = ctxt' |> Variable.variant_fixes
   750                         (Name.invents (Variable.names_of ctxt') Name.uu (length Ts - length params))
   751                       val params' = (more_names @ param_names) ~~ Ts
   752                     in
   753                       trace_ex ctxt'' params' atoms (discr initems) n hist
   754                     end; NONE)
   755           end
   756       | refute [] js = SOME js
   757   in refute end;
   758 
   759 fun refute ctxt params show_ex do_pre split_neq terms : injust list option =
   760   refutes ctxt params show_ex (split_items ctxt do_pre split_neq
   761     (map snd params, terms)) [];
   762 
   763 fun count P xs = length (filter P xs);
   764 
   765 fun prove ctxt params show_ex do_pre Hs concl : bool * injust list option =
   766   let
   767     val _ = trace_msg "prove:"
   768     (* append the negated conclusion to 'Hs' -- this corresponds to     *)
   769     (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
   770     (* theorem/tactic level                                             *)
   771     val Hs' = Hs @ [LA_Logic.neg_prop concl]
   772     fun is_neq NONE                 = false
   773       | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
   774     val neq_limit = Config.get ctxt LA_Data.fast_arith_neq_limit
   775     val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
   776   in
   777     if split_neq then ()
   778     else
   779       trace_msg ("fast_arith_neq_limit exceeded (current value is " ^
   780         string_of_int neq_limit ^ "), ignoring all inequalities");
   781     (split_neq, refute ctxt params show_ex do_pre split_neq Hs')
   782   end handle TERM ("neg_prop", _) =>
   783     (* since no meta-logic negation is available, we can only fail if   *)
   784     (* the conclusion is not of the form 'Trueprop $ _' (simply         *)
   785     (* dropping the conclusion doesn't work either, because even        *)
   786     (* 'False' does not imply arbitrary 'concl::prop')                  *)
   787     (trace_msg "prove failed (cannot negate conclusion).";
   788       (false, NONE));
   789 
   790 fun refute_tac ss (i, split_neq, justs) =
   791   fn state =>
   792     let
   793       val ctxt = Simplifier.the_context ss;
   794       val _ = trace_thm ctxt
   795         ("refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
   796           string_of_int (length justs) ^ " justification(s)):") state
   797       val {neqE, ...} = get_data ctxt;
   798       fun just1 j =
   799         (* eliminate inequalities *)
   800         (if split_neq then
   801           REPEAT_DETERM (eresolve_tac neqE i)
   802         else
   803           all_tac) THEN
   804           PRIMITIVE (trace_thm ctxt "State after neqE:") THEN
   805           (* use theorems generated from the actual justifications *)
   806           Subgoal.FOCUS (fn {prems, ...} => rtac (mkthm ss prems j) 1) ctxt i
   807     in
   808       (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
   809       DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i) THEN
   810       (* user-defined preprocessing of the subgoal *)
   811       DETERM (LA_Data.pre_tac ss i) THEN
   812       PRIMITIVE (trace_thm ctxt "State after pre_tac:") THEN
   813       (* prove every resulting subgoal, using its justification *)
   814       EVERY (map just1 justs)
   815     end  state;
   816 
   817 (*
   818 Fast but very incomplete decider. Only premises and conclusions
   819 that are already (negated) (in)equations are taken into account.
   820 *)
   821 fun simpset_lin_arith_tac ss show_ex = SUBGOAL (fn (A, i) =>
   822   let
   823     val ctxt = Simplifier.the_context ss
   824     val params = rev (Logic.strip_params A)
   825     val Hs = Logic.strip_assums_hyp A
   826     val concl = Logic.strip_assums_concl A
   827     val _ = trace_term ctxt ("Trying to refute subgoal " ^ string_of_int i) A
   828   in
   829     case prove ctxt params show_ex true Hs concl of
   830       (_, NONE) => (trace_msg "Refutation failed."; no_tac)
   831     | (split_neq, SOME js) => (trace_msg "Refutation succeeded.";
   832                                refute_tac ss (i, split_neq, js))
   833   end);
   834 
   835 fun cut_lin_arith_tac ss =
   836   cut_facts_tac (Simplifier.prems_of_ss ss) THEN'
   837   simpset_lin_arith_tac ss false;
   838 
   839 fun lin_arith_tac ctxt =
   840   simpset_lin_arith_tac (Simplifier.context ctxt Simplifier.empty_ss);
   841 
   842 
   843 
   844 (** Forward proof from theorems **)
   845 
   846 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   847 to splits of ~= premises) such that it coincides with the order of the cases
   848 generated by function split_items. *)
   849 
   850 datatype splittree = Tip of thm list
   851                    | Spl of thm * cterm * splittree * cterm * splittree;
   852 
   853 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
   854 
   855 fun extract (imp : cterm) : cterm * cterm =
   856 let val (Il, r)    = Thm.dest_comb imp
   857     val (_, imp1)  = Thm.dest_comb Il
   858     val (Ict1, _)  = Thm.dest_comb imp1
   859     val (_, ct1)   = Thm.dest_comb Ict1
   860     val (Ir, _)    = Thm.dest_comb r
   861     val (_, Ict2r) = Thm.dest_comb Ir
   862     val (Ict2, _)  = Thm.dest_comb Ict2r
   863     val (_, ct2)   = Thm.dest_comb Ict2
   864 in (ct1, ct2) end;
   865 
   866 fun splitasms ctxt (asms : thm list) : splittree =
   867 let val {neqE, ...} = get_data ctxt
   868     fun elim_neq [] (asms', []) = Tip (rev asms')
   869       | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
   870       | elim_neq (neq :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
   871       | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
   872       (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
   873         SOME spl =>
   874           let val (ct1, ct2) = extract (cprop_of spl)
   875               val thm1 = Thm.assume ct1
   876               val thm2 = Thm.assume ct2
   877           in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
   878             ct2, elim_neq neqs (asms', asms@[thm2]))
   879           end
   880       | NONE => elim_neq neqs (asm::asms', asms))
   881 in elim_neq neqE ([], asms) end;
   882 
   883 fun fwdproof ss (Tip asms : splittree) (j::js : injust list) = (mkthm ss asms j, js)
   884   | fwdproof ss (Spl (thm, ct1, tree1, ct2, tree2)) js =
   885       let
   886         val (thm1, js1) = fwdproof ss tree1 js
   887         val (thm2, js2) = fwdproof ss tree2 js1
   888         val thm1' = Thm.implies_intr ct1 thm1
   889         val thm2' = Thm.implies_intr ct2 thm2
   890       in (thm2' COMP (thm1' COMP thm), js2) end;
   891       (* FIXME needs handle THM _ => NONE ? *)
   892 
   893 fun prover ss thms Tconcl (js : injust list) split_neq pos : thm option =
   894   let
   895     val ctxt = Simplifier.the_context ss
   896     val thy = ProofContext.theory_of ctxt
   897     val nTconcl = LA_Logic.neg_prop Tconcl
   898     val cnTconcl = cterm_of thy nTconcl
   899     val nTconclthm = Thm.assume cnTconcl
   900     val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
   901     val (Falsethm, _) = fwdproof ss tree js
   902     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   903     val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
   904   in SOME (trace_thm ctxt "Proved by lin. arith. prover:" (LA_Logic.mk_Eq concl)) end
   905   (*in case concl contains ?-var, which makes assume fail:*)   (* FIXME Variable.import_terms *)
   906   handle THM _ => NONE;
   907 
   908 (* PRE: concl is not negated!
   909    This assumption is OK because
   910    1. lin_arith_simproc tries both to prove and disprove concl and
   911    2. lin_arith_simproc is applied by the Simplifier which
   912       dives into terms and will thus try the non-negated concl anyway.
   913 *)
   914 fun lin_arith_simproc ss concl =
   915   let
   916     val ctxt = Simplifier.the_context ss
   917     val thms = maps LA_Logic.atomize (Simplifier.prems_of_ss ss)
   918     val Hs = map Thm.prop_of thms
   919     val Tconcl = LA_Logic.mk_Trueprop concl
   920   in
   921     case prove ctxt [] false false Hs Tconcl of (* concl provable? *)
   922       (split_neq, SOME js) => prover ss thms Tconcl js split_neq true
   923     | (_, NONE) =>
   924         let val nTconcl = LA_Logic.neg_prop Tconcl in
   925           case prove ctxt [] false false Hs nTconcl of (* ~concl provable? *)
   926             (split_neq, SOME js) => prover ss thms nTconcl js split_neq false
   927           | (_, NONE) => NONE
   928         end
   929   end;
   930 
   931 end;