src/ZF/Tools/primrec_package.ML
author wenzelm
Thu, 09 Jun 2011 16:34:49 +0200
changeset 44206 2b47822868e4
parent 41558 65631ca437c9
child 45112 7943b69f0188
permissions -rw-r--r--
discontinued Name.variant to emphasize that this is old-style / indirect;
paulson@6050
     1
(*  Title:      ZF/Tools/primrec_package.ML
wenzelm@29266
     2
    Author:     Norbert Voelker, FernUni Hagen
wenzelm@29266
     3
    Author:     Stefan Berghofer, TU Muenchen
wenzelm@29266
     4
    Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
paulson@6050
     5
wenzelm@29266
     6
Package for defining functions on datatypes by primitive recursion.
paulson@6050
     7
*)
paulson@6050
     8
paulson@6050
     9
signature PRIMREC_PACKAGE =
paulson@6050
    10
sig
haftmann@29579
    11
  val add_primrec: ((binding * string) * Attrib.src list) list -> theory -> theory * thm list
haftmann@29579
    12
  val add_primrec_i: ((binding * term) * attribute list) list -> theory -> theory * thm list
paulson@6050
    13
end;
paulson@6050
    14
paulson@6050
    15
structure PrimrecPackage : PRIMREC_PACKAGE =
paulson@6050
    16
struct
paulson@6050
    17
paulson@6050
    18
exception RecError of string;
paulson@6050
    19
paulson@6141
    20
(*Remove outer Trueprop and equality sign*)
paulson@6141
    21
val dest_eqn = FOLogic.dest_eq o FOLogic.dest_Trueprop;
paulson@6050
    22
paulson@6050
    23
fun primrec_err s = error ("Primrec definition error:\n" ^ s);
paulson@6050
    24
paulson@6050
    25
fun primrec_eq_err sign s eq =
wenzelm@26939
    26
  primrec_err (s ^ "\nin equation\n" ^ Syntax.string_of_term_global sign eq);
paulson@6050
    27
wenzelm@12183
    28
paulson@6050
    29
(* preprocessing of equations *)
paulson@6050
    30
paulson@6050
    31
(*rec_fn_opt records equations already noted for this function*)
wenzelm@12183
    32
fun process_eqn thy (eq, rec_fn_opt) =
paulson@6050
    33
  let
wenzelm@12183
    34
    val (lhs, rhs) =
wenzelm@29266
    35
        if null (Term.add_vars eq []) then
wenzelm@12203
    36
            dest_eqn eq handle TERM _ => raise RecError "not a proper equation"
wenzelm@12183
    37
        else raise RecError "illegal schematic variable(s)";
paulson@6050
    38
paulson@6050
    39
    val (recfun, args) = strip_comb lhs;
wenzelm@12203
    40
    val (fname, ftype) = dest_Const recfun handle TERM _ =>
paulson@6050
    41
      raise RecError "function is not declared as constant in theory";
paulson@6050
    42
paulson@6050
    43
    val (ls_frees, rest)  = take_prefix is_Free args;
paulson@6050
    44
    val (middle, rs_frees) = take_suffix is_Free rest;
paulson@6050
    45
wenzelm@12183
    46
    val (constr, cargs_frees) =
paulson@6050
    47
      if null middle then raise RecError "constructor missing"
paulson@6050
    48
      else strip_comb (hd middle);
paulson@6050
    49
    val (cname, _) = dest_Const constr
wenzelm@12203
    50
      handle TERM _ => raise RecError "ill-formed constructor";
wenzelm@17412
    51
    val con_info = the (Symtab.lookup (ConstructorsData.get thy) cname)
skalberg@15531
    52
      handle Option =>
paulson@6050
    53
      raise RecError "cannot determine datatype associated with function"
paulson@6050
    54
wenzelm@12183
    55
    val (ls, cargs, rs) = (map dest_Free ls_frees,
wenzelm@12183
    56
                           map dest_Free cargs_frees,
wenzelm@12183
    57
                           map dest_Free rs_frees)
wenzelm@12203
    58
      handle TERM _ => raise RecError "illegal argument in pattern";
paulson@6050
    59
    val lfrees = ls @ rs @ cargs;
paulson@6050
    60
paulson@6050
    61
    (*Constructor, frees to left of pattern, pattern variables,
paulson@6050
    62
      frees to right of pattern, rhs of equation, full original equation. *)
paulson@6050
    63
    val new_eqn = (cname, (rhs, cargs, eq))
paulson@6050
    64
paulson@6050
    65
  in
wenzelm@18973
    66
    if has_duplicates (op =) lfrees then
wenzelm@12183
    67
      raise RecError "repeated variable name in pattern"
haftmann@33038
    68
    else if not (subset (op =) (Term.add_frees rhs [], lfrees)) then
