src/HOL/Tools/function_package/fundef_datatype.ML
author wenzelm
Sat, 06 Oct 2007 16:50:04 +0200
changeset 24867 e5b55d7be9bb
parent 24466 619f78b717cb
child 24920 2a45e400fdad
permissions -rw-r--r--
simplified interfaces for outer syntax;
krauss@19564
     1
(*  Title:      HOL/Tools/function_package/fundef_datatype.ML
krauss@19564
     2
    ID:         $Id$
krauss@19564
     3
    Author:     Alexander Krauss, TU Muenchen
krauss@19564
     4
wenzelm@20344
     5
A package for general recursive function definitions.
krauss@19564
     6
A tactic to prove completeness of datatype patterns.
krauss@19564
     7
*)
krauss@19564
     8
wenzelm@20344
     9
signature FUNDEF_DATATYPE =
krauss@19564
    10
sig
krauss@19564
    11
    val pat_complete_tac: int -> tactic
krauss@19564
    12
krauss@20523
    13
    val pat_completeness : method
krauss@19564
    14
    val setup : theory -> theory
krauss@19564
    15
end
krauss@19564
    16
wenzelm@23351
    17
structure FundefDatatype: FUNDEF_DATATYPE =
krauss@19564
    18
struct
krauss@19564
    19
krauss@21051
    20
open FundefLib
krauss@21051
    21
open FundefCommon
krauss@19770
    22
krauss@23189
    23
krauss@23189
    24
fun check_pats ctxt geq =
krauss@23189
    25
    let 
krauss@23203
    26
      fun err str = error (cat_lines ["Malformed definition:",
krauss@23203
    27
                                      str ^ " not allowed in sequential mode.",
krauss@23189
    28
                                      ProofContext.string_of_term ctxt geq])
krauss@23189
    29
      val thy = ProofContext.theory_of ctxt
krauss@23203
    30
                
krauss@23203
    31
      fun check_constr_pattern (Bound _) = ()
krauss@23203
    32
        | check_constr_pattern t =
krauss@23203
    33
          let
krauss@23203
    34
            val (hd, args) = strip_comb t
krauss@23203
    35
          in
krauss@23203
    36
            (((case DatatypePackage.datatype_of_constr thy (fst (dest_Const hd)) of
krauss@23203
    37
                 SOME _ => ()
krauss@23203
    38
               | NONE => err "Non-constructor pattern")
krauss@23203
    39
              handle TERM ("dest_Const", _) => err "Non-constructor patterns");
krauss@23203
    40
             map check_constr_pattern args; 
krauss@23203
    41
             ())
krauss@23203
    42
          end
krauss@23203
    43
          
krauss@24170
    44
      val (fname, qs, gs, args, rhs) = split_def ctxt geq 
krauss@23203
    45
                                       
krauss@23203
    46
      val _ = if not (null gs) then err "Conditional equations" else ()
krauss@23203
    47
      val _ = map check_constr_pattern args
krauss@23203
    48
                  
krauss@23189
    49
                  (* just count occurrences to check linearity *)
krauss@24170
    50
      val _ = if fold (fold_aterms (fn Bound _ => curry (op +) 1 | _ => I)) args 0 > length qs
krauss@23203
    51
              then err "Nonlinear patterns" else ()
krauss@23189
    52
    in
krauss@23189
    53
      ()
krauss@23189
    54
    end
krauss@23203
    55
    
krauss@23189
    56
krauss@19564
    57
fun mk_argvar i T = Free ("_av" ^ (string_of_int i), T)
krauss@19564
    58
fun mk_patvar i T = Free ("_pv" ^ (string_of_int i), T)
krauss@19564
    59
krauss@19564
    60
fun inst_free var inst thm =
krauss@19564
    61
    forall_elim inst (forall_intr var thm)
krauss@19564
    62
krauss@19564
    63
krauss@19564
    64
fun inst_case_thm thy x P thm =
krauss@19564
    65
    let
wenzelm@20344
    66
        val [Pv, xv] = term_vars (prop_of thm)
krauss@19564
    67
    in
wenzelm@20344
    68
        cterm_instantiate [(cterm_of thy xv, cterm_of thy x), (cterm_of thy Pv, cterm_of thy P)] thm
krauss@19564
    69
    end
krauss@19564
    70
krauss@19564
    71
krauss@19564
    72
fun invent_vars constr i =
krauss@19564
    73
    let
wenzelm@20344
    74
        val Ts = binder_types (fastype_of constr)
wenzelm@20344
    75
        val j = i + length Ts
wenzelm@20344
    76
        val is = i upto (j - 1)
wenzelm@20344
    77
        val avs = map2 mk_argvar is Ts
wenzelm@20344
    78
        val pvs = map2 mk_patvar is Ts
