src/HOL/Tools/set_comprehension_pointfree.ML
author bulwahn
Sun, 14 Oct 2012 19:16:32 +0200
changeset 50865 873fa7156468
parent 50864 d9822ec4f434
child 50867 caaa1956f0da
permissions -rw-r--r--
adding postprocessing of computed pointfree expression in set_comprehension_pointfree simproc
wenzelm@49139
     1
(*  Title:      HOL/Tools/set_comprehension_pointfree.ML
bulwahn@49064
     2
    Author:     Felix Kuperjans, Lukas Bulwahn, TU Muenchen
wenzelm@49139
     3
    Author:     Rafal Kolanski, NICTA
bulwahn@49064
     4
bulwahn@49064
     5
Simproc for rewriting set comprehensions to pointfree expressions.
bulwahn@49064
     6
*)
bulwahn@49064
     7
bulwahn@49064
     8
signature SET_COMPREHENSION_POINTFREE =
bulwahn@49064
     9
sig
bulwahn@50864
    10
  val base_simproc : simpset -> cterm -> thm option
wenzelm@49143
    11
  val code_simproc : simpset -> cterm -> thm option
wenzelm@49139
    12
  val simproc : simpset -> cterm -> thm option
bulwahn@49064
    13
end
bulwahn@49064
    14
bulwahn@49064
    15
structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
bulwahn@49064
    16
struct
bulwahn@49064
    17
bulwahn@50864
    18
bulwahn@49064
    19
(* syntactic operations *)
bulwahn@49064
    20
bulwahn@49064
    21
fun mk_inf (t1, t2) =
bulwahn@49064
    22
  let
bulwahn@49064
    23
    val T = fastype_of t1
bulwahn@49064
    24
  in
bulwahn@49064
    25
    Const (@{const_name Lattices.inf_class.inf}, T --> T --> T) $ t1 $ t2
bulwahn@49064
    26
  end
bulwahn@49064
    27
bulwahn@50783
    28
fun mk_sup (t1, t2) =
bulwahn@50783
    29
  let
bulwahn@50783
    30
    val T = fastype_of t1
bulwahn@50783
    31
  in
bulwahn@50783
    32
    Const (@{const_name Lattices.sup_class.sup}, T --> T --> T) $ t1 $ t2
bulwahn@50783
    33
  end
bulwahn@50783
    34
bulwahn@50783
    35
fun mk_Compl t =
bulwahn@50783
    36
  let
bulwahn@50783
    37
    val T = fastype_of t
bulwahn@50783
    38
  in
bulwahn@50783
    39
    Const (@{const_name "Groups.uminus_class.uminus"}, T --> T) $ t
bulwahn@50783
    40
  end
bulwahn@50783
    41
bulwahn@49064
    42
fun mk_image t1 t2 =
bulwahn@49064
    43
  let
bulwahn@49064
    44
    val T as Type (@{type_name fun}, [_ , R]) = fastype_of t1
bulwahn@49064
    45
  in
rafal@49123
    46
    Const (@{const_name image},
rafal@49123
    47
      T --> fastype_of t2 --> HOLogic.mk_setT R) $ t1 $ t2
bulwahn@49064
    48
  end;
bulwahn@49064
    49
bulwahn@49064
    50
fun mk_sigma (t1, t2) =
bulwahn@49064
    51
  let
bulwahn@49064
    52
    val T1 = fastype_of t1
bulwahn@49064
    53
    val T2 = fastype_of t2
bulwahn@49064
    54
    val setT = HOLogic.dest_setT T1
rafal@49123
    55
    val resT = HOLogic.mk_setT (HOLogic.mk_prodT (setT, HOLogic.dest_setT T2))
bulwahn@49064
    56
  in
