src/HOL/ex/predicate_compile.ML
changeset 30374 7311a1546d85
child 30379 1ae7b86638ad
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/ex/predicate_compile.ML	Sun Mar 08 15:25:28 2009 +0100
     1.3 @@ -0,0 +1,1346 @@
     1.4 +(* Author: Lukas Bulwahn
     1.5 +
     1.6 +(Prototype of) A compiler from predicates specified by intro/elim rules
     1.7 +to equations.
     1.8 +*)
     1.9 +
    1.10 +signature PREDICATE_COMPILE =
    1.11 +sig
    1.12 +  val create_def_equation': string -> (int list option list * int list) option -> theory -> theory
    1.13 +  val create_def_equation: string -> theory -> theory
    1.14 +  val intro_rule: theory -> string -> (int list option list * int list) -> thm
    1.15 +  val elim_rule: theory -> string -> (int list option list * int list) -> thm
    1.16 +  val strip_intro_concl : term -> int -> (term * (term list * term list))
    1.17 +  val code_ind_intros_attrib : attribute
    1.18 +  val code_ind_cases_attrib : attribute
    1.19 +  val setup : theory -> theory
    1.20 +  val print_alternative_rules : theory -> theory
    1.21 +  val do_proofs: bool ref
    1.22 +end;
    1.23 +
    1.24 +structure Predicate_Compile: PREDICATE_COMPILE =
    1.25 +struct
    1.26 +
    1.27 +structure PredModetab = TableFun(
    1.28 +  type key = (string * (int list option list * int list))
    1.29 +  val ord = prod_ord fast_string_ord (prod_ord
    1.30 +            (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord)))
    1.31 +
    1.32 +
    1.33 +structure IndCodegenData = TheoryDataFun
    1.34 +(
    1.35 +  type T = {names : string PredModetab.table,
    1.36 +            modes : ((int list option list * int list) list) Symtab.table,
    1.37 +            function_defs : Thm.thm Symtab.table,
    1.38 +            function_intros : Thm.thm Symtab.table,
    1.39 +            function_elims : Thm.thm Symtab.table,
    1.40 +            intro_rules : (Thm.thm list) Symtab.table,
    1.41 +            elim_rules : Thm.thm Symtab.table,
    1.42 +            nparams : int Symtab.table
    1.43 +           };
    1.44 +      (* names: map from inductive predicate and mode to function name (string).
    1.45 +         modes: map from inductive predicates to modes
    1.46 +         function_defs: map from function name to definition
    1.47 +         function_intros: map from function name to intro rule
    1.48 +         function_elims: map from function name to elim rule
    1.49 +         intro_rules: map from inductive predicate to alternative intro rules
    1.50 +         elim_rules: map from inductive predicate to alternative elimination rule
    1.51 +         nparams: map from const name to number of parameters (* assuming there exist intro and elimination rules *) 
    1.52 +       *)
    1.53 +  val empty = {names = PredModetab.empty,
    1.54 +               modes = Symtab.empty,
    1.55 +               function_defs = Symtab.empty,
    1.56 +               function_intros = Symtab.empty,
    1.57 +               function_elims = Symtab.empty,
    1.58 +               intro_rules = Symtab.empty,
    1.59 +               elim_rules = Symtab.empty,
    1.60 +               nparams = Symtab.empty};
    1.61 +  val copy = I;
    1.62 +  val extend = I;
    1.63 +  fun merge _ r = {names = PredModetab.merge (op =) (pairself #names r),
    1.64 +                   modes = Symtab.merge (op =) (pairself #modes r),
    1.65 +                   function_defs = Symtab.merge Thm.eq_thm (pairself #function_defs r),
    1.66 +                   function_intros = Symtab.merge Thm.eq_thm (pairself #function_intros r),
    1.67 +                   function_elims = Symtab.merge Thm.eq_thm (pairself #function_elims r),
    1.68 +                   intro_rules = Symtab.merge ((forall Thm.eq_thm) o (op ~~)) (pairself #intro_rules r),
    1.69 +                   elim_rules = Symtab.merge Thm.eq_thm (pairself #elim_rules r),
    1.70 +                   nparams = Symtab.merge (op =) (pairself #nparams r)};
    1.71 +);
    1.72 +
    1.73 +  fun map_names f thy = IndCodegenData.map
    1.74 +    (fn x => {names = f (#names x), modes = #modes x, function_defs = #function_defs x,
    1.75 +            function_intros = #function_intros x, function_elims = #function_elims x,
    1.76 +            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
    1.77 +            nparams = #nparams x}) thy
    1.78 +
    1.79 +  fun map_modes f thy = IndCodegenData.map
    1.80 +    (fn x => {names = #names x, modes = f (#modes x), function_defs = #function_defs x,
    1.81 +            function_intros = #function_intros x, function_elims = #function_elims x,
    1.82 +            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
    1.83 +            nparams = #nparams x}) thy
    1.84 +
    1.85 +  fun map_function_defs f thy = IndCodegenData.map
    1.86 +    (fn x => {names = #names x, modes = #modes x, function_defs = f (#function_defs x),
    1.87 +            function_intros = #function_intros x, function_elims = #function_elims x,
    1.88 +            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
    1.89 +            nparams = #nparams x}) thy 
    1.90 +  
    1.91 +  fun map_function_elims f thy = IndCodegenData.map
    1.92 +    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
    1.93 +            function_intros = #function_intros x, function_elims = f (#function_elims x),
    1.94 +            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
    1.95 +            nparams = #nparams x}) thy
    1.96 +
    1.97 +  fun map_function_intros f thy = IndCodegenData.map
    1.98 +    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
    1.99 +            function_intros = f (#function_intros x), function_elims = #function_elims x,
   1.100 +            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   1.101 +            nparams = #nparams x}) thy
   1.102 +
   1.103 +  fun map_intro_rules f thy = IndCodegenData.map
   1.104 +    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   1.105 +            function_intros = #function_intros x, function_elims = #function_elims x,
   1.106 +            intro_rules = f (#intro_rules x), elim_rules = #elim_rules x,
   1.107 +            nparams = #nparams x}) thy 
   1.108 +  
   1.109 +  fun map_elim_rules f thy = IndCodegenData.map
   1.110 +    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   1.111 +            function_intros = #function_intros x, function_elims = #function_elims x,
   1.112 +            intro_rules = #intro_rules x, elim_rules = f (#elim_rules x),
   1.113 +            nparams = #nparams x}) thy
   1.114 +
   1.115 +  fun map_nparams f thy = IndCodegenData.map
   1.116 +    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   1.117 +            function_intros = #function_intros x, function_elims = #function_elims x,
   1.118 +            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   1.119 +            nparams = f (#nparams x)}) thy
   1.120 +
   1.121 +(* Debug stuff and tactics ***********************************************************)
   1.122 +
   1.123 +fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
   1.124 +fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
   1.125 +
   1.126 +fun debug_tac msg = (fn st =>
   1.127 +     (tracing msg; Seq.single st));
   1.128 +
   1.129 +(* removes first subgoal *)
   1.130 +fun mycheat_tac thy i st =
   1.131 +  (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
   1.132 +
   1.133 +val (do_proofs : bool ref) = ref true;
   1.134 +
   1.135 +(* Lightweight mode analysis **********************************************)
   1.136 +
   1.137 +(* Hack for message from old code generator *)
   1.138 +val message = tracing;
   1.139 +
   1.140 +
   1.141 +(**************************************************************************)
   1.142 +(* source code from old code generator ************************************)
   1.143 +
   1.144 +(**** check if a term contains only constructor functions ****)
   1.145 +
   1.146 +fun is_constrt thy =
   1.147 +  let
   1.148 +    val cnstrs = flat (maps
   1.149 +      (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
   1.150 +      (Symtab.dest (DatatypePackage.get_datatypes thy)));
   1.151 +    fun check t = (case strip_comb t of
   1.152 +        (Free _, []) => true
   1.153 +      | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
   1.154 +            (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
   1.155 +          | _ => false)
   1.156 +      | _ => false)
   1.157 +  in check end;
   1.158 +
   1.159 +(**** check if a type is an equality type (i.e. doesn't contain fun) ****)
   1.160 +
   1.161 +fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
   1.162 +  | is_eqT _ = true;
   1.163 +
   1.164 +(**** mode inference ****)
   1.165 +
   1.166 +fun string_of_mode (iss, is) = space_implode " -> " (map
   1.167 +  (fn NONE => "X"
   1.168 +    | SOME js => enclose "[" "]" (commas (map string_of_int js)))
   1.169 +       (iss @ [SOME is]));
   1.170 +
   1.171 +fun print_modes modes = message ("Inferred modes:\n" ^
   1.172 +  cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
   1.173 +    string_of_mode ms)) modes));
   1.174 +
   1.175 +fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
   1.176 +val terms_vs = distinct (op =) o maps term_vs;
   1.177 +
   1.178 +(** collect all Frees in a term (with duplicates!) **)
   1.179 +fun term_vTs tm =
   1.180 +  fold_aterms (fn Free xT => cons xT | _ => I) tm [];
   1.181 +
   1.182 +fun get_args is ts = let
   1.183 +  fun get_args' _ _ [] = ([], [])
   1.184 +    | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
   1.185 +        (get_args' is (i+1) ts)
   1.186 +in get_args' is 1 ts end
   1.187 +
   1.188 +fun merge xs [] = xs
   1.189 +  | merge [] ys = ys
   1.190 +  | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
   1.191 +      else y::merge (x::xs) ys;
   1.192 +
   1.193 +fun subsets i j = if i <= j then
   1.194 +       let val is = subsets (i+1) j
   1.195 +       in merge (map (fn ks => i::ks) is) is end
   1.196 +     else [[]];
   1.197 +
   1.198 +fun cprod ([], ys) = []
   1.199 +  | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
   1.200 +
   1.201 +fun cprods xss = foldr (map op :: o cprod) [[]] xss;
   1.202 +
   1.203 +datatype mode = Mode of (int list option list * int list) * int list * mode option list;
   1.204 +
   1.205 +fun modes_of modes t =
   1.206 +  let
   1.207 +    val ks = 1 upto length (binder_types (fastype_of t));
   1.208 +    val default = [Mode (([], ks), ks, [])];
   1.209 +    fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
   1.210 +        let
   1.211 +          val (args1, args2) =
   1.212 +            if length args < length iss then
   1.213 +              error ("Too few arguments for inductive predicate " ^ name)
   1.214 +            else chop (length iss) args;
   1.215 +          val k = length args2;
   1.216 +          val prfx = 1 upto k
   1.217 +        in
   1.218 +          if not (is_prefix op = prfx is) then [] else
   1.219 +          let val is' = map (fn i => i - k) (List.drop (is, k))
   1.220 +          in map (fn x => Mode (m, is', x)) (cprods (map
   1.221 +            (fn (NONE, _) => [NONE]
   1.222 +              | (SOME js, arg) => map SOME (filter
   1.223 +                  (fn Mode (_, js', _) => js=js') (modes_of modes arg)))
   1.224 +                    (iss ~~ args1)))
   1.225 +          end
   1.226 +        end)) (AList.lookup op = modes name)
   1.227 +
   1.228 +  in (case strip_comb t of
   1.229 +      (Const (name, _), args) => the_default default (mk_modes name args)
   1.230 +    | (Var ((name, _), _), args) => the (mk_modes name args)
   1.231 +    | (Free (name, _), args) => the (mk_modes name args)
   1.232 +    | _ => default)
   1.233 +  end
   1.234 +
   1.235 +datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term;
   1.236 +
   1.237 +fun select_mode_prem thy modes vs ps =
   1.238 +  find_first (is_some o snd) (ps ~~ map
   1.239 +    (fn Prem (us, t) => find_first (fn Mode (_, is, _) =>
   1.240 +          let
   1.241 +            val (in_ts, out_ts) = get_args is us;
   1.242 +            val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
   1.243 +            val vTs = maps term_vTs out_ts';
   1.244 +            val dupTs = map snd (duplicates (op =) vTs) @
   1.245 +              List.mapPartial (AList.lookup (op =) vTs) vs;
   1.246 +          in
   1.247 +            terms_vs (in_ts @ in_ts') subset vs andalso
   1.248 +            forall (is_eqT o fastype_of) in_ts' andalso
   1.249 +            term_vs t subset vs andalso
   1.250 +            forall is_eqT dupTs
   1.251 +          end)
   1.252 +            (modes_of modes t handle Option =>
   1.253 +               error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
   1.254 +      | Negprem (us, t) => find_first (fn Mode (_, is, _) =>
   1.255 +            length us = length is andalso
   1.256 +            terms_vs us subset vs andalso
   1.257 +            term_vs t subset vs)
   1.258 +            (modes_of modes t handle Option =>
   1.259 +               error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
   1.260 +      | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
   1.261 +          else NONE
   1.262 +      ) ps);
   1.263 +
   1.264 +fun check_mode_clause thy param_vs modes (iss, is) (ts, ps) =
   1.265 +  let
   1.266 +    val modes' = modes @ List.mapPartial
   1.267 +      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   1.268 +        (param_vs ~~ iss); 
   1.269 +    fun check_mode_prems vs [] = SOME vs
   1.270 +      | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
   1.271 +          NONE => NONE
   1.272 +        | SOME (x, _) => check_mode_prems
   1.273 +            (case x of Prem (us, _) => vs union terms_vs us | _ => vs)
   1.274 +            (filter_out (equal x) ps))
   1.275 +    val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is ts));
   1.276 +    val in_vs = terms_vs in_ts;
   1.277 +    val concl_vs = terms_vs ts
   1.278 +  in
   1.279 +    forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
   1.280 +    forall (is_eqT o fastype_of) in_ts' andalso
   1.281 +    (case check_mode_prems (param_vs union in_vs) ps of
   1.282 +       NONE => false
   1.283 +     | SOME vs => concl_vs subset vs)
   1.284 +  end;
   1.285 +
   1.286 +fun check_modes_pred thy param_vs preds modes (p, ms) =
   1.287 +  let val SOME rs = AList.lookup (op =) preds p
   1.288 +  in (p, List.filter (fn m => case find_index
   1.289 +    (not o check_mode_clause thy param_vs modes m) rs of
   1.290 +      ~1 => true
   1.291 +    | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
   1.292 +      p ^ " violates mode " ^ string_of_mode m); false)) ms)
   1.293 +  end;
   1.294 +
   1.295 +fun fixp f (x : (string * (int list option list * int list) list) list) =
   1.296 +  let val y = f x
   1.297 +  in if x = y then x else fixp f y end;
   1.298 +
   1.299 +fun infer_modes thy extra_modes arities param_vs preds = fixp (fn modes =>
   1.300 +  map (check_modes_pred thy param_vs preds (modes @ extra_modes)) modes)
   1.301 +    (map (fn (s, (ks, k)) => (s, cprod (cprods (map
   1.302 +      (fn NONE => [NONE]
   1.303 +        | SOME k' => map SOME (subsets 1 k')) ks),
   1.304 +      subsets 1 k))) arities);
   1.305 +
   1.306 +
   1.307 +(*****************************************************************************************)
   1.308 +(**** end of old source code *************************************************************)
   1.309 +(*****************************************************************************************)
   1.310 +(**** term construction ****)
   1.311 +
   1.312 +fun mk_eq (x, xs) =
   1.313 +  let fun mk_eqs _ [] = []
   1.314 +        | mk_eqs a (b::cs) =
   1.315 +            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
   1.316 +  in mk_eqs x xs end;
   1.317 +
   1.318 +fun mk_tuple [] = HOLogic.unit
   1.319 +  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
   1.320 +
   1.321 +fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
   1.322 +  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
   1.323 +  | dest_tuple t = [t]
   1.324 +
   1.325 +fun mk_tupleT [] = HOLogic.unitT
   1.326 +  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
   1.327 +
   1.328 +fun mk_pred_enumT T = Type ("Predicate.pred", [T])
   1.329 +
   1.330 +fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
   1.331 +  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
   1.332 +
   1.333 +fun mk_single t =
   1.334 +  let val T = fastype_of t
   1.335 +  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
   1.336 +
   1.337 +fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
   1.338 +
   1.339 +fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
   1.340 +                          HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) 
   1.341 +                         $ cond
   1.342 +
   1.343 +fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
   1.344 +  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
   1.345 +
   1.346 +fun mk_bind (x, f) =
   1.347 +  let val T as Type ("fun", [_, U]) = fastype_of f
   1.348 +  in
   1.349 +    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
   1.350 +  end;
   1.351 +
   1.352 +fun mk_Enum f =
   1.353 +  let val T as Type ("fun", [T', _]) = fastype_of f
   1.354 +  in
   1.355 +    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
   1.356 +  end;
   1.357 +
   1.358 +fun mk_Eval (f, x) =
   1.359 +  let val T = fastype_of x
   1.360 +  in
   1.361 +    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
   1.362 +  end;
   1.363 +
   1.364 +fun mk_Eval' f =
   1.365 +  let val T = fastype_of f
   1.366 +  in
   1.367 +    Const (@{const_name Predicate.eval}, T --> dest_pred_enumT T --> HOLogic.boolT) $ f
   1.368 +  end; 
   1.369 +
   1.370 +val mk_sup = HOLogic.mk_binop @{const_name sup};
   1.371 +
   1.372 +(* for simple modes (e.g. parameters) only: better call it param_funT *)
   1.373 +(* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
   1.374 +fun funT_of T NONE = T
   1.375 +  | funT_of T (SOME mode) = let
   1.376 +     val Ts = binder_types T;
   1.377 +     val (Us1, Us2) = get_args mode Ts
   1.378 +   in Us1 ---> (mk_pred_enumT (mk_tupleT Us2)) end;
   1.379 +
   1.380 +fun funT'_of (iss, is) T = let
   1.381 +    val Ts = binder_types T
   1.382 +    val (paramTs, argTs) = chop (length iss) Ts
   1.383 +    val paramTs' = map2 (fn SOME is => funT'_of ([], is) | NONE => I) iss paramTs 
   1.384 +    val (inargTs, outargTs) = get_args is argTs
   1.385 +  in
   1.386 +    (paramTs' @ inargTs) ---> (mk_pred_enumT (mk_tupleT outargTs))
   1.387 +  end; 
   1.388 +
   1.389 +
   1.390 +fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of
   1.391 +      NONE => ((names, (s, [])::vs), Free (s, T))
   1.392 +    | SOME xs =>
   1.393 +        let
   1.394 +          val s' = Name.variant names s;
   1.395 +          val v = Free (s', T)
   1.396 +        in
   1.397 +          ((s'::names, AList.update (op =) (s, v::xs) vs), v)
   1.398 +        end);
   1.399 +
   1.400 +fun distinct_v (nvs, Free (s, T)) = mk_v nvs s T
   1.401 +  | distinct_v (nvs, t $ u) =
   1.402 +      let
   1.403 +        val (nvs', t') = distinct_v (nvs, t);
   1.404 +        val (nvs'', u') = distinct_v (nvs', u);
   1.405 +      in (nvs'', t' $ u') end
   1.406 +  | distinct_v x = x;
   1.407 +
   1.408 +fun compile_match thy eqs eqs' out_ts success_t =
   1.409 +  let 
   1.410 +    val eqs'' = maps mk_eq eqs @ eqs'
   1.411 +    val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
   1.412 +    val name = Name.variant names "x";
   1.413 +    val name' = Name.variant (name :: names) "y";
   1.414 +    val T = mk_tupleT (map fastype_of out_ts);
   1.415 +    val U = fastype_of success_t;
   1.416 +    val U' = dest_pred_enumT U;
   1.417 +    val v = Free (name, T);
   1.418 +    val v' = Free (name', T);
   1.419 +  in
   1.420 +    lambda v (fst (DatatypePackage.make_case
   1.421 +      (ProofContext.init thy) false [] v
   1.422 +      [(mk_tuple out_ts,
   1.423 +        if null eqs'' then success_t
   1.424 +        else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
   1.425 +          foldr1 HOLogic.mk_conj eqs'' $ success_t $
   1.426 +            mk_empty U'),
   1.427 +       (v', mk_empty U')]))
   1.428 +  end;
   1.429 +
   1.430 +fun modename thy name mode = let
   1.431 +    val v = (PredModetab.lookup (#names (IndCodegenData.get thy)) (name, mode))
   1.432 +  in if (is_some v) then the v
   1.433 +     else error ("fun modename - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
   1.434 +  end
   1.435 +
   1.436 +(* function can be removed *)
   1.437 +fun mk_funcomp f t =
   1.438 +  let
   1.439 +    val names = Term.add_free_names t [];
   1.440 +    val Ts = binder_types (fastype_of t);
   1.441 +    val vs = map Free
   1.442 +      (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
   1.443 +  in
   1.444 +    fold_rev lambda vs (f (list_comb (t, vs)))
   1.445 +  end;
   1.446 +
   1.447 +fun compile_param thy modes (NONE, t) = t
   1.448 +  | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = let
   1.449 +    val (f, args) = strip_comb t
   1.450 +    val (params, args') = chop (length ms) args
   1.451 +    val params' = map (compile_param thy modes) (ms ~~ params)
   1.452 +    val f' = case f of
   1.453 +        Const (name, T) =>
   1.454 +          if AList.defined op = modes name then
   1.455 +            Const (modename thy name (iss, is'), funT'_of (iss, is') T)
   1.456 +          else error "compile param: Not an inductive predicate with correct mode"
   1.457 +      | Free (name, T) => Free (name, funT_of T (SOME is'))
   1.458 +    in list_comb (f', params' @ args') end
   1.459 +  | compile_param _ _ _ = error "compile params"
   1.460 +
   1.461 +fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) =
   1.462 +      (case strip_comb t of
   1.463 +         (Const (name, T), params) =>
   1.464 +           if AList.defined op = modes name then
   1.465 +             let
   1.466 +               val (Ts, Us) = get_args is
   1.467 +                 (curry Library.drop (length ms) (fst (strip_type T)))
   1.468 +               val params' = map (compile_param thy modes) (ms ~~ params)
   1.469 +               val mode_id = modename thy name mode
   1.470 +             in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) --->
   1.471 +               mk_pred_enumT (mk_tupleT Us)), params')
   1.472 +             end
   1.473 +           else error "not a valid inductive expression"
   1.474 +       | (Free (name, T), args) =>
   1.475 +         (*if name mem param_vs then *)
   1.476 +         (* Higher order mode call *)
   1.477 +         let val r = Free (name, funT_of T (SOME is))
   1.478 +         in list_comb (r, args) end)
   1.479 +  | compile_expr _ _ _ = error "not a valid inductive expression"
   1.480 +
   1.481 +
   1.482 +fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp =
   1.483 +  let
   1.484 +    val modes' = modes @ List.mapPartial
   1.485 +      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   1.486 +        (param_vs ~~ iss);
   1.487 +    fun check_constrt ((names, eqs), t) =
   1.488 +      if is_constrt thy t then ((names, eqs), t) else
   1.489 +        let
   1.490 +          val s = Name.variant names "x";
   1.491 +          val v = Free (s, fastype_of t)
   1.492 +        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
   1.493 +
   1.494 +    val (in_ts, out_ts) = get_args is ts;
   1.495 +    val ((all_vs', eqs), in_ts') =
   1.496 +      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
   1.497 +
   1.498 +    fun compile_prems out_ts' vs names [] =
   1.499 +          let
   1.500 +            val ((names', eqs'), out_ts'') =
   1.501 +              (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts');
   1.502 +            val (nvs, out_ts''') = (*FIXME*) Library.foldl_map distinct_v
   1.503 +              ((names', map (rpair []) vs), out_ts'');
   1.504 +          in
   1.505 +            compile_match thy (snd nvs) (eqs @ eqs') out_ts'''
   1.506 +              (mk_single (mk_tuple out_ts))
   1.507 +          end
   1.508 +      | compile_prems out_ts vs names ps =
   1.509 +          let
   1.510 +            val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
   1.511 +            val SOME (p, mode as SOME (Mode (_, js, _))) =
   1.512 +              select_mode_prem thy modes' vs' ps
   1.513 +            val ps' = filter_out (equal p) ps
   1.514 +            val ((names', eqs), out_ts') =
   1.515 +              (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts)
   1.516 +            val (nvs, out_ts'') = (*FIXME*) Library.foldl_map distinct_v
   1.517 +              ((names', map (rpair []) vs), out_ts')
   1.518 +            val (compiled_clause, rest) = case p of
   1.519 +               Prem (us, t) =>
   1.520 +                 let
   1.521 +                   val (in_ts, out_ts''') = get_args js us;
   1.522 +                   val u = list_comb (compile_expr thy modes (mode, t), in_ts)
   1.523 +                   val rest = compile_prems out_ts''' vs' (fst nvs) ps'
   1.524 +                 in
   1.525 +                   (u, rest)
   1.526 +                 end
   1.527 +             | Negprem (us, t) =>
   1.528 +                 let
   1.529 +                   val (in_ts, out_ts''') = get_args js us
   1.530 +                   val u = list_comb (compile_expr thy modes (mode, t), in_ts)
   1.531 +                   val rest = compile_prems out_ts''' vs' (fst nvs) ps'
   1.532 +                 in
   1.533 +                   (mk_not_pred u, rest)
   1.534 +                 end
   1.535 +             | Sidecond t =>
   1.536 +                 let
   1.537 +                   val rest = compile_prems [] vs' (fst nvs) ps';
   1.538 +                 in
   1.539 +                   (mk_if_predenum t, rest)
   1.540 +                 end
   1.541 +          in
   1.542 +            compile_match thy (snd nvs) eqs out_ts'' 
   1.543 +              (mk_bind (compiled_clause, rest))
   1.544 +          end
   1.545 +    val prem_t = compile_prems in_ts' param_vs all_vs' ps;
   1.546 +  in
   1.547 +    mk_bind (mk_single inp, prem_t)
   1.548 +  end
   1.549 +
   1.550 +fun compile_pred thy all_vs param_vs modes s T cls mode =
   1.551 +  let
   1.552 +    val Ts = binder_types T;
   1.553 +    val (Ts1, Ts2) = chop (length param_vs) Ts;
   1.554 +    val Ts1' = map2 funT_of Ts1 (fst mode)
   1.555 +    val (Us1, Us2) = get_args (snd mode) Ts2;
   1.556 +    val xnames = Name.variant_list param_vs
   1.557 +      (map (fn i => "x" ^ string_of_int i) (snd mode));
   1.558 +    val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1;
   1.559 +    val cl_ts =
   1.560 +      map (fn cl => compile_clause thy
   1.561 +        all_vs param_vs modes mode cl (mk_tuple xs)) cls;
   1.562 +    val mode_id = modename thy s mode
   1.563 +  in
   1.564 +    HOLogic.mk_Trueprop (HOLogic.mk_eq
   1.565 +      (list_comb (Const (mode_id, (Ts1' @ Us1) --->
   1.566 +           mk_pred_enumT (mk_tupleT Us2)),
   1.567 +         map2 (fn s => fn T => Free (s, T)) param_vs Ts1' @ xs),
   1.568 +       foldr1 mk_sup cl_ts))
   1.569 +  end;
   1.570 +
   1.571 +fun compile_preds thy all_vs param_vs modes preds =
   1.572 +  map (fn (s, (T, cls)) =>
   1.573 +    map (compile_pred thy all_vs param_vs modes s T cls)
   1.574 +      ((the o AList.lookup (op =) modes) s)) preds;
   1.575 +
   1.576 +(* end of term construction ******************************************************)
   1.577 +
   1.578 +(* special setup for simpset *)                  
   1.579 +val HOL_basic_ss' = HOL_basic_ss setSolver 
   1.580 +  (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
   1.581 +
   1.582 +
   1.583 +(* misc: constructing and proving tupleE rules ***********************************)
   1.584 +
   1.585 +
   1.586 +(* Creating definitions of functional programs 
   1.587 +   and proving intro and elim rules **********************************************) 
   1.588 +
   1.589 +fun is_ind_pred thy c = 
   1.590 +  (can (InductivePackage.the_inductive (ProofContext.init thy)) c) orelse
   1.591 +  (c mem_string (Symtab.keys (#intro_rules (IndCodegenData.get thy))))
   1.592 +
   1.593 +fun get_name_of_ind_calls_of_clauses thy preds intrs =
   1.594 +    fold Term.add_consts intrs [] |> map fst
   1.595 +    |> filter_out (member (op =) preds) |> filter (is_ind_pred thy)
   1.596 +
   1.597 +fun print_arities arities = message ("Arities:\n" ^
   1.598 +  cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
   1.599 +    space_implode " -> " (map
   1.600 +      (fn NONE => "X" | SOME k' => string_of_int k')
   1.601 +        (ks @ [SOME k]))) arities));
   1.602 +
   1.603 +fun mk_Eval_of ((x, T), NONE) names = (x, names)
   1.604 +  | mk_Eval_of ((x, T), SOME mode) names = let
   1.605 +  val Ts = binder_types T
   1.606 +  val argnames = Name.variant_list names
   1.607 +        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   1.608 +  val args = map Free (argnames ~~ Ts)
   1.609 +  val (inargs, outargs) = get_args mode args
   1.610 +  val r = mk_Eval (list_comb (x, inargs), mk_tuple outargs)
   1.611 +  val t = fold_rev lambda args r 
   1.612 +in
   1.613 +  (t, argnames @ names)
   1.614 +end;
   1.615 +
   1.616 +fun create_intro_rule nparams mode defthm mode_id funT pred thy =
   1.617 +let
   1.618 +  val Ts = binder_types (fastype_of pred)
   1.619 +  val funtrm = Const (mode_id, funT)
   1.620 +  val argnames = Name.variant_list []
   1.621 +        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   1.622 +  val (Ts1, Ts2) = chop nparams Ts;
   1.623 +  val Ts1' = map2 funT_of Ts1 (fst mode)
   1.624 +  val args = map Free (argnames ~~ (Ts1' @ Ts2))
   1.625 +  val (params, io_args) = chop nparams args
   1.626 +  val (inargs, outargs) = get_args (snd mode) io_args
   1.627 +  val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ (fst mode)) []
   1.628 +  val predprop = HOLogic.mk_Trueprop (list_comb (pred, params' @ io_args))
   1.629 +  val funargs = params @ inargs
   1.630 +  val funpropE = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
   1.631 +                  if null outargs then Free("y", HOLogic.unitT) else mk_tuple outargs))
   1.632 +  val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
   1.633 +                   mk_tuple outargs))
   1.634 +  val introtrm = Logic.mk_implies (predprop, funpropI)
   1.635 +  val simprules = [defthm, @{thm eval_pred},
   1.636 +                   @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
   1.637 +  val unfolddef_tac = (Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1)
   1.638 +  val introthm = Goal.prove (ProofContext.init thy) (argnames @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
   1.639 +  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
   1.640 +  val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predprop, P)], P)
   1.641 +  val elimthm = Goal.prove (ProofContext.init thy) (argnames @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
   1.642 +in
   1.643 +  map_function_intros (Symtab.update_new (mode_id, introthm)) thy
   1.644 +  |> map_function_elims (Symtab.update_new (mode_id, elimthm))
   1.645 +  |> PureThy.store_thm (Binding.name (NameSpace.base_name mode_id ^ "I"), introthm) |> snd
   1.646 +  |> PureThy.store_thm (Binding.name (NameSpace.base_name mode_id ^ "E"), elimthm)  |> snd
   1.647 +end;
   1.648 +
   1.649 +fun create_definitions preds nparams (name, modes) thy =
   1.650 +  let
   1.651 +    val _ = tracing "create definitions"
   1.652 +    val T = AList.lookup (op =) preds name |> the
   1.653 +    fun create_definition mode thy = let
   1.654 +      fun string_of_mode mode = if null mode then "0"
   1.655 +        else space_implode "_" (map string_of_int mode)
   1.656 +      val HOmode = let
   1.657 +        fun string_of_HOmode m s = case m of NONE => s | SOME mode => s ^ "__" ^ (string_of_mode mode)    
   1.658 +        in (fold string_of_HOmode (fst mode) "") end;
   1.659 +      val mode_id = name ^ (if HOmode = "" then "_" else HOmode ^ "___")
   1.660 +        ^ (string_of_mode (snd mode))
   1.661 +      val Ts = binder_types T;
   1.662 +      val (Ts1, Ts2) = chop nparams Ts;
   1.663 +      val Ts1' = map2 funT_of Ts1 (fst mode)
   1.664 +      val (Us1, Us2) = get_args (snd mode) Ts2;
   1.665 +      val names = Name.variant_list []
   1.666 +        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   1.667 +      val xs = map Free (names ~~ (Ts1' @ Ts2));
   1.668 +      val (xparams, xargs) = chop nparams xs;
   1.669 +      val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ (fst mode)) names
   1.670 +      val (xins, xouts) = get_args (snd mode) xargs;
   1.671 +      fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
   1.672 +       | mk_split_lambda [x] t = lambda x t
   1.673 +       | mk_split_lambda xs t = let
   1.674 +         fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
   1.675 +           | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
   1.676 +         in mk_split_lambda' xs t end;
   1.677 +      val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs)))
   1.678 +      val funT = (Ts1' @ Us1) ---> (mk_pred_enumT (mk_tupleT Us2))
   1.679 +      val mode_id = Sign.full_bname thy (NameSpace.base_name mode_id)
   1.680 +      val lhs = list_comb (Const (mode_id, funT), xparams @ xins)
   1.681 +      val def = Logic.mk_equals (lhs, predterm)
   1.682 +      val ([defthm], thy') = thy |>
   1.683 +        Sign.add_consts_i [(NameSpace.base_name mode_id, funT, NoSyn)] |>
   1.684 +        PureThy.add_defs false [((Binding.name (NameSpace.base_name mode_id ^ "_def"), def), [])]
   1.685 +      in thy' |> map_names (PredModetab.update_new ((name, mode), mode_id))
   1.686 +           |> map_function_defs (Symtab.update_new (mode_id, defthm))
   1.687 +           |> create_intro_rule nparams mode defthm mode_id funT (Const (name, T))
   1.688 +      end;
   1.689 +  in
   1.690 +    fold create_definition modes thy
   1.691 +  end;
   1.692 +
   1.693 +(**************************************************************************************)
   1.694 +(* Proving equivalence of term *)
   1.695 +
   1.696 +
   1.697 +fun intro_rule thy pred mode = modename thy pred mode
   1.698 +    |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the
   1.699 +
   1.700 +fun elim_rule thy pred mode = modename thy pred mode
   1.701 +    |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the
   1.702 +
   1.703 +fun pred_intros thy predname = let
   1.704 +    fun is_intro_of pred intro = let
   1.705 +      val const = fst (strip_comb (HOLogic.dest_Trueprop (concl_of intro)))
   1.706 +    in (fst (dest_Const const) = pred) end;
   1.707 +    val d = IndCodegenData.get thy
   1.708 +  in
   1.709 +    if (Symtab.defined (#intro_rules d) predname) then
   1.710 +      rev (Symtab.lookup_list (#intro_rules d) predname)
   1.711 +    else
   1.712 +      InductivePackage.the_inductive (ProofContext.init thy) predname
   1.713 +      |> snd |> #intrs |> filter (is_intro_of predname)
   1.714 +  end
   1.715 +
   1.716 +fun function_definition thy pred mode =
   1.717 +  modename thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
   1.718 +
   1.719 +fun is_Type (Type _) = true
   1.720 +  | is_Type _ = false
   1.721 +
   1.722 +fun imp_prems_conv cv ct =
   1.723 +  case Thm.term_of ct of
   1.724 +    Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   1.725 +  | _ => Conv.all_conv ct
   1.726 +
   1.727 +fun Trueprop_conv cv ct =
   1.728 +  case Thm.term_of ct of
   1.729 +    Const ("Trueprop", _) $ _ => Conv.arg_conv cv ct  
   1.730 +  | _ => error "Trueprop_conv"
   1.731 +
   1.732 +fun preprocess_intro thy rule = Thm.transfer thy rule (*FIXME preprocessor
   1.733 +  Conv.fconv_rule
   1.734 +    (imp_prems_conv
   1.735 +      (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @ {thm Predicate.eq_is_eq})))))
   1.736 +    (Thm.transfer thy rule) *)
   1.737 +
   1.738 +fun preprocess_elim thy nargs elimrule = (*FIXME preprocessor -- let
   1.739 +   fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
   1.740 +      HOLogic.mk_Trueprop (Const (@ {const_name Predicate.eq}, T) $ lhs $ rhs)
   1.741 +    | replace_eqs t = t
   1.742 +   fun preprocess_case t = let
   1.743 +     val params = Logic.strip_params t
   1.744 +     val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
   1.745 +     val assums_hyp' = assums1 @ (map replace_eqs assums2)
   1.746 +     in list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t)) end
   1.747 +   val prems = Thm.prems_of elimrule
   1.748 +   val cases' = map preprocess_case (tl prems)
   1.749 +   val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
   1.750 + in
   1.751 +   Thm.equal_elim
   1.752 +     (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@ {thm eq_is_eq}])
   1.753 +        (cterm_of thy elimrule')))
   1.754 +     elimrule
   1.755 + end*) elimrule;
   1.756 +
   1.757 +
   1.758 +(* returns true if t is an application of an datatype constructor *)
   1.759 +(* which then consequently would be splitted *)
   1.760 +(* else false *)
   1.761 +fun is_constructor thy t =
   1.762 +  if (is_Type (fastype_of t)) then
   1.763 +    (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
   1.764 +      NONE => false
   1.765 +    | SOME info => (let
   1.766 +      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
   1.767 +      val (c, _) = strip_comb t
   1.768 +      in (case c of
   1.769 +        Const (name, _) => name mem_string constr_consts
   1.770 +        | _ => false) end))
   1.771 +  else false
   1.772 +
   1.773 +(* MAJOR FIXME:  prove_params should be simple
   1.774 + - different form of introrule for parameters ? *)
   1.775 +fun prove_param thy modes (NONE, t) = all_tac 
   1.776 +  | prove_param thy modes (m as SOME (Mode (mode, is, ms)), t) = let
   1.777 +    val  (f, args) = strip_comb t
   1.778 +    val (params, _) = chop (length ms) args
   1.779 +    val f_tac = case f of
   1.780 +        Const (name, T) => simp_tac (HOL_basic_ss addsimps 
   1.781 +           @{thm eval_pred}::function_definition thy name mode::[]) 1
   1.782 +      | Free _ => all_tac
   1.783 +  in  
   1.784 +    print_tac "before simplification in prove_args:"
   1.785 +    THEN debug_tac ("mode" ^ (makestring mode))
   1.786 +    THEN f_tac
   1.787 +    THEN print_tac "after simplification in prove_args"
   1.788 +    (* work with parameter arguments *)
   1.789 +    THEN (EVERY (map (prove_param thy modes) (ms ~~ params)))
   1.790 +    THEN (REPEAT_DETERM (atac 1))
   1.791 +  end
   1.792 +
   1.793 +fun prove_expr thy modes (SOME (Mode (mode, is, ms)), t, us) (premposition : int) =
   1.794 +  (case strip_comb t of
   1.795 +    (Const (name, T), args) =>
   1.796 +      if AList.defined op = modes name then (let
   1.797 +          val introrule = intro_rule thy name mode
   1.798 +          (*val (in_args, out_args) = get_args is us
   1.799 +          val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop
   1.800 +            (hd (Logic.strip_imp_prems (prop_of introrule))))
   1.801 +          val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *)
   1.802 +          val (_, args) = chop nparams rargs
   1.803 +          val _ = tracing ("args: " ^ (makestring args))
   1.804 +          val subst = map (pairself (cterm_of thy)) (args ~~ us)
   1.805 +          val _ = tracing ("subst: " ^ (makestring subst))
   1.806 +          val inst_introrule = Drule.cterm_instantiate subst introrule*)
   1.807 +         (* the next line is old and probably wrong *)
   1.808 +          val (args1, args2) = chop (length ms) args
   1.809 +          val _ = tracing ("premposition: " ^ (makestring premposition))
   1.810 +        in
   1.811 +        rtac @{thm bindI} 1
   1.812 +        THEN print_tac "before intro rule:"
   1.813 +        THEN debug_tac ("mode" ^ (makestring mode))
   1.814 +        THEN debug_tac (makestring introrule)
   1.815 +        THEN debug_tac ("premposition: " ^ (makestring premposition))
   1.816 +        (* for the right assumption in first position *)
   1.817 +        THEN rotate_tac premposition 1
   1.818 +        THEN rtac introrule 1
   1.819 +        THEN print_tac "after intro rule"
   1.820 +        (* work with parameter arguments *)
   1.821 +        THEN (EVERY (map (prove_param thy modes) (ms ~~ args1)))
   1.822 +        THEN (REPEAT_DETERM (atac 1)) end)
   1.823 +      else error "Prove expr if case not implemented"
   1.824 +    | _ => rtac @{thm bindI} 1
   1.825 +           THEN atac 1)
   1.826 +  | prove_expr _ _ _ _ =  error "Prove expr not implemented"
   1.827 +
   1.828 +fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; 
   1.829 +
   1.830 +fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
   1.831 +
   1.832 +fun prove_match thy (out_ts : term list) = let
   1.833 +  fun get_case_rewrite t =
   1.834 +    if (is_constructor thy t) then let
   1.835 +      val case_rewrites = (#case_rewrites (DatatypePackage.the_datatype thy
   1.836 +        ((fst o dest_Type o fastype_of) t)))
   1.837 +      in case_rewrites @ (flat (map get_case_rewrite (snd (strip_comb t)))) end
   1.838 +    else []
   1.839 +  val simprules = @{thm "unit.cases"} :: @{thm "prod.cases"} :: (flat (map get_case_rewrite out_ts))
   1.840 +(* replace TRY by determining if it necessary - are there equations when calling compile match? *)
   1.841 +in
   1.842 +  print_tac ("before prove_match rewriting: simprules = " ^ (makestring simprules))
   1.843 +   (* make this simpset better! *)
   1.844 +  THEN asm_simp_tac (HOL_basic_ss' addsimps simprules) 1
   1.845 +  THEN print_tac "after prove_match:"
   1.846 +  THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm "HOL.if_P"}] 1
   1.847 +         THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss 1))))
   1.848 +         THEN (SOLVED (asm_simp_tac HOL_basic_ss 1)))))
   1.849 +  THEN print_tac "after if simplification"
   1.850 +end;
   1.851 +
   1.852 +(* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
   1.853 +
   1.854 +fun prove_sidecond thy modes t = let
   1.855 +  val _ = tracing ("prove_sidecond:" ^ (makestring t))
   1.856 +  fun preds_of t nameTs = case strip_comb t of 
   1.857 +    (f as Const (name, T), args) =>
   1.858 +      if AList.defined (op =) modes name then (name, T) :: nameTs
   1.859 +        else fold preds_of args nameTs
   1.860 +    | _ => nameTs
   1.861 +  val preds = preds_of t []
   1.862 +  
   1.863 +  val _ = tracing ("preds: " ^ (makestring preds))
   1.864 +  val defs = map
   1.865 +    (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
   1.866 +      preds
   1.867 +  val _ = tracing ("defs: " ^ (makestring defs))
   1.868 +in 
   1.869 +   (* remove not_False_eq_True when simpset in prove_match is better *)
   1.870 +   simp_tac (HOL_basic_ss addsimps @{thm not_False_eq_True} :: @{thm eval_pred} :: defs) 1 
   1.871 +   (* need better control here! *)
   1.872 +   THEN print_tac "after sidecond simplification"
   1.873 +   end
   1.874 +
   1.875 +fun prove_clause thy nargs all_vs param_vs modes (iss, is) (ts, ps) = let
   1.876 +  val modes' = modes @ List.mapPartial
   1.877 +   (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   1.878 +     (param_vs ~~ iss);
   1.879 +  fun check_constrt ((names, eqs), t) =
   1.880 +      if is_constrt thy t then ((names, eqs), t) else
   1.881 +        let
   1.882 +          val s = Name.variant names "x";
   1.883 +          val v = Free (s, fastype_of t)
   1.884 +        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
   1.885 +  
   1.886 +  val (in_ts, clause_out_ts) = get_args is ts;
   1.887 +  val ((all_vs', eqs), in_ts') =
   1.888 +      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
   1.889 +  fun prove_prems out_ts vs [] =
   1.890 +    (prove_match thy out_ts)
   1.891 +    THEN asm_simp_tac HOL_basic_ss' 1
   1.892 +    THEN print_tac "before the last rule of singleI:"
   1.893 +    THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
   1.894 +  | prove_prems out_ts vs rps =
   1.895 +    let
   1.896 +      val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
   1.897 +      val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
   1.898 +        select_mode_prem thy modes' vs' rps;
   1.899 +      val premposition = (find_index (equal p) ps) + nargs
   1.900 +      val rps' = filter_out (equal p) rps;
   1.901 +      val rest_tac = (case p of Prem (us, t) =>
   1.902 +          let
   1.903 +            val (in_ts, out_ts''') = get_args js us
   1.904 +            val rec_tac = prove_prems out_ts''' vs' rps'
   1.905 +          in
   1.906 +            print_tac "before clause:"
   1.907 +            THEN asm_simp_tac HOL_basic_ss 1
   1.908 +            THEN print_tac "before prove_expr:"
   1.909 +            THEN prove_expr thy modes (mode, t, us) premposition
   1.910 +            THEN print_tac "after prove_expr:"
   1.911 +            THEN rec_tac
   1.912 +          end
   1.913 +        | Negprem (us, t) =>
   1.914 +          let
   1.915 +            val (in_ts, out_ts''') = get_args js us
   1.916 +            val rec_tac = prove_prems out_ts''' vs' rps'
   1.917 +            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
   1.918 +            val (_, params) = strip_comb t
   1.919 +          in
   1.920 +            print_tac "before negated clause:"
   1.921 +            THEN rtac @{thm bindI} 1
   1.922 +            THEN (if (is_some name) then
   1.923 +                simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1
   1.924 +                THEN rtac @{thm not_predI} 1
   1.925 +                THEN print_tac "after neg. intro rule"
   1.926 +                THEN print_tac ("t = " ^ (makestring t))
   1.927 +                (* FIXME: work with parameter arguments *)
   1.928 +                THEN (EVERY (map (prove_param thy modes) (param_modes ~~ params)))
   1.929 +              else
   1.930 +                rtac @{thm not_predI'} 1)
   1.931 +            THEN (REPEAT_DETERM (atac 1))
   1.932 +            THEN rec_tac
   1.933 +          end
   1.934 +        | Sidecond t =>
   1.935 +         rtac @{thm bindI} 1
   1.936 +         THEN rtac @{thm if_predI} 1
   1.937 +         THEN print_tac "before sidecond:"
   1.938 +         THEN prove_sidecond thy modes t
   1.939 +         THEN print_tac "after sidecond:"
   1.940 +         THEN prove_prems [] vs' rps')
   1.941 +    in (prove_match thy out_ts)
   1.942 +        THEN rest_tac
   1.943 +    end;
   1.944 +  val prems_tac = prove_prems in_ts' param_vs ps
   1.945 +in
   1.946 +  rtac @{thm bindI} 1
   1.947 +  THEN rtac @{thm singleI} 1
   1.948 +  THEN prems_tac
   1.949 +end;
   1.950 +
   1.951 +fun select_sup 1 1 = []
   1.952 +  | select_sup _ 1 = [rtac @{thm supI1}]
   1.953 +  | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
   1.954 +
   1.955 +fun get_nparams thy s = let
   1.956 +    val _ = tracing ("get_nparams: " ^ s)
   1.957 +  in
   1.958 +  if Symtab.defined (#nparams (IndCodegenData.get thy)) s then
   1.959 +    the (Symtab.lookup (#nparams (IndCodegenData.get thy)) s) 
   1.960 +  else
   1.961 +    case try (InductivePackage.the_inductive (ProofContext.init thy)) s of
   1.962 +      SOME info => info |> snd |> #raw_induct |> Thm.unvarify
   1.963 +        |> InductivePackage.params_of |> length
   1.964 +    | NONE => 0 (* default value *)
   1.965 +  end
   1.966 +
   1.967 +val ind_set_codegen_preproc = InductiveSetPackage.codegen_preproc;
   1.968 +
   1.969 +fun pred_elim thy predname =
   1.970 +  if (Symtab.defined (#elim_rules (IndCodegenData.get thy)) predname) then
   1.971 +    the (Symtab.lookup (#elim_rules (IndCodegenData.get thy)) predname)
   1.972 +  else
   1.973 +    (let
   1.974 +      val ind_result = InductivePackage.the_inductive (ProofContext.init thy) predname
   1.975 +      val index = find_index (fn s => s = predname) (#names (fst ind_result))
   1.976 +    in nth (#elims (snd ind_result)) index end)
   1.977 +
   1.978 +fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let
   1.979 +  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename thy pred mode))
   1.980 +(*  val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred
   1.981 +  val index = find_index (fn s => s = pred) (#names (fst ind_result))
   1.982 +  val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *)
   1.983 +  val nargs = length (binder_types T) - get_nparams thy pred
   1.984 +  val pred_case_rule = singleton (ind_set_codegen_preproc thy)
   1.985 +    (preprocess_elim thy nargs (pred_elim thy pred))
   1.986 +  (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}])*)
   1.987 +  val _ = tracing ("pred_case_rule " ^ (makestring pred_case_rule))
   1.988 +in
   1.989 +  REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
   1.990 +  THEN etac elim_rule 1
   1.991 +  THEN etac pred_case_rule 1
   1.992 +  THEN (EVERY (map
   1.993 +         (fn i => EVERY' (select_sup (length clauses) i) i) 
   1.994 +           (1 upto (length clauses))))
   1.995 +  THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses))
   1.996 +end;
   1.997 +
   1.998 +(*******************************************************************************************************)
   1.999 +(* Proof in the other direction ************************************************************************)
  1.1000 +(*******************************************************************************************************)
  1.1001 +
  1.1002 +fun prove_match2 thy out_ts = let
  1.1003 +  fun split_term_tac (Free _) = all_tac
  1.1004 +    | split_term_tac t =
  1.1005 +      if (is_constructor thy t) then let
  1.1006 +        val info = DatatypePackage.the_datatype thy ((fst o dest_Type o fastype_of) t)
  1.1007 +        val num_of_constrs = length (#case_rewrites info)
  1.1008 +        (* special treatment of pairs -- because of fishing *)
  1.1009 +        val split_rules = case (fst o dest_Type o fastype_of) t of
  1.1010 +          "*" => [@{thm prod.split_asm}] 
  1.1011 +          | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
  1.1012 +        val (_, ts) = strip_comb t
  1.1013 +      in
  1.1014 +        print_tac ("splitting with t = " ^ (makestring t))
  1.1015 +        THEN (Splitter.split_asm_tac split_rules 1)
  1.1016 +(*        THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
  1.1017 +          THEN (DETERM (TRY (etac @{thm Pair_inject} 1))) *)
  1.1018 +        THEN (REPEAT_DETERM_N (num_of_constrs - 1) (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))
  1.1019 +        THEN (EVERY (map split_term_tac ts))
  1.1020 +      end
  1.1021 +    else all_tac
  1.1022 +  in
  1.1023 +    split_term_tac (mk_tuple out_ts)
  1.1024 +    THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) THEN (etac @{thm botE} 2))))
  1.1025 +  end
  1.1026 +
  1.1027 +(* VERY LARGE SIMILIRATIY to function prove_param 
  1.1028 +-- join both functions
  1.1029 +*) 
  1.1030 +fun prove_param2 thy modes (NONE, t) = all_tac 
  1.1031 +  | prove_param2 thy modes (m as SOME (Mode (mode, is, ms)), t) = let
  1.1032 +    val  (f, args) = strip_comb t
  1.1033 +    val (params, _) = chop (length ms) args
  1.1034 +    val f_tac = case f of
  1.1035 +        Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
  1.1036 +           @{thm eval_pred}::function_definition thy name mode::[]) 1
  1.1037 +      | Free _ => all_tac
  1.1038 +  in  
  1.1039 +    print_tac "before simplification in prove_args:"
  1.1040 +    THEN debug_tac ("function : " ^ (makestring f) ^ " - mode" ^ (makestring mode))
  1.1041 +    THEN f_tac
  1.1042 +    THEN print_tac "after simplification in prove_args"
  1.1043 +    (* work with parameter arguments *)
  1.1044 +    THEN (EVERY (map (prove_param2 thy modes) (ms ~~ params)))
  1.1045 +  end
  1.1046 +
  1.1047 +fun prove_expr2 thy modes (SOME (Mode (mode, is, ms)), t) = 
  1.1048 +  (case strip_comb t of
  1.1049 +    (Const (name, T), args) =>
  1.1050 +      if AList.defined op = modes name then
  1.1051 +        etac @{thm bindE} 1
  1.1052 +        THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
  1.1053 +        THEN (etac (elim_rule thy name mode) 1)
  1.1054 +        THEN (EVERY (map (prove_param2 thy modes) (ms ~~ args)))
  1.1055 +      else error "Prove expr2 if case not implemented"
  1.1056 +    | _ => etac @{thm bindE} 1)
  1.1057 +  | prove_expr2 _ _ _ = error "Prove expr2 not implemented"
  1.1058 +
  1.1059 +fun prove_sidecond2 thy modes t = let
  1.1060 +  val _ = tracing ("prove_sidecond:" ^ (makestring t))
  1.1061 +  fun preds_of t nameTs = case strip_comb t of 
  1.1062 +    (f as Const (name, T), args) =>
  1.1063 +      if AList.defined (op =) modes name then (name, T) :: nameTs
  1.1064 +        else fold preds_of args nameTs
  1.1065 +    | _ => nameTs
  1.1066 +  val preds = preds_of t []
  1.1067 +  val _ = tracing ("preds: " ^ (makestring preds))
  1.1068 +  val defs = map
  1.1069 +    (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
  1.1070 +      preds
  1.1071 +  in
  1.1072 +   (* only simplify the one assumption *)
  1.1073 +   full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 
  1.1074 +   (* need better control here! *)
  1.1075 +   THEN print_tac "after sidecond2 simplification"
  1.1076 +   end
  1.1077 +  
  1.1078 +fun prove_clause2 thy all_vs param_vs modes (iss, is) (ts, ps) pred i = let
  1.1079 +  val modes' = modes @ List.mapPartial
  1.1080 +   (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
  1.1081 +     (param_vs ~~ iss);
  1.1082 +  fun check_constrt ((names, eqs), t) =
  1.1083 +      if is_constrt thy t then ((names, eqs), t) else
  1.1084 +        let
  1.1085 +          val s = Name.variant names "x";
  1.1086 +          val v = Free (s, fastype_of t)
  1.1087 +        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
  1.1088 +  val pred_intro_rule = nth (pred_intros thy pred) (i - 1)
  1.1089 +    |> preprocess_intro thy
  1.1090 +    |> (fn thm => hd (ind_set_codegen_preproc thy [thm]))
  1.1091 +    (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *)
  1.1092 +  val (in_ts, clause_out_ts) = get_args is ts;
  1.1093 +  val ((all_vs', eqs), in_ts') =
  1.1094 +      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
  1.1095 +  fun prove_prems2 out_ts vs [] =
  1.1096 +    print_tac "before prove_match2 - last call:"
  1.1097 +    THEN prove_match2 thy out_ts
  1.1098 +    THEN print_tac "after prove_match2 - last call:"
  1.1099 +    THEN (etac @{thm singleE} 1)
  1.1100 +    THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
  1.1101 +    THEN (asm_full_simp_tac HOL_basic_ss' 1)
  1.1102 +    THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
  1.1103 +    THEN (asm_full_simp_tac HOL_basic_ss' 1)
  1.1104 +    THEN SOLVED (print_tac "state before applying intro rule:"
  1.1105 +      THEN (rtac pred_intro_rule 1)
  1.1106 +      (* How to handle equality correctly? *)
  1.1107 +      THEN (print_tac "state before assumption matching")
  1.1108 +      THEN (REPEAT (atac 1 ORELSE 
  1.1109 +         (CHANGED (asm_full_simp_tac HOL_basic_ss' 1)
  1.1110 +          THEN print_tac "state after simp_tac:"))))
  1.1111 +  | prove_prems2 out_ts vs ps = let
  1.1112 +      val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
  1.1113 +      val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
  1.1114 +        select_mode_prem thy modes' vs' ps;
  1.1115 +      val ps' = filter_out (equal p) ps;
  1.1116 +      val rest_tac = (case p of Prem (us, t) =>
  1.1117 +          let
  1.1118 +            val (in_ts, out_ts''') = get_args js us
  1.1119 +            val rec_tac = prove_prems2 out_ts''' vs' ps'
  1.1120 +          in
  1.1121 +            (prove_expr2 thy modes (mode, t)) THEN rec_tac
  1.1122 +          end
  1.1123 +        | Negprem (us, t) =>
  1.1124 +          let
  1.1125 +            val (in_ts, out_ts''') = get_args js us
  1.1126 +            val rec_tac = prove_prems2 out_ts''' vs' ps'
  1.1127 +            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
  1.1128 +            val (_, params) = strip_comb t
  1.1129 +          in
  1.1130 +            print_tac "before neg prem 2"
  1.1131 +            THEN etac @{thm bindE} 1
  1.1132 +            THEN (if is_some name then
  1.1133 +                full_simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1 
  1.1134 +                THEN etac @{thm not_predE} 1
  1.1135 +                THEN (EVERY (map (prove_param2 thy modes) (param_modes ~~ params)))
  1.1136 +              else
  1.1137 +                etac @{thm not_predE'} 1)
  1.1138 +            THEN rec_tac
  1.1139 +          end 
  1.1140 +        | Sidecond t =>
  1.1141 +            etac @{thm bindE} 1
  1.1142 +            THEN etac @{thm if_predE} 1
  1.1143 +            THEN prove_sidecond2 thy modes t 
  1.1144 +            THEN prove_prems2 [] vs' ps')
  1.1145 +    in print_tac "before prove_match2:"
  1.1146 +       THEN prove_match2 thy out_ts
  1.1147 +       THEN print_tac "after prove_match2:"
  1.1148 +       THEN rest_tac
  1.1149 +    end;
  1.1150 +  val prems_tac = prove_prems2 in_ts' param_vs ps 
  1.1151 +in
  1.1152 +  print_tac "starting prove_clause2"
  1.1153 +  THEN etac @{thm bindE} 1
  1.1154 +  THEN (etac @{thm singleE'} 1)
  1.1155 +  THEN (TRY (etac @{thm Pair_inject} 1))
  1.1156 +  THEN print_tac "after singleE':"
  1.1157 +  THEN prems_tac
  1.1158 +end;
  1.1159 + 
  1.1160 +fun prove_other_direction thy all_vs param_vs modes clauses (pred, mode) = let
  1.1161 +  fun prove_clause (clause, i) =
  1.1162 +    (if i < length clauses then etac @{thm supE} 1 else all_tac)
  1.1163 +    THEN (prove_clause2 thy all_vs param_vs modes mode clause pred i)
  1.1164 +in
  1.1165 +  (DETERM (TRY (rtac @{thm unit.induct} 1)))
  1.1166 +   THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
  1.1167 +   THEN (rtac (intro_rule thy pred mode) 1)
  1.1168 +   THEN (EVERY (map prove_clause (clauses ~~ (1 upto (length clauses)))))
  1.1169 +end;
  1.1170 +
  1.1171 +fun prove_pred thy all_vs param_vs modes clauses (((pred, T), mode), t) = let
  1.1172 +  val ctxt = ProofContext.init thy
  1.1173 +  val clauses' = the (AList.lookup (op =) clauses pred)
  1.1174 +in
  1.1175 +  Goal.prove ctxt (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) t []) [] t
  1.1176 +    (if !do_proofs then
  1.1177 +      (fn _ =>
  1.1178 +      rtac @{thm pred_iffI} 1
  1.1179 +      THEN prove_one_direction thy all_vs param_vs modes clauses' ((pred, T), mode)
  1.1180 +      THEN print_tac "proved one direction"
  1.1181 +      THEN prove_other_direction thy all_vs param_vs modes clauses' (pred, mode)
  1.1182 +      THEN print_tac "proved other direction")
  1.1183 +     else (fn _ => mycheat_tac thy 1))
  1.1184 +end;
  1.1185 +
  1.1186 +fun prove_preds thy all_vs param_vs modes clauses pmts =
  1.1187 +  map (prove_pred thy all_vs param_vs modes clauses) pmts
  1.1188 +
  1.1189 +(* look for other place where this functionality was used before *)
  1.1190 +fun strip_intro_concl intro nparams = let
  1.1191 +  val _ $ u = Logic.strip_imp_concl intro
  1.1192 +  val (pred, all_args) = strip_comb u
  1.1193 +  val (params, args) = chop nparams all_args
  1.1194 +in (pred, (params, args)) end
  1.1195 +
  1.1196 +(* setup for alternative introduction and elimination rules *)
  1.1197 +
  1.1198 +fun add_intro_thm thm thy = let
  1.1199 +   val (pred, _) = dest_Const (fst (strip_intro_concl (prop_of thm) 0))
  1.1200 + in map_intro_rules (Symtab.insert_list Thm.eq_thm (pred, thm)) thy end
  1.1201 +
  1.1202 +fun add_elim_thm thm thy = let
  1.1203 +    val (pred, _) = dest_Const (fst 
  1.1204 +      (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
  1.1205 +  in map_elim_rules (Symtab.update (pred, thm)) thy end
  1.1206 +
  1.1207 +
  1.1208 +(* special case: inductive predicate with no clauses *)
  1.1209 +fun noclause (predname, T) thy = let
  1.1210 +  val Ts = binder_types T
  1.1211 +  val names = Name.variant_list []
  1.1212 +        (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
  1.1213 +  val vs = map Free (names ~~ Ts)
  1.1214 +  val clausehd =  HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs))
  1.1215 +  val intro_t = Logic.mk_implies (@{prop False}, clausehd)
  1.1216 +  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
  1.1217 +  val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
  1.1218 +  val intro_thm = Goal.prove (ProofContext.init thy) names [] intro_t
  1.1219 +        (fn {...} => etac @{thm FalseE} 1)
  1.1220 +  val elim_thm = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
  1.1221 +        (fn {...} => etac (pred_elim thy predname) 1) 
  1.1222 +in
  1.1223 +  add_intro_thm intro_thm thy
  1.1224 +  |> add_elim_thm elim_thm
  1.1225 +end
  1.1226 +
  1.1227 +(*************************************************************************************)
  1.1228 +(* main function *********************************************************************)
  1.1229 +(*************************************************************************************)
  1.1230 +
  1.1231 +fun create_def_equation' ind_name (mode : (int list option list * int list) option) thy =
  1.1232 +let
  1.1233 +  val _ = tracing ("starting create_def_equation' with " ^ ind_name)
  1.1234 +  val (prednames, preds) = 
  1.1235 +    case (try (InductivePackage.the_inductive (ProofContext.init thy)) ind_name) of
  1.1236 +      SOME info => let val preds = info |> snd |> #preds
  1.1237 +        in (map (fst o dest_Const) preds, map ((apsnd Logic.unvarifyT) o dest_Const) preds) end
  1.1238 +    | NONE => let
  1.1239 +        val pred = Symtab.lookup (#intro_rules (IndCodegenData.get thy)) ind_name
  1.1240 +          |> the |> hd |> prop_of
  1.1241 +          |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> strip_comb
  1.1242 +          |> fst |>  dest_Const |> apsnd Logic.unvarifyT
  1.1243 +       in ([ind_name], [pred]) end
  1.1244 +  val thy' = fold (fn pred as (predname, T) => fn thy =>
  1.1245 +    if null (pred_intros thy predname) then noclause pred thy else thy) preds thy
  1.1246 +  val intrs = map (preprocess_intro thy') (maps (pred_intros thy') prednames)
  1.1247 +    |> ind_set_codegen_preproc thy' (*FIXME preprocessor
  1.1248 +    |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*)
  1.1249 +    |> map (Logic.unvarify o prop_of)
  1.1250 +  val _ = tracing ("preprocessed intro rules:" ^ (makestring (map (cterm_of thy') intrs)))
  1.1251 +  val name_of_calls = get_name_of_ind_calls_of_clauses thy' prednames intrs 
  1.1252 +  val _ = tracing ("calling preds: " ^ makestring name_of_calls)
  1.1253 +  val _ = tracing "starting recursive compilations"
  1.1254 +  fun rec_call name thy = 
  1.1255 +    if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then
  1.1256 +      create_def_equation name thy else thy
  1.1257 +  val thy'' = fold rec_call name_of_calls thy'
  1.1258 +  val _ = tracing "returning from recursive calls"
  1.1259 +  val _ = tracing "starting mode inference"
  1.1260 +  val extra_modes = Symtab.dest (#modes (IndCodegenData.get thy''))
  1.1261 +  val nparams = get_nparams thy'' ind_name
  1.1262 +  val _ $ u = Logic.strip_imp_concl (hd intrs);
  1.1263 +  val params = List.take (snd (strip_comb u), nparams);
  1.1264 +  val param_vs = maps term_vs params
  1.1265 +  val all_vs = terms_vs intrs
  1.1266 +  fun dest_prem t =
  1.1267 +      (case strip_comb t of
  1.1268 +        (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
  1.1269 +      | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of
  1.1270 +          Prem (ts, t) => Negprem (ts, t)
  1.1271 +        | Negprem _ => error ("Double negation not allowed in premise: " ^ (makestring (c $ t))) 
  1.1272 +        | Sidecond t => Sidecond (c $ t))
  1.1273 +      | (c as Const (s, _), ts) =>
  1.1274 +        if is_ind_pred thy'' s then
  1.1275 +          let val (ts1, ts2) = chop (get_nparams thy'' s) ts
  1.1276 +          in Prem (ts2, list_comb (c, ts1)) end
  1.1277 +        else Sidecond t
  1.1278 +      | _ => Sidecond t)
  1.1279 +  fun add_clause intr (clauses, arities) =
  1.1280 +  let
  1.1281 +    val _ $ t = Logic.strip_imp_concl intr;
  1.1282 +    val (Const (name, T), ts) = strip_comb t;
  1.1283 +    val (ts1, ts2) = chop nparams ts;
  1.1284 +    val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
  1.1285 +    val (Ts, Us) = chop nparams (binder_types T)
  1.1286 +  in
  1.1287 +    (AList.update op = (name, these (AList.lookup op = clauses name) @
  1.1288 +      [(ts2, prems)]) clauses,
  1.1289 +     AList.update op = (name, (map (fn U => (case strip_type U of
  1.1290 +                 (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs)
  1.1291 +               | _ => NONE)) Ts,
  1.1292 +             length Us)) arities)
  1.1293 +  end;
  1.1294 +  val (clauses, arities) = fold add_clause intrs ([], []);
  1.1295 +  val modes = infer_modes thy'' extra_modes arities param_vs clauses
  1.1296 +  val _ = print_arities arities;
  1.1297 +  val _ = print_modes modes;
  1.1298 +  val modes = if (is_some mode) then AList.update (op =) (ind_name, [the mode]) modes else modes
  1.1299 +  val _ = print_modes modes
  1.1300 +  val thy''' = fold (create_definitions preds nparams) modes thy''
  1.1301 +    |> map_modes (fold Symtab.update_new modes)
  1.1302 +  val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses
  1.1303 +  val _ = tracing "compiling predicates..."
  1.1304 +  val ts = compile_preds thy''' all_vs param_vs (extra_modes @ modes) clauses'
  1.1305 +  val _ = tracing "returned term from compile_preds"
  1.1306 +  val pred_mode = maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses'
  1.1307 +  val _ = tracing "starting proof"
  1.1308 +  val result_thms = prove_preds thy''' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts))
  1.1309 +  val (_, thy'''') = yield_singleton PureThy.add_thmss
  1.1310 +    ((Binding.name (NameSpace.base_name ind_name ^ "_codegen" (*FIXME other suffix*)), result_thms),
  1.1311 +      [Attrib.attribute_i thy''' Code.add_default_eqn_attrib]) thy'''
  1.1312 +in
  1.1313 +  thy''''
  1.1314 +end
  1.1315 +and create_def_equation ind_name thy = create_def_equation' ind_name NONE thy
  1.1316 +
  1.1317 +fun set_nparams (pred, nparams) thy = map_nparams (Symtab.update (pred, nparams)) thy
  1.1318 +
  1.1319 +fun print_alternative_rules thy = let
  1.1320 +    val d = IndCodegenData.get thy
  1.1321 +    val preds = (Symtab.keys (#intro_rules d)) union (Symtab.keys (#elim_rules d))
  1.1322 +    val _ = tracing ("preds: " ^ (makestring preds))
  1.1323 +    fun print pred = let
  1.1324 +      val _ = tracing ("predicate: " ^ pred)
  1.1325 +      val _ = tracing ("introrules: ")
  1.1326 +      val _ = fold (fn thm => fn u => tracing (makestring thm))
  1.1327 +        (rev (Symtab.lookup_list (#intro_rules d) pred)) ()
  1.1328 +      val _ = tracing ("casesrule: ")
  1.1329 +      val _ = tracing (makestring (Symtab.lookup (#elim_rules d) pred))
  1.1330 +    in () end
  1.1331 +    val _ = map print preds
  1.1332 + in thy end; 
  1.1333 +  
  1.1334 +fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I)
  1.1335 +
  1.1336 +val code_ind_intros_attrib = attrib add_intro_thm
  1.1337 +
  1.1338 +val code_ind_cases_attrib = attrib add_elim_thm
  1.1339 +
  1.1340 +val setup = Attrib.add_attributes
  1.1341 +    [("code_ind_intros", Attrib.no_args code_ind_intros_attrib,
  1.1342 +      "adding alternative introduction rules for code generation of inductive predicates"),
  1.1343 +     ("code_ind_cases", Attrib.no_args code_ind_cases_attrib, 
  1.1344 +      "adding alternative elimination rules for code generation of inductive predicates")]
  1.1345 +
  1.1346 +end;
  1.1347 +
  1.1348 +fun pred_compile name thy = Predicate_Compile.create_def_equation
  1.1349 +  (Sign.intern_const thy name) thy;