tuned;
authorwenzelm
Fri, 16 Dec 2011 11:02:55 +0100
changeset 46769b619242b0439
parent 46768 65cef0298158
child 46771 c9ae2bc95fad
child 46772 df887263a379
tuned;
src/HOL/Tools/Datatype/datatype_case.ML
src/HOL/Tools/Datatype/primrec.ML
     1.1 --- a/src/HOL/Tools/Datatype/datatype_case.ML	Fri Dec 16 10:52:35 2011 +0100
     1.2 +++ b/src/HOL/Tools/Datatype/datatype_case.ML	Fri Dec 16 11:02:55 2011 +0100
     1.3 @@ -130,7 +130,7 @@
     1.4                   names = names,
     1.5                   constraints = cnstrts,
     1.6                   group = in_group'} :: part cs not_in_group
     1.7 -              end
     1.8 +              end;
     1.9        in part constructors rows end;
    1.10  
    1.11  fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
    1.12 @@ -143,7 +143,6 @@
    1.13    let
    1.14      val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
    1.15  
    1.16 -    val name = singleton (Name.variant_list used) "a";
    1.17      fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand_var_row", ~1)
    1.18        | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
    1.19            if is_Free p then
    1.20 @@ -153,7 +152,10 @@
    1.21                  let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
    1.22                  in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end
    1.23              in map expnd constructors end
    1.24 -          else [row]
    1.25 +          else [row];
    1.26 +
    1.27 +    val name = singleton (Name.variant_list used) "a";
    1.28 +
    1.29      fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
    1.30        | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
    1.31        | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
    1.32 @@ -277,19 +279,22 @@
    1.33                  val (u', used'') = prep_pat u used';
    1.34                in (t' $ u', used'') end
    1.35            | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
    1.36 +
    1.37          fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
    1.38                let val (l', cnstrts) = strip_constraints l
    1.39                in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
    1.40            | dest_case1 t = case_error "dest_case1";
    1.41 +
    1.42          fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
    1.43            | dest_case2 t = [t];
    1.44 +
    1.45          val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
    1.46 -        val case_tm =
    1.47 -          make_case_untyped ctxt
    1.48 -            (if err then Error else Warning) []
    1.49 -            (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
    1.50 -               (flat cnstrts) t) cases;
    1.51 -      in case_tm end
    1.52 +      in
    1.53 +        make_case_untyped ctxt
    1.54 +          (if err then Error else Warning) []
    1.55 +          (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
    1.56 +             (flat cnstrts) t) cases
    1.57 +      end
    1.58    | case_tr _ _ _ = case_error "case_tr";
    1.59  
    1.60  val trfun_setup =
     2.1 --- a/src/HOL/Tools/Datatype/primrec.ML	Fri Dec 16 10:52:35 2011 +0100
     2.2 +++ b/src/HOL/Tools/Datatype/primrec.ML	Fri Dec 16 11:02:55 2011 +0100
     2.3 @@ -206,11 +206,11 @@
     2.4  
     2.5  (* find datatypes which contain all datatypes in tnames' *)
     2.6  
     2.7 -fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = []
     2.8 +fun find_dts _ _ [] = []
     2.9    | find_dts dt_info tnames' (tname :: tnames) =
    2.10        (case Symtab.lookup dt_info tname of
    2.11          NONE => primrec_error (quote tname ^ " is not a datatype")
    2.12 -      | SOME dt =>
    2.13 +      | SOME (dt : Datatype_Aux.info) =>
    2.14            if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
    2.15              (tname, dt) :: (find_dts dt_info tnames' tnames)
    2.16            else find_dts dt_info tnames' tnames);
    2.17 @@ -218,12 +218,12 @@
    2.18  
    2.19  (* distill primitive definition(s) from primrec specification *)
    2.20  
    2.21 -fun distill lthy fixes eqs =
    2.22 +fun distill ctxt fixes eqs =
    2.23    let
    2.24 -    val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
    2.25 +    val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v
    2.26        orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
    2.27      val tnames = distinct (op =) (map (#1 o snd) eqns);
    2.28 -    val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
    2.29 +    val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames;
    2.30      val main_fns = map (fn (tname, {index, ...}) =>
    2.31        (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
    2.32      val {descr, rec_names, rec_rewrites, ...} =
    2.33 @@ -232,7 +232,7 @@
    2.34        else snd (hd dts);
    2.35      val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
    2.36      val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
    2.37 -    val defs = map (make_def lthy fixes fs) raw_defs;
    2.38 +    val defs = map (make_def ctxt fixes fs) raw_defs;
    2.39      val names = map snd fnames;
    2.40      val names_eqns = map fst eqns;
    2.41      val _ =
    2.42 @@ -241,17 +241,17 @@
    2.43          "\nare not mutually recursive");
    2.44      val rec_rewrites' = map mk_meta_eq rec_rewrites;
    2.45      val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
    2.46 -    fun prove lthy defs =
    2.47 +    fun prove ctxt defs =
    2.48        let
    2.49 -        val frees = fold (Variable.add_free_names lthy) eqs [];
    2.50 +        val frees = fold (Variable.add_free_names ctxt) eqs [];
    2.51          val rewrites = rec_rewrites' @ map (snd o snd) defs;
    2.52          fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
    2.53 -      in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
    2.54 +      in map (fn eq => Goal.prove ctxt frees [] eq tac) eqs end;
    2.55    in ((prefix, (fs, defs)), prove) end
    2.56    handle PrimrecError (msg, some_eqn) =>
    2.57      error ("Primrec definition error:\n" ^ msg ^
    2.58        (case some_eqn of
    2.59 -        SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
    2.60 +        SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn)
    2.61        | NONE => ""));
    2.62  
    2.63  
    2.64 @@ -259,7 +259,7 @@
    2.65  
    2.66  fun add_primrec_simple fixes ts lthy =
    2.67    let
    2.68 -    val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
    2.69 +    val ((prefix, (_, defs)), prove) = distill lthy fixes ts;
    2.70    in
    2.71      lthy
    2.72      |> fold_map Local_Theory.define defs