rafal@49123
    57
    Const (@{const_name Sigma},
rafal@49123
    58
      T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
bulwahn@49064
    59
  end;
bulwahn@49064
    60
bulwahn@49064
    61
fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (_, _, t)) = t
bulwahn@49064
    62
  | dest_Collect t = raise TERM ("dest_Collect", [t])
bulwahn@49064
    63
bulwahn@49064
    64
(* Copied from predicate_compile_aux.ML *)
bulwahn@49064
    65
fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
bulwahn@49064
    66
  let
bulwahn@49064
    67
    val (xTs, t') = strip_ex t
bulwahn@49064
    68
  in
bulwahn@49064
    69
    ((x, T) :: xTs, t')
bulwahn@49064
    70
  end
bulwahn@49064
    71
  | strip_ex t = ([], t)
bulwahn@49064
    72
bulwahn@50864
    73
fun mk_prod1 Ts (t1, t2) =
bulwahn@50864
    74
  let
bulwahn@50864
    75
    val (T1, T2) = pairself (curry fastype_of1 Ts) (t1, t2)
bulwahn@50864
    76
  in
bulwahn@50864
    77
    HOLogic.pair_const T1 T2 $ t1 $ t2
bulwahn@50864
    78
  end;
bulwahn@50864
    79
bulwahn@50864
    80
bulwahn@50864
    81
(* patterns *)
bulwahn@50864
    82
bulwahn@50864
    83
datatype pattern = TBound of int | TPair of pattern * pattern;
bulwahn@50864
    84
bulwahn@50864
    85
fun mk_pattern (Bound n) = TBound n
bulwahn@50864
    86
  | mk_pattern (Const (@{const_name "Product_Type.Pair"}, _) $ l $ r) =
bulwahn@50864
    87
      TPair (mk_pattern l, mk_pattern r)
bulwahn@50864
    88
  | mk_pattern t = raise TERM ("mk_pattern: only bound variable tuples currently supported", [t]);
bulwahn@50864
    89
bulwahn@50864
    90
fun type_of_pattern Ts (TBound n) = nth Ts n
bulwahn@50864
    91
  | type_of_pattern Ts (TPair (l, r)) = HOLogic.mk_prodT (type_of_pattern Ts l, type_of_pattern Ts r)
bulwahn@50864
    92
bulwahn@50864
    93
fun term_of_pattern _ (TBound n) = Bound n
bulwahn@50864
    94
  | term_of_pattern Ts (TPair (l, r)) =
bulwahn@50864
    95
    let
bulwahn@50864
    96
      val (lt, rt) = pairself (term_of_pattern Ts) (l, r)
bulwahn@50864
    97
      val (lT, rT) = pairself (curry fastype_of1 Ts) (lt, rt) 
bulwahn@50864
    98
    in
bulwahn@50864
    99
      HOLogic.pair_const lT rT $ lt $ rt
bulwahn@50864
   100
    end;
bulwahn@50864
   101
bulwahn@50864
   102
fun bounds_of_pattern (TBound i) = [i]
bulwahn@50864
   103
  | bounds_of_pattern (TPair (l, r)) = union (op =) (bounds_of_pattern l) (bounds_of_pattern r)
bulwahn@50864
   104
bulwahn@50864
   105
bulwahn@50864
   106
(* formulas *)
bulwahn@50864
   107
bulwahn@50864
   108
datatype formula = Atom of (pattern * term) | Int of formula * formula | Un of formula * formula
bulwahn@50864
   109
bulwahn@50864
   110
fun mk_atom (Const (@{const_name "Set.member"}, _) $ x $ s) = (mk_pattern x, Atom (mk_pattern x, s))
bulwahn@50864
   111
  | mk_atom (Const (@{const_name "HOL.Not"}, _) $ (Const (@{const_name "Set.member"}, _) $ x $ s)) =
bulwahn@50864
   112
      (mk_pattern x, Atom (mk_pattern x, mk_Compl s))
bulwahn@50864
   113
bulwahn@50864
   114
fun can_merge (pats1, pats2) =
bulwahn@50864
   115
  let
bulwahn@50864
   116
    fun check pat1 pat2 = (pat1 = pat2)
bulwahn@50864
   117
      orelse (inter (op =) (bounds_of_pattern pat1) (bounds_of_pattern pat2) = [])
bulwahn@50864
   118
  in
bulwahn@50864
   119
    forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1 
bulwahn@50864
   120
  end
bulwahn@50864
   121
bulwahn@50864
   122
fun merge_patterns (pats1, pats2) =
bulwahn@50864
   123
  if can_merge (pats1, pats2) then
bulwahn@50864
   124
    union (op =) pats1 pats2
bulwahn@50864
   125
  else raise Fail "merge_patterns: variable groups overlap"
bulwahn@50864
   126
bulwahn@50864
   127
fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2))
bulwahn@50864
   128
