src/HOL/Tools/function_package/lexicographic_order.ML
author wenzelm
Tue, 09 Oct 2007 00:20:13 +0200
changeset 24920 2a45e400fdad
parent 24576 32ddd902b0ad
child 24961 5298ee9c3fe5
permissions -rw-r--r--
generic Syntax.pretty/string_of operations;
bulwahn@21131
     1
(*  Title:       HOL/Tools/function_package/lexicographic_order.ML
krauss@21201
     2
    ID:          $Id$
bulwahn@21131
     3
    Author:      Lukas Bulwahn, TU Muenchen
bulwahn@21131
     4
bulwahn@21131
     5
Method for termination proofs with lexicographic orderings.
bulwahn@21131
     6
*)
bulwahn@21131
     7
bulwahn@21131
     8
signature LEXICOGRAPHIC_ORDER =
bulwahn@21131
     9
sig
krauss@23056
    10
  val lexicographic_order : thm list -> Proof.context -> Method.method
krauss@21510
    11
krauss@21510
    12
  (* exported for use by size-change termination prototype.
krauss@21510
    13
     FIXME: provide a common interface later *)
krauss@23074
    14
  val mk_base_funs : theory -> typ -> term list
bulwahn@22309
    15
  (* exported for debugging *)
krauss@21237
    16
  val setup: theory -> theory
bulwahn@21131
    17
end
bulwahn@21131
    18
bulwahn@21131
    19
structure LexicographicOrder : LEXICOGRAPHIC_ORDER =
bulwahn@21131
    20
struct
bulwahn@21131
    21
krauss@23074
    22
(** General stuff **)
krauss@23074
    23
krauss@23074
    24
fun mk_measures domT mfuns =
krauss@24576
    25
    let 
krauss@24576
    26
        val relT = HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT))
krauss@24576
    27
        val mlexT = (domT --> HOLogic.natT) --> relT --> relT
krauss@24576
    28
        fun mk_ms [] = Const (@{const_name "{}"}, relT)
krauss@24576
    29
          | mk_ms (f::fs) = 
krauss@24576
    30
            Const (@{const_name "Wellfounded_Relations.mlex_prod"}, mlexT) $ f $ mk_ms fs
krauss@23074
    31
    in
krauss@24576
    32
        mk_ms mfuns
krauss@23074
    33
    end
krauss@23074
    34
krauss@23074
    35
fun del_index n [] = []
krauss@23074
    36
  | del_index n (x :: xs) =
wenzelm@23633
    37
    if n > 0 then x :: del_index (n - 1) xs else xs
bulwahn@21131
    38
bulwahn@21131
    39
fun transpose ([]::_) = []
bulwahn@21131
    40
  | transpose xss = map hd xss :: transpose (map tl xss)
bulwahn@21131
    41
krauss@23074
    42
(** Matrix cell datatype **)
krauss@23074
    43
krauss@24576
    44
datatype cell = Less of thm| LessEq of (thm * thm) | None of (thm * thm) | False of thm;
wenzelm@23633
    45
krauss@23074
    46
fun is_Less (Less _) = true
krauss@23074
    47
  | is_Less _ = false
wenzelm@23633
    48
krauss@23074
    49
fun is_LessEq (LessEq _) = true
krauss@23074
    50
  | is_LessEq _ = false
wenzelm@23633
    51
wenzelm@23633
    52
fun thm_of_cell (Less thm) = thm
wenzelm@23633
    53
  | thm_of_cell (LessEq (thm, _)) = thm
wenzelm@23633
    54
  | thm_of_cell (False thm) = thm
wenzelm@23633
    55
  | thm_of_cell (None (thm, _)) = thm
wenzelm@23633
    56
krauss@23437
    57
fun pr_cell (Less _ ) = " < "
wenzelm@23633
    58
  | pr_cell (LessEq _) = " <="
krauss@23437
    59
  | pr_cell (None _) = " ? "
krauss@23437
    60
  | pr_cell (False _) = " F "
bulwahn@22258
    61
bulwahn@22258
    62
krauss@23074
    63
(** Generating Measure Functions **)
bulwahn@22258
    64
wenzelm@23633
    65
