src/HOL/Tools/function_package/lexicographic_order.ML
author krauss
Tue, 22 May 2007 17:25:26 +0200
changeset 23074 a53cb8ddb052
parent 23056 448827ccd9e9
child 23128 8e0abe0fa80f
permissions -rw-r--r--
some optimizations, cleanup
bulwahn@21131
     1
(*  Title:       HOL/Tools/function_package/lexicographic_order.ML
krauss@21201
     2
    ID:          $Id$
bulwahn@21131
     3
    Author:      Lukas Bulwahn, TU Muenchen
bulwahn@21131
     4
bulwahn@21131
     5
Method for termination proofs with lexicographic orderings.
bulwahn@21131
     6
*)
bulwahn@21131
     7
bulwahn@21131
     8
signature LEXICOGRAPHIC_ORDER =
bulwahn@21131
     9
sig
krauss@23056
    10
  val lexicographic_order : thm list -> Proof.context -> Method.method
krauss@21510
    11
krauss@21510
    12
  (* exported for use by size-change termination prototype.
krauss@21510
    13
     FIXME: provide a common interface later *)
krauss@23074
    14
  val mk_base_funs : theory -> typ -> term list
bulwahn@22309
    15
  (* exported for debugging *)
krauss@21237
    16
  val setup: theory -> theory
bulwahn@21131
    17
end
bulwahn@21131
    18
bulwahn@21131
    19
structure LexicographicOrder : LEXICOGRAPHIC_ORDER =
bulwahn@21131
    20
struct
bulwahn@21131
    21
krauss@23074
    22
(** General stuff **)
krauss@23074
    23
krauss@23074
    24
fun mk_measures domT mfuns =
krauss@23074
    25
    let val list = HOLogic.mk_list (domT --> HOLogic.natT) mfuns
krauss@23074
    26
    in
krauss@23074
    27
      Const (@{const_name "List.measures"}, fastype_of list --> (HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT)))) $ list
krauss@23074
    28
    end
krauss@23074
    29
krauss@23074
    30
fun del_index n [] = []
krauss@23074
    31
  | del_index n (x :: xs) =
krauss@23074
    32
    if n > 0 then x :: del_index (n - 1) xs else xs 
bulwahn@21131
    33
bulwahn@21131
    34
fun transpose ([]::_) = []
bulwahn@21131
    35
  | transpose xss = map hd xss :: transpose (map tl xss)
bulwahn@21131
    36
krauss@23074
    37
(** Matrix cell datatype **)
krauss@23074
    38
bulwahn@21131
    39
datatype cell = Less of thm | LessEq of thm | None of thm | False of thm;
krauss@21237
    40
         
krauss@23074
    41
fun is_Less (Less _) = true
krauss@23074
    42
  | is_Less _ = false
krauss@21237
    43
                                                        
krauss@23074
    44
fun is_LessEq (LessEq _) = true
krauss@23074
    45
  | is_LessEq _ = false
krauss@21237
    46
                                                            
krauss@23074
    47
fun thm_of_cell (Less thm) = thm 
krauss@23074
    48
  | thm_of_cell (LessEq thm) = thm 
krauss@23074
    49
  | thm_of_cell (False thm) = thm 
krauss@23074
    50
  | thm_of_cell (None thm) = thm 
krauss@21237
    51
                  
krauss@23074
    52
fun pr_cell (Less _ ) = " <  "
krauss@23074
    53
  | pr_cell (LessEq _) = " <= " 
krauss@23074
    54
  | pr_cell (None _) = " N  "
krauss@23074
    55
  | pr_cell (False _) = " F  "
bulwahn@22258
    56
bulwahn@22258
    57
krauss@23074
    58
(** Generating Measure Functions **)
bulwahn@22258
    59
krauss@23074
    60
fun mk_comp g f = 
krauss@23074
    61
    let 
krauss@23074
    62
      val fT = fastype_of f 
krauss@23074
    63
      val gT as (Type ("fun", [xT, _])) = fastype_of g
krauss@23074
    64
      val comp = Abs ("f", fT, Abs ("g", gT, Abs ("x", xT, Bound 2 $ (Bound 1 $ Bound 0))))
krauss@23074
    65
    in
krauss@23074
    66
      Envir.beta_norm (comp $ f $ g)
krauss@23074
    67
    end
krauss@23074
    68
krauss@23074
    69
fun mk_base_funs thy (T as Type("*", [fT, sT])) = (* products *)
krauss@23074
    70
      map (mk_comp (Const ("fst", T --> fT))) (mk_base_funs thy fT)
krauss@23074
    71
    @ map (mk_comp (Const ("snd", T --> sT))) (mk_base_funs thy sT)
krauss@23074
    72