bulwahn@50864
   129
fun mk_formula (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula t1) (mk_formula t2)
bulwahn@50864
   130
  | mk_formula (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula t1) (mk_formula t2)
bulwahn@50864
   131
  | mk_formula t = apfst single (mk_atom t)
bulwahn@50864
   132
bulwahn@50864
   133
bulwahn@50864
   134
(* term construction *)
bulwahn@50864
   135
bulwahn@50864
   136
fun reorder_bounds pats t =
bulwahn@50864
   137
  let
bulwahn@50864
   138
    val bounds = maps bounds_of_pattern pats
bulwahn@50864
   139
    val bperm = bounds ~~ ((length bounds - 1) downto 0)
bulwahn@50864
   140
      |> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
bulwahn@50864
   141
  in
bulwahn@50864
   142
    subst_bounds (map Bound bperm, t)
bulwahn@50864
   143
  end;
bulwahn@50864
   144
bulwahn@50864
   145
fun mk_split_abs vs (Bound i) t = let val (x, T) = nth vs i in Abs (x, T, t) end
bulwahn@50864
   146
  | mk_split_abs vs (Const ("Product_Type.Pair", _) $ u $ v) t =
bulwahn@50864
   147
      HOLogic.mk_split (mk_split_abs vs u (mk_split_abs vs v t))
bulwahn@50864
   148
  | mk_split_abs _ t _ = raise TERM ("mk_split_abs: bad term", [t]);
rafal@49123
   149
bulwahn@49064
   150
fun mk_pointfree_expr t =
bulwahn@49064
   151
  let
bulwahn@49064
   152
    val (vs, t'') = strip_ex (dest_Collect t)
bulwahn@50864
   153
    val Ts = map snd (rev vs)
bulwahn@50864
   154
    fun mk_mem_UNIV n = HOLogic.mk_mem (Bound n, HOLogic.mk_UNIV (nth Ts n))