krauss@19564
    79
    in
wenzelm@20344
    80
        (avs, pvs, j)
krauss@19564
    81
    end
krauss@19564
    82
krauss@19564
    83
krauss@19564
    84
fun filter_pats thy cons pvars [] = []
krauss@19564
    85
  | filter_pats thy cons pvars (([], thm) :: pts) = raise Match
wenzelm@20344
    86
  | filter_pats thy cons pvars ((pat :: pats, thm) :: pts) =
krauss@19564
    87
    case pat of
wenzelm@20344
    88
        Free _ => let val inst = list_comb (cons, pvars)
wenzelm@20344
    89
                 in (inst :: pats, inst_free (cterm_of thy pat) (cterm_of thy inst) thm)
wenzelm@20344
    90
                    :: (filter_pats thy cons pvars pts) end
krauss@19564
    91
      | _ => if fst (strip_comb pat) = cons
wenzelm@20344
    92
             then (pat :: pats, thm) :: (filter_pats thy cons pvars pts)
wenzelm@20344
    93
             else filter_pats thy cons pvars pts
krauss@19564
    94
krauss@19564
    95
krauss@19564
    96
fun inst_constrs_of thy (T as Type (name, _)) =
wenzelm@20344
    97
        map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
wenzelm@20344
    98
            (the (DatatypePackage.get_datatype_constrs thy name))
krauss@19564
    99
  | inst_constrs_of thy _ = raise Match
krauss@19564
   100
krauss@19564
   101
krauss@19564
   102
fun transform_pat thy avars c_assum ([] , thm) = raise Match
krauss@19564
   103
  | transform_pat thy avars c_assum (pat :: pats, thm) =
krauss@19564
   104
    let
wenzelm@20344
   105
        val (_, subps) = strip_comb pat
wenzelm@20344
   106
        val eqs = map (cterm_of thy o HOLogic.mk_Trueprop o HOLogic.mk_eq) (avars ~~ subps)
wenzelm@20344
   107
        val a_eqs = map assume eqs
wenzelm@20344
   108
        val c_eq_pat = simplify (HOL_basic_ss addsimps a_eqs) c_assum
krauss@19564
   109
    in
wenzelm@20344
   110
        (subps @ pats, fold_rev implies_intr eqs
wenzelm@20344
   111
                                (implies_elim thm c_eq_pat))
krauss@19564
   112
    end
krauss@19564
   113
krauss@19564
   114
krauss@19564
   115
exception COMPLETENESS
krauss@19564
   116
krauss@19564
   117
fun constr_case thy P idx (v :: vs) pats cons =
krauss@19564
   118
    let
wenzelm@20344
   119
        val (avars, pvars, newidx) = invent_vars cons idx
wenzelm@20344
   120
        val c_hyp = cterm_of thy (HOLogic.mk_Trueprop (HOLogic.mk_eq (v, list_comb (cons, avars))))
wenzelm@20344
   121
        val c_assum = assume c_hyp
wenzelm@20344
   122
        val newpats = map (transform_pat thy avars c_assum) (filter_pats thy cons pvars pats)
krauss@19564
   123
    in
wenzelm@20344
   124
        o_alg thy P newidx (avars @ vs) newpats
wenzelm@20344
   125
              |> implies_intr c_hyp
wenzelm@20344
   126
              |> fold_rev (forall_intr o cterm_of thy) avars
krauss@19564
   127
    end
krauss@19564
   128
  | constr_case _ _ _ _ _ _ = raise Match
krauss@19564
   129
and o_alg thy P idx [] (([], Pthm) :: _)  = Pthm
krauss@19564
   130
  | o_alg thy P idx (v :: vs) [] = raise COMPLETENESS
krauss@19564
   131
  | o_alg thy P idx (v :: vs) pts =
krauss@19564
   132
    if forall (is_Free o hd o fst) pts (* Var case *)
krauss@19564
   133
    then o_alg thy P idx vs (map (fn (pv :: pats, thm) =>
wenzelm@20344
   134
                               (pats, refl RS (inst_free (cterm_of thy pv) (cterm_of thy v) thm))) pts)
krauss@19564
   135
    else (* Cons case *)
wenzelm@20344
   136
         let
wenzelm@20344
   137
             val T = fastype_of v
wenzelm@20344
   138
             val (tname, _) = dest_Type T
wenzelm@20344
   139
             val {exhaustion=case_thm, ...} = DatatypePackage.the_datatype thy tname
wenzelm@20344
   140
             val constrs = inst_constrs_of thy T
wenzelm@20344
   141
             val c_cases = map (constr_case thy P idx (v :: vs) pts) constrs
wenzelm@20344
   142
         in
wenzelm@20344
   143
             inst_case_thm thy v P case_thm
