src/HOL/Matrix/eq_codegen.ML
author obua
Fri, 03 Sep 2004 17:10:36 +0200
changeset 15178 5f621aa35c25
child 15531 08c8dad8e399
permissions -rw-r--r--
Matrix theory, linear programming
obua@15178
     1
fun inst_cterm inst ct = fst (Drule.dest_equals
obua@15178
     2
  (Thm.cprop_of (Thm.instantiate inst (reflexive ct))));
obua@15178
     3
fun tyinst_cterm tyinst = inst_cterm (tyinst, []);
obua@15178
     4
obua@15178
     5
val bla = ref ([] : term list);
obua@15178
     6
obua@15178
     7
(******************************************************)
obua@15178
     8
(*        Code generator for equational proofs        *)
obua@15178
     9
(******************************************************)
obua@15178
    10
fun my_mk_meta_eq thm =
obua@15178
    11
  let
obua@15178
    12
    val (_, eq) = Thm.dest_comb (cprop_of thm);
obua@15178
    13
    val (ct, rhs) = Thm.dest_comb eq;
obua@15178
    14
    val (_, lhs) = Thm.dest_comb ct
obua@15178
    15
  in Thm.implies_elim (Drule.instantiate' [Some (ctyp_of_term lhs)]
obua@15178
    16
    [Some lhs, Some rhs] eq_reflection) thm
obua@15178
    17
  end; 
obua@15178
    18
obua@15178
    19
structure SimprocsCodegen =
obua@15178
    20
struct
obua@15178
    21
obua@15178
    22
val simp_thms = ref ([] : thm list);
obua@15178
    23
obua@15178
    24
fun parens b = if b then Pretty.enclose "(" ")" else Pretty.block;
obua@15178
    25
obua@15178
    26
fun gen_mk_val f xs ps = Pretty.block ([Pretty.str "val ",
obua@15178
    27
  f (length xs > 1) (flat
obua@15178
    28
    (separate [Pretty.str ",", Pretty.brk 1] (map (single o Pretty.str) xs))),
obua@15178
    29
  Pretty.str " =", Pretty.brk 1] @ ps @ [Pretty.str ";"]);
obua@15178
    30
obua@15178
    31
val mk_val = gen_mk_val parens;
obua@15178
    32
val mk_vall = gen_mk_val (K (Pretty.enclose "[" "]"));
obua@15178
    33
obua@15178
    34
fun rename s = if s mem ThmDatabase.ml_reserved then s ^ "'" else s;
obua@15178
    35
obua@15178
    36
fun mk_decomp_name (Var ((s, i), _)) = rename (if i=0 then s else s ^ string_of_int i)
obua@15178
    37
  | mk_decomp_name (Const (s, _)) = rename (Codegen.mk_id (Sign.base_name s))
obua@15178
    38
  | mk_decomp_name _ = "ct";
obua@15178
    39
obua@15178
    40
fun decomp_term_code cn ((vs, bs, ps), (v, t)) =
obua@15178
    41
  if exists (equal t o fst) bs then (vs, bs, ps)
obua@15178
    42
  else (case t of
obua@15178
    43
      Var _ => (vs, bs @ [(t, v)], ps)
obua@15178
    44
    | Const _ => (vs, if cn then bs @ [(t, v)] else bs, ps)
obua@15178
    45
    | Bound _ => (vs, bs, ps)
obua@15178
    46
    | Abs (s, T, t) =>
obua@15178
    47
      let
obua@15178
    48
        val v1 = variant vs s;
obua@15178
    49
        val v2 = variant (v1 :: vs) (mk_decomp_name t)
obua@15178
    50
      in
obua@15178
    51
        decomp_term_code cn ((v1 :: v2 :: vs,
obua@15178
    52
          bs @ [(Free (s, T), v1)],
obua@15178
    53
          ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_abs", Pretty.brk 1,
obua@15178
    54
            Pretty.str "None", Pretty.brk 1, Pretty.str v]]), (v2, t))
obua@15178
    55
      end
obua@15178
    56
    | t $ u =>
obua@15178
    57
      let
obua@15178
    58
        val v1 = variant vs (mk_decomp_name t);
obua@15178
    59
        val v2 = variant (v1 :: vs) (mk_decomp_name u);
obua@15178
    60
        val (vs', bs', ps') = decomp_term_code cn ((v1 :: v2 :: vs, bs,
obua@15178
    61
          ps @ [mk_val [v1, v2] [Pretty.str "Thm.dest_comb", Pretty.brk 1,
obua@15178
    62
            Pretty.str v]]), (v1, t));
