class_package - operational view on type classes
authorhaftmann
Mon, 14 Nov 2005 15:15:34 +0100
changeset 18168d35daf321b8a
parent 18167 4f9410e685df
child 18169 45def66f86cb
class_package - operational view on type classes
src/Pure/Tools/class_package.ML
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/Pure/Tools/class_package.ML	Mon Nov 14 15:15:34 2005 +0100
     1.3 @@ -0,0 +1,290 @@
     1.4 +(*  Title:      Pure/Tools/class_package.ML
     1.5 +    ID:         $Id$
     1.6 +    Author:     Florian Haftmann, TU Muenchen
     1.7 +
     1.8 +Haskell98-like operational view on type classes.
     1.9 +*)
    1.10 +
    1.11 +signature CLASS_PACKAGE =
    1.12 +sig
    1.13 +  val add_consts: class * xstring list -> theory -> theory
    1.14 +  val add_consts_i: class * string list -> theory -> theory
    1.15 +  val add_tycos: class * xstring list -> theory -> theory
    1.16 +  val add_tycos_i: class * (string * string) list -> theory -> theory
    1.17 +  val the_consts: theory -> class -> string list
    1.18 +  val the_tycos: theory -> class -> (string * string) list
    1.19 +
    1.20 +  val is_class: theory -> class -> bool
    1.21 +  val get_arities: theory -> sort -> string -> sort list
    1.22 +  val get_superclasses: theory -> class -> class list
    1.23 +  val get_const_sign: theory -> string -> string * typ
    1.24 +  val get_inst_consts_sign: theory -> string * class -> (string * typ) list
    1.25 +  val lookup_const_class: theory -> string -> class option
    1.26 +  val get_classtab: theory -> (string list * (string * string) list) Symtab.table
    1.27 +
    1.28 +  type sortcontext = (string * sort) list
    1.29 +  datatype sortlookup = Instance of (class * string) * sortlookup list list
    1.30 +                      | Lookup of class list * (string * int)
    1.31 +  val extract_sortctxt: theory -> typ -> sortcontext
    1.32 +  val extract_sortlookup: theory -> typ * typ -> sortlookup list list
    1.33 +end;
    1.34 +
    1.35 +structure ClassPackage: CLASS_PACKAGE =
    1.36 +struct
    1.37 +
    1.38 +
    1.39 +(* data kind 'Pure/classes' *)
    1.40 +
    1.41 +type class_data = {
    1.42 +  locale_name: string,
    1.43 +  axclass_name: string,
    1.44 +  consts: string list,
    1.45 +  tycos: (string * string) list
    1.46 +};
    1.47 +
    1.48 +structure ClassesData = TheoryDataFun (
    1.49 +  struct
    1.50 +    val name = "Pure/classes";
    1.51 +    type T = class_data Symtab.table * class Symtab.table;
    1.52 +    val empty = (Symtab.empty, Symtab.empty);
    1.53 +    val copy = I;
    1.54 +    val extend = I;
    1.55 +    fun merge _ ((t1, r1), (t2, r2))=
    1.56 +      (Symtab.merge (op =) (t1, t2),
    1.57 +       Symtab.merge (op =) (r1, r2));
    1.58 +    fun print _ (tab, _) = (Pretty.writeln o Pretty.chunks) (map Pretty.str (Symtab.keys tab));
    1.59 +  end
    1.60 +);
    1.61 +
    1.62 +val lookup_class_data = Symtab.lookup o fst o ClassesData.get;
    1.63 +val lookup_const_class = Symtab.lookup o snd o ClassesData.get;
    1.64 +
    1.65 +fun get_class_data thy class =
    1.66 +  case lookup_class_data thy class
    1.67 +    of NONE => error ("undeclared class " ^ quote class)
    1.68 +     | SOME data => data;
    1.69 +
    1.70 +fun put_class_data class data =
    1.71 +  ClassesData.map (apfst (Symtab.update (class, data)));
    1.72 +fun add_const class const =
    1.73 +  ClassesData.map (apsnd (Symtab.update (const, class)));
    1.74 +
    1.75 +
    1.76 +(* name mangling *)
    1.77 +
    1.78 +fun get_locale_for_class thy class =
    1.79 +  #locale_name (get_class_data thy class);
    1.80 +
    1.81 +fun get_axclass_for_class thy class =
    1.82 +  #axclass_name (get_class_data thy class);
    1.83 +
    1.84 +
    1.85 +(* assign consts to type classes *)
    1.86 +
    1.87 +local
    1.88 +
    1.89 +fun gen_add_consts prep_class prep_const (raw_class, raw_consts_new) thy =
    1.90 +  let
    1.91 +    val class = prep_class thy raw_class;
    1.92 +    val consts_new = map (prep_const thy) raw_consts_new;
    1.93 +    val {locale_name, axclass_name, consts, tycos} =
    1.94 +      get_class_data thy class;
    1.95 +  in
    1.96 +    thy
    1.97 +    |> put_class_data class {
    1.98 +         locale_name = locale_name,
    1.99 +         axclass_name = axclass_name,
   1.100 +         consts = consts @ consts_new,
   1.101 +         tycos = tycos
   1.102 +       }
   1.103 +    |> fold (add_const class) consts_new
   1.104 +  end;
   1.105 +
   1.106 +in
   1.107 +
   1.108 +val add_consts = gen_add_consts Sign.intern_class Sign.intern_const;
   1.109 +val add_consts_i = gen_add_consts (K I) (K I);
   1.110 +
   1.111 +end; (* local *)
   1.112 +
   1.113 +val the_consts = #consts oo get_class_data;
   1.114 +
   1.115 +
   1.116 +(* assign type constructors to type classes *)
   1.117 +
   1.118 +local
   1.119 +
   1.120 +fun gen_add_tycos prep_class prep_type (raw_class, raw_tycos_new) thy =
   1.121 +  let
   1.122 +    val class = prep_class thy raw_class
   1.123 +    val tycos_new = map (prep_type thy) raw_tycos_new
   1.124 +    val {locale_name, axclass_name, consts, tycos} =
   1.125 +      get_class_data thy class
   1.126 +  in
   1.127 +    thy
   1.128 +    |> put_class_data class {
   1.129 +         locale_name = locale_name,
   1.130 +         axclass_name = axclass_name,
   1.131 +         consts = consts,
   1.132 +         tycos = tycos @ tycos_new
   1.133 +       }
   1.134 +  end;
   1.135 +
   1.136 +in
   1.137 +
   1.138 +fun add_tycos xs thy =
   1.139 +  gen_add_tycos Sign.intern_class (rpair (Context.theory_name thy) oo Sign.intern_type) xs thy;
   1.140 +val add_tycos_i = gen_add_tycos (K I) (K I);
   1.141 +
   1.142 +end; (* local *)
   1.143 +
   1.144 +val the_tycos = #tycos oo get_class_data;
   1.145 +
   1.146 +
   1.147 +(* class queries *)
   1.148 +
   1.149 +fun is_class thy = is_some o lookup_class_data thy;
   1.150 +
   1.151 +fun filter_class thy = filter (is_class thy);
   1.152 +
   1.153 +fun assert_class thy class =
   1.154 +  if is_class thy class then class
   1.155 +  else error ("not a class: " ^ quote class);
   1.156 +
   1.157 +fun get_arities thy sort tycon =
   1.158 +  Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort
   1.159 +  |> (map o map) (assert_class thy);
   1.160 +
   1.161 +fun get_superclasses thy class =
   1.162 +  Sorts.superclasses (Sign.classes_of thy) class
   1.163 +  |> filter_class thy;
   1.164 +
   1.165 +
   1.166 +(* instance queries *)
   1.167 +
   1.168 +fun get_const_sign thy const =
   1.169 +  let
   1.170 +    val class = (the o lookup_const_class thy) const;
   1.171 +    val ty = (Type.unvarifyT o Sign.the_const_constraint thy) const;
   1.172 +    val tvar = fold_atyps
   1.173 +      (fn TFree (tvar, sort) =>
   1.174 +        if Sorts.sort_eq (Sign.classes_of thy) ([class], sort) then K (SOME tvar) else I | _ => I) ty NONE
   1.175 +      |> the;
   1.176 +    val ty' = map_type_tfree (fn (tvar', sort) =>
   1.177 +        if tvar' = tvar
   1.178 +        then TFree (tvar, [])
   1.179 +        else TFree (tvar', sort)
   1.180 +      ) ty;
   1.181 +  in (tvar, ty') end;
   1.182 +
   1.183 +fun get_inst_consts_sign thy (tyco, class) =
   1.184 +  let
   1.185 +    val consts = the_consts thy class;
   1.186 +    val arities = get_arities thy [class] tyco;
   1.187 +    val const_signs = map (get_const_sign thy) consts;
   1.188 +    val vars_used = fold (fn (tvar, ty) => curry (gen_union (op =))
   1.189 +      (map fst (typ_tfrees ty) |> remove (op =) tvar)) const_signs [];
   1.190 +    val vars_new = Term.invent_names vars_used "'a" (length arities);
   1.191 +    val typ_arity = Type (tyco, map2 TFree (vars_new, arities));
   1.192 +    val instmem_signs =
   1.193 +      map (fn (tvar, ty) => typ_subst_atomic [(TFree (tvar, []), typ_arity)] ty) const_signs;
   1.194 +  in consts ~~ instmem_signs end;
   1.195 +
   1.196 +fun get_classtab thy =
   1.197 +  Symtab.fold
   1.198 +    (fn (class, { consts = consts, tycos = tycos, ... }) =>
   1.199 +      Symtab.update_new (class, (consts, tycos)))
   1.200 +       (fst (ClassesData.get thy)) Symtab.empty;
   1.201 +
   1.202 +
   1.203 +(* extracting dictionary obligations from types *)
   1.204 +
   1.205 +type sortcontext = (string * sort) list;
   1.206 +
   1.207 +fun extract_sortctxt thy typ =
   1.208 +  (typ_tfrees o Type.unvarifyT) typ
   1.209 +  |> map (apsnd (filter_class thy))
   1.210 +  |> filter (not o null o snd);
   1.211 +
   1.212 +datatype sortlookup = Instance of (class * string) * sortlookup list list
   1.213 +                    | Lookup of class list * (string * int)
   1.214 +
   1.215 +fun extract_sortlookup thy (raw_typ_def, raw_typ_use) =
   1.216 +  let
   1.217 +    val typ_def = Type.varifyT raw_typ_def;
   1.218 +    val typ_use = Type.varifyT raw_typ_use;
   1.219 +    val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
   1.220 +    fun tab_lookup vname = (the o Vartab.lookup match_tab) (vname, 0);
   1.221 +    fun get_superclass_derivation (subclasses, superclass) =
   1.222 +      (the oo get_first) (fn subclass =>
   1.223 +        Sorts.class_le_path (Sign.classes_of thy) (subclass, superclass)
   1.224 +      ) subclasses;
   1.225 +    fun mk_class_deriv thy subclasses superclass =
   1.226 +      case get_superclass_derivation (subclasses, superclass)
   1.227 +      of (subclass::deriv) => (rev deriv, find_index_eq subclass subclasses);
   1.228 +    fun mk_lookup (sort_def, (Type (tycon, tys))) =
   1.229 +          let
   1.230 +            val arity_lookup = map2 mk_lookup
   1.231 +              (map (filter_class thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort_def), tys)
   1.232 +          in map (fn class => Instance ((class, tycon), arity_lookup)) sort_def end
   1.233 +      | mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
   1.234 +          let
   1.235 +            fun mk_look class =
   1.236 +              let val (deriv, classindex) = mk_class_deriv thy sort_use class
   1.237 +              in Lookup (deriv, (vname, classindex)) end;
   1.238 +          in map mk_look sort_def end;
   1.239 +  in
   1.240 +    extract_sortctxt thy raw_typ_def
   1.241 +    |> map (tab_lookup o fst)
   1.242 +    |> map (apfst (filter_class thy))
   1.243 +    |> filter (not o null o fst)
   1.244 +    |> map mk_lookup
   1.245 +  end;
   1.246 +
   1.247 +
   1.248 +(* outer syntax *)
   1.249 +
   1.250 +local
   1.251 +
   1.252 +structure P = OuterParse
   1.253 +and K = OuterKeyword;
   1.254 +
   1.255 +in
   1.256 +
   1.257 +val classcgK = "codegen_class";
   1.258 +
   1.259 +fun classcg raw_class raw_consts raw_tycos thy =
   1.260 +  let
   1.261 +    val class = Sign.intern_class thy raw_class;
   1.262 +  in
   1.263 +    thy
   1.264 +    |> put_class_data class {
   1.265 +         locale_name = "",
   1.266 +         axclass_name = class,
   1.267 +         consts = [],
   1.268 +         tycos = []
   1.269 +       }
   1.270 +    |> add_consts (class, raw_consts)
   1.271 +    |> add_tycos (class, raw_tycos)
   1.272 +  end
   1.273 +
   1.274 +val classcgP =
   1.275 +  OuterSyntax.command classcgK "codegen data for classes" K.thy_decl (
   1.276 +    P.xname
   1.277 +    -- ((P.$$$ "\\<Rightarrow>" || P.$$$ "=>") |-- (P.list1 P.name))
   1.278 +    -- (Scan.optional ((P.$$$ "\\<Rightarrow>" || P.$$$ "=>") |-- (P.list1 P.name)) [])
   1.279 +    >> (fn ((name, tycos), consts) => (Toplevel.theory (classcg name consts tycos)))
   1.280 +  )
   1.281 +
   1.282 +val _ = OuterSyntax.add_parsers [classcgP];
   1.283 +
   1.284 +val _ = OuterSyntax.add_keywords ["\\<Rightarrow>", "=>"];
   1.285 +
   1.286 +end; (* local *)
   1.287 +
   1.288 +
   1.289 +(* setup *)
   1.290 +
   1.291 +val _ = Context.add_setup [ClassesData.init];
   1.292 +
   1.293 +end; (* struct *)