krauss@23074
    73
  | mk_base_funs thy T = (* default: size function, if available *)
krauss@23074
    74
    if Sorts.of_sort (Sign.classes_of thy) (T, [HOLogic.class_size])
krauss@23074
    75
    then [HOLogic.size_const T]
krauss@23074
    76
    else []
krauss@23074
    77
krauss@23074
    78
fun mk_sum_case f1 f2 =
krauss@23074
    79
    let
krauss@23074
    80
      val Type ("fun", [fT, Q]) = fastype_of f1 
krauss@23074
    81
      val Type ("fun", [sT, _]) = fastype_of f2
krauss@23074
    82
    in
krauss@23074
    83
      Const (@{const_name "Sum_Type.sum_case"}, (fT --> Q) --> (sT --> Q) --> Type("+", [fT, sT]) --> Q) $ f1 $ f2
krauss@23074
    84
    end
krauss@23074
    85
                 
krauss@23074
    86
fun constant_0 T = Abs ("x", T, HOLogic.zero)
krauss@23074
    87
fun constant_1 T = Abs ("x", T, HOLogic.Suc_zero)
krauss@23074
    88
krauss@23074
    89
fun mk_funorder_funs (Type ("+", [fT, sT])) =
krauss@23074
    90
      map (fn m => mk_sum_case m (constant_0 sT)) (mk_funorder_funs fT)
krauss@23074
    91
    @ map (fn m => mk_sum_case (constant_0 fT) m) (mk_funorder_funs sT)
krauss@23074
    92
  | mk_funorder_funs T = [ constant_1 T ] 
krauss@23074
    93
krauss@23074
    94
fun mk_ext_base_funs thy (Type("+", [fT, sT])) =
krauss@23074
    95
    product (mk_ext_base_funs thy fT) (mk_ext_base_funs thy sT)
krauss@23074
    96
       |> map (uncurry mk_sum_case)
krauss@23074
    97
  | mk_ext_base_funs thy T = mk_base_funs thy T
krauss@23074
    98
krauss@23074
    99
fun mk_all_measure_funs thy (T as Type ("+", _)) =
krauss@23074
   100
    mk_ext_base_funs thy T @ mk_funorder_funs T
krauss@23074
   101
  | mk_all_measure_funs thy T = mk_base_funs thy T
krauss@23074
   102
krauss@23074
   103
krauss@23074
   104
(** Proof attempts to build the matrix **)
krauss@21237
   105
           
bulwahn@21131
   106
fun dest_term (t : term) =
bulwahn@21131
   107
    let
krauss@23074
   108
      val (vars, prop) = FundefLib.dest_all_all t
krauss@21237
   109
      val prems = Logic.strip_imp_prems prop
krauss@23074
   110
      val (lhs, rhs) = Logic.strip_imp_concl prop
krauss@21237
   111
                         |> HOLogic.dest_Trueprop 
krauss@23074
   112
                         |> HOLogic.dest_mem |> fst
krauss@23074
   113
                         |> HOLogic.dest_prod 
bulwahn@21131
   114
    in
krauss@23074
   115
      (vars, prems, lhs, rhs)
bulwahn@21131
   116
    end
krauss@21237
   117
    
bulwahn@21131
   118
fun mk_goal (vars, prems, lhs, rhs) rel =
bulwahn@21131
   119
    let 
krauss@21237
   120
      val concl = HOLogic.mk_binrel rel (lhs, rhs) |> HOLogic.mk_Trueprop
krauss@21237
   121
    in  
krauss@23074
   122
      Logic.list_implies (prems, concl) 
krauss@23074
   123
        |> fold_rev FundefLib.mk_forall vars
bulwahn@21131
   124
    end
krauss@21237
   125
    
krauss@23055
   126
fun prove (thy: theory) solve_tac (t: term) =
bulwahn@21131
   127
    cterm_of thy t |> Goal.init 
krauss@23055
   128
    |> SINGLE solve_tac |> the
krauss@21237
   129
    
krauss@23074
   130
fun mk_cell (thy : theory) solve_tac (vars, prems, lhs, rhs) mfun = 
krauss@21237
   131
    let 
krauss@23074
   132
      val goals = mk_goal (vars, prems, mfun $ lhs, mfun $ rhs) 
krauss@23055
   133
      val less_thm = goals "Orderings.ord_class.less" |> prove thy solve_tac
bulwahn@21131
   134
    in
krauss@21237
   135
      if Thm.no_prems less_thm then
krauss@21237
   136
        Less (Goal.finish less_thm)
krauss@21237
   137
      else
krauss@21237
   138
        let
krauss@23055
   139
          val lesseq_thm = goals "Orderings.ord_class.less_eq" |> prove thy solve_tac
