src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
changeset 50313 36e551d3af3b
parent 50312 47fbf2e3e89c
child 50315 c707df2e2083
     1.1 --- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 22:13:22 2012 +0200
     1.2 +++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 22:31:43 2012 +0200
     1.3 @@ -8,7 +8,7 @@
     1.4  signature BNF_FP_SUGAR =
     1.5  sig
     1.6    val datatyp: bool ->
     1.7 -    bool * ((((typ * typ option) list * binding) * mixfix) * ((((binding * binding) *
     1.8 +    bool * ((((typ * sort) list * binding) * mixfix) * ((((binding * binding) *
     1.9        (binding * typ) list) * (binding * term) list) * mixfix) list) list ->
    1.10      local_theory -> local_theory
    1.11  end;
    1.12 @@ -47,7 +47,9 @@
    1.13      | SOME T' => T')
    1.14    | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
    1.15  
    1.16 -fun retype_free (Free (s, _)) T = Free (s, T);
    1.17 +fun resort_tfree S (TFree (s, _)) = TFree (s, S);
    1.18 +
    1.19 +fun retype_free T (Free (s, _)) = Free (s, T);
    1.20  
    1.21  val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
    1.22  
    1.23 @@ -69,23 +71,10 @@
    1.24  
    1.25  fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    1.26  
    1.27 -fun merge_type_arg_constrained ctxt (T, c) (T', c') =
    1.28 -  if T = T' then
    1.29 -    (case (c, c') of
    1.30 -      (_, NONE) => (T, c)
    1.31 -    | (NONE, _) => (T, c')
    1.32 -    | _ =>
    1.33 -      if c = c' then
    1.34 -        (T, c)
    1.35 -      else
    1.36 -        error ("Inconsistent sort constraints for type variable " ^
    1.37 -          quote (Syntax.string_of_typ ctxt T)))
    1.38 -  else
    1.39 -    cannot_merge_types ();
    1.40 +fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
    1.41  
    1.42 -fun merge_type_args_constrained ctxt (cAs, cAs') =
    1.43 -  if length cAs = length cAs' then map2 (merge_type_arg_constrained ctxt) cAs cAs'
    1.44 -  else cannot_merge_types ();
    1.45 +fun merge_type_args (As, As') =
    1.46 +  if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types ();
    1.47  
    1.48  fun type_args_constrained_of (((cAs, _), _), _) = cAs;
    1.49  val type_args_of = map fst o type_args_constrained_of;
    1.50 @@ -99,31 +88,45 @@
    1.51  fun defaults_of ((_, ds), _) = ds;
    1.52  fun ctr_mixfix_of (_, mx) = mx;
    1.53  
    1.54 -fun prepare_datatype prepare_typ prepare_term lfp (no_dests, specs) fake_lthy no_defs_lthy =
    1.55 +fun define_datatype prepare_constraint prepare_typ prepare_term lfp (no_dests, specs)
    1.56 +    no_defs_lthy0 =
    1.57    let
    1.58 +    (* TODO: sanity checks on arguments *)
    1.59 +
    1.60      val _ = if not lfp andalso no_dests then error "Cannot define destructor-less codatatypes"
    1.61        else ();
    1.62  
    1.63 -    val constrained_As =
    1.64 -      map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
    1.65 -      |> Library.foldr1 (merge_type_args_constrained no_defs_lthy);
    1.66 -    val As = map fst constrained_As;
    1.67 -    val As' = map dest_TFree As;
    1.68 +    val N = length specs;
    1.69  
    1.70 -    val _ = (case duplicates (op =) As of [] => ()
    1.71 -      | A :: _ => error ("Duplicate type parameter " ^
    1.72 -          quote (Syntax.string_of_typ no_defs_lthy A)));
    1.73 +    fun prepare_type_arg (ty, c) =
    1.74 +      let val TFree (s, _) = prepare_typ no_defs_lthy0 ty in
    1.75 +        TFree (s, prepare_constraint no_defs_lthy0 c)
    1.76 +      end;
    1.77  
    1.78 -    (* TODO: use sort constraints on type args *)
    1.79 +    val Ass0 = map (map prepare_type_arg o type_args_constrained_of) specs;
    1.80 +    val unsorted_Ass0 = map (map (resort_tfree HOLogic.typeS)) Ass0;
    1.81 +    val unsorted_As = Library.foldr1 merge_type_args unsorted_Ass0;
    1.82  
    1.83 -    val N = length specs;
    1.84 +    val ((Bs, Cs), no_defs_lthy) =
    1.85 +      no_defs_lthy0
    1.86 +      |> fold (Variable.declare_typ o resort_tfree dummyS) unsorted_As
    1.87 +      |> mk_TFrees N
    1.88 +      ||>> mk_TFrees N;
    1.89 +
    1.90 +    (* TODO: cleaner handling of fake contexts, without "background_theory" *)
    1.91 +    (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
    1.92 +      locale and shadows an existing global type*)
    1.93 +    val fake_thy =
    1.94 +      Theory.copy #> fold (fn spec => perhaps (try (Sign.add_type no_defs_lthy
    1.95 +        (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
    1.96 +    val fake_lthy = Proof_Context.background_theory fake_thy no_defs_lthy;
    1.97  
    1.98      fun mk_fake_T b =
    1.99        Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
   1.100 -        As);
   1.101 +        unsorted_As);
   1.102  
   1.103      val bs = map type_binder_of specs;
   1.104 -    val fakeTs = map mk_fake_T bs;
   1.105 +    val fake_Ts = map mk_fake_T bs;
   1.106  
   1.107      val mixfixes = map mixfix_of specs;
   1.108  
   1.109 @@ -138,39 +141,41 @@
   1.110      val ctr_mixfixess = map (map ctr_mixfix_of) ctr_specss;
   1.111  
   1.112      val sel_bindersss = map (map (map fst)) ctr_argsss;
   1.113 -    val fake_ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
   1.114 -
   1.115 +    val fake_ctr_Tsss0 = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
   1.116      val raw_sel_defaultsss = map (map defaults_of) ctr_specss;
   1.117  
   1.118 +    val (Ass as As :: _) :: fake_ctr_Tsss =
   1.119 +      burrow (burrow (Syntax.check_typs fake_lthy)) (Ass0 :: fake_ctr_Tsss0);
   1.120 +
   1.121 +    val _ = (case duplicates (op =) unsorted_As of [] => ()
   1.122 +      | A :: _ => error ("Duplicate type parameter " ^
   1.123 +          quote (Syntax.string_of_typ no_defs_lthy A)));
   1.124 +
   1.125      val rhs_As' = fold (fold (fold Term.add_tfreesT)) fake_ctr_Tsss [];
   1.126 -    val _ = (case subtract (op =) As' rhs_As' of
   1.127 +    val _ = (case subtract (op =) (map dest_TFree As) rhs_As' of
   1.128          [] => ()
   1.129        | A' :: _ => error ("Extra type variables on rhs: " ^
   1.130            quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
   1.131  
   1.132 -    val ((Bs, Cs), _) =
   1.133 -      no_defs_lthy
   1.134 -      |> fold Variable.declare_typ As
   1.135 -      |> mk_TFrees N
   1.136 -      ||>> mk_TFrees N;
   1.137 -
   1.138      fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
   1.139          s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
   1.140            quote (Syntax.string_of_typ fake_lthy T)))
   1.141        | eq_fpT _ _ = false;
   1.142  
   1.143      fun freeze_fp (T as Type (s, Us)) =
   1.144 -        (case find_index (eq_fpT T) fakeTs of ~1 => Type (s, map freeze_fp Us) | j => nth Bs j)
   1.145 +        (case find_index (eq_fpT T) fake_Ts of ~1 => Type (s, map freeze_fp Us) | j => nth Bs j)
   1.146        | freeze_fp T = T;
   1.147  
   1.148      val ctr_TsssBs = map (map (map freeze_fp)) fake_ctr_Tsss;
   1.149      val ctr_sum_prod_TsBs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssBs;
   1.150  
   1.151 -    val eqs = map dest_TFree Bs ~~ ctr_sum_prod_TsBs;
   1.152 +    val fp_eqs =
   1.153 +      map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
   1.154  
   1.155      val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
   1.156          fp_iter_thms, fp_rec_thms), lthy)) =
   1.157 -      fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes As' eqs no_defs_lthy;
   1.158 +      fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes (map dest_TFree unsorted_As) fp_eqs
   1.159 +        no_defs_lthy0;
   1.160  
   1.161      val add_nested_bnf_names =
   1.162        let
   1.163 @@ -245,8 +250,8 @@
   1.164                dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
   1.165            val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
   1.166  
   1.167 -          val hss = map2 (map2 retype_free) gss h_Tss;
   1.168 -          val zssss_hd = map2 (map2 (map2 (fn y => fn T :: _ => retype_free y T))) ysss z_Tssss;
   1.169 +          val hss = map2 (map2 retype_free) h_Tss gss;
   1.170 +          val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
   1.171            val (zssss_tl, _) =
   1.172              lthy
   1.173              |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
   1.174 @@ -296,7 +301,7 @@
   1.175  
   1.176            val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
   1.177  
   1.178 -          val hssss_hd = map2 (map2 (map2 (fn [g] => fn T :: _ => retype_free g T))) gssss h_Tssss;
   1.179 +          val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
   1.180            val ((sssss, hssss_tl), _) =
   1.181              lthy
   1.182              |> mk_Freessss "q" s_Tssss
   1.183 @@ -688,20 +693,9 @@
   1.184      (timer; lthy')
   1.185    end;
   1.186  
   1.187 -fun datatyp lfp bundle lthy = prepare_datatype (K I) (K I) lfp bundle lthy lthy;
   1.188 +val datatyp = define_datatype (K I) (K I) (K I);
   1.189  
   1.190 -fun datatype_cmd lfp (bundle as (_, specs)) lthy =
   1.191 -  let
   1.192 -    (* TODO: cleaner handling of fake contexts, without "background_theory" *)
   1.193 -    (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
   1.194 -      locale and shadows an existing global type*)
   1.195 -    val fake_thy = Theory.copy
   1.196 -      #> fold (fn spec => perhaps (try (Sign.add_type lthy
   1.197 -        (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
   1.198 -    val fake_lthy = Proof_Context.background_theory fake_thy lthy;
   1.199 -  in
   1.200 -    prepare_datatype Syntax.read_typ Syntax.read_term lfp bundle fake_lthy lthy
   1.201 -  end;
   1.202 +val datatype_cmd = define_datatype Typedecl.read_constraint Syntax.parse_typ Syntax.read_term;
   1.203  
   1.204  val parse_opt_binding_colon = Scan.optional (Parse.binding --| @{keyword ":"}) no_binder
   1.205