fun mk_comp g f =
wenzelm@23633
    66
    let
wenzelm@23633
    67
      val fT = fastype_of f
krauss@23074
    68
      val gT as (Type ("fun", [xT, _])) = fastype_of g
krauss@23074
    69
      val comp = Abs ("f", fT, Abs ("g", gT, Abs ("x", xT, Bound 2 $ (Bound 1 $ Bound 0))))
krauss@23074
    70
    in
krauss@23074
    71
      Envir.beta_norm (comp $ f $ g)
krauss@23074
    72
    end
krauss@23074
    73
krauss@23074
    74
fun mk_base_funs thy (T as Type("*", [fT, sT])) = (* products *)
krauss@23074
    75
      map (mk_comp (Const ("fst", T --> fT))) (mk_base_funs thy fT)
krauss@23074
    76
    @ map (mk_comp (Const ("snd", T --> sT))) (mk_base_funs thy sT)
krauss@23074
    77
krauss@23074
    78
  | mk_base_funs thy T = (* default: size function, if available *)
krauss@23074
    79
    if Sorts.of_sort (Sign.classes_of thy) (T, [HOLogic.class_size])
krauss@23074
    80
    then [HOLogic.size_const T]
krauss@23074
    81
    else []
krauss@23074
    82
krauss@23074
    83
fun mk_sum_case f1 f2 =
krauss@23074
    84
    let
wenzelm@23633
    85
      val Type ("fun", [fT, Q]) = fastype_of f1
krauss@23074
    86
      val Type ("fun", [sT, _]) = fastype_of f2
krauss@23074
    87
    in
krauss@23074
    88
      Const (@{const_name "Sum_Type.sum_case"}, (fT --> Q) --> (sT --> Q) --> Type("+", [fT, sT]) --> Q) $ f1 $ f2
krauss@23074
    89
    end
wenzelm@23633
    90
krauss@23074
    91
fun constant_0 T = Abs ("x", T, HOLogic.zero)
krauss@23074
    92
fun constant_1 T = Abs ("x", T, HOLogic.Suc_zero)
krauss@23074
    93
krauss@23074
    94
fun mk_funorder_funs (Type ("+", [fT, sT])) =
krauss@23074
    95
      map (fn m => mk_sum_case m (constant_0 sT)) (mk_funorder_funs fT)
krauss@23074
    96
    @ map (fn m => mk_sum_case (constant_0 fT) m) (mk_funorder_funs sT)
wenzelm@23633
    97
  | mk_funorder_funs T = [ constant_1 T ]
krauss@23074
    98
krauss@23074
    99
fun mk_ext_base_funs thy (Type("+", [fT, sT])) =
krauss@23074
   100
    product (mk_ext_base_funs thy fT) (mk_ext_base_funs thy sT)
krauss@23074
   101
       |> map (uncurry mk_sum_case)
krauss@23074
   102
  | mk_ext_base_funs thy T = mk_base_funs thy T
krauss@23074
   103
krauss@23074
   104
fun mk_all_measure_funs thy (T as Type ("+", _)) =
krauss@23074
   105
    mk_ext_base_funs thy T @ mk_funorder_funs T
krauss@23074
   106
  | mk_all_measure_funs thy T = mk_base_funs thy T
krauss@23074
   107
krauss@23074
   108
krauss@23074
   109
(** Proof attempts to build the matrix **)
wenzelm@23633
   110
bulwahn@21131
   111
fun dest_term (t : term) =
bulwahn@21131
   112
    let
krauss@23074
   113
      val (vars, prop) = FundefLib.dest_all_all t
krauss@21237
   114
      val prems = Logic.strip_imp_prems prop
krauss@23074
   115
      val (lhs, rhs) = Logic.strip_imp_concl prop
wenzelm@23633
   116
                         |> HOLogic.dest_Trueprop
krauss@23074
   117
                         |> HOLogic.dest_mem |> fst
wenzelm@23633
   118
                         |> HOLogic.dest_prod
bulwahn@21131
   119
    in
krauss@23074
   120
      (vars, prems, lhs, rhs)
bulwahn@21131
   121
    end
wenzelm@23633
   122
bulwahn@21131
   123