wenzelm@20344
   144
                           |> fold (curry op COMP) c_cases
wenzelm@20344
   145
         end
krauss@19564
   146
  | o_alg _ _ _ _ _ = raise Match
krauss@19564
   147
wenzelm@20344
   148
krauss@19564
   149
fun prove_completeness thy x P qss pats =
krauss@19564
   150
    let
wenzelm@20344
   151
        fun mk_assum qs pat = Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (x,pat)),
wenzelm@20344
   152
                                                HOLogic.mk_Trueprop P)
wenzelm@20344
   153
                                               |> fold_rev mk_forall qs
wenzelm@20344
   154
                                               |> cterm_of thy
krauss@19564
   155
wenzelm@20344
   156
        val hyps = map2 mk_assum qss pats
krauss@19564
   157
wenzelm@20344
   158
        fun inst_hyps hyp qs = fold (forall_elim o cterm_of thy) qs (assume hyp)
krauss@19564
   159
wenzelm@20344
   160
        val assums = map2 inst_hyps hyps qss
krauss@19564
   161
    in
wenzelm@20344
   162
        o_alg thy P 2 [x] (map2 (pair o single) pats assums)
wenzelm@20344
   163
              |> fold_rev implies_intr hyps
krauss@19564
   164
    end
krauss@19564
   165
krauss@19564
   166
krauss@19564
   167
krauss@19564
   168
fun pat_complete_tac i thm =
wenzelm@20344
   169
    let
krauss@19922
   170
      val thy = theory_of_thm thm
krauss@19922
   171
wenzelm@20344
   172
        val subgoal = nth (prems_of thm) (i - 1)   (* FIXME SUBGOAL tactical *)
krauss@19922
   173
krauss@19922
   174
        val ([P, x], subgf) = dest_all_all subgoal
krauss@19922
   175
wenzelm@20344
   176
        val assums = Logic.strip_imp_prems subgf
krauss@19564
   177
wenzelm@20344
   178
        fun pat_of assum =
wenzelm@20344
   179
            let
wenzelm@20344
   180
                val (qs, imp) = dest_all_all assum
wenzelm@20344
   181
            in
wenzelm@20344
   182
                case Logic.dest_implies imp of
wenzelm@20344
   183
                    (_ $ (_ $ _ $ pat), _) => (qs, pat)
wenzelm@20344
   184
                  | _ => raise COMPLETENESS
wenzelm@20344
   185
            end
krauss@19564
   186
wenzelm@20344
   187
        val (qss, pats) = split_list (map pat_of assums)
wenzelm@20344
   188
wenzelm@20344
   189
        val complete_thm = prove_completeness thy x P qss pats
krauss@19922
   190
                                              |> forall_intr (cterm_of thy x)
krauss@19922
   191
                                              |> forall_intr (cterm_of thy P)
krauss@19564
   192
    in
wenzelm@20344
   193
        Seq.single (Drule.compose_single(complete_thm, i, thm))
krauss@19564
   194
    end
krauss@19564
   195
    handle COMPLETENESS => Seq.empty
krauss@19564
   196
krauss@19564
   197
wenzelm@21588
   198
val pat_completeness = Method.SIMPLE_METHOD' pat_complete_tac
krauss@20523
   199
krauss@20523
   200
val by_pat_completeness_simp =
krauss@20523
   201
    Proof.global_terminal_proof
wenzelm@23351
   202
      (Method.Basic (K pat_completeness, Position.none),
krauss@22899
   203
       SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
krauss@20523
   204
krauss@22733
   205
val termination_by_lexicographic_order =
krauss@22733
   206
    FundefPackage.setup_termination_proof NONE
wenzelm@23351
   207
    #> Proof.global_terminal_proof
wenzelm@23351
   208
      (Method.Basic (LexicographicOrder.lexicographic_order [], Position.none), NONE)
krauss@19564
   209
krauss@23203
   210
fun mk_catchall fixes arities =
krauss@23203
   211
    let
krauss@23203
   212
      fun mk_eqn ((fname, fT), _) =
krauss@23203
   213
          let 
krauss@23203
   214
            val n = the (Symtab.lookup arities fname)
krauss@23203
   215
            val (argTs, rT) = chop n (binder_types fT)
krauss@23203
   216
                                   |> apsnd (fn Ts => Ts ---> body_type fT) 
krauss@23203
   217
                              
krauss@23203
   218
            val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
krauss@23203
   219
          in
krauss@23203
   220
            HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
krauss@23203
   221
                          Const ("HOL.undefined", rT))
krauss@23203
   222
              |> HOLogic.mk_Trueprop
krauss@23203
   223
              |> fold_rev mk_forall qs
krauss@23203
   224
          end
krauss@23203
   225
    in
krauss@23203
   226
      map mk_eqn fixes
krauss@23203
   227
    end
