src/HOL/Tools/primrec.ML
author wenzelm
Thu, 09 Jun 2011 17:51:49 +0200
changeset 44208 47cf4bc789aa
parent 43353 42c7ef050bc3
child 45112 7943b69f0188
permissions -rw-r--r--
simplified Name.variant -- discontinued builtin fold_map;
     1 (*  Title:      HOL/Tools/primrec.ML
     2     Author:     Norbert Voelker, FernUni Hagen
     3     Author:     Stefan Berghofer, TU Muenchen
     4     Author:     Florian Haftmann, TU Muenchen
     5 
     6 Primitive recursive functions on datatypes.
     7 *)
     8 
     9 signature PRIMREC =
    10 sig
    11   val add_primrec: (binding * typ option * mixfix) list ->
    12     (Attrib.binding * term) list -> local_theory -> (term list * thm list) * local_theory
    13   val add_primrec_cmd: (binding * string option * mixfix) list ->
    14     (Attrib.binding * string) list -> local_theory -> (term list * thm list) * local_theory
    15   val add_primrec_global: (binding * typ option * mixfix) list ->
    16     (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
    17   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    18     (binding * typ option * mixfix) list ->
    19     (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
    20   val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
    21     local_theory -> (string * (term list * thm list)) * local_theory
    22 end;
    23 
    24 structure Primrec : PRIMREC =
    25 struct
    26 
    27 open Datatype_Aux;
    28 
    29 exception PrimrecError of string * term option;
    30 
    31 fun primrec_error msg = raise PrimrecError (msg, NONE);
    32 fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
    33 
    34 
    35 (* preprocessing of equations *)
    36 
    37 fun process_eqn is_fixed spec rec_fns =
    38   let
    39     val (vs, Ts) = split_list (strip_qnt_vars "all" spec);
    40     val body = strip_qnt_body "all" spec;
    41     val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms
    42       (fn Free (v, _) => insert (op =) v | _ => I) body []));
    43     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    44     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    45       handle TERM _ => primrec_error "not a proper equation";
    46     val (recfun, args) = strip_comb lhs;
    47     val fname =
    48       (case recfun of
    49         Free (v, _) =>
    50           if is_fixed v then v
    51           else primrec_error "illegal head of function equation"
    52       | _ => primrec_error "illegal head of function equation");
    53 
    54     val (ls', rest)  = take_prefix is_Free args;
    55     val (middle, rs') = take_suffix is_Free rest;
    56     val rpos = length ls';
    57 
    58     val (constr, cargs') =
    59       if null middle then primrec_error "constructor missing"
    60       else strip_comb (hd middle);
    61     val (cname, T) = dest_Const constr
    62       handle TERM _ => primrec_error "ill-formed constructor";
    63     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    64       primrec_error "cannot determine datatype associated with function"
    65 
    66     val (ls, cargs, rs) =
    67       (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
    68       handle TERM _ => primrec_error "illegal argument in pattern";
    69     val lfrees = ls @ rs @ cargs;
    70 
    71     fun check_vars _ [] = ()
    72       | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
    73   in
    74     if length middle > 1 then
    75       primrec_error "more than one non-variable in pattern"
    76     else
    77      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    78       check_vars "extra variables on rhs: "
    79         (Term.add_frees rhs [] |> subtract (op =) lfrees
    80           |> filter_out (is_fixed o fst));
    81       (case AList.lookup (op =) rec_fns fname of
    82         NONE =>
    83           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns
    84       | SOME (_, rpos', eqns) =>
    85           if AList.defined (op =) eqns cname then
    86             primrec_error "constructor already occurred as pattern"
    87           else if rpos <> rpos' then
    88             primrec_error "position of recursive argument inconsistent"
    89           else
    90             AList.update (op =)
    91               (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns))
    92               rec_fns))
    93   end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    94 
    95 fun process_fun descr eqns (i, fname) (fnames, fnss) =
    96   let
    97     val (_, (tname, _, constrs)) = nth descr i;
    98 
    99     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
   100 
   101     fun subst [] t fs = (t, fs)
   102       | subst subs (Abs (a, T, t)) fs =
   103           fs
   104           |> subst subs t
   105           |-> (fn t' => pair (Abs (a, T, t')))
   106       | subst subs (t as (_ $ _)) fs =
   107           let
   108             val (f, ts) = strip_comb t;
   109           in
   110             if is_Free f
   111               andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   112               let
   113                 val (fname', _) = dest_Free f;
   114                 val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   115                 val (ls, rs) = chop rpos ts
   116                 val (x', rs') =
   117                   (case rs of
   118                     x' :: rs => (x', rs)
   119                   | [] => primrec_error ("not enough arguments in recursive application\n" ^
   120                       "of function " ^ quote fname' ^ " on rhs"));
   121                 val (x, xs) = strip_comb x';
   122               in
   123                 (case AList.lookup (op =) subs x of
   124                   NONE =>
   125                     fs
   126                     |> fold_map (subst subs) ts
   127                     |-> (fn ts' => pair (list_comb (f, ts')))
   128                 | SOME (i', y) =>
   129                     fs
   130                     |> fold_map (subst subs) (xs @ ls @ rs')
   131                     ||> process_fun descr eqns (i', fname')
   132                     |-> (fn ts' => pair (list_comb (y, ts'))))
   133               end
   134             else
   135               fs
   136               |> fold_map (subst subs) (f :: ts)
   137               |-> (fn f' :: ts' => pair (list_comb (f', ts')))
   138           end
   139       | subst _ t fs = (t, fs);
   140 
   141     (* translate rec equations into function arguments suitable for rec comb *)
   142 
   143     fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   144       (case AList.lookup (op =) eqns cname of
   145         NONE => (warning ("No equation for constructor " ^ quote cname ^
   146           "\nin definition of function " ^ quote fname);
   147             (fnames', fnss', (Const (@{const_name undefined}, dummyT)) :: fns))
   148       | SOME (ls, cargs', rs, rhs, eq) =>
   149           let
   150             val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   151             val rargs = map fst recs;
   152             val subs = map (rpair dummyT o fst)
   153               (rev (Term.rename_wrt_term rhs rargs));
   154             val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
   155               (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
   156                 handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   157           in (fnames'', fnss'',
   158               (list_abs_free (cargs' @ subs @ ls @ rs, rhs')) :: fns)
   159           end)
   160 
   161   in
   162     (case AList.lookup (op =) fnames i of
   163       NONE =>
   164         if exists (fn (_, v) => fname = v) fnames then
   165           primrec_error ("inconsistent functions for datatype " ^ quote tname)
   166         else
   167           let
   168             val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   169             val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   170               ((i, fname) :: fnames, fnss, [])
   171           in
   172             (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss')
   173           end
   174     | SOME fname' =>
   175         if fname = fname' then (fnames, fnss)
   176         else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   177   end;
   178 
   179 
   180 (* prepare functions needed for definitions *)
   181 
   182 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   183   (case AList.lookup (op =) fns i of
   184     NONE =>
   185       let
   186         val dummy_fns = map (fn (_, cargs) => Const (@{const_name undefined},
   187           replicate (length cargs + length (filter is_rec_type cargs))
   188             dummyT ---> HOLogic.unitT)) constrs;
   189         val _ = warning ("No function definition for datatype " ^ quote tname)
   190       in
   191         (dummy_fns @ fs, defs)
   192       end
   193   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs));
   194 
   195 
   196 (* make definition *)
   197 
   198 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   199   let
   200     val SOME (var, varT) = get_first (fn ((b, T), mx) =>
   201       if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes;
   202     val def_name = Thm.def_name (Long_Name.base_name fname);
   203     val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT])
   204       (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1))))
   205     val rhs = singleton (Syntax.check_terms ctxt) (Type.constraint varT raw_rhs);
   206   in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end;
   207 
   208 
   209 (* find datatypes which contain all datatypes in tnames' *)
   210 
   211 fun find_dts (dt_info : info Symtab.table) _ [] = []
   212   | find_dts dt_info tnames' (tname :: tnames) =
   213       (case Symtab.lookup dt_info tname of
   214         NONE => primrec_error (quote tname ^ " is not a datatype")
   215       | SOME dt =>
   216           if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   217             (tname, dt) :: (find_dts dt_info tnames' tnames)
   218           else find_dts dt_info tnames' tnames);
   219 
   220 
   221 (* distill primitive definition(s) from primrec specification *)
   222 
   223 fun distill lthy fixes eqs = 
   224   let
   225     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   226       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
   227     val tnames = distinct (op =) (map (#1 o snd) eqns);
   228     val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
   229     val main_fns = map (fn (tname, {index, ...}) =>
   230       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   231     val {descr, rec_names, rec_rewrites, ...} =
   232       if null dts then primrec_error
   233         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   234       else snd (hd dts);
   235     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   236     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   237     val defs = map (make_def lthy fixes fs) raw_defs;
   238     val names = map snd fnames;
   239     val names_eqns = map fst eqns;
   240     val _ =
   241       if eq_set (op =) (names, names_eqns) then ()
   242       else primrec_error ("functions " ^ commas_quote names_eqns ^
   243         "\nare not mutually recursive");
   244     val rec_rewrites' = map mk_meta_eq rec_rewrites;
   245     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   246     fun prove lthy defs =
   247       let
   248         val frees = fold (Variable.add_free_names lthy) eqs [];
   249         val rewrites = rec_rewrites' @ map (snd o snd) defs;
   250         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   251       in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
   252   in ((prefix, (fs, defs)), prove) end
   253   handle PrimrecError (msg, some_eqn) =>
   254     error ("Primrec definition error:\n" ^ msg ^
   255       (case some_eqn of
   256         SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   257       | NONE => ""));
   258 
   259 
   260 (* primrec definition *)
   261 
   262 fun add_primrec_simple fixes ts lthy =
   263   let
   264     val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
   265   in
   266     lthy
   267     |> fold_map Local_Theory.define defs
   268     |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs))))
   269   end;
   270 
   271 local
   272 
   273 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   274   let
   275     val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
   276     fun attr_bindings prefix = map (fn ((b, attrs), _) =>
   277       (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
   278     fun simp_attr_binding prefix =
   279       (Binding.qualify true prefix (Binding.name "simps"),
   280         map (Attrib.internal o K) [Simplifier.simp_add, Nitpick_Simps.add]);
   281   in
   282     lthy
   283     |> add_primrec_simple fixes (map snd spec)
   284     |-> (fn (prefix, (ts, simps)) =>
   285       Spec_Rules.add Spec_Rules.Equational (ts, simps)
   286       #> fold_map Local_Theory.note (attr_bindings prefix ~~ map single simps)
   287       #-> (fn simps' => Local_Theory.note (simp_attr_binding prefix, maps snd simps')
   288       #>> (fn (_, simps'') => (ts, simps''))))
   289   end;
   290 
   291 in
   292 
   293 val add_primrec = gen_primrec Specification.check_spec;
   294 val add_primrec_cmd = gen_primrec Specification.read_spec;
   295 
   296 end;
   297 
   298 fun add_primrec_global fixes specs thy =
   299   let
   300     val lthy = Named_Target.theory_init thy;
   301     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   302     val simps' = Proof_Context.export lthy' lthy simps;
   303   in ((ts, simps'), Local_Theory.exit_global lthy') end;
   304 
   305 fun add_primrec_overloaded ops fixes specs thy =
   306   let
   307     val lthy = Overloading.overloading ops thy;
   308     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   309     val simps' = Proof_Context.export lthy' lthy simps;
   310   in ((ts, simps'), Local_Theory.exit_global lthy') end;
   311 
   312 
   313 (* outer syntax *)
   314 
   315 val _ =
   316   Outer_Syntax.local_theory "primrec" "define primitive recursive functions on datatypes"
   317     Keyword.thy_decl
   318     (Parse.fixes -- Parse_Spec.where_alt_specs
   319       >> (fn (fixes, specs) => add_primrec_cmd fixes specs #> snd));
   320 
   321 end;