src/HOL/Tools/set_comprehension_pointfree.ML
author bulwahn
Sun, 14 Oct 2012 19:16:35 +0200
changeset 50867 caaa1956f0da
parent 50865 873fa7156468
child 50872 7bf407d77152
permissions -rw-r--r--
refined tactic in set_comprehension_pointfree simproc
     1 (*  Title:      HOL/Tools/set_comprehension_pointfree.ML
     2     Author:     Felix Kuperjans, Lukas Bulwahn, TU Muenchen
     3     Author:     Rafal Kolanski, NICTA
     4 
     5 Simproc for rewriting set comprehensions to pointfree expressions.
     6 *)
     7 
     8 signature SET_COMPREHENSION_POINTFREE =
     9 sig
    10   val base_simproc : simpset -> cterm -> thm option
    11   val code_simproc : simpset -> cterm -> thm option
    12   val simproc : simpset -> cterm -> thm option
    13 end
    14 
    15 structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
    16 struct
    17 
    18 
    19 (* syntactic operations *)
    20 
    21 fun mk_inf (t1, t2) =
    22   let
    23     val T = fastype_of t1
    24   in
    25     Const (@{const_name Lattices.inf_class.inf}, T --> T --> T) $ t1 $ t2
    26   end
    27 
    28 fun mk_sup (t1, t2) =
    29   let
    30     val T = fastype_of t1
    31   in
    32     Const (@{const_name Lattices.sup_class.sup}, T --> T --> T) $ t1 $ t2
    33   end
    34 
    35 fun mk_Compl t =
    36   let
    37     val T = fastype_of t
    38   in
    39     Const (@{const_name "Groups.uminus_class.uminus"}, T --> T) $ t
    40   end
    41 
    42 fun mk_image t1 t2 =
    43   let
    44     val T as Type (@{type_name fun}, [_ , R]) = fastype_of t1
    45   in
    46     Const (@{const_name image},
    47       T --> fastype_of t2 --> HOLogic.mk_setT R) $ t1 $ t2
    48   end;
    49 
    50 fun mk_sigma (t1, t2) =
    51   let
    52     val T1 = fastype_of t1
    53     val T2 = fastype_of t2
    54     val setT = HOLogic.dest_setT T1
    55     val resT = HOLogic.mk_setT (HOLogic.mk_prodT (setT, HOLogic.dest_setT T2))
    56   in
    57     Const (@{const_name Sigma},
    58       T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
    59   end;
    60 
    61 fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (_, _, t)) = t
    62   | dest_Collect t = raise TERM ("dest_Collect", [t])
    63 
    64 (* Copied from predicate_compile_aux.ML *)
    65 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
    66   let
    67     val (xTs, t') = strip_ex t
    68   in
    69     ((x, T) :: xTs, t')
    70   end
    71   | strip_ex t = ([], t)
    72 
    73 fun mk_prod1 Ts (t1, t2) =
    74   let
    75     val (T1, T2) = pairself (curry fastype_of1 Ts) (t1, t2)
    76   in
    77     HOLogic.pair_const T1 T2 $ t1 $ t2
    78   end;
    79 
    80 
    81 (* patterns *)
    82 
    83 datatype pattern = TBound of int | TPair of pattern * pattern;
    84 
    85 fun mk_pattern (Bound n) = TBound n
    86   | mk_pattern (Const (@{const_name "Product_Type.Pair"}, _) $ l $ r) =
    87       TPair (mk_pattern l, mk_pattern r)
    88   | mk_pattern t = raise TERM ("mk_pattern: only bound variable tuples currently supported", [t]);
    89 
    90 fun type_of_pattern Ts (TBound n) = nth Ts n
    91   | type_of_pattern Ts (TPair (l, r)) = HOLogic.mk_prodT (type_of_pattern Ts l, type_of_pattern Ts r)
    92 
    93 fun term_of_pattern _ (TBound n) = Bound n
    94   | term_of_pattern Ts (TPair (l, r)) =
    95     let
    96       val (lt, rt) = pairself (term_of_pattern Ts) (l, r)
    97       val (lT, rT) = pairself (curry fastype_of1 Ts) (lt, rt) 
    98     in
    99       HOLogic.pair_const lT rT $ lt $ rt
   100     end;
   101 
   102 fun bounds_of_pattern (TBound i) = [i]
   103   | bounds_of_pattern (TPair (l, r)) = union (op =) (bounds_of_pattern l) (bounds_of_pattern r)
   104 
   105 
   106 (* formulas *)
   107 
   108 datatype formula = Atom of (pattern * term) | Int of formula * formula | Un of formula * formula
   109 
   110 fun mk_atom (Const (@{const_name "Set.member"}, _) $ x $ s) = (mk_pattern x, Atom (mk_pattern x, s))
   111   | mk_atom (Const (@{const_name "HOL.Not"}, _) $ (Const (@{const_name "Set.member"}, _) $ x $ s)) =
   112       (mk_pattern x, Atom (mk_pattern x, mk_Compl s))
   113 
   114 fun can_merge (pats1, pats2) =
   115   let
   116     fun check pat1 pat2 = (pat1 = pat2)
   117       orelse (inter (op =) (bounds_of_pattern pat1) (bounds_of_pattern pat2) = [])
   118   in
   119     forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1 
   120   end
   121 
   122 fun merge_patterns (pats1, pats2) =
   123   if can_merge (pats1, pats2) then
   124     union (op =) pats1 pats2
   125   else raise Fail "merge_patterns: variable groups overlap"
   126 
   127 fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2))
   128 
   129 fun mk_formula (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula t1) (mk_formula t2)
   130   | mk_formula (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula t1) (mk_formula t2)
   131   | mk_formula t = apfst single (mk_atom t)
   132 
   133 fun strip_Int (Int (fm1, fm2)) = fm1 :: (strip_Int fm2) 
   134   | strip_Int fm = [fm]
   135 
   136 (* term construction *)
   137 
   138 fun reorder_bounds pats t =
   139   let
   140     val bounds = maps bounds_of_pattern pats
   141     val bperm = bounds ~~ ((length bounds - 1) downto 0)
   142       |> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
   143   in
   144     subst_bounds (map Bound bperm, t)
   145   end;
   146 
   147 fun mk_split_abs vs (Bound i) t = let val (x, T) = nth vs i in Abs (x, T, t) end
   148   | mk_split_abs vs (Const ("Product_Type.Pair", _) $ u $ v) t =
   149       HOLogic.mk_split (mk_split_abs vs u (mk_split_abs vs v t))
   150   | mk_split_abs _ t _ = raise TERM ("mk_split_abs: bad term", [t]);
   151 
   152 fun mk_pointfree_expr t =
   153   let
   154     val (vs, t'') = strip_ex (dest_Collect t)
   155     val Ts = map snd (rev vs)
   156     fun mk_mem_UNIV n = HOLogic.mk_mem (Bound n, HOLogic.mk_UNIV (nth Ts n))
   157     fun lookup (pat', t) pat = if pat = pat' then t else HOLogic.mk_UNIV (type_of_pattern Ts pat)
   158     val conjs = HOLogic.dest_conj t''
   159     val is_the_eq =
   160       the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
   161     val SOME eq = find_first is_the_eq conjs
   162     val f = snd (HOLogic.dest_eq eq)
   163     val conjs' = filter_out (fn t => eq = t) conjs
   164     val unused_bounds = subtract (op =) (distinct (op =) (maps loose_bnos conjs'))
   165       (0 upto (length vs - 1))
   166     val (pats, fm) =
   167       mk_formula (foldr1 HOLogic.mk_conj (conjs' @ map mk_mem_UNIV unused_bounds))
   168     fun mk_set (Atom pt) = (case map (lookup pt) pats of [t'] => t' | ts => foldr1 mk_sigma ts)
   169       | mk_set (Un (f1, f2)) = mk_sup (mk_set f1, mk_set f2)
   170       | mk_set (Int (f1, f2)) = mk_inf (mk_set f1, mk_set f2)
   171     val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
   172     val t = mk_split_abs (rev vs) pat (reorder_bounds pats f)
   173   in
   174     (fm, mk_image t (mk_set fm))
   175   end;
   176 
   177 val rewrite_term = try mk_pointfree_expr
   178 
   179 
   180 (* proof tactic *)
   181 
   182 val prod_case_distrib = @{lemma "(prod_case g x) z = prod_case (% x y. (g x y) z) x" by (simp add: prod_case_beta)}
   183 
   184 (* FIXME: one of many clones *)
   185 fun Trueprop_conv cv ct =
   186   (case Thm.term_of ct of
   187     Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct
   188   | _ => raise CTERM ("Trueprop_conv", [ct]))
   189 
   190 (* FIXME: another clone *)
   191 fun eq_conv cv1 cv2 ct =
   192   (case Thm.term_of ct of
   193     Const (@{const_name HOL.eq}, _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv1) cv2 ct
   194   | _ => raise CTERM ("eq_conv", [ct]))
   195 
   196 val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
   197   THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE}))
   198   THEN' TRY o etac @{thm conjE}
   199   THEN' hyp_subst_tac;
   200 
   201 fun intro_image_tac ctxt = rtac @{thm image_eqI}
   202     THEN' (REPEAT_DETERM1 o
   203       (rtac @{thm refl}
   204       ORELSE' rtac
   205         @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}
   206       ORELSE' CONVERSION (Conv.params_conv ~1 (K (Conv.concl_conv ~1
   207         (Trueprop_conv (eq_conv Conv.all_conv (Conv.rewr_conv (mk_meta_eq prod_case_distrib)))))) ctxt)))
   208 
   209 val elim_image_tac = etac @{thm imageE}
   210   THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
   211   THEN' hyp_subst_tac
   212 
   213 fun tac1_of_formula (Int (fm1, fm2)) =
   214     TRY o etac @{thm conjE}
   215     THEN' rtac @{thm IntI}
   216     THEN' (fn i => tac1_of_formula fm2 (i + 1))
   217     THEN' tac1_of_formula fm1
   218   | tac1_of_formula (Un (fm1, fm2)) =
   219     etac @{thm disjE} THEN' rtac @{thm UnI1}
   220     THEN' tac1_of_formula fm1
   221     THEN' rtac @{thm UnI2}
   222     THEN' tac1_of_formula fm2
   223   | tac1_of_formula (Atom _) =
   224     (REPEAT_DETERM1 o (rtac @{thm SigmaI}
   225       ORELSE' rtac @{thm UNIV_I}
   226       ORELSE' rtac @{thm iffD2[OF Compl_iff]}
   227       ORELSE' atac))
   228 
   229 fun tac2_of_formula (Int (fm1, fm2)) =
   230     TRY o etac @{thm IntE}
   231     THEN' TRY o rtac @{thm conjI}
   232     THEN' (fn i => tac2_of_formula fm2 (i + 1))
   233     THEN' tac2_of_formula fm1
   234   | tac2_of_formula (Un (fm1, fm2)) =
   235     etac @{thm UnE} THEN' rtac @{thm disjI1}
   236     THEN' tac2_of_formula fm1
   237     THEN' rtac @{thm disjI2}
   238     THEN' tac2_of_formula fm2
   239   | tac2_of_formula (Atom _) =
   240     TRY o REPEAT_DETERM1 o
   241       (dtac @{thm iffD1[OF mem_Sigma_iff]}
   242        ORELSE' etac @{thm conjE}
   243        ORELSE' etac @{thm ComplE}
   244        ORELSE' atac)
   245 
   246 fun tac ctxt fm =
   247   let
   248     val subset_tac1 = rtac @{thm subsetI}
   249       THEN' elim_Collect_tac
   250       THEN' (intro_image_tac ctxt)
   251       THEN' (tac1_of_formula fm)
   252     val subset_tac2 = rtac @{thm subsetI}
   253       THEN' elim_image_tac
   254       THEN' rtac @{thm iffD2[OF mem_Collect_eq]}
   255       THEN' REPEAT_DETERM1 o resolve_tac @{thms exI}
   256       THEN' (TRY o REPEAT_ALL_NEW (rtac @{thm conjI}))
   257       THEN' (K (SOMEGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl})))
   258       THEN' (fn i => EVERY (rev (map_index (fn (j, f) =>
   259         REPEAT_DETERM (etac @{thm IntE} (i + j)) THEN tac2_of_formula f (i + j)) (strip_Int fm))))
   260   in
   261     rtac @{thm subset_antisym} THEN' subset_tac1 THEN' subset_tac2
   262   end;
   263 
   264 
   265 (* main simprocs *)
   266 
   267 val post_thms =
   268   map mk_meta_eq [@{thm Times_Un_distrib1[symmetric]},
   269   @{lemma "A \<times> B \<union> A \<times> C = A \<times> (B \<union> C)" by auto},
   270   @{lemma "(A \<times> B \<inter> C \<times> D) = (A \<inter> C) \<times> (B \<inter> D)" by auto}]
   271 
   272 fun conv ctxt t =
   273   let
   274     val ct = cterm_of (Proof_Context.theory_of ctxt) t
   275     val Bex_def = mk_meta_eq @{thm Bex_def}
   276     val unfold_eq = Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv Bex_def))) ctxt ct
   277     val t' = term_of (Thm.rhs_of unfold_eq)
   278     fun mk_thm (fm, t'') = Goal.prove ctxt [] []
   279       (HOLogic.mk_Trueprop (HOLogic.mk_eq (t', t''))) (fn {context, ...} => tac context fm 1)
   280     fun unfold th = th RS ((unfold_eq RS meta_eq_to_obj_eq) RS @{thm trans})
   281     fun post th = Conv.fconv_rule (Trueprop_conv (eq_conv Conv.all_conv
   282       (Raw_Simplifier.rewrite true post_thms))) th
   283   in
   284     Option.map (post o unfold o mk_thm) (rewrite_term t')
   285   end;
   286 
   287 fun base_simproc ss redex =
   288   let
   289     val ctxt = Simplifier.the_context ss
   290     val set_compr = term_of redex
   291   in
   292     conv ctxt set_compr
   293     |> Option.map (fn thm => thm RS @{thm eq_reflection})
   294   end;
   295 
   296 fun instantiate_arg_cong ctxt pred =
   297   let
   298     val certify = cterm_of (Proof_Context.theory_of ctxt)
   299     val arg_cong = Thm.incr_indexes (maxidx_of_term pred + 1) @{thm arg_cong}
   300     val f $ _ = fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (concl_of arg_cong)))
   301   in
   302     cterm_instantiate [(certify f, certify pred)] arg_cong
   303   end;
   304 
   305 fun simproc ss redex =
   306   let
   307     val ctxt = Simplifier.the_context ss
   308     val pred $ set_compr = term_of redex
   309     val arg_cong' = instantiate_arg_cong ctxt pred
   310   in
   311     conv ctxt set_compr
   312     |> Option.map (fn thm => thm RS arg_cong' RS @{thm eq_reflection})
   313   end;
   314 
   315 fun code_simproc ss redex =
   316   let
   317     val prep_thm = Raw_Simplifier.rewrite false @{thms eq_equal[symmetric]} redex
   318   in
   319     case base_simproc ss (Thm.rhs_of prep_thm) of
   320       SOME rewr_thm => SOME (transitive_thm OF [transitive_thm OF [prep_thm, rewr_thm],
   321         Raw_Simplifier.rewrite false @{thms eq_equal} (Thm.rhs_of rewr_thm)])
   322     | NONE => NONE
   323   end;
   324 
   325 end;
   326