bulwahn@50864
   155
    fun lookup (pat', t) pat = if pat = pat' then t else HOLogic.mk_UNIV (type_of_pattern Ts pat)
bulwahn@50776
   156
    val conjs = HOLogic.dest_conj t''
bulwahn@50776
   157
    val is_the_eq =
bulwahn@50776
   158
      the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
bulwahn@50776
   159
    val SOME eq = find_first is_the_eq conjs
bulwahn@50776
   160
    val f = snd (HOLogic.dest_eq eq)
bulwahn@50776
   161
    val conjs' = filter_out (fn t => eq = t) conjs
bulwahn@50864
   162
    val unused_bounds = subtract (op =) (distinct (op =) (maps loose_bnos conjs'))
bulwahn@50864
   163
      (0 upto (length vs - 1))
bulwahn@50864
   164
    val (pats, fm) =
bulwahn@50864
   165
      mk_formula (foldr1 HOLogic.mk_conj (conjs' @ map mk_mem_UNIV unused_bounds))
bulwahn@50864
   166
    fun mk_set (Atom pt) = (case map (lookup pt) pats of [t'] => t' | ts => foldr1 mk_sigma ts)
bulwahn@50864
   167
      | mk_set (Un (f1, f2)) = mk_sup (mk_set f1, mk_set f2)
bulwahn@50864
   168
      | mk_set (Int (f1, f2)) = mk_inf (mk_set f1, mk_set f2)
bulwahn@50864
   169
    val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
bulwahn@50864
   170
    val t = mk_split_abs (rev vs) pat (reorder_bounds pats f)
bulwahn@49064
   171
  in
bulwahn@50864
   172
    (fm, mk_image t (mk_set fm))
bulwahn@49064
   173
  end;
bulwahn@49064
   174
rafal@49123
   175
val rewrite_term = try mk_pointfree_expr
rafal@49123
   176
bulwahn@50864
   177
bulwahn@49064
   178
(* proof tactic *)
bulwahn@49064
   179
bulwahn@50864
   180
val prod_case_distrib = @{lemma "(prod_case g x) z = prod_case (% x y. (g x y) z) x" by (simp add: prod_case_beta)}
bulwahn@49064
   181
bulwahn@50864
   182
(* FIXME: one of many clones *)
bulwahn@50864
   183
fun Trueprop_conv cv ct =
bulwahn@50864
   184
  (case Thm.term_of ct of
bulwahn@50864
   185
    Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct
bulwahn@50864
   186
  | _ => raise CTERM ("Trueprop_conv", [ct]))
bulwahn@50864
   187
bulwahn@50864
   188
(* FIXME: another clone *)
bulwahn@50864
   189
fun eq_conv cv1 cv2 ct =
bulwahn@50864
   190
  (case Thm.term_of ct of
bulwahn@50864
   191
    Const (@{const_name HOL.eq}, _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv1) cv2 ct
bulwahn@50864
   192
  | _ => raise CTERM ("eq_conv", [ct]))
bulwahn@50864
   193
bulwahn@50864
   194
val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
bulwahn@50864
   195
  THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE}))
bulwahn@50864
   196
  THEN' TRY o etac @{thm conjE}
rafal@49123
   197
  THEN' hyp_subst_tac;
bulwahn@49064
   198
bulwahn@50864
   199
fun intro_image_tac ctxt = rtac @{thm image_eqI}
rafal@49123
   200
    THEN' (REPEAT_DETERM1 o
rafal@49123
   201
      (rtac @{thm refl}
rafal@49123
   202
      ORELSE' rtac
bulwahn@50864
   203
        @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}
bulwahn@50864
   204
      ORELSE' CONVERSION (Conv.params_conv ~1 (K (Conv.concl_conv ~1
bulwahn@50864
   205
        (Trueprop_conv (eq_conv Conv.all_conv (Conv.rewr_conv (mk_meta_eq prod_case_distrib)))))) ctxt)))
bulwahn@49064
   206
bulwahn@50864
   207
val elim_image_tac = etac @{thm imageE}
rafal@49123
   208
  THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
rafal@49123
   209
  THEN' hyp_subst_tac
rafal@49123
   210
rafal@49123
   211
val intro_Collect_tac = rtac @{thm iffD2[OF mem_Collect_eq]}
rafal@49123
   212
  THEN' REPEAT_DETERM1 o resolve_tac @{thms exI}
bulwahn@50864
   213
  THEN' (TRY o (rtac @{thm conjI}))
bulwahn@50864
   214
  THEN' (TRY o hyp_subst_tac)
bulwahn@50864
   215
  THEN' rtac @{thm refl};
bulwahn@49064
   216
bulwahn@50864
   217
fun tac1_of_formula (Int (fm1, fm2)) =
bulwahn@50864
   218
    TRY o etac @{thm conjE}
bulwahn@50864
   219
    THEN' rtac @{thm IntI}
bulwahn@50864
   220
    THEN' (fn i => tac1_of_formula fm2 (i + 1))