krauss@21237
   140
        in
krauss@21237
   141
          if Thm.no_prems lesseq_thm then
krauss@21237
   142
            LessEq (Goal.finish lesseq_thm)
krauss@21237
   143
          else 
krauss@21237
   144
            if prems_of lesseq_thm = [HOLogic.Trueprop $ HOLogic.false_const] then False lesseq_thm
krauss@21237
   145
            else None lesseq_thm
krauss@21237
   146
        end
bulwahn@21131
   147
    end
bulwahn@22309
   148
bulwahn@22309
   149
krauss@23074
   150
(** Search algorithms **)
bulwahn@22309
   151
krauss@23074
   152
fun check_col ls = forall (fn c => is_Less c orelse is_LessEq c) ls andalso not (forall (is_LessEq) ls)
bulwahn@22309
   153
krauss@23074
   154
fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (del_index col)
krauss@23074
   155
krauss@23074
   156
fun transform_order col order = map (fn x => if x >= col then x + 1 else x) order
krauss@21237
   157
      
bulwahn@21131
   158
(* simple depth-first search algorithm for the table *)
bulwahn@21131
   159
fun search_table table =
bulwahn@21131
   160
    case table of
krauss@21237
   161
      [] => SOME []
krauss@21237
   162
    | _ =>
krauss@21237
   163
      let
krauss@21237
   164
        val col = find_index (check_col) (transpose table)
krauss@21237
   165
      in case col of
krauss@21237
   166
           ~1 => NONE 
krauss@21237
   167
         | _ =>
krauss@21237
   168
           let
bulwahn@22309
   169
             val order_opt = (table, col) |-> transform_table |> search_table
krauss@21237
   170
           in case order_opt of
krauss@21237
   171
                NONE => NONE
krauss@23074
   172
              | SOME order =>SOME (col :: transform_order col order)
krauss@21237
   173
           end
krauss@21237
   174
      end
bulwahn@22258
   175
bulwahn@22258
   176
(* find all positions of elements in a list *) 
krauss@23074
   177
fun find_index_list P =
krauss@23074
   178
    let fun find _ [] = []
krauss@23074
   179
          | find n (x :: xs) = if P x then n :: find (n + 1) xs else find (n + 1) xs
krauss@23074
   180
    in find 0 end
bulwahn@22258
   181
bulwahn@22258
   182
(* simple breadth-first search algorithm for the table *) 
bulwahn@22309
   183
fun bfs_search_table nodes =
bulwahn@22309
   184
    case nodes of
bulwahn@22258
   185
      [] => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun search_table (breadth search finished)" 
bulwahn@22309
   186
    | (node::rnodes) => let
bulwahn@22309
   187
	val (order, table) = node
bulwahn@22309
   188
      in
bulwahn@22309
   189
        case table of
krauss@23074
   190
          [] => SOME (foldr (fn (c, order) => c :: transform_order c order) [] (rev order))
bulwahn@22309
   191
        | _ => let
bulwahn@22309
   192
	    val cols = find_index_list (check_col) (transpose table)
bulwahn@22309
   193
          in
bulwahn@22309
   194
            case cols of
bulwahn@22309
   195
	      [] => NONE
bulwahn@22309
   196
            | _ => let 
bulwahn@22309
   197
              val newtables = map (transform_table table) cols 
krauss@23074
   198
              val neworders = map (fn c => c :: order) cols
bulwahn@22309
   199
              val newnodes = neworders ~~ newtables
bulwahn@22309
   200
            in
bulwahn@22309
   201
              bfs_search_table (rnodes @ newnodes)
bulwahn@22309
   202
            end 
bulwahn@22309
   203
          end
bulwahn@22258
   204
      end
bulwahn@22258
   205
bulwahn@22309
   206
fun nsearch_table table = bfs_search_table [([], table)] 	       
bulwahn@22258
   207
krauss@23074
   208
(** Proof Reconstruction **)
krauss@23074
   209
krauss@23074
   210
(* prove row :: cell list -> tactic *)
krauss@23074
   211
fun prove_row (Less less_thm :: _) =
krauss@23074
   212
    (rtac @{thm "measures_less"} 1)
krauss@23074
   213
    THEN PRIMITIVE (flip implies_elim less_thm)
krauss@23074
   214
  | prove_row (LessEq lesseq_thm :: tail) =
krauss@23074
   215
    (rtac @{thm "measures_lesseq"} 1)
krauss@23074
   216
    THEN PRIMITIVE (flip implies_elim lesseq_thm)
krauss@23074
   217
    THEN prove_row tail
krauss@23074
   218
  | prove_row _ = sys_error "lexicographic_order"
krauss@23074
   219
krauss@23074
   220
krauss@23074
   221
