src/HOL/BNF/Tools/bnf_lfp_compat.ML
author blanchet
Thu, 07 Nov 2013 00:37:18 +0100
changeset 55738 22616f65d4ea
parent 55719 78e8a178b690
child 55808 4a655e62ad34
permissions -rw-r--r--
properly detect when to perform n2m -- e.g. handle the case of two independent functions on irrelevant types being defined in parallel
blanchet@54440
     1
(*  Title:      HOL/BNF/Tools/bnf_lfp_compat.ML
blanchet@54440
     2
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@54440
     3
    Copyright   2013
blanchet@54440
     4
blanchet@54440
     5
Compatibility layer with the old datatype package.
blanchet@54440
     6
*)
blanchet@54440
     7
blanchet@54440
     8
signature BNF_LFP_COMPAT =
blanchet@54440
     9
sig
blanchet@54446
    10
  val datatype_new_compat_cmd : string list -> local_theory -> local_theory
blanchet@54440
    11
end;
blanchet@54440
    12
blanchet@54440
    13
structure BNF_LFP_Compat : BNF_LFP_COMPAT =
blanchet@54440
    14
struct
blanchet@54440
    15
blanchet@55143
    16
open Ctr_Sugar
blanchet@54440
    17
open BNF_Util
blanchet@54440
    18
open BNF_FP_Util
blanchet@54440
    19
open BNF_FP_Def_Sugar
blanchet@54440
    20
open BNF_FP_N2M_Sugar
blanchet@54440
    21
blanchet@54440
    22
fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
blanchet@54440
    23
  | dtyp_of_typ recTs (T as Type (s, Ts)) =
blanchet@54440
    24
    (case find_index (curry (op =) T) recTs of
blanchet@54440
    25
      ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
blanchet@54440
    26
    | kk => Datatype_Aux.DtRec kk);
blanchet@54440
    27
blanchet@54440
    28
val compatN = "compat_";
blanchet@54440
    29
blanchet@54440
    30
(* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
blanchet@54446
    31
fun datatype_new_compat_cmd raw_fpT_names lthy =
blanchet@54440
    32
  let
blanchet@54440
    33
    val thy = Proof_Context.theory_of lthy;
blanchet@54440
    34
blanchet@54440
    35
    fun not_datatype s = error (quote s ^ " is not a new-style datatype");
blanchet@54440
    36
    fun not_mutually_recursive ss =
blanchet@54440
    37
      error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
blanchet@54440
    38
blanchet@54440
    39
    val (fpT_names as fpT_name1 :: _) =
blanchet@54440
    40
      map (fst o dest_Type o Proof_Context.read_type_name_proper lthy false) raw_fpT_names;
blanchet@54440
    41
blanchet@54440
    42
    val Ss = Sign.arity_sorts thy fpT_name1 HOLogic.typeS;
blanchet@54440
    43
blanchet@54440
    44
    val (unsorted_As, _) = lthy |> mk_TFrees (length Ss);
blanchet@54440
    45
    val As = map2 resort_tfree Ss unsorted_As;
blanchet@54440
    46
blanchet@54440
    47
    fun lfp_sugar_of s =
blanchet@54440
    48
      (case fp_sugar_of lthy s of
blanchet@54440
    49
        SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
blanchet@54440
    50
      | _ => not_datatype s);
blanchet@54440
    51
blanchet@54440
    52
    val fp_sugar0 as {fp_res = {Ts = fpTs0, ...}, ...} = lfp_sugar_of fpT_name1;
blanchet@54440
    53
    val fpT_names' = map (fst o dest_Type) fpTs0;
blanchet@54440
    54
blanchet@54440
    55
    val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
blanchet@54440
    56
blanchet@54440
    57
    val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
blanchet@54440
    58
blanchet@54440
    59
    fun add_nested_types_of (T as Type (s, _)) seen =
blanchet@55631
    60
      if member (op =) seen T then
blanchet@54440
    61
        seen
blanchet@55631
    62
      else if s = @{type_name fun} then
blanchet@55631
    63
        (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
blanchet@54440
    64
      else
blanchet@54440
    65
        (case try lfp_sugar_of s of
blanchet@54440
    66
          SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
blanchet@54440
    67
          let
blanchet@54440
    68
            val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
blanchet@54440
    69
            val substT = Term.typ_subst_TVars rho;
blanchet@54440
    70
blanchet@54440
    71
            val mutual_Ts = map substT mutual_Ts0;
blanchet@54440
    72
blanchet@55038
    73
            fun add_interesting_subtypes (U as Type (_, Us)) =
blanchet@54440
    74
                (case filter (exists_subtype_in mutual_Ts) Us of [] => I
blanchet@54440
    75
                | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
blanchet@54440
    76
              | add_interesting_subtypes _ = I;
blanchet@54440
    77
blanchet@54440
    78
            val ctrs = maps #ctrs ctr_sugars;
blanchet@54440
    79
            val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
blanchet@54440
    80
            val subTs = fold add_interesting_subtypes ctr_Ts [];
blanchet@54440
    81
          in
blanchet@54440
    82
            fold add_nested_types_of subTs (seen @ mutual_Ts)
blanchet@54440
    83
          end
blanchet@54440
    84
        | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
blanchet@54440
    85
            " not associated with new-style datatype (cf. \"datatype_new\")"));
blanchet@54440
    86
blanchet@54440
    87
    val Ts = add_nested_types_of fpT1 [];
blanchet@54883
    88
    val b_names = map base_name_of_typ Ts;
blanchet@54883
    89
    val compat_b_names = map (prefix compatN) b_names;
blanchet@54883
    90
    val compat_bs = map Binding.name compat_b_names;
blanchet@54883
    91
    val common_name = compatN ^ mk_common_name b_names;
blanchet@54440
    92
    val nn_fp = length fpTs;
blanchet@54440
    93
    val nn = length Ts;
blanchet@54440
    94
    val get_indices = K [];
blanchet@54883
    95
    val fp_sugars0 = if nn = 1 then [fp_sugar0] else map (lfp_sugar_of o fst o dest_Type) Ts;
blanchet@55719
    96
    val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
blanchet@54440
    97
blanchet@54883
    98
    val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
blanchet@55738
    99
      if nn > nn_fp then
blanchet@55738
   100
        mutualize_fp_sugars Least_FP compat_bs Ts get_indices callssss fp_sugars0 lthy
blanchet@55738
   101
      else
blanchet@55738
   102
        ((fp_sugars0, (NONE, NONE)), lthy);
blanchet@54440
   103
blanchet@54440
   104
    val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
blanchet@54440
   105
      fp_sugars;
blanchet@54440
   106
    val inducts = conj_dests nn induct;
blanchet@54440
   107
blanchet@54440
   108
    val frozen_Ts = map Type.legacy_freeze_type Ts;
blanchet@54440
   109
    val mk_dtyp = dtyp_of_typ frozen_Ts;
blanchet@54440
   110
blanchet@54440
   111
    fun mk_ctr_descr (Const (s, T)) =
blanchet@54440
   112
      (s, map mk_dtyp (binder_types (Type.legacy_freeze_type T)));
blanchet@55140
   113
    fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
blanchet@54440
   114
      (index, (T_name, map mk_dtyp Ts, map mk_ctr_descr ctrs));
blanchet@54440
   115
blanchet@54440
   116
    val descr = map3 mk_typ_descr (0 upto nn - 1) frozen_Ts ctr_sugars;
blanchet@54440
   117
    val recs = map (fst o dest_Const o co_rec_of) co_iterss;
blanchet@54440
   118
    val rec_thms = flat (map co_rec_of iter_thmsss);
blanchet@54440
   119
blanchet@55140
   120
    fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
blanchet@54440
   121
      let
blanchet@54440
   122
        val {casex, exhaust, nchotomy, injects, distincts, case_thms, case_cong, weak_case_cong,
blanchet@54440
   123
          split, split_asm, ...} = nth ctr_sugars index;
blanchet@54440
   124
      in
blanchet@54440
   125
        (T_name0,
blanchet@54440
   126
         {index = index, descr = descr, inject = injects, distinct = distincts, induct = induct,
blanchet@54440
   127
         inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
blanchet@54440
   128
         rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
blanchet@54440
   129
         case_cong = case_cong, weak_case_cong = weak_case_cong, split = split,
blanchet@54440
   130
         split_asm = split_asm})
blanchet@54440
   131
      end;
blanchet@54440
   132
blanchet@54440
   133
    val infos = map mk_info (take nn_fp fp_sugars);
blanchet@54440
   134
blanchet@54883
   135
    val all_notes =
blanchet@54883
   136
      (case lfp_sugar_thms of
blanchet@54883
   137
        NONE => []
blanchet@54945
   138
      | SOME ((induct_thms, induct_thm, induct_attrs), (fold_thmss, rec_thmss, _)) =>
blanchet@54883
   139
        let
blanchet@54883
   140
          val common_notes =
blanchet@54883
   141
            (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
blanchet@54883
   142
            |> filter_out (null o #2)
blanchet@54883
   143
            |> map (fn (thmN, thms, attrs) =>
blanchet@54883
   144
              ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
blanchet@54883
   145
blanchet@54883
   146
          val notes =
blanchet@54884
   147
            [(foldN, fold_thmss, []),
blanchet@54883
   148
             (inductN, map single induct_thms, induct_attrs),
blanchet@54884
   149
             (recN, rec_thmss, [])]
blanchet@54883
   150
            |> filter_out (null o #2)
blanchet@54883
   151
            |> maps (fn (thmN, thmss, attrs) =>
blanchet@54883
   152
              if forall null thmss then
blanchet@54883
   153
                []
blanchet@54883
   154
              else
blanchet@54883
   155
                map2 (fn b_name => fn thms =>
blanchet@54883
   156
                    ((Binding.qualify true b_name (Binding.name thmN), attrs), [(thms, [])]))
blanchet@54883
   157
                  compat_b_names thmss);
blanchet@54883
   158
        in
blanchet@54883
   159
          common_notes @ notes
blanchet@54883
   160
        end);
blanchet@54883
   161
blanchet@54883
   162
    val register_interpret =
blanchet@54440
   163
      Datatype_Data.register infos
blanchet@54440
   164
      #> Datatype_Data.interpretation_data (Datatype_Aux.default_config, map fst infos)
blanchet@54440
   165
  in
blanchet@54883
   166
    lthy
blanchet@54883
   167
    |> Local_Theory.raw_theory register_interpret
blanchet@54883
   168
    |> Local_Theory.notes all_notes |> snd
blanchet@54440
   169
  end;
blanchet@54440
   170
blanchet@54440
   171
val _ =
blanchet@54446
   172
  Outer_Syntax.local_theory @{command_spec "datatype_new_compat"}
blanchet@55162
   173
    "register new-style datatypes as old-style datatypes"
blanchet@54446
   174
    (Scan.repeat1 Parse.type_const >> datatype_new_compat_cmd);
blanchet@54440
   175
blanchet@54440
   176
end;