krauss@23203
   228
krauss@24170
   229
fun add_catchall ctxt fixes spec =
krauss@23203
   230
    let 
krauss@24170
   231
      val catchalls = mk_catchall fixes (mk_arities (map (split_def ctxt) (map snd spec)))
krauss@23203
   232
    in
krauss@23203
   233
      spec @ map (pair true) catchalls
krauss@23203
   234
    end
krauss@23203
   235
krauss@24356
   236
fun warn_if_redundant ctxt origs tss =
krauss@24356
   237
    let
krauss@24466
   238
        fun msg t = "Ignoring redundant equation: " ^ quote (ProofContext.string_of_term ctxt t)
krauss@24466
   239
                    
krauss@24466
   240
        val (tss', _) = chop (length origs) tss
krauss@24466
   241
        fun check ((_, t), []) = (Output.warning (msg t); [])
krauss@24466
   242
          | check ((_, t), s) = s
krauss@24356
   243
    in
krauss@24466
   244
        (map check (origs ~~ tss'); tss)
krauss@24356
   245
    end
krauss@24356
   246
krauss@24356
   247
krauss@23203
   248
fun sequential_preproc (config as FundefConfig {sequential, ...}) flags ctxt fixes spec =
krauss@23203
   249
    let
krauss@23203
   250
      val enabled = sequential orelse exists I flags
krauss@23203
   251
    in 
krauss@23203
   252
      if enabled then
krauss@23203
   253
        let
krauss@23203
   254
          val flags' = if sequential then map (K true) flags else flags
krauss@23203
   255
krauss@23203
   256
          val (nas, eqss) = split_list spec
krauss@23203
   257
                            
krauss@23203
   258
          val eqs = map the_single eqss
krauss@23203
   259
                    
krauss@24356
   260
          val feqs = eqs
krauss@23203
   261
                           |> tap (check_defs ctxt fixes) (* Standard checks *)
krauss@23203
   262
                           |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
krauss@23203
   263
                           |> curry op ~~ flags'
krauss@24356
   264
krauss@24466
   265
    val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
krauss@24356
   266
krauss@24466
   267
    val spliteqs = warn_if_redundant ctxt feqs
krauss@24466
   268
             (FundefSplit.split_some_equations ctxt compleqs)
krauss@23203
   269
krauss@23203
   270
          fun restore_spec thms =
krauss@23203
   271
              nas ~~ Library.take (length nas, Library.unflat spliteqs thms)
krauss@23819
   272
              
krauss@23819
   273
          val spliteqs' = flat (Library.take (length nas, spliteqs))
krauss@23819
   274
          val fnames = map (fst o fst) fixes
krauss@23819
   275
          val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
krauss@23819
   276
krauss@23819
   277
          fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
krauss@23819
   278
                                       |> map (map snd)
krauss@23819
   279
krauss@23203
   280
        in
krauss@23819
   281
          (flat spliteqs, restore_spec, sort)
krauss@23203
   282
        end
krauss@23203
   283
      else
krauss@23203
   284
        FundefCommon.empty_preproc check_defs config flags ctxt fixes spec
krauss@23203
   285
    end
krauss@23203
   286
wenzelm@20344
   287
val setup =
krauss@23203
   288
    Method.add_methods [("pat_completeness", Method.no_args pat_completeness, 
krauss@23203
   289
                         "Completeness prover for datatype patterns")]
krauss@23203
   290
    #> Context.theory_map (FundefCommon.set_preproc sequential_preproc)
krauss@19564
   291
krauss@20523
   292
krauss@23203
   293
val fun_config = FundefConfig { sequential=true, default="%x. arbitrary", 
krauss@23203
   294
                                target=NONE, domintros=false, tailrec=false }
krauss@20523
   295
krauss@20523
   296
krauss@20523
   297
local structure P = OuterParse and K = OuterKeyword in
krauss@20523
   298
krauss@23203
   299
fun fun_cmd config fixes statements flags lthy =
krauss@21211
   300
    lthy
krauss@23203
   301
      |> FundefPackage.add_fundef fixes statements config flags
krauss@22733
   302
      |> by_pat_completeness_simp
krauss@22733
   303
      |> termination_by_lexicographic_order
krauss@21211
   304
wenzelm@24867
   305
val _ =
krauss@20523
   306
  OuterSyntax.command "fun" "define general recursive functions (short version)" K.thy_decl
krauss@23203
   307
  (fundef_parser fun_config
krauss@23203
   308
     >> (fn ((config, fixes), (flags, statements)) =>
krauss@23203
   309
            (Toplevel.local_theory (target_of config) (fun_cmd config fixes statements flags))));
krauss@20523
   310
wenzelm@20344
   311
end
krauss@20523
   312
wenzelm@20875
   313
end