Coercive subtyping via subtype constraints, by Dmitriy Traytel (21-Oct-2010).
1.1 --- a/src/HOL/IsaMakefile Fri Oct 29 18:17:11 2010 +0200
1.2 +++ b/src/HOL/IsaMakefile Fri Oct 29 21:34:07 2010 +0200
1.3 @@ -1012,8 +1012,8 @@
1.4 Number_Theory/Primes.thy ex/Abstract_NAT.thy ex/Antiquote.thy \
1.5 ex/Arith_Examples.thy ex/Arithmetic_Series_Complex.thy ex/BT.thy \
1.6 ex/BinEx.thy ex/Binary.thy ex/CTL.thy ex/Chinese.thy \
1.7 - ex/Classical.thy ex/CodegenSML_Test.thy ex/Coherent.thy \
1.8 - ex/Dedekind_Real.thy ex/Efficient_Nat_examples.thy \
1.9 + ex/Classical.thy ex/CodegenSML_Test.thy ex/Coercion_Examples.thy \
1.10 + ex/Coherent.thy ex/Dedekind_Real.thy ex/Efficient_Nat_examples.thy \
1.11 ex/Eval_Examples.thy ex/Fundefs.thy ex/Gauge_Integration.thy \
1.12 ex/Groebner_Examples.thy ex/Guess.thy ex/HarmonicSeries.thy \
1.13 ex/Hebrew.thy ex/Hex_Bin_Examples.thy ex/Higher_Order_Logic.thy \
2.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
2.2 +++ b/src/HOL/ex/Coercion_Examples.thy Fri Oct 29 21:34:07 2010 +0200
2.3 @@ -0,0 +1,172 @@
2.4 +theory Coercion_Examples
2.5 +imports Main
2.6 +uses "~~/src/Tools/subtyping.ML"
2.7 +begin
2.8 +
2.9 +(* Coercion/type maps definitions*)
2.10 +
2.11 +consts func :: "(nat \<Rightarrow> int) \<Rightarrow> nat"
2.12 +consts arg :: "int \<Rightarrow> nat"
2.13 +(* Invariant arguments
2.14 +term "func arg"
2.15 +*)
2.16 +(* No subtype relation - constraint
2.17 +term "(1::nat)::int"
2.18 +*)
2.19 +consts func' :: "int \<Rightarrow> int"
2.20 +consts arg' :: "nat"
2.21 +(* No subtype relation - function application
2.22 +term "func' arg'"
2.23 +*)
2.24 +(* Uncomparable types in bound
2.25 +term "arg' = True"
2.26 +*)
2.27 +(* Unfullfilled type class requirement
2.28 +term "1 = True"
2.29 +*)
2.30 +(* Different constructors
2.31 +term "[1::int] = func"
2.32 +*)
2.33 +
2.34 +primrec nat_of_bool :: "bool \<Rightarrow> nat"
2.35 +where
2.36 + "nat_of_bool False = 0"
2.37 +| "nat_of_bool True = 1"
2.38 +
2.39 +declare [[coercion nat_of_bool]]
2.40 +
2.41 +declare [[coercion int]]
2.42 +
2.43 +declare [[map_function map]]
2.44 +
2.45 +definition map_fun :: "('a \<Rightarrow> 'b) \<Rightarrow> ('c \<Rightarrow> 'd) \<Rightarrow> ('b \<Rightarrow> 'c) \<Rightarrow> ('a \<Rightarrow> 'd)" where
2.46 + "map_fun f g h = g o h o f"
2.47 +
2.48 +declare [[map_function "\<lambda> f g h . g o h o f"]]
2.49 +
2.50 +primrec map_pair :: "('a \<Rightarrow> 'c) \<Rightarrow> ('b \<Rightarrow> 'd) \<Rightarrow> ('a * 'b) \<Rightarrow> ('c * 'd)" where
2.51 + "map_pair f g (x,y) = (f x, g y)"
2.52 +
2.53 +declare [[map_function map_pair]]
2.54 +
2.55 +(* Examples taken from the haskell draft implementation *)
2.56 +
2.57 +term "(1::nat) = True"
2.58 +
2.59 +term "True = (1::nat)"
2.60 +
2.61 +term "(1::nat) = (True = (1::nat))"
2.62 +
2.63 +term "op = (True = (1::nat))"
2.64 +
2.65 +term "[1::nat,True]"
2.66 +
2.67 +term "[True,1::nat]"
2.68 +
2.69 +term "[1::nat] = [True]"
2.70 +
2.71 +term "[True] = [1::nat]"
2.72 +
2.73 +term "[[True]] = [[1::nat]]"
2.74 +
2.75 +term "[[[[[[[[[[True]]]]]]]]]] = [[[[[[[[[[1::nat]]]]]]]]]]"
2.76 +
2.77 +term "[[True],[42::nat]] = rev [[True]]"
2.78 +
2.79 +term "rev [10000::nat] = [False, 420000::nat, True]"
2.80 +
2.81 +term "\<lambda> x . x = (3::nat)"
2.82 +
2.83 +term "(\<lambda> x . x = (3::nat)) True"
2.84 +
2.85 +term "map (\<lambda> x . x = (3::nat))"
2.86 +
2.87 +term "map (\<lambda> x . x = (3::nat)) [True,1::nat]"
2.88 +
2.89 +consts bnn :: "(bool \<Rightarrow> nat) \<Rightarrow> nat"
2.90 +consts nb :: "nat \<Rightarrow> bool"
2.91 +consts ab :: "'a \<Rightarrow> bool"
2.92 +
2.93 +term "bnn nb"
2.94 +
2.95 +term "bnn ab"
2.96 +
2.97 +term "\<lambda> x . x = (3::int)"
2.98 +
2.99 +term "map (\<lambda> x . x = (3::int)) [True]"
2.100 +
2.101 +term "map (\<lambda> x . x = (3::int)) [True,1::nat]"
2.102 +
2.103 +term "map (\<lambda> x . x = (3::int)) [True,1::nat,1::int]"
2.104 +
2.105 +term "[1::nat,True,1::int,False]"
2.106 +
2.107 +term "map (map (\<lambda> x . x = (3::int))) [[True],[1::nat],[True,1::int]]"
2.108 +
2.109 +consts cbool :: "'a \<Rightarrow> bool"
2.110 +consts cnat :: "'a \<Rightarrow> nat"
2.111 +consts cint :: "'a \<Rightarrow> int"
2.112 +
2.113 +term "[id, cbool, cnat, cint]"
2.114 +
2.115 +consts funfun :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b"
2.116 +consts flip :: "('a \<Rightarrow> 'b \<Rightarrow> 'c) \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'c"
2.117 +
2.118 +term "flip funfun"
2.119 +
2.120 +term "map funfun [id,cnat,cint,cbool]"
2.121 +
2.122 +term "map (flip funfun True)"
2.123 +
2.124 +term "map (flip funfun True) [id,cnat,cint,cbool]"
2.125 +
2.126 +consts ii :: "int \<Rightarrow> int"
2.127 +consts aaa :: "'a \<Rightarrow> 'a \<Rightarrow> 'a"
2.128 +consts nlist :: "nat list"
2.129 +consts ilil :: "int list \<Rightarrow> int list"
2.130 +
2.131 +term "ii (aaa (1::nat) True)"
2.132 +
2.133 +term "map ii nlist"
2.134 +
2.135 +term "ilil nlist"
2.136 +
2.137 +(***************************************************)
2.138 +
2.139 +(* Other examples *)
2.140 +
2.141 +definition xs :: "bool list" where "xs = [True]"
2.142 +
2.143 +term "(xs::nat list)"
2.144 +
2.145 +term "(1::nat) = True"
2.146 +
2.147 +term "True = (1::nat)"
2.148 +
2.149 +term "int (1::nat)"
2.150 +
2.151 +term "((True::nat)::int)"
2.152 +
2.153 +term "1::nat"
2.154 +
2.155 +term "nat 1"
2.156 +
2.157 +definition C :: nat
2.158 +where "C = 123"
2.159 +
2.160 +consts g :: "int \<Rightarrow> int"
2.161 +consts h :: "nat \<Rightarrow> nat"
2.162 +
2.163 +term "(g (1::nat)) + (h 2)"
2.164 +
2.165 +term "g 1"
2.166 +
2.167 +term "1+(1::nat)"
2.168 +
2.169 +term "((1::int) + (1::nat),(1::int))"
2.170 +
2.171 +definition ys :: "bool list list list list list" where "ys=[[[[[True]]]]]"
2.172 +
2.173 +term "ys=[[[[[1::nat]]]]]"
2.174 +
2.175 +end
3.1 --- a/src/HOL/ex/ROOT.ML Fri Oct 29 18:17:11 2010 +0200
3.2 +++ b/src/HOL/ex/ROOT.ML Fri Oct 29 21:34:07 2010 +0200
3.3 @@ -13,6 +13,7 @@
3.4
3.5 use_thys [
3.6 "Iff_Oracle",
3.7 + "Coercion_Examples",
3.8 "Numeral",
3.9 "Higher_Order_Logic",
3.10 "Abstract_NAT",
4.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
4.2 +++ b/src/Tools/subtyping.ML Fri Oct 29 21:34:07 2010 +0200
4.3 @@ -0,0 +1,766 @@
4.4 +(* Title: Tools/subtyping.ML
4.5 + Author: Dmitriy Traytel, TU Muenchen
4.6 +
4.7 +Coercive subtyping via subtype constraints.
4.8 +*)
4.9 +
4.10 +signature SUBTYPING =
4.11 +sig
4.12 + datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
4.13 + val infer_types: Proof.context -> (string -> typ option) -> (indexname -> typ option) ->
4.14 + term list -> term list
4.15 +end;
4.16 +
4.17 +structure Subtyping =
4.18 +struct
4.19 +
4.20 +
4.21 +
4.22 +(** coercions data **)
4.23 +
4.24 +datatype variance = COVARIANT | CONTRAVARIANT | INVARIANT
4.25 +
4.26 +datatype data = Data of
4.27 + {coes: term Symreltab.table, (* coercions table *)
4.28 + coes_graph: unit Graph.T, (* coercions graph *)
4.29 + tmaps: (term * variance list) Symtab.table}; (* map functions *)
4.30 +
4.31 +fun make_data (coes, coes_graph, tmaps) =
4.32 + Data {coes = coes, coes_graph = coes_graph, tmaps = tmaps};
4.33 +
4.34 +structure Data = Generic_Data
4.35 +(
4.36 + type T = data;
4.37 + val empty = make_data (Symreltab.empty, Graph.empty, Symtab.empty);
4.38 + val extend = I;
4.39 + fun merge
4.40 + (Data {coes = coes1, coes_graph = coes_graph1, tmaps = tmaps1},
4.41 + Data {coes = coes2, coes_graph = coes_graph2, tmaps = tmaps2}) =
4.42 + make_data (Symreltab.merge (op aconv) (coes1, coes2),
4.43 + Graph.merge (op =) (coes_graph1, coes_graph2),
4.44 + Symtab.merge (eq_pair (op aconv) (op =)) (tmaps1, tmaps2));
4.45 +);
4.46 +
4.47 +fun map_data f =
4.48 + Data.map (fn Data {coes, coes_graph, tmaps} =>
4.49 + make_data (f (coes, coes_graph, tmaps)));
4.50 +
4.51 +fun map_coes f =
4.52 + map_data (fn (coes, coes_graph, tmaps) =>
4.53 + (f coes, coes_graph, tmaps));
4.54 +
4.55 +fun map_coes_graph f =
4.56 + map_data (fn (coes, coes_graph, tmaps) =>
4.57 + (coes, f coes_graph, tmaps));
4.58 +
4.59 +fun map_coes_and_graph f =
4.60 + map_data (fn (coes, coes_graph, tmaps) =>
4.61 + let val (coes', coes_graph') = f (coes, coes_graph);
4.62 + in (coes', coes_graph', tmaps) end);
4.63 +
4.64 +fun map_tmaps f =
4.65 + map_data (fn (coes, coes_graph, tmaps) =>
4.66 + (coes, coes_graph, f tmaps));
4.67 +
4.68 +fun rep_data context = Data.get context |> (fn Data args => args);
4.69 +
4.70 +val coes_of = #coes o rep_data;
4.71 +val coes_graph_of = #coes_graph o rep_data;
4.72 +val tmaps_of = #tmaps o rep_data;
4.73 +
4.74 +
4.75 +
4.76 +(** utils **)
4.77 +
4.78 +val is_param = Type_Infer.is_param
4.79 +val is_paramT = Type_Infer.is_paramT
4.80 +val deref = Type_Infer.deref
4.81 +fun mk_param i S = TVar (("?'a", i), S); (* TODO dup? see src/Pure/type_infer.ML *)
4.82 +
4.83 +fun nameT (Type (s, [])) = s;
4.84 +fun t_of s = Type (s, []);
4.85 +fun sort_of (TFree (_, S)) = SOME S
4.86 + | sort_of (TVar (_, S)) = SOME S
4.87 + | sort_of _ = NONE;
4.88 +
4.89 +val is_typeT = fn (Type _) => true | _ => false;
4.90 +val is_compT = fn (Type (_, _::_)) => true | _ => false;
4.91 +val is_freeT = fn (TFree _) => true | _ => false;
4.92 +val is_fixedvarT = fn (TVar (xi, _)) => not (is_param xi) | _ => false;
4.93 +
4.94 +
4.95 +(* unification TODO dup? needed for weak unification *)
4.96 +
4.97 +exception NO_UNIFIER of string * typ Vartab.table;
4.98 +
4.99 +fun unify weak ctxt =
4.100 + let
4.101 + val thy = ProofContext.theory_of ctxt;
4.102 + val pp = Syntax.pp ctxt;
4.103 + val arity_sorts = Type.arity_sorts pp (Sign.tsig_of thy);
4.104 +
4.105 +
4.106 + (* adjust sorts of parameters *)
4.107 +
4.108 + fun not_of_sort x S' S =
4.109 + "Variable " ^ x ^ "::" ^ Syntax.string_of_sort ctxt S' ^ " not of sort " ^
4.110 + Syntax.string_of_sort ctxt S;
4.111 +
4.112 + fun meet (_, []) tye_idx = tye_idx
4.113 + | meet (Type (a, Ts), S) (tye_idx as (tye, _)) =
4.114 + meets (Ts, arity_sorts a S handle ERROR msg => raise NO_UNIFIER (msg, tye)) tye_idx
4.115 + | meet (TFree (x, S'), S) (tye_idx as (tye, _)) =
4.116 + if Sign.subsort thy (S', S) then tye_idx
4.117 + else raise NO_UNIFIER (not_of_sort x S' S, tye)
4.118 + | meet (TVar (xi, S'), S) (tye_idx as (tye, idx)) =
4.119 + if Sign.subsort thy (S', S) then tye_idx
4.120 + else if Type_Infer.is_param xi then
4.121 + (Vartab.update_new (xi, mk_param idx (Sign.inter_sort thy (S', S))) tye, idx + 1)
4.122 + else raise NO_UNIFIER (not_of_sort (Term.string_of_vname xi) S' S, tye)
4.123 + and meets (T :: Ts, S :: Ss) (tye_idx as (tye, _)) =
4.124 + meets (Ts, Ss) (meet (deref tye T, S) tye_idx)
4.125 + | meets _ tye_idx = tye_idx;
4.126 +
4.127 + val weak_meet = if weak then fn _ => I else meet
4.128 +
4.129 +
4.130 + (* occurs check and assignment *)
4.131 +
4.132 + fun occurs_check tye xi (TVar (xi', _)) =
4.133 + if xi = xi' then raise NO_UNIFIER ("Occurs check!", tye)
4.134 + else
4.135 + (case Vartab.lookup tye xi' of
4.136 + NONE => ()
4.137 + | SOME T => occurs_check tye xi T)
4.138 + | occurs_check tye xi (Type (_, Ts)) = List.app (occurs_check tye xi) Ts
4.139 + | occurs_check _ _ _ = ();
4.140 +
4.141 + fun assign xi (T as TVar (xi', _)) S env =
4.142 + if xi = xi' then env
4.143 + else env |> weak_meet (T, S) |>> Vartab.update_new (xi, T)
4.144 + | assign xi T S (env as (tye, _)) =
4.145 + (occurs_check tye xi T; env |> weak_meet (T, S) |>> Vartab.update_new (xi, T));
4.146 +
4.147 +
4.148 + (* unification *)
4.149 +
4.150 + fun show_tycon (a, Ts) =
4.151 + quote (Syntax.string_of_typ ctxt (Type (a, replicate (length Ts) dummyT)));
4.152 +
4.153 + fun unif (T1, T2) (env as (tye, _)) =
4.154 + (case pairself (`is_paramT o deref tye) (T1, T2) of
4.155 + ((true, TVar (xi, S)), (_, T)) => assign xi T S env
4.156 + | ((_, T), (true, TVar (xi, S))) => assign xi T S env
4.157 + | ((_, Type (a, Ts)), (_, Type (b, Us))) =>
4.158 + if weak andalso null Ts andalso null Us then env
4.159 + else if a <> b then
4.160 + raise NO_UNIFIER
4.161 + ("Clash of types " ^ show_tycon (a, Ts) ^ " and " ^ show_tycon (b, Us), tye)
4.162 + else fold unif (Ts ~~ Us) env
4.163 + | ((_, T), (_, U)) => if T = U then env else raise NO_UNIFIER ("", tye));
4.164 +
4.165 + in unif end;
4.166 +
4.167 +val weak_unify = unify true;
4.168 +val strong_unify = unify false;
4.169 +
4.170 +
4.171 +(* Typ_Graph shortcuts *)
4.172 +
4.173 +val add_edge = Typ_Graph.add_edge_acyclic;
4.174 +fun get_preds G T = Typ_Graph.all_preds G [T];
4.175 +fun get_succs G T = Typ_Graph.all_succs G [T];
4.176 +fun maybe_new_typnode T G = perhaps (try (Typ_Graph.new_node (T, ()))) G;
4.177 +fun maybe_new_typnodes Ts G = fold maybe_new_typnode Ts G;
4.178 +fun new_imm_preds G Ts =
4.179 + subtract (op =) Ts (distinct (op =) (maps (Typ_Graph.imm_preds G) Ts));
4.180 +fun new_imm_succs G Ts =
4.181 + subtract op= Ts (distinct (op =) (maps (Typ_Graph.imm_succs G) Ts));
4.182 +
4.183 +
4.184 +(* Graph shortcuts *)
4.185 +
4.186 +fun maybe_new_node s G = perhaps (try (Graph.new_node (s, ()))) G
4.187 +fun maybe_new_nodes ss G = fold maybe_new_node ss G
4.188 +
4.189 +
4.190 +
4.191 +(** error messages **)
4.192 +
4.193 +fun prep_output ctxt tye bs ts Ts =
4.194 + let
4.195 + val (Ts_bTs', ts') = Type_Infer.finish ctxt tye (Ts @ map snd bs, ts);
4.196 + val (Ts', Ts'') = chop (length Ts) Ts_bTs';
4.197 + fun prep t =
4.198 + let val xs = rev (Term.variant_frees t (rev (map fst bs ~~ Ts'')))
4.199 + in Term.subst_bounds (map Syntax.mark_boundT xs, t) end;
4.200 + in (map prep ts', Ts') end;
4.201 +
4.202 +fun err_loose i = error ("Loose bound variable: B." ^ string_of_int i);
4.203 +
4.204 +fun inf_failed msg =
4.205 + "Subtype inference failed" ^ (if msg = "" then "" else ": " ^ msg) ^ "\n\n";
4.206 +
4.207 +fun err_appl ctxt msg tye bs t T u U =
4.208 + let val ([t', u'], [T', U']) = prep_output ctxt tye bs [t, u] [T, U]
4.209 + in error (inf_failed msg ^ Type.appl_error (Syntax.pp ctxt) t' T' u' U' ^ "\n") end;
4.210 +
4.211 +fun err_subtype ctxt msg tye (bs, t $ u, U, V, U') =
4.212 + err_appl ctxt msg tye bs t (U --> V) u U';
4.213 +
4.214 +fun err_list ctxt msg tye Ts =
4.215 + let
4.216 + val (_, Ts') = prep_output ctxt tye [] [] Ts;
4.217 + val text = cat_lines ([inf_failed msg,
4.218 + "Cannot unify a list of types that should be the same,",
4.219 + "according to suptype dependencies:",
4.220 + (Pretty.string_of (Pretty.list "[" "]" (map (Pretty.typ (Syntax.pp ctxt)) Ts')))]);
4.221 + in
4.222 + error text
4.223 + end;
4.224 +
4.225 +fun err_bound ctxt msg tye packs =
4.226 + let
4.227 + val pp = Syntax.pp ctxt;
4.228 + val (ts, Ts) = fold
4.229 + (fn (bs, t $ u, U, _, U') => fn (ts, Ts) =>
4.230 + let val (t', T') = prep_output ctxt tye bs [t, u] [U, U']
4.231 + in (t'::ts, T'::Ts) end)
4.232 + packs ([], []);
4.233 + val text = cat_lines ([inf_failed msg, "Cannot fullfill subtype constraints:"] @
4.234 + (map2 (fn [t, u] => fn [T, U] => Pretty.string_of (
4.235 + Pretty.block [
4.236 + Pretty.typ pp T, Pretty.brk 2, Pretty.str "<:", Pretty.brk 2, Pretty.typ pp U,
4.237 + Pretty.brk 3, Pretty.str "from function application", Pretty.brk 2,
4.238 + Pretty.block [Pretty.term pp t, Pretty.brk 1, Pretty.term pp u]]))
4.239 + ts Ts))
4.240 + in
4.241 + error text
4.242 + end;
4.243 +
4.244 +
4.245 +
4.246 +(** constraint generation **)
4.247 +
4.248 +fun generate_constraints ctxt =
4.249 + let
4.250 + fun gen cs _ (Const (_, T)) tye_idx = (T, tye_idx, cs)
4.251 + | gen cs _ (Free (_, T)) tye_idx = (T, tye_idx, cs)
4.252 + | gen cs _ (Var (_, T)) tye_idx = (T, tye_idx, cs)
4.253 + | gen cs bs (Bound i) tye_idx =
4.254 + (snd (nth bs i handle Subscript => err_loose i), tye_idx, cs)
4.255 + | gen cs bs (Abs (x, T, t)) tye_idx =
4.256 + let val (U, tye_idx', cs') = gen cs ((x, T) :: bs) t tye_idx
4.257 + in (T --> U, tye_idx', cs') end
4.258 + | gen cs bs (t $ u) tye_idx =
4.259 + let
4.260 + val (T, tye_idx', cs') = gen cs bs t tye_idx;
4.261 + val (U', (tye, idx), cs'') = gen cs' bs u tye_idx';
4.262 + val U = mk_param idx [];
4.263 + val V = mk_param (idx + 1) [];
4.264 + val tye_idx''= strong_unify ctxt (U --> V, T) (tye, idx + 2)
4.265 + handle NO_UNIFIER (msg, tye') => err_appl ctxt msg tye' bs t T u U;
4.266 + val error_pack = (bs, t $ u, U, V, U');
4.267 + in (V, tye_idx'', ((U', U), error_pack) :: cs'') end;
4.268 + in
4.269 + gen [] []
4.270 + end;
4.271 +
4.272 +
4.273 +
4.274 +(** constraint resolution **)
4.275 +
4.276 +exception BOUND_ERROR of string;
4.277 +
4.278 +fun process_constraints ctxt cs tye_idx =
4.279 + let
4.280 + val coes_graph = coes_graph_of (Context.Proof ctxt);
4.281 + val tmaps = tmaps_of (Context.Proof ctxt);
4.282 + val tsig = Sign.tsig_of (ProofContext.theory_of ctxt);
4.283 + val pp = Syntax.pp ctxt;
4.284 + val arity_sorts = Type.arity_sorts pp tsig;
4.285 + val subsort = Type.subsort tsig;
4.286 +
4.287 + fun split_cs _ [] = ([], [])
4.288 + | split_cs f (c::cs) =
4.289 + (case pairself f (fst c) of
4.290 + (false, false) => apsnd (cons c) (split_cs f cs)
4.291 + | _ => apfst (cons c) (split_cs f cs));
4.292 +
4.293 +
4.294 + (* check whether constraint simplification will terminate using weak unification *)
4.295 +
4.296 + val _ = fold (fn (TU, error_pack) => fn tye_idx =>
4.297 + (weak_unify ctxt TU tye_idx handle NO_UNIFIER (msg, tye) =>
4.298 + err_subtype ctxt ("Weak unification of subtype constraints fails:\n" ^ msg)
4.299 + tye error_pack)) cs tye_idx;
4.300 +
4.301 +
4.302 + (* simplify constraints *)
4.303 +
4.304 + fun simplify_constraints cs tye_idx =
4.305 + let
4.306 + fun contract a Ts Us error_pack done todo tye idx =
4.307 + let
4.308 + val arg_var =
4.309 + (case Symtab.lookup tmaps a of
4.310 + (*everything is invariant for unknown constructors*)
4.311 + NONE => replicate (length Ts) INVARIANT
4.312 + | SOME av => snd av);
4.313 + fun new_constraints (variance, constraint) (cs, tye_idx) =
4.314 + (case variance of
4.315 + COVARIANT => (constraint :: cs, tye_idx)
4.316 + | CONTRAVARIANT => (swap constraint :: cs, tye_idx)
4.317 + | INVARIANT => (cs, strong_unify ctxt constraint tye_idx
4.318 + handle NO_UNIFIER (msg, tye) => err_subtype ctxt msg tye error_pack));
4.319 + val (new, (tye', idx')) = apfst (fn cs => (cs ~~ replicate (length cs) error_pack))
4.320 + (fold new_constraints (arg_var ~~ (Ts ~~ Us)) ([], (tye, idx)));
4.321 + val test_update = is_compT orf is_freeT orf is_fixedvarT;
4.322 + val (ch, done') =
4.323 + if not (null new) then ([], done)
4.324 + else split_cs (test_update o deref tye') done;
4.325 + val todo' = ch @ todo;
4.326 + in
4.327 + simplify done' (new @ todo') (tye', idx')
4.328 + end
4.329 + (*xi is definitely a parameter*)
4.330 + and expand varleq xi S a Ts error_pack done todo tye idx =
4.331 + let
4.332 + val n = length Ts;
4.333 + val args = map2 mk_param (idx upto idx + n - 1) (arity_sorts a S);
4.334 + val tye' = Vartab.update_new (xi, Type(a, args)) tye;
4.335 + val (ch, done') = split_cs (is_compT o deref tye') done;
4.336 + val todo' = ch @ todo;
4.337 + val new =
4.338 + if varleq then (Type(a, args), Type (a, Ts))
4.339 + else (Type (a, Ts), Type(a, args));
4.340 + in
4.341 + simplify done' ((new, error_pack) :: todo') (tye', idx + n)
4.342 + end
4.343 + (*TU is a pair of a parameter and a free/fixed variable*)
4.344 + and eliminate TU error_pack done todo tye idx =
4.345 + let
4.346 + val [TVar (xi, S)] = filter is_paramT TU;
4.347 + val [T] = filter_out is_paramT TU;
4.348 + val SOME S' = sort_of T;
4.349 + val test_update = if is_freeT T then is_freeT else is_fixedvarT;
4.350 + val tye' = Vartab.update_new (xi, T) tye;
4.351 + val (ch, done') = split_cs (test_update o deref tye') done;
4.352 + val todo' = ch @ todo;
4.353 + in
4.354 + if subsort (S', S) (*TODO check this*)
4.355 + then simplify done' todo' (tye', idx)
4.356 + else err_subtype ctxt "Sort mismatch" tye error_pack
4.357 + end
4.358 + and simplify done [] tye_idx = (done, tye_idx)
4.359 + | simplify done (((T, U), error_pack) :: todo) (tye_idx as (tye, idx)) =
4.360 + (case (deref tye T, deref tye U) of
4.361 + (Type (a, []), Type (b, [])) =>
4.362 + if a = b then simplify done todo tye_idx
4.363 + else if Graph.is_edge coes_graph (a, b) then simplify done todo tye_idx
4.364 + else err_subtype ctxt (a ^" is not a subtype of " ^ b) (fst tye_idx) error_pack
4.365 + | (Type (a, Ts), Type (b, Us)) =>
4.366 + if a<>b then err_subtype ctxt "Different constructors" (fst tye_idx) error_pack
4.367 + else contract a Ts Us error_pack done todo tye idx
4.368 + | (TVar (xi, S), Type (a, Ts as (_::_))) =>
4.369 + expand true xi S a Ts error_pack done todo tye idx
4.370 + | (Type (a, Ts as (_::_)), TVar (xi, S)) =>
4.371 + expand false xi S a Ts error_pack done todo tye idx
4.372 + | (T, U) =>
4.373 + if T = U then simplify done todo tye_idx
4.374 + else if exists (is_freeT orf is_fixedvarT) [T, U] andalso
4.375 + exists is_paramT [T, U]
4.376 + then eliminate [T, U] error_pack done todo tye idx
4.377 + else if exists (is_freeT orf is_fixedvarT) [T, U]
4.378 + then err_subtype ctxt "Not eliminated free/fixed variables"
4.379 + (fst tye_idx) error_pack
4.380 + else simplify (((T, U), error_pack)::done) todo tye_idx);
4.381 + in
4.382 + simplify [] cs tye_idx
4.383 + end;
4.384 +
4.385 +
4.386 + (* do simplification *)
4.387 +
4.388 + val (cs', tye_idx') = simplify_constraints cs tye_idx;
4.389 +
4.390 + fun find_error_pack lower T' =
4.391 + map snd (filter (fn ((T, U), _) => if lower then T' = U else T' = T) cs');
4.392 +
4.393 + fun unify_list (T::Ts) tye_idx =
4.394 + fold (fn U => fn tye_idx => strong_unify ctxt (T, U) tye_idx
4.395 + handle NO_UNIFIER (msg, tye) => err_list ctxt msg tye (T::Ts))
4.396 + Ts tye_idx;
4.397 +
4.398 + (*styps stands either for supertypes or for subtypes of a type T
4.399 + in terms of the subtype-relation (excluding T itself)*)
4.400 + fun styps super T =
4.401 + (if super then Graph.imm_succs else Graph.imm_preds) coes_graph T
4.402 + handle Graph.UNDEF _ => [];
4.403 +
4.404 + fun minmax sup (T::Ts) =
4.405 + let
4.406 + fun adjust T U = if sup then (T, U) else (U, T);
4.407 + fun extract T [] = T
4.408 + | extract T (U::Us) =
4.409 + if Graph.is_edge coes_graph (adjust T U) then extract T Us
4.410 + else if Graph.is_edge coes_graph (adjust U T) then extract U Us
4.411 + else raise BOUND_ERROR "Uncomparable types in type list";
4.412 + in
4.413 + t_of (extract T Ts)
4.414 + end;
4.415 +
4.416 + fun ex_styp_of_sort super T styps_and_sorts =
4.417 + let
4.418 + fun adjust T U = if super then (T, U) else (U, T);
4.419 + fun styp_test U Ts = forall
4.420 + (fn T => T = U orelse Graph.is_edge coes_graph (adjust U T)) Ts;
4.421 + fun fitting Ts S U = Type.of_sort tsig (t_of U, S) andalso styp_test U Ts
4.422 + in
4.423 + forall (fn (Ts, S) => exists (fitting Ts S) (T :: styps super T)) styps_and_sorts
4.424 + end;
4.425 +
4.426 + (* computes the tightest possible, correct assignment for 'a::S
4.427 + e.g. in the supremum case (sup = true):
4.428 + ------- 'a::S---
4.429 + / / \ \
4.430 + / / \ \
4.431 + 'b::C1 'c::C2 ... T1 T2 ...
4.432 +
4.433 + sorts - list of sorts [C1, C2, ...]
4.434 + T::Ts - non-empty list of base types [T1, T2, ...]
4.435 + *)
4.436 + fun tightest sup S styps_and_sorts (T::Ts) =
4.437 + let
4.438 + fun restriction T = Type.of_sort tsig (t_of T, S)
4.439 + andalso ex_styp_of_sort (not sup) T styps_and_sorts;
4.440 + fun candidates T = inter (op =) (filter restriction (T :: styps sup T));
4.441 + in
4.442 + (case fold candidates Ts (filter restriction (T :: styps sup T)) of
4.443 + [] => raise BOUND_ERROR ("No " ^ (if sup then "supremum" else "infimum"))
4.444 + | [T] => t_of T
4.445 + | Ts => minmax sup Ts)
4.446 + end;
4.447 +
4.448 + fun build_graph G [] tye_idx = (G, tye_idx)
4.449 + | build_graph G ((T, U)::cs) tye_idx =
4.450 + if T = U then build_graph G cs tye_idx
4.451 + else
4.452 + let
4.453 + val G' = maybe_new_typnodes [T, U] G;
4.454 + val (G'', tye_idx') = (add_edge (T, U) G', tye_idx)
4.455 + handle Typ_Graph.CYCLES cycles =>
4.456 + let
4.457 + val (tye, idx) = fold unify_list cycles tye_idx
4.458 + in
4.459 + (*all cycles collapse to one node,
4.460 + because all of them share at least the nodes x and y*)
4.461 + collapse (tye, idx) (distinct (op =) (flat cycles)) G
4.462 + end;
4.463 + in
4.464 + build_graph G'' cs tye_idx'
4.465 + end
4.466 + and collapse (tye, idx) nodes G = (*nodes non-empty list*)
4.467 + let
4.468 + val T = hd nodes;
4.469 + val P = new_imm_preds G nodes;
4.470 + val S = new_imm_succs G nodes;
4.471 + val G' = Typ_Graph.del_nodes (tl nodes) G;
4.472 + in
4.473 + build_graph G' (map (fn x => (x, T)) P @ map (fn x => (T, x)) S) (tye, idx)
4.474 + end;
4.475 +
4.476 + fun assign_bound lower G key (tye_idx as (tye, _)) =
4.477 + if is_paramT (deref tye key) then
4.478 + let
4.479 + val TVar (xi, S) = deref tye key;
4.480 + val get_bound = if lower then get_preds else get_succs;
4.481 + val raw_bound = get_bound G key;
4.482 + val bound = map (deref tye) raw_bound;
4.483 + val not_params = filter_out is_paramT bound;
4.484 + fun to_fulfil T =
4.485 + (case sort_of T of
4.486 + NONE => NONE
4.487 + | SOME S =>
4.488 + SOME (map nameT (filter_out is_paramT (map (deref tye) (get_bound G T))), S));
4.489 + val styps_and_sorts = distinct (op =) (map_filter to_fulfil raw_bound);
4.490 + val assignment =
4.491 + if null bound orelse null not_params then NONE
4.492 + else SOME (tightest lower S styps_and_sorts (map nameT not_params)
4.493 + handle BOUND_ERROR msg => err_bound ctxt msg tye (find_error_pack lower key))
4.494 + in
4.495 + (case assignment of
4.496 + NONE => tye_idx
4.497 + | SOME T =>
4.498 + if is_paramT T then tye_idx
4.499 + else if lower then (*upper bound check*)
4.500 + let
4.501 + val other_bound = map (deref tye) (get_succs G key);
4.502 + val s = nameT T;
4.503 + in
4.504 + if subset (op = o apfst nameT) (filter is_typeT other_bound, s :: styps true s)
4.505 + then apfst (Vartab.update (xi, T)) tye_idx
4.506 + else err_bound ctxt ("Assigned simple type " ^ s ^
4.507 + " clashes with the upper bound of variable " ^
4.508 + Syntax.string_of_typ ctxt (TVar(xi, S))) tye (find_error_pack (not lower) key)
4.509 + end
4.510 + else apfst (Vartab.update (xi, T)) tye_idx)
4.511 + end
4.512 + else tye_idx;
4.513 +
4.514 + val assign_lb = assign_bound true;
4.515 + val assign_ub = assign_bound false;
4.516 +
4.517 + fun assign_alternating ts' ts G tye_idx =
4.518 + if ts' = ts then tye_idx
4.519 + else
4.520 + let
4.521 + val (tye_idx' as (tye, _)) = fold (assign_lb G) ts tye_idx
4.522 + |> fold (assign_ub G) ts;
4.523 + in
4.524 + assign_alternating ts (filter (is_paramT o deref tye) ts) G tye_idx'
4.525 + end;
4.526 +
4.527 + (*Unify all weakly connected components of the constraint forest,
4.528 + that contain only params. These are the only WCCs that contain
4.529 + params anyway.*)
4.530 + fun unify_params G (tye_idx as (tye, _)) =
4.531 + let
4.532 + val max_params = filter (is_paramT o deref tye) (Typ_Graph.maximals G);
4.533 + val to_unify = map (fn T => T :: get_preds G T) max_params;
4.534 + in
4.535 + fold unify_list to_unify tye_idx
4.536 + end;
4.537 +
4.538 + fun solve_constraints G tye_idx = tye_idx
4.539 + |> assign_alternating [] (Typ_Graph.keys G) G
4.540 + |> unify_params G;
4.541 + in
4.542 + build_graph Typ_Graph.empty (map fst cs') tye_idx'
4.543 + |-> solve_constraints
4.544 + end;
4.545 +
4.546 +
4.547 +
4.548 +(** coercion insertion **)
4.549 +
4.550 +fun insert_coercions ctxt tye ts =
4.551 + let
4.552 + fun deep_deref T =
4.553 + (case deref tye T of
4.554 + Type (a, Ts) => Type (a, map deep_deref Ts)
4.555 + | U => U);
4.556 +
4.557 + fun gen_coercion ((Type (a, [])), (Type (b, []))) =
4.558 + if a = b
4.559 + then Abs (Name.uu, Type (a, []), Bound 0)
4.560 + else
4.561 + (case Symreltab.lookup (coes_of (Context.Proof ctxt)) (a, b) of
4.562 + NONE => raise Fail (a ^ " is not a subtype of " ^ b)
4.563 + | SOME co => co)
4.564 + | gen_coercion ((Type (a, Ts)), (Type (b, Us))) =
4.565 + if a <> b
4.566 + then raise raise Fail ("Different constructors: " ^ a ^ " and " ^ b)
4.567 + else
4.568 + let
4.569 + fun inst t Ts =
4.570 + Term.subst_vars
4.571 + (((Term.add_tvar_namesT (fastype_of t) []) ~~ rev Ts), []) t;
4.572 + fun sub_co (COVARIANT, TU) = gen_coercion TU
4.573 + | sub_co (CONTRAVARIANT, TU) = gen_coercion (swap TU);
4.574 + fun ts_of [] = []
4.575 + | ts_of (Type ("fun", [x1, x2])::xs) = x1::x2::(ts_of xs);
4.576 + in
4.577 + (case Symtab.lookup (tmaps_of (Context.Proof ctxt)) a of
4.578 + NONE => raise Fail ("No map function for " ^ a ^ " known")
4.579 + | SOME tmap =>
4.580 + let
4.581 + val used_coes = map sub_co ((snd tmap) ~~ (Ts ~~ Us));
4.582 + in
4.583 + Term.list_comb
4.584 + (inst (fst tmap) (ts_of (map fastype_of used_coes)), used_coes)
4.585 + end)
4.586 + end
4.587 + | gen_coercion (T, U) =
4.588 + if Type.could_unify (T, U)
4.589 + then Abs (Name.uu, T, Bound 0)
4.590 + else raise Fail ("Cannot generate coercion from "
4.591 + ^ Syntax.string_of_typ ctxt T ^ " to " ^ Syntax.string_of_typ ctxt U);
4.592 +
4.593 + fun insert _ (Const (c, T)) =
4.594 + let val T' = deep_deref T;
4.595 + in (Const (c, T'), T') end
4.596 + | insert _ (Free (x, T)) =
4.597 + let val T' = deep_deref T;
4.598 + in (Free (x, T'), T') end
4.599 + | insert _ (Var (xi, T)) =
4.600 + let val T' = deep_deref T;
4.601 + in (Var (xi, T'), T') end
4.602 + | insert bs (Bound i) =
4.603 + let val T = nth bs i handle Subscript =>
4.604 + raise TYPE ("Loose bound variable: B." ^ string_of_int i, [], []);
4.605 + in (Bound i, T) end
4.606 + | insert bs (Abs (x, T, t)) =
4.607 + let
4.608 + val T' = deep_deref T;
4.609 + val (t', T'') = insert (T'::bs) t;
4.610 + in
4.611 + (Abs (x, T', t'), T' --> T'')
4.612 + end
4.613 + | insert bs (t $ u) =
4.614 + let
4.615 + val (t', Type ("fun", [U, T])) = insert bs t;
4.616 + val (u', U') = insert bs u;
4.617 + in
4.618 + if U <> U'
4.619 + then (t' $ (gen_coercion (U', U) $ u'), T)
4.620 + else (t' $ u', T)
4.621 + end
4.622 + in
4.623 + map (fst o insert []) ts
4.624 + end;
4.625 +
4.626 +
4.627 +
4.628 +(** assembling the pipeline **)
4.629 +
4.630 +fun infer_types ctxt const_type var_type raw_ts =
4.631 + let
4.632 + val (idx, ts) = Type_Infer.prepare ctxt const_type var_type raw_ts;
4.633 +
4.634 + fun gen_all t (tye_idx, constraints) =
4.635 + let
4.636 + val (_, tye_idx', constraints') = generate_constraints ctxt t tye_idx
4.637 + in (tye_idx', constraints' @ constraints) end;
4.638 +
4.639 + val (tye_idx, constraints) = fold gen_all ts ((Vartab.empty, idx), []);
4.640 + val (tye, _) = process_constraints ctxt constraints tye_idx;
4.641 + val ts' = insert_coercions ctxt tye ts;
4.642 +
4.643 + val (_, ts'') = Type_Infer.finish ctxt tye ([], ts');
4.644 + in ts'' end;
4.645 +
4.646 +
4.647 +
4.648 +(** installation **)
4.649 +
4.650 +fun coercion_infer_types ctxt =
4.651 + infer_types ctxt
4.652 + (try (Consts.the_constraint (ProofContext.consts_of ctxt)))
4.653 + (ProofContext.def_type ctxt);
4.654 +
4.655 +local
4.656 +
4.657 +fun add eq what f = Context.>> (what (fn xs => fn ctxt =>
4.658 + let val xs' = f ctxt xs in if eq_list eq (xs, xs') then NONE else SOME (xs', ctxt) end));
4.659 +
4.660 +in
4.661 +
4.662 +val _ = add (op aconv) (Syntax.add_term_check ~100 "coercions") coercion_infer_types;
4.663 +
4.664 +end;
4.665 +
4.666 +
4.667 +(* interface *)
4.668 +
4.669 +fun add_type_map map_fun context =
4.670 + let
4.671 + val ctxt = Context.proof_of context;
4.672 + val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt map_fun);
4.673 +
4.674 + fun err_str () = "\n\nthe general type signature for a map function is" ^
4.675 + "\nf1 => f2 => ... => fn => C [x1, ..., xn] => C [x1, ..., xn]" ^
4.676 + "\nwhere C is a constructor and fi is of type (xi => yi) or (yi => xi)";
4.677 +
4.678 + fun gen_arg_var ([], []) = []
4.679 + | gen_arg_var ((T, T')::Ts, (U, U')::Us) =
4.680 + if T = U andalso T' = U' then COVARIANT :: gen_arg_var (Ts, Us)
4.681 + else if T = U' andalso T' = U then CONTRAVARIANT :: gen_arg_var (Ts, Us)
4.682 + else error ("Functions do not apply to arguments correctly:" ^ err_str ())
4.683 + | gen_arg_var (_, _) =
4.684 + error ("Different numbers of functions and arguments\n" ^ err_str ());
4.685 +
4.686 + (* TODO: This function is only needed to introde the fun type map
4.687 + function: "% f g h . g o h o f". There must be a better solution. *)
4.688 + fun balanced (Type (_, [])) (Type (_, [])) = true
4.689 + | balanced (Type (a, Ts)) (Type (b, Us)) =
4.690 + a = b andalso forall I (map2 balanced Ts Us)
4.691 + | balanced (TFree _) (TFree _) = true
4.692 + | balanced (TVar _) (TVar _) = true
4.693 + | balanced _ _ = false;
4.694 +
4.695 + fun check_map_fun (pairs, []) (Type ("fun", [T as Type (C, Ts), U as Type (_, Us)])) =
4.696 + if balanced T U
4.697 + then ((pairs, Ts~~Us), C)
4.698 + else if C = "fun"
4.699 + then check_map_fun (pairs @ [(hd Ts, hd (tl Ts))], []) U
4.700 + else error ("Not a proper map function:" ^ err_str ())
4.701 + | check_map_fun _ _ = error ("Not a proper map function:" ^ err_str ());
4.702 +
4.703 + val res = check_map_fun ([], []) (fastype_of t);
4.704 + val res_av = gen_arg_var (fst res);
4.705 + in
4.706 + map_tmaps (Symtab.update (snd res, (t, res_av))) context
4.707 + end;
4.708 +
4.709 +fun add_coercion coercion context =
4.710 + let
4.711 + val ctxt = Context.proof_of context;
4.712 + val t = singleton (Variable.polymorphic ctxt) (Syntax.read_term ctxt coercion);
4.713 +
4.714 + fun err_coercion () = error ("Bad type for coercion " ^
4.715 + Syntax.string_of_term ctxt t ^ ":\n" ^
4.716 + Syntax.string_of_typ ctxt (fastype_of t));
4.717 +
4.718 + val (Type ("fun", [T1, T2])) = fastype_of t
4.719 + handle Bind => err_coercion ();
4.720 +
4.721 + val a =
4.722 + (case T1 of
4.723 + Type (x, []) => x
4.724 + | _ => err_coercion ());
4.725 +
4.726 + val b =
4.727 + (case T2 of
4.728 + Type (x, []) => x
4.729 + | _ => err_coercion ());
4.730 +
4.731 + fun coercion_data_update (tab, G) =
4.732 + let
4.733 + val G' = maybe_new_nodes [a, b] G
4.734 + val G'' = Graph.add_edge_trans_acyclic (a, b) G'
4.735 + handle Graph.CYCLES _ => error (a ^ " is already a subtype of " ^ b ^
4.736 + "!\n\nCannot add coercion of type: " ^ a ^ " => " ^ b);
4.737 + val new_edges =
4.738 + flat (Graph.dest G'' |> map (fn (x, ys) => ys |> map_filter (fn y =>
4.739 + if Graph.is_edge G' (x, y) then NONE else SOME (x, y))));
4.740 + val G_and_new = Graph.add_edge (a, b) G';
4.741 +
4.742 + fun complex_coercion tab G (a, b) =
4.743 + let
4.744 + val path = hd (Graph.irreducible_paths G (a, b))
4.745 + val path' = (fst (split_last path)) ~~ tl path
4.746 + in Abs (Name.uu, Type (a, []),
4.747 + fold (fn t => fn u => t $ u) (map (the o Symreltab.lookup tab) path') (Bound 0))
4.748 + end;
4.749 +
4.750 + val tab' = fold
4.751 + (fn pair => fn tab => Symreltab.update (pair, complex_coercion tab G_and_new pair) tab)
4.752 + (filter (fn pair => pair <> (a, b)) new_edges)
4.753 + (Symreltab.update ((a, b), t) tab);
4.754 + in
4.755 + (tab', G'')
4.756 + end;
4.757 + in
4.758 + map_coes_and_graph coercion_data_update context
4.759 + end;
4.760 +
4.761 +val _ = Context.>> (Context.map_theory
4.762 + (Attrib.setup (Binding.name "coercion") (Scan.lift Parse.term >>
4.763 + (fn t => fn (context, thm) => (add_coercion t context, thm)))
4.764 + "declaration of new coercions" #>
4.765 + Attrib.setup (Binding.name "map_function") (Scan.lift Parse.term >>
4.766 + (fn t => fn (context, thm) => (add_type_map t context, thm)))
4.767 + "declaration of new map functions"));
4.768 +
4.769 +end;