paulson@6050
    69
      raise RecError "extra variables on rhs"
wenzelm@12183
    70
    else if length middle > 1 then
paulson@6050
    71
      raise RecError "more than one non-variable in pattern"
paulson@6050
    72
    else case rec_fn_opt of
skalberg@15531
    73
        NONE => SOME (fname, ftype, ls, rs, con_info, [new_eqn])
skalberg@15531
    74
      | SOME (fname', _, ls', rs', con_info': constructor_info, eqns) =>
haftmann@17314
    75
          if AList.defined (op =) eqns cname then
wenzelm@12183
    76
            raise RecError "constructor already occurred as pattern"
wenzelm@12183
    77
          else if (ls <> ls') orelse (rs <> rs') then
wenzelm@12183
    78
            raise RecError "non-recursive arguments are inconsistent"
wenzelm@12183
    79
          else if #big_rec_name con_info <> #big_rec_name con_info' then
wenzelm@12183
    80
             raise RecError ("Mixed datatypes for function " ^ fname)
wenzelm@12183
    81
          else if fname <> fname' then
wenzelm@12183
    82
             raise RecError ("inconsistent functions for datatype " ^
wenzelm@12183
    83
                             #big_rec_name con_info)
skalberg@15531
    84
          else SOME (fname, ftype, ls, rs, con_info, new_eqn::eqns)
paulson@6050
    85
  end
wenzelm@20342
    86
  handle RecError s => primrec_eq_err thy s eq;
paulson@6050
    87
paulson@6050
    88
paulson@6050
    89
(*Instantiates a recursor equation with constructor arguments*)
wenzelm@12183
    90
fun inst_recursor ((_ $ constr, rhs), cargs') =
paulson@6050
    91
    subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs;
paulson@6050
    92
paulson@6050
    93
paulson@6050
    94
(*Convert a list of recursion equations into a recursor call*)
paulson@6050
    95
fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) =
paulson@6050
    96
  let
paulson@6050
    97
    val fconst = Const(fname, ftype)
paulson@6050
    98
    val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs)
paulson@6050
    99
    and {big_rec_name, constructors, rec_rewrites, ...} = con_info
paulson@6050
   100
paulson@6050
   101
    (*Replace X_rec(args,t) by fname(ls,t,rs) *)
paulson@6050
   102
    fun use_fabs (_ $ t) = subst_bound (t, fabs)
paulson@6050
   103
      | use_fabs t       = t
paulson@6050
   104
