1 (* Title: HOL/Tools/Datatype/datatype_case.ML
2 Author: Konrad Slind, Cambridge University Computer Laboratory
3 Author: Stefan Berghofer, TU Muenchen
5 Datatype package: nested case expressions on datatypes.
8 signature DATATYPE_CASE =
10 datatype config = Error | Warning | Quiet
11 type info = Datatype_Aux.info
12 val make_case: (string * typ -> info option) ->
13 Proof.context -> config -> string list -> term -> (term * term) list ->
15 val dest_case: (string -> info option) -> bool ->
16 string list -> term -> (term * (term * term) list) option
17 val strip_case: (string -> info option) -> bool ->
18 term -> (term * (term * term) list) option
19 val case_tr: bool -> (theory -> string * typ -> info option) ->
20 Proof.context -> term list -> term
21 val case_tr': (theory -> string -> info option) ->
22 string -> Proof.context -> term list -> term
25 structure Datatype_Case : DATATYPE_CASE =
28 datatype config = Error | Warning | Quiet;
29 type info = Datatype_Aux.info;
31 exception CASE_ERROR of string * int;
33 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
35 (* Get information about datatypes *)
39 SOME ({descr, case_name, index, sorts, ...} : info) =>
41 val (_, (tname, dts, constrs)) = nth descr index;
42 val mk_ty = Datatype_Aux.typ_of_dtyp descr sorts;
43 val T = Type (tname, map mk_ty dts);
45 SOME {case_name = case_name,
46 constructors = map (fn (cname, dts') =>
47 Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
52 (*Each pattern carries with it a tag i, which denotes the clause it
53 came from. i = ~1 indicates that the clause was added by pattern
56 fun add_row_used ((prfx, pats), (tm, tag)) =
57 fold Term.add_free_names (tm :: pats @ map Free prfx);
59 (*try to preserve names given by user*)
60 fun default_names names ts =
61 map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
63 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
64 strip_constraints t ||> cons tT
65 | strip_constraints t = (t, []);
67 fun mk_fun_constrain tT t =
68 Syntax.const @{syntax_const "_constrain"} $ t $
69 (Syntax.const @{type_syntax fun} $ tT $ Syntax.const @{type_syntax dummy});
72 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
73 fun fresh_constr ty_match ty_inst colty used c =
75 val (_, Ty) = dest_Const c
76 val Ts = binder_types Ty;
77 val names = Name.variant_list used
78 (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts));
79 val ty = body_type Ty;
80 val ty_theta = ty_match ty colty handle Type.TYPE_MATCH =>
81 raise CASE_ERROR ("type mismatch", ~1)
82 val c' = ty_inst ty_theta c
83 val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
87 (*Goes through a list of rows and picks out the ones beginning with a
88 pattern with constructor = name.*)
89 fun mk_group (name, T) rows =
90 let val k = length (binder_types T) in
91 fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
92 fn ((in_group, not_in_group), (names, cnstrts)) =>
94 (Const (name', _), args) =>
96 if length args = k then
97 let val (args', cnstrts') = split_list (map strip_constraints args)
99 ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
100 (default_names names args', map2 append cnstrts cnstrts'))
102 else raise CASE_ERROR
103 ("Wrong number of arguments for constructor " ^ name, i)
104 else ((in_group, row :: not_in_group), (names, cnstrts))
105 | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
106 rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
112 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
113 | partition ty_match ty_inst type_of used constructors colty res_ty
114 (rows as (((prfx, _ :: ps), _) :: _)) =
117 | part [] ((_, (_, i)) :: _) =
118 raise CASE_ERROR ("Not a constructor pattern", i)
119 | part (c :: cs) rows =
121 val ((in_group, not_in_group), (names, cnstrts)) =
122 mk_group (dest_Const c) rows;
123 val used' = fold add_row_used in_group used;
124 val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
126 if null in_group (* Constructor not given *)
129 val Ts = map type_of ps;
130 val xs = Name.variant_list
131 (fold Term.add_free_names gvars used')
132 (replicate (length ps) "x")
134 [((prfx, gvars @ map Free (xs ~~ Ts)),
135 (Const (@{const_syntax undefined}, res_ty), ~1))]
142 constraints = cnstrts,
143 group = in_group'} :: part cs not_in_group
145 in part constructors rows end;
147 fun v_to_prfx (prfx, Free v::pats) = (v::prfx,pats)
148 | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
151 (* Translation of pattern terms into nested case expressions. *)
153 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
155 val name = singleton (Name.variant_list used) "a";
156 fun expand constructors used ty ((_, []), _) =
157 raise CASE_ERROR ("mk_case: expand_var_row", ~1)
158 | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
161 val used' = add_row_used row used;
164 list_comb (fresh_constr ty_match ty_inst ty used' c)
165 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag))
167 in map expnd constructors end
169 fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
170 | mk [] (((_, []), (tm, tag)) :: _) = (* Done *)
172 | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) =
174 | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
175 let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
176 (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
179 val rows' = map (fn ((v, _), row) => row ||>
180 apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
182 | SOME (Const (cname, cT), i) =>
183 (case ty_info tab (cname, cT) of
184 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
185 | SOME {case_name, constructors} =>
187 val pty = body_type cT;
188 val used' = fold Term.add_free_names us used;
189 val nrows = maps (expand constructors used' pty) rows;
190 val subproblems = partition ty_match ty_inst type_of used'
191 constructors pty range_ty nrows;
192 val (pat_rect, dtrees) = split_list (map (fn {new_formals, group, ...} =>
193 mk (new_formals @ us) group) subproblems)
194 val case_functions = map2
195 (fn {new_formals, names, constraints, ...} =>
196 fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
197 Abs (if s = "" then name else s, T,
198 abstract_over (x, t)) |>
199 fold mk_fun_constrain cnstrts)
200 (new_formals ~~ names ~~ constraints))
202 val types = map type_of (case_functions @ [u]);
203 val case_const = Const (case_name, types ---> range_ty)
204 val tree = list_comb (case_const, case_functions @ [u])
205 in (flat pat_rect, tree) end)
206 | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
207 Syntax.string_of_term ctxt t, i))
209 | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
212 fun case_error s = error ("Error in case expression:\n" ^ s);
214 (*Repeated variable occurrences in a pattern are not allowed.*)
215 fun no_repeat_vars ctxt pat = fold_aterms
216 (fn x as Free (s, _) => (fn xs =>
217 if member op aconv xs x then
218 case_error (quote s ^ " occurs repeatedly in the pattern " ^
219 quote (Syntax.string_of_term ctxt pat))
223 fun gen_make_case ty_match ty_inst type_of tab ctxt config used x clauses =
225 fun string_of_clause (pat, rhs) =
226 Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
227 val _ = map (no_repeat_vars ctxt o fst) clauses;
228 val rows = map_index (fn (i, (pat, rhs)) =>
229 (([], [pat]), (rhs, i))) clauses;
231 (case distinct op = (map (type_of o snd) clauses) of
232 [] => case_error "no clauses given"
234 | _ => case_error "all cases must have the same result type");
235 val used' = fold add_row_used rows used;
236 val (tags, case_tm) = mk_case tab ctxt ty_match ty_inst type_of
237 used' rangeT [x] rows
238 handle CASE_ERROR (msg, i) => case_error (msg ^
240 else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
242 (case subtract (op =) tags (map (snd o snd) rows) of
245 (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
246 ("The following clauses are redundant (covered by preceding clauses):\n" ^
247 cat_lines (map (string_of_clause o nth clauses) is)));
252 fun make_case tab ctxt = gen_make_case
253 (match_type (Proof_Context.theory_of ctxt)) Envir.subst_term_types fastype_of tab ctxt;
254 val make_case_untyped = gen_make_case (K (K Vartab.empty))
255 (K (Term.map_types (K dummyT))) (K dummyT);
258 (* parse translation *)
260 fun case_tr err tab_of ctxt [t, u] =
262 val thy = Proof_Context.theory_of ctxt;
263 val intern_const_syntax = Consts.intern_syntax (Proof_Context.consts_of ctxt);
265 (* replace occurrences of dummy_pattern by distinct variables *)
266 (* internalize constant names *)
267 (* FIXME proper name context!? *)
268 fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
269 let val (t', used') = prep_pat t used
270 in (c $ t' $ tT, used') end
271 | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
272 let val x = singleton (Name.variant_list used) "x"
273 in (Free (x, T), x :: used) end
274 | prep_pat (Const (s, T)) used =
275 (Const (intern_const_syntax s, T), used)
276 | prep_pat (v as Free (s, T)) used =
277 let val s' = Proof_Context.intern_const ctxt s in
278 if Sign.declared_const thy s' then
279 (Const (s', T), used)
282 | prep_pat (t $ u) used =
284 val (t', used') = prep_pat t used;
285 val (u', used'') = prep_pat u used';
289 | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
290 fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
291 let val (l', cnstrts) = strip_constraints l
292 in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
293 | dest_case1 t = case_error "dest_case1";
294 fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
295 | dest_case2 t = [t];
296 val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
297 val case_tm = make_case_untyped (tab_of thy) ctxt
298 (if err then Error else Warning) []
299 (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
300 (flat cnstrts) t) cases;
302 | case_tr _ _ _ ts = case_error "case_tr";
305 (* Pretty printing of nested case expressions *)
307 (* destruct one level of pattern matching *)
309 (* FIXME proper name context!? *)
310 fun gen_dest_case name_of type_of tab d used t =
311 (case apfst name_of (strip_comb t) of
312 (SOME cname, ts as _ :: _) =>
314 val (fs, x) = split_last ts;
317 val zs = strip_abs_vars t;
318 val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
319 val (xs, ys) = chop i zs;
320 val u = list_abs (ys, strip_abs_body t);
321 val xs' = map Free (Name.variant_list (OldTerm.add_term_names (u, used))
322 (map fst xs) ~~ map snd xs)
323 in (xs', subst_bounds (rev xs', u)) end;
324 fun is_dependent i t =
325 let val k = length (strip_abs_vars t) - i
326 in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
327 fun count_cases (_, _, true) = I
328 | count_cases (c, (_, body), false) =
329 AList.map_default op aconv (body, []) (cons c);
330 val is_undefined = name_of #> equal (SOME @{const_name undefined});
331 fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
333 (case ty_info tab cname of
334 SOME {constructors, case_name} =>
335 if length fs = length constructors then
337 val cases = map (fn (Const (s, U), t) =>
339 val k = length (binder_types U);
340 val p as (xs, _) = strip_abs k t
342 (Const (s, map type_of xs ---> type_of x),
344 end) (constructors ~~ fs);
345 val cases' = sort (int_ord o swap o pairself (length o snd))
346 (fold_rev count_cases cases []);
349 if d then Const (@{const_name dummy_pattern}, R)
350 else Free (singleton (Name.variant_list used) "x", R);
354 (case find_first (is_undefined o fst) cases' of
356 if length cs = length constructors then [hd cases]
357 else filter_out (fn (_, (_, body), _) => is_undefined body) cases
358 | NONE => case cases' of
360 | (default, cs) :: _ =>
361 if length cs = 1 then cases
362 else if length cs = length constructors then
363 [hd cases, (dummy, ([], default), false)]
365 filter_out (fn (c, _, _) => member op aconv cs c) cases @
366 [(dummy, ([], default), false)]))
367 end handle CASE_ERROR _ => NONE
373 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
374 val dest_case' = gen_dest_case (try (dest_Const #> fst #> Lexicon.unmark_const)) (K dummyT);
377 (* destruct nested patterns *)
379 fun strip_case'' dest (pat, rhs) =
380 (case dest (Term.add_free_names pat []) rhs of
381 SOME (exp as Free _, clauses) =>
382 if member op aconv (OldTerm.term_frees pat) exp andalso
383 not (exists (fn (_, rhs') =>
384 member op aconv (OldTerm.term_frees rhs') exp) clauses)
386 maps (strip_case'' dest) (map (fn (pat', rhs') =>
387 (subst_free [(exp, pat')] pat, rhs')) clauses)
389 | _ => [(pat, rhs)]);
391 fun gen_strip_case dest t =
394 SOME (x, maps (strip_case'' dest) clauses)
397 val strip_case = gen_strip_case oo dest_case;
398 val strip_case' = gen_strip_case oo dest_case';
401 (* print translation *)
403 fun case_tr' tab_of cname ctxt ts =
405 val thy = Proof_Context.theory_of ctxt;
406 fun mk_clause (pat, rhs) =
407 let val xs = Term.add_frees pat [] in
408 Syntax.const @{syntax_const "_case1"} $
410 (fn Free p => Syntax_Trans.mark_boundT p
411 | Const (s, _) => Syntax.const (Lexicon.mark_const s)
414 (fn x as Free (s, T) =>
415 if member (op =) xs (s, T) then Syntax_Trans.mark_bound s else x
419 (case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
421 Syntax.const @{syntax_const "_case_syntax"} $ x $
422 foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
423 (map mk_clause clauses)
424 | NONE => raise Match)