(** Error reporting **)
krauss@23074
   222
krauss@23074
   223
fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table))
krauss@23074
   224
       
bulwahn@21131
   225
fun pr_unprovable_subgoals table =
krauss@23074
   226
    filter_out (fn x => is_Less x orelse is_LessEq x) (flat table)
krauss@21237
   227
    |> map ((fn th => Pretty.string_of (Pretty.chunks (Display.pretty_goals (Thm.nprems_of th) th))) o thm_of_cell)
krauss@21237
   228
    
krauss@23074
   229
fun no_order_msg table thy tl measure_funs =  
krauss@23074
   230
    let 
krauss@23074
   231
      fun pr_fun t i = string_of_int i ^ ") " ^ string_of_cterm (cterm_of thy t)
krauss@23074
   232
krauss@23074
   233
      fun pr_goal t i = 
krauss@23074
   234
          let
krauss@23074
   235
            val (_, _, lhs, rhs) = dest_term t 
krauss@23074
   236
            val prterm = string_of_cterm o (cterm_of thy)
krauss@23074
   237
          in (* also show prems? *)
krauss@23074
   238
               i ^ ") " ^ prterm lhs ^ " '<' " ^ prterm rhs
krauss@23074
   239
          end
krauss@23074
   240
krauss@23074
   241
      val gc = map (fn i => chr (i + 96)) (1 upto length table)
krauss@23074
   242
      val mc = 1 upto length measure_funs
krauss@23074
   243
      val tstr = "   " ^ concat (map (enclose " " " " o string_of_int) mc)
krauss@23074
   244
                 :: map2 (fn r => fn i => i ^ ": " ^ concat (map pr_cell r)) table gc
krauss@23074
   245
      val gstr = "Goals:" :: map2 pr_goal tl gc
krauss@23074
   246
      val mstr = "Measures:" :: map2 pr_fun measure_funs mc
krauss@23074
   247
      val ustr = "Unfinished subgoals:" :: pr_unprovable_subgoals table
bulwahn@21131
   248
    in
krauss@23074
   249
      cat_lines ("Could not find lexicographic termination order:" :: tstr @ gstr @ mstr @ ustr)
bulwahn@21131
   250
    end
krauss@21237
   251
      
krauss@23074
   252
(** The Main Function **)
krauss@23055
   253
fun lexicographic_order_tac ctxt solve_tac (st: thm) = 
bulwahn@21131
   254
    let
krauss@21237
   255
      val thy = theory_of_thm st
krauss@23074
   256
      val ((trueprop $ (wf $ rel)) :: tl) = prems_of st
krauss@23074
   257
krauss@23074
   258
      val (domT, _) = HOLogic.dest_prodT (HOLogic.dest_setT (fastype_of rel))
krauss@23074
   259
krauss@23074
   260
      val measure_funs = mk_all_measure_funs thy domT (* 1: generate measures *)
krauss@23074
   261
                         
krauss@23074
   262
      (* 2: create table *)
krauss@23074
   263
      val table = map (fn t => map (mk_cell thy solve_tac (dest_term t)) measure_funs) tl
krauss@23074
   264
krauss@23074
   265
      val order = the (search_table table) (* 3: search table *)
krauss@23074
   266
          handle Option => error (no_order_msg table thy tl measure_funs)
krauss@23074
   267
krauss@21237
   268
      val clean_table = map (fn x => map (nth x) order) table
krauss@23074
   269
krauss@23074
   270
      val relation = mk_measures domT (map (nth measure_funs) order)
krauss@23074
   271
      val _ = writeln ("Found termination order: " ^ quote (ProofContext.string_of_term ctxt relation))
krauss@23074
   272
krauss@23074
   273
    in (* 4: proof reconstruction *)
krauss@23074
   274
      st |> (PRIMITIVE (cterm_instantiate [(cterm_of thy rel, cterm_of thy relation)])
krauss@23074
   275
              THEN rtac @{thm "wf_measures"} 1
krauss@23074
   276
              THEN EVERY (map prove_row clean_table))
bulwahn@21131
   277
    end
bulwahn@21131
   278
krauss@23055
   279
fun lexicographic_order thms ctxt = Method.SIMPLE_METHOD (FundefCommon.apply_termination_rule ctxt 1 
krauss@23055
   280
                                                         THEN lexicographic_order_tac ctxt (auto_tac (local_clasimpset_of ctxt)))
krauss@21201
   281
krauss@23055
   282
val setup = Method.add_methods [("lexicographic_order", Method.bang_sectioned_args clasimp_modifiers lexicographic_order, 
krauss@23055
   283
                                 "termination prover for lexicographic orderings")]
bulwahn@21131
   284
wenzelm@21590
   285
end