src/HOL/Matrix_LP/fspmlp.ML
changeset 47859 9f492f5b0cec
parent 47404 faf233c4a404
child 48326 26315a545e26
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Matrix_LP/fspmlp.ML	Sat Mar 17 12:52:40 2012 +0100
     1.3 @@ -0,0 +1,313 @@
     1.4 +(*  Title:      HOL/Matrix/fspmlp.ML
     1.5 +    Author:     Steven Obua
     1.6 +*)
     1.7 +
     1.8 +signature FSPMLP =
     1.9 +sig
    1.10 +    type linprog
    1.11 +    type vector = FloatSparseMatrixBuilder.vector
    1.12 +    type matrix = FloatSparseMatrixBuilder.matrix
    1.13 +
    1.14 +    val y : linprog -> term
    1.15 +    val A : linprog -> term * term
    1.16 +    val b : linprog -> term
    1.17 +    val c : linprog -> term * term
    1.18 +    val r12 : linprog -> term * term
    1.19 +
    1.20 +    exception Load of string
    1.21 +
    1.22 +    val load : string -> int -> bool -> linprog
    1.23 +end
    1.24 +
    1.25 +structure Fspmlp : FSPMLP =
    1.26 +struct
    1.27 +
    1.28 +type vector = FloatSparseMatrixBuilder.vector
    1.29 +type matrix = FloatSparseMatrixBuilder.matrix
    1.30 +
    1.31 +type linprog = term * (term * term) * term * (term * term) * (term * term)
    1.32 +
    1.33 +fun y (c1, _, _, _, _) = c1
    1.34 +fun A (_, c2, _, _, _) = c2
    1.35 +fun b (_, _, c3, _, _) = c3
    1.36 +fun c (_, _, _, c4, _) = c4
    1.37 +fun r12 (_, _, _, _, c6) = c6
    1.38 +
    1.39 +structure CplexFloatSparseMatrixConverter =
    1.40 +MAKE_CPLEX_MATRIX_CONVERTER(structure cplex = Cplex and matrix_builder = FloatSparseMatrixBuilder);
    1.41 +
    1.42 +datatype bound_type = LOWER | UPPER
    1.43 +
    1.44 +fun intbound_ord ((i1: int, b1),(i2,b2)) =
    1.45 +    if i1 < i2 then LESS
    1.46 +    else if i1 = i2 then
    1.47 +        (if b1 = b2 then EQUAL else if b1=LOWER then LESS else GREATER)
    1.48 +    else GREATER
    1.49 +
    1.50 +structure Inttab = Table(type key = int val ord = (rev_order o int_ord));
    1.51 +
    1.52 +structure VarGraph = Table(type key = int*bound_type val ord = intbound_ord);
    1.53 +(* key -> (float option) * (int -> (float * (((float * float) * key) list)))) *)
    1.54 +(* dest_key -> (sure_bound * (row_index -> (row_bound * (((coeff_lower * coeff_upper) * src_key) list)))) *)
    1.55 +
    1.56 +exception Internal of string;
    1.57 +
    1.58 +fun add_row_bound g dest_key row_index row_bound =
    1.59 +    let
    1.60 +        val x =
    1.61 +            case VarGraph.lookup g dest_key of
    1.62 +                NONE => (NONE, Inttab.update (row_index, (row_bound, [])) Inttab.empty)
    1.63 +              | SOME (sure_bound, f) =>
    1.64 +                (sure_bound,
    1.65 +                 case Inttab.lookup f row_index of
    1.66 +                     NONE => Inttab.update (row_index, (row_bound, [])) f
    1.67 +                   | SOME _ => raise (Internal "add_row_bound"))
    1.68 +    in
    1.69 +        VarGraph.update (dest_key, x) g
    1.70 +    end
    1.71 +
    1.72 +fun update_sure_bound g (key as (_, btype)) bound =
    1.73 +    let
    1.74 +        val x =
    1.75 +            case VarGraph.lookup g key of
    1.76 +                NONE => (SOME bound, Inttab.empty)
    1.77 +              | SOME (NONE, f) => (SOME bound, f)
    1.78 +              | SOME (SOME old_bound, f) =>
    1.79 +                (SOME ((case btype of
    1.80 +                            UPPER => Float.min
    1.81 +                          | LOWER => Float.max)
    1.82 +                           old_bound bound), f)
    1.83 +    in
    1.84 +        VarGraph.update (key, x) g
    1.85 +    end
    1.86 +
    1.87 +fun get_sure_bound g key =
    1.88 +    case VarGraph.lookup g key of
    1.89 +        NONE => NONE
    1.90 +      | SOME (sure_bound, _) => sure_bound
    1.91 +
    1.92 +(*fun get_row_bound g key row_index =
    1.93 +    case VarGraph.lookup g key of
    1.94 +        NONE => NONE
    1.95 +      | SOME (sure_bound, f) =>
    1.96 +        (case Inttab.lookup f row_index of
    1.97 +             NONE => NONE
    1.98 +           | SOME (row_bound, _) => (sure_bound, row_bound))*)
    1.99 +
   1.100 +fun add_edge g src_key dest_key row_index coeff =
   1.101 +    case VarGraph.lookup g dest_key of
   1.102 +        NONE => raise (Internal "add_edge: dest_key not found")
   1.103 +      | SOME (sure_bound, f) =>
   1.104 +        (case Inttab.lookup f row_index of
   1.105 +             NONE => raise (Internal "add_edge: row_index not found")
   1.106 +           | SOME (row_bound, sources) =>
   1.107 +             VarGraph.update (dest_key, (sure_bound, Inttab.update (row_index, (row_bound, (coeff, src_key) :: sources)) f)) g)
   1.108 +
   1.109 +fun split_graph g =
   1.110 +  let
   1.111 +    fun split (key, (sure_bound, _)) (r1, r2) = case sure_bound
   1.112 +     of NONE => (r1, r2)
   1.113 +      | SOME bound =>  (case key
   1.114 +         of (u, UPPER) => (r1, Inttab.update (u, bound) r2)
   1.115 +          | (u, LOWER) => (Inttab.update (u, bound) r1, r2))
   1.116 +  in VarGraph.fold split g (Inttab.empty, Inttab.empty) end
   1.117 +
   1.118 +(* If safe is true, termination is guaranteed, but the sure bounds may be not optimal (relative to the algorithm).
   1.119 +   If safe is false, termination is not guaranteed, but on termination the sure bounds are optimal (relative to the algorithm) *)
   1.120 +fun propagate_sure_bounds safe names g =
   1.121 +    let
   1.122 +        (* returns NONE if no new sure bound could be calculated, otherwise the new sure bound is returned *)
   1.123 +        fun calc_sure_bound_from_sources g (key as (_, btype)) =
   1.124 +            let
   1.125 +                fun mult_upper x (lower, upper) =
   1.126 +                    if Float.sign x = LESS then
   1.127 +                        Float.mult x lower
   1.128 +                    else
   1.129 +                        Float.mult x upper
   1.130 +
   1.131 +                fun mult_lower x (lower, upper) =
   1.132 +                    if Float.sign x = LESS then
   1.133 +                        Float.mult x upper
   1.134 +                    else
   1.135 +                        Float.mult x lower
   1.136 +
   1.137 +                val mult_btype = case btype of UPPER => mult_upper | LOWER => mult_lower
   1.138 +
   1.139 +                fun calc_sure_bound (_, (row_bound, sources)) sure_bound =
   1.140 +                    let
   1.141 +                        fun add_src_bound (coeff, src_key) sum =
   1.142 +                            case sum of
   1.143 +                                NONE => NONE
   1.144 +                              | SOME x =>
   1.145 +                                (case get_sure_bound g src_key of
   1.146 +                                     NONE => NONE
   1.147 +                                   | SOME src_sure_bound => SOME (Float.add x (mult_btype src_sure_bound coeff)))
   1.148 +                    in
   1.149 +                        case fold add_src_bound sources (SOME row_bound) of
   1.150 +                            NONE => sure_bound
   1.151 +                          | new_sure_bound as (SOME new_bound) =>
   1.152 +                            (case sure_bound of
   1.153 +                                 NONE => new_sure_bound
   1.154 +                               | SOME old_bound =>
   1.155 +                                 SOME (case btype of
   1.156 +                                           UPPER => Float.min old_bound new_bound
   1.157 +                                         | LOWER => Float.max old_bound new_bound))
   1.158 +                    end
   1.159 +            in
   1.160 +                case VarGraph.lookup g key of
   1.161 +                    NONE => NONE
   1.162 +                  | SOME (sure_bound, f) =>
   1.163 +                    let
   1.164 +                        val x = Inttab.fold calc_sure_bound f sure_bound
   1.165 +                    in
   1.166 +                        if x = sure_bound then NONE else x
   1.167 +                    end
   1.168 +                end
   1.169 +
   1.170 +        fun propagate (key, _) (g, b) =
   1.171 +            case calc_sure_bound_from_sources g key of
   1.172 +                NONE => (g,b)
   1.173 +              | SOME bound => (update_sure_bound g key bound,
   1.174 +                               if safe then
   1.175 +                                   case get_sure_bound g key of
   1.176 +                                       NONE => true
   1.177 +                                     | _ => b
   1.178 +                               else
   1.179 +                                   true)
   1.180 +
   1.181 +        val (g, b) = VarGraph.fold propagate g (g, false)
   1.182 +    in
   1.183 +        if b then propagate_sure_bounds safe names g else g
   1.184 +    end
   1.185 +
   1.186 +exception Load of string;
   1.187 +
   1.188 +val empty_spvec = @{term "Nil :: real spvec"};
   1.189 +fun cons_spvec x xs = @{term "Cons :: nat * real => real spvec => real spvec"} $ x $ xs;
   1.190 +val empty_spmat = @{term "Nil :: real spmat"};
   1.191 +fun cons_spmat x xs = @{term "Cons :: nat * real spvec => real spmat => real spmat"} $ x $ xs;
   1.192 +
   1.193 +fun calcr safe_propagation xlen names prec A b =
   1.194 +    let
   1.195 +        fun test_1 (lower, upper) =
   1.196 +            if lower = upper then
   1.197 +                (if Float.eq (lower, (~1, 0)) then ~1
   1.198 +                 else if Float.eq (lower, (1, 0)) then 1
   1.199 +                 else 0)
   1.200 +            else 0
   1.201 +
   1.202 +        fun calcr (row_index, a) g =
   1.203 +            let
   1.204 +                val b =  FloatSparseMatrixBuilder.v_elem_at b row_index
   1.205 +                val (_, b2) = FloatArith.approx_decstr_by_bin prec (case b of NONE => "0" | SOME b => b)
   1.206 +                val approx_a = FloatSparseMatrixBuilder.v_fold (fn (i, s) => fn l =>
   1.207 +                                                                   (i, FloatArith.approx_decstr_by_bin prec s)::l) a []
   1.208 +
   1.209 +                fun fold_dest_nodes (dest_index, dest_value) g =
   1.210 +                    let
   1.211 +                        val dest_test = test_1 dest_value
   1.212 +                    in
   1.213 +                        if dest_test = 0 then
   1.214 +                            g
   1.215 +                        else let
   1.216 +                                val (dest_key as (_, dest_btype), row_bound) =
   1.217 +                                    if dest_test = ~1 then
   1.218 +                                        ((dest_index, LOWER), Float.neg b2)
   1.219 +                                    else
   1.220 +                                        ((dest_index, UPPER), b2)
   1.221 +
   1.222 +                                fun fold_src_nodes (src_index, src_value as (src_lower, src_upper)) g =
   1.223 +                                    if src_index = dest_index then g
   1.224 +                                    else
   1.225 +                                        let
   1.226 +                                            val coeff = case dest_btype of
   1.227 +                                                            UPPER => (Float.neg src_upper, Float.neg src_lower)
   1.228 +                                                          | LOWER => src_value
   1.229 +                                        in
   1.230 +                                            if Float.sign src_lower = LESS then
   1.231 +                                                add_edge g (src_index, UPPER) dest_key row_index coeff
   1.232 +                                            else
   1.233 +                                                add_edge g (src_index, LOWER) dest_key row_index coeff
   1.234 +                                        end
   1.235 +                            in
   1.236 +                                fold fold_src_nodes approx_a (add_row_bound g dest_key row_index row_bound)
   1.237 +                            end
   1.238 +                    end
   1.239 +            in
   1.240 +                case approx_a of
   1.241 +                    [] => g
   1.242 +                  | [(u, a)] =>
   1.243 +                    let
   1.244 +                        val atest = test_1 a
   1.245 +                    in
   1.246 +                        if atest = ~1 then
   1.247 +                            update_sure_bound g (u, LOWER) (Float.neg b2)
   1.248 +                        else if atest = 1 then
   1.249 +                            update_sure_bound g (u, UPPER) b2
   1.250 +                        else
   1.251 +                            g
   1.252 +                    end
   1.253 +                  | _ => fold fold_dest_nodes approx_a g
   1.254 +            end
   1.255 +
   1.256 +        val g = FloatSparseMatrixBuilder.m_fold calcr A VarGraph.empty
   1.257 +
   1.258 +        val g = propagate_sure_bounds safe_propagation names g
   1.259 +
   1.260 +        val (r1, r2) = split_graph g
   1.261 +
   1.262 +        fun add_row_entry m index f vname value =
   1.263 +            let
   1.264 +                val v = (case value of 
   1.265 +                             SOME value => FloatSparseMatrixBuilder.mk_spvec_entry 0 value
   1.266 +                           | NONE => FloatSparseMatrixBuilder.mk_spvec_entry' 0 (f $ (Var ((vname,0), HOLogic.realT))))
   1.267 +                val vec = cons_spvec v empty_spvec
   1.268 +            in
   1.269 +                cons_spmat (FloatSparseMatrixBuilder.mk_spmat_entry index vec) m
   1.270 +            end
   1.271 +
   1.272 +        fun abs_estimate i r1 r2 =
   1.273 +            if i = 0 then
   1.274 +                let val e = empty_spmat in (e, e) end
   1.275 +            else
   1.276 +                let
   1.277 +                    val index = xlen-i
   1.278 +                    val (r12_1, r12_2) = abs_estimate (i-1) r1 r2
   1.279 +                    val b1 = Inttab.lookup r1 index
   1.280 +                    val b2 = Inttab.lookup r2 index
   1.281 +                in
   1.282 +                    (add_row_entry r12_1 index @{term "lbound :: real => real"} ((names index)^"l") b1, 
   1.283 +                     add_row_entry r12_2 index @{term "ubound :: real => real"} ((names index)^"u") b2)
   1.284 +                end
   1.285 +
   1.286 +        val (r1, r2) = abs_estimate xlen r1 r2
   1.287 +
   1.288 +    in
   1.289 +        (r1, r2)
   1.290 +    end
   1.291 +
   1.292 +fun load filename prec safe_propagation =
   1.293 +    let
   1.294 +        val prog = Cplex.load_cplexFile filename
   1.295 +        val prog = Cplex.elim_nonfree_bounds prog
   1.296 +        val prog = Cplex.relax_strict_ineqs prog
   1.297 +        val (maximize, c, A, b, (xlen, names, _)) = CplexFloatSparseMatrixConverter.convert_prog prog                       
   1.298 +        val (r1, r2) = calcr safe_propagation xlen names prec A b
   1.299 +        val _ = if maximize then () else raise Load "sorry, cannot handle minimization problems"
   1.300 +        val (dualprog, indexof) = FloatSparseMatrixBuilder.dual_cplexProg c A b
   1.301 +        val results = Cplex.solve dualprog
   1.302 +        val (_, v) = CplexFloatSparseMatrixConverter.convert_results results indexof
   1.303 +        (*val A = FloatSparseMatrixBuilder.cut_matrix v NONE A*)
   1.304 +        fun id x = x
   1.305 +        val v = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 v
   1.306 +        val b = FloatSparseMatrixBuilder.transpose_matrix (FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 b)
   1.307 +        val c = FloatSparseMatrixBuilder.set_vector FloatSparseMatrixBuilder.empty_matrix 0 c
   1.308 +        val (y1, _) = FloatSparseMatrixBuilder.approx_matrix prec Float.positive_part v
   1.309 +        val A = FloatSparseMatrixBuilder.approx_matrix prec id A
   1.310 +        val (_,b2) = FloatSparseMatrixBuilder.approx_matrix prec id b
   1.311 +        val c = FloatSparseMatrixBuilder.approx_matrix prec id c
   1.312 +    in
   1.313 +        (y1, A, b2, c, (r1, r2))
   1.314 +    end handle CplexFloatSparseMatrixConverter.Converter s => (raise (Load ("Converter: "^s)))
   1.315 +
   1.316 +end