src/HOL/ex/predicate_compile.ML
author haftmann
Thu, 14 May 2009 15:09:48 +0200
changeset 31156 90fed3d4430f
parent 31124 58bc773c60e2
child 31170 c6efe82fc652
permissions -rw-r--r--
merged module code_unit.ML into code.ML
haftmann@30374
     1
(* Author: Lukas Bulwahn
haftmann@30374
     2
haftmann@30374
     3
(Prototype of) A compiler from predicates specified by intro/elim rules
haftmann@30374
     4
to equations.
haftmann@30374
     5
*)
haftmann@30374
     6
haftmann@30374
     7
signature PREDICATE_COMPILE =
haftmann@30374
     8
sig
haftmann@30972
     9
  type mode = int list option list * int list
haftmann@31124
    10
  val prove_equation: string -> mode option -> theory -> theory
haftmann@30972
    11
  val intro_rule: theory -> string -> mode -> thm
haftmann@30972
    12
  val elim_rule: theory -> string -> mode -> thm
haftmann@31124
    13
  val strip_intro_concl: term -> int -> term * (term list * term list)
haftmann@30972
    14
  val modename_of: theory -> string -> mode -> string
haftmann@30972
    15
  val modes_of: theory -> string -> mode list
haftmann@31124
    16
  val setup: theory -> theory
haftmann@31124
    17
  val code_pred: string -> Proof.context -> Proof.state
haftmann@31124
    18
  val code_pred_cmd: string -> Proof.context -> Proof.state
haftmann@31124
    19
  val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*)
haftmann@30374
    20
  val do_proofs: bool ref
haftmann@31124
    21
  val pred_intros: theory -> string -> thm list
haftmann@31124
    22
  val get_nparams: theory -> string -> int
haftmann@30374
    23
end;
haftmann@30374
    24
haftmann@31124
    25
structure Predicate_Compile : PREDICATE_COMPILE =
haftmann@30374
    26
struct
haftmann@30374
    27
haftmann@30972
    28
(** auxiliary **)
haftmann@30972
    29
haftmann@30972
    30
(* debug stuff *)
haftmann@30972
    31
haftmann@30972
    32
fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
haftmann@30972
    33
haftmann@30972
    34
fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
haftmann@30972
    35
fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
haftmann@30972
    36
haftmann@30972
    37
val do_proofs = ref true;
haftmann@30972
    38
haftmann@30972
    39
haftmann@30972
    40
(** fundamentals **)
haftmann@30972
    41
haftmann@30972
    42
(* syntactic operations *)
haftmann@30972
    43
haftmann@30972
    44
fun mk_eq (x, xs) =
haftmann@30972
    45
  let fun mk_eqs _ [] = []
haftmann@30972
    46
        | mk_eqs a (b::cs) =
haftmann@30972
    47
            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
haftmann@30972
    48
  in mk_eqs x xs end;
haftmann@30972
    49
haftmann@30972
    50
fun mk_tupleT [] = HOLogic.unitT
haftmann@30972
    51
  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
haftmann@30972
    52
haftmann@30972
    53
fun mk_tuple [] = HOLogic.unit
haftmann@30972
    54
  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
haftmann@30972
    55
haftmann@30972
    56
fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
haftmann@30972
    57
  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
haftmann@30972
    58
  | dest_tuple t = [t]
haftmann@30972
    59
haftmann@30972
    60
fun mk_pred_enumT T = Type ("Predicate.pred", [T])
haftmann@30972
    61
haftmann@30972
    62
fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
haftmann@30972
    63
  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
haftmann@30972
    64
haftmann@30972
    65