fun mk_goal (vars, prems, lhs, rhs) rel =
wenzelm@23633
   124
    let
krauss@21237
   125
      val concl = HOLogic.mk_binrel rel (lhs, rhs) |> HOLogic.mk_Trueprop
wenzelm@23633
   126
    in
wenzelm@23633
   127
      Logic.list_implies (prems, concl)
krauss@23074
   128
        |> fold_rev FundefLib.mk_forall vars
bulwahn@21131
   129
    end
wenzelm@23633
   130
wenzelm@23633
   131
fun prove thy solve_tac t =
wenzelm@23633
   132
    cterm_of thy t |> Goal.init
krauss@23055
   133
    |> SINGLE solve_tac |> the
wenzelm@23633
   134
wenzelm@23633
   135
fun mk_cell (thy : theory) solve_tac (vars, prems, lhs, rhs) mfun =
wenzelm@23633
   136
    let
wenzelm@23633
   137
      val goals = mk_goal (vars, prems, mfun $ lhs, mfun $ rhs)
haftmann@23881
   138
      val less_thm = goals @{const_name HOL.less} |> prove thy solve_tac
bulwahn@21131
   139
    in
krauss@21237
   140
      if Thm.no_prems less_thm then
krauss@21237
   141
        Less (Goal.finish less_thm)
krauss@21237
   142
      else
krauss@21237
   143
        let
haftmann@23881
   144
          val lesseq_thm = goals @{const_name HOL.less_eq} |> prove thy solve_tac
krauss@21237
   145
        in
krauss@21237
   146
          if Thm.no_prems lesseq_thm then
krauss@23437
   147
            LessEq (Goal.finish lesseq_thm, less_thm)
wenzelm@23633
   148
          else
krauss@21237
   149
            if prems_of lesseq_thm = [HOLogic.Trueprop $ HOLogic.false_const] then False lesseq_thm
krauss@23437
   150
            else None (lesseq_thm, less_thm)
krauss@21237
   151
        end
bulwahn@21131
   152
    end
bulwahn@22309
   153
bulwahn@22309
   154
krauss@23074
   155
(** Search algorithms **)
bulwahn@22309
   156
krauss@23074
   157
fun check_col ls = forall (fn c => is_Less c orelse is_LessEq c) ls andalso not (forall (is_LessEq) ls)
bulwahn@22309
   158
krauss@23074
   159
fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (del_index col)
krauss@23074
   160
krauss@23074
   161
fun transform_order col order = map (fn x => if x >= col then x + 1 else x) order
wenzelm@23633
   162
bulwahn@21131
   163
(* simple depth-first search algorithm for the table *)
bulwahn@21131
   164
fun search_table table =
bulwahn@21131
   165
    case table of
krauss@21237
   166
      [] => SOME []
krauss@21237
   167
    | _ =>
krauss@21237
   168
      let
krauss@21237
   169
        val col = find_index (check_col) (transpose table)
krauss@21237
   170
      in case col of
wenzelm@23633
   171
           ~1 => NONE
krauss@21237
   172
         | _ =>
krauss@21237
   173
           let
bulwahn@22309
   174
             val order_opt = (table, col) |-> transform_table |> search_table
krauss@21237
   175
           in case order_opt of
krauss@21237
   176
                NONE => NONE
krauss@23074
   177
              | SOME order =>SOME (col :: transform_order col order)
krauss@21237
   178
           end
krauss@21237
   179
      end
bulwahn@22258
   180
wenzelm@23633
   181
(* find all positions of elements in a list *)
krauss@23074
   182
fun find_index_list P =
krauss@23074
   183
    let fun find _ [] = []
krauss@23074
   184
          | find n (x :: xs) = if P x then n :: find (n + 1) xs else find (n + 1) xs
krauss@23074
   185
    in find 0 end
bulwahn@22258
   186
wenzelm@23633
   187
(* simple breadth-first search algorithm for the table *)
bulwahn@22309
   188
fun bfs_search_table nodes =
bulwahn@22309
   189
    case nodes of
wenzelm@23633
   190
      [] => sys_error "INTERNAL ERROR IN lexicographic order termination tactic - fun search_table (breadth search finished)"
