src/HOL/Tools/recdef_package.ML
author wenzelm
Sat, 09 Aug 2008 22:43:46 +0200
changeset 27809 a1e409db516b
parent 27727 2397e310b2cc
child 28083 103d9282a946
permissions -rw-r--r--
unified Args.T with OuterLex.token, renamed some operations;
     1 (*  Title:      HOL/Tools/recdef_package.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel, TU Muenchen
     4 
     5 Wrapper module for Konrad Slind's TFL package.
     6 *)
     7 
     8 signature RECDEF_PACKAGE =
     9 sig
    10   val get_recdef: theory -> string
    11     -> {simps: thm list, rules: thm list list, induct: thm, tcs: term list} option
    12   val get_hints: Proof.context -> {simps: thm list, congs: (string * thm) list, wfs: thm list}
    13   val simp_add: attribute
    14   val simp_del: attribute
    15   val cong_add: attribute
    16   val cong_del: attribute
    17   val wf_add: attribute
    18   val wf_del: attribute
    19   val add_recdef: bool -> xstring -> string -> ((bstring * string) * Attrib.src list) list ->
    20     Attrib.src option -> theory -> theory
    21       * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
    22   val add_recdef_i: bool -> xstring -> term -> ((bstring * term) * attribute list) list ->
    23     theory -> theory * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
    24   val defer_recdef: xstring -> string list -> (Facts.ref * Attrib.src list) list
    25     -> theory -> theory * {induct_rules: thm}
    26   val defer_recdef_i: xstring -> term list -> thm list -> theory -> theory * {induct_rules: thm}
    27   val recdef_tc: bstring * Attrib.src list -> xstring -> int option -> bool ->
    28     local_theory -> Proof.state
    29   val recdef_tc_i: bstring * Attrib.src list -> string -> int option -> bool ->
    30     local_theory -> Proof.state
    31   val setup: theory -> theory
    32 end;
    33 
    34 structure RecdefPackage: RECDEF_PACKAGE =
    35 struct
    36 
    37 
    38 (** recdef hints **)
    39 
    40 (* type hints *)
    41 
    42 type hints = {simps: thm list, congs: (string * thm) list, wfs: thm list};
    43 
    44 fun mk_hints (simps, congs, wfs) = {simps = simps, congs = congs, wfs = wfs}: hints;
    45 fun map_hints f ({simps, congs, wfs}: hints) = mk_hints (f (simps, congs, wfs));
    46 
    47 fun map_simps f = map_hints (fn (simps, congs, wfs) => (f simps, congs, wfs));
    48 fun map_congs f = map_hints (fn (simps, congs, wfs) => (simps, f congs, wfs));
    49 fun map_wfs f = map_hints (fn (simps, congs, wfs) => (simps, congs, f wfs));
    50 
    51 fun pretty_hints ({simps, congs, wfs}: hints) =
    52  [Pretty.big_list "recdef simp hints:" (map Display.pretty_thm simps),
    53   Pretty.big_list "recdef cong hints:" (map Display.pretty_thm (map snd congs)),
    54   Pretty.big_list "recdef wf hints:" (map Display.pretty_thm wfs)];
    55 
    56 
    57 (* congruence rules *)
    58 
    59 local
    60 
    61 val cong_head =
    62   fst o Term.dest_Const o Term.head_of o fst o Logic.dest_equals o Thm.concl_of;
    63 
    64 fun prep_cong raw_thm =
    65   let val thm = safe_mk_meta_eq raw_thm in (cong_head thm, thm) end;
    66 
    67 in
    68 
    69 fun add_cong raw_thm congs =
    70   let
    71     val (c, thm) = prep_cong raw_thm;
    72     val _ = if AList.defined (op =) congs c
    73       then warning ("Overwriting recdef congruence rule for " ^ quote c)
    74       else ();
    75   in AList.update (op =) (c, thm) congs end;
    76 
    77 fun del_cong raw_thm congs =
    78   let
    79     val (c, thm) = prep_cong raw_thm;
    80     val _ = if AList.defined (op =) congs c
    81       then ()
    82       else warning ("No recdef congruence rule for " ^ quote c);
    83   in AList.delete (op =) c congs end;
    84 
    85 end;
    86 
    87 
    88 
    89 (** global and local recdef data **)
    90 
    91 (* theory data *)
    92 
    93 type recdef_info = {simps: thm list, rules: thm list list, induct: thm, tcs: term list};
    94 
    95 structure GlobalRecdefData = TheoryDataFun
    96 (
    97   type T = recdef_info Symtab.table * hints;
    98   val empty = (Symtab.empty, mk_hints ([], [], [])): T;
    99   val copy = I;
   100   val extend = I;
   101   fun merge _
   102    ((tab1, {simps = simps1, congs = congs1, wfs = wfs1}),
   103     (tab2, {simps = simps2, congs = congs2, wfs = wfs2})) : T =
   104       (Symtab.merge (K true) (tab1, tab2),
   105         mk_hints (Thm.merge_thms (simps1, simps2),
   106           AList.merge (op =) Thm.eq_thm (congs1, congs2),
   107           Thm.merge_thms (wfs1, wfs2)));
   108 );
   109 
   110 val get_recdef = Symtab.lookup o #1 o GlobalRecdefData.get;
   111 
   112 fun put_recdef name info thy =
   113   let
   114     val (tab, hints) = GlobalRecdefData.get thy;
   115     val tab' = Symtab.update_new (name, info) tab
   116       handle Symtab.DUP _ => error ("Duplicate recursive function definition " ^ quote name);
   117   in GlobalRecdefData.put (tab', hints) thy end;
   118 
   119 val get_global_hints = #2 o GlobalRecdefData.get;
   120 
   121 
   122 (* proof data *)
   123 
   124 structure LocalRecdefData = ProofDataFun
   125 (
   126   type T = hints;
   127   val init = get_global_hints;
   128 );
   129 
   130 val get_hints = LocalRecdefData.get;
   131 fun map_hints f = Context.mapping (GlobalRecdefData.map (apsnd f)) (LocalRecdefData.map f);
   132 
   133 
   134 (* attributes *)
   135 
   136 fun attrib f = Thm.declaration_attribute (map_hints o f);
   137 
   138 val simp_add = attrib (map_simps o Thm.add_thm);
   139 val simp_del = attrib (map_simps o Thm.del_thm);
   140 val cong_add = attrib (map_congs o add_cong);
   141 val cong_del = attrib (map_congs o del_cong);
   142 val wf_add = attrib (map_wfs o Thm.add_thm);
   143 val wf_del = attrib (map_wfs o Thm.del_thm);
   144 
   145 
   146 (* modifiers *)
   147 
   148 val recdef_simpN = "recdef_simp";
   149 val recdef_congN = "recdef_cong";
   150 val recdef_wfN = "recdef_wf";
   151 
   152 val recdef_modifiers =
   153  [Args.$$$ recdef_simpN -- Args.colon >> K ((I, simp_add): Method.modifier),
   154   Args.$$$ recdef_simpN -- Args.add -- Args.colon >> K (I, simp_add),
   155   Args.$$$ recdef_simpN -- Args.del -- Args.colon >> K (I, simp_del),
   156   Args.$$$ recdef_congN -- Args.colon >> K (I, cong_add),
   157   Args.$$$ recdef_congN -- Args.add -- Args.colon >> K (I, cong_add),
   158   Args.$$$ recdef_congN -- Args.del -- Args.colon >> K (I, cong_del),
   159   Args.$$$ recdef_wfN -- Args.colon >> K (I, wf_add),
   160   Args.$$$ recdef_wfN -- Args.add -- Args.colon >> K (I, wf_add),
   161   Args.$$$ recdef_wfN -- Args.del -- Args.colon >> K (I, wf_del)] @
   162   Clasimp.clasimp_modifiers;
   163 
   164 
   165 
   166 (** prepare_hints(_i) **)
   167 
   168 fun prepare_hints thy opt_src =
   169   let
   170     val ctxt0 = ProofContext.init thy;
   171     val ctxt =
   172       (case opt_src of
   173         NONE => ctxt0
   174       | SOME src => Method.only_sectioned_args recdef_modifiers I src ctxt0);
   175     val {simps, congs, wfs} = get_hints ctxt;
   176     val cs = local_claset_of ctxt;
   177     val ss = local_simpset_of ctxt addsimps simps;
   178   in (cs, ss, rev (map snd congs), wfs) end;
   179 
   180 fun prepare_hints_i thy () =
   181   let
   182     val ctxt0 = ProofContext.init thy;
   183     val {simps, congs, wfs} = get_global_hints thy;
   184   in (local_claset_of ctxt0, local_simpset_of ctxt0 addsimps simps, rev (map snd congs), wfs) end;
   185 
   186 
   187 
   188 (** add_recdef(_i) **)
   189 
   190 fun requires_recdef thy = Theory.requires thy "Recdef" "recursive functions";
   191 
   192 fun gen_add_recdef tfl_fn prep_att prep_hints not_permissive raw_name R eq_srcs hints thy =
   193   let
   194     val _ = requires_recdef thy;
   195 
   196     val name = Sign.intern_const thy raw_name;
   197     val bname = Sign.base_name name;
   198     val _ = writeln ("Defining recursive function " ^ quote name ^ " ...");
   199 
   200     val ((eq_names, eqs), raw_eq_atts) = apfst split_list (split_list eq_srcs);
   201     val eq_atts = map (map (prep_att thy)) raw_eq_atts;
   202 
   203     val (cs, ss, congs, wfs) = prep_hints thy hints;
   204     (*We must remove imp_cong to prevent looping when the induction rule
   205       is simplified. Many induction rules have nested implications that would
   206       give rise to looping conditional rewriting.*)
   207     val (thy, {rules = rules_idx, induct, tcs}) =
   208         tfl_fn not_permissive thy cs (ss delcongs [imp_cong])
   209                congs wfs name R eqs;
   210     val rules = (map o map) fst (partition_eq (eq_snd (op = : int * int -> bool)) rules_idx);
   211     val simp_att = if null tcs then [Simplifier.simp_add, RecfunCodegen.add_default] else [];
   212 
   213     val ((simps' :: rules', [induct']), thy) =
   214       thy
   215       |> Sign.add_path bname
   216       |> PureThy.add_thmss
   217         ((("simps", List.concat rules), simp_att) :: ((eq_names ~~ rules) ~~ eq_atts))
   218       ||>> PureThy.add_thms [(("induct", induct), [])];
   219     val result = {simps = simps', rules = rules', induct = induct', tcs = tcs};
   220     val thy =
   221       thy
   222       |> put_recdef name result
   223       |> Sign.parent_path;
   224   in (thy, result) end;
   225 
   226 val add_recdef = gen_add_recdef Tfl.define Attrib.attribute prepare_hints;
   227 fun add_recdef_i x y z w = gen_add_recdef Tfl.define_i (K I) prepare_hints_i x y z w ();
   228 
   229 
   230 
   231 (** defer_recdef(_i) **)
   232 
   233 fun gen_defer_recdef tfl_fn eval_thms raw_name eqs raw_congs thy =
   234   let
   235     val name = Sign.intern_const thy raw_name;
   236     val bname = Sign.base_name name;
   237 
   238     val _ = requires_recdef thy;
   239     val _ = writeln ("Deferred recursive function " ^ quote name ^ " ...");
   240 
   241     val congs = eval_thms (ProofContext.init thy) raw_congs;
   242     val (thy2, induct_rules) = tfl_fn thy congs name eqs;
   243     val ([induct_rules'], thy3) =
   244       thy2
   245       |> Sign.add_path bname
   246       |> PureThy.add_thms [(("induct_rules", induct_rules), [])]
   247       ||> Sign.parent_path;
   248   in (thy3, {induct_rules = induct_rules'}) end;
   249 
   250 val defer_recdef = gen_defer_recdef Tfl.defer Attrib.eval_thms;
   251 val defer_recdef_i = gen_defer_recdef Tfl.defer_i (K I);
   252 
   253 
   254 
   255 (** recdef_tc(_i) **)
   256 
   257 fun gen_recdef_tc prep_att prep_name (bname, raw_atts) raw_name opt_i int lthy =
   258   let
   259     val thy = ProofContext.theory_of lthy;
   260     val name = prep_name thy raw_name;
   261     val atts = map (prep_att thy) raw_atts;
   262     val tcs =
   263       (case get_recdef thy name of
   264         NONE => error ("No recdef definition of constant: " ^ quote name)
   265       | SOME {tcs, ...} => tcs);
   266     val i = the_default 1 opt_i;
   267     val tc = nth tcs (i - 1) handle Subscript =>
   268       error ("No termination condition #" ^ string_of_int i ^
   269         " in recdef definition of " ^ quote name);
   270   in
   271     Specification.theorem Thm.internalK NONE (K I) (bname, atts)
   272       [] (Element.Shows [(("", []), [(HOLogic.mk_Trueprop tc, [])])]) int lthy
   273   end;
   274 
   275 val recdef_tc = gen_recdef_tc Attrib.intern_src Sign.intern_const;
   276 val recdef_tc_i = gen_recdef_tc (K I) (K I);
   277 
   278 
   279 
   280 (** package setup **)
   281 
   282 (* setup theory *)
   283 
   284 val setup =
   285   Attrib.add_attributes
   286    [(recdef_simpN, Attrib.add_del_args simp_add simp_del, "declaration of recdef simp rule"),
   287     (recdef_congN, Attrib.add_del_args cong_add cong_del, "declaration of recdef cong rule"),
   288     (recdef_wfN, Attrib.add_del_args wf_add wf_del, "declaration of recdef wf rule")];
   289 
   290 
   291 (* outer syntax *)
   292 
   293 local structure P = OuterParse and K = OuterKeyword in
   294 
   295 val _ = List.app OuterKeyword.keyword ["permissive", "congs", "hints"];
   296 
   297 val hints =
   298   P.$$$ "(" |-- P.!!! (P.position (P.$$$ "hints" -- Args.parse) --| P.$$$ ")") >> Args.src;
   299 
   300 val recdef_decl =
   301   Scan.optional (P.$$$ "(" -- P.!!! (P.$$$ "permissive" -- P.$$$ ")") >> K false) true --
   302   P.name -- P.term -- Scan.repeat1 (SpecParse.opt_thm_name ":" -- P.prop) -- Scan.option hints
   303   >> (fn ((((p, f), R), eqs), src) => #1 o add_recdef p f R (map P.triple_swap eqs) src);
   304 
   305 val _ =
   306   OuterSyntax.command "recdef" "define general recursive functions (TFL)" K.thy_decl
   307     (recdef_decl >> Toplevel.theory);
   308 
   309 
   310 val defer_recdef_decl =
   311   P.name -- Scan.repeat1 P.prop --
   312   Scan.optional (P.$$$ "(" |-- P.$$$ "congs" |-- P.!!! (SpecParse.xthms1 --| P.$$$ ")")) []
   313   >> (fn ((f, eqs), congs) => #1 o defer_recdef f eqs congs);
   314 
   315 val _ =
   316   OuterSyntax.command "defer_recdef" "defer general recursive functions (TFL)" K.thy_decl
   317     (defer_recdef_decl >> Toplevel.theory);
   318 
   319 val _ =
   320   OuterSyntax.local_theory_to_proof' "recdef_tc" "recommence proof of termination condition (TFL)"
   321     K.thy_goal
   322     (SpecParse.opt_thm_name ":" -- P.xname -- Scan.option (P.$$$ "(" |-- P.nat --| P.$$$ ")")
   323       >> (fn ((thm_name, name), i) => recdef_tc thm_name name i));
   324 
   325 end;
   326 
   327 end;