fun mk_Enum f =
haftmann@30972
    66
  let val T as Type ("fun", [T', _]) = fastype_of f
haftmann@30972
    67
  in
haftmann@30972
    68
    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
haftmann@30972
    69
  end;
haftmann@30972
    70
haftmann@30972
    71
fun mk_Eval (f, x) =
haftmann@30972
    72
  let val T = fastype_of x
haftmann@30972
    73
  in
haftmann@30972
    74
    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
haftmann@30972
    75
  end;
haftmann@30972
    76
haftmann@30972
    77
fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
haftmann@30972
    78
haftmann@30972
    79
fun mk_single t =
haftmann@30972
    80
  let val T = fastype_of t
haftmann@30972
    81
  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
haftmann@30972
    82
haftmann@30972
    83
fun mk_bind (x, f) =
haftmann@30972
    84
  let val T as Type ("fun", [_, U]) = fastype_of f
haftmann@30972
    85
  in
haftmann@30972
    86
    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
haftmann@30972
    87
  end;
haftmann@30972
    88
haftmann@30972
    89
val mk_sup = HOLogic.mk_binop @{const_name sup};
haftmann@30972
    90
haftmann@30972
    91
fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
haftmann@30972
    92
  HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
haftmann@30972
    93
haftmann@30972
    94
fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
haftmann@30972
    95
  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
haftmann@30972
    96
haftmann@30972
    97
haftmann@30972
    98
(* data structures *)
haftmann@30972
    99
haftmann@30972
   100
type mode = int list option list * int list;
haftmann@30972
   101
haftmann@30972
   102
val mode_ord = prod_ord (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord);
haftmann@30972
   103
haftmann@30374
   104
structure PredModetab = TableFun(
haftmann@30972
   105
  type key = string * mode
haftmann@30972
   106
  val ord = prod_ord fast_string_ord mode_ord
haftmann@30972
   107
);
haftmann@30374
   108
haftmann@30374
   109
haftmann@30972
   110
(*FIXME scrap boilerplate*)
haftmann@30972
   111
haftmann@30374
   112
structure IndCodegenData = TheoryDataFun
haftmann@30374
   113
(
haftmann@30374
   114
  type T = {names : string PredModetab.table,
haftmann@30972
   115
            modes : mode list Symtab.table,
haftmann@30374
   116
            function_defs : Thm.thm Symtab.table,
haftmann@30374
   117
            function_intros : Thm.thm Symtab.table,
haftmann@30374
   118
            function_elims : Thm.thm Symtab.table,
haftmann@30972
   119
            intro_rules : Thm.thm list Symtab.table,
haftmann@30374
   120
            elim_rules : Thm.thm Symtab.table,
haftmann@30374
   121
            nparams : int Symtab.table
haftmann@30972
   122
           }; (*FIXME: better group tables according to key*)
haftmann@30374
   123
      (* names: map from inductive predicate and mode to function name (string).
haftmann@30374
   124
         modes: map from inductive predicates to modes
haftmann@30374
   125
         function_defs: map from function name to definition
haftmann@30374
   126
         function_intros: map from function name to intro rule
haftmann@30374
   127
         function_elims: map from function name to elim rule
haftmann@30374
   128
         intro_rules: map from inductive predicate to alternative intro rules
haftmann@30374
   129
         elim_rules: map from inductive predicate to alternative elimination rule
haftmann@30374
   130
         nparams: map from const name to number of parameters (* assuming there exist intro and elimination rules *) 
haftmann@30374
   131
       *)
haftmann@30374
   132
  val empty = {names = PredModetab.empty,
haftmann@30374
   133
               modes = Symtab.empty,
haftmann@30374
   134
               function_defs = Symtab.empty,
haftmann@30374
   135
               function_intros = Symtab.empty,
haftmann@30374
   136
               function_elims = Symtab.empty,
haftmann@30374
   137
               intro_rules = Symtab.empty,
haftmann@30374
   138
               elim_rules = Symtab.empty,
haftmann@30374
   139
               nparams = Symtab.empty};
haftmann@30374
   140
  val copy = I;
haftmann@30374
   141
  val extend = I;
haftmann@30374
   142
  fun merge _ r = {names = PredModetab.merge (op =) (pairself #names r),
haftmann@30374
   143
                   modes = Symtab.merge (op =) (pairself #modes r),
haftmann@30374
   144
                   function_defs = Symtab.merge Thm.eq_thm (pairself #function_defs r),
haftmann@30374
   145
                   function_intros = Symtab.merge Thm.eq_thm (pairself #function_intros r),
haftmann@30374
   146
                   function_elims = Symtab.merge Thm.eq_thm (pairself #function_elims r),
haftmann@30374
   147
                   intro_rules = Symtab.merge ((forall Thm.eq_thm) o (op ~~)) (pairself #intro_rules r),
haftmann@30374
   148
                   elim_rules = Symtab.merge Thm.eq_thm (pairself #elim_rules r),
haftmann@30374
   149
                   nparams = Symtab.merge (op =) (pairself #nparams r)};
haftmann@30374
   150
);
haftmann@30374
   151
haftmann@30374
   152
  fun map_names f thy = IndCodegenData.map
haftmann@30374
   153
    (fn x => {names = f (#names x), modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   154
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   155
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   156
            nparams = #nparams x}) thy
haftmann@30374
   157
haftmann@30374
   158
  fun map_modes f thy = IndCodegenData.map
haftmann@30374
   159
    (fn x => {names = #names x, modes = f (#modes x), function_defs = #function_defs x,
haftmann@30374
   160
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   161
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   162
            nparams = #nparams x}) thy
haftmann@30374
   163
haftmann@30374
   164
  fun map_function_defs f thy = IndCodegenData.map
haftmann@30374
   165
    (fn x => {names = #names x, modes = #modes x, function_defs = f (#function_defs x),
haftmann@30374
   166
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   167
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   168
            nparams = #nparams x}) thy 
haftmann@30374
   169
  
haftmann@30374
   170
  fun map_function_elims f thy = IndCodegenData.map
haftmann@30374
   171
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   172
            function_intros = #function_intros x, function_elims = f (#function_elims x),
haftmann@30374
   173
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   174
            nparams = #nparams x}) thy
haftmann@30374
   175
haftmann@30374
   176
  fun map_function_intros f thy = IndCodegenData.map
haftmann@30374
   177
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   178
            function_intros = f (#function_intros x), function_elims = #function_elims x,
haftmann@30374
   179
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   180
            nparams = #nparams x}) thy
haftmann@30374
   181
haftmann@30374
   182
  fun map_intro_rules f thy = IndCodegenData.map
haftmann@30374
   183
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   184
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   185
            intro_rules = f (#intro_rules x), elim_rules = #elim_rules x,
haftmann@30374
   186
            nparams = #nparams x}) thy 
haftmann@30374
   187
  
haftmann@30374
   188
  fun map_elim_rules f thy = IndCodegenData.map
haftmann@30374
   189
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   190
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   191
            intro_rules = #intro_rules x, elim_rules = f (#elim_rules x),
haftmann@30374
   192
            nparams = #nparams x}) thy
haftmann@30374
   193
haftmann@30374
   194
  fun map_nparams f thy = IndCodegenData.map
haftmann@30374
   195
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   196
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   197
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   198
            nparams = f (#nparams x)}) thy
haftmann@30374
   199
haftmann@30374
   200
(* removes first subgoal *)
haftmann@30374
   201
fun mycheat_tac thy i st =
haftmann@30374
   202
  (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
haftmann@30374
   203
haftmann@30374
   204
(* Lightweight mode analysis **********************************************)
haftmann@30374
   205
haftmann@30374
   206
(**************************************************************************)
haftmann@30374
   207
(* source code from old code generator ************************************)
haftmann@30374
   208
haftmann@30374
   209
(**** check if a term contains only constructor functions ****)
haftmann@30374
   210
haftmann@30374
   211
fun is_constrt thy =
haftmann@30374
   212
  let
haftmann@30374
   213
    val cnstrs = flat (maps
haftmann@30374
   214
      (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
haftmann@30374
   215
      (Symtab.dest (DatatypePackage.get_datatypes thy)));
haftmann@30374
   216
    fun check t = (case strip_comb t of
haftmann@30374
   217
        (Free _, []) => true
haftmann@30374
   218
      | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
haftmann@30374
   219
            (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
haftmann@30374
   220
          | _ => false)
haftmann@30374
   221
      | _ => false)
haftmann@30374
   222
  in check end;
haftmann@30374
   223
haftmann@30972
   224
(**** check if a type is an equality type (i.e. doesn't contain fun)
haftmann@30972
   225
  FIXME this is only an approximation ****)
haftmann@30374
   226
haftmann@30374
   227
fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
haftmann@30374
   228
  | is_eqT _ = true;
haftmann@30374
   229
haftmann@30374
   230
(**** mode inference ****)
haftmann@30374
   231
haftmann@30374
   232
fun string_of_mode (iss, is) = space_implode " -> " (map
haftmann@30374
   233
  (fn NONE => "X"
haftmann@30374
   234
    | SOME js => enclose "[" "]" (commas (map string_of_int js)))
haftmann@30374
   235
       (iss @ [SOME is]));
haftmann@30374
   236
haftmann@30972
   237
fun print_modes modes = tracing ("Inferred modes:\n" ^
haftmann@30374
   238
  cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
haftmann@30374
   239
    string_of_mode ms)) modes));
haftmann@30374
   240
haftmann@30374
   241
fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
haftmann@30374
   242
val terms_vs = distinct (op =) o maps term_vs;
haftmann@30374
   243
haftmann@30374
   244
(** collect all Frees in a term (with duplicates!) **)
haftmann@30374
   245
fun term_vTs tm =
haftmann@30374
   246
  fold_aterms (fn Free xT => cons xT | _ => I) tm [];
haftmann@30374
   247
haftmann@30374
   248
fun get_args is ts = let
haftmann@30374
   249
  fun get_args' _ _ [] = ([], [])
haftmann@30374
   250
    | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
haftmann@30374
   251
        (get_args' is (i+1) ts)
haftmann@30374
   252
in get_args' is 1 ts end
haftmann@30374
   253
haftmann@30972
   254
(*FIXME this function should not be named merge... make it local instead*)
haftmann@30374
   255
fun merge xs [] = xs
haftmann@30374
   256
  | merge [] ys = ys
haftmann@30374
   257
  | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
haftmann@30374
   258
      else y::merge (x::xs) ys;
haftmann@30374
   259
haftmann@30374
   260
fun subsets i j = if i <= j then
haftmann@30374
   261
       let val is = subsets (i+1) j
haftmann@30374
   262
       in merge (map (fn ks => i::ks) is) is end
haftmann@30374
   263
     else [[]];
haftmann@30374
   264
haftmann@30374
   265
fun cprod ([], ys) = []
haftmann@30374
   266
  | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
haftmann@30374
   267
haftmann@30374
   268
fun cprods xss = foldr (map op :: o cprod) [[]] xss;
haftmann@30374
   269
haftmann@30972
   270
datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
haftmann@30972
   271
  why there is another mode type!?*)
haftmann@30374
   272
haftmann@30374
   273
fun modes_of modes t =
haftmann@30374
   274
  let
haftmann@30374
   275
    val ks = 1 upto length (binder_types (fastype_of t));
haftmann@30374
   276
    val default = [Mode (([], ks), ks, [])];
haftmann@30374
   277
    fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
haftmann@30374
   278
        let
haftmann@30374
   279
          val (args1, args2) =
haftmann@30374
   280
            if length args < length iss then
haftmann@30374
   281
              error ("Too few arguments for inductive predicate " ^ name)
haftmann@30374
   282
            else chop (length iss) args;
haftmann@30374
   283
          val k = length args2;
haftmann@30374
   284
          val prfx = 1 upto k
haftmann@30374
   285
        in
haftmann@30374
   286
          if not (is_prefix op = prfx is) then [] else
haftmann@30374
   287
          let val is' = map (fn i => i - k) (List.drop (is, k))
haftmann@30374
   288
          in map (fn x => Mode (m, is', x)) (cprods (map
haftmann@30374
   289
            (fn (NONE, _) => [NONE]
haftmann@30374
   290
              | (SOME js, arg) => map SOME (filter
haftmann@30374
   291
                  (fn Mode (_, js', _) => js=js') (modes_of modes arg)))
haftmann@30374
   292
                    (iss ~~ args1)))
haftmann@30374
   293
          end
haftmann@30374
   294
        end)) (AList.lookup op = modes name)
haftmann@30374
   295
haftmann@30374
   296
  in (case strip_comb t of
haftmann@30374
   297
      (Const (name, _), args) => the_default default (mk_modes name args)
haftmann@30374
   298
    | (Var ((name, _), _), args) => the (mk_modes name args)
haftmann@30374
   299
    | (Free (name, _), args) => the (mk_modes name args)
haftmann@30374
   300
    | _ => default)
haftmann@30374
   301
  end
haftmann@30374
   302
haftmann@30374
   303
datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term;
haftmann@30374
   304
haftmann@30374
   305
fun select_mode_prem thy modes vs ps =
haftmann@30374
   306
  find_first (is_some o snd) (ps ~~ map
haftmann@30374
   307
    (fn Prem (us, t) => find_first (fn Mode (_, is, _) =>
haftmann@30374
   308
          let
haftmann@30374
   309
            val (in_ts, out_ts) = get_args is us;
haftmann@30374
   310
            val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
haftmann@30374
   311
            val vTs = maps term_vTs out_ts';
haftmann@30374
   312
            val dupTs = map snd (duplicates (op =) vTs) @
haftmann@30374
   313
              List.mapPartial (AList.lookup (op =) vTs) vs;
haftmann@30374
   314
          in
haftmann@30374
   315
            terms_vs (in_ts @ in_ts') subset vs andalso
haftmann@30374
   316
            forall (is_eqT o fastype_of) in_ts' andalso
haftmann@30374
   317
            term_vs t subset vs andalso
haftmann@30374
   318
            forall is_eqT dupTs
haftmann@30374
   319
          end)
haftmann@30374
   320
            (modes_of modes t handle Option =>
haftmann@30374
   321
               error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
haftmann@30374
   322
      | Negprem (us, t) => find_first (fn Mode (_, is, _) =>
haftmann@30374
   323
            length us = length is andalso
haftmann@30374
   324
            terms_vs us subset vs andalso
haftmann@30374
   325
            term_vs t subset vs)
haftmann@30374
   326
            (modes_of modes t handle Option =>
haftmann@30374
   327
               error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
haftmann@30374
   328
      | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
haftmann@30374
   329
          else NONE
haftmann@30374
   330
      ) ps);
haftmann@30374
   331
haftmann@30374
   332
fun check_mode_clause thy param_vs modes (iss, is) (ts, ps) =
haftmann@30374
   333
  let
haftmann@30374
   334
    val modes' = modes @ List.mapPartial
haftmann@30374
   335
      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
   336
        (param_vs ~~ iss); 
haftmann@30374
   337
    fun check_mode_prems vs [] = SOME vs
haftmann@30374
   338
      | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
haftmann@30374
   339
          NONE => NONE
haftmann@30374
   340
        | SOME (x, _) => check_mode_prems
haftmann@30374
   341
            (case x of Prem (us, _) => vs union terms_vs us | _ => vs)
haftmann@30374
   342
            (filter_out (equal x) ps))
haftmann@30374
   343
    val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is ts));
haftmann@30374
   344
    val in_vs = terms_vs in_ts;
haftmann@30374
   345
    val concl_vs = terms_vs ts
haftmann@30374
   346
  in
haftmann@30374
   347
    forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
haftmann@30374
   348
    forall (is_eqT o fastype_of) in_ts' andalso
haftmann@30374
   349
    (case check_mode_prems (param_vs union in_vs) ps of
haftmann@30374
   350
       NONE => false
haftmann@30374
   351
     | SOME vs => concl_vs subset vs)
haftmann@30374
   352
  end;
haftmann@30374
   353
haftmann@30374
   354
fun check_modes_pred thy param_vs preds modes (p, ms) =
haftmann@30374
   355
  let val SOME rs = AList.lookup (op =) preds p
haftmann@30374
   356
  in (p, List.filter (fn m => case find_index
haftmann@30374
   357
    (not o check_mode_clause thy param_vs modes m) rs of
haftmann@30374
   358
      ~1 => true
haftmann@30972
   359
    | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
haftmann@30374
   360
      p ^ " violates mode " ^ string_of_mode m); false)) ms)
haftmann@30374
   361
  end;
haftmann@30374
   362
haftmann@30972
   363
fun fixp f (x : (string * mode list) list) =
haftmann@30374
   364
  let val y = f x
haftmann@30374
   365
  in if x = y then x else fixp f y end;
haftmann@30374
   366
haftmann@30374
   367
fun infer_modes thy extra_modes arities param_vs preds = fixp (fn modes =>
haftmann@30374
   368
  map (check_modes_pred thy param_vs preds (modes @ extra_modes)) modes)
haftmann@30374
   369
    (map (fn (s, (ks, k)) => (s, cprod (cprods (map
haftmann@30374
   370
      (fn NONE => [NONE]
haftmann@30374
   371
        | SOME k' => map SOME (subsets 1 k')) ks),
haftmann@30374
   372
      subsets 1 k))) arities);
haftmann@30374
   373
haftmann@30374
   374
haftmann@30374
   375
(*****************************************************************************************)
haftmann@30374
   376
(**** end of old source code *************************************************************)
haftmann@30374
   377
(*****************************************************************************************)
haftmann@30374
   378
(**** term construction ****)
haftmann@30374
   379
haftmann@30374
   380
(* for simple modes (e.g. parameters) only: better call it param_funT *)
haftmann@30374
   381
(* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
haftmann@30374
   382
fun funT_of T NONE = T
haftmann@30374
   383
  | funT_of T (SOME mode) = let
haftmann@30374
   384
     val Ts = binder_types T;
haftmann@30374
   385
     val (Us1, Us2) = get_args mode Ts
haftmann@30374
   386
   in Us1 ---> (mk_pred_enumT (mk_tupleT Us2)) end;
haftmann@30374
   387
haftmann@30374
   388
fun funT'_of (iss, is) T = let
haftmann@30374
   389
    val Ts = binder_types T
haftmann@30374
   390
    val (paramTs, argTs) = chop (length iss) Ts
haftmann@30374
   391
    val paramTs' = map2 (fn SOME is => funT'_of ([], is) | NONE => I) iss paramTs 
haftmann@30374
   392
    val (inargTs, outargTs) = get_args is argTs
haftmann@30374
   393
  in
haftmann@30374
   394
    (paramTs' @ inargTs) ---> (mk_pred_enumT (mk_tupleT outargTs))
haftmann@30374
   395
  end; 
haftmann@30374
   396
haftmann@30374
   397
haftmann@30374
   398
fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of
haftmann@30374
   399
      NONE => ((names, (s, [])::vs), Free (s, T))
haftmann@30374
   400
    | SOME xs =>
haftmann@30374
   401
        let
haftmann@30374
   402
          val s' = Name.variant names s;
haftmann@30374
   403
          val v = Free (s', T)
haftmann@30374
   404
        in
haftmann@30374
   405
          ((s'::names, AList.update (op =) (s, v::xs) vs), v)
haftmann@30374
   406
        end);
haftmann@30374
   407
haftmann@30374
   408
fun distinct_v (nvs, Free (s, T)) = mk_v nvs s T
haftmann@30374
   409
  | distinct_v (nvs, t $ u) =
haftmann@30374
   410
      let
haftmann@30374
   411
        val (nvs', t') = distinct_v (nvs, t);
haftmann@30374
   412
        val (nvs'', u') = distinct_v (nvs', u);
haftmann@30374
   413
      in (nvs'', t' $ u') end
haftmann@30374
   414
  | distinct_v x = x;
haftmann@30374
   415
haftmann@30374
   416
fun compile_match thy eqs eqs' out_ts success_t =
haftmann@30374
   417
  let 
haftmann@30374
   418
    val eqs'' = maps mk_eq eqs @ eqs'
haftmann@30374
   419
    val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
haftmann@30374
   420
    val name = Name.variant names "x";
haftmann@30374
   421
    val name' = Name.variant (name :: names) "y";
haftmann@30374
   422
    val T = mk_tupleT (map fastype_of out_ts);
haftmann@30374
   423
    val U = fastype_of success_t;
haftmann@30374
   424
    val U' = dest_pred_enumT U;
haftmann@30374
   425
    val v = Free (name, T);
haftmann@30374
   426
    val v' = Free (name', T);
haftmann@30374
   427
  in
haftmann@30374
   428
    lambda v (fst (DatatypePackage.make_case
haftmann@30374
   429
      (ProofContext.init thy) false [] v
haftmann@30374
   430
      [(mk_tuple out_ts,
haftmann@30374
   431
        if null eqs'' then success_t
haftmann@30374
   432
        else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
haftmann@30374
   433
          foldr1 HOLogic.mk_conj eqs'' $ success_t $
haftmann@30374
   434
            mk_empty U'),
haftmann@30374
   435
       (v', mk_empty U')]))
haftmann@30374
   436
  end;
haftmann@30374
   437
haftmann@30972
   438
fun modename_of thy name mode = let
haftmann@30374
   439
    val v = (PredModetab.lookup (#names (IndCodegenData.get thy)) (name, mode))
haftmann@30972
   440
  in if (is_some v) then the v (*FIXME use case here*)
haftmann@30972
   441
     else error ("fun modename_of - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
haftmann@30374
   442
  end
haftmann@30374
   443
haftmann@30972
   444
fun modes_of thy =
haftmann@30972
   445
  these o Symtab.lookup ((#modes o IndCodegenData.get) thy);
haftmann@30972
   446
haftmann@30972
   447
(*FIXME function can be removed*)
haftmann@30374
   448
fun mk_funcomp f t =
haftmann@30374
   449
  let
haftmann@30374
   450
    val names = Term.add_free_names t [];
haftmann@30374
   451
    val Ts = binder_types (fastype_of t);
haftmann@30374
   452
    val vs = map Free
haftmann@30374
   453
      (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
haftmann@30374
   454
  in
haftmann@30374
   455
    fold_rev lambda vs (f (list_comb (t, vs)))
haftmann@30374
   456
  end;
haftmann@30374
   457
haftmann@30374
   458
fun compile_param thy modes (NONE, t) = t
haftmann@30374
   459
  | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = let
haftmann@30374
   460
    val (f, args) = strip_comb t
haftmann@30374
   461
    val (params, args') = chop (length ms) args
haftmann@30374
   462
    val params' = map (compile_param thy modes) (ms ~~ params)
haftmann@30374
   463
    val f' = case f of
haftmann@30374
   464
        Const (name, T) =>
haftmann@30374
   465
          if AList.defined op = modes name then
haftmann@30972
   466
            Const (modename_of thy name (iss, is'), funT'_of (iss, is') T)
haftmann@30374
   467
          else error "compile param: Not an inductive predicate with correct mode"
haftmann@30374
   468
      | Free (name, T) => Free (name, funT_of T (SOME is'))
haftmann@30374
   469
    in list_comb (f', params' @ args') end
haftmann@30374
   470
  | compile_param _ _ _ = error "compile params"
haftmann@30374
   471
haftmann@30374
   472
fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) =
haftmann@30374
   473
      (case strip_comb t of
haftmann@30374
   474
         (Const (name, T), params) =>
haftmann@30374
   475
           if AList.defined op = modes name then
haftmann@30374
   476
             let
haftmann@30374
   477
               val (Ts, Us) = get_args is
haftmann@30374
   478
                 (curry Library.drop (length ms) (fst (strip_type T)))
haftmann@30374
   479
               val params' = map (compile_param thy modes) (ms ~~ params)
haftmann@30972
   480
               val mode_id = modename_of thy name mode
haftmann@30374
   481
             in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) --->
haftmann@30374
   482
               mk_pred_enumT (mk_tupleT Us)), params')
haftmann@30374
   483
             end
haftmann@30374
   484
           else error "not a valid inductive expression"
haftmann@30374
   485
       | (Free (name, T), args) =>
haftmann@30374
   486
         (*if name mem param_vs then *)
haftmann@30374
   487
         (* Higher order mode call *)
haftmann@30374
   488
         let val r = Free (name, funT_of T (SOME is))
haftmann@30374
   489
         in list_comb (r, args) end)
haftmann@30374
   490
  | compile_expr _ _ _ = error "not a valid inductive expression"
haftmann@30374
   491
haftmann@30374
   492
haftmann@30374
   493
fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp =
haftmann@30374
   494
  let
haftmann@30374
   495
    val modes' = modes @ List.mapPartial
haftmann@30374
   496
      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
   497
        (param_vs ~~ iss);
haftmann@30374
   498
    fun check_constrt ((names, eqs), t) =
haftmann@30374
   499
      if is_constrt thy t then ((names, eqs), t) else
haftmann@30374
   500
        let
haftmann@30374
   501
          val s = Name.variant names "x";
haftmann@30374
   502
          val v = Free (s, fastype_of t)
haftmann@30374
   503
        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
haftmann@30374
   504
haftmann@30374
   505
    val (in_ts, out_ts) = get_args is ts;
haftmann@30374
   506
    val ((all_vs', eqs), in_ts') =
haftmann@30374
   507
      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
haftmann@30374
   508
haftmann@30374
   509
    fun compile_prems out_ts' vs names [] =
haftmann@30374
   510
          let
haftmann@30374
   511
            val ((names', eqs'), out_ts'') =
haftmann@30374
   512
              (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts');
haftmann@30374
   513
            val (nvs, out_ts''') = (*FIXME*) Library.foldl_map distinct_v
haftmann@30374
   514
              ((names', map (rpair []) vs), out_ts'');
haftmann@30374
   515
          in
haftmann@30374
   516
            compile_match thy (snd nvs) (eqs @ eqs') out_ts'''
haftmann@30374
   517
              (mk_single (mk_tuple out_ts))
haftmann@30374
   518
          end
haftmann@30374
   519
      | compile_prems out_ts vs names ps =
haftmann@30374
   520
          let
haftmann@30374
   521
            val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
haftmann@30374
   522
            val SOME (p, mode as SOME (Mode (_, js, _))) =
haftmann@30374
   523
              select_mode_prem thy modes' vs' ps
haftmann@30374
   524
            val ps' = filter_out (equal p) ps
haftmann@30374
   525
            val ((names', eqs), out_ts') =
haftmann@30374
   526
              (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts)
haftmann@30374
   527
            val (nvs, out_ts'') = (*FIXME*) Library.foldl_map distinct_v
haftmann@30374
   528
              ((names', map (rpair []) vs), out_ts')
haftmann@30374
   529
            val (compiled_clause, rest) = case p of
haftmann@30374
   530
               Prem (us, t) =>
haftmann@30374
   531
                 let
haftmann@30374
   532
                   val (in_ts, out_ts''') = get_args js us;
haftmann@30374
   533
                   val u = list_comb (compile_expr thy modes (mode, t), in_ts)
haftmann@30374
   534
                   val rest = compile_prems out_ts''' vs' (fst nvs) ps'
haftmann@30374
   535
                 in
haftmann@30374
   536
                   (u, rest)
haftmann@30374
   537
                 end
haftmann@30374
   538
             | Negprem (us, t) =>
haftmann@30374
   539
                 let
haftmann@30374
   540
                   val (in_ts, out_ts''') = get_args js us
haftmann@30374
   541
                   val u = list_comb (compile_expr thy modes (mode, t), in_ts)
haftmann@30374
   542
                   val rest = compile_prems out_ts''' vs' (fst nvs) ps'
haftmann@30374
   543
                 in
haftmann@30374
   544
                   (mk_not_pred u, rest)
haftmann@30374
   545
                 end
haftmann@30374
   546
             | Sidecond t =>
haftmann@30374
   547
                 let
haftmann@30374
   548
                   val rest = compile_prems [] vs' (fst nvs) ps';
haftmann@30374
   549
                 in
haftmann@30374
   550
                   (mk_if_predenum t, rest)
haftmann@30374
   551
                 end
haftmann@30374
   552
          in
haftmann@30374
   553
            compile_match thy (snd nvs) eqs out_ts'' 
haftmann@30374
   554
              (mk_bind (compiled_clause, rest))
haftmann@30374
   555
          end
haftmann@30374
   556
    val prem_t = compile_prems in_ts' param_vs all_vs' ps;
haftmann@30374
   557
  in
haftmann@30374
   558
    mk_bind (mk_single inp, prem_t)
haftmann@30374
   559
  end
haftmann@30374
   560
haftmann@30374
   561
fun compile_pred thy all_vs param_vs modes s T cls mode =
haftmann@30374
   562
  let
haftmann@30374
   563
    val Ts = binder_types T;
haftmann@30374
   564
    val (Ts1, Ts2) = chop (length param_vs) Ts;
haftmann@30374
   565
    val Ts1' = map2 funT_of Ts1 (fst mode)
haftmann@30374
   566
    val (Us1, Us2) = get_args (snd mode) Ts2;
haftmann@30374
   567
    val xnames = Name.variant_list param_vs
haftmann@30374
   568
      (map (fn i => "x" ^ string_of_int i) (snd mode));
haftmann@30374
   569
    val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1;
haftmann@30374
   570
    val cl_ts =
haftmann@30374
   571
      map (fn cl => compile_clause thy
haftmann@30374
   572
        all_vs param_vs modes mode cl (mk_tuple xs)) cls;
haftmann@30972
   573
    val mode_id = modename_of thy s mode
haftmann@30374
   574
  in
haftmann@30374
   575
    HOLogic.mk_Trueprop (HOLogic.mk_eq
haftmann@30374
   576
      (list_comb (Const (mode_id, (Ts1' @ Us1) --->
haftmann@30374
   577
           mk_pred_enumT (mk_tupleT Us2)),
haftmann@30374
   578
         map2 (fn s => fn T => Free (s, T)) param_vs Ts1' @ xs),
haftmann@30374
   579
       foldr1 mk_sup cl_ts))
haftmann@30374
   580
  end;
haftmann@30374
   581
haftmann@30374
   582
fun compile_preds thy all_vs param_vs modes preds =
haftmann@30374
   583
  map (fn (s, (T, cls)) =>
haftmann@30374
   584
    map (compile_pred thy all_vs param_vs modes s T cls)
haftmann@30374
   585
      ((the o AList.lookup (op =) modes) s)) preds;
haftmann@30374
   586
haftmann@30374
   587
(* end of term construction ******************************************************)
haftmann@30374
   588
haftmann@30374
   589
(* special setup for simpset *)                  
haftmann@30374
   590
val HOL_basic_ss' = HOL_basic_ss setSolver 
haftmann@30374
   591
  (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
haftmann@30374
   592
haftmann@30374
   593
haftmann@30374
   594
(* misc: constructing and proving tupleE rules ***********************************)
haftmann@30374
   595
haftmann@30374
   596
haftmann@30374
   597
(* Creating definitions of functional programs 
haftmann@30374
   598
   and proving intro and elim rules **********************************************) 
haftmann@30374
   599
haftmann@30374
   600
fun is_ind_pred thy c = 
haftmann@30374
   601
  (can (InductivePackage.the_inductive (ProofContext.init thy)) c) orelse
haftmann@30374
   602
  (c mem_string (Symtab.keys (#intro_rules (IndCodegenData.get thy))))
haftmann@30374
   603
haftmann@30374
   604
fun get_name_of_ind_calls_of_clauses thy preds intrs =
haftmann@30374
   605
    fold Term.add_consts intrs [] |> map fst
haftmann@30374
   606
    |> filter_out (member (op =) preds) |> filter (is_ind_pred thy)
haftmann@30374
   607
haftmann@30972
   608
fun print_arities arities = tracing ("Arities:\n" ^
haftmann@30374
   609
  cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
haftmann@30374
   610
    space_implode " -> " (map
haftmann@30374
   611
      (fn NONE => "X" | SOME k' => string_of_int k')
haftmann@30374
   612
        (ks @ [SOME k]))) arities));
haftmann@30374
   613
haftmann@30374
   614
fun mk_Eval_of ((x, T), NONE) names = (x, names)
haftmann@30374
   615
  | mk_Eval_of ((x, T), SOME mode) names = let
haftmann@30374
   616
  val Ts = binder_types T
haftmann@30374
   617
  val argnames = Name.variant_list names
haftmann@30374
   618
        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
haftmann@30374
   619
  val args = map Free (argnames ~~ Ts)
haftmann@30374
   620
  val (inargs, outargs) = get_args mode args
haftmann@30374
   621
  val r = mk_Eval (list_comb (x, inargs), mk_tuple outargs)
haftmann@30374
   622
  val t = fold_rev lambda args r 
haftmann@30374
   623
in
haftmann@30374
   624
  (t, argnames @ names)
haftmann@30374
   625
end;
haftmann@30374
   626
haftmann@30374
   627
fun create_intro_rule nparams mode defthm mode_id funT pred thy =
haftmann@30374
   628
let
haftmann@30374
   629
  val Ts = binder_types (fastype_of pred)
haftmann@30374
   630
  val funtrm = Const (mode_id, funT)
haftmann@30374
   631
  val argnames = Name.variant_list []
haftmann@30374
   632
        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
haftmann@30374
   633
  val (Ts1, Ts2) = chop nparams Ts;
haftmann@30374
   634
  val Ts1' = map2 funT_of Ts1 (fst mode)
haftmann@30374
   635
  val args = map Free (argnames ~~ (Ts1' @ Ts2))
haftmann@30374
   636
  val (params, io_args) = chop nparams args
haftmann@30374
   637
  val (inargs, outargs) = get_args (snd mode) io_args
haftmann@30374
   638
  val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ (fst mode)) []
haftmann@30374
   639
  val predprop = HOLogic.mk_Trueprop (list_comb (pred, params' @ io_args))
haftmann@30374
   640
  val funargs = params @ inargs
haftmann@30374
   641
  val funpropE = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
haftmann@30374
   642
                  if null outargs then Free("y", HOLogic.unitT) else mk_tuple outargs))
haftmann@30374
   643
  val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
haftmann@30374
   644
                   mk_tuple outargs))
haftmann@30374
   645
  val introtrm = Logic.mk_implies (predprop, funpropI)
haftmann@30374
   646
  val simprules = [defthm, @{thm eval_pred},
haftmann@30374
   647
                   @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
haftmann@30374
   648
  val unfolddef_tac = (Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1)
haftmann@30374
   649
  val introthm = Goal.prove (ProofContext.init thy) (argnames @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
haftmann@30374
   650
  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
haftmann@30374
   651
  val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predprop, P)], P)
haftmann@30374
   652
  val elimthm = Goal.prove (ProofContext.init thy) (argnames @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
haftmann@30374
   653
in
haftmann@30374
   654
  map_function_intros (Symtab.update_new (mode_id, introthm)) thy
haftmann@30374
   655
  |> map_function_elims (Symtab.update_new (mode_id, elimthm))
haftmann@30384
   656
  |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "I"), introthm) |> snd
haftmann@30384
   657
  |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "E"), elimthm)  |> snd
haftmann@30374
   658
end;
haftmann@30374
   659
haftmann@30374
   660
fun create_definitions preds nparams (name, modes) thy =
haftmann@30374
   661
  let
haftmann@30374
   662
    val _ = tracing "create definitions"
haftmann@30374
   663
    val T = AList.lookup (op =) preds name |> the
haftmann@30374
   664
    fun create_definition mode thy = let
haftmann@30374
   665
      fun string_of_mode mode = if null mode then "0"
haftmann@30374
   666
        else space_implode "_" (map string_of_int mode)
haftmann@30374
   667
      val HOmode = let
haftmann@30374
   668
        fun string_of_HOmode m s = case m of NONE => s | SOME mode => s ^ "__" ^ (string_of_mode mode)    
haftmann@30374
   669
        in (fold string_of_HOmode (fst mode) "") end;
haftmann@30374
   670
      val mode_id = name ^ (if HOmode = "" then "_" else HOmode ^ "___")
haftmann@30374
   671
        ^ (string_of_mode (snd mode))
haftmann@30374
   672
      val Ts = binder_types T;
haftmann@30374
   673
      val (Ts1, Ts2) = chop nparams Ts;
haftmann@30374
   674
      val Ts1' = map2 funT_of Ts1 (fst mode)
haftmann@30374
   675
      val (Us1, Us2) = get_args (snd mode) Ts2;
haftmann@30374
   676
      val names = Name.variant_list []
haftmann@30374
   677
        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
haftmann@30374
   678
      val xs = map Free (names ~~ (Ts1' @ Ts2));
haftmann@30374
   679
      val (xparams, xargs) = chop nparams xs;
haftmann@30374
   680
      val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ (fst mode)) names
haftmann@30374
   681
      val (xins, xouts) = get_args (snd mode) xargs;
haftmann@30374
   682
      fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
haftmann@30374
   683
       | mk_split_lambda [x] t = lambda x t
haftmann@30374
   684
       | mk_split_lambda xs t = let
haftmann@30374
   685
         fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
haftmann@30374
   686
           | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
haftmann@30374
   687
         in mk_split_lambda' xs t end;
haftmann@30374
   688
      val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs)))
haftmann@30374
   689
      val funT = (Ts1' @ Us1) ---> (mk_pred_enumT (mk_tupleT Us2))
haftmann@30384
   690
      val mode_id = Sign.full_bname thy (Long_Name.base_name mode_id)
haftmann@30374
   691
      val lhs = list_comb (Const (mode_id, funT), xparams @ xins)
haftmann@30374
   692
      val def = Logic.mk_equals (lhs, predterm)
haftmann@30374
   693
      val ([defthm], thy') = thy |>
haftmann@30384
   694
        Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_id), funT, NoSyn)] |>
haftmann@30384
   695
        PureThy.add_defs false [((Binding.name (Long_Name.base_name mode_id ^ "_def"), def), [])]
haftmann@30374
   696
      in thy' |> map_names (PredModetab.update_new ((name, mode), mode_id))
haftmann@30374
   697
           |> map_function_defs (Symtab.update_new (mode_id, defthm))
haftmann@30374
   698
           |> create_intro_rule nparams mode defthm mode_id funT (Const (name, T))
haftmann@30374
   699
      end;
haftmann@30374
   700
  in
haftmann@30374
   701
    fold create_definition modes thy
haftmann@30374
   702
  end;
haftmann@30374
   703
haftmann@30374
   704
(**************************************************************************************)
haftmann@30374
   705
(* Proving equivalence of term *)
haftmann@30374
   706
haftmann@30374
   707
haftmann@30972
   708
fun intro_rule thy pred mode = modename_of thy pred mode
haftmann@30374
   709
    |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the
haftmann@30374
   710
haftmann@30972
   711
fun elim_rule thy pred mode = modename_of thy pred mode
haftmann@30374
   712
    |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the
haftmann@30374
   713
haftmann@30374
   714
fun pred_intros thy predname = let
haftmann@30374
   715
    fun is_intro_of pred intro = let
haftmann@30374
   716
      val const = fst (strip_comb (HOLogic.dest_Trueprop (concl_of intro)))
haftmann@30374
   717
    in (fst (dest_Const const) = pred) end;
haftmann@30374
   718
    val d = IndCodegenData.get thy
haftmann@30374
   719
  in
haftmann@30374
   720
    if (Symtab.defined (#intro_rules d) predname) then
haftmann@30374
   721
      rev (Symtab.lookup_list (#intro_rules d) predname)
haftmann@30374
   722
    else
haftmann@30374
   723
      InductivePackage.the_inductive (ProofContext.init thy) predname
haftmann@30374
   724
      |> snd |> #intrs |> filter (is_intro_of predname)
haftmann@30374
   725
  end
haftmann@30374
   726
haftmann@30374
   727
fun function_definition thy pred mode =
haftmann@30972
   728
  modename_of thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
haftmann@30374
   729
haftmann@30374
   730
fun is_Type (Type _) = true
haftmann@30374
   731
  | is_Type _ = false
haftmann@30374
   732
haftmann@30374
   733
fun imp_prems_conv cv ct =
haftmann@30374
   734
  case Thm.term_of ct of
haftmann@30374
   735
    Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
haftmann@30374
   736
  | _ => Conv.all_conv ct
haftmann@30374
   737
haftmann@30374
   738
fun Trueprop_conv cv ct =
haftmann@30374
   739
  case Thm.term_of ct of
haftmann@30374
   740
    Const ("Trueprop", _) $ _ => Conv.arg_conv cv ct  
haftmann@30374
   741
  | _ => error "Trueprop_conv"
haftmann@30374
   742
bulwahn@31105
   743
fun preprocess_intro thy rule =
haftmann@30374
   744
  Conv.fconv_rule
haftmann@30374
   745
    (imp_prems_conv
bulwahn@31105
   746
      (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
bulwahn@31105
   747
    (Thm.transfer thy rule)
haftmann@30374
   748
bulwahn@31105
   749
fun preprocess_elim thy nargs elimrule = let
haftmann@30374
   750
   fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
bulwahn@31105
   751
      HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
haftmann@30374
   752
    | replace_eqs t = t
haftmann@30374
   753
   fun preprocess_case t = let
haftmann@30374
   754
     val params = Logic.strip_params t
haftmann@30374
   755
     val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
haftmann@30374
   756
     val assums_hyp' = assums1 @ (map replace_eqs assums2)
haftmann@30374
   757
     in list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t)) end
haftmann@30374
   758
   val prems = Thm.prems_of elimrule
haftmann@30374
   759
   val cases' = map preprocess_case (tl prems)
haftmann@30374
   760
   val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
haftmann@30374
   761
 in
haftmann@30374
   762
   Thm.equal_elim
bulwahn@31105
   763
     (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm eq_is_eq}])
haftmann@30374
   764
        (cterm_of thy elimrule')))
haftmann@30374
   765
     elimrule
bulwahn@31105
   766
 end;
haftmann@30374
   767
haftmann@30374
   768
haftmann@30374
   769
(* returns true if t is an application of an datatype constructor *)
haftmann@30374
   770
(* which then consequently would be splitted *)
haftmann@30374
   771
(* else false *)
haftmann@30374
   772
fun is_constructor thy t =
haftmann@30374
   773
  if (is_Type (fastype_of t)) then
haftmann@30374
   774
    (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
haftmann@30374
   775
      NONE => false
haftmann@30374
   776
    | SOME info => (let
haftmann@30374
   777
      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
haftmann@30374
   778
      val (c, _) = strip_comb t
haftmann@30374
   779
      in (case c of
haftmann@30374
   780
        Const (name, _) => name mem_string constr_consts
haftmann@30374
   781
        | _ => false) end))
haftmann@30374
   782
  else false
haftmann@30374
   783
haftmann@30374
   784
(* MAJOR FIXME:  prove_params should be simple
haftmann@30374
   785
 - different form of introrule for parameters ? *)
haftmann@30374
   786
fun prove_param thy modes (NONE, t) = all_tac 
haftmann@30374
   787
  | prove_param thy modes (m as SOME (Mode (mode, is, ms)), t) = let
haftmann@30374
   788
    val  (f, args) = strip_comb t
haftmann@30374
   789
    val (params, _) = chop (length ms) args
haftmann@30374
   790
    val f_tac = case f of
haftmann@30374
   791
        Const (name, T) => simp_tac (HOL_basic_ss addsimps 
haftmann@30374
   792
           @{thm eval_pred}::function_definition thy name mode::[]) 1
haftmann@30374
   793
      | Free _ => all_tac
haftmann@30374
   794
  in  
haftmann@30374
   795
    print_tac "before simplification in prove_args:"
haftmann@30374
   796
    THEN debug_tac ("mode" ^ (makestring mode))
haftmann@30374
   797
    THEN f_tac
haftmann@30374
   798
    THEN print_tac "after simplification in prove_args"
haftmann@30374
   799
    (* work with parameter arguments *)
haftmann@30374
   800
    THEN (EVERY (map (prove_param thy modes) (ms ~~ params)))
haftmann@30374
   801
    THEN (REPEAT_DETERM (atac 1))
haftmann@30374
   802
  end
haftmann@30374
   803
haftmann@30374
   804
fun prove_expr thy modes (SOME (Mode (mode, is, ms)), t, us) (premposition : int) =
haftmann@30374
   805
  (case strip_comb t of
haftmann@30374
   806
    (Const (name, T), args) =>
haftmann@30374
   807
      if AList.defined op = modes name then (let
haftmann@30374
   808
          val introrule = intro_rule thy name mode
haftmann@30374
   809
          (*val (in_args, out_args) = get_args is us
haftmann@30374
   810
          val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop
haftmann@30374
   811
            (hd (Logic.strip_imp_prems (prop_of introrule))))
haftmann@30374
   812
          val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *)
haftmann@30374
   813
          val (_, args) = chop nparams rargs
haftmann@30374
   814
          val _ = tracing ("args: " ^ (makestring args))
haftmann@30374
   815
          val subst = map (pairself (cterm_of thy)) (args ~~ us)
haftmann@30374
   816
          val _ = tracing ("subst: " ^ (makestring subst))
haftmann@30374
   817
          val inst_introrule = Drule.cterm_instantiate subst introrule*)
haftmann@30374
   818
         (* the next line is old and probably wrong *)
haftmann@30374
   819
          val (args1, args2) = chop (length ms) args
haftmann@30374
   820
          val _ = tracing ("premposition: " ^ (makestring premposition))
haftmann@30374
   821
        in
haftmann@30374
   822
        rtac @{thm bindI} 1
haftmann@30374
   823
        THEN print_tac "before intro rule:"
haftmann@30374
   824
        THEN debug_tac ("mode" ^ (makestring mode))
haftmann@30374
   825
        THEN debug_tac (makestring introrule)
haftmann@30374
   826
        THEN debug_tac ("premposition: " ^ (makestring premposition))
haftmann@30374
   827
        (* for the right assumption in first position *)
haftmann@30374
   828
        THEN rotate_tac premposition 1
haftmann@30374
   829
        THEN rtac introrule 1
haftmann@30374
   830
        THEN print_tac "after intro rule"
haftmann@30374
   831
        (* work with parameter arguments *)
haftmann@30374
   832
        THEN (EVERY (map (prove_param thy modes) (ms ~~ args1)))
haftmann@30374
   833
        THEN (REPEAT_DETERM (atac 1)) end)
haftmann@30374
   834
      else error "Prove expr if case not implemented"
haftmann@30374
   835
    | _ => rtac @{thm bindI} 1
haftmann@30374
   836
           THEN atac 1)
haftmann@30374
   837
  | prove_expr _ _ _ _ =  error "Prove expr not implemented"
haftmann@30374
   838
haftmann@30374
   839
fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; 
haftmann@30374
   840
haftmann@30374
   841
fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
haftmann@30374
   842
haftmann@30374
   843
fun prove_match thy (out_ts : term list) = let
haftmann@30374
   844
  fun get_case_rewrite t =
haftmann@30374
   845
    if (is_constructor thy t) then let
haftmann@30374
   846
      val case_rewrites = (#case_rewrites (DatatypePackage.the_datatype thy
haftmann@30374
   847
        ((fst o dest_Type o fastype_of) t)))
haftmann@30374
   848
      in case_rewrites @ (flat (map get_case_rewrite (snd (strip_comb t)))) end
haftmann@30374
   849
    else []
haftmann@30374
   850
  val simprules = @{thm "unit.cases"} :: @{thm "prod.cases"} :: (flat (map get_case_rewrite out_ts))
haftmann@30374
   851
(* replace TRY by determining if it necessary - are there equations when calling compile match? *)
haftmann@30374
   852
in
haftmann@30374
   853
  print_tac ("before prove_match rewriting: simprules = " ^ (makestring simprules))
haftmann@30374
   854
   (* make this simpset better! *)
haftmann@30374
   855
  THEN asm_simp_tac (HOL_basic_ss' addsimps simprules) 1
haftmann@30374
   856
  THEN print_tac "after prove_match:"
haftmann@30374
   857
  THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm "HOL.if_P"}] 1
haftmann@30374
   858
         THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss 1))))
haftmann@30374
   859
         THEN (SOLVED (asm_simp_tac HOL_basic_ss 1)))))
haftmann@30374
   860
  THEN print_tac "after if simplification"
haftmann@30374
   861
end;
haftmann@30374
   862
haftmann@30374
   863
(* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
haftmann@30374
   864
haftmann@30374
   865
fun prove_sidecond thy modes t = let
haftmann@30374
   866
  val _ = tracing ("prove_sidecond:" ^ (makestring t))
haftmann@30374
   867
  fun preds_of t nameTs = case strip_comb t of 
haftmann@30374
   868
    (f as Const (name, T), args) =>
haftmann@30374
   869
      if AList.defined (op =) modes name then (name, T) :: nameTs
haftmann@30374
   870
        else fold preds_of args nameTs
haftmann@30374
   871
    | _ => nameTs
haftmann@30374
   872
  val preds = preds_of t []
haftmann@30374
   873
  
haftmann@30374
   874
  val _ = tracing ("preds: " ^ (makestring preds))
haftmann@30374
   875
  val defs = map
haftmann@30374
   876
    (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
haftmann@30374
   877
      preds
haftmann@30374
   878
  val _ = tracing ("defs: " ^ (makestring defs))
haftmann@30374
   879
in 
haftmann@30374
   880
   (* remove not_False_eq_True when simpset in prove_match is better *)
haftmann@30374
   881
   simp_tac (HOL_basic_ss addsimps @{thm not_False_eq_True} :: @{thm eval_pred} :: defs) 1 
haftmann@30374
   882
   (* need better control here! *)
haftmann@30374
   883
   THEN print_tac "after sidecond simplification"
haftmann@30374
   884
   end
haftmann@30374
   885
haftmann@30374
   886
fun prove_clause thy nargs all_vs param_vs modes (iss, is) (ts, ps) = let
haftmann@30374
   887
  val modes' = modes @ List.mapPartial
haftmann@30374
   888
   (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
   889
     (param_vs ~~ iss);
haftmann@30374
   890
  fun check_constrt ((names, eqs), t) =
haftmann@30374
   891
      if is_constrt thy t then ((names, eqs), t) else
haftmann@30374
   892
        let
haftmann@30374
   893
          val s = Name.variant names "x";
haftmann@30374
   894
          val v = Free (s, fastype_of t)
haftmann@30374
   895
        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
haftmann@30374
   896
  
haftmann@30374
   897
  val (in_ts, clause_out_ts) = get_args is ts;
haftmann@30374
   898
  val ((all_vs', eqs), in_ts') =
haftmann@30374
   899
      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
haftmann@30374
   900
  fun prove_prems out_ts vs [] =
haftmann@30374
   901
    (prove_match thy out_ts)
haftmann@30374
   902
    THEN asm_simp_tac HOL_basic_ss' 1
haftmann@30374
   903
    THEN print_tac "before the last rule of singleI:"
haftmann@30374
   904
    THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
haftmann@30374
   905
  | prove_prems out_ts vs rps =
haftmann@30374
   906
    let
haftmann@30374
   907
      val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
haftmann@30374
   908
      val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
haftmann@30374
   909
        select_mode_prem thy modes' vs' rps;
haftmann@30374
   910
      val premposition = (find_index (equal p) ps) + nargs
haftmann@30374
   911
      val rps' = filter_out (equal p) rps;
haftmann@30374
   912
      val rest_tac = (case p of Prem (us, t) =>
haftmann@30374
   913
          let
haftmann@30374
   914
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
   915
            val rec_tac = prove_prems out_ts''' vs' rps'
haftmann@30374
   916
          in
haftmann@30374
   917
            print_tac "before clause:"
haftmann@30374
   918
            THEN asm_simp_tac HOL_basic_ss 1
haftmann@30374
   919
            THEN print_tac "before prove_expr:"
haftmann@30374
   920
            THEN prove_expr thy modes (mode, t, us) premposition
haftmann@30374
   921
            THEN print_tac "after prove_expr:"
haftmann@30374
   922
            THEN rec_tac
haftmann@30374
   923
          end
haftmann@30374
   924
        | Negprem (us, t) =>
haftmann@30374
   925
          let
haftmann@30374
   926
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
   927
            val rec_tac = prove_prems out_ts''' vs' rps'
haftmann@30374
   928
            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
haftmann@30374
   929
            val (_, params) = strip_comb t
haftmann@30374
   930
          in
haftmann@30374
   931
            print_tac "before negated clause:"
haftmann@30374
   932
            THEN rtac @{thm bindI} 1
haftmann@30374
   933
            THEN (if (is_some name) then
haftmann@30374
   934
                simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1
haftmann@30374
   935
                THEN rtac @{thm not_predI} 1
haftmann@30374
   936
                THEN print_tac "after neg. intro rule"
haftmann@30374
   937
                THEN print_tac ("t = " ^ (makestring t))
haftmann@30374
   938
                (* FIXME: work with parameter arguments *)
haftmann@30374
   939
                THEN (EVERY (map (prove_param thy modes) (param_modes ~~ params)))
haftmann@30374
   940
              else
haftmann@30374
   941
                rtac @{thm not_predI'} 1)
haftmann@30374
   942
            THEN (REPEAT_DETERM (atac 1))
haftmann@30374
   943
            THEN rec_tac
haftmann@30374
   944
          end
haftmann@30374
   945
        | Sidecond t =>
haftmann@30374
   946
         rtac @{thm bindI} 1
haftmann@30374
   947
         THEN rtac @{thm if_predI} 1
haftmann@30374
   948
         THEN print_tac "before sidecond:"
haftmann@30374
   949
         THEN prove_sidecond thy modes t
haftmann@30374
   950
         THEN print_tac "after sidecond:"
haftmann@30374
   951
         THEN prove_prems [] vs' rps')
haftmann@30374
   952
    in (prove_match thy out_ts)
haftmann@30374
   953
        THEN rest_tac
haftmann@30374
   954
    end;
haftmann@30374
   955
  val prems_tac = prove_prems in_ts' param_vs ps
haftmann@30374
   956
in
haftmann@30374
   957
  rtac @{thm bindI} 1
haftmann@30374
   958
  THEN rtac @{thm singleI} 1
haftmann@30374
   959
  THEN prems_tac
haftmann@30374
   960
end;
haftmann@30374
   961
haftmann@30374
   962
fun select_sup 1 1 = []
haftmann@30374
   963
  | select_sup _ 1 = [rtac @{thm supI1}]
haftmann@30374
   964
  | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
haftmann@30374
   965
haftmann@30374
   966
fun get_nparams thy s = let
haftmann@30374
   967
    val _ = tracing ("get_nparams: " ^ s)
haftmann@30374
   968
  in
haftmann@30374
   969
  if Symtab.defined (#nparams (IndCodegenData.get thy)) s then
haftmann@30374
   970
    the (Symtab.lookup (#nparams (IndCodegenData.get thy)) s) 
haftmann@30374
   971
  else
haftmann@30374
   972
    case try (InductivePackage.the_inductive (ProofContext.init thy)) s of
haftmann@30374
   973
      SOME info => info |> snd |> #raw_induct |> Thm.unvarify
haftmann@30374
   974
        |> InductivePackage.params_of |> length
haftmann@30374
   975
    | NONE => 0 (* default value *)
haftmann@30374
   976
  end
haftmann@30374
   977
haftmann@30374
   978
val ind_set_codegen_preproc = InductiveSetPackage.codegen_preproc;
haftmann@30374
   979
haftmann@30374
   980
fun pred_elim thy predname =
haftmann@30374
   981
  if (Symtab.defined (#elim_rules (IndCodegenData.get thy)) predname) then
haftmann@30374
   982
    the (Symtab.lookup (#elim_rules (IndCodegenData.get thy)) predname)
haftmann@30374
   983
  else
haftmann@30374
   984
    (let
haftmann@30374
   985
      val ind_result = InductivePackage.the_inductive (ProofContext.init thy) predname
haftmann@30374
   986
      val index = find_index (fn s => s = predname) (#names (fst ind_result))
haftmann@30374
   987
    in nth (#elims (snd ind_result)) index end)
haftmann@30374
   988
haftmann@30374
   989
fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let
haftmann@30972
   990
  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename_of thy pred mode))
haftmann@30374
   991
(*  val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred
haftmann@30374
   992
  val index = find_index (fn s => s = pred) (#names (fst ind_result))
haftmann@30374
   993
  val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *)
haftmann@30374
   994
  val nargs = length (binder_types T) - get_nparams thy pred
haftmann@30374
   995
  val pred_case_rule = singleton (ind_set_codegen_preproc thy)
haftmann@30374
   996
    (preprocess_elim thy nargs (pred_elim thy pred))
haftmann@30374
   997
  (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}])*)
haftmann@30374
   998
  val _ = tracing ("pred_case_rule " ^ (makestring pred_case_rule))
haftmann@30374
   999
in
haftmann@30374
  1000
  REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
haftmann@30374
  1001
  THEN etac elim_rule 1
haftmann@30374
  1002
  THEN etac pred_case_rule 1
haftmann@30374
  1003
  THEN (EVERY (map
haftmann@30374
  1004
         (fn i => EVERY' (select_sup (length clauses) i) i) 
haftmann@30374
  1005
           (1 upto (length clauses))))
haftmann@30374
  1006
  THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses))
haftmann@30374
  1007
end;
haftmann@30374
  1008
haftmann@30374
  1009
(*******************************************************************************************************)
haftmann@30374
  1010
(* Proof in the other direction ************************************************************************)
haftmann@30374
  1011
(*******************************************************************************************************)
haftmann@30374
  1012
haftmann@30374
  1013
fun prove_match2 thy out_ts = let
haftmann@30374
  1014
  fun split_term_tac (Free _) = all_tac
haftmann@30374
  1015
    | split_term_tac t =
haftmann@30374
  1016
      if (is_constructor thy t) then let
haftmann@30374
  1017
        val info = DatatypePackage.the_datatype thy ((fst o dest_Type o fastype_of) t)
haftmann@30374
  1018
        val num_of_constrs = length (#case_rewrites info)
haftmann@30374
  1019
        (* special treatment of pairs -- because of fishing *)
haftmann@30374
  1020
        val split_rules = case (fst o dest_Type o fastype_of) t of
haftmann@30374
  1021
          "*" => [@{thm prod.split_asm}] 
haftmann@30374
  1022
          | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
haftmann@30374
  1023
        val (_, ts) = strip_comb t
haftmann@30374
  1024
      in
haftmann@30374
  1025
        print_tac ("splitting with t = " ^ (makestring t))
haftmann@30374
  1026
        THEN (Splitter.split_asm_tac split_rules 1)
haftmann@30374
  1027
(*        THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
haftmann@30374
  1028
          THEN (DETERM (TRY (etac @{thm Pair_inject} 1))) *)
haftmann@30374
  1029
        THEN (REPEAT_DETERM_N (num_of_constrs - 1) (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))
haftmann@30374
  1030
        THEN (EVERY (map split_term_tac ts))
haftmann@30374
  1031
      end
haftmann@30374
  1032
    else all_tac
haftmann@30374
  1033
  in
haftmann@30374
  1034
    split_term_tac (mk_tuple out_ts)
haftmann@30374
  1035
    THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) THEN (etac @{thm botE} 2))))
haftmann@30374
  1036
  end
haftmann@30374
  1037
haftmann@30374
  1038
(* VERY LARGE SIMILIRATIY to function prove_param 
haftmann@30374
  1039
-- join both functions
haftmann@30374
  1040
*) 
haftmann@30374
  1041
fun prove_param2 thy modes (NONE, t) = all_tac 
haftmann@30374
  1042
  | prove_param2 thy modes (m as SOME (Mode (mode, is, ms)), t) = let
haftmann@30374
  1043
    val  (f, args) = strip_comb t
haftmann@30374
  1044
    val (params, _) = chop (length ms) args
haftmann@30374
  1045
    val f_tac = case f of
haftmann@30374
  1046
        Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
haftmann@30374
  1047
           @{thm eval_pred}::function_definition thy name mode::[]) 1
haftmann@30374
  1048
      | Free _ => all_tac
haftmann@30374
  1049
  in  
haftmann@30374
  1050
    print_tac "before simplification in prove_args:"
haftmann@30374
  1051
    THEN debug_tac ("function : " ^ (makestring f) ^ " - mode" ^ (makestring mode))
haftmann@30374
  1052
    THEN f_tac
haftmann@30374
  1053
    THEN print_tac "after simplification in prove_args"
haftmann@30374
  1054
    (* work with parameter arguments *)
haftmann@30374
  1055
    THEN (EVERY (map (prove_param2 thy modes) (ms ~~ params)))
haftmann@30374
  1056
  end
haftmann@30374
  1057
haftmann@30374
  1058
fun prove_expr2 thy modes (SOME (Mode (mode, is, ms)), t) = 
haftmann@30374
  1059
  (case strip_comb t of
haftmann@30374
  1060
    (Const (name, T), args) =>
haftmann@30374
  1061
      if AList.defined op = modes name then
haftmann@30374
  1062
        etac @{thm bindE} 1
haftmann@30374
  1063
        THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
haftmann@30374
  1064
        THEN (etac (elim_rule thy name mode) 1)
haftmann@30374
  1065
        THEN (EVERY (map (prove_param2 thy modes) (ms ~~ args)))
haftmann@30374
  1066
      else error "Prove expr2 if case not implemented"
haftmann@30374
  1067
    | _ => etac @{thm bindE} 1)
haftmann@30374
  1068
  | prove_expr2 _ _ _ = error "Prove expr2 not implemented"
haftmann@30374
  1069
haftmann@30374
  1070
fun prove_sidecond2 thy modes t = let
haftmann@30374
  1071
  val _ = tracing ("prove_sidecond:" ^ (makestring t))
haftmann@30374
  1072
  fun preds_of t nameTs = case strip_comb t of 
haftmann@30374
  1073
    (f as Const (name, T), args) =>
haftmann@30374
  1074
      if AList.defined (op =) modes name then (name, T) :: nameTs
haftmann@30374
  1075
        else fold preds_of args nameTs
haftmann@30374
  1076
    | _ => nameTs
haftmann@30374
  1077
  val preds = preds_of t []
haftmann@30374
  1078
  val _ = tracing ("preds: " ^ (makestring preds))
haftmann@30374
  1079
  val defs = map
haftmann@30374
  1080
    (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
haftmann@30374
  1081
      preds
haftmann@30374
  1082
  in
haftmann@30374
  1083
   (* only simplify the one assumption *)
haftmann@30374
  1084
   full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 
haftmann@30374
  1085
   (* need better control here! *)
haftmann@30374
  1086
   THEN print_tac "after sidecond2 simplification"
haftmann@30374
  1087
   end
haftmann@30374
  1088
  
haftmann@30374
  1089
fun prove_clause2 thy all_vs param_vs modes (iss, is) (ts, ps) pred i = let
haftmann@30374
  1090
  val modes' = modes @ List.mapPartial
haftmann@30374
  1091
   (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
  1092
     (param_vs ~~ iss);
haftmann@30374
  1093
  fun check_constrt ((names, eqs), t) =
haftmann@30374
  1094
      if is_constrt thy t then ((names, eqs), t) else
haftmann@30374
  1095
        let
haftmann@30374
  1096
          val s = Name.variant names "x";
haftmann@30374
  1097
          val v = Free (s, fastype_of t)
haftmann@30374
  1098
        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
haftmann@30374
  1099
  val pred_intro_rule = nth (pred_intros thy pred) (i - 1)
haftmann@30374
  1100
    |> preprocess_intro thy
haftmann@30374
  1101
    |> (fn thm => hd (ind_set_codegen_preproc thy [thm]))
haftmann@30374
  1102
    (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *)
haftmann@30374
  1103
  val (in_ts, clause_out_ts) = get_args is ts;
haftmann@30374
  1104
  val ((all_vs', eqs), in_ts') =
haftmann@30374
  1105
      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
haftmann@30374
  1106
  fun prove_prems2 out_ts vs [] =
haftmann@30374
  1107
    print_tac "before prove_match2 - last call:"
haftmann@30374
  1108
    THEN prove_match2 thy out_ts
haftmann@30374
  1109
    THEN print_tac "after prove_match2 - last call:"
haftmann@30374
  1110
    THEN (etac @{thm singleE} 1)
haftmann@30374
  1111
    THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
haftmann@30374
  1112
    THEN (asm_full_simp_tac HOL_basic_ss' 1)
haftmann@30374
  1113
    THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
haftmann@30374
  1114
    THEN (asm_full_simp_tac HOL_basic_ss' 1)
haftmann@30374
  1115
    THEN SOLVED (print_tac "state before applying intro rule:"
haftmann@30374
  1116
      THEN (rtac pred_intro_rule 1)
haftmann@30374
  1117
      (* How to handle equality correctly? *)
haftmann@30374
  1118
      THEN (print_tac "state before assumption matching")
haftmann@30374
  1119
      THEN (REPEAT (atac 1 ORELSE 
haftmann@30374
  1120
         (CHANGED (asm_full_simp_tac HOL_basic_ss' 1)
haftmann@30374
  1121
          THEN print_tac "state after simp_tac:"))))
haftmann@30374
  1122
  | prove_prems2 out_ts vs ps = let
haftmann@30374
  1123
      val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
haftmann@30374
  1124
      val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
haftmann@30374
  1125
        select_mode_prem thy modes' vs' ps;
haftmann@30374
  1126
      val ps' = filter_out (equal p) ps;
haftmann@30374
  1127
      val rest_tac = (case p of Prem (us, t) =>
haftmann@30374
  1128
          let
haftmann@30374
  1129
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
  1130
            val rec_tac = prove_prems2 out_ts''' vs' ps'
haftmann@30374
  1131
          in
haftmann@30374
  1132
            (prove_expr2 thy modes (mode, t)) THEN rec_tac
haftmann@30374
  1133
          end
haftmann@30374
  1134
        | Negprem (us, t) =>
haftmann@30374
  1135
          let
haftmann@30374
  1136
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
  1137
            val rec_tac = prove_prems2 out_ts''' vs' ps'
haftmann@30374
  1138
            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
haftmann@30374
  1139
            val (_, params) = strip_comb t
haftmann@30374
  1140
          in
haftmann@30374
  1141
            print_tac "before neg prem 2"
haftmann@30374
  1142
            THEN etac @{thm bindE} 1
haftmann@30374
  1143
            THEN (if is_some name then
haftmann@30374
  1144
                full_simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1 
haftmann@30374
  1145
                THEN etac @{thm not_predE} 1
haftmann@30374
  1146
                THEN (EVERY (map (prove_param2 thy modes) (param_modes ~~ params)))
haftmann@30374
  1147
              else
haftmann@30374
  1148
                etac @{thm not_predE'} 1)
haftmann@30374
  1149
            THEN rec_tac
haftmann@30374
  1150
          end 
haftmann@30374
  1151
        | Sidecond t =>
haftmann@30374
  1152
            etac @{thm bindE} 1
haftmann@30374
  1153
            THEN etac @{thm if_predE} 1
haftmann@30374
  1154
            THEN prove_sidecond2 thy modes t 
haftmann@30374
  1155
            THEN prove_prems2 [] vs' ps')
haftmann@30374
  1156
    in print_tac "before prove_match2:"
haftmann@30374
  1157
       THEN prove_match2 thy out_ts
haftmann@30374
  1158
       THEN print_tac "after prove_match2:"
haftmann@30374
  1159
       THEN rest_tac
haftmann@30374
  1160
    end;
haftmann@30374
  1161
  val prems_tac = prove_prems2 in_ts' param_vs ps 
haftmann@30374
  1162
in
haftmann@30374
  1163
  print_tac "starting prove_clause2"
haftmann@30374
  1164
  THEN etac @{thm bindE} 1
haftmann@30374
  1165
  THEN (etac @{thm singleE'} 1)
haftmann@30374
  1166
  THEN (TRY (etac @{thm Pair_inject} 1))
haftmann@30374
  1167
  THEN print_tac "after singleE':"
haftmann@30374
  1168
  THEN prems_tac
haftmann@30374
  1169
end;
haftmann@30374
  1170
 
haftmann@30374
  1171
fun prove_other_direction thy all_vs param_vs modes clauses (pred, mode) = let
haftmann@30374
  1172
  fun prove_clause (clause, i) =
haftmann@30374
  1173
    (if i < length clauses then etac @{thm supE} 1 else all_tac)
haftmann@30374
  1174
    THEN (prove_clause2 thy all_vs param_vs modes mode clause pred i)
haftmann@30374
  1175
in
haftmann@30374
  1176
  (DETERM (TRY (rtac @{thm unit.induct} 1)))
haftmann@30374
  1177
   THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
haftmann@30374
  1178
   THEN (rtac (intro_rule thy pred mode) 1)
haftmann@30374
  1179
   THEN (EVERY (map prove_clause (clauses ~~ (1 upto (length clauses)))))
haftmann@30374
  1180
end;
haftmann@30374
  1181
haftmann@30374
  1182
fun prove_pred thy all_vs param_vs modes clauses (((pred, T), mode), t) = let
haftmann@30374
  1183
  val ctxt = ProofContext.init thy
haftmann@30374
  1184
  val clauses' = the (AList.lookup (op =) clauses pred)
haftmann@30374
  1185
in
haftmann@30374
  1186
  Goal.prove ctxt (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) t []) [] t
haftmann@30374
  1187
    (if !do_proofs then
haftmann@30374
  1188
      (fn _ =>
haftmann@30374
  1189
      rtac @{thm pred_iffI} 1
haftmann@30374
  1190
      THEN prove_one_direction thy all_vs param_vs modes clauses' ((pred, T), mode)
haftmann@30374
  1191
      THEN print_tac "proved one direction"
haftmann@30374
  1192
      THEN prove_other_direction thy all_vs param_vs modes clauses' (pred, mode)
haftmann@30374
  1193
      THEN print_tac "proved other direction")
haftmann@30374
  1194
     else (fn _ => mycheat_tac thy 1))
haftmann@30374
  1195
end;
haftmann@30374
  1196
haftmann@30374
  1197
fun prove_preds thy all_vs param_vs modes clauses pmts =
haftmann@30374
  1198
  map (prove_pred thy all_vs param_vs modes clauses) pmts
haftmann@30374
  1199
haftmann@30374
  1200
(* look for other place where this functionality was used before *)
haftmann@30374
  1201
fun strip_intro_concl intro nparams = let
haftmann@30374
  1202
  val _ $ u = Logic.strip_imp_concl intro
haftmann@30374
  1203
  val (pred, all_args) = strip_comb u
haftmann@30374
  1204
  val (params, args) = chop nparams all_args
haftmann@30374
  1205
in (pred, (params, args)) end
haftmann@30374
  1206
haftmann@30374
  1207
(* setup for alternative introduction and elimination rules *)
haftmann@30374
  1208
haftmann@30374
  1209
fun add_intro_thm thm thy = let
haftmann@30374
  1210
   val (pred, _) = dest_Const (fst (strip_intro_concl (prop_of thm) 0))
haftmann@30374
  1211
 in map_intro_rules (Symtab.insert_list Thm.eq_thm (pred, thm)) thy end
haftmann@30374
  1212
haftmann@30374
  1213
fun add_elim_thm thm thy = let
haftmann@30374
  1214
    val (pred, _) = dest_Const (fst 
haftmann@30374
  1215
      (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
haftmann@30374
  1216
  in map_elim_rules (Symtab.update (pred, thm)) thy end
haftmann@30374
  1217
haftmann@30374
  1218
haftmann@30374
  1219
(* special case: inductive predicate with no clauses *)
haftmann@30374
  1220
fun noclause (predname, T) thy = let
haftmann@30374
  1221
  val Ts = binder_types T
haftmann@30374
  1222
  val names = Name.variant_list []
haftmann@30374
  1223
        (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
haftmann@31124
  1224
  val vs = map2 (curry Free) names Ts
haftmann@30374
  1225
  val clausehd =  HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs))
haftmann@30374
  1226
  val intro_t = Logic.mk_implies (@{prop False}, clausehd)
haftmann@30374
  1227
  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
haftmann@30374
  1228
  val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
haftmann@30374
  1229
  val intro_thm = Goal.prove (ProofContext.init thy) names [] intro_t
haftmann@30374
  1230
        (fn {...} => etac @{thm FalseE} 1)
haftmann@30374
  1231
  val elim_thm = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
haftmann@30374
  1232
        (fn {...} => etac (pred_elim thy predname) 1) 
haftmann@30374
  1233
in
haftmann@30374
  1234
  add_intro_thm intro_thm thy
haftmann@30374
  1235
  |> add_elim_thm elim_thm
haftmann@30374
  1236
end
haftmann@30374
  1237
haftmann@30374
  1238
(*************************************************************************************)
haftmann@30374
  1239
(* main function *********************************************************************)
haftmann@30374
  1240
(*************************************************************************************)
haftmann@30374
  1241
haftmann@31124
  1242
fun prove_equation ind_name mode thy =
haftmann@30374
  1243
let
haftmann@31124
  1244
  val _ = tracing ("starting prove_equation' with " ^ ind_name)
haftmann@30374
  1245
  val (prednames, preds) = 
haftmann@30374
  1246
    case (try (InductivePackage.the_inductive (ProofContext.init thy)) ind_name) of
haftmann@30374
  1247
      SOME info => let val preds = info |> snd |> #preds
haftmann@30374
  1248
        in (map (fst o dest_Const) preds, map ((apsnd Logic.unvarifyT) o dest_Const) preds) end
haftmann@30374
  1249
    | NONE => let
haftmann@30374
  1250
        val pred = Symtab.lookup (#intro_rules (IndCodegenData.get thy)) ind_name
haftmann@30374
  1251
          |> the |> hd |> prop_of
haftmann@30374
  1252
          |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> strip_comb
haftmann@30374
  1253
          |> fst |>  dest_Const |> apsnd Logic.unvarifyT
haftmann@30374
  1254
       in ([ind_name], [pred]) end
haftmann@30374
  1255
  val thy' = fold (fn pred as (predname, T) => fn thy =>
haftmann@30374
  1256
    if null (pred_intros thy predname) then noclause pred thy else thy) preds thy
haftmann@30374
  1257
  val intrs = map (preprocess_intro thy') (maps (pred_intros thy') prednames)
haftmann@30374
  1258
    |> ind_set_codegen_preproc thy' (*FIXME preprocessor
haftmann@30374
  1259
    |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*)
haftmann@30374
  1260
    |> map (Logic.unvarify o prop_of)
haftmann@30374
  1261
  val _ = tracing ("preprocessed intro rules:" ^ (makestring (map (cterm_of thy') intrs)))
haftmann@30374
  1262
  val name_of_calls = get_name_of_ind_calls_of_clauses thy' prednames intrs 
haftmann@30374
  1263
  val _ = tracing ("calling preds: " ^ makestring name_of_calls)
haftmann@30374
  1264
  val _ = tracing "starting recursive compilations"
haftmann@30374
  1265
  fun rec_call name thy = 
haftmann@30972
  1266
    (*FIXME use member instead of infix mem*)
haftmann@30374
  1267
    if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then
haftmann@31124
  1268
      prove_equation name NONE thy else thy
haftmann@30374
  1269
  val thy'' = fold rec_call name_of_calls thy'
haftmann@30374
  1270
  val _ = tracing "returning from recursive calls"
haftmann@30374
  1271
  val _ = tracing "starting mode inference"
haftmann@30374
  1272
  val extra_modes = Symtab.dest (#modes (IndCodegenData.get thy''))
haftmann@30374
  1273
  val nparams = get_nparams thy'' ind_name
haftmann@30374
  1274
  val _ $ u = Logic.strip_imp_concl (hd intrs);
haftmann@30374
  1275
  val params = List.take (snd (strip_comb u), nparams);
haftmann@30374
  1276
  val param_vs = maps term_vs params
haftmann@30374
  1277
  val all_vs = terms_vs intrs
haftmann@30374
  1278
  fun dest_prem t =
haftmann@30374
  1279
      (case strip_comb t of
haftmann@30374
  1280
        (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
haftmann@30374
  1281
      | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of
haftmann@30374
  1282
          Prem (ts, t) => Negprem (ts, t)
haftmann@30374
  1283
        | Negprem _ => error ("Double negation not allowed in premise: " ^ (makestring (c $ t))) 
haftmann@30374
  1284
        | Sidecond t => Sidecond (c $ t))
haftmann@30374
  1285
      | (c as Const (s, _), ts) =>
haftmann@30374
  1286
        if is_ind_pred thy'' s then
haftmann@30374
  1287
          let val (ts1, ts2) = chop (get_nparams thy'' s) ts
haftmann@30374
  1288
          in Prem (ts2, list_comb (c, ts1)) end
haftmann@30374
  1289
        else Sidecond t
haftmann@30374
  1290
      | _ => Sidecond t)
haftmann@30374
  1291
  fun add_clause intr (clauses, arities) =
haftmann@30374
  1292
  let
haftmann@30374
  1293
    val _ $ t = Logic.strip_imp_concl intr;
haftmann@30374
  1294
    val (Const (name, T), ts) = strip_comb t;
haftmann@30374
  1295
    val (ts1, ts2) = chop nparams ts;
haftmann@30374
  1296
    val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
haftmann@30374
  1297
    val (Ts, Us) = chop nparams (binder_types T)
haftmann@30374
  1298
  in
haftmann@30374
  1299
    (AList.update op = (name, these (AList.lookup op = clauses name) @
haftmann@30374
  1300
      [(ts2, prems)]) clauses,
haftmann@30374
  1301
     AList.update op = (name, (map (fn U => (case strip_type U of
haftmann@30374
  1302
                 (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs)
haftmann@30374
  1303
               | _ => NONE)) Ts,
haftmann@30374
  1304
             length Us)) arities)
haftmann@30374
  1305
  end;
haftmann@30374
  1306
  val (clauses, arities) = fold add_clause intrs ([], []);
haftmann@30374
  1307
  val modes = infer_modes thy'' extra_modes arities param_vs clauses
haftmann@30374
  1308
  val _ = print_arities arities;
haftmann@30374
  1309
  val _ = print_modes modes;
haftmann@30374
  1310
  val modes = if (is_some mode) then AList.update (op =) (ind_name, [the mode]) modes else modes
haftmann@30374
  1311
  val _ = print_modes modes
haftmann@30374
  1312
  val thy''' = fold (create_definitions preds nparams) modes thy''
haftmann@30374
  1313
    |> map_modes (fold Symtab.update_new modes)
haftmann@30374
  1314
  val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses
haftmann@30374
  1315
  val _ = tracing "compiling predicates..."
haftmann@30374
  1316
  val ts = compile_preds thy''' all_vs param_vs (extra_modes @ modes) clauses'
haftmann@30374
  1317
  val _ = tracing "returned term from compile_preds"
haftmann@30374
  1318
  val pred_mode = maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses'
haftmann@30374
  1319
  val _ = tracing "starting proof"
haftmann@30374
  1320
  val result_thms = prove_preds thy''' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts))
haftmann@30374
  1321
  val (_, thy'''') = yield_singleton PureThy.add_thmss
haftmann@31124
  1322
    ((Binding.qualify true (Long_Name.base_name ind_name) (Binding.name "equation"), result_thms),
haftmann@30374
  1323
      [Attrib.attribute_i thy''' Code.add_default_eqn_attrib]) thy'''
haftmann@30374
  1324
in
haftmann@30374
  1325
  thy''''
haftmann@30374
  1326
end
haftmann@30374
  1327
haftmann@30374
  1328
fun set_nparams (pred, nparams) thy = map_nparams (Symtab.update (pred, nparams)) thy
haftmann@30374
  1329
haftmann@30374
  1330
fun print_alternative_rules thy = let
haftmann@30374
  1331
    val d = IndCodegenData.get thy
haftmann@30374
  1332
    val preds = (Symtab.keys (#intro_rules d)) union (Symtab.keys (#elim_rules d))
haftmann@30374
  1333
    val _ = tracing ("preds: " ^ (makestring preds))
haftmann@30374
  1334
    fun print pred = let
haftmann@30374
  1335
      val _ = tracing ("predicate: " ^ pred)
haftmann@30374
  1336
      val _ = tracing ("introrules: ")
haftmann@30374
  1337
      val _ = fold (fn thm => fn u => tracing (makestring thm))
haftmann@30374
  1338
        (rev (Symtab.lookup_list (#intro_rules d) pred)) ()
haftmann@30374
  1339
      val _ = tracing ("casesrule: ")
haftmann@30374
  1340
      val _ = tracing (makestring (Symtab.lookup (#elim_rules d) pred))
haftmann@30374
  1341
    in () end
haftmann@30374
  1342
    val _ = map print preds
haftmann@30374
  1343
 in thy end; 
haftmann@30374
  1344
haftmann@30374
  1345
bulwahn@31106
  1346
(* generation of case rules from user-given introduction rules *)
bulwahn@31106
  1347
haftmann@31124
  1348
fun mk_casesrule introrules nparams ctxt =
haftmann@31124
  1349
  let
bulwahn@31106
  1350
    val intros = map prop_of introrules
bulwahn@31106
  1351
    val (pred, (params, args)) = strip_intro_concl (hd intros) nparams
bulwahn@31106
  1352
    val ([propname], ctxt1) = Variable.variant_fixes ["thesis"] ctxt
bulwahn@31106
  1353
    val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
bulwahn@31106
  1354
    val (argnames, ctxt2) = Variable.variant_fixes
bulwahn@31106
  1355
      (map (fn i => "a" ^ string_of_int i) (1 upto (length args))) ctxt1
bulwahn@31106
  1356
    val argvs = map Free (argnames ~~ (map fastype_of args))
haftmann@31124
  1357
      (*FIXME map2*)
bulwahn@31106
  1358
    fun mk_case intro = let
bulwahn@31106
  1359
        val (_, (_, args)) = strip_intro_concl intro nparams
bulwahn@31106
  1360
        val prems = Logic.strip_imp_prems intro
bulwahn@31106
  1361
        val eqprems = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (argvs ~~ args)
bulwahn@31106
  1362
        val frees = (fold o fold_aterms)
bulwahn@31106
  1363
          (fn t as Free _ =>
bulwahn@31106
  1364
              if member (op aconv) params t then I else insert (op aconv) t
bulwahn@31106
  1365
           | _ => I) (args @ prems) []
bulwahn@31106
  1366
        in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
bulwahn@31106
  1367
    val assm = HOLogic.mk_Trueprop (list_comb (pred, params @ argvs))
bulwahn@31106
  1368
    val cases = map mk_case intros
bulwahn@31106
  1369
    val (_, ctxt3) = ProofContext.add_assms_i Assumption.assume_export
bulwahn@31106
  1370
              [((Binding.name AutoBind.assmsN, []), map (fn t => (t, [])) (assm :: cases))]
bulwahn@31106
  1371
              ctxt2
bulwahn@31106
  1372
  in (pred, prop, ctxt3) end;
bulwahn@31106
  1373
bulwahn@31106
  1374
haftmann@31124
  1375
(** user interface **)
bulwahn@31106
  1376
haftmann@31124
  1377
local
haftmann@31124
  1378
haftmann@31124
  1379
fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I);
haftmann@31124
  1380
haftmann@31124
  1381
val add_elim_attrib = attrib add_elim_thm;
haftmann@31124
  1382
haftmann@31124
  1383
fun generic_code_pred prep_const raw_const lthy =
haftmann@31124
  1384
  let
haftmann@31124
  1385
    val thy = ProofContext.theory_of lthy
haftmann@31124
  1386
    val const = prep_const thy raw_const
haftmann@31124
  1387
    val nparams = get_nparams thy const
haftmann@31124
  1388
    val intro_rules = pred_intros thy const
haftmann@31124
  1389
    val (((tfrees, frees), fact), lthy') =
haftmann@31124
  1390
      Variable.import_thms true intro_rules lthy;
haftmann@31124
  1391
    val (pred, prop, lthy'') = mk_casesrule fact nparams lthy'
haftmann@31124
  1392
    val (predname, _) = dest_Const pred
haftmann@31124
  1393
    fun after_qed [[th]] lthy'' =
haftmann@31124
  1394
      lthy''
haftmann@31124
  1395
      |> LocalTheory.note Thm.theoremK
haftmann@31124
  1396
           ((Binding.empty, [Attrib.internal (K add_elim_attrib)]), [th])
haftmann@31124
  1397
      |> snd
haftmann@31124
  1398
      |> LocalTheory.theory (prove_equation predname NONE)
haftmann@31124
  1399
  in
haftmann@31124
  1400
    Proof.theorem_i NONE after_qed [[(prop, [])]] lthy''
haftmann@31124
  1401
  end;
haftmann@31124
  1402
haftmann@31124
  1403
structure P = OuterParse
haftmann@31124
  1404
haftmann@31124
  1405
in
haftmann@31124
  1406
haftmann@31124
  1407
val code_pred = generic_code_pred (K I);
haftmann@31156
  1408
val code_pred_cmd = generic_code_pred Code.read_const
haftmann@31124
  1409
haftmann@31124
  1410
val setup =
haftmann@31124
  1411
  Attrib.setup @{binding code_ind_intros} (Scan.succeed (attrib add_intro_thm))
haftmann@31124
  1412
    "adding alternative introduction rules for code generation of inductive predicates" #>
haftmann@31124
  1413
  Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib)
haftmann@31124
  1414
    "adding alternative elimination rules for code generation of inductive predicates";
haftmann@31124
  1415
  (*FIXME name discrepancy in attribs and ML code*)
haftmann@31124
  1416
  (*FIXME intros should be better named intro*)
haftmann@31124
  1417
  (*FIXME why distinguished atribute for cases?*)
haftmann@31124
  1418
haftmann@31124
  1419
val _ = OuterSyntax.local_theory_to_proof "code_pred"
haftmann@31124
  1420
  "prove equations for predicate specified by intro/elim rules"
haftmann@31124
  1421
  OuterKeyword.thy_goal (P.term_group >> code_pred_cmd)
haftmann@31124
  1422
haftmann@31124
  1423
end
haftmann@31124
  1424
haftmann@31124
  1425
(*FIXME
haftmann@31124
  1426
- Naming of auxiliary rules necessary?
haftmann@31124
  1427
*)
bulwahn@31106
  1428
haftmann@30374
  1429
end;