src/HOL/Tools/Datatype/datatype_case.ML
author wenzelm
Thu, 09 Jun 2011 16:34:49 +0200
changeset 44206 2b47822868e4
parent 44098 b81fd5c8f2dc
child 45004 44adaa6db327
permissions -rw-r--r--
discontinued Name.variant to emphasize that this is old-style / indirect;
     1 (*  Title:      HOL/Tools/Datatype/datatype_case.ML
     2     Author:     Konrad Slind, Cambridge University Computer Laboratory
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Datatype package: nested case expressions on datatypes.
     6 *)
     7 
     8 signature DATATYPE_CASE =
     9 sig
    10   datatype config = Error | Warning | Quiet
    11   type info = Datatype_Aux.info
    12   val make_case: (string * typ -> info option) ->
    13     Proof.context -> config -> string list -> term -> (term * term) list ->
    14     term
    15   val dest_case: (string -> info option) -> bool ->
    16     string list -> term -> (term * (term * term) list) option
    17   val strip_case: (string -> info option) -> bool ->
    18     term -> (term * (term * term) list) option
    19   val case_tr: bool -> (theory -> string * typ -> info option) ->
    20     Proof.context -> term list -> term
    21   val case_tr': (theory -> string -> info option) ->
    22     string -> Proof.context -> term list -> term
    23 end;
    24 
    25 structure Datatype_Case : DATATYPE_CASE =
    26 struct
    27 
    28 datatype config = Error | Warning | Quiet;
    29 type info = Datatype_Aux.info;
    30 
    31 exception CASE_ERROR of string * int;
    32 
    33 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
    34 
    35 (* Get information about datatypes *)
    36 
    37 fun ty_info tab sT =
    38   (case tab sT of
    39     SOME ({descr, case_name, index, sorts, ...} : info) =>
    40       let
    41         val (_, (tname, dts, constrs)) = nth descr index;
    42         val mk_ty = Datatype_Aux.typ_of_dtyp descr sorts;
    43         val T = Type (tname, map mk_ty dts);
    44       in
    45         SOME {case_name = case_name,
    46           constructors = map (fn (cname, dts') =>
    47             Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
    48       end
    49   | NONE => NONE);
    50 
    51 
    52 (*Each pattern carries with it a tag i, which denotes the clause it
    53 came from. i = ~1 indicates that the clause was added by pattern
    54 completion.*)
    55 
    56 fun add_row_used ((prfx, pats), (tm, tag)) =
    57   fold Term.add_free_names (tm :: pats @ map Free prfx);
    58 
    59 (*try to preserve names given by user*)
    60 fun default_names names ts =
    61   map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
    62 
    63 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
    64       strip_constraints t ||> cons tT
    65   | strip_constraints t = (t, []);
    66 
    67 fun mk_fun_constrain tT t =
    68   Syntax.const @{syntax_const "_constrain"} $ t $
    69     (Syntax.const @{type_syntax fun} $ tT $ Syntax.const @{type_syntax dummy});
    70 
    71 
    72 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
    73 fun fresh_constr ty_match ty_inst colty used c =
    74   let
    75     val (_, Ty) = dest_Const c
    76     val Ts = binder_types Ty;
    77     val names = Name.variant_list used
    78       (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts));
    79     val ty = body_type Ty;
    80     val ty_theta = ty_match ty colty handle Type.TYPE_MATCH =>
    81       raise CASE_ERROR ("type mismatch", ~1)
    82     val c' = ty_inst ty_theta c
    83     val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
    84   in (c', gvars) end;
    85 
    86 
    87 (*Goes through a list of rows and picks out the ones beginning with a
    88  pattern with constructor = name.*)
    89 fun mk_group (name, T) rows =
    90   let val k = length (binder_types T) in
    91     fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
    92       fn ((in_group, not_in_group), (names, cnstrts)) =>
    93         (case strip_comb p of
    94           (Const (name', _), args) =>
    95             if name = name' then
    96               if length args = k then
    97                 let val (args', cnstrts') = split_list (map strip_constraints args)
    98                 in
    99                   ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
   100                    (default_names names args', map2 append cnstrts cnstrts'))
   101                 end
   102               else raise CASE_ERROR
   103                 ("Wrong number of arguments for constructor " ^ name, i)
   104             else ((in_group, row :: not_in_group), (names, cnstrts))
   105         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
   106     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   107   end;
   108 
   109 
   110 (* Partitioning *)
   111 
   112 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
   113   | partition ty_match ty_inst type_of used constructors colty res_ty
   114         (rows as (((prfx, _ :: ps), _) :: _)) =
   115       let
   116         fun part [] [] = []
   117           | part [] ((_, (_, i)) :: _) =
   118               raise CASE_ERROR ("Not a constructor pattern", i)
   119           | part (c :: cs) rows =
   120               let
   121                 val ((in_group, not_in_group), (names, cnstrts)) =
   122                   mk_group (dest_Const c) rows;
   123                 val used' = fold add_row_used in_group used;
   124                 val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
   125                 val in_group' =
   126                   if null in_group  (* Constructor not given *)
   127                   then
   128                     let
   129                       val Ts = map type_of ps;
   130                       val xs = Name.variant_list
   131                         (fold Term.add_free_names gvars used')
   132                         (replicate (length ps) "x")
   133                     in
   134                       [((prfx, gvars @ map Free (xs ~~ Ts)),
   135                         (Const (@{const_syntax undefined}, res_ty), ~1))]
   136                     end
   137                   else in_group
   138               in
   139                 {constructor = c',
   140                  new_formals = gvars,
   141                  names = names,
   142                  constraints = cnstrts,
   143                  group = in_group'} :: part cs not_in_group
   144               end
   145       in part constructors rows end;
   146 
   147 fun v_to_prfx (prfx, Free v::pats) = (v::prfx,pats)
   148   | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
   149 
   150 
   151 (* Translation of pattern terms into nested case expressions. *)
   152  
   153 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
   154   let
   155     val name = singleton (Name.variant_list used) "a";
   156     fun expand constructors used ty ((_, []), _) =
   157           raise CASE_ERROR ("mk_case: expand_var_row", ~1)
   158       | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
   159           if is_Free p then
   160             let
   161               val used' = add_row_used row used;
   162               fun expnd c =
   163                 let val capp =
   164                   list_comb (fresh_constr ty_match ty_inst ty used' c)
   165                 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag))
   166                 end
   167             in map expnd constructors end
   168           else [row]
   169     fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
   170       | mk [] (((_, []), (tm, tag)) :: _) =  (* Done *)
   171           ([tag], tm)
   172       | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) =
   173           mk path [row]
   174       | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
   175           let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
   176             (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
   177               NONE =>
   178                 let
   179                   val rows' = map (fn ((v, _), row) => row ||>
   180                     apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
   181                 in mk us rows' end
   182             | SOME (Const (cname, cT), i) =>
   183                 (case ty_info tab (cname, cT) of
   184                   NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
   185                 | SOME {case_name, constructors} =>
   186                   let
   187                     val pty = body_type cT;
   188                     val used' = fold Term.add_free_names us used;
   189                     val nrows = maps (expand constructors used' pty) rows;
   190                     val subproblems = partition ty_match ty_inst type_of used'
   191                       constructors pty range_ty nrows;
   192                     val (pat_rect, dtrees) = split_list (map (fn {new_formals, group, ...} =>
   193                       mk (new_formals @ us) group) subproblems)
   194                     val case_functions = map2
   195                       (fn {new_formals, names, constraints, ...} =>
   196                          fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
   197                            Abs (if s = "" then name else s, T,
   198                              abstract_over (x, t)) |>
   199                            fold mk_fun_constrain cnstrts)
   200                              (new_formals ~~ names ~~ constraints))
   201                       subproblems dtrees;
   202                     val types = map type_of (case_functions @ [u]);
   203                     val case_const = Const (case_name, types ---> range_ty)
   204                     val tree = list_comb (case_const, case_functions @ [u])
   205                   in (flat pat_rect, tree) end)
   206             | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
   207                 Syntax.string_of_term ctxt t, i))
   208           end
   209       | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
   210   in mk end;
   211 
   212 fun case_error s = error ("Error in case expression:\n" ^ s);
   213 
   214 (*Repeated variable occurrences in a pattern are not allowed.*)
   215 fun no_repeat_vars ctxt pat = fold_aterms
   216   (fn x as Free (s, _) => (fn xs =>
   217         if member op aconv xs x then
   218           case_error (quote s ^ " occurs repeatedly in the pattern " ^
   219             quote (Syntax.string_of_term ctxt pat))
   220         else x :: xs)
   221     | _ => I) pat [];
   222 
   223 fun gen_make_case ty_match ty_inst type_of tab ctxt config used x clauses =
   224   let
   225     fun string_of_clause (pat, rhs) =
   226       Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
   227     val _ = map (no_repeat_vars ctxt o fst) clauses;
   228     val rows = map_index (fn (i, (pat, rhs)) =>
   229       (([], [pat]), (rhs, i))) clauses;
   230     val rangeT =
   231       (case distinct op = (map (type_of o snd) clauses) of
   232         [] => case_error "no clauses given"
   233       | [T] => T
   234       | _ => case_error "all cases must have the same result type");
   235     val used' = fold add_row_used rows used;
   236     val (tags, case_tm) = mk_case tab ctxt ty_match ty_inst type_of
   237         used' rangeT [x] rows
   238       handle CASE_ERROR (msg, i) => case_error (msg ^
   239         (if i < 0 then ""
   240          else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
   241     val _ =
   242       (case subtract (op =) tags (map (snd o snd) rows) of
   243         [] => ()
   244       | is =>
   245           (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
   246             ("The following clauses are redundant (covered by preceding clauses):\n" ^
   247              cat_lines (map (string_of_clause o nth clauses) is)));
   248   in
   249     case_tm
   250   end;
   251 
   252 fun make_case tab ctxt = gen_make_case
   253   (match_type (Proof_Context.theory_of ctxt)) Envir.subst_term_types fastype_of tab ctxt;
   254 val make_case_untyped = gen_make_case (K (K Vartab.empty))
   255   (K (Term.map_types (K dummyT))) (K dummyT);
   256 
   257 
   258 (* parse translation *)
   259 
   260 fun case_tr err tab_of ctxt [t, u] =
   261       let
   262         val thy = Proof_Context.theory_of ctxt;
   263         val intern_const_syntax = Consts.intern_syntax (Proof_Context.consts_of ctxt);
   264 
   265         (* replace occurrences of dummy_pattern by distinct variables *)
   266         (* internalize constant names                                 *)
   267         (* FIXME proper name context!? *)
   268         fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
   269               let val (t', used') = prep_pat t used
   270               in (c $ t' $ tT, used') end
   271           | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
   272               let val x = singleton (Name.variant_list used) "x"
   273               in (Free (x, T), x :: used) end
   274           | prep_pat (Const (s, T)) used =
   275               (Const (intern_const_syntax s, T), used)
   276           | prep_pat (v as Free (s, T)) used =
   277               let val s' = Proof_Context.intern_const ctxt s in
   278                 if Sign.declared_const thy s' then
   279                   (Const (s', T), used)
   280                 else (v, used)
   281               end
   282           | prep_pat (t $ u) used =
   283               let
   284                 val (t', used') = prep_pat t used;
   285                 val (u', used'') = prep_pat u used';
   286               in
   287                 (t' $ u', used'')
   288               end
   289           | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
   290         fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
   291               let val (l', cnstrts) = strip_constraints l
   292               in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
   293           | dest_case1 t = case_error "dest_case1";
   294         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
   295           | dest_case2 t = [t];
   296         val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
   297         val case_tm = make_case_untyped (tab_of thy) ctxt
   298           (if err then Error else Warning) []
   299           (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
   300              (flat cnstrts) t) cases;
   301       in case_tm end
   302   | case_tr _ _ _ ts = case_error "case_tr";
   303 
   304 
   305 (* Pretty printing of nested case expressions *)
   306 
   307 (* destruct one level of pattern matching *)
   308 
   309 (* FIXME proper name context!? *)
   310 fun gen_dest_case name_of type_of tab d used t =
   311   (case apfst name_of (strip_comb t) of
   312     (SOME cname, ts as _ :: _) =>
   313       let
   314         val (fs, x) = split_last ts;
   315         fun strip_abs i t =
   316           let
   317             val zs = strip_abs_vars t;
   318             val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
   319             val (xs, ys) = chop i zs;
   320             val u = list_abs (ys, strip_abs_body t);
   321             val xs' = map Free (Name.variant_list (OldTerm.add_term_names (u, used))
   322               (map fst xs) ~~ map snd xs)
   323           in (xs', subst_bounds (rev xs', u)) end;
   324         fun is_dependent i t =
   325           let val k = length (strip_abs_vars t) - i
   326           in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
   327         fun count_cases (_, _, true) = I
   328           | count_cases (c, (_, body), false) =
   329               AList.map_default op aconv (body, []) (cons c);
   330         val is_undefined = name_of #> equal (SOME @{const_name undefined});
   331         fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
   332       in
   333         (case ty_info tab cname of
   334           SOME {constructors, case_name} =>
   335             if length fs = length constructors then
   336               let
   337                 val cases = map (fn (Const (s, U), t) =>
   338                   let
   339                     val k = length (binder_types U);
   340                     val p as (xs, _) = strip_abs k t
   341                   in
   342                     (Const (s, map type_of xs ---> type_of x),
   343                      p, is_dependent k t)
   344                   end) (constructors ~~ fs);
   345                 val cases' = sort (int_ord o swap o pairself (length o snd))
   346                   (fold_rev count_cases cases []);
   347                 val R = type_of t;
   348                 val dummy =
   349                   if d then Const (@{const_name dummy_pattern}, R)
   350                   else Free (singleton (Name.variant_list used) "x", R);
   351               in
   352                 SOME (x,
   353                   map mk_case
   354                     (case find_first (is_undefined o fst) cases' of
   355                       SOME (_, cs) =>
   356                       if length cs = length constructors then [hd cases]
   357                       else filter_out (fn (_, (_, body), _) => is_undefined body) cases
   358                     | NONE => case cases' of
   359                       [] => cases
   360                     | (default, cs) :: _ =>
   361                       if length cs = 1 then cases
   362                       else if length cs = length constructors then
   363                         [hd cases, (dummy, ([], default), false)]
   364                       else
   365                         filter_out (fn (c, _, _) => member op aconv cs c) cases @
   366                         [(dummy, ([], default), false)]))
   367               end handle CASE_ERROR _ => NONE
   368             else NONE
   369         | _ => NONE)
   370       end
   371   | _ => NONE);
   372 
   373 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
   374 val dest_case' = gen_dest_case (try (dest_Const #> fst #> Lexicon.unmark_const)) (K dummyT);
   375 
   376 
   377 (* destruct nested patterns *)
   378 
   379 fun strip_case'' dest (pat, rhs) =
   380   (case dest (Term.add_free_names pat []) rhs of
   381     SOME (exp as Free _, clauses) =>
   382       if member op aconv (OldTerm.term_frees pat) exp andalso
   383         not (exists (fn (_, rhs') =>
   384           member op aconv (OldTerm.term_frees rhs') exp) clauses)
   385       then
   386         maps (strip_case'' dest) (map (fn (pat', rhs') =>
   387           (subst_free [(exp, pat')] pat, rhs')) clauses)
   388       else [(pat, rhs)]
   389   | _ => [(pat, rhs)]);
   390 
   391 fun gen_strip_case dest t =
   392   (case dest [] t of
   393     SOME (x, clauses) =>
   394       SOME (x, maps (strip_case'' dest) clauses)
   395   | NONE => NONE);
   396 
   397 val strip_case = gen_strip_case oo dest_case;
   398 val strip_case' = gen_strip_case oo dest_case';
   399 
   400 
   401 (* print translation *)
   402 
   403 fun case_tr' tab_of cname ctxt ts =
   404   let
   405     val thy = Proof_Context.theory_of ctxt;
   406     fun mk_clause (pat, rhs) =
   407       let val xs = Term.add_frees pat [] in
   408         Syntax.const @{syntax_const "_case1"} $
   409           map_aterms
   410             (fn Free p => Syntax_Trans.mark_boundT p
   411               | Const (s, _) => Syntax.const (Lexicon.mark_const s)
   412               | t => t) pat $
   413           map_aterms
   414             (fn x as Free (s, T) =>
   415                   if member (op =) xs (s, T) then Syntax_Trans.mark_bound s else x
   416               | t => t) rhs
   417       end;
   418   in
   419     (case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
   420       SOME (x, clauses) =>
   421         Syntax.const @{syntax_const "_case_syntax"} $ x $
   422           foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
   423             (map mk_clause clauses)
   424     | NONE => raise Match)
   425   end;
   426 
   427 end;