bulwahn@22309
   191
    | (node::rnodes) => let
wenzelm@23633
   192
        val (order, table) = node
bulwahn@22309
   193
      in
bulwahn@22309
   194
        case table of
krauss@23074
   195
          [] => SOME (foldr (fn (c, order) => c :: transform_order c order) [] (rev order))
bulwahn@22309
   196
        | _ => let
wenzelm@23633
   197
            val cols = find_index_list (check_col) (transpose table)
bulwahn@22309
   198
          in
bulwahn@22309
   199
            case cols of
wenzelm@23633
   200
              [] => NONE
wenzelm@23633
   201
            | _ => let
wenzelm@23633
   202
              val newtables = map (transform_table table) cols
krauss@23074
   203
              val neworders = map (fn c => c :: order) cols
bulwahn@22309
   204
              val newnodes = neworders ~~ newtables
bulwahn@22309
   205
            in
bulwahn@22309
   206
              bfs_search_table (rnodes @ newnodes)
wenzelm@23633
   207
            end
bulwahn@22309
   208
          end
bulwahn@22258
   209
      end
bulwahn@22258
   210
wenzelm@23633
   211
fun nsearch_table table = bfs_search_table [([], table)]
bulwahn@22258
   212
krauss@23074
   213
(** Proof Reconstruction **)
krauss@23074
   214
krauss@23074
   215
(* prove row :: cell list -> tactic *)
krauss@23074
   216
fun prove_row (Less less_thm :: _) =
krauss@24576
   217
    (rtac @{thm "mlex_less"} 1)
krauss@23074
   218
    THEN PRIMITIVE (flip implies_elim less_thm)
krauss@23437
   219
  | prove_row (LessEq (lesseq_thm, _) :: tail) =
krauss@24576
   220
    (rtac @{thm "mlex_leq"} 1)
krauss@23074
   221
    THEN PRIMITIVE (flip implies_elim lesseq_thm)
krauss@23074
   222
    THEN prove_row tail
krauss@23074
   223
  | prove_row _ = sys_error "lexicographic_order"
krauss@23074
   224
krauss@23074
   225
krauss@23074
   226
(** Error reporting **)
krauss@23074
   227
krauss@23074
   228
fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table))
wenzelm@23633
   229
wenzelm@23633
   230
fun pr_goals ctxt st =
wenzelm@23633
   231
    Display.pretty_goals_aux (ProofContext.pp ctxt) Markup.none (true, false) (Thm.nprems_of st) st
krauss@23437
   232
     |> Pretty.chunks
krauss@23437
   233
     |> Pretty.string_of
krauss@23437
   234
krauss@23437
   235
fun row_index i = chr (i + 97)
krauss@23437
   236
fun col_index j = string_of_int (j + 1)
krauss@23437
   237
wenzelm@23633
   238
fun pr_unprovable_cell _ ((i,j), Less _) = ""
wenzelm@23633
   239
  | pr_unprovable_cell ctxt ((i,j), LessEq (_, st)) =
wenzelm@23633
   240
      "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st
wenzelm@23633
   241
  | pr_unprovable_cell ctxt ((i,j), None (st_less, st_leq)) =
wenzelm@23633
   242
      "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st_less
wenzelm@23633
   243
      ^ "\n(" ^ row_index i ^ ", " ^ col_index j ^ ", <=):\n" ^ pr_goals ctxt st_leq
wenzelm@23633
   244
  | pr_unprovable_cell ctxt ((i,j), False st) =
wenzelm@23633
   245
      "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st
krauss@23437
   246
wenzelm@23633
   247
fun pr_unprovable_subgoals ctxt table =
krauss@23437
   248
    table
krauss@23437
   249
     |> map_index (fn (i,cs) => map_index (fn (j,x) => ((i,j), x)) cs)
krauss@23437
   250
     |> flat
wenzelm@23633
   251
     |> map (pr_unprovable_cell ctxt)
krauss@23437
   252
wenzelm@23633
   253
fun no_order_msg ctxt table tl measure_funs =
wenzelm@23633
   254
    let
wenzelm@24920
   255
      val prterm = Syntax.string_of_term ctxt