bulwahn@50864
   221
    THEN' tac1_of_formula fm1
bulwahn@50864
   222
  | tac1_of_formula (Un (fm1, fm2)) =
bulwahn@50864
   223
    etac @{thm disjE} THEN' rtac @{thm UnI1}
bulwahn@50864
   224
    THEN' tac1_of_formula fm1
bulwahn@50864
   225
    THEN' rtac @{thm UnI2}
bulwahn@50864
   226
    THEN' tac1_of_formula fm2
bulwahn@50864
   227
  | tac1_of_formula (Atom _) =
bulwahn@50864
   228
    (REPEAT_DETERM1 o (rtac @{thm SigmaI}
bulwahn@50864
   229
      ORELSE' rtac @{thm UNIV_I}
bulwahn@50864
   230
      ORELSE' rtac @{thm iffD2[OF Compl_iff]}
bulwahn@50864
   231
      ORELSE' atac))
bulwahn@50864
   232
bulwahn@50864
   233
fun tac2_of_formula (Int (fm1, fm2)) =
bulwahn@50864
   234
    TRY o etac @{thm IntE}
bulwahn@50864
   235
    THEN' TRY o rtac @{thm conjI}
bulwahn@50864
   236
    THEN' (fn i => tac2_of_formula fm2 (i + 1))
bulwahn@50864
   237
    THEN' tac2_of_formula fm1
bulwahn@50864
   238
  | tac2_of_formula (Un (fm1, fm2)) =
bulwahn@50864
   239
    etac @{thm UnE} THEN' rtac @{thm disjI1}
bulwahn@50864
   240
    THEN' tac2_of_formula fm1
bulwahn@50864
   241
    THEN' rtac @{thm disjI2}
bulwahn@50864
   242
    THEN' tac2_of_formula fm2
bulwahn@50864
   243
  | tac2_of_formula (Atom _) =
bulwahn@50864
   244
    TRY o REPEAT_DETERM1 o
bulwahn@50864
   245
      (dtac @{thm iffD1[OF mem_Sigma_iff]}
bulwahn@50864
   246
       ORELSE' etac @{thm conjE}
bulwahn@50864
   247
       ORELSE' etac @{thm ComplE}
bulwahn@50864
   248
       ORELSE' atac)
bulwahn@50864
   249
bulwahn@50864
   250
fun tac ctxt fm =
rafal@49123
   251
  let
rafal@49123
   252
    val subset_tac1 = rtac @{thm subsetI}
bulwahn@50864
   253
      THEN' elim_Collect_tac
bulwahn@50864
   254
      THEN' (intro_image_tac ctxt)
bulwahn@50864
   255
      THEN' (tac1_of_formula fm)
rafal@49123
   256
    val subset_tac2 = rtac @{thm subsetI}
bulwahn@50864
   257
      THEN' elim_image_tac
rafal@49123
   258
      THEN' intro_Collect_tac
bulwahn@50864
   259
      THEN' tac2_of_formula fm
rafal@49123
   260
  in
rafal@49123
   261
    rtac @{thm subset_antisym} THEN' subset_tac1 THEN' subset_tac2
rafal@49123
   262
  end;
rafal@49123
   263
bulwahn@50864
   264
bulwahn@50864
   265
(* main simprocs *)
bulwahn@50864
   266
bulwahn@50865
   267
val post_thms =
bulwahn@50865
   268
  map mk_meta_eq [@{thm Times_Un_distrib1[symmetric]},
bulwahn@50865
   269
  @{lemma "A \<times> B \<union> A \<times> C = A \<times> (B \<union> C)" by auto},
bulwahn@50865
   270
  @{lemma "(A \<times> B \<inter> C \<times> D) = (A \<inter> C) \<times> (B \<inter> D)" by auto}]
bulwahn@50865
   271
rafal@49123
   272
fun conv ctxt t =
rafal@49123
   273
  let
bulwahn@50780
   274
    val ct = cterm_of (Proof_Context.theory_of ctxt) t
bulwahn@50780
   275
    val Bex_def = mk_meta_eq @{thm Bex_def}
bulwahn@50780
   276
    val unfold_eq = Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv Bex_def))) ctxt ct