paulson@6050
   105
    val cnames         = map (#1 o dest_Const) constructors
paulson@6141
   106
    and recursor_pairs = map (dest_eqn o concl_of) rec_rewrites
paulson@6050
   107
paulson@6050
   108
    fun absterm (Free(a,T), body) = absfree (a,T,body)
paulson@9179
   109
      | absterm (t,body) = Abs("rec", Ind_Syntax.iT, abstract_over (t, body))
paulson@6050
   110
paulson@6050
   111
    (*Translate rec equations into function arguments suitable for recursor.
paulson@6050
   112
      Missing cases are replaced by 0 and all cases are put into order.*)
paulson@6050
   113
    fun add_case ((cname, recursor_pair), cases) =
wenzelm@12183
   114
      let val (rhs, recursor_rhs, eq) =
haftmann@17314
   115
            case AList.lookup (op =) eqns cname of
skalberg@15531
   116
                NONE => (warning ("no equation for constructor " ^ cname ^
wenzelm@12183
   117
                                  "\nin definition of function " ^ fname);
wenzelm@41558
   118
                         (Const (@{const_name zero}, Ind_Syntax.iT),
wenzelm@41558
   119
                          #2 recursor_pair, Const (@{const_name zero}, Ind_Syntax.iT)))
skalberg@15531
   120
              | SOME (rhs, cargs', eq) =>
wenzelm@12183
   121
                    (rhs, inst_recursor (recursor_pair, cargs'), eq)
wenzelm@12183
   122
          val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs))
wenzelm@30193
   123
          val abs = List.foldr absterm rhs allowed_terms
wenzelm@12183
   124
      in
paulson@6050
   125
          if !Ind_Syntax.trace then
wenzelm@12183
   126
              writeln ("recursor_rhs = " ^
wenzelm@26939
   127
                       Syntax.string_of_term_global thy recursor_rhs ^
wenzelm@26939
   128
                       "\nabs = " ^ Syntax.string_of_term_global thy abs)
paulson@6050
   129
          else();
wenzelm@12183
   130
          if Logic.occs (fconst, abs) then
wenzelm@20342
   131
              primrec_eq_err thy
wenzelm@12183
   132
                   ("illegal recursive occurrences of " ^ fname)
wenzelm@12183
   133
                   eq
wenzelm@12183
   134
          else abs :: cases
paulson@6050
   135
      end
paulson@6050
   136
paulson@6050
   137
    val recursor = head_of (#1 (hd recursor_pairs))
paulson@6050
   138
paulson@6050
   139
    (** make definition **)
paulson@6050
   140
paulson@6050
   141
    (*the recursive argument*)
wenzelm@44206
   142
    val rec_arg =
wenzelm@44206
   143
      Free (singleton (Name.variant_list (map #1 (ls@rs))) (Long_Name.base_name big_rec_name),
wenzelm@44206
   144
        Ind_Syntax.iT)
paulson@6050
   145
paulson@6050
   146
    val def_tm = Logic.mk_equals
wenzelm@12183
   147
                    (subst_bound (rec_arg, fabs),
wenzelm@12183
   148
                     list_comb (recursor,
wenzelm@30193
   149
                                List.foldr add_case [] (cnames ~~ recursor_pairs))
wenzelm@12183
   150
                     $ rec_arg)
paulson@6050
   151
paulson@6050
   152
  in
paulson@6065
   153
      if !Ind_Syntax.trace then
wenzelm@12183
   154
            writeln ("primrec def:\n" ^
wenzelm@26939
   155
                     Syntax.string_of_term_global thy def_tm)
paulson@6065
   156
      else();
wenzelm@30364
   157
      (Long_Name.base_name fname ^ "_" ^ Long_Name.base_name big_rec_name ^ "_def",
paulson@6050
   158
       def_tm)
paulson@6050
   159
  end;
paulson@6050
   160
paulson@6050
   161
paulson@6050
   162
(* prepare functions needed for definitions *)
paulson@6050
   163
wenzelm@12183
   164
fun add_primrec_i args thy =
paulson@6050
   165
  let
wenzelm@12183
   166
    val ((eqn_names, eqn_terms), eqn_atts) = apfst split_list (split_list args);
skalberg@15531
   167
    val SOME (fname, ftype, ls, rs, con_info, eqns) =
wenzelm@30193
   168
      List.foldr (process_eqn thy) NONE eqn_terms;
wenzelm@12183
   169
    val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns);
wenzelm@12183
   170
haftmann@18358
   171
    val ([def_thm], thy1) = thy
wenzelm@30364
   172
      |> Sign.add_path (Long_Name.base_name fname)
wenzelm@39814
   173
      |> Global_Theory.add_defs false [Thm.no_attributes (apfst Binding.name def)];
wenzelm@12183
   174
wenzelm@8438
   175
    val rewrites = def_thm :: map mk_meta_eq (#rec_rewrites con_info)
wenzelm@12183
   176
    val eqn_thms =
wenzelm@17985
   177
      eqn_terms |> map (fn t =>
wenzelm@20046
   178
        Goal.prove_global thy1 [] [] (Ind_Syntax.traceIt "next primrec equation = " thy1 t)
wenzelm@35409
   179
          (fn _ => EVERY [rewrite_goals_tac rewrites, rtac @{thm refl} 1]));
wenzelm@12183
   180
haftmann@18377
   181
    val (eqn_thms', thy2) =
haftmann@18377
   182
      thy1
wenzelm@39814
   183
      |> Global_Theory.add_thms ((eqn_names ~~ eqn_thms) ~~ eqn_atts);
haftmann@18377
   184
    val (_, thy3) =
haftmann@18377
   185
      thy2
wenzelm@39814
   186
      |> Global_Theory.add_thmss [((Binding.name "simps", eqn_thms'), [Simplifier.simp_add])]
wenzelm@24712
   187
      ||> Sign.parent_path;
wenzelm@12183
   188
  in (thy3, eqn_thms') end;
paulson@6050
   189
wenzelm@12183
   190
fun add_primrec args thy =
wenzelm@12183
   191
  add_primrec_i (map (fn ((name, s), srcs) =>
wenzelm@24707
   192
    ((name, Syntax.read_prop_global thy s), map (Attrib.attribute thy) srcs))
wenzelm@12183
   193
    args) thy;
wenzelm@12183
   194
wenzelm@12183
   195
wenzelm@12183
   196
(* outer syntax *)
wenzelm@12183
   197
wenzelm@24867
   198
val _ =
wenzelm@36970
   199
  Outer_Syntax.command "primrec" "define primitive recursive functions on datatypes"
wenzelm@36970
   200
    Keyword.thy_decl
wenzelm@36970
   201
    (Scan.repeat1 (Parse_Spec.opt_thm_name ":" -- Parse.prop)
wenzelm@36970
   202
      >> (Toplevel.theory o (#1 oo (add_primrec o map Parse.triple_swap))));
paulson@6050
   203
paulson@6050
   204
end;
wenzelm@12183
   205