src/Provers/Arith/fast_lin_arith.ML
author haftmann
Wed, 28 Jul 2010 14:09:56 +0200
changeset 38298 04a8de29e8f7
parent 36945 9bec62c10714
child 39037 996afaa9254a
permissions -rw-r--r--
dropped dead code
     1 (*  Title:      Provers/Arith/fast_lin_arith.ML
     2     Author:     Tobias Nipkow and Tjark Weber and Sascha Boehme
     3 
     4 A generic linear arithmetic package.  It provides two tactics
     5 (cut_lin_arith_tac, lin_arith_tac) and a simplification procedure
     6 (lin_arith_simproc).
     7 
     8 Only take premises and conclusions into account that are already
     9 (negated) (in)equations. lin_arith_simproc tries to prove or disprove
    10 the term.
    11 *)
    12 
    13 (*** Data needed for setting up the linear arithmetic package ***)
    14 
    15 signature LIN_ARITH_LOGIC =
    16 sig
    17   val conjI       : thm  (* P ==> Q ==> P & Q *)
    18   val ccontr      : thm  (* (~ P ==> False) ==> P *)
    19   val notI        : thm  (* (P ==> False) ==> ~ P *)
    20   val not_lessD   : thm  (* ~(m < n) ==> n <= m *)
    21   val not_leD     : thm  (* ~(m <= n) ==> n < m *)
    22   val sym         : thm  (* x = y ==> y = x *)
    23   val trueI       : thm  (* True *)
    24   val mk_Eq       : thm -> thm
    25   val atomize     : thm -> thm list
    26   val mk_Trueprop : term -> term
    27   val neg_prop    : term -> term
    28   val is_False    : thm -> bool
    29   val is_nat      : typ list * term -> bool
    30   val mk_nat_thm  : theory -> term -> thm
    31 end;
    32 (*
    33 mk_Eq(~in) = `in == False'
    34 mk_Eq(in) = `in == True'
    35 where `in' is an (in)equality.
    36 
    37 neg_prop(t) = neg  if t is wrapped up in Trueprop and neg is the
    38   (logically) negated version of t (again wrapped up in Trueprop),
    39   where the negation of a negative term is the term itself (no
    40   double negation!); raises TERM ("neg_prop", [t]) if t is not of
    41   the form 'Trueprop $ _'
    42 
    43 is_nat(parameter-types,t) =  t:nat
    44 mk_nat_thm(t) = "0 <= t"
    45 *)
    46 
    47 signature LIN_ARITH_DATA =
    48 sig
    49   (*internal representation of linear (in-)equations:*)
    50   type decomp = (term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat * bool
    51   val decomp: Proof.context -> term -> decomp option
    52   val domain_is_nat: term -> bool
    53 
    54   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    55   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    56 
    57   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    58   val pre_tac: simpset -> int -> tactic
    59 
    60   (*the limit on the number of ~= allowed; because each ~= is split
    61     into two cases, this can lead to an explosion*)
    62   val fast_arith_neq_limit: int Config.T
    63 end;
    64 (*
    65 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    66    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    67          p (q, respectively) is the decomposition of the sum term x
    68          (y, respectively) into a list of summand * multiplicity
    69          pairs and a constant summand and d indicates if the domain
    70          is discrete.
    71 
    72 domain_is_nat(`x Rel y') t should yield true iff x is of type "nat".
    73 
    74 The relationship between pre_decomp and pre_tac is somewhat tricky.  The
    75 internal representation of a subgoal and the corresponding theorem must
    76 be modified by pre_decomp (pre_tac, resp.) in a corresponding way.  See
    77 the comment for split_items below.  (This is even necessary for eta- and
    78 beta-equivalent modifications, as some of the lin. arith. code is not
    79 insensitive to them.)
    80 
    81 ss must reduce contradictory <= to False.
    82    It should also cancel common summands to keep <= reduced;
    83    otherwise <= can grow to massive proportions.
    84 *)
    85 
    86 signature FAST_LIN_ARITH =
    87 sig
    88   val cut_lin_arith_tac: simpset -> int -> tactic
    89   val lin_arith_tac: Proof.context -> bool -> int -> tactic
    90   val lin_arith_simproc: simpset -> term -> thm option
    91   val map_data: ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    92                  lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    93                  number_of : serial * (theory -> typ -> int -> cterm)}
    94                  -> {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
    95                      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset,
    96                      number_of : serial * (theory -> typ -> int -> cterm)})
    97                 -> Context.generic -> Context.generic
    98   val trace: bool Unsynchronized.ref
    99 end;
   100 
   101 functor Fast_Lin_Arith
   102   (structure LA_Logic: LIN_ARITH_LOGIC and LA_Data: LIN_ARITH_DATA): FAST_LIN_ARITH =
   103 struct
   104 
   105 
   106 (** theory data **)
   107 
   108 fun no_number_of _ _ _ = raise CTERM ("number_of", [])
   109 
   110 structure Data = Generic_Data
   111 (
   112   type T =
   113    {add_mono_thms: thm list,
   114     mult_mono_thms: thm list,
   115     inj_thms: thm list,
   116     lessD: thm list,
   117     neqE: thm list,
   118     simpset: Simplifier.simpset,
   119     number_of : serial * (theory -> typ -> int -> cterm)};
   120 
   121   val empty = {add_mono_thms = [], mult_mono_thms = [], inj_thms = [],
   122                lessD = [], neqE = [], simpset = Simplifier.empty_ss,
   123                number_of = (serial (), no_number_of) } : T;
   124   val extend = I;
   125   fun merge
   126     ({add_mono_thms= add_mono_thms1, mult_mono_thms= mult_mono_thms1, inj_thms= inj_thms1,
   127       lessD = lessD1, neqE=neqE1, simpset = simpset1,
   128       number_of = (number_of1 as (s1, _))},
   129      {add_mono_thms= add_mono_thms2, mult_mono_thms= mult_mono_thms2, inj_thms= inj_thms2,
   130       lessD = lessD2, neqE=neqE2, simpset = simpset2,
   131       number_of = (number_of2 as (s2, _))}) =
   132     {add_mono_thms = Thm.merge_thms (add_mono_thms1, add_mono_thms2),
   133      mult_mono_thms = Thm.merge_thms (mult_mono_thms1, mult_mono_thms2),
   134      inj_thms = Thm.merge_thms (inj_thms1, inj_thms2),
   135      lessD = Thm.merge_thms (lessD1, lessD2),
   136      neqE = Thm.merge_thms (neqE1, neqE2),
   137      simpset = Simplifier.merge_ss (simpset1, simpset2),
   138      (* FIXME depends on accidental load order !?! *)
   139      number_of = if s1 > s2 then number_of1 else number_of2};
   140 );
   141 
   142 val map_data = Data.map;
   143 val get_data = Data.get o Context.Proof;
   144 
   145 
   146 
   147 (*** A fast decision procedure ***)
   148 (*** Code ported from HOL Light ***)
   149 (* possible optimizations:
   150    use (var,coeff) rep or vector rep  tp save space;
   151    treat non-negative atoms separately rather than adding 0 <= atom
   152 *)
   153 
   154 val trace = Unsynchronized.ref false;
   155 
   156 datatype lineq_type = Eq | Le | Lt;
   157 
   158 datatype injust = Asm of int
   159                 | Nat of int (* index of atom *)
   160                 | LessD of injust
   161                 | NotLessD of injust
   162                 | NotLeD of injust
   163                 | NotLeDD of injust
   164                 | Multiplied of int * injust
   165                 | Added of injust * injust;
   166 
   167 datatype lineq = Lineq of int * lineq_type * int list * injust;
   168 
   169 (* ------------------------------------------------------------------------- *)
   170 (* Finding a (counter) example from the trace of a failed elimination        *)
   171 (* ------------------------------------------------------------------------- *)
   172 (* Examples are represented as rational numbers,                             *)
   173 (* Dont blame John Harrison for this code - it is entirely mine. TN          *)
   174 
   175 exception NoEx;
   176 
   177 (* Coding: (i,true,cs) means i <= cs and (i,false,cs) means i < cs.
   178    In general, true means the bound is included, false means it is excluded.
   179    Need to know if it is a lower or upper bound for unambiguous interpretation!
   180 *)
   181 
   182 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
   183   | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
   184   | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
   185 
   186 (* PRE: ex[v] must be 0! *)
   187 fun eval ex v (a, le, cs) =
   188   let
   189     val rs = map Rat.rat_of_int cs;
   190     val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
   191   in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
   192 (* If nth rs v < 0, le should be negated.
   193    Instead this swap is taken into account in ratrelmin2.
   194 *)
   195 
   196 fun ratrelmin2 (x as (r, ler), y as (s, les)) =
   197   case Rat.ord (r, s)
   198    of EQUAL => (r, (not ler) andalso (not les))
   199     | LESS => x
   200     | GREATER => y;
   201 
   202 fun ratrelmax2 (x as (r, ler), y as (s, les)) =
   203   case Rat.ord (r, s)
   204    of EQUAL => (r, ler andalso les)
   205     | LESS => y
   206     | GREATER => x;
   207 
   208 val ratrelmin = foldr1 ratrelmin2;
   209 val ratrelmax = foldr1 ratrelmax2;
   210 
   211 fun ratexact up (r, exact) =
   212   if exact then r else
   213   let
   214     val (_, q) = Rat.quotient_of_rat r;
   215     val nth = Rat.inv (Rat.rat_of_int q);
   216   in Rat.add r (if up then nth else Rat.neg nth) end;
   217 
   218 fun ratmiddle (r, s) = Rat.mult (Rat.add r s) (Rat.inv Rat.two);
   219 
   220 fun choose2 d ((lb, exactl), (ub, exactu)) =
   221   let val ord = Rat.sign lb in
   222   if (ord = LESS orelse exactl) andalso (ord = GREATER orelse exactu)
   223     then Rat.zero
   224     else if not d then
   225       if ord = GREATER
   226         then if exactl then lb else ratmiddle (lb, ub)
   227         else if exactu then ub else ratmiddle (lb, ub)
   228       else (*discrete domain, both bounds must be exact*)
   229       if ord = GREATER
   230         then let val lb' = Rat.roundup lb in
   231           if Rat.le lb' ub then lb' else raise NoEx end
   232         else let val ub' = Rat.rounddown ub in
   233           if Rat.le lb ub' then ub' else raise NoEx end
   234   end;
   235 
   236 fun findex1 discr (v, lineqs) ex =
   237   let
   238     val nz = filter (fn (Lineq (_, _, cs, _)) => nth cs v <> 0) lineqs;
   239     val ineqs = maps elim_eqns nz;
   240     val (ge, le) = List.partition (fn (_,_,cs) => nth cs v > 0) ineqs
   241     val lb = ratrelmax (map (eval ex v) ge)
   242     val ub = ratrelmin (map (eval ex v) le)
   243   in nth_map v (K (choose2 (nth discr v) (lb, ub))) ex end;
   244 
   245 fun elim1 v x =
   246   map (fn (a,le,bs) => (Rat.add a (Rat.neg (Rat.mult (nth bs v) x)), le,
   247                         nth_map v (K Rat.zero) bs));
   248 
   249 fun single_var v (_, _, cs) = case filter_out (curry (op =) EQUAL o Rat.sign) cs
   250  of [x] => x =/ nth cs v
   251   | _ => false;
   252 
   253 (* The base case:
   254    all variables occur only with positive or only with negative coefficients *)
   255 fun pick_vars discr (ineqs,ex) =
   256   let val nz = filter_out (fn (_,_,cs) => forall (curry (op =) EQUAL o Rat.sign) cs) ineqs
   257   in case nz of [] => ex
   258      | (_,_,cs) :: _ =>
   259        let val v = find_index (not o curry (op =) EQUAL o Rat.sign) cs
   260            val d = nth discr v;
   261            val pos = not (Rat.sign (nth cs v) = LESS);
   262            val sv = filter (single_var v) nz;
   263            val minmax =
   264              if pos then if d then Rat.roundup o fst o ratrelmax
   265                          else ratexact true o ratrelmax
   266                     else if d then Rat.rounddown o fst o ratrelmin
   267                          else ratexact false o ratrelmin
   268            val bnds = map (fn (a,le,bs) => (Rat.mult a (Rat.inv (nth bs v)), le)) sv
   269            val x = minmax((Rat.zero,if pos then true else false)::bnds)
   270            val ineqs' = elim1 v x nz
   271            val ex' = nth_map v (K x) ex
   272        in pick_vars discr (ineqs',ex') end
   273   end;
   274 
   275 fun findex0 discr n lineqs =
   276   let val ineqs = maps elim_eqns lineqs
   277       val rineqs = map (fn (a,le,cs) => (Rat.rat_of_int a, le, map Rat.rat_of_int cs))
   278                        ineqs
   279   in pick_vars discr (rineqs,replicate n Rat.zero) end;
   280 
   281 (* ------------------------------------------------------------------------- *)
   282 (* End of counterexample finder. The actual decision procedure starts here.  *)
   283 (* ------------------------------------------------------------------------- *)
   284 
   285 (* ------------------------------------------------------------------------- *)
   286 (* Calculate new (in)equality type after addition.                           *)
   287 (* ------------------------------------------------------------------------- *)
   288 
   289 fun find_add_type(Eq,x) = x
   290   | find_add_type(x,Eq) = x
   291   | find_add_type(_,Lt) = Lt
   292   | find_add_type(Lt,_) = Lt
   293   | find_add_type(Le,Le) = Le;
   294 
   295 (* ------------------------------------------------------------------------- *)
   296 (* Multiply out an (in)equation.                                             *)
   297 (* ------------------------------------------------------------------------- *)
   298 
   299 fun multiply_ineq n (i as Lineq(k,ty,l,just)) =
   300   if n = 1 then i
   301   else if n = 0 andalso ty = Lt then sys_error "multiply_ineq"
   302   else if n < 0 andalso (ty=Le orelse ty=Lt) then sys_error "multiply_ineq"
   303   else Lineq (n * k, ty, map (Integer.mult n) l, Multiplied (n, just));
   304 
   305 (* ------------------------------------------------------------------------- *)
   306 (* Add together (in)equations.                                               *)
   307 (* ------------------------------------------------------------------------- *)
   308 
   309 fun add_ineq (Lineq (k1,ty1,l1,just1)) (Lineq (k2,ty2,l2,just2)) =
   310   let val l = map2 Integer.add l1 l2
   311   in Lineq(k1+k2,find_add_type(ty1,ty2),l,Added(just1,just2)) end;
   312 
   313 (* ------------------------------------------------------------------------- *)
   314 (* Elimination of variable between a single pair of (in)equations.           *)
   315 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   316 (* ------------------------------------------------------------------------- *)
   317 
   318 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   319   let val c1 = nth l1 v and c2 = nth l2 v
   320       val m = Integer.lcm (abs c1) (abs c2)
   321       val m1 = m div (abs c1) and m2 = m div (abs c2)
   322       val (n1,n2) =
   323         if (c1 >= 0) = (c2 >= 0)
   324         then if ty1 = Eq then (~m1,m2)
   325              else if ty2 = Eq then (m1,~m2)
   326                   else sys_error "elim_var"
   327         else (m1,m2)
   328       val (p1,p2) = if ty1=Eq andalso ty2=Eq andalso (n1 = ~1 orelse n2 = ~1)
   329                     then (~n1,~n2) else (n1,n2)
   330   in add_ineq (multiply_ineq p1 i1) (multiply_ineq p2 i2) end;
   331 
   332 (* ------------------------------------------------------------------------- *)
   333 (* The main refutation-finding code.                                         *)
   334 (* ------------------------------------------------------------------------- *)
   335 
   336 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   337 
   338 fun is_contradictory (Lineq(k,ty,_,_)) =
   339   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   340 
   341 fun calc_blowup l =
   342   let val (p,n) = List.partition (curry (op <) 0) (filter (curry (op <>) 0) l)
   343   in length p * length n end;
   344 
   345 (* ------------------------------------------------------------------------- *)
   346 (* Main elimination code:                                                    *)
   347 (*                                                                           *)
   348 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   349 (*                                                                           *)
   350 (* (2) If there are any equations, picks a variable with the lowest absolute *)
   351 (* coefficient in any of them, and uses it to eliminate.                     *)
   352 (*                                                                           *)
   353 (* (3) Otherwise, chooses a variable in the inequality to minimize the       *)
   354 (* blowup (number of consequences generated) and eliminates it.              *)
   355 (* ------------------------------------------------------------------------- *)
   356 
   357 fun extract_first p =
   358   let
   359     fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
   360       | extract xs [] = raise Empty
   361   in extract [] end;
   362 
   363 fun print_ineqs ineqs =
   364   if !trace then
   365      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   366        string_of_int c ^
   367        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   368        commas(map string_of_int l)) ineqs))
   369   else ();
   370 
   371 type history = (int * lineq list) list;
   372 datatype result = Success of injust | Failure of history;
   373 
   374 fun elim (ineqs, hist) =
   375   let val _ = print_ineqs ineqs
   376       val (triv, nontriv) = List.partition is_trivial ineqs in
   377   if not (null triv)
   378   then case Library.find_first is_contradictory triv of
   379          NONE => elim (nontriv, hist)
   380        | SOME(Lineq(_,_,_,j)) => Success j
   381   else
   382   if null nontriv then Failure hist
   383   else
   384   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   385   if not (null eqs) then
   386      let val c =
   387            fold (fn Lineq(_,_,l,_) => fn cs => union (op =) l cs) eqs []
   388            |> filter (fn i => i <> 0)
   389            |> sort (int_ord o pairself abs)
   390            |> hd
   391          val (eq as Lineq(_,_,ceq,_),othereqs) =
   392                extract_first (fn Lineq(_,_,l,_) => member (op =) l c) eqs
   393          val v = find_index (fn v => v = c) ceq
   394          val (ioth,roth) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0)
   395                                      (othereqs @ noneqs)
   396          val others = map (elim_var v eq) roth @ ioth
   397      in elim(others,(v,nontriv)::hist) end
   398   else
   399   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
   400       val numlist = 0 upto (length (hd lists) - 1)
   401       val coeffs = map (fn i => map (fn xs => nth xs i) lists) numlist
   402       val blows = map calc_blowup coeffs
   403       val iblows = blows ~~ numlist
   404       val nziblows = filter_out (fn (i, _) => i = 0) iblows
   405   in if null nziblows then Failure((~1,nontriv)::hist)
   406      else
   407      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
   408          val (no,yes) = List.partition (fn (Lineq(_,_,l,_)) => nth l v = 0) ineqs
   409          val (pos,neg) = List.partition(fn (Lineq(_,_,l,_)) => nth l v > 0) yes
   410      in elim(no @ map_product (elim_var v) pos neg, (v,nontriv)::hist) end
   411   end
   412   end
   413   end;
   414 
   415 (* ------------------------------------------------------------------------- *)
   416 (* Translate back a proof.                                                   *)
   417 (* ------------------------------------------------------------------------- *)
   418 
   419 fun trace_thm ctxt msg th =
   420   (if !trace then (tracing msg; tracing (Display.string_of_thm ctxt th)) else (); th);
   421 
   422 fun trace_term ctxt msg t =
   423   (if !trace then tracing (cat_lines [msg, Syntax.string_of_term ctxt t]) else (); t)
   424 
   425 fun trace_msg msg =
   426   if !trace then tracing msg else ();
   427 
   428 val union_term = union Pattern.aeconv;
   429 val union_bterm = union (fn ((b:bool, t), (b', t')) => b = b' andalso Pattern.aeconv (t, t'));
   430 
   431 fun add_atoms (lhs, _, _, rhs, _, _) =
   432   union_term (map fst lhs) o union_term (map fst rhs);
   433 
   434 fun atoms_of ds = fold add_atoms ds [];
   435 
   436 (*
   437 Simplification may detect a contradiction 'prematurely' due to type
   438 information: n+1 <= 0 is simplified to False and does not need to be crossed
   439 with 0 <= n.
   440 *)
   441 local
   442   exception FalseE of thm
   443 in
   444 
   445 fun mkthm ss asms (just: injust) =
   446   let
   447     val ctxt = Simplifier.the_context ss;
   448     val thy = ProofContext.theory_of ctxt;
   449     val {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset,
   450       number_of = (_, num_of), ...} = get_data ctxt;
   451     val simpset' = Simplifier.inherit_context ss simpset;
   452     fun only_concl f thm =
   453       if Thm.no_prems thm then f (Thm.concl_of thm) else NONE;
   454     val atoms = atoms_of (map_filter (only_concl (LA_Data.decomp ctxt)) asms);
   455 
   456     fun use_first rules thm =
   457       get_first (fn th => SOME (thm RS th) handle THM _ => NONE) rules
   458 
   459     fun add2 thm1 thm2 =
   460       use_first add_mono_thms (thm1 RS (thm2 RS LA_Logic.conjI));
   461     fun try_add thms thm = get_first (fn th => add2 th thm) thms;
   462 
   463     fun add_thms thm1 thm2 =
   464       (case add2 thm1 thm2 of
   465         NONE =>
   466           (case try_add ([thm1] RL inj_thms) thm2 of
   467             NONE =>
   468               (the (try_add ([thm2] RL inj_thms) thm1)
   469                 handle Option =>
   470                   (trace_thm ctxt "" thm1; trace_thm ctxt "" thm2;
   471                    sys_error "Linear arithmetic: failed to add thms"))
   472           | SOME thm => thm)
   473       | SOME thm => thm);
   474 
   475     fun mult_by_add n thm =
   476       let fun mul i th = if i = 1 then th else mul (i - 1) (add_thms thm th)
   477       in mul n thm end;
   478 
   479     val rewr = Simplifier.rewrite simpset';
   480     val rewrite_concl = Conv.fconv_rule (Conv.concl_conv ~1 (Conv.arg_conv
   481       (Conv.binop_conv rewr)));
   482     fun discharge_prem thm = if Thm.nprems_of thm = 0 then thm else
   483       let val cv = Conv.arg1_conv (Conv.arg_conv rewr)
   484       in Thm.implies_elim (Conv.fconv_rule cv thm) LA_Logic.trueI end
   485 
   486     fun mult n thm =
   487       (case use_first mult_mono_thms thm of
   488         NONE => mult_by_add n thm
   489       | SOME mth =>
   490           let
   491             val cv = mth |> Thm.cprop_of |> Drule.strip_imp_concl
   492               |> Thm.dest_arg |> Thm.dest_arg1 |> Thm.dest_arg1
   493             val T = #T (Thm.rep_cterm cv)
   494           in
   495             mth
   496             |> Thm.instantiate ([], [(cv, num_of thy T n)])
   497             |> rewrite_concl
   498             |> discharge_prem
   499             handle CTERM _ => mult_by_add n thm
   500                  | THM _ => mult_by_add n thm
   501           end);
   502 
   503     fun mult_thm (n, thm) =
   504       if n = ~1 then thm RS LA_Logic.sym
   505       else if n < 0 then mult (~n) thm RS LA_Logic.sym
   506       else mult n thm;
   507 
   508     fun simp thm =
   509       let val thm' = trace_thm ctxt "Simplified:" (full_simplify simpset' thm)
   510       in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
   511 
   512     fun mk (Asm i) = trace_thm ctxt ("Asm " ^ string_of_int i) (nth asms i)
   513       | mk (Nat i) = trace_thm ctxt ("Nat " ^ string_of_int i) (LA_Logic.mk_nat_thm thy (nth atoms i))
   514       | mk (LessD j) = trace_thm ctxt "L" (hd ([mk j] RL lessD))
   515       | mk (NotLeD j) = trace_thm ctxt "NLe" (mk j RS LA_Logic.not_leD)
   516       | mk (NotLeDD j) = trace_thm ctxt "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
   517       | mk (NotLessD j) = trace_thm ctxt "NL" (mk j RS LA_Logic.not_lessD)
   518       | mk (Added (j1, j2)) = simp (trace_thm ctxt "+" (add_thms (mk j1) (mk j2)))
   519       | mk (Multiplied (n, j)) =
   520           (trace_msg ("*" ^ string_of_int n); trace_thm ctxt "*" (mult_thm (n, mk j)))
   521 
   522   in
   523     let
   524       val _ = trace_msg "mkthm";
   525       val thm = trace_thm ctxt "Final thm:" (mk just);
   526       val fls = simplify simpset' thm;
   527       val _ = trace_thm ctxt "After simplification:" fls;
   528       val _ =
   529         if LA_Logic.is_False fls then ()
   530         else
   531          (tracing (cat_lines
   532            (["Assumptions:"] @ map (Display.string_of_thm ctxt) asms @ [""] @
   533             ["Proved:", Display.string_of_thm ctxt fls, ""]));
   534           warning "Linear arithmetic should have refuted the assumptions.\n\
   535             \Please inform Tobias Nipkow.")
   536     in fls end
   537     handle FalseE thm => trace_thm ctxt "False reached early:" thm
   538   end;
   539 
   540 end;
   541 
   542 fun coeff poly atom =
   543   AList.lookup Pattern.aeconv poly atom |> the_default 0;
   544 
   545 fun integ(rlhs,r,rel,rrhs,s,d) =
   546 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   547     val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   548     fun mult(t,r) =
   549         let val (i,j) = Rat.quotient_of_rat r
   550         in (t,i * (m div j)) end
   551 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   552 
   553 fun mklineq atoms =
   554   fn (item, k) =>
   555   let val (m, (lhs,i,rel,rhs,j,discrete)) = integ item
   556       val lhsa = map (coeff lhs) atoms
   557       and rhsa = map (coeff rhs) atoms
   558       val diff = map2 (curry (op -)) rhsa lhsa
   559       val c = i-j
   560       val just = Asm k
   561       fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied(m,j))
   562   in case rel of
   563       "<="   => lineq(c,Le,diff,just)
   564      | "~<=" => if discrete
   565                 then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
   566                 else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
   567      | "<"   => if discrete
   568                 then lineq(c+1,Le,diff,LessD(just))
   569                 else lineq(c,Lt,diff,just)
   570      | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
   571      | "="   => lineq(c,Eq,diff,just)
   572      | _     => sys_error("mklineq" ^ rel)
   573   end;
   574 
   575 (* ------------------------------------------------------------------------- *)
   576 (* Print (counter) example                                                   *)
   577 (* ------------------------------------------------------------------------- *)
   578 
   579 fun print_atom((a,d),r) =
   580   let val (p,q) = Rat.quotient_of_rat r
   581       val s = if d then string_of_int p else
   582               if p = 0 then "0"
   583               else string_of_int p ^ "/" ^ string_of_int q
   584   in a ^ " = " ^ s end;
   585 
   586 fun produce_ex sds =
   587   curry (op ~~) sds
   588   #> map print_atom
   589   #> commas
   590   #> curry (op ^) "Counterexample (possibly spurious):\n";
   591 
   592 fun trace_ex ctxt params atoms discr n (hist: history) =
   593   case hist of
   594     [] => ()
   595   | (v, lineqs) :: hist' =>
   596       let
   597         val frees = map Free params
   598         fun show_term t = Syntax.string_of_term ctxt (subst_bounds (frees, t))
   599         val start =
   600           if v = ~1 then (hist', findex0 discr n lineqs)
   601           else (hist, replicate n Rat.zero)
   602         val ex = SOME (produce_ex (map show_term atoms ~~ discr)
   603             (uncurry (fold (findex1 discr)) start))
   604           handle NoEx => NONE
   605       in
   606         case ex of
   607           SOME s => (warning "Linear arithmetic failed - see trace for a counterexample."; tracing s)
   608         | NONE => warning "Linear arithmetic failed"
   609       end;
   610 
   611 (* ------------------------------------------------------------------------- *)
   612 
   613 fun mknat (pTs : typ list) (ixs : int list) (atom : term, i : int) : lineq option =
   614   if LA_Logic.is_nat (pTs, atom)
   615   then let val l = map (fn j => if j=i then 1 else 0) ixs
   616        in SOME (Lineq (0, Le, l, Nat i)) end
   617   else NONE;
   618 
   619 (* This code is tricky. It takes a list of premises in the order they occur
   620 in the subgoal. Numerical premises are coded as SOME(tuple), non-numerical
   621 ones as NONE. Going through the premises, each numeric one is converted into
   622 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
   623 >. Thus split_items returns a list of equation systems. This may blow up if
   624 there are many ~=, but in practice it does not seem to happen. The really
   625 tricky bit is to arrange the order of the cases such that they coincide with
   626 the order in which the cases are in the end generated by the tactic that
   627 applies the generated refutation thms (see function 'refute_tac').
   628 
   629 For variables n of type nat, a constraint 0 <= n is added.
   630 *)
   631 
   632 (* FIXME: To optimize, the splitting of cases and the search for refutations *)
   633 (*        could be intertwined: separate the first (fully split) case,       *)
   634 (*        refute it, continue with splitting and refuting.  Terminate with   *)
   635 (*        failure as soon as a case could not be refuted; i.e. delay further *)
   636 (*        splitting until after a refutation for other cases has been found. *)
   637 
   638 fun split_items ctxt do_pre split_neq (Ts, terms) : (typ list * (LA_Data.decomp * int) list) list =
   639 let
   640   (* splits inequalities '~=' into '<' and '>'; this corresponds to *)
   641   (* 'REPEAT_DETERM (eresolve_tac neqE i)' at the theorem/tactic    *)
   642   (* level                                                          *)
   643   (* FIXME: this is currently sensitive to the order of theorems in *)
   644   (*        neqE:  The theorem for type "nat" must come first.  A   *)
   645   (*        better (i.e. less likely to break when neqE changes)    *)
   646   (*        implementation should *test* which theorem from neqE    *)
   647   (*        can be applied, and split the premise accordingly.      *)
   648   fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
   649                (LA_Data.decomp option * bool) list list =
   650   let
   651     fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
   652                   (LA_Data.decomp option * bool) list list =
   653           [[]]
   654       | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =
   655           map (cons (NONE, is_nat)) (elim_neq' nat_only ineqs)
   656       | elim_neq' nat_only ((ineq as (SOME (l, i, rel, r, j, d), is_nat)) :: ineqs) =
   657           if rel = "~=" andalso (not nat_only orelse is_nat) then
   658             (* [| ?l ~= ?r; ?l < ?r ==> ?R; ?r < ?l ==> ?R |] ==> ?R *)
   659             elim_neq' nat_only (ineqs @ [(SOME (l, i, "<", r, j, d), is_nat)]) @
   660             elim_neq' nat_only (ineqs @ [(SOME (r, j, "<", l, i, d), is_nat)])
   661           else
   662             map (cons ineq) (elim_neq' nat_only ineqs)
   663   in
   664     ineqs |> elim_neq' true
   665           |> maps (elim_neq' false)
   666   end
   667 
   668   fun ignore_neq (NONE, bool) = (NONE, bool)
   669     | ignore_neq (ineq as SOME (_, _, rel, _, _, _), bool) =
   670       if rel = "~=" then (NONE, bool) else (ineq, bool)
   671 
   672   fun number_hyps _ []             = []
   673     | number_hyps n (NONE::xs)     = number_hyps (n+1) xs
   674     | number_hyps n ((SOME x)::xs) = (x, n) :: number_hyps (n+1) xs
   675 
   676   val result = (Ts, terms)
   677     |> (* user-defined preprocessing of the subgoal *)
   678        (if do_pre then LA_Data.pre_decomp ctxt else Library.single)
   679     |> tap (fn subgoals => trace_msg ("Preprocessing yields " ^
   680          string_of_int (length subgoals) ^ " subgoal(s) total."))
   681     |> (* produce the internal encoding of (in-)equalities *)
   682        map (apsnd (map (fn t => (LA_Data.decomp ctxt t, LA_Data.domain_is_nat t))))
   683     |> (* splitting of inequalities *)
   684        map (apsnd (if split_neq then elim_neq else
   685                      Library.single o map ignore_neq))
   686     |> maps (fn (Ts, subgoals) => map (pair Ts o map fst) subgoals)
   687     |> (* numbering of hypotheses, ignoring irrelevant ones *)
   688        map (apsnd (number_hyps 0))
   689 in
   690   trace_msg ("Splitting of inequalities yields " ^
   691     string_of_int (length result) ^ " subgoal(s) total.");
   692   result
   693 end;
   694 
   695 fun add_datoms ((lhs,_,_,rhs,_,d) : LA_Data.decomp, _) (dats : (bool * term) list) =
   696   union_bterm (map (pair d o fst) lhs) (union_bterm (map (pair d o fst) rhs) dats);
   697 
   698 fun discr (initems : (LA_Data.decomp * int) list) : bool list =
   699   map fst (fold add_datoms initems []);
   700 
   701 fun refutes ctxt params show_ex :
   702     (typ list * (LA_Data.decomp * int) list) list -> injust list -> injust list option =
   703   let
   704     fun refute ((Ts, initems : (LA_Data.decomp * int) list) :: initemss) (js: injust list) =
   705           let
   706             val atoms = atoms_of (map fst initems)
   707             val n = length atoms
   708             val mkleq = mklineq atoms
   709             val ixs = 0 upto (n - 1)
   710             val iatoms = atoms ~~ ixs
   711             val natlineqs = map_filter (mknat Ts ixs) iatoms
   712             val ineqs = map mkleq initems @ natlineqs
   713           in case elim (ineqs, []) of
   714                Success j =>
   715                  (trace_msg ("Contradiction! (" ^ string_of_int (length js + 1) ^ ")");
   716                   refute initemss (js @ [j]))
   717              | Failure hist =>
   718                  (if not show_ex then ()
   719                   else
   720                     let
   721                       val (param_names, ctxt') = ctxt |> Variable.variant_fixes (map fst params)
   722                       val (more_names, ctxt'') = ctxt' |> Variable.variant_fixes
   723                         (Name.invents (Variable.names_of ctxt') Name.uu (length Ts - length params))
   724                       val params' = (more_names @ param_names) ~~ Ts
   725                     in
   726                       trace_ex ctxt'' params' atoms (discr initems) n hist
   727                     end; NONE)
   728           end
   729       | refute [] js = SOME js
   730   in refute end;
   731 
   732 fun refute ctxt params show_ex do_pre split_neq terms : injust list option =
   733   refutes ctxt params show_ex (split_items ctxt do_pre split_neq
   734     (map snd params, terms)) [];
   735 
   736 fun count P xs = length (filter P xs);
   737 
   738 fun prove ctxt params show_ex do_pre Hs concl : bool * injust list option =
   739   let
   740     val _ = trace_msg "prove:"
   741     (* append the negated conclusion to 'Hs' -- this corresponds to     *)
   742     (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
   743     (* theorem/tactic level                                             *)
   744     val Hs' = Hs @ [LA_Logic.neg_prop concl]
   745     fun is_neq NONE                 = false
   746       | is_neq (SOME (_,_,r,_,_,_)) = (r = "~=")
   747     val neq_limit = Config.get ctxt LA_Data.fast_arith_neq_limit
   748     val split_neq = count is_neq (map (LA_Data.decomp ctxt) Hs') <= neq_limit
   749   in
   750     if split_neq then ()
   751     else
   752       trace_msg ("fast_arith_neq_limit exceeded (current value is " ^
   753         string_of_int neq_limit ^ "), ignoring all inequalities");
   754     (split_neq, refute ctxt params show_ex do_pre split_neq Hs')
   755   end handle TERM ("neg_prop", _) =>
   756     (* since no meta-logic negation is available, we can only fail if   *)
   757     (* the conclusion is not of the form 'Trueprop $ _' (simply         *)
   758     (* dropping the conclusion doesn't work either, because even        *)
   759     (* 'False' does not imply arbitrary 'concl::prop')                  *)
   760     (trace_msg "prove failed (cannot negate conclusion).";
   761       (false, NONE));
   762 
   763 fun refute_tac ss (i, split_neq, justs) =
   764   fn state =>
   765     let
   766       val ctxt = Simplifier.the_context ss;
   767       val _ = trace_thm ctxt
   768         ("refute_tac (on subgoal " ^ string_of_int i ^ ", with " ^
   769           string_of_int (length justs) ^ " justification(s)):") state
   770       val {neqE, ...} = get_data ctxt;
   771       fun just1 j =
   772         (* eliminate inequalities *)
   773         (if split_neq then
   774           REPEAT_DETERM (eresolve_tac neqE i)
   775         else
   776           all_tac) THEN
   777           PRIMITIVE (trace_thm ctxt "State after neqE:") THEN
   778           (* use theorems generated from the actual justifications *)
   779           Subgoal.FOCUS (fn {prems, ...} => rtac (mkthm ss prems j) 1) ctxt i
   780     in
   781       (* rewrite "[| A1; ...; An |] ==> B" to "[| A1; ...; An; ~B |] ==> False" *)
   782       DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i) THEN
   783       (* user-defined preprocessing of the subgoal *)
   784       DETERM (LA_Data.pre_tac ss i) THEN
   785       PRIMITIVE (trace_thm ctxt "State after pre_tac:") THEN
   786       (* prove every resulting subgoal, using its justification *)
   787       EVERY (map just1 justs)
   788     end  state;
   789 
   790 (*
   791 Fast but very incomplete decider. Only premises and conclusions
   792 that are already (negated) (in)equations are taken into account.
   793 *)
   794 fun simpset_lin_arith_tac ss show_ex = SUBGOAL (fn (A, i) =>
   795   let
   796     val ctxt = Simplifier.the_context ss
   797     val params = rev (Logic.strip_params A)
   798     val Hs = Logic.strip_assums_hyp A
   799     val concl = Logic.strip_assums_concl A
   800     val _ = trace_term ctxt ("Trying to refute subgoal " ^ string_of_int i) A
   801   in
   802     case prove ctxt params show_ex true Hs concl of
   803       (_, NONE) => (trace_msg "Refutation failed."; no_tac)
   804     | (split_neq, SOME js) => (trace_msg "Refutation succeeded.";
   805                                refute_tac ss (i, split_neq, js))
   806   end);
   807 
   808 fun cut_lin_arith_tac ss =
   809   cut_facts_tac (Simplifier.prems_of_ss ss) THEN'
   810   simpset_lin_arith_tac ss false;
   811 
   812 fun lin_arith_tac ctxt =
   813   simpset_lin_arith_tac (Simplifier.context ctxt Simplifier.empty_ss);
   814 
   815 
   816 
   817 (** Forward proof from theorems **)
   818 
   819 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
   820 to splits of ~= premises) such that it coincides with the order of the cases
   821 generated by function split_items. *)
   822 
   823 datatype splittree = Tip of thm list
   824                    | Spl of thm * cterm * splittree * cterm * splittree;
   825 
   826 (* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
   827 
   828 fun extract (imp : cterm) : cterm * cterm =
   829 let val (Il, r)    = Thm.dest_comb imp
   830     val (_, imp1)  = Thm.dest_comb Il
   831     val (Ict1, _)  = Thm.dest_comb imp1
   832     val (_, ct1)   = Thm.dest_comb Ict1
   833     val (Ir, _)    = Thm.dest_comb r
   834     val (_, Ict2r) = Thm.dest_comb Ir
   835     val (Ict2, _)  = Thm.dest_comb Ict2r
   836     val (_, ct2)   = Thm.dest_comb Ict2
   837 in (ct1, ct2) end;
   838 
   839 fun splitasms ctxt (asms : thm list) : splittree =
   840 let val {neqE, ...} = get_data ctxt
   841     fun elim_neq [] (asms', []) = Tip (rev asms')
   842       | elim_neq [] (asms', asms) = Tip (rev asms' @ asms)
   843       | elim_neq (neq :: neqs) (asms', []) = elim_neq neqs ([],rev asms')
   844       | elim_neq (neqs as (neq :: _)) (asms', asm::asms) =
   845       (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) [neq] of
   846         SOME spl =>
   847           let val (ct1, ct2) = extract (cprop_of spl)
   848               val thm1 = Thm.assume ct1
   849               val thm2 = Thm.assume ct2
   850           in Spl (spl, ct1, elim_neq neqs (asms', asms@[thm1]),
   851             ct2, elim_neq neqs (asms', asms@[thm2]))
   852           end
   853       | NONE => elim_neq neqs (asm::asms', asms))
   854 in elim_neq neqE ([], asms) end;
   855 
   856 fun fwdproof ss (Tip asms : splittree) (j::js : injust list) = (mkthm ss asms j, js)
   857   | fwdproof ss (Spl (thm, ct1, tree1, ct2, tree2)) js =
   858       let
   859         val (thm1, js1) = fwdproof ss tree1 js
   860         val (thm2, js2) = fwdproof ss tree2 js1
   861         val thm1' = Thm.implies_intr ct1 thm1
   862         val thm2' = Thm.implies_intr ct2 thm2
   863       in (thm2' COMP (thm1' COMP thm), js2) end;
   864       (* FIXME needs handle THM _ => NONE ? *)
   865 
   866 fun prover ss thms Tconcl (js : injust list) split_neq pos : thm option =
   867   let
   868     val ctxt = Simplifier.the_context ss
   869     val thy = ProofContext.theory_of ctxt
   870     val nTconcl = LA_Logic.neg_prop Tconcl
   871     val cnTconcl = cterm_of thy nTconcl
   872     val nTconclthm = Thm.assume cnTconcl
   873     val tree = (if split_neq then splitasms ctxt else Tip) (thms @ [nTconclthm])
   874     val (Falsethm, _) = fwdproof ss tree js
   875     val contr = if pos then LA_Logic.ccontr else LA_Logic.notI
   876     val concl = Thm.implies_intr cnTconcl Falsethm COMP contr
   877   in SOME (trace_thm ctxt "Proved by lin. arith. prover:" (LA_Logic.mk_Eq concl)) end
   878   (*in case concl contains ?-var, which makes assume fail:*)   (* FIXME Variable.import_terms *)
   879   handle THM _ => NONE;
   880 
   881 (* PRE: concl is not negated!
   882    This assumption is OK because
   883    1. lin_arith_simproc tries both to prove and disprove concl and
   884    2. lin_arith_simproc is applied by the Simplifier which
   885       dives into terms and will thus try the non-negated concl anyway.
   886 *)
   887 fun lin_arith_simproc ss concl =
   888   let
   889     val ctxt = Simplifier.the_context ss
   890     val thms = maps LA_Logic.atomize (Simplifier.prems_of_ss ss)
   891     val Hs = map Thm.prop_of thms
   892     val Tconcl = LA_Logic.mk_Trueprop concl
   893   in
   894     case prove ctxt [] false false Hs Tconcl of (* concl provable? *)
   895       (split_neq, SOME js) => prover ss thms Tconcl js split_neq true
   896     | (_, NONE) =>
   897         let val nTconcl = LA_Logic.neg_prop Tconcl in
   898           case prove ctxt [] false false Hs nTconcl of (* ~concl provable? *)
   899             (split_neq, SOME js) => prover ss thms nTconcl js split_neq false
   900           | (_, NONE) => NONE
   901         end
   902   end;
   903 
   904 end;