bulwahn@50864
   277
    val t' = term_of (Thm.rhs_of unfold_eq)
bulwahn@50864
   278
    fun mk_thm (fm, t'') = Goal.prove ctxt [] []
bulwahn@50864
   279
      (HOLogic.mk_Trueprop (HOLogic.mk_eq (t', t''))) (fn {context, ...} => tac context fm 1)
bulwahn@50780
   280
    fun unfold th = th RS ((unfold_eq RS meta_eq_to_obj_eq) RS @{thm trans})
bulwahn@50865
   281
    fun post th = Conv.fconv_rule (Trueprop_conv (eq_conv Conv.all_conv
bulwahn@50865
   282
      (Raw_Simplifier.rewrite true post_thms))) th
rafal@49123
   283
  in
bulwahn@50865
   284
    Option.map (post o unfold o mk_thm) (rewrite_term t')
rafal@49123
   285
  end;
bulwahn@49064
   286
wenzelm@49143
   287
fun base_simproc ss redex =
bulwahn@49137
   288
  let
bulwahn@49137
   289
    val ctxt = Simplifier.the_context ss
bulwahn@49137
   290
    val set_compr = term_of redex
bulwahn@49137
   291
  in
bulwahn@49137
   292
    conv ctxt set_compr
bulwahn@49137
   293
    |> Option.map (fn thm => thm RS @{thm eq_reflection})
bulwahn@49137
   294
  end;
bulwahn@49137
   295
bulwahn@50778
   296
fun instantiate_arg_cong ctxt pred =
bulwahn@50778
   297
  let
bulwahn@50778
   298
    val certify = cterm_of (Proof_Context.theory_of ctxt)
bulwahn@50846
   299
    val arg_cong = Thm.incr_indexes (maxidx_of_term pred + 1) @{thm arg_cong}
bulwahn@50778
   300
    val f $ _ = fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (concl_of arg_cong)))
bulwahn@50778
   301
  in
bulwahn@50778
   302
    cterm_instantiate [(certify f, certify pred)] arg_cong
bulwahn@50778
   303
  end;
bulwahn@50778
   304
wenzelm@49139
   305
fun simproc ss redex =
bulwahn@49064
   306
  let
bulwahn@49064
   307
    val ctxt = Simplifier.the_context ss
bulwahn@50778
   308
    val pred $ set_compr = term_of redex
bulwahn@50778
   309
    val arg_cong' = instantiate_arg_cong ctxt pred
bulwahn@49064
   310
  in
rafal@49123
   311
    conv ctxt set_compr
bulwahn@50778
   312
    |> Option.map (fn thm => thm RS arg_cong' RS @{thm eq_reflection})
rafal@49123
   313
  end;
bulwahn@49064
   314
wenzelm@49143
   315
fun code_simproc ss redex =
bulwahn@49137
   316
  let
bulwahn@49137
   317
    val prep_thm = Raw_Simplifier.rewrite false @{thms eq_equal[symmetric]} redex
bulwahn@49137
   318
  in
wenzelm@49143
   319
    case base_simproc ss (Thm.rhs_of prep_thm) of
bulwahn@49137
   320
      SOME rewr_thm => SOME (transitive_thm OF [transitive_thm OF [prep_thm, rewr_thm],
bulwahn@49137
   321
        Raw_Simplifier.rewrite false @{thms eq_equal} (Thm.rhs_of rewr_thm)])
bulwahn@49137
   322
    | NONE => NONE
bulwahn@49137
   323
  end;
bulwahn@49137
   324
bulwahn@49064
   325
end;
rafal@49123
   326