wenzelm@23633
   256
      fun pr_fun t i = string_of_int i ^ ") " ^ prterm t
krauss@23074
   257
wenzelm@23633
   258
      fun pr_goal t i =
krauss@23074
   259
          let
wenzelm@23633
   260
            val (_, _, lhs, rhs) = dest_term t
krauss@23074
   261
          in (* also show prems? *)
krauss@23128
   262
               i ^ ") " ^ prterm rhs ^ " ~> " ^ prterm lhs
krauss@23074
   263
          end
krauss@23074
   264
krauss@23074
   265
      val gc = map (fn i => chr (i + 96)) (1 upto length table)
krauss@23074
   266
      val mc = 1 upto length measure_funs
krauss@23437
   267
      val tstr = "Result matrix:" ::  "   " ^ concat (map (enclose " " " " o string_of_int) mc)
krauss@23074
   268
                 :: map2 (fn r => fn i => i ^ ": " ^ concat (map pr_cell r)) table gc
krauss@23437
   269
      val gstr = "Calls:" :: map2 (prefix "  " oo pr_goal) tl gc
krauss@23437
   270
      val mstr = "Measures:" :: map2 (prefix "  " oo pr_fun) measure_funs mc
wenzelm@23633
   271
      val ustr = "Unfinished subgoals:" :: pr_unprovable_subgoals ctxt table
bulwahn@21131
   272
    in
krauss@23437
   273
      cat_lines (ustr @ gstr @ mstr @ tstr @ ["", "Could not find lexicographic termination order."])
bulwahn@21131
   274
    end
wenzelm@23633
   275
krauss@23074
   276
(** The Main Function **)
wenzelm@23633
   277
fun lexicographic_order_tac ctxt solve_tac (st: thm) =
bulwahn@21131
   278
    let
krauss@21237
   279
      val thy = theory_of_thm st
krauss@23074
   280
      val ((trueprop $ (wf $ rel)) :: tl) = prems_of st
krauss@23074
   281
krauss@23074
   282
      val (domT, _) = HOLogic.dest_prodT (HOLogic.dest_setT (fastype_of rel))
krauss@23074
   283
krauss@23074
   284
      val measure_funs = mk_all_measure_funs thy domT (* 1: generate measures *)
wenzelm@23633
   285
krauss@23074
   286
      (* 2: create table *)
krauss@23074
   287
      val table = map (fn t => map (mk_cell thy solve_tac (dest_term t)) measure_funs) tl
krauss@23074
   288
krauss@23074
   289
      val order = the (search_table table) (* 3: search table *)
wenzelm@23633
   290
          handle Option => error (no_order_msg ctxt table tl measure_funs)
krauss@23074
   291
krauss@21237
   292
      val clean_table = map (fn x => map (nth x) order) table
krauss@23074
   293
krauss@23074
   294
      val relation = mk_measures domT (map (nth measure_funs) order)
wenzelm@24920
   295
      val _ = writeln ("Found termination order: " ^ quote (Syntax.string_of_term ctxt relation))
krauss@23074
   296
krauss@23074
   297
    in (* 4: proof reconstruction *)
krauss@23074
   298
      st |> (PRIMITIVE (cterm_instantiate [(cterm_of thy rel, cterm_of thy relation)])
krauss@24576
   299
              THEN (REPEAT (rtac @{thm "wf_mlex"} 1))
krauss@24576
   300
              THEN (rtac @{thm "wf_empty"} 1)
krauss@23074
   301
              THEN EVERY (map prove_row clean_table))
bulwahn@21131
   302
    end
bulwahn@21131
   303
wenzelm@23633
   304
fun lexicographic_order thms ctxt = Method.SIMPLE_METHOD (FundefCommon.apply_termination_rule ctxt 1
krauss@23055
   305
                                                         THEN lexicographic_order_tac ctxt (auto_tac (local_clasimpset_of ctxt)))
krauss@21201
   306
wenzelm@23633
   307
val setup = Method.add_methods [("lexicographic_order", Method.bang_sectioned_args clasimp_modifiers lexicographic_order,
krauss@23055
   308
                                 "termination prover for lexicographic orderings")]
bulwahn@21131
   309
wenzelm@21590
   310
end