1.1 --- a/src/HOL/Tools/Datatype/datatype_case.ML Fri Dec 16 10:52:35 2011 +0100
1.2 +++ b/src/HOL/Tools/Datatype/datatype_case.ML Fri Dec 16 11:02:55 2011 +0100
1.3 @@ -130,7 +130,7 @@
1.4 names = names,
1.5 constraints = cnstrts,
1.6 group = in_group'} :: part cs not_in_group
1.7 - end
1.8 + end;
1.9 in part constructors rows end;
1.10
1.11 fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
1.12 @@ -143,7 +143,6 @@
1.13 let
1.14 val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
1.15
1.16 - val name = singleton (Name.variant_list used) "a";
1.17 fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand_var_row", ~1)
1.18 | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
1.19 if is_Free p then
1.20 @@ -153,7 +152,10 @@
1.21 let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
1.22 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end
1.23 in map expnd constructors end
1.24 - else [row]
1.25 + else [row];
1.26 +
1.27 + val name = singleton (Name.variant_list used) "a";
1.28 +
1.29 fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
1.30 | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
1.31 | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
1.32 @@ -277,19 +279,22 @@
1.33 val (u', used'') = prep_pat u used';
1.34 in (t' $ u', used'') end
1.35 | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
1.36 +
1.37 fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
1.38 let val (l', cnstrts) = strip_constraints l
1.39 in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
1.40 | dest_case1 t = case_error "dest_case1";
1.41 +
1.42 fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
1.43 | dest_case2 t = [t];
1.44 +
1.45 val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
1.46 - val case_tm =
1.47 - make_case_untyped ctxt
1.48 - (if err then Error else Warning) []
1.49 - (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
1.50 - (flat cnstrts) t) cases;
1.51 - in case_tm end
1.52 + in
1.53 + make_case_untyped ctxt
1.54 + (if err then Error else Warning) []
1.55 + (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
1.56 + (flat cnstrts) t) cases
1.57 + end
1.58 | case_tr _ _ _ = case_error "case_tr";
1.59
1.60 val trfun_setup =
2.1 --- a/src/HOL/Tools/Datatype/primrec.ML Fri Dec 16 10:52:35 2011 +0100
2.2 +++ b/src/HOL/Tools/Datatype/primrec.ML Fri Dec 16 11:02:55 2011 +0100
2.3 @@ -206,11 +206,11 @@
2.4
2.5 (* find datatypes which contain all datatypes in tnames' *)
2.6
2.7 -fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = []
2.8 +fun find_dts _ _ [] = []
2.9 | find_dts dt_info tnames' (tname :: tnames) =
2.10 (case Symtab.lookup dt_info tname of
2.11 NONE => primrec_error (quote tname ^ " is not a datatype")
2.12 - | SOME dt =>
2.13 + | SOME (dt : Datatype_Aux.info) =>
2.14 if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
2.15 (tname, dt) :: (find_dts dt_info tnames' tnames)
2.16 else find_dts dt_info tnames' tnames);
2.17 @@ -218,12 +218,12 @@
2.18
2.19 (* distill primitive definition(s) from primrec specification *)
2.20
2.21 -fun distill lthy fixes eqs =
2.22 +fun distill ctxt fixes eqs =
2.23 let
2.24 - val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
2.25 + val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v
2.26 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
2.27 val tnames = distinct (op =) (map (#1 o snd) eqns);
2.28 - val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
2.29 + val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames;
2.30 val main_fns = map (fn (tname, {index, ...}) =>
2.31 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
2.32 val {descr, rec_names, rec_rewrites, ...} =
2.33 @@ -232,7 +232,7 @@
2.34 else snd (hd dts);
2.35 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
2.36 val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
2.37 - val defs = map (make_def lthy fixes fs) raw_defs;
2.38 + val defs = map (make_def ctxt fixes fs) raw_defs;
2.39 val names = map snd fnames;
2.40 val names_eqns = map fst eqns;
2.41 val _ =
2.42 @@ -241,17 +241,17 @@
2.43 "\nare not mutually recursive");
2.44 val rec_rewrites' = map mk_meta_eq rec_rewrites;
2.45 val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
2.46 - fun prove lthy defs =
2.47 + fun prove ctxt defs =
2.48 let
2.49 - val frees = fold (Variable.add_free_names lthy) eqs [];
2.50 + val frees = fold (Variable.add_free_names ctxt) eqs [];
2.51 val rewrites = rec_rewrites' @ map (snd o snd) defs;
2.52 fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
2.53 - in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
2.54 + in map (fn eq => Goal.prove ctxt frees [] eq tac) eqs end;
2.55 in ((prefix, (fs, defs)), prove) end
2.56 handle PrimrecError (msg, some_eqn) =>
2.57 error ("Primrec definition error:\n" ^ msg ^
2.58 (case some_eqn of
2.59 - SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
2.60 + SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn)
2.61 | NONE => ""));
2.62
2.63
2.64 @@ -259,7 +259,7 @@
2.65
2.66 fun add_primrec_simple fixes ts lthy =
2.67 let
2.68 - val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
2.69 + val ((prefix, (_, defs)), prove) = distill lthy fixes ts;
2.70 in
2.71 lthy
2.72 |> fold_map Local_Theory.define defs