obua@15178
    63
        val (vs'', bs'', ps'') = decomp_term_code cn ((vs', bs', ps'), (v2, u))
obua@15178
    64
      in
obua@15178
    65
        if bs'' = bs then (vs, bs, ps) else (vs'', bs'', ps'')
obua@15178
    66
      end);
obua@15178
    67
obua@15178
    68
val strip_tv = implode o tl o explode;
obua@15178
    69
obua@15178
    70
fun mk_decomp_tname (TVar ((s, i), _)) =
obua@15178
    71
      strip_tv ((if i=0 then s else s ^ string_of_int i) ^ "T")
obua@15178
    72
  | mk_decomp_tname (Type (s, _)) = Codegen.mk_id (Sign.base_name s) ^ "T"
obua@15178
    73
  | mk_decomp_tname _ = "cT";
obua@15178
    74
obua@15178
    75
fun decomp_type_code ((vs, bs, ps), (v, TVar (ixn, _))) =
obua@15178
    76
      if exists (equal ixn o fst) bs then (vs, bs, ps)
obua@15178
    77
      else (vs, bs @ [(ixn, v)], ps)
obua@15178
    78
  | decomp_type_code ((vs, bs, ps), (v, Type (_, Ts))) =
obua@15178
    79
      let
obua@15178
    80
        val vs' = variantlist (map mk_decomp_tname Ts, vs);
obua@15178
    81
        val (vs'', bs', ps') =
obua@15178
    82
          foldl decomp_type_code ((vs @ vs', bs, ps @
obua@15178
    83
            [mk_vall vs' [Pretty.str "Thm.dest_ctyp", Pretty.brk 1,
obua@15178
    84
              Pretty.str v]]), vs' ~~ Ts)
obua@15178
    85
      in
obua@15178
    86
        if bs' = bs then (vs, bs, ps) else (vs'', bs', ps')
obua@15178
    87
      end;
obua@15178
    88
obua@15178
    89
fun gen_mk_bindings s dest decomp ((vs, bs, ps), (v, x)) =
obua@15178
    90
  let
obua@15178
    91
    val s' = variant vs s;
obua@15178
    92
    val (vs', bs', ps') = decomp ((s' :: vs, bs, ps @
obua@15178
    93
      [mk_val [s'] (dest v)]), (s', x))
obua@15178
    94
  in
obua@15178
    95
    if bs' = bs then (vs, bs, ps) else (vs', bs', ps')
obua@15178
    96
  end;
obua@15178
    97
obua@15178
    98
val mk_term_bindings = gen_mk_bindings "ct"
obua@15178
    99
  (fn s => [Pretty.str "cprop_of", Pretty.brk 1, Pretty.str s])
obua@15178
   100
  (decomp_term_code true);
obua@15178
   101
obua@15178
   102
val mk_type_bindings = gen_mk_bindings "cT"
obua@15178
   103
  (fn s => [Pretty.str "Thm.ctyp_of_term", Pretty.brk 1, Pretty.str s])
obua@15178
   104
  decomp_type_code;
obua@15178
   105
obua@15178
   106
fun pretty_pattern b (Const (s, _)) = Pretty.block [Pretty.str "Const",
obua@15178
   107
      Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\", _)")]
obua@15178
   108
  | pretty_pattern b (t as _ $ _) = parens b
obua@15178
   109
      (flat (separate [Pretty.str " $", Pretty.brk 1]
obua@15178
   110
        (map (single o pretty_pattern true) (op :: (strip_comb t)))))
obua@15178
   111
  | pretty_pattern b _ = Pretty.str "_";
obua@15178
   112
obua@15178
   113
fun term_consts' t = foldl_aterms
obua@15178
   114
  (fn (cs, c as Const _) => c ins cs | (cs, _) => cs) ([], t);
obua@15178
   115
obua@15178
   116
fun mk_apps s b p [] = p
obua@15178
   117
  | mk_apps s b p (q :: qs) = 
obua@15178
   118
      mk_apps s b (parens (b orelse not (null qs))
obua@15178
   119
        [Pretty.str s, Pretty.brk 1, p, Pretty.brk 1, q]) qs;
obua@15178
   120
obua@15178
   121
fun mk_refleq eq ct = mk_val [eq] [Pretty.str ("Thm.reflexive " ^ ct)];
obua@15178
   122
obua@15178
   123
fun mk_tyinst ((s, i), s') =
obua@15178
   124
  Pretty.block [Pretty.str ("((" ^ quote s ^ ","), Pretty.brk 1,
obua@15178
   125
    Pretty.str (string_of_int i ^ "),"), Pretty.brk 1,
obua@15178
   126
    Pretty.str (s' ^ ")")];
obua@15178
   127
obua@15178
   128
fun inst_ty b ty_bs t s = (case term_tvars t of
obua@15178
   129
    [] => Pretty.str s
obua@15178
   130
  | Ts => parens b [Pretty.str "tyinst_cterm", Pretty.brk 1,
obua@15178
   131
      Pretty.list "[" "]" (map (fn (ixn, _) => mk_tyinst
obua@15178
   132
        (ixn, the (assoc (ty_bs, ixn)))) Ts),
obua@15178
   133
      Pretty.brk 1, Pretty.str s]);
obua@15178
   134
obua@15178
   135
fun mk_cterm_code b ty_bs ts xs (vals, t $ u) =
obua@15178
   136
      let
obua@15178
   137
        val (vals', p1) = mk_cterm_code true ty_bs ts xs (vals, t);
obua@15178
   138
        val (vals'', p2) = mk_cterm_code true ty_bs ts xs (vals', u)
obua@15178
   139
      in
obua@15178
   140
        (vals'', parens b [Pretty.str "Thm.capply", Pretty.brk 1,
obua@15178
   141
          p1, Pretty.brk 1, p2])
obua@15178
   142
      end
obua@15178
   143
  | mk_cterm_code b ty_bs ts xs (vals, Abs (s, T, t)) =
obua@15178
   144
      let
obua@15178
   145
        val u = Free (s, T);
obua@15178
   146
        val Some s' = assoc (ts, u);
obua@15178
   147
        val p = Pretty.str s';
obua@15178
   148
        val (vals', p') = mk_cterm_code true ty_bs ts (p :: xs)
obua@15178
   149
          (if null (typ_tvars T) then vals
obua@15178
   150
           else vals @ [(u, (("", s'), [mk_val [s'] [inst_ty true ty_bs u s']]))], t)
obua@15178
   151
      in (vals',
obua@15178
   152
        parens b [Pretty.str "Thm.cabs", Pretty.brk 1, p, Pretty.brk 1, p'])
obua@15178
   153
      end
obua@15178
   154
  | mk_cterm_code b ty_bs ts xs (vals, Bound i) = (vals, nth_elem (i, xs))
obua@15178
   155
  | mk_cterm_code b ty_bs ts xs (vals, t) = (case assoc (vals, t) of
obua@15178
   156
        None =>
obua@15178
   157
          let val Some s = assoc (ts, t)
obua@15178
   158
          in (if is_Const t andalso not (null (term_tvars t)) then
obua@15178
   159
              vals @ [(t, (("", s), [mk_val [s] [inst_ty true ty_bs t s]]))]
obua@15178
   160
            else vals, Pretty.str s)
obua@15178
   161
          end
obua@15178
   162
      | Some ((_, s), _) => (vals, Pretty.str s));
obua@15178
   163
obua@15178
   164
fun get_cases sg =
obua@15178
   165
  Symtab.foldl (fn (tab, (k, {case_rewrites, ...})) => Symtab.update_new
obua@15178
   166
    ((fst (dest_Const (head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop
obua@15178
   167
      (prop_of (hd case_rewrites))))))), map my_mk_meta_eq case_rewrites), tab))
obua@15178
   168
        (Symtab.empty, DatatypePackage.get_datatypes_sg sg);
obua@15178
   169
obua@15178
   170
fun decomp_case th =
obua@15178
   171
  let
obua@15178
   172
    val (lhs, _) = Logic.dest_equals (prop_of th);
obua@15178
   173
    val (f, ts) = strip_comb lhs;
obua@15178
   174
    val (us, u) = split_last ts;
obua@15178
   175
    val (Const (s, _), vs) = strip_comb u
obua@15178
   176
  in (us, s, vs, u) end;
obua@15178
   177
obua@15178
   178
fun rename vs t =
obua@15178
   179
  let
obua@15178
   180
    fun mk_subst ((vs, subs), Var ((s, i), T)) =
obua@15178
   181
      let val s' = variant vs s
obua@15178
   182
      in if s = s' then (vs, subs)
obua@15178
   183
        else (s' :: vs, ((s, i), Var ((s', i), T)) :: subs)
obua@15178
   184
      end;
obua@15178
   185
    val (vs', subs) = foldl mk_subst ((vs, []), term_vars t)
obua@15178
   186
  in (vs', subst_Vars subs t) end;
obua@15178
   187
obua@15178
   188
fun is_instance sg t u = t = subst_TVars_Vartab
obua@15178
   189
  (Type.typ_match (Sign.tsig_of sg) (Vartab.empty,
obua@15178
   190
    (fastype_of u, fastype_of t))) u handle Type.TYPE_MATCH => false;
obua@15178
   191
obua@15178
   192
(*
obua@15178
   193
fun lookup sg fs t = apsome snd (Library.find_first
obua@15178
   194
  (is_instance sg t o fst) fs);
obua@15178
   195
*)
obua@15178
   196
obua@15178
   197
fun lookup sg fs t = (case Library.find_first (is_instance sg t o fst) fs of
obua@15178
   198
    None => (bla := (t ins !bla); None)
obua@15178
   199
  | Some (_, x) => Some x);
obua@15178
   200
obua@15178
   201
fun unint sg fs t = forall (is_none o lookup sg fs) (term_consts' t);
obua@15178
   202
obua@15178
   203
fun mk_let s i xs ys =
obua@15178
   204
  Pretty.blk (0, [Pretty.blk (i, separate Pretty.fbrk (Pretty.str s :: xs)),
obua@15178
   205
    Pretty.fbrk,
obua@15178
   206
    Pretty.blk (i, ([Pretty.str "in", Pretty.fbrk] @ ys)),
obua@15178
   207
    Pretty.fbrk, Pretty.str "end"]);
obua@15178
   208
obua@15178
   209
(*****************************************************************************)
obua@15178
   210
(* Generate bindings for simplifying term t                                  *)
obua@15178
   211
(* mkeq: whether to generate reflexivity theorem for uninterpreted terms     *)
obua@15178
   212
(* fs:   interpreted functions                                               *)
obua@15178
   213
(* ts:   atomic terms                                                        *)
obua@15178
   214
(* vs:   used identifiers                                                    *)
obua@15178
   215
(* vals: list of bindings of the form ((eq, ct), ps) where                   *)
obua@15178
   216
(*       eq: name of equational theorem                                      *)
obua@15178
   217
(*       ct: name of simplified cterm                                        *)
obua@15178
   218
(*       ps: ML code for creating the above two items                        *)
obua@15178
   219
(*****************************************************************************)
obua@15178
   220
obua@15178
   221
fun mk_simpl_code sg case_tab mkeq fs ts ty_bs thm_bs ((vs, vals), t) =
obua@15178
   222
  (case assoc (vals, t) of
obua@15178
   223
    Some ((eq, ct), ps) =>  (* binding already generated *) 
obua@15178
   224
      if mkeq andalso eq="" then
obua@15178
   225
        let val eq' = variant vs "eq"
obua@15178
   226
        in ((eq' :: vs, overwrite (vals,
obua@15178
   227
          (t, ((eq', ct), ps @ [mk_refleq eq' ct])))), (eq', ct))
obua@15178
   228
        end
obua@15178
   229
      else ((vs, vals), (eq, ct))
obua@15178
   230
  | None => (case assoc (ts, t) of
obua@15178
   231
      Some v =>  (* atomic term *)
obua@15178
   232
        let val xs = if not (null (term_tvars t)) andalso is_Const t then
obua@15178
   233
          [mk_val [v] [inst_ty false ty_bs t v]] else []
obua@15178
   234
        in
obua@15178
   235
          if mkeq then
obua@15178
   236
            let val eq = variant vs "eq"
obua@15178
   237
            in ((eq :: vs, vals @
obua@15178
   238
              [(t, ((eq, v), xs @ [mk_refleq eq v]))]), (eq, v))
obua@15178
   239
            end
obua@15178
   240
          else ((vs, if null xs then vals else vals @
obua@15178
   241
            [(t, (("", v), xs))]), ("", v))
obua@15178
   242
        end
obua@15178
   243
    | None =>  (* complex term *)
obua@15178
   244
        let val (f as Const (cname, _), us) = strip_comb t
obua@15178
   245
        in case Symtab.lookup (case_tab, cname) of
obua@15178
   246
            Some cases =>  (* case expression *)
obua@15178
   247
              let
obua@15178
   248
                val (us', u) = split_last us;
obua@15178
   249
                val b = unint sg fs u;
obua@15178
   250
                val ((vs1, vals1), (eq, ct)) =
obua@15178
   251
                  mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs ((vs, vals), u);
obua@15178
   252
                val xs = variantlist (replicate (length us') "f", vs1);
obua@15178
   253
                val (vals2, ps) = foldl_map
obua@15178
   254
                  (mk_cterm_code false ty_bs ts []) (vals1, us');
obua@15178
   255
                val fvals = map (fn (x, p) => mk_val [x] [p]) (xs ~~ ps);
obua@15178
   256
                val uT = fastype_of u;
obua@15178
   257
                val (us'', _, _, u') = decomp_case (hd cases);
obua@15178
   258
                val (vs2, ty_bs', ty_vals) = mk_type_bindings
obua@15178
   259
                  (mk_type_bindings ((vs1 @ xs, [], []),
obua@15178
   260
                    (hd xs, fastype_of (hd us''))), (ct, fastype_of u'));
obua@15178
   261
                val insts1 = map mk_tyinst ty_bs';
obua@15178
   262
                val i = length vals2;
obua@15178
   263
   
obua@15178
   264
                fun mk_case_code ((vs, vals), (f, (name, eqn))) =
obua@15178
   265
                  let
obua@15178
   266
                    val (fvs, cname, cvs, _) = decomp_case eqn;
obua@15178
   267
                    val Ts = binder_types (fastype_of f);
obua@15178
   268
                    val ys = variantlist (map (fst o fst o dest_Var) cvs, vs);
obua@15178
   269
                    val cvs' = map Var (map (rpair 0) ys ~~ Ts);
obua@15178
   270
                    val rs = cvs' ~~ cvs;
obua@15178
   271
                    val lhs = list_comb (Const (cname, Ts ---> uT), cvs');
obua@15178
   272
                    val rhs = foldl betapply (f, cvs');
obua@15178
   273
                    val (vs', tm_bs, tm_vals) = decomp_term_code false
obua@15178
   274
                      ((vs @ ys, [], []), (ct, lhs));
obua@15178
   275
                    val ((vs'', all_vals), (eq', ct')) = mk_simpl_code sg case_tab
obua@15178
   276
                      false fs (tm_bs @ ts) ty_bs thm_bs ((vs', vals), rhs);
obua@15178
   277
                    val (old_vals, eq_vals) = splitAt (i, all_vals);
obua@15178
   278
                    val vs''' = vs @ filter (fn v => exists
obua@15178
   279
                      (fn (_, ((v', _), _)) => v = v') old_vals) (vs'' \\ vs');
obua@15178
   280
                    val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(",
obua@15178
   281
                      inst_ty false ty_bs' t (the (assoc (thm_bs, t))), Pretty.str ",",
obua@15178
   282
                      Pretty.brk 1, Pretty.str (s ^ ")")]) ((fvs ~~ xs) @
obua@15178
   283
                        (map (fn (v, s) => (the (assoc (rs, v)), s)) tm_bs));
obua@15178
   284
                    val eq'' = if null insts1 andalso null insts2 then Pretty.str name
obua@15178
   285
                      else parens (eq' <> "") [Pretty.str
obua@15178
   286
                          (if null cvs then "Thm.instantiate" else "Drule.instantiate"),
obua@15178
   287
                        Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1,
obua@15178
   288
                        Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2,
obua@15178
   289
                        Pretty.str ")", Pretty.brk 1, Pretty.str name];
obua@15178
   290
                    val eq''' = if eq' = "" then eq'' else
obua@15178
   291
                      Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1,
obua@15178
   292
                        eq'', Pretty.brk 1, Pretty.str eq']
obua@15178
   293
                  in
obua@15178
   294
                    ((vs''', old_vals), Pretty.block [pretty_pattern false lhs,
obua@15178
   295
                      Pretty.str " =>",
obua@15178
   296
                      Pretty.brk 1, mk_let "let" 2 (tm_vals @ flat (map (snd o snd) eq_vals))
obua@15178
   297
                        [Pretty.str ("(" ^ ct' ^ ","), Pretty.brk 1, eq''', Pretty.str ")"]])
obua@15178
   298
                  end;
obua@15178
   299
obua@15178
   300
                val case_names = map (fn i => Sign.base_name cname ^ "_" ^
obua@15178
   301
                  string_of_int i) (1 upto length cases);
obua@15178
   302
                val ((vs3, vals3), case_ps) = foldl_map mk_case_code
obua@15178
   303
                  ((vs2, vals2), us' ~~ (case_names ~~ cases));
obua@15178
   304
                val eq' = variant vs3 "eq";
obua@15178
   305
                val ct' = variant (eq' :: vs3) "ct";
obua@15178
   306
                val eq'' = variant (eq' :: ct' :: vs3) "eq";
obua@15178
   307
                val case_vals =
obua@15178
   308
                  fvals @ ty_vals @
obua@15178
   309
                  [mk_val [ct', eq'] ([Pretty.str "(case", Pretty.brk 1,
obua@15178
   310
                    Pretty.str ("term_of " ^ ct ^ " of"), Pretty.brk 1] @
obua@15178
   311
                    flat (separate [Pretty.brk 1, Pretty.str "| "]
obua@15178
   312
                      (map single case_ps)) @ [Pretty.str ")"])]
obua@15178
   313
              in
obua@15178
   314
                if b then
obua@15178
   315
                  ((eq' :: ct' :: vs3, vals3 @
obua@15178
   316
                     [(t, ((eq', ct'), case_vals))]), (eq', ct'))
obua@15178
   317
                else
obua@15178
   318
                  let val ((vs4, vals4), (_, ctcase)) = mk_simpl_code sg case_tab false
obua@15178
   319
                    fs ts ty_bs thm_bs ((eq' :: eq'' :: ct' :: vs3, vals3), f)
obua@15178
   320
                  in
obua@15178
   321
                    ((vs4, vals4 @ [(t, ((eq'', ct'), case_vals @
obua@15178
   322
                       [mk_val [eq''] [Pretty.str "Thm.transitive", Pretty.brk 1,
obua@15178
   323
                          Pretty.str "(Thm.combination", Pretty.brk 1,
obua@15178
   324
                          Pretty.str "(Thm.reflexive", Pretty.brk 1,
obua@15178
   325
                          mk_apps "Thm.capply" true (Pretty.str ctcase)
obua@15178
   326
                            (map Pretty.str xs),
obua@15178
   327
                          Pretty.str ")", Pretty.brk 1, Pretty.str (eq ^ ")"),
obua@15178
   328
                          Pretty.brk 1, Pretty.str eq']]))]), (eq'', ct'))
obua@15178
   329
                  end
obua@15178
   330
              end
obua@15178
   331
          
obua@15178
   332
          | None =>
obua@15178
   333
            let
obua@15178
   334
              val b = forall (unint sg fs) us;
obua@15178
   335
              val (q, eqs) = foldl_map
obua@15178
   336
                (mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs) ((vs, vals), us);
obua@15178
   337
              val ((vs', vals'), (eqf, ctf)) = if is_some (lookup sg fs f) andalso b
obua@15178
   338
                then (q, ("", ""))
obua@15178
   339
                else mk_simpl_code sg case_tab (not b) fs ts ty_bs thm_bs (q, f);
obua@15178
   340
              val ct = variant vs' "ct";
obua@15178
   341
              val eq = variant (ct :: vs') "eq";
obua@15178
   342
              val ctv = mk_val [ct] [mk_apps "Thm.capply" false
obua@15178
   343
                (Pretty.str ctf) (map (Pretty.str o snd) eqs)];
obua@15178
   344
              fun combp b = mk_apps "Thm.combination" b
obua@15178
   345
                (Pretty.str eqf) (map (Pretty.str o fst) eqs)
obua@15178
   346
            in
obua@15178
   347
              case (lookup sg fs f, b) of
obua@15178
   348
                (None, true) =>  (* completely uninterpreted *)
obua@15178
   349
                  if mkeq then ((ct :: eq :: vs', vals' @
obua@15178
   350
                    [(t, ((eq, ct), [ctv, mk_refleq eq ct]))]), (eq, ct))
obua@15178
   351
                  else ((ct :: vs', vals' @ [(t, (("", ct), [ctv]))]), ("", ct))
obua@15178
   352
              | (None, false) =>  (* function uninterpreted *)
obua@15178
   353
                  ((eq :: ct :: vs', vals' @
obua@15178
   354
                     [(t, ((eq, ct), [ctv, mk_val [eq] [combp false]]))]), (eq, ct))
obua@15178
   355
              | (Some (s, _, _), true) =>  (* arguments uninterpreted *)
obua@15178
   356
                  ((eq :: ct :: vs', vals' @
obua@15178
   357
                     [(t, ((eq, ct), [mk_val [ct, eq] (separate (Pretty.brk 1)
obua@15178
   358
                       (Pretty.str s :: map (Pretty.str o snd) eqs))]))]), (eq, ct))
obua@15178
   359
              | (Some (s, _, _), false) =>  (* function and arguments interpreted *)
obua@15178
   360
                  let val eq' = variant (eq :: ct :: vs') "eq"
obua@15178
   361
                  in ((eq' :: eq :: ct :: vs', vals' @ [(t, ((eq', ct),
obua@15178
   362
                    [mk_val [ct, eq] (separate (Pretty.brk 1)
obua@15178
   363
                       (Pretty.str s :: map (Pretty.str o snd) eqs)),
obua@15178
   364
                     mk_val [eq'] [Pretty.str "Thm.transitive", Pretty.brk 1,
obua@15178
   365
                       combp true, Pretty.brk 1, Pretty.str eq]]))]), (eq', ct))
obua@15178
   366
                  end
obua@15178
   367
            end
obua@15178
   368
        end));
obua@15178
   369
obua@15178
   370
fun lhs_of thm = fst (Logic.dest_equals (prop_of thm));
obua@15178
   371
fun rhs_of thm = snd (Logic.dest_equals (prop_of thm));
obua@15178
   372
obua@15178
   373
fun mk_funs_code sg case_tab fs fs' =
obua@15178
   374
  let
obua@15178
   375
    val case_thms = mapfilter (fn s => (case Symtab.lookup (case_tab, s) of
obua@15178
   376
        None => None
obua@15178
   377
      | Some thms => Some (unsuffix "_case" (Sign.base_name s) ^ ".cases",
obua@15178
   378
          map (fn i => Sign.base_name s ^ "_" ^ string_of_int i)
obua@15178
   379
            (1 upto length thms) ~~ thms)))
obua@15178
   380
      (foldr add_term_consts (map (prop_of o snd)
obua@15178
   381
        (flat (map (#3 o snd) fs')), []));
obua@15178
   382
    val case_vals = map (fn (s, cs) => mk_vall (map fst cs)
obua@15178
   383
      [Pretty.str "map my_mk_meta_eq", Pretty.brk 1,
obua@15178
   384
       Pretty.str ("(thms \"" ^ s ^ "\")")]) case_thms;
obua@15178
   385
    val (vs, thm_bs, thm_vals) = foldl mk_term_bindings (([], [], []),
obua@15178
   386
      flat (map (map (apsnd prop_of) o #3 o snd) fs') @
obua@15178
   387
      map (apsnd prop_of) (flat (map snd case_thms)));
obua@15178
   388
obua@15178
   389
    fun mk_fun_code (prfx, (fname, d, eqns)) =
obua@15178
   390
      let
obua@15178
   391
        val (f, ts) = strip_comb (lhs_of (snd (hd eqns)));
obua@15178
   392
        val args = variantlist (replicate (length ts) "ct", vs);
obua@15178
   393
        val (vs', ty_bs, ty_vals) = foldl mk_type_bindings
obua@15178
   394
          ((vs @ args, [], []), args ~~ map fastype_of ts);
obua@15178
   395
        val insts1 = map mk_tyinst ty_bs;
obua@15178
   396
obua@15178
   397
        fun mk_eqn_code (name, eqn) =
obua@15178
   398
          let
obua@15178
   399
            val (_, argts) = strip_comb (lhs_of eqn);
obua@15178
   400
            val (vs'', tm_bs, tm_vals) = foldl (decomp_term_code false)
obua@15178
   401
              ((vs', [], []), args ~~ argts);
obua@15178
   402
            val ((vs''', eq_vals), (eq, ct)) = mk_simpl_code sg case_tab false fs
obua@15178
   403
              (tm_bs @ filter_out (is_Var o fst) thm_bs) ty_bs thm_bs
obua@15178
   404
              ((vs'', []), rhs_of eqn);
obua@15178
   405
            val insts2 = map (fn (t, s) => Pretty.block [Pretty.str "(",
obua@15178
   406
              inst_ty false ty_bs t (the (assoc (thm_bs, t))), Pretty.str ",", Pretty.brk 1,
obua@15178
   407
              Pretty.str (s ^ ")")]) tm_bs
obua@15178
   408
            val eq' = if null insts1 andalso null insts2 then Pretty.str name
obua@15178
   409
              else parens (eq <> "") [Pretty.str "Thm.instantiate",
obua@15178
   410
                Pretty.brk 1, Pretty.str "(", Pretty.list "[" "]" insts1,
obua@15178
   411
                Pretty.str ",", Pretty.brk 1, Pretty.list "[" "]" insts2,
obua@15178
   412
                Pretty.str ")", Pretty.brk 1, Pretty.str name];
obua@15178
   413
            val eq'' = if eq = "" then eq' else
obua@15178
   414
              Pretty.block [Pretty.str "Thm.transitive", Pretty.brk 1,
obua@15178
   415
                eq', Pretty.brk 1, Pretty.str eq]
obua@15178
   416
          in
obua@15178
   417
            Pretty.block [parens (length argts > 1)
obua@15178
   418
                (Pretty.commas (map (pretty_pattern false) argts)),
obua@15178
   419
              Pretty.str " =>",
obua@15178
   420
              Pretty.brk 1, mk_let "let" 2 (ty_vals @ tm_vals @ flat (map (snd o snd) eq_vals))
obua@15178
   421
                [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1, eq'', Pretty.str ")"]]
obua@15178
   422
          end;
obua@15178
   423
obua@15178
   424
        val default = if d then
obua@15178
   425
            let
obua@15178
   426
              val Some s = assoc (thm_bs, f);
obua@15178
   427
              val ct = variant vs' "ct"
obua@15178
   428
            in [Pretty.brk 1, Pretty.str "handle", Pretty.brk 1,
obua@15178
   429
              Pretty.str "Match =>", Pretty.brk 1, mk_let "let" 2
obua@15178
   430
                (ty_vals @ (if null (term_tvars f) then [] else
obua@15178
   431
                   [mk_val [s] [inst_ty false ty_bs f s]]) @
obua@15178
   432
                 [mk_val [ct] [mk_apps "Thm.capply" false (Pretty.str s)
obua@15178
   433
                    (map Pretty.str args)]])
obua@15178
   434
                [Pretty.str ("(" ^ ct ^ ","), Pretty.brk 1,
obua@15178
   435
                 Pretty.str "Thm.reflexive", Pretty.brk 1, Pretty.str (ct ^ ")")]]
obua@15178
   436
            end
obua@15178
   437
          else []
obua@15178
   438
      in
obua@15178
   439
        ("and ", Pretty.block (separate (Pretty.brk 1)
obua@15178
   440
            (Pretty.str (prfx ^ fname) :: map Pretty.str args) @
obua@15178
   441
          [Pretty.str " =", Pretty.brk 1, Pretty.str "(case", Pretty.brk 1,
obua@15178
   442
           Pretty.list "(" ")" (map (fn s => Pretty.str ("term_of " ^ s)) args),
obua@15178
   443
           Pretty.str " of", Pretty.brk 1] @
obua@15178
   444
          flat (separate [Pretty.brk 1, Pretty.str "| "]
obua@15178
   445
            (map (single o mk_eqn_code) eqns)) @ [Pretty.str ")"] @ default))
obua@15178
   446
      end;
obua@15178
   447
obua@15178
   448
    val (_, decls) = foldl_map mk_fun_code ("fun ", map snd fs')
obua@15178
   449
  in
obua@15178
   450
    mk_let "local" 2 (case_vals @ thm_vals) (separate Pretty.fbrk decls)
obua@15178
   451
  end;
obua@15178
   452
obua@15178
   453
fun mk_simprocs_code sg eqns =
obua@15178
   454
  let
obua@15178
   455
    val case_tab = get_cases sg;
obua@15178
   456
    fun get_head th = head_of (fst (Logic.dest_equals (prop_of th)));
obua@15178
   457
    fun attach_term (x as (_, _, (_, th) :: _)) = (get_head th, x);
obua@15178
   458
    val eqns' = map attach_term eqns;
obua@15178
   459
    fun mk_node (s, _, (_, th) :: _) = (s, get_head th);
obua@15178
   460
    fun mk_edges (s, _, ths) = map (pair s) (distinct
obua@15178
   461
      (mapfilter (fn t => apsome #1 (lookup sg eqns' t))
obua@15178
   462
        (flat (map (term_consts' o prop_of o snd) ths))));
obua@15178
   463
    val gr = foldr (uncurry Graph.add_edge)
obua@15178
   464
      (map (pair "" o #1) eqns @ flat (map mk_edges eqns),
obua@15178
   465
       foldr (uncurry Graph.new_node)
obua@15178
   466
         (("", Bound 0) :: map mk_node eqns, Graph.empty));
obua@15178
   467
    val keys = rev (Graph.all_succs gr [""] \ "");
obua@15178
   468
    fun gr_ord (x :: _, y :: _) =
obua@15178
   469
      int_ord (find_index (equal x) keys, find_index (equal y) keys);
obua@15178
   470
    val scc = map (fn xs => filter (fn (_, (s, _, _)) => s mem xs) eqns')
obua@15178
   471
      (sort gr_ord (Graph.strong_conn gr \ [""]));
obua@15178
   472
  in
obua@15178
   473
    flat (separate [Pretty.str ";", Pretty.fbrk, Pretty.str " ", Pretty.fbrk]
obua@15178
   474
      (map (fn eqns'' => [mk_funs_code sg case_tab eqns' eqns'']) scc)) @
obua@15178
   475
    [Pretty.str ";", Pretty.fbrk]
obua@15178
   476
  end;
obua@15178
   477
obua@15178
   478
fun use_simprocs_code sg eqns =
obua@15178
   479
  let
obua@15178
   480
    fun attach_name (i, x) = (i+1, ("simp_thm_" ^ string_of_int i, x));
obua@15178
   481
    fun attach_names (i, (s, b, eqs)) =
obua@15178
   482
      let val (i', eqs') = foldl_map attach_name (i, eqs)
obua@15178
   483
      in (i', (s, b, eqs')) end;
obua@15178
   484
    val (_, eqns') = foldl_map attach_names (1, eqns);
obua@15178
   485
    val (names, thms) = split_list (flat (map #3 eqns'));
obua@15178
   486
    val s = setmp print_mode [] Pretty.string_of
obua@15178
   487
      (mk_let "local" 2 [mk_vall names [Pretty.str "!SimprocsCodegen.simp_thms"]]
obua@15178
   488
        (mk_simprocs_code sg eqns'))
obua@15178
   489
  in
obua@15178
   490
    (simp_thms := thms; use_text Context.ml_output false s)
obua@15178
   491
  end;
obua@15178
   492
obua@15178
   493
end;