src/Pure/Tools/class_package.ML
changeset 18702 7dc7dcd63224
parent 18670 c3f445b92aff
child 18708 4b3dadb4fe33
equal deleted inserted replaced
18701:98e6a0a011f3 18702:7dc7dcd63224
    16     -> theory -> Proof.state
    16     -> theory -> Proof.state
    17   val add_instance_arity_i: (string * sort list) * sort
    17   val add_instance_arity_i: (string * sort list) * sort
    18     -> ((bstring * term) * theory attribute list) list
    18     -> ((bstring * term) * theory attribute list) list
    19     -> theory -> Proof.state
    19     -> theory -> Proof.state
    20   val add_classentry: class -> xstring list -> xstring list -> theory -> theory
    20   val add_classentry: class -> xstring list -> xstring list -> theory -> theory
    21   val the_consts: theory -> class -> string list
    21 
    22   val the_tycos: theory -> class -> (string * string) list
    22   val syntactic_sort_of: theory -> sort -> sort
       
    23   val the_superclasses: theory -> class -> class list
       
    24   val the_consts_sign: theory -> class -> string * (string * typ) list
       
    25   val lookup_const_class: theory -> string -> class option
       
    26   val the_instances: theory -> class -> (string * string) list
       
    27   val the_inst_sign: theory -> class * string -> (string * sort) list * (string * typ) list
       
    28   val get_classtab: theory -> (string list * (string * string) list) Symtab.table
    23   val print_classes: theory -> unit
    29   val print_classes: theory -> unit
    24 
       
    25   val syntactic_sort_of: theory -> sort -> sort
       
    26   val get_arities: theory -> sort -> string -> sort list
       
    27   val get_superclasses: theory -> class -> class list
       
    28   val get_const_sign: theory -> string -> string -> typ
       
    29   val get_inst_consts_sign: theory -> string * class -> (string * typ) list
       
    30   val lookup_const_class: theory -> string -> class option
       
    31   val get_classtab: theory -> (string list * (string * string) list) Symtab.table
       
    32 
    30 
    33   type sortcontext = (string * sort) list
    31   type sortcontext = (string * sort) list
    34   datatype sortlookup = Instance of (class * string) * sortlookup list list
    32   datatype sortlookup = Instance of (class * string) * sortlookup list list
    35                       | Lookup of class list * (string * int)
    33                       | Lookup of class list * (string * int)
    36   val extract_sortctxt: theory -> typ -> sortcontext
    34   val extract_sortctxt: theory -> typ -> sortcontext
    37   val extract_sortlookup: theory -> typ * typ -> sortlookup list list
    35   val extract_sortlookup: theory -> string * typ -> sortlookup list list
    38 end;
    36 end;
    39 
    37 
    40 structure ClassPackage: CLASS_PACKAGE =
    38 structure ClassPackage: CLASS_PACKAGE =
    41 struct
    39 struct
    42 
    40 
   124            var = var,
   122            var = var,
   125            consts = consts,
   123            consts = consts,
   126            insts = insts @ [inst]
   124            insts = insts @ [inst]
   127           });
   125           });
   128 
   126 
   129 val the_consts = map fst o #consts oo get_class_data;
       
   130 val the_tycos = #insts oo get_class_data;
       
   131 
       
   132 
   127 
   133 (* classes and instances *)
   128 (* classes and instances *)
   134 
   129 
       
   130 fun subst_clsvar v ty_subst =
       
   131   map_type_tfree (fn u as (w, _) =>
       
   132     if w = v then ty_subst else TFree u);
       
   133 
   135 local
   134 local
   136 
   135 
   137 open Element
   136 open Element
   138 
   137 
   139 fun gen_add_class add_locale bname raw_import raw_body thy =
   138 fun gen_add_class add_locale bname raw_import raw_body thy =
   140   let
   139   let
   141     fun subst_clsvar v ty_subst =
       
   142       map_type_tfree (fn u as (w, _) =>
       
   143         if w = v then ty_subst else TFree u);
       
   144     fun extract_assumes c_adds elems =
   140     fun extract_assumes c_adds elems =
   145       let
   141       let
   146         fun subst_free ts =
   142         fun subst_free ts =
   147           let
   143           let
   148             val get_ty = the o AList.lookup (op =) (fold Term.add_frees ts []);
   144             val get_ty = the o AList.lookup (op =) (fold Term.add_frees ts []);
   238         |> pair (c, ty1)
   234         |> pair (c, ty1)
   239       end;
   235       end;
   240     fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
   236     fun get_c_given thy = map (fst o dest_def o snd o tap_def thy o fst) raw_defs;
   241     fun check_defs c_given c_req thy =
   237     fun check_defs c_given c_req thy =
   242       let
   238       let
   243         fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2 andalso Sign.typ_instance thy (ty1, ty2)
   239         fun eq_c ((c1, ty1), (c2, ty2)) = c1 = c2
       
   240           andalso Sign.typ_instance thy (ty1, ty2)
       
   241           andalso Sign.typ_instance thy (ty2, ty1)
   244         val _ = case fold (remove eq_c) c_given c_req
   242         val _ = case fold (remove eq_c) c_given c_req
   245          of [] => ()
   243          of [] => ()
   246           | cs => error ("no definition(s) given for"
   244           | cs => error ("no definition(s) given for"
   247                     ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
   245                     ^ (commas o map (fn (c, ty) => quote (c ^ "::" ^ Sign.string_of_typ thy ty))) cs);
   248         val _ = case fold (remove eq_c) c_req c_given
   246         val _ = case fold (remove eq_c) c_req c_given
   261 
   259 
   262 val add_instance_arity = fn x => gen_instance_arity (AxClass.read_arity) IsarThy.add_defs read_axm x;
   260 val add_instance_arity = fn x => gen_instance_arity (AxClass.read_arity) IsarThy.add_defs read_axm x;
   263 val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
   261 val add_instance_arity_i = fn x => gen_instance_arity (AxClass.cert_arity) IsarThy.add_defs_i (K I) x;
   264 
   262 
   265 
   263 
   266 (* class queries *)
   264 (* queries *)
   267 
   265 
   268 fun is_class thy cls = lookup_class_data thy cls |> Option.map (not o null o #consts) |> the_default false;
   266 fun is_class thy cls =
       
   267   lookup_class_data thy cls
       
   268   |> Option.map (not o null o #consts)
       
   269   |> the_default false;
   269 
   270 
   270 fun syntactic_sort_of thy sort =
   271 fun syntactic_sort_of thy sort =
   271   let
   272   let
   272     val classes = Sign.classes_of thy;
   273     val classes = Sign.classes_of thy;
   273     fun get_sort cls =
   274     fun get_sort cls =
   278     map get_sort sort
   279     map get_sort sort
   279     |> Library.flat
   280     |> Library.flat
   280     |> Sorts.norm_sort classes
   281     |> Sorts.norm_sort classes
   281   end;
   282   end;
   282 
   283 
   283 fun get_arities thy sort tycon =
   284 fun the_superclasses thy class =
   284   Sorts.mg_domain (Sign.classes_arities_of thy) tycon (syntactic_sort_of thy sort)
       
   285   |> map (syntactic_sort_of thy);
       
   286 
       
   287 fun get_superclasses thy class =
       
   288   if is_class thy class
   285   if is_class thy class
   289   then
   286   then
   290     Sorts.superclasses (Sign.classes_of thy) class
   287     Sorts.superclasses (Sign.classes_of thy) class
   291     |> syntactic_sort_of thy
   288     |> syntactic_sort_of thy
   292   else
   289   else
   293     error ("no syntactic class: " ^ class);
   290     error ("no syntactic class: " ^ class);
   294 
   291 
   295 
   292 fun the_consts_sign thy class =
   296 (* instance queries *)
   293   let
   297 
   294     val data = (the oo Symtab.lookup) ((fst o ClassData.get) thy) class
   298 fun mk_const_sign thy class tvar ty =
   295   in (#var data, #consts data) end;
   299   let
   296 
   300     val (ty', thaw) = Type.freeze_thaw_type ty;
   297 fun lookup_const_class thy =
   301     val tvars_used = Term.add_tfreesT ty' [];
   298   Symtab.lookup ((snd o ClassData.get) thy);
   302     val tvar_rename = hd (Term.invent_names (map fst tvars_used) tvar 1);
   299 
   303   in
   300 fun the_instances thy class =
   304     ty'
   301   (#insts o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
   305     |> map_type_tfree (fn (tvar', sort) =>
   302 
   306           if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
   303 fun the_inst_sign thy (class, tyco) =
   307           then TFree (tvar, [])
   304   let
   308           else if tvar' = tvar
   305     val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
   309           then TVar ((tvar_rename, 0), sort)
   306     val arity = 
   310           else TFree (tvar', sort))
   307       Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]
   311     |> thaw
   308       |> map (syntactic_sort_of thy);
   312   end;
   309     val clsvar = (#var o the o Symtab.lookup ((fst o ClassData.get) thy)) class;
   313 
   310     val const_sign = (snd o the_consts_sign thy) class;
   314 fun get_const_sign thy tvar const =
   311     fun add_var sort used =
   315   let
   312       let
   316     val class = (the o lookup_const_class thy) const;
   313         val v = hd (Term.invent_names used "'a" 1)
   317     val ty = Sign.the_const_constraint thy const;
   314       in ((v, sort), v::used) end;
   318   in mk_const_sign thy class tvar ty end;
   315     val (vsorts, _) =
   319 
   316       []
   320 fun get_inst_consts_sign thy (tyco, class) =
   317       |> fold (fn (_, ty) => curry (gen_union (op =))
   321   let
   318            ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) const_sign
   322     val consts = the_consts thy class;
   319       |> fold_map add_var arity;
   323     val arities = get_arities thy [class] tyco;
   320     val ty_inst = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vsorts);
   324     val const_signs = map (get_const_sign thy "'a") consts;
   321     val inst_signs = map (apsnd (subst_clsvar clsvar ty_inst)) const_sign;
   325     val vars_used = fold (fn ty => curry (gen_union (op =))
   322   in (vsorts, inst_signs) end;
   326       (map fst (typ_tfrees ty) |> remove (op =) "'a")) const_signs [];
       
   327     val vars_new = Term.invent_names vars_used "'a" (length arities);
       
   328     val typ_arity = Type (tyco, map2 (curry TFree) vars_new arities);
       
   329     val instmem_signs =
       
   330       map (typ_subst_TVars [(("'a", 0), typ_arity)]) const_signs;
       
   331   in consts ~~ instmem_signs end;
       
   332 
   323 
   333 fun get_classtab thy =
   324 fun get_classtab thy =
   334   Symtab.fold
   325   Symtab.fold
   335     (fn (class, { consts = consts, insts = insts, ... }) =>
   326     (fn (class, { consts = consts, insts = insts, ... }) =>
   336       Symtab.update_new (class, (map fst consts, insts)))
   327       Symtab.update_new (class, (map fst consts, insts)))
   337        (fst (ClassData.get thy)) Symtab.empty;
   328        ((fst o ClassData.get) thy) Symtab.empty;
   338 
   329 
   339 
   330 
   340 (* extracting dictionary obligations from types *)
   331 (* extracting dictionary obligations from types *)
   341 
   332 
   342 type sortcontext = (string * sort) list;
   333 type sortcontext = (string * sort) list;
   343 
   334 
   344 fun extract_sortctxt thy ty =
   335 fun extract_sortctxt thy ty =
   345   (typ_tfrees o Type.no_tvars) ty
   336   (typ_tfrees o fst o Type.freeze_thaw_type) ty
   346   |> map (apsnd (syntactic_sort_of thy))
   337   |> map (apsnd (syntactic_sort_of thy))
   347   |> filter (not o null o snd);
   338   |> filter (not o null o snd);
   348 
   339 
   349 datatype sortlookup = Instance of (class * string) * sortlookup list list
   340 datatype sortlookup = Instance of (class * string) * sortlookup list list
   350                     | Lookup of class list * (string * int)
   341                     | Lookup of class list * (string * int)
   351 
   342 
   352 fun extract_sortlookup thy (raw_typ_def, raw_typ_use) =
   343 fun extract_sortlookup thy (c, raw_typ_use) =
   353   let
   344   let
       
   345     val raw_typ_def = Sign.the_const_constraint thy c;
   354     val typ_def = Type.varifyT raw_typ_def;
   346     val typ_def = Type.varifyT raw_typ_def;
   355     val typ_use = Type.varifyT raw_typ_use;
   347     val typ_use = Type.varifyT raw_typ_use;
   356     val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
   348     val match_tab = Sign.typ_match thy (typ_def, typ_use) Vartab.empty;
   357     fun tab_lookup vname = (the o Vartab.lookup match_tab) (vname, 0);
   349     fun tab_lookup vname = (the o Vartab.lookup match_tab) (vname, 0);
   358     fun get_superclass_derivation (subclasses, superclass) =
   350     fun get_superclass_derivation (subclasses, superclass) =
   372           let
   364           let
   373             fun mk_look class =
   365             fun mk_look class =
   374               let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class
   366               let val (deriv, classindex) = mk_class_deriv thy (syntactic_sort_of thy sort_use) class
   375               in Lookup (deriv, (vname, classindex)) end;
   367               in Lookup (deriv, (vname, classindex)) end;
   376           in map mk_look sort_def end;
   368           in map mk_look sort_def end;
       
   369     fun reorder_sortctxt ctxt =
       
   370       case lookup_const_class thy c
       
   371        of NONE => ctxt
       
   372         | SOME class =>
       
   373             let
       
   374               val data = (the o Symtab.lookup ((fst o ClassData.get) thy)) class;
       
   375               val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
       
   376               val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
       
   377               val v : string = case Vartab.lookup match_tab (#var data, 0)
       
   378                 of SOME (_, TVar ((v, _), _)) => v;
       
   379             in
       
   380               (v, (the o AList.lookup (op =) ctxt) v) :: AList.delete (op =) v ctxt
       
   381             end;
   377   in
   382   in
   378     extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
   383     extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
       
   384     |> reorder_sortctxt
   379     |> map (tab_lookup o fst)
   385     |> map (tab_lookup o fst)
   380     |> map (apfst (syntactic_sort_of thy))
   386     |> map (apfst (syntactic_sort_of thy))
   381     |> filter (not o null o fst)
   387     |> filter (not o null o fst)
   382     |> map mk_lookup
   388     |> map mk_lookup
   383   end;
   389   end;
   386 (* intermediate auxiliary *)
   392 (* intermediate auxiliary *)
   387 
   393 
   388 fun add_classentry raw_class raw_cs raw_insts thy =
   394 fun add_classentry raw_class raw_cs raw_insts thy =
   389   let
   395   let
   390     val class = Sign.intern_class thy raw_class;
   396     val class = Sign.intern_class thy raw_class;
   391     val cs = raw_cs |> map (Sign.intern_const thy);
   397     val cs_proto =
       
   398       raw_cs
       
   399       |> map (Sign.intern_const thy)
       
   400       |> map (fn c => (c, Sign.the_const_constraint thy c));
       
   401     val used = 
       
   402       []
       
   403       |> fold (fn (_, ty) => curry (gen_union (op =))
       
   404            ((map (fst o fst) o typ_tvars) ty @ (map fst o typ_tfrees) ty)) cs_proto
       
   405     val v = hd (Term.invent_names used "'a" 1);
       
   406     val cs =
       
   407       cs_proto
       
   408       |> map (fn (c, ty) => (c, map_type_tvar (fn var as ((tvar', _), sort) =>
       
   409           if Sorts.sort_eq (Sign.classes_of thy) ([class], sort)
       
   410           then TFree (v, [])
       
   411           else TVar var
       
   412          ) ty));
   392     val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
   413     val insts = map (rpair (Context.theory_name thy) o Sign.intern_type thy) raw_insts;
   393   in
   414   in
   394     thy
   415     thy
   395     |> add_class_data (class, ([], "", class, "", map (rpair dummyT) cs))
   416     |> add_class_data (class, ([], "", class, v, cs))
   396     |> fold (curry add_inst_data class) insts
   417     |> fold (curry add_inst_data class) insts
   397   end;
   418   end;
   398 
   419 
   399 
   420 
   400 (* toplevel interface *)
   421 (* toplevel interface *)