src/HOL/ex/predicate_compile.ML
author haftmann
Fri, 24 Apr 2009 17:45:16 +0200
changeset 30972 5b65835ccc92
parent 30530 7173bf123335
child 31108 0ce5f53fc65d
permissions -rw-r--r--
some experiements towards user interface for predicate compiler
     1 (* Author: Lukas Bulwahn
     2 
     3 (Prototype of) A compiler from predicates specified by intro/elim rules
     4 to equations.
     5 *)
     6 
     7 signature PREDICATE_COMPILE =
     8 sig
     9   type mode = int list option list * int list
    10   val create_def_equation': string -> mode option -> theory -> theory
    11   val create_def_equation: string -> theory -> theory
    12   val intro_rule: theory -> string -> mode -> thm
    13   val elim_rule: theory -> string -> mode -> thm
    14   val strip_intro_concl : term -> int -> (term * (term list * term list))
    15   val code_ind_intros_attrib : attribute
    16   val code_ind_cases_attrib : attribute
    17   val print_alternative_rules : theory -> theory
    18   val modename_of: theory -> string -> mode -> string
    19   val modes_of: theory -> string -> mode list
    20   val setup : theory -> theory
    21   val do_proofs: bool ref
    22 end;
    23 
    24 structure Predicate_Compile: PREDICATE_COMPILE =
    25 struct
    26 
    27 (** auxiliary **)
    28 
    29 (* debug stuff *)
    30 
    31 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
    32 
    33 fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
    34 fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
    35 
    36 val do_proofs = ref true;
    37 
    38 
    39 (** fundamentals **)
    40 
    41 (* syntactic operations *)
    42 
    43 fun mk_eq (x, xs) =
    44   let fun mk_eqs _ [] = []
    45         | mk_eqs a (b::cs) =
    46             HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
    47   in mk_eqs x xs end;
    48 
    49 fun mk_tupleT [] = HOLogic.unitT
    50   | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
    51 
    52 fun mk_tuple [] = HOLogic.unit
    53   | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
    54 
    55 fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
    56   | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
    57   | dest_tuple t = [t]
    58 
    59 fun mk_pred_enumT T = Type ("Predicate.pred", [T])
    60 
    61 fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
    62   | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
    63 
    64 fun mk_Enum f =
    65   let val T as Type ("fun", [T', _]) = fastype_of f
    66   in
    67     Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
    68   end;
    69 
    70 fun mk_Eval (f, x) =
    71   let val T = fastype_of x
    72   in
    73     Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
    74   end;
    75 
    76 fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
    77 
    78 fun mk_single t =
    79   let val T = fastype_of t
    80   in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
    81 
    82 fun mk_bind (x, f) =
    83   let val T as Type ("fun", [_, U]) = fastype_of f
    84   in
    85     Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
    86   end;
    87 
    88 val mk_sup = HOLogic.mk_binop @{const_name sup};
    89 
    90 fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
    91   HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
    92 
    93 fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
    94   in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
    95 
    96 
    97 (* data structures *)
    98 
    99 type mode = int list option list * int list;
   100 
   101 val mode_ord = prod_ord (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord);
   102 
   103 structure PredModetab = TableFun(
   104   type key = string * mode
   105   val ord = prod_ord fast_string_ord mode_ord
   106 );
   107 
   108 
   109 (*FIXME scrap boilerplate*)
   110 
   111 structure IndCodegenData = TheoryDataFun
   112 (
   113   type T = {names : string PredModetab.table,
   114             modes : mode list Symtab.table,
   115             function_defs : Thm.thm Symtab.table,
   116             function_intros : Thm.thm Symtab.table,
   117             function_elims : Thm.thm Symtab.table,
   118             intro_rules : Thm.thm list Symtab.table,
   119             elim_rules : Thm.thm Symtab.table,
   120             nparams : int Symtab.table
   121            }; (*FIXME: better group tables according to key*)
   122       (* names: map from inductive predicate and mode to function name (string).
   123          modes: map from inductive predicates to modes
   124          function_defs: map from function name to definition
   125          function_intros: map from function name to intro rule
   126          function_elims: map from function name to elim rule
   127          intro_rules: map from inductive predicate to alternative intro rules
   128          elim_rules: map from inductive predicate to alternative elimination rule
   129          nparams: map from const name to number of parameters (* assuming there exist intro and elimination rules *) 
   130        *)
   131   val empty = {names = PredModetab.empty,
   132                modes = Symtab.empty,
   133                function_defs = Symtab.empty,
   134                function_intros = Symtab.empty,
   135                function_elims = Symtab.empty,
   136                intro_rules = Symtab.empty,
   137                elim_rules = Symtab.empty,
   138                nparams = Symtab.empty};
   139   val copy = I;
   140   val extend = I;
   141   fun merge _ r = {names = PredModetab.merge (op =) (pairself #names r),
   142                    modes = Symtab.merge (op =) (pairself #modes r),
   143                    function_defs = Symtab.merge Thm.eq_thm (pairself #function_defs r),
   144                    function_intros = Symtab.merge Thm.eq_thm (pairself #function_intros r),
   145                    function_elims = Symtab.merge Thm.eq_thm (pairself #function_elims r),
   146                    intro_rules = Symtab.merge ((forall Thm.eq_thm) o (op ~~)) (pairself #intro_rules r),
   147                    elim_rules = Symtab.merge Thm.eq_thm (pairself #elim_rules r),
   148                    nparams = Symtab.merge (op =) (pairself #nparams r)};
   149 );
   150 
   151   fun map_names f thy = IndCodegenData.map
   152     (fn x => {names = f (#names x), modes = #modes x, function_defs = #function_defs x,
   153             function_intros = #function_intros x, function_elims = #function_elims x,
   154             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   155             nparams = #nparams x}) thy
   156 
   157   fun map_modes f thy = IndCodegenData.map
   158     (fn x => {names = #names x, modes = f (#modes x), function_defs = #function_defs x,
   159             function_intros = #function_intros x, function_elims = #function_elims x,
   160             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   161             nparams = #nparams x}) thy
   162 
   163   fun map_function_defs f thy = IndCodegenData.map
   164     (fn x => {names = #names x, modes = #modes x, function_defs = f (#function_defs x),
   165             function_intros = #function_intros x, function_elims = #function_elims x,
   166             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   167             nparams = #nparams x}) thy 
   168   
   169   fun map_function_elims f thy = IndCodegenData.map
   170     (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   171             function_intros = #function_intros x, function_elims = f (#function_elims x),
   172             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   173             nparams = #nparams x}) thy
   174 
   175   fun map_function_intros f thy = IndCodegenData.map
   176     (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   177             function_intros = f (#function_intros x), function_elims = #function_elims x,
   178             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   179             nparams = #nparams x}) thy
   180 
   181   fun map_intro_rules f thy = IndCodegenData.map
   182     (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   183             function_intros = #function_intros x, function_elims = #function_elims x,
   184             intro_rules = f (#intro_rules x), elim_rules = #elim_rules x,
   185             nparams = #nparams x}) thy 
   186   
   187   fun map_elim_rules f thy = IndCodegenData.map
   188     (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   189             function_intros = #function_intros x, function_elims = #function_elims x,
   190             intro_rules = #intro_rules x, elim_rules = f (#elim_rules x),
   191             nparams = #nparams x}) thy
   192 
   193   fun map_nparams f thy = IndCodegenData.map
   194     (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
   195             function_intros = #function_intros x, function_elims = #function_elims x,
   196             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
   197             nparams = f (#nparams x)}) thy
   198 
   199 (* removes first subgoal *)
   200 fun mycheat_tac thy i st =
   201   (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
   202 
   203 (* Lightweight mode analysis **********************************************)
   204 
   205 (**************************************************************************)
   206 (* source code from old code generator ************************************)
   207 
   208 (**** check if a term contains only constructor functions ****)
   209 
   210 fun is_constrt thy =
   211   let
   212     val cnstrs = flat (maps
   213       (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
   214       (Symtab.dest (DatatypePackage.get_datatypes thy)));
   215     fun check t = (case strip_comb t of
   216         (Free _, []) => true
   217       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
   218             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
   219           | _ => false)
   220       | _ => false)
   221   in check end;
   222 
   223 (**** check if a type is an equality type (i.e. doesn't contain fun)
   224   FIXME this is only an approximation ****)
   225 
   226 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
   227   | is_eqT _ = true;
   228 
   229 (**** mode inference ****)
   230 
   231 fun string_of_mode (iss, is) = space_implode " -> " (map
   232   (fn NONE => "X"
   233     | SOME js => enclose "[" "]" (commas (map string_of_int js)))
   234        (iss @ [SOME is]));
   235 
   236 fun print_modes modes = tracing ("Inferred modes:\n" ^
   237   cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
   238     string_of_mode ms)) modes));
   239 
   240 fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
   241 val terms_vs = distinct (op =) o maps term_vs;
   242 
   243 (** collect all Frees in a term (with duplicates!) **)
   244 fun term_vTs tm =
   245   fold_aterms (fn Free xT => cons xT | _ => I) tm [];
   246 
   247 fun get_args is ts = let
   248   fun get_args' _ _ [] = ([], [])
   249     | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
   250         (get_args' is (i+1) ts)
   251 in get_args' is 1 ts end
   252 
   253 (*FIXME this function should not be named merge... make it local instead*)
   254 fun merge xs [] = xs
   255   | merge [] ys = ys
   256   | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
   257       else y::merge (x::xs) ys;
   258 
   259 fun subsets i j = if i <= j then
   260        let val is = subsets (i+1) j
   261        in merge (map (fn ks => i::ks) is) is end
   262      else [[]];
   263 
   264 fun cprod ([], ys) = []
   265   | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
   266 
   267 fun cprods xss = foldr (map op :: o cprod) [[]] xss;
   268 
   269 datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
   270   why there is another mode type!?*)
   271 
   272 fun modes_of modes t =
   273   let
   274     val ks = 1 upto length (binder_types (fastype_of t));
   275     val default = [Mode (([], ks), ks, [])];
   276     fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
   277         let
   278           val (args1, args2) =
   279             if length args < length iss then
   280               error ("Too few arguments for inductive predicate " ^ name)
   281             else chop (length iss) args;
   282           val k = length args2;
   283           val prfx = 1 upto k
   284         in
   285           if not (is_prefix op = prfx is) then [] else
   286           let val is' = map (fn i => i - k) (List.drop (is, k))
   287           in map (fn x => Mode (m, is', x)) (cprods (map
   288             (fn (NONE, _) => [NONE]
   289               | (SOME js, arg) => map SOME (filter
   290                   (fn Mode (_, js', _) => js=js') (modes_of modes arg)))
   291                     (iss ~~ args1)))
   292           end
   293         end)) (AList.lookup op = modes name)
   294 
   295   in (case strip_comb t of
   296       (Const (name, _), args) => the_default default (mk_modes name args)
   297     | (Var ((name, _), _), args) => the (mk_modes name args)
   298     | (Free (name, _), args) => the (mk_modes name args)
   299     | _ => default)
   300   end
   301 
   302 datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term;
   303 
   304 fun select_mode_prem thy modes vs ps =
   305   find_first (is_some o snd) (ps ~~ map
   306     (fn Prem (us, t) => find_first (fn Mode (_, is, _) =>
   307           let
   308             val (in_ts, out_ts) = get_args is us;
   309             val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
   310             val vTs = maps term_vTs out_ts';
   311             val dupTs = map snd (duplicates (op =) vTs) @
   312               List.mapPartial (AList.lookup (op =) vTs) vs;
   313           in
   314             terms_vs (in_ts @ in_ts') subset vs andalso
   315             forall (is_eqT o fastype_of) in_ts' andalso
   316             term_vs t subset vs andalso
   317             forall is_eqT dupTs
   318           end)
   319             (modes_of modes t handle Option =>
   320                error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
   321       | Negprem (us, t) => find_first (fn Mode (_, is, _) =>
   322             length us = length is andalso
   323             terms_vs us subset vs andalso
   324             term_vs t subset vs)
   325             (modes_of modes t handle Option =>
   326                error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
   327       | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
   328           else NONE
   329       ) ps);
   330 
   331 fun check_mode_clause thy param_vs modes (iss, is) (ts, ps) =
   332   let
   333     val modes' = modes @ List.mapPartial
   334       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   335         (param_vs ~~ iss); 
   336     fun check_mode_prems vs [] = SOME vs
   337       | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
   338           NONE => NONE
   339         | SOME (x, _) => check_mode_prems
   340             (case x of Prem (us, _) => vs union terms_vs us | _ => vs)
   341             (filter_out (equal x) ps))
   342     val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is ts));
   343     val in_vs = terms_vs in_ts;
   344     val concl_vs = terms_vs ts
   345   in
   346     forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
   347     forall (is_eqT o fastype_of) in_ts' andalso
   348     (case check_mode_prems (param_vs union in_vs) ps of
   349        NONE => false
   350      | SOME vs => concl_vs subset vs)
   351   end;
   352 
   353 fun check_modes_pred thy param_vs preds modes (p, ms) =
   354   let val SOME rs = AList.lookup (op =) preds p
   355   in (p, List.filter (fn m => case find_index
   356     (not o check_mode_clause thy param_vs modes m) rs of
   357       ~1 => true
   358     | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
   359       p ^ " violates mode " ^ string_of_mode m); false)) ms)
   360   end;
   361 
   362 fun fixp f (x : (string * mode list) list) =
   363   let val y = f x
   364   in if x = y then x else fixp f y end;
   365 
   366 fun infer_modes thy extra_modes arities param_vs preds = fixp (fn modes =>
   367   map (check_modes_pred thy param_vs preds (modes @ extra_modes)) modes)
   368     (map (fn (s, (ks, k)) => (s, cprod (cprods (map
   369       (fn NONE => [NONE]
   370         | SOME k' => map SOME (subsets 1 k')) ks),
   371       subsets 1 k))) arities);
   372 
   373 
   374 (*****************************************************************************************)
   375 (**** end of old source code *************************************************************)
   376 (*****************************************************************************************)
   377 (**** term construction ****)
   378 
   379 (* for simple modes (e.g. parameters) only: better call it param_funT *)
   380 (* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
   381 fun funT_of T NONE = T
   382   | funT_of T (SOME mode) = let
   383      val Ts = binder_types T;
   384      val (Us1, Us2) = get_args mode Ts
   385    in Us1 ---> (mk_pred_enumT (mk_tupleT Us2)) end;
   386 
   387 fun funT'_of (iss, is) T = let
   388     val Ts = binder_types T
   389     val (paramTs, argTs) = chop (length iss) Ts
   390     val paramTs' = map2 (fn SOME is => funT'_of ([], is) | NONE => I) iss paramTs 
   391     val (inargTs, outargTs) = get_args is argTs
   392   in
   393     (paramTs' @ inargTs) ---> (mk_pred_enumT (mk_tupleT outargTs))
   394   end; 
   395 
   396 
   397 fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of
   398       NONE => ((names, (s, [])::vs), Free (s, T))
   399     | SOME xs =>
   400         let
   401           val s' = Name.variant names s;
   402           val v = Free (s', T)
   403         in
   404           ((s'::names, AList.update (op =) (s, v::xs) vs), v)
   405         end);
   406 
   407 fun distinct_v (nvs, Free (s, T)) = mk_v nvs s T
   408   | distinct_v (nvs, t $ u) =
   409       let
   410         val (nvs', t') = distinct_v (nvs, t);
   411         val (nvs'', u') = distinct_v (nvs', u);
   412       in (nvs'', t' $ u') end
   413   | distinct_v x = x;
   414 
   415 fun compile_match thy eqs eqs' out_ts success_t =
   416   let 
   417     val eqs'' = maps mk_eq eqs @ eqs'
   418     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
   419     val name = Name.variant names "x";
   420     val name' = Name.variant (name :: names) "y";
   421     val T = mk_tupleT (map fastype_of out_ts);
   422     val U = fastype_of success_t;
   423     val U' = dest_pred_enumT U;
   424     val v = Free (name, T);
   425     val v' = Free (name', T);
   426   in
   427     lambda v (fst (DatatypePackage.make_case
   428       (ProofContext.init thy) false [] v
   429       [(mk_tuple out_ts,
   430         if null eqs'' then success_t
   431         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
   432           foldr1 HOLogic.mk_conj eqs'' $ success_t $
   433             mk_empty U'),
   434        (v', mk_empty U')]))
   435   end;
   436 
   437 fun modename_of thy name mode = let
   438     val v = (PredModetab.lookup (#names (IndCodegenData.get thy)) (name, mode))
   439   in if (is_some v) then the v (*FIXME use case here*)
   440      else error ("fun modename_of - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
   441   end
   442 
   443 fun modes_of thy =
   444   these o Symtab.lookup ((#modes o IndCodegenData.get) thy);
   445 
   446 (*FIXME function can be removed*)
   447 fun mk_funcomp f t =
   448   let
   449     val names = Term.add_free_names t [];
   450     val Ts = binder_types (fastype_of t);
   451     val vs = map Free
   452       (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
   453   in
   454     fold_rev lambda vs (f (list_comb (t, vs)))
   455   end;
   456 
   457 fun compile_param thy modes (NONE, t) = t
   458   | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = let
   459     val (f, args) = strip_comb t
   460     val (params, args') = chop (length ms) args
   461     val params' = map (compile_param thy modes) (ms ~~ params)
   462     val f' = case f of
   463         Const (name, T) =>
   464           if AList.defined op = modes name then
   465             Const (modename_of thy name (iss, is'), funT'_of (iss, is') T)
   466           else error "compile param: Not an inductive predicate with correct mode"
   467       | Free (name, T) => Free (name, funT_of T (SOME is'))
   468     in list_comb (f', params' @ args') end
   469   | compile_param _ _ _ = error "compile params"
   470 
   471 fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) =
   472       (case strip_comb t of
   473          (Const (name, T), params) =>
   474            if AList.defined op = modes name then
   475              let
   476                val (Ts, Us) = get_args is
   477                  (curry Library.drop (length ms) (fst (strip_type T)))
   478                val params' = map (compile_param thy modes) (ms ~~ params)
   479                val mode_id = modename_of thy name mode
   480              in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) --->
   481                mk_pred_enumT (mk_tupleT Us)), params')
   482              end
   483            else error "not a valid inductive expression"
   484        | (Free (name, T), args) =>
   485          (*if name mem param_vs then *)
   486          (* Higher order mode call *)
   487          let val r = Free (name, funT_of T (SOME is))
   488          in list_comb (r, args) end)
   489   | compile_expr _ _ _ = error "not a valid inductive expression"
   490 
   491 
   492 fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp =
   493   let
   494     val modes' = modes @ List.mapPartial
   495       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   496         (param_vs ~~ iss);
   497     fun check_constrt ((names, eqs), t) =
   498       if is_constrt thy t then ((names, eqs), t) else
   499         let
   500           val s = Name.variant names "x";
   501           val v = Free (s, fastype_of t)
   502         in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
   503 
   504     val (in_ts, out_ts) = get_args is ts;
   505     val ((all_vs', eqs), in_ts') =
   506       (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
   507 
   508     fun compile_prems out_ts' vs names [] =
   509           let
   510             val ((names', eqs'), out_ts'') =
   511               (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts');
   512             val (nvs, out_ts''') = (*FIXME*) Library.foldl_map distinct_v
   513               ((names', map (rpair []) vs), out_ts'');
   514           in
   515             compile_match thy (snd nvs) (eqs @ eqs') out_ts'''
   516               (mk_single (mk_tuple out_ts))
   517           end
   518       | compile_prems out_ts vs names ps =
   519           let
   520             val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
   521             val SOME (p, mode as SOME (Mode (_, js, _))) =
   522               select_mode_prem thy modes' vs' ps
   523             val ps' = filter_out (equal p) ps
   524             val ((names', eqs), out_ts') =
   525               (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts)
   526             val (nvs, out_ts'') = (*FIXME*) Library.foldl_map distinct_v
   527               ((names', map (rpair []) vs), out_ts')
   528             val (compiled_clause, rest) = case p of
   529                Prem (us, t) =>
   530                  let
   531                    val (in_ts, out_ts''') = get_args js us;
   532                    val u = list_comb (compile_expr thy modes (mode, t), in_ts)
   533                    val rest = compile_prems out_ts''' vs' (fst nvs) ps'
   534                  in
   535                    (u, rest)
   536                  end
   537              | Negprem (us, t) =>
   538                  let
   539                    val (in_ts, out_ts''') = get_args js us
   540                    val u = list_comb (compile_expr thy modes (mode, t), in_ts)
   541                    val rest = compile_prems out_ts''' vs' (fst nvs) ps'
   542                  in
   543                    (mk_not_pred u, rest)
   544                  end
   545              | Sidecond t =>
   546                  let
   547                    val rest = compile_prems [] vs' (fst nvs) ps';
   548                  in
   549                    (mk_if_predenum t, rest)
   550                  end
   551           in
   552             compile_match thy (snd nvs) eqs out_ts'' 
   553               (mk_bind (compiled_clause, rest))
   554           end
   555     val prem_t = compile_prems in_ts' param_vs all_vs' ps;
   556   in
   557     mk_bind (mk_single inp, prem_t)
   558   end
   559 
   560 fun compile_pred thy all_vs param_vs modes s T cls mode =
   561   let
   562     val Ts = binder_types T;
   563     val (Ts1, Ts2) = chop (length param_vs) Ts;
   564     val Ts1' = map2 funT_of Ts1 (fst mode)
   565     val (Us1, Us2) = get_args (snd mode) Ts2;
   566     val xnames = Name.variant_list param_vs
   567       (map (fn i => "x" ^ string_of_int i) (snd mode));
   568     val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1;
   569     val cl_ts =
   570       map (fn cl => compile_clause thy
   571         all_vs param_vs modes mode cl (mk_tuple xs)) cls;
   572     val mode_id = modename_of thy s mode
   573   in
   574     HOLogic.mk_Trueprop (HOLogic.mk_eq
   575       (list_comb (Const (mode_id, (Ts1' @ Us1) --->
   576            mk_pred_enumT (mk_tupleT Us2)),
   577          map2 (fn s => fn T => Free (s, T)) param_vs Ts1' @ xs),
   578        foldr1 mk_sup cl_ts))
   579   end;
   580 
   581 fun compile_preds thy all_vs param_vs modes preds =
   582   map (fn (s, (T, cls)) =>
   583     map (compile_pred thy all_vs param_vs modes s T cls)
   584       ((the o AList.lookup (op =) modes) s)) preds;
   585 
   586 (* end of term construction ******************************************************)
   587 
   588 (* special setup for simpset *)                  
   589 val HOL_basic_ss' = HOL_basic_ss setSolver 
   590   (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
   591 
   592 
   593 (* misc: constructing and proving tupleE rules ***********************************)
   594 
   595 
   596 (* Creating definitions of functional programs 
   597    and proving intro and elim rules **********************************************) 
   598 
   599 fun is_ind_pred thy c = 
   600   (can (InductivePackage.the_inductive (ProofContext.init thy)) c) orelse
   601   (c mem_string (Symtab.keys (#intro_rules (IndCodegenData.get thy))))
   602 
   603 fun get_name_of_ind_calls_of_clauses thy preds intrs =
   604     fold Term.add_consts intrs [] |> map fst
   605     |> filter_out (member (op =) preds) |> filter (is_ind_pred thy)
   606 
   607 fun print_arities arities = tracing ("Arities:\n" ^
   608   cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
   609     space_implode " -> " (map
   610       (fn NONE => "X" | SOME k' => string_of_int k')
   611         (ks @ [SOME k]))) arities));
   612 
   613 fun mk_Eval_of ((x, T), NONE) names = (x, names)
   614   | mk_Eval_of ((x, T), SOME mode) names = let
   615   val Ts = binder_types T
   616   val argnames = Name.variant_list names
   617         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   618   val args = map Free (argnames ~~ Ts)
   619   val (inargs, outargs) = get_args mode args
   620   val r = mk_Eval (list_comb (x, inargs), mk_tuple outargs)
   621   val t = fold_rev lambda args r 
   622 in
   623   (t, argnames @ names)
   624 end;
   625 
   626 fun create_intro_rule nparams mode defthm mode_id funT pred thy =
   627 let
   628   val Ts = binder_types (fastype_of pred)
   629   val funtrm = Const (mode_id, funT)
   630   val argnames = Name.variant_list []
   631         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   632   val (Ts1, Ts2) = chop nparams Ts;
   633   val Ts1' = map2 funT_of Ts1 (fst mode)
   634   val args = map Free (argnames ~~ (Ts1' @ Ts2))
   635   val (params, io_args) = chop nparams args
   636   val (inargs, outargs) = get_args (snd mode) io_args
   637   val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ (fst mode)) []
   638   val predprop = HOLogic.mk_Trueprop (list_comb (pred, params' @ io_args))
   639   val funargs = params @ inargs
   640   val funpropE = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
   641                   if null outargs then Free("y", HOLogic.unitT) else mk_tuple outargs))
   642   val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
   643                    mk_tuple outargs))
   644   val introtrm = Logic.mk_implies (predprop, funpropI)
   645   val simprules = [defthm, @{thm eval_pred},
   646                    @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
   647   val unfolddef_tac = (Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1)
   648   val introthm = Goal.prove (ProofContext.init thy) (argnames @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
   649   val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
   650   val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predprop, P)], P)
   651   val elimthm = Goal.prove (ProofContext.init thy) (argnames @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
   652 in
   653   map_function_intros (Symtab.update_new (mode_id, introthm)) thy
   654   |> map_function_elims (Symtab.update_new (mode_id, elimthm))
   655   |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "I"), introthm) |> snd
   656   |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "E"), elimthm)  |> snd
   657 end;
   658 
   659 fun create_definitions preds nparams (name, modes) thy =
   660   let
   661     val _ = tracing "create definitions"
   662     val T = AList.lookup (op =) preds name |> the
   663     fun create_definition mode thy = let
   664       fun string_of_mode mode = if null mode then "0"
   665         else space_implode "_" (map string_of_int mode)
   666       val HOmode = let
   667         fun string_of_HOmode m s = case m of NONE => s | SOME mode => s ^ "__" ^ (string_of_mode mode)    
   668         in (fold string_of_HOmode (fst mode) "") end;
   669       val mode_id = name ^ (if HOmode = "" then "_" else HOmode ^ "___")
   670         ^ (string_of_mode (snd mode))
   671       val Ts = binder_types T;
   672       val (Ts1, Ts2) = chop nparams Ts;
   673       val Ts1' = map2 funT_of Ts1 (fst mode)
   674       val (Us1, Us2) = get_args (snd mode) Ts2;
   675       val names = Name.variant_list []
   676         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   677       val xs = map Free (names ~~ (Ts1' @ Ts2));
   678       val (xparams, xargs) = chop nparams xs;
   679       val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ (fst mode)) names
   680       val (xins, xouts) = get_args (snd mode) xargs;
   681       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
   682        | mk_split_lambda [x] t = lambda x t
   683        | mk_split_lambda xs t = let
   684          fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
   685            | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
   686          in mk_split_lambda' xs t end;
   687       val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs)))
   688       val funT = (Ts1' @ Us1) ---> (mk_pred_enumT (mk_tupleT Us2))
   689       val mode_id = Sign.full_bname thy (Long_Name.base_name mode_id)
   690       val lhs = list_comb (Const (mode_id, funT), xparams @ xins)
   691       val def = Logic.mk_equals (lhs, predterm)
   692       val ([defthm], thy') = thy |>
   693         Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_id), funT, NoSyn)] |>
   694         PureThy.add_defs false [((Binding.name (Long_Name.base_name mode_id ^ "_def"), def), [])]
   695       in thy' |> map_names (PredModetab.update_new ((name, mode), mode_id))
   696            |> map_function_defs (Symtab.update_new (mode_id, defthm))
   697            |> create_intro_rule nparams mode defthm mode_id funT (Const (name, T))
   698       end;
   699   in
   700     fold create_definition modes thy
   701   end;
   702 
   703 (**************************************************************************************)
   704 (* Proving equivalence of term *)
   705 
   706 
   707 fun intro_rule thy pred mode = modename_of thy pred mode
   708     |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the
   709 
   710 fun elim_rule thy pred mode = modename_of thy pred mode
   711     |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the
   712 
   713 fun pred_intros thy predname = let
   714     fun is_intro_of pred intro = let
   715       val const = fst (strip_comb (HOLogic.dest_Trueprop (concl_of intro)))
   716     in (fst (dest_Const const) = pred) end;
   717     val d = IndCodegenData.get thy
   718   in
   719     if (Symtab.defined (#intro_rules d) predname) then
   720       rev (Symtab.lookup_list (#intro_rules d) predname)
   721     else
   722       InductivePackage.the_inductive (ProofContext.init thy) predname
   723       |> snd |> #intrs |> filter (is_intro_of predname)
   724   end
   725 
   726 fun function_definition thy pred mode =
   727   modename_of thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
   728 
   729 fun is_Type (Type _) = true
   730   | is_Type _ = false
   731 
   732 fun imp_prems_conv cv ct =
   733   case Thm.term_of ct of
   734     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   735   | _ => Conv.all_conv ct
   736 
   737 fun Trueprop_conv cv ct =
   738   case Thm.term_of ct of
   739     Const ("Trueprop", _) $ _ => Conv.arg_conv cv ct  
   740   | _ => error "Trueprop_conv"
   741 
   742 fun preprocess_intro thy rule = Thm.transfer thy rule (*FIXME preprocessor
   743   Conv.fconv_rule
   744     (imp_prems_conv
   745       (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @ {thm Predicate.eq_is_eq})))))
   746     (Thm.transfer thy rule) *)
   747 
   748 fun preprocess_elim thy nargs elimrule = (*FIXME preprocessor -- let
   749    fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
   750       HOLogic.mk_Trueprop (Const (@ {const_name Predicate.eq}, T) $ lhs $ rhs)
   751     | replace_eqs t = t
   752    fun preprocess_case t = let
   753      val params = Logic.strip_params t
   754      val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
   755      val assums_hyp' = assums1 @ (map replace_eqs assums2)
   756      in list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t)) end
   757    val prems = Thm.prems_of elimrule
   758    val cases' = map preprocess_case (tl prems)
   759    val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
   760  in
   761    Thm.equal_elim
   762      (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@ {thm eq_is_eq}])
   763         (cterm_of thy elimrule')))
   764      elimrule
   765  end*) elimrule;
   766 
   767 
   768 (* returns true if t is an application of an datatype constructor *)
   769 (* which then consequently would be splitted *)
   770 (* else false *)
   771 fun is_constructor thy t =
   772   if (is_Type (fastype_of t)) then
   773     (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
   774       NONE => false
   775     | SOME info => (let
   776       val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
   777       val (c, _) = strip_comb t
   778       in (case c of
   779         Const (name, _) => name mem_string constr_consts
   780         | _ => false) end))
   781   else false
   782 
   783 (* MAJOR FIXME:  prove_params should be simple
   784  - different form of introrule for parameters ? *)
   785 fun prove_param thy modes (NONE, t) = all_tac 
   786   | prove_param thy modes (m as SOME (Mode (mode, is, ms)), t) = let
   787     val  (f, args) = strip_comb t
   788     val (params, _) = chop (length ms) args
   789     val f_tac = case f of
   790         Const (name, T) => simp_tac (HOL_basic_ss addsimps 
   791            @{thm eval_pred}::function_definition thy name mode::[]) 1
   792       | Free _ => all_tac
   793   in  
   794     print_tac "before simplification in prove_args:"
   795     THEN debug_tac ("mode" ^ (makestring mode))
   796     THEN f_tac
   797     THEN print_tac "after simplification in prove_args"
   798     (* work with parameter arguments *)
   799     THEN (EVERY (map (prove_param thy modes) (ms ~~ params)))
   800     THEN (REPEAT_DETERM (atac 1))
   801   end
   802 
   803 fun prove_expr thy modes (SOME (Mode (mode, is, ms)), t, us) (premposition : int) =
   804   (case strip_comb t of
   805     (Const (name, T), args) =>
   806       if AList.defined op = modes name then (let
   807           val introrule = intro_rule thy name mode
   808           (*val (in_args, out_args) = get_args is us
   809           val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop
   810             (hd (Logic.strip_imp_prems (prop_of introrule))))
   811           val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *)
   812           val (_, args) = chop nparams rargs
   813           val _ = tracing ("args: " ^ (makestring args))
   814           val subst = map (pairself (cterm_of thy)) (args ~~ us)
   815           val _ = tracing ("subst: " ^ (makestring subst))
   816           val inst_introrule = Drule.cterm_instantiate subst introrule*)
   817          (* the next line is old and probably wrong *)
   818           val (args1, args2) = chop (length ms) args
   819           val _ = tracing ("premposition: " ^ (makestring premposition))
   820         in
   821         rtac @{thm bindI} 1
   822         THEN print_tac "before intro rule:"
   823         THEN debug_tac ("mode" ^ (makestring mode))
   824         THEN debug_tac (makestring introrule)
   825         THEN debug_tac ("premposition: " ^ (makestring premposition))
   826         (* for the right assumption in first position *)
   827         THEN rotate_tac premposition 1
   828         THEN rtac introrule 1
   829         THEN print_tac "after intro rule"
   830         (* work with parameter arguments *)
   831         THEN (EVERY (map (prove_param thy modes) (ms ~~ args1)))
   832         THEN (REPEAT_DETERM (atac 1)) end)
   833       else error "Prove expr if case not implemented"
   834     | _ => rtac @{thm bindI} 1
   835            THEN atac 1)
   836   | prove_expr _ _ _ _ =  error "Prove expr not implemented"
   837 
   838 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; 
   839 
   840 fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
   841 
   842 fun prove_match thy (out_ts : term list) = let
   843   fun get_case_rewrite t =
   844     if (is_constructor thy t) then let
   845       val case_rewrites = (#case_rewrites (DatatypePackage.the_datatype thy
   846         ((fst o dest_Type o fastype_of) t)))
   847       in case_rewrites @ (flat (map get_case_rewrite (snd (strip_comb t)))) end
   848     else []
   849   val simprules = @{thm "unit.cases"} :: @{thm "prod.cases"} :: (flat (map get_case_rewrite out_ts))
   850 (* replace TRY by determining if it necessary - are there equations when calling compile match? *)
   851 in
   852   print_tac ("before prove_match rewriting: simprules = " ^ (makestring simprules))
   853    (* make this simpset better! *)
   854   THEN asm_simp_tac (HOL_basic_ss' addsimps simprules) 1
   855   THEN print_tac "after prove_match:"
   856   THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm "HOL.if_P"}] 1
   857          THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss 1))))
   858          THEN (SOLVED (asm_simp_tac HOL_basic_ss 1)))))
   859   THEN print_tac "after if simplification"
   860 end;
   861 
   862 (* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
   863 
   864 fun prove_sidecond thy modes t = let
   865   val _ = tracing ("prove_sidecond:" ^ (makestring t))
   866   fun preds_of t nameTs = case strip_comb t of 
   867     (f as Const (name, T), args) =>
   868       if AList.defined (op =) modes name then (name, T) :: nameTs
   869         else fold preds_of args nameTs
   870     | _ => nameTs
   871   val preds = preds_of t []
   872   
   873   val _ = tracing ("preds: " ^ (makestring preds))
   874   val defs = map
   875     (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
   876       preds
   877   val _ = tracing ("defs: " ^ (makestring defs))
   878 in 
   879    (* remove not_False_eq_True when simpset in prove_match is better *)
   880    simp_tac (HOL_basic_ss addsimps @{thm not_False_eq_True} :: @{thm eval_pred} :: defs) 1 
   881    (* need better control here! *)
   882    THEN print_tac "after sidecond simplification"
   883    end
   884 
   885 fun prove_clause thy nargs all_vs param_vs modes (iss, is) (ts, ps) = let
   886   val modes' = modes @ List.mapPartial
   887    (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
   888      (param_vs ~~ iss);
   889   fun check_constrt ((names, eqs), t) =
   890       if is_constrt thy t then ((names, eqs), t) else
   891         let
   892           val s = Name.variant names "x";
   893           val v = Free (s, fastype_of t)
   894         in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
   895   
   896   val (in_ts, clause_out_ts) = get_args is ts;
   897   val ((all_vs', eqs), in_ts') =
   898       (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
   899   fun prove_prems out_ts vs [] =
   900     (prove_match thy out_ts)
   901     THEN asm_simp_tac HOL_basic_ss' 1
   902     THEN print_tac "before the last rule of singleI:"
   903     THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
   904   | prove_prems out_ts vs rps =
   905     let
   906       val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
   907       val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
   908         select_mode_prem thy modes' vs' rps;
   909       val premposition = (find_index (equal p) ps) + nargs
   910       val rps' = filter_out (equal p) rps;
   911       val rest_tac = (case p of Prem (us, t) =>
   912           let
   913             val (in_ts, out_ts''') = get_args js us
   914             val rec_tac = prove_prems out_ts''' vs' rps'
   915           in
   916             print_tac "before clause:"
   917             THEN asm_simp_tac HOL_basic_ss 1
   918             THEN print_tac "before prove_expr:"
   919             THEN prove_expr thy modes (mode, t, us) premposition
   920             THEN print_tac "after prove_expr:"
   921             THEN rec_tac
   922           end
   923         | Negprem (us, t) =>
   924           let
   925             val (in_ts, out_ts''') = get_args js us
   926             val rec_tac = prove_prems out_ts''' vs' rps'
   927             val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
   928             val (_, params) = strip_comb t
   929           in
   930             print_tac "before negated clause:"
   931             THEN rtac @{thm bindI} 1
   932             THEN (if (is_some name) then
   933                 simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1
   934                 THEN rtac @{thm not_predI} 1
   935                 THEN print_tac "after neg. intro rule"
   936                 THEN print_tac ("t = " ^ (makestring t))
   937                 (* FIXME: work with parameter arguments *)
   938                 THEN (EVERY (map (prove_param thy modes) (param_modes ~~ params)))
   939               else
   940                 rtac @{thm not_predI'} 1)
   941             THEN (REPEAT_DETERM (atac 1))
   942             THEN rec_tac
   943           end
   944         | Sidecond t =>
   945          rtac @{thm bindI} 1
   946          THEN rtac @{thm if_predI} 1
   947          THEN print_tac "before sidecond:"
   948          THEN prove_sidecond thy modes t
   949          THEN print_tac "after sidecond:"
   950          THEN prove_prems [] vs' rps')
   951     in (prove_match thy out_ts)
   952         THEN rest_tac
   953     end;
   954   val prems_tac = prove_prems in_ts' param_vs ps
   955 in
   956   rtac @{thm bindI} 1
   957   THEN rtac @{thm singleI} 1
   958   THEN prems_tac
   959 end;
   960 
   961 fun select_sup 1 1 = []
   962   | select_sup _ 1 = [rtac @{thm supI1}]
   963   | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
   964 
   965 fun get_nparams thy s = let
   966     val _ = tracing ("get_nparams: " ^ s)
   967   in
   968   if Symtab.defined (#nparams (IndCodegenData.get thy)) s then
   969     the (Symtab.lookup (#nparams (IndCodegenData.get thy)) s) 
   970   else
   971     case try (InductivePackage.the_inductive (ProofContext.init thy)) s of
   972       SOME info => info |> snd |> #raw_induct |> Thm.unvarify
   973         |> InductivePackage.params_of |> length
   974     | NONE => 0 (* default value *)
   975   end
   976 
   977 val ind_set_codegen_preproc = InductiveSetPackage.codegen_preproc;
   978 
   979 fun pred_elim thy predname =
   980   if (Symtab.defined (#elim_rules (IndCodegenData.get thy)) predname) then
   981     the (Symtab.lookup (#elim_rules (IndCodegenData.get thy)) predname)
   982   else
   983     (let
   984       val ind_result = InductivePackage.the_inductive (ProofContext.init thy) predname
   985       val index = find_index (fn s => s = predname) (#names (fst ind_result))
   986     in nth (#elims (snd ind_result)) index end)
   987 
   988 fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let
   989   val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename_of thy pred mode))
   990 (*  val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred
   991   val index = find_index (fn s => s = pred) (#names (fst ind_result))
   992   val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *)
   993   val nargs = length (binder_types T) - get_nparams thy pred
   994   val pred_case_rule = singleton (ind_set_codegen_preproc thy)
   995     (preprocess_elim thy nargs (pred_elim thy pred))
   996   (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}])*)
   997   val _ = tracing ("pred_case_rule " ^ (makestring pred_case_rule))
   998 in
   999   REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
  1000   THEN etac elim_rule 1
  1001   THEN etac pred_case_rule 1
  1002   THEN (EVERY (map
  1003          (fn i => EVERY' (select_sup (length clauses) i) i) 
  1004            (1 upto (length clauses))))
  1005   THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses))
  1006 end;
  1007 
  1008 (*******************************************************************************************************)
  1009 (* Proof in the other direction ************************************************************************)
  1010 (*******************************************************************************************************)
  1011 
  1012 fun prove_match2 thy out_ts = let
  1013   fun split_term_tac (Free _) = all_tac
  1014     | split_term_tac t =
  1015       if (is_constructor thy t) then let
  1016         val info = DatatypePackage.the_datatype thy ((fst o dest_Type o fastype_of) t)
  1017         val num_of_constrs = length (#case_rewrites info)
  1018         (* special treatment of pairs -- because of fishing *)
  1019         val split_rules = case (fst o dest_Type o fastype_of) t of
  1020           "*" => [@{thm prod.split_asm}] 
  1021           | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
  1022         val (_, ts) = strip_comb t
  1023       in
  1024         print_tac ("splitting with t = " ^ (makestring t))
  1025         THEN (Splitter.split_asm_tac split_rules 1)
  1026 (*        THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
  1027           THEN (DETERM (TRY (etac @{thm Pair_inject} 1))) *)
  1028         THEN (REPEAT_DETERM_N (num_of_constrs - 1) (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))
  1029         THEN (EVERY (map split_term_tac ts))
  1030       end
  1031     else all_tac
  1032   in
  1033     split_term_tac (mk_tuple out_ts)
  1034     THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) THEN (etac @{thm botE} 2))))
  1035   end
  1036 
  1037 (* VERY LARGE SIMILIRATIY to function prove_param 
  1038 -- join both functions
  1039 *) 
  1040 fun prove_param2 thy modes (NONE, t) = all_tac 
  1041   | prove_param2 thy modes (m as SOME (Mode (mode, is, ms)), t) = let
  1042     val  (f, args) = strip_comb t
  1043     val (params, _) = chop (length ms) args
  1044     val f_tac = case f of
  1045         Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
  1046            @{thm eval_pred}::function_definition thy name mode::[]) 1
  1047       | Free _ => all_tac
  1048   in  
  1049     print_tac "before simplification in prove_args:"
  1050     THEN debug_tac ("function : " ^ (makestring f) ^ " - mode" ^ (makestring mode))
  1051     THEN f_tac
  1052     THEN print_tac "after simplification in prove_args"
  1053     (* work with parameter arguments *)
  1054     THEN (EVERY (map (prove_param2 thy modes) (ms ~~ params)))
  1055   end
  1056 
  1057 fun prove_expr2 thy modes (SOME (Mode (mode, is, ms)), t) = 
  1058   (case strip_comb t of
  1059     (Const (name, T), args) =>
  1060       if AList.defined op = modes name then
  1061         etac @{thm bindE} 1
  1062         THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
  1063         THEN (etac (elim_rule thy name mode) 1)
  1064         THEN (EVERY (map (prove_param2 thy modes) (ms ~~ args)))
  1065       else error "Prove expr2 if case not implemented"
  1066     | _ => etac @{thm bindE} 1)
  1067   | prove_expr2 _ _ _ = error "Prove expr2 not implemented"
  1068 
  1069 fun prove_sidecond2 thy modes t = let
  1070   val _ = tracing ("prove_sidecond:" ^ (makestring t))
  1071   fun preds_of t nameTs = case strip_comb t of 
  1072     (f as Const (name, T), args) =>
  1073       if AList.defined (op =) modes name then (name, T) :: nameTs
  1074         else fold preds_of args nameTs
  1075     | _ => nameTs
  1076   val preds = preds_of t []
  1077   val _ = tracing ("preds: " ^ (makestring preds))
  1078   val defs = map
  1079     (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
  1080       preds
  1081   in
  1082    (* only simplify the one assumption *)
  1083    full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 
  1084    (* need better control here! *)
  1085    THEN print_tac "after sidecond2 simplification"
  1086    end
  1087   
  1088 fun prove_clause2 thy all_vs param_vs modes (iss, is) (ts, ps) pred i = let
  1089   val modes' = modes @ List.mapPartial
  1090    (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
  1091      (param_vs ~~ iss);
  1092   fun check_constrt ((names, eqs), t) =
  1093       if is_constrt thy t then ((names, eqs), t) else
  1094         let
  1095           val s = Name.variant names "x";
  1096           val v = Free (s, fastype_of t)
  1097         in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
  1098   val pred_intro_rule = nth (pred_intros thy pred) (i - 1)
  1099     |> preprocess_intro thy
  1100     |> (fn thm => hd (ind_set_codegen_preproc thy [thm]))
  1101     (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *)
  1102   val (in_ts, clause_out_ts) = get_args is ts;
  1103   val ((all_vs', eqs), in_ts') =
  1104       (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
  1105   fun prove_prems2 out_ts vs [] =
  1106     print_tac "before prove_match2 - last call:"
  1107     THEN prove_match2 thy out_ts
  1108     THEN print_tac "after prove_match2 - last call:"
  1109     THEN (etac @{thm singleE} 1)
  1110     THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
  1111     THEN (asm_full_simp_tac HOL_basic_ss' 1)
  1112     THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
  1113     THEN (asm_full_simp_tac HOL_basic_ss' 1)
  1114     THEN SOLVED (print_tac "state before applying intro rule:"
  1115       THEN (rtac pred_intro_rule 1)
  1116       (* How to handle equality correctly? *)
  1117       THEN (print_tac "state before assumption matching")
  1118       THEN (REPEAT (atac 1 ORELSE 
  1119          (CHANGED (asm_full_simp_tac HOL_basic_ss' 1)
  1120           THEN print_tac "state after simp_tac:"))))
  1121   | prove_prems2 out_ts vs ps = let
  1122       val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
  1123       val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
  1124         select_mode_prem thy modes' vs' ps;
  1125       val ps' = filter_out (equal p) ps;
  1126       val rest_tac = (case p of Prem (us, t) =>
  1127           let
  1128             val (in_ts, out_ts''') = get_args js us
  1129             val rec_tac = prove_prems2 out_ts''' vs' ps'
  1130           in
  1131             (prove_expr2 thy modes (mode, t)) THEN rec_tac
  1132           end
  1133         | Negprem (us, t) =>
  1134           let
  1135             val (in_ts, out_ts''') = get_args js us
  1136             val rec_tac = prove_prems2 out_ts''' vs' ps'
  1137             val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
  1138             val (_, params) = strip_comb t
  1139           in
  1140             print_tac "before neg prem 2"
  1141             THEN etac @{thm bindE} 1
  1142             THEN (if is_some name then
  1143                 full_simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1 
  1144                 THEN etac @{thm not_predE} 1
  1145                 THEN (EVERY (map (prove_param2 thy modes) (param_modes ~~ params)))
  1146               else
  1147                 etac @{thm not_predE'} 1)
  1148             THEN rec_tac
  1149           end 
  1150         | Sidecond t =>
  1151             etac @{thm bindE} 1
  1152             THEN etac @{thm if_predE} 1
  1153             THEN prove_sidecond2 thy modes t 
  1154             THEN prove_prems2 [] vs' ps')
  1155     in print_tac "before prove_match2:"
  1156        THEN prove_match2 thy out_ts
  1157        THEN print_tac "after prove_match2:"
  1158        THEN rest_tac
  1159     end;
  1160   val prems_tac = prove_prems2 in_ts' param_vs ps 
  1161 in
  1162   print_tac "starting prove_clause2"
  1163   THEN etac @{thm bindE} 1
  1164   THEN (etac @{thm singleE'} 1)
  1165   THEN (TRY (etac @{thm Pair_inject} 1))
  1166   THEN print_tac "after singleE':"
  1167   THEN prems_tac
  1168 end;
  1169  
  1170 fun prove_other_direction thy all_vs param_vs modes clauses (pred, mode) = let
  1171   fun prove_clause (clause, i) =
  1172     (if i < length clauses then etac @{thm supE} 1 else all_tac)
  1173     THEN (prove_clause2 thy all_vs param_vs modes mode clause pred i)
  1174 in
  1175   (DETERM (TRY (rtac @{thm unit.induct} 1)))
  1176    THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
  1177    THEN (rtac (intro_rule thy pred mode) 1)
  1178    THEN (EVERY (map prove_clause (clauses ~~ (1 upto (length clauses)))))
  1179 end;
  1180 
  1181 fun prove_pred thy all_vs param_vs modes clauses (((pred, T), mode), t) = let
  1182   val ctxt = ProofContext.init thy
  1183   val clauses' = the (AList.lookup (op =) clauses pred)
  1184 in
  1185   Goal.prove ctxt (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) t []) [] t
  1186     (if !do_proofs then
  1187       (fn _ =>
  1188       rtac @{thm pred_iffI} 1
  1189       THEN prove_one_direction thy all_vs param_vs modes clauses' ((pred, T), mode)
  1190       THEN print_tac "proved one direction"
  1191       THEN prove_other_direction thy all_vs param_vs modes clauses' (pred, mode)
  1192       THEN print_tac "proved other direction")
  1193      else (fn _ => mycheat_tac thy 1))
  1194 end;
  1195 
  1196 fun prove_preds thy all_vs param_vs modes clauses pmts =
  1197   map (prove_pred thy all_vs param_vs modes clauses) pmts
  1198 
  1199 (* look for other place where this functionality was used before *)
  1200 fun strip_intro_concl intro nparams = let
  1201   val _ $ u = Logic.strip_imp_concl intro
  1202   val (pred, all_args) = strip_comb u
  1203   val (params, args) = chop nparams all_args
  1204 in (pred, (params, args)) end
  1205 
  1206 (* setup for alternative introduction and elimination rules *)
  1207 
  1208 fun add_intro_thm thm thy = let
  1209    val (pred, _) = dest_Const (fst (strip_intro_concl (prop_of thm) 0))
  1210  in map_intro_rules (Symtab.insert_list Thm.eq_thm (pred, thm)) thy end
  1211 
  1212 fun add_elim_thm thm thy = let
  1213     val (pred, _) = dest_Const (fst 
  1214       (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
  1215   in map_elim_rules (Symtab.update (pred, thm)) thy end
  1216 
  1217 
  1218 (* special case: inductive predicate with no clauses *)
  1219 fun noclause (predname, T) thy = let
  1220   val Ts = binder_types T
  1221   val names = Name.variant_list []
  1222         (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
  1223   val vs = map Free (names ~~ Ts)
  1224   val clausehd =  HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs))
  1225   val intro_t = Logic.mk_implies (@{prop False}, clausehd)
  1226   val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
  1227   val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
  1228   val intro_thm = Goal.prove (ProofContext.init thy) names [] intro_t
  1229         (fn {...} => etac @{thm FalseE} 1)
  1230   val elim_thm = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
  1231         (fn {...} => etac (pred_elim thy predname) 1) 
  1232 in
  1233   add_intro_thm intro_thm thy
  1234   |> add_elim_thm elim_thm
  1235 end
  1236 
  1237 (*************************************************************************************)
  1238 (* main function *********************************************************************)
  1239 (*************************************************************************************)
  1240 
  1241 fun create_def_equation' ind_name (mode : mode option) thy =
  1242 let
  1243   val _ = tracing ("starting create_def_equation' with " ^ ind_name)
  1244   val (prednames, preds) = 
  1245     case (try (InductivePackage.the_inductive (ProofContext.init thy)) ind_name) of
  1246       SOME info => let val preds = info |> snd |> #preds
  1247         in (map (fst o dest_Const) preds, map ((apsnd Logic.unvarifyT) o dest_Const) preds) end
  1248     | NONE => let
  1249         val pred = Symtab.lookup (#intro_rules (IndCodegenData.get thy)) ind_name
  1250           |> the |> hd |> prop_of
  1251           |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> strip_comb
  1252           |> fst |>  dest_Const |> apsnd Logic.unvarifyT
  1253        in ([ind_name], [pred]) end
  1254   val thy' = fold (fn pred as (predname, T) => fn thy =>
  1255     if null (pred_intros thy predname) then noclause pred thy else thy) preds thy
  1256   val intrs = map (preprocess_intro thy') (maps (pred_intros thy') prednames)
  1257     |> ind_set_codegen_preproc thy' (*FIXME preprocessor
  1258     |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*)
  1259     |> map (Logic.unvarify o prop_of)
  1260   val _ = tracing ("preprocessed intro rules:" ^ (makestring (map (cterm_of thy') intrs)))
  1261   val name_of_calls = get_name_of_ind_calls_of_clauses thy' prednames intrs 
  1262   val _ = tracing ("calling preds: " ^ makestring name_of_calls)
  1263   val _ = tracing "starting recursive compilations"
  1264   fun rec_call name thy = 
  1265     (*FIXME use member instead of infix mem*)
  1266     if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then
  1267       create_def_equation name thy else thy
  1268   val thy'' = fold rec_call name_of_calls thy'
  1269   val _ = tracing "returning from recursive calls"
  1270   val _ = tracing "starting mode inference"
  1271   val extra_modes = Symtab.dest (#modes (IndCodegenData.get thy''))
  1272   val nparams = get_nparams thy'' ind_name
  1273   val _ $ u = Logic.strip_imp_concl (hd intrs);
  1274   val params = List.take (snd (strip_comb u), nparams);
  1275   val param_vs = maps term_vs params
  1276   val all_vs = terms_vs intrs
  1277   fun dest_prem t =
  1278       (case strip_comb t of
  1279         (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
  1280       | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of
  1281           Prem (ts, t) => Negprem (ts, t)
  1282         | Negprem _ => error ("Double negation not allowed in premise: " ^ (makestring (c $ t))) 
  1283         | Sidecond t => Sidecond (c $ t))
  1284       | (c as Const (s, _), ts) =>
  1285         if is_ind_pred thy'' s then
  1286           let val (ts1, ts2) = chop (get_nparams thy'' s) ts
  1287           in Prem (ts2, list_comb (c, ts1)) end
  1288         else Sidecond t
  1289       | _ => Sidecond t)
  1290   fun add_clause intr (clauses, arities) =
  1291   let
  1292     val _ $ t = Logic.strip_imp_concl intr;
  1293     val (Const (name, T), ts) = strip_comb t;
  1294     val (ts1, ts2) = chop nparams ts;
  1295     val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
  1296     val (Ts, Us) = chop nparams (binder_types T)
  1297   in
  1298     (AList.update op = (name, these (AList.lookup op = clauses name) @
  1299       [(ts2, prems)]) clauses,
  1300      AList.update op = (name, (map (fn U => (case strip_type U of
  1301                  (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs)
  1302                | _ => NONE)) Ts,
  1303              length Us)) arities)
  1304   end;
  1305   val (clauses, arities) = fold add_clause intrs ([], []);
  1306   val modes = infer_modes thy'' extra_modes arities param_vs clauses
  1307   val _ = print_arities arities;
  1308   val _ = print_modes modes;
  1309   val modes = if (is_some mode) then AList.update (op =) (ind_name, [the mode]) modes else modes
  1310   val _ = print_modes modes
  1311   val thy''' = fold (create_definitions preds nparams) modes thy''
  1312     |> map_modes (fold Symtab.update_new modes)
  1313   val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses
  1314   val _ = tracing "compiling predicates..."
  1315   val ts = compile_preds thy''' all_vs param_vs (extra_modes @ modes) clauses'
  1316   val _ = tracing "returned term from compile_preds"
  1317   val pred_mode = maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses'
  1318   val _ = tracing "starting proof"
  1319   val result_thms = prove_preds thy''' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts))
  1320   val (_, thy'''') = yield_singleton PureThy.add_thmss
  1321     ((Binding.name (Long_Name.base_name ind_name ^ "_codegen" (*FIXME other suffix*)), result_thms),
  1322       [Attrib.attribute_i thy''' Code.add_default_eqn_attrib]) thy'''
  1323 in
  1324   thy''''
  1325 end
  1326 and create_def_equation ind_name thy = create_def_equation' ind_name NONE thy
  1327 
  1328 fun set_nparams (pred, nparams) thy = map_nparams (Symtab.update (pred, nparams)) thy
  1329 
  1330 fun print_alternative_rules thy = let
  1331     val d = IndCodegenData.get thy
  1332     val preds = (Symtab.keys (#intro_rules d)) union (Symtab.keys (#elim_rules d))
  1333     val _ = tracing ("preds: " ^ (makestring preds))
  1334     fun print pred = let
  1335       val _ = tracing ("predicate: " ^ pred)
  1336       val _ = tracing ("introrules: ")
  1337       val _ = fold (fn thm => fn u => tracing (makestring thm))
  1338         (rev (Symtab.lookup_list (#intro_rules d) pred)) ()
  1339       val _ = tracing ("casesrule: ")
  1340       val _ = tracing (makestring (Symtab.lookup (#elim_rules d) pred))
  1341     in () end
  1342     val _ = map print preds
  1343  in thy end; 
  1344   
  1345 fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I)
  1346 
  1347 val code_ind_intros_attrib = attrib add_intro_thm
  1348 
  1349 val code_ind_cases_attrib = attrib add_elim_thm
  1350 
  1351 val setup =
  1352   Attrib.setup @{binding code_ind_intros} (Scan.succeed code_ind_intros_attrib)
  1353     "adding alternative introduction rules for code generation of inductive predicates" #>
  1354   Attrib.setup @{binding code_ind_cases} (Scan.succeed code_ind_cases_attrib)
  1355     "adding alternative elimination rules for code generation of inductive predicates";
  1356 
  1357 end;
  1358 
  1359 fun pred_compile name thy = Predicate_Compile.create_def_equation
  1360   (Sign.intern_const thy name) thy;