1 (* Title: HOL/Tools/datatype_case.ML
3 Author: Konrad Slind, Cambridge University Computer Laboratory
4 Stefan Berghofer, TU Muenchen
6 Nested case expressions on datatypes.
9 signature DATATYPE_CASE =
11 val make_case: (string -> DatatypeAux.datatype_info option) ->
12 Proof.context -> bool -> string list -> term -> (term * term) list ->
13 term * (term * (int * bool)) list
14 val dest_case: (string -> DatatypeAux.datatype_info option) -> bool ->
15 string list -> term -> (term * (term * term) list) option
16 val strip_case: (string -> DatatypeAux.datatype_info option) -> bool ->
17 term -> (term * (term * term) list) option
18 val case_tr: bool -> (theory -> string -> DatatypeAux.datatype_info option)
19 -> Proof.context -> term list -> term
20 val case_tr': (theory -> string -> DatatypeAux.datatype_info option) ->
21 string -> Proof.context -> term list -> term
24 structure DatatypeCase : DATATYPE_CASE =
27 exception CASE_ERROR of string * int;
29 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
31 (*---------------------------------------------------------------------------
32 * Get information about datatypes
33 *---------------------------------------------------------------------------*)
35 fun ty_info (tab : string -> DatatypeAux.datatype_info option) s =
37 SOME {descr, case_name, index, sorts, ...} =>
39 val (_, (tname, dts, constrs)) = nth descr index;
40 val mk_ty = DatatypeAux.typ_of_dtyp descr sorts;
41 val T = Type (tname, map mk_ty dts)
43 SOME {case_name = case_name,
44 constructors = map (fn (cname, dts') =>
45 Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs}
50 (*---------------------------------------------------------------------------
51 * Each pattern carries with it a tag (i,b) where
52 * i is the clause it came from and
53 * b=true indicates that clause was given by the user
54 * (or is an instantiation of a user supplied pattern)
56 *---------------------------------------------------------------------------*)
58 fun pattern_map f (tm,x) = (f tm, x);
60 fun pattern_subst theta = pattern_map (subst_free theta);
62 fun row_of_pat x = fst (snd x);
64 fun add_row_used ((prfx, pats), (tm, tag)) used =
65 foldl add_term_free_names (foldl add_term_free_names
66 (add_term_free_names (tm, used)) pats) prfx;
68 (* try to preserve names given by user *)
69 fun default_names names ts =
70 map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
72 fun strip_constraints (Const ("_constrain", _) $ t $ tT) =
73 strip_constraints t ||> cons tT
74 | strip_constraints t = (t, []);
76 fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $
77 (Syntax.free "fun" $ tT $ Syntax.free "dummy");
80 (*---------------------------------------------------------------------------
81 * Produce an instance of a constructor, plus genvars for its arguments.
82 *---------------------------------------------------------------------------*)
83 fun fresh_constr ty_match ty_inst colty used c =
85 val (_, Ty) = dest_Const c
86 val Ts = binder_types Ty;
87 val names = Name.variant_list used
88 (DatatypeProp.make_tnames (map Logic.unvarifyT Ts));
89 val ty = body_type Ty;
90 val ty_theta = ty_match ty colty handle Type.TYPE_MATCH =>
91 raise CASE_ERROR ("type mismatch", ~1)
92 val c' = ty_inst ty_theta c
93 val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
98 (*---------------------------------------------------------------------------
99 * Goes through a list of rows and picks out the ones beginning with a
100 * pattern with constructor = name.
101 *---------------------------------------------------------------------------*)
102 fun mk_group (name, T) rows =
103 let val k = length (binder_types T)
104 in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
105 fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of
106 (Const (name', _), args) =>
108 if length args = k then
109 let val (args', cnstrts') = split_list (map strip_constraints args)
111 ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
112 (default_names names args', map2 append cnstrts cnstrts'))
114 else raise CASE_ERROR
115 ("Wrong number of arguments for constructor " ^ name, i)
116 else ((in_group, row :: not_in_group), (names, cnstrts))
117 | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
118 rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
121 (*---------------------------------------------------------------------------
122 * Partition the rows. Not efficient: we should use hashing.
123 *---------------------------------------------------------------------------*)
124 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
125 | partition ty_match ty_inst type_of used constructors colty res_ty
126 (rows as (((prfx, _ :: rstp), _) :: _)) =
128 fun part {constrs = [], rows = [], A} = rev A
129 | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} =
130 raise CASE_ERROR ("Not a constructor pattern", i)
131 | part {constrs = c :: crst, rows, A} =
133 val ((in_group, not_in_group), (names, cnstrts)) =
134 mk_group (dest_Const c) rows;
135 val used' = fold add_row_used in_group used;
136 val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
138 if null in_group (* Constructor not given *)
141 val Ts = map type_of rstp;
142 val xs = Name.variant_list
143 (foldl add_term_free_names used' gvars)
144 (replicate (length rstp) "x")
146 [((prfx, gvars @ map Free (xs ~~ Ts)),
147 (Const ("HOL.undefined", res_ty), (~1, false)))]
153 A = {constructor = c',
156 constraints = cnstrts,
157 group = in_group'} :: A}
159 in part {constrs = constructors, rows = rows, A = []}
162 (*---------------------------------------------------------------------------
163 * Misc. routines used in mk_case
164 *---------------------------------------------------------------------------*)
166 fun mk_pat ((c, c'), l) =
168 val L = length (binder_types (fastype_of c))
169 fun build (prfx, tag, plist) =
170 let val (args, plist') = chop L plist
171 in (prfx, tag, list_comb (c', args) :: plist') end
174 fun v_to_prfx (prfx, v::pats) = (v::prfx,pats)
175 | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
177 fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats)
178 | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1);
181 (*----------------------------------------------------------------------------
182 * Translation of pattern terms into nested case expressions.
184 * This performs the translation and also builds the full set of patterns.
185 * Thus it supports the construction of induction theorems even when an
186 * incomplete set of patterns is given.
187 *---------------------------------------------------------------------------*)
189 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
191 val name = Name.variant used "a";
192 fun expand constructors used ty ((_, []), _) =
193 raise CASE_ERROR ("mk_case: expand_var_row", ~1)
194 | expand constructors used ty (row as ((prfx, p :: rst), rhs)) =
197 val used' = add_row_used row used;
200 list_comb (fresh_constr ty_match ty_inst ty used' c)
201 in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs)
203 in map expnd constructors end
205 fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1)
206 | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} = (* Done *)
207 ([(prfx, tag, [])], tm)
208 | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} =
209 mk {path = path, rows = [row]}
210 | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} =
211 let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows
212 in case Option.map (apfst head_of)
213 (find_first (not o is_Free o fst) col0) of
216 val rows' = map (fn ((v, _), row) => row ||>
217 pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
218 val (pref_patl, tm) = mk {path = rstp, rows = rows'}
219 in (map v_to_pats pref_patl, tm) end
220 | SOME (Const (cname, cT), i) => (case ty_info tab cname of
221 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
222 | SOME {case_name, constructors} =>
224 val pty = body_type cT;
225 val used' = foldl add_term_free_names used rstp;
226 val nrows = maps (expand constructors used' pty) rows;
227 val subproblems = partition ty_match ty_inst type_of used'
228 constructors pty range_ty nrows;
229 val new_formals = map #new_formals subproblems
230 val constructors' = map #constructor subproblems
231 val news = map (fn {new_formals, group, ...} =>
232 {path = new_formals @ rstp, rows = group}) subproblems;
233 val (pat_rect, dtrees) = split_list (map mk news);
234 val case_functions = map2
235 (fn {new_formals, names, constraints, ...} =>
236 fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
237 Abs (if s = "" then name else s, T,
238 abstract_over (x, t)) |>
239 fold mk_fun_constrain cnstrts)
240 (new_formals ~~ names ~~ constraints))
242 val types = map type_of (case_functions @ [u]);
243 val case_const = Const (case_name, types ---> range_ty)
244 val tree = list_comb (case_const, case_functions @ [u])
245 val pat_rect1 = flat (map mk_pat
246 (constructors ~~ constructors' ~~ pat_rect))
249 | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
250 Syntax.string_of_term ctxt t, i)
252 | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
256 fun case_error s = error ("Error in case expression:\n" ^ s);
258 (* Repeated variable occurrences in a pattern are not allowed. *)
259 fun no_repeat_vars ctxt pat = fold_aterms
260 (fn x as Free (s, _) => (fn xs =>
261 if member op aconv xs x then
262 case_error (quote s ^ " occurs repeatedly in the pattern " ^
263 quote (Syntax.string_of_term ctxt pat))
267 fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses =
269 fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt
270 (Syntax.const "_case1" $ pat $ rhs);
271 val _ = map (no_repeat_vars ctxt o fst) clauses;
272 val rows = map_index (fn (i, (pat, rhs)) =>
273 (([], [pat]), (rhs, (i, true)))) clauses;
274 val rangeT = (case distinct op = (map (type_of o snd) clauses) of
275 [] => case_error "no clauses given"
277 | _ => case_error "all cases must have the same result type");
278 val used' = fold add_row_used rows used;
279 val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of
280 used' rangeT {path = [x], rows = rows}
281 handle CASE_ERROR (msg, i) => case_error (msg ^
283 else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
285 (fn (_, tag, [pat]) => (pat, tag)
286 | _ => case_error "error in pattern-match translation") patts;
287 val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1
288 val finals = map row_of_pat patts2
289 val originals = map (row_of_pat o #2) rows
290 val _ = case originals \\ finals of
292 | is => (if err then case_error else warning)
293 ("The following clauses are redundant (covered by preceding clauses):\n" ^
294 space_implode "\n" (map (string_of_clause o nth clauses) is));
299 fun make_case tab ctxt = gen_make_case
300 (match_type (ProofContext.theory_of ctxt)) Envir.subst_TVars fastype_of tab ctxt;
301 val make_case_untyped = gen_make_case (K (K Vartab.empty))
302 (K (Term.map_types (K dummyT))) (K dummyT);
305 (* parse translation *)
307 fun case_tr err tab_of ctxt [t, u] =
309 val thy = ProofContext.theory_of ctxt;
310 (* replace occurrences of dummy_pattern by distinct variables *)
311 (* internalize constant names *)
312 fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used =
313 let val (t', used') = prep_pat t used
314 in (c $ t' $ tT, used') end
315 | prep_pat (Const ("dummy_pattern", T)) used =
316 let val x = Name.variant used "x"
317 in (Free (x, T), x :: used) end
318 | prep_pat (Const (s, T)) used =
319 (case try (unprefix Syntax.constN) s of
320 SOME c => (Const (c, T), used)
321 | NONE => (Const (Sign.intern_const thy s, T), used))
322 | prep_pat (v as Free (s, T)) used =
323 let val s' = Sign.intern_const thy s
325 if Sign.declared_const thy s' then
326 (Const (s', T), used)
329 | prep_pat (t $ u) used =
331 val (t', used') = prep_pat t used;
332 val (u', used'') = prep_pat u used'
336 | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
337 fun dest_case1 (t as Const ("_case1", _) $ l $ r) =
338 let val (l', cnstrts) = strip_constraints l
339 in ((fst (prep_pat l' (add_term_free_names (t, []))), r), cnstrts)
341 | dest_case1 t = case_error "dest_case1";
342 fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
343 | dest_case2 t = [t];
344 val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
345 val (case_tm, _) = make_case_untyped (tab_of thy) ctxt err []
346 (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT)
347 (flat cnstrts) t) cases;
349 | case_tr _ _ _ ts = case_error "case_tr";
352 (*---------------------------------------------------------------------------
353 * Pretty printing of nested case expressions
354 *---------------------------------------------------------------------------*)
356 (* destruct one level of pattern matching *)
358 fun gen_dest_case name_of type_of tab d used t =
359 case apfst name_of (strip_comb t) of
360 (SOME cname, ts as _ :: _) =>
362 val (fs, x) = split_last ts;
365 val zs = strip_abs_vars t;
366 val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
367 val (xs, ys) = chop i zs;
368 val u = list_abs (ys, strip_abs_body t);
369 val xs' = map Free (Name.variant_list (add_term_names (u, used))
370 (map fst xs) ~~ map snd xs)
371 in (xs', subst_bounds (rev xs', u)) end;
372 fun is_dependent i t =
373 let val k = length (strip_abs_vars t) - i
374 in k < 0 orelse exists (fn j => j >= k)
375 (loose_bnos (strip_abs_body t))
377 fun count_cases (_, _, true) = I
378 | count_cases (c, (_, body), false) =
379 AList.map_default op aconv (body, []) (cons c);
380 val is_undefined = name_of #> equal (SOME "HOL.undefined");
381 fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
382 in case ty_info tab cname of
383 SOME {constructors, case_name} =>
384 if length fs = length constructors then
386 val cases = map (fn (Const (s, U), t) =>
388 val k = length (binder_types U);
389 val p as (xs, _) = strip_abs k t
391 (Const (s, map type_of xs ---> type_of x),
393 end) (constructors ~~ fs);
394 val cases' = sort (int_ord o swap o pairself (length o snd))
395 (fold_rev count_cases cases []);
397 val dummy = if d then Const ("dummy_pattern", R)
398 else Free (Name.variant used "x", R)
400 SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of
402 if length cs = length constructors then [hd cases]
403 else filter_out (fn (_, (_, body), _) => is_undefined body) cases
404 | NONE => case cases' of
406 | (default, cs) :: _ =>
407 if length cs = 1 then cases
408 else if length cs = length constructors then
409 [hd cases, (dummy, ([], default), false)]
411 filter_out (fn (c, _, _) => member op aconv cs c) cases @
412 [(dummy, ([], default), false)]))
413 end handle CASE_ERROR _ => NONE
419 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
420 val dest_case' = gen_dest_case
421 (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT);
424 (* destruct nested patterns *)
426 fun strip_case' dest (pat, rhs) =
427 case dest (add_term_free_names (pat, [])) rhs of
428 SOME (exp as Free _, clauses) =>
429 if member op aconv (term_frees pat) exp andalso
430 not (exists (fn (_, rhs') =>
431 member op aconv (term_frees rhs') exp) clauses)
433 maps (strip_case' dest) (map (fn (pat', rhs') =>
434 (subst_free [(exp, pat')] pat, rhs')) clauses)
438 fun gen_strip_case dest t = case dest [] t of
440 SOME (x, maps (strip_case' dest) clauses)
443 val strip_case = gen_strip_case oo dest_case;
444 val strip_case' = gen_strip_case oo dest_case';
447 (* print translation *)
449 fun case_tr' tab_of cname ctxt ts =
451 val thy = ProofContext.theory_of ctxt;
452 val consts = ProofContext.consts_of ctxt;
453 fun mk_clause (pat, rhs) =
454 let val xs = term_frees pat
456 Syntax.const "_case1" $
458 (fn Free p => Syntax.mark_boundT p
459 | Const (s, _) => Const (Consts.extern_early consts s, dummyT)
462 (fn x as Free (s, _) =>
463 if member op aconv xs x then Syntax.mark_bound s else x
466 in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
467 SOME (x, clauses) => Syntax.const "_case_syntax" $ x $
468 foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u)
469 (map mk_clause clauses)
470 | NONE => raise Match