improved backwards compatiblity of primrec_new (Isabelle/ML interface, attributes, etc.)
authortraytel
Tue, 01 Oct 2013 17:04:27 +0200
changeset 5515038c0bbb8348b
parent 55149 7a8263843acb
child 55151 21dac9a60f0c
improved backwards compatiblity of primrec_new (Isabelle/ML interface, attributes, etc.)
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_lfp.ML
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Oct 01 15:02:12 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Oct 01 17:04:27 2013 +0200
     1.3 @@ -7,8 +7,17 @@
     1.4  
     1.5  signature BNF_FP_REC_SUGAR =
     1.6  sig
     1.7 +  val add_primrec: (binding * typ option * mixfix) list ->
     1.8 +    (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
     1.9    val add_primrec_cmd: (binding * string option * mixfix) list ->
    1.10 -    (Attrib.binding * string) list -> local_theory -> local_theory;
    1.11 +    (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
    1.12 +  val add_primrec_global: (binding * typ option * mixfix) list ->
    1.13 +    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    1.14 +  val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    1.15 +    (binding * typ option * mixfix) list ->
    1.16 +    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    1.17 +  val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
    1.18 +    local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
    1.19    val add_primcorecursive_cmd: bool ->
    1.20      (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
    1.21      Proof.context -> Proof.state
    1.22 @@ -31,8 +40,9 @@
    1.23  val selN = "sel"
    1.24  
    1.25  val nitpick_attrs = @{attributes [nitpick_simp]};
    1.26 -val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
    1.27  val simp_attrs = @{attributes [simp]};
    1.28 +val code_nitpick_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
    1.29 +val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs @ simp_attrs;
    1.30  
    1.31  exception Primrec_Error of string * term list;
    1.32  
    1.33 @@ -300,11 +310,11 @@
    1.34      |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
    1.35    end;
    1.36  
    1.37 -fun add_primrec fixes specs lthy =
    1.38 +fun prepare_primrec fixes specs lthy =
    1.39    let
    1.40      val (bs, mxs) = map_split (apfst fst) fixes;
    1.41      val fun_names = map Binding.name_of bs;
    1.42 -    val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
    1.43 +    val eqns_data = map (dissect_eqn lthy fun_names) specs;
    1.44      val funs_data = eqns_data
    1.45        |> partition_eq ((op =) o pairself #fun_name)
    1.46        |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
    1.47 @@ -330,52 +340,51 @@
    1.48  
    1.49      val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
    1.50  
    1.51 -    fun prove def_thms' ({nested_map_idents, nested_map_comps, ctr_specs, ...} : rec_spec)
    1.52 -        induct_thm (fun_data : eqn_data list) lthy =
    1.53 +    fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
    1.54 +        (fun_data : eqn_data list) =
    1.55        let
    1.56 -        val fun_name = #fun_name (hd fun_data);
    1.57          val def_thms = map (snd o snd) def_thms';
    1.58 -        val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
    1.59 +        val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
    1.60            |> fst
    1.61            |> map_filter (try (fn (x, [y]) =>
    1.62              (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
    1.63            |> map (fn (user_eqn, num_extra_args, rec_thm) =>
    1.64              mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
    1.65 -            |> K |> Goal.prove lthy [] [] user_eqn)
    1.66 +            |> K |> Goal.prove lthy [] [] user_eqn);
    1.67 +        val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
    1.68 +      in
    1.69 +        (poss, simp_thmss)
    1.70 +      end;
    1.71  
    1.72 -        val notes =
    1.73 -          [(inductN, if n2m then [induct_thm] else [], []),
    1.74 -           (simpsN, simp_thms, code_nitpick_simp_attrs @ simp_attrs)]
    1.75 -          |> filter_out (null o #2)
    1.76 -          |> map (fn (thmN, thms, attrs) =>
    1.77 -            ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]));
    1.78 -      in
    1.79 -        lthy |> Local_Theory.notes notes
    1.80 -      end;
    1.81 +    val notes =
    1.82 +      (if n2m then map2 (fn name => fn thm =>
    1.83 +        (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
    1.84 +      |> map (fn (prefix, thmN, thms, attrs) =>
    1.85 +        ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
    1.86  
    1.87      val common_name = mk_common_name fun_names;
    1.88  
    1.89      val common_notes =
    1.90 -      [(inductN, if n2m then [induct_thm] else [], [])]
    1.91 -      |> filter_out (null o #2)
    1.92 +      (if n2m then [(inductN, [induct_thm], [])] else [])
    1.93        |> map (fn (thmN, thms, attrs) =>
    1.94          ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
    1.95    in
    1.96 -    lthy'
    1.97 -    |> fold_map Local_Theory.define defs
    1.98 -    |-> snd oo (fn def_thms' => fold_map3 (prove def_thms') (take actual_nn rec_specs)
    1.99 -      (take actual_nn induct_thms) funs_data)
   1.100 -    |> Local_Theory.notes common_notes |> snd
   1.101 +    (((fun_names, defs),
   1.102 +      fn lthy => fn defs =>
   1.103 +        split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   1.104 +      lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
   1.105    end;
   1.106  
   1.107 -fun add_primrec_cmd raw_fixes raw_specs lthy =
   1.108 +(* primrec definition *)
   1.109 +
   1.110 +fun add_primrec_simple fixes ts lthy =
   1.111    let
   1.112 -    val _ = let val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) in null d orelse
   1.113 -      primrec_error ("duplicate function name(s): " ^ commas d) end;
   1.114 -    val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
   1.115 +    val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
   1.116 +      handle ERROR str => primrec_error str;
   1.117    in
   1.118 -    add_primrec fixes specs lthy
   1.119 -      handle ERROR str => primrec_error str
   1.120 +    lthy
   1.121 +    |> fold_map Local_Theory.define defs
   1.122 +    |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
   1.123    end
   1.124    handle Primrec_Error (str, eqns) =>
   1.125      if null eqns
   1.126 @@ -383,6 +392,56 @@
   1.127      else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
   1.128        space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
   1.129  
   1.130 +local
   1.131 +
   1.132 +fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   1.133 +  let
   1.134 +    val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
   1.135 +    val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
   1.136 +
   1.137 +    val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   1.138 +
   1.139 +    val mk_notes =
   1.140 +      flat ooo map3 (fn poss => fn prefix => fn thms =>
   1.141 +        let
   1.142 +          val (bs, attrss) = map_split (fst o nth specs) poss;
   1.143 +          val notes =
   1.144 +            map3 (fn b => fn attrs => fn thm =>
   1.145 +              ((Binding.qualify false prefix b, code_nitpick_simp_attrs @ attrs), [([thm], [])]))
   1.146 +            bs attrss thms;
   1.147 +        in
   1.148 +          ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   1.149 +        end);
   1.150 +  in
   1.151 +    lthy
   1.152 +    |> add_primrec_simple fixes (map snd specs)
   1.153 +    |-> (fn (names, (ts, (posss, simpss))) =>
   1.154 +      Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   1.155 +      #> Local_Theory.notes (mk_notes posss names simpss)
   1.156 +      #>> pair ts o map snd)
   1.157 +  end;
   1.158 +
   1.159 +in
   1.160 +
   1.161 +val add_primrec = gen_primrec Specification.check_spec;
   1.162 +val add_primrec_cmd = gen_primrec Specification.read_spec;
   1.163 +
   1.164 +end;
   1.165 +
   1.166 +fun add_primrec_global fixes specs thy =
   1.167 +  let
   1.168 +    val lthy = Named_Target.theory_init thy;
   1.169 +    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   1.170 +    val simps' = burrow (Proof_Context.export lthy' lthy) simps;
   1.171 +  in ((ts, simps'), Local_Theory.exit_global lthy') end;
   1.172 +
   1.173 +fun add_primrec_overloaded ops fixes specs thy =
   1.174 +  let
   1.175 +    val lthy = Overloading.overloading ops thy;
   1.176 +    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   1.177 +    val simps' = burrow (Proof_Context.export lthy' lthy) simps;
   1.178 +  in ((ts, simps'), Local_Theory.exit_global lthy') end;
   1.179 +
   1.180  
   1.181  
   1.182  (* Primcorec *)
   1.183 @@ -875,7 +934,7 @@
   1.184  
   1.185          val notes =
   1.186            [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
   1.187 -           (codeN, ctr_thmss(*FIXME*), code_nitpick_simp_attrs),
   1.188 +           (codeN, ctr_thmss(*FIXME*), code_nitpick_attrs),
   1.189             (ctrN, ctr_thmss, []),
   1.190             (discN, disc_thmss, simp_attrs),
   1.191             (selN, sel_thmss, simp_attrs),
     2.1 --- a/src/HOL/BNF/Tools/bnf_lfp.ML	Tue Oct 01 15:02:12 2013 +0200
     2.2 +++ b/src/HOL/BNF/Tools/bnf_lfp.ML	Tue Oct 01 17:04:27 2013 +0200
     2.3 @@ -1889,6 +1889,6 @@
     2.4  
     2.5  val _ = Outer_Syntax.local_theory @{command_spec "primrec_new"}
     2.6    "define primitive recursive functions"
     2.7 -  (Parse.fixes -- Parse_Spec.where_alt_specs >> uncurry add_primrec_cmd);
     2.8 +  (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd));
     2.9  
    2.10  end;