1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
1.2 +++ b/src/HOL/Matrix_LP/fspmlp.ML Sat Mar 17 12:52:40 2012 +0100
1.3 @@ -0,0 +1,313 @@
1.4 +(* Title: HOL/Matrix/fspmlp.ML
1.5 + Author: Steven Obua
1.6 +*)
1.7 +
1.8 +signature FSPMLP =
1.9 +sig
1.10 + type linprog
1.11 + type vector = FloatSparseMatrixBuilder.vector
1.12 + type matrix = FloatSparseMatrixBuilder.matrix
1.13 +
1.14 + val y : linprog -> term
1.15 + val A : linprog -> term * term
1.16 + val b : linprog -> term
1.17 + val c : linprog -> term * term
1.18 + val r12 : linprog -> term * term
1.19 +
1.20 + exception Load of string
1.21 +
1.22 + val load : string -> int -> bool -> linprog
1.23 +end
1.24 +
1.25 +structure Fspmlp : FSPMLP =
1.26 +struct
1.27 +
1.28 +type vector = FloatSparseMatrixBuilder.vector
1.29 +type matrix = FloatSparseMatrixBuilder.matrix
1.30 +
1.31 +type linprog = term * (term * term) * term * (term * term) * (term * term)
1.32 +
1.33 +fun y (c1, _, _, _, _) = c1
1.34 +fun A (_, c2, _, _, _) = c2
1.35 +fun b (_, _, c3, _, _) = c3
1.36 +fun c (_, _, _, c4, _) = c4
1.37 +fun r12 (_, _, _, _, c6) = c6
1.38 +
1.39 +structure CplexFloatSparseMatrixConverter =
1.40 +MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder);
1.41 +
1.42 +datatype bound_type = LOWER | UPPER
1.43 +
1.44 +fun intbound_ord ((i1: int, b1),(i2,b2)) =
1.45 + if i1 < i2 then LESS
1.46 + else if i1 = i2 then
1.47 + (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER)
1.48 + else GREATER
1.49 +
1.50 +structure Inttab = Table(type key = int val ord = (rev_order o int_ord));
1.51 +
1.52 +structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord);
1.53 +(* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *)
1.54 +(* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *)
1.55 +
1.56 +exception Internal of string;
1.57 +
1.58 +fun add_row_bound g dest_key row_index row_bound =
1.59 + let
1.60 + val x =
1.61 + case VarGraph.lookup g dest_key of
1.62 + NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty)
1.63 + | SOME (sure_bound, f) =>
1.64 + (sure_bound,
1.65 + case Inttab.lookup f row_index of
1.66 + NONE => Inttab.update (row_index, (row_bound, [])) f
1.67 + | SOME _ => raise (Internal "add_row_bound"))
1.68 + in
1.69 + VarGraph.update (dest_key, x) g
1.70 + end
1.71 +
1.72 +fun update_sure_bound g (key as (_, btype)) bound =
1.73 + let
1.74 + val x =
1.75 + case VarGraph.lookup g key of
1.76 + NONE => (SOME bound, Inttab.empty)
1.77 + | SOME (NONE, f) => (SOME bound, f)
1.78 + | SOME (SOME old_bound, f) =>
1.79 + (SOME ((case btype of
1.80 + UPPER => Float.min
1.81 + | LOWER => Float.max)
1.82 + old_bound bound), f)
1.83 + in
1.84 + VarGraph.update (key, x) g
1.85 + end
1.86 +
1.87 +fun get_sure_bound g key =
1.88 + case VarGraph.lookup g key of
1.89 + NONE => NONE
1.90 + | SOME (sure_bound, _) => sure_bound
1.91 +
1.92 +(*fun get_row_bound g key row_index =
1.93 + case VarGraph.lookup g key of
1.94 + NONE => NONE
1.95 + | SOME (sure_bound, f) =>
1.96 + (case Inttab.lookup f row_index of
1.97 + NONE => NONE
1.98 + | SOME (row_bound, _) => (sure_bound, row_bound))*)
1.99 +
1.100 +fun add_edge g src_key dest_key row_index coeff =
1.101 + case VarGraph.lookup g dest_key of
1.102 + NONE => raise (Internal "add_edge: dest_key not found")
1.103 + | SOME (sure_bound, f) =>
1.104 + (case Inttab.lookup f row_index of
1.105 + NONE => raise (Internal "add_edge: row_index not found")
1.106 + | SOME (row_bound, sources) =>
1.107 + VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g)
1.108 +
1.109 +fun split_graph g =
1.110 + let
1.111 + fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound
1.112 + of NONE => (r1, r2)
1.113 + | SOME bound => (case key
1.114 + of (u, UPPER) => (r1, Inttab.update (u, bound) r2)
1.115 + | (u, LOWER) => (Inttab.update (u, bound) r1, r2))
1.116 + in VarGraph.fold split g (Inttab.empty, Inttab.empty) end
1.117 +
1.118 +(* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm).
1.119 + If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *)
1.120 +fun propagate_sure_bounds safe names g =
1.121 + let
1.122 + (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *)
1.123 + fun calc_sure_bound_from_sources g (key as (_, btype)) =
1.124 + let
1.125 + fun mult_upper x (lower, upper) =
1.126 + if Float.sign x = LESS then
1.127 + Float.mult x lower
1.128 + else
1.129 + Float.mult x upper
1.130 +
1.131 + fun mult_lower x (lower, upper) =
1.132 + if Float.sign x = LESS then
1.133 + Float.mult x upper
1.134 + else
1.135 + Float.mult x lower
1.136 +
1.137 + val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower
1.138 +
1.139 + fun calc_sure_bound (_, (row_bound, sources)) sure_bound =
1.140 + let
1.141 + fun add_src_bound (coeff, src_key) sum =
1.142 + case sum of
1.143 + NONE => NONE
1.144 + | SOME x =>
1.145 + (case get_sure_bound g src_key of
1.146 + NONE => NONE
1.147 + | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff)))
1.148 + in
1.149 + case fold add_src_bound sources (SOME row_bound) of
1.150 + NONE => sure_bound
1.151 + | new_sure_bound as (SOME new_bound) =>
1.152 + (case sure_bound of
1.153 + NONE => new_sure_bound
1.154 + | SOME old_bound =>
1.155 + SOME (case btype of
1.156 + UPPER => Float.min old_bound new_bound
1.157 + | LOWER => Float.max old_bound new_bound))
1.158 + end
1.159 + in
1.160 + case VarGraph.lookup g key of
1.161 + NONE => NONE
1.162 + | SOME (sure_bound, f) =>
1.163 + let
1.164 + val x = Inttab.fold calc_sure_bound f sure_bound
1.165 + in
1.166 + if x = sure_bound then NONE else x
1.167 + end
1.168 + end
1.169 +
1.170 + fun propagate (key, _) (g, b) =
1.171 + case calc_sure_bound_from_sources g key of
1.172 + NONE => (g,b)
1.173 + | SOME bound => (update_sure_bound g key bound,
1.174 + if safe then
1.175 + case get_sure_bound g key of
1.176 + NONE => true
1.177 + | _ => b
1.178 + else
1.179 + true)
1.180 +
1.181 + val (g, b) = VarGraph.fold propagate g (g, false)
1.182 + in
1.183 + if b then propagate_sure_bounds safe names g else g
1.184 + end
1.185 +
1.186 +exception Load of string;
1.187 +
1.188 +val empty_spvec = @{term "Nil :: real spvec"};
1.189 +fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs;
1.190 +val empty_spmat = @{term "Nil :: real spmat"};
1.191 +fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs;
1.192 +
1.193 +fun calcr safe_propagation xlen names prec A b =
1.194 + let
1.195 + fun test_1 (lower, upper) =
1.196 + if lower = upper then
1.197 + (if Float.eq (lower, (~1, 0)) then ~1
1.198 + else if Float.eq (lower, (1, 0)) then 1
1.199 + else 0)
1.200 + else 0
1.201 +
1.202 + fun calcr (row_index, a) g =
1.203 + let
1.204 + val b = FloatSparseMatrixBuilder.v_elem_at b row_index
1.205 + val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b)
1.206 + val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l =>
1.207 + (i, FloatArith.approx_decstr_by_bin prec s)::l) a []
1.208 +
1.209 + fun fold_dest_nodes (dest_index, dest_value) g =
1.210 + let
1.211 + val dest_test = test_1 dest_value
1.212 + in
1.213 + if dest_test = 0 then
1.214 + g
1.215 + else let
1.216 + val (dest_key as (_, dest_btype), row_bound) =
1.217 + if dest_test = ~1 then
1.218 + ((dest_index, LOWER), Float.neg b2)
1.219 + else
1.220 + ((dest_index, UPPER), b2)
1.221 +
1.222 + fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g =
1.223 + if src_index = dest_index then g
1.224 + else
1.225 + let
1.226 + val coeff = case dest_btype of
1.227 + UPPER => (Float.neg src_upper, Float.neg src_lower)
1.228 + | LOWER => src_value
1.229 + in
1.230 + if Float.sign src_lower = LESS then
1.231 + add_edge g (src_index, UPPER) dest_key row_index coeff
1.232 + else
1.233 + add_edge g (src_index, LOWER) dest_key row_index coeff
1.234 + end
1.235 + in
1.236 + fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound)
1.237 + end
1.238 + end
1.239 + in
1.240 + case approx_a of
1.241 + [] => g
1.242 + | [(u, a)] =>
1.243 + let
1.244 + val atest = test_1 a
1.245 + in
1.246 + if atest = ~1 then
1.247 + update_sure_bound g (u, LOWER) (Float.neg b2)
1.248 + else if atest = 1 then
1.249 + update_sure_bound g (u, UPPER) b2
1.250 + else
1.251 + g
1.252 + end
1.253 + | _ => fold fold_dest_nodes approx_a g
1.254 + end
1.255 +
1.256 + val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty
1.257 +
1.258 + val g = propagate_sure_bounds safe_propagation names g
1.259 +
1.260 + val (r1, r2) = split_graph g
1.261 +
1.262 + fun add_row_entry m index f vname value =
1.263 + let
1.264 + val v = (case value of
1.265 + SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value
1.266 + | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT))))
1.267 + val vec = cons_spvec v empty_spvec
1.268 + in
1.269 + cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m
1.270 + end
1.271 +
1.272 + fun abs_estimate i r1 r2 =
1.273 + if i = 0 then
1.274 + let val e = empty_spmat in (e, e) end
1.275 + else
1.276 + let
1.277 + val index = xlen-i
1.278 + val (r12_1, r12_2) = abs_estimate (i-1) r1 r2
1.279 + val b1 = Inttab.lookup r1 index
1.280 + val b2 = Inttab.lookup r2 index
1.281 + in
1.282 + (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1,
1.283 + add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2)
1.284 + end
1.285 +
1.286 + val (r1, r2) = abs_estimate xlen r1 r2
1.287 +
1.288 + in
1.289 + (r1, r2)
1.290 + end
1.291 +
1.292 +fun load filename prec safe_propagation =
1.293 + let
1.294 + val prog = Cplex.load_cplexFile filename
1.295 + val prog = Cplex.elim_nonfree_bounds prog
1.296 + val prog = Cplex.relax_strict_ineqs prog
1.297 + val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog
1.298 + val (r1, r2) = calcr safe_propagation xlen names prec A b
1.299 + val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems"
1.300 + val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b
1.301 + val results = Cplex.solve dualprog
1.302 + val (_, v) = CplexFloatSparseMatrixConverter.convert_results results indexof
1.303 + (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*)
1.304 + fun id x = x
1.305 + val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v
1.306 + val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b)
1.307 + val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c
1.308 + val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v
1.309 + val A = FloatSparseMatrixBuilder.approx_matrix prec id A
1.310 + val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b
1.311 + val c = FloatSparseMatrixBuilder.approx_matrix prec id c
1.312 + in
1.313 + (y1, A, b2, c, (r1, r2))
1.314 + end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s)))
1.315 +
1.316 +end