src/HOL/Tools/set_comprehension_pointfree.ML
changeset 50864 d9822ec4f434
parent 50846 b28dbb7a45d9
child 50865 873fa7156468
equal deleted inserted replaced
50863:f222a054342e 50864:d9822ec4f434
     5 Simproc for rewriting set comprehensions to pointfree expressions.
     5 Simproc for rewriting set comprehensions to pointfree expressions.
     6 *)
     6 *)
     7 
     7 
     8 signature SET_COMPREHENSION_POINTFREE =
     8 signature SET_COMPREHENSION_POINTFREE =
     9 sig
     9 sig
       
    10   val base_simproc : simpset -> cterm -> thm option
    10   val code_simproc : simpset -> cterm -> thm option
    11   val code_simproc : simpset -> cterm -> thm option
    11   val simproc : simpset -> cterm -> thm option
    12   val simproc : simpset -> cterm -> thm option
    12   val rewrite_term : term -> term option
       
    13   (* FIXME: function conv is not a conversion, i.e. of type cterm -> thm, MAYBE rename *)
       
    14   val conv : Proof.context -> term -> thm option
       
    15 end
    13 end
    16 
    14 
    17 structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
    15 structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
    18 struct
    16 struct
       
    17 
    19 
    18 
    20 (* syntactic operations *)
    19 (* syntactic operations *)
    21 
    20 
    22 fun mk_inf (t1, t2) =
    21 fun mk_inf (t1, t2) =
    23   let
    22   let
    57   in
    56   in
    58     Const (@{const_name Sigma},
    57     Const (@{const_name Sigma},
    59       T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
    58       T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
    60   end;
    59   end;
    61 
    60 
    62 fun dest_Bound (Bound x) = x
       
    63   | dest_Bound t = raise TERM("dest_Bound", [t]);
       
    64 
       
    65 fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (_, _, t)) = t
    61 fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (_, _, t)) = t
    66   | dest_Collect t = raise TERM ("dest_Collect", [t])
    62   | dest_Collect t = raise TERM ("dest_Collect", [t])
    67 
    63 
    68 (* Copied from predicate_compile_aux.ML *)
    64 (* Copied from predicate_compile_aux.ML *)
    69 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
    65 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
    72   in
    68   in
    73     ((x, T) :: xTs, t')
    69     ((x, T) :: xTs, t')
    74   end
    70   end
    75   | strip_ex t = ([], t)
    71   | strip_ex t = ([], t)
    76 
    72 
    77 fun list_tupled_abs [] f = f
    73 fun mk_prod1 Ts (t1, t2) =
    78   | list_tupled_abs [(n, T)] f = (Abs (n, T, f))
    74   let
    79   | list_tupled_abs ((n, T)::v::vs) f =
    75     val (T1, T2) = pairself (curry fastype_of1 Ts) (t1, t2)
    80       HOLogic.mk_split (Abs (n, T, list_tupled_abs (v::vs) f))
    76   in
       
    77     HOLogic.pair_const T1 T2 $ t1 $ t2
       
    78   end;
       
    79 
       
    80 
       
    81 (* patterns *)
       
    82 
       
    83 datatype pattern = TBound of int | TPair of pattern * pattern;
       
    84 
       
    85 fun mk_pattern (Bound n) = TBound n
       
    86   | mk_pattern (Const (@{const_name "Product_Type.Pair"}, _) $ l $ r) =
       
    87       TPair (mk_pattern l, mk_pattern r)
       
    88   | mk_pattern t = raise TERM ("mk_pattern: only bound variable tuples currently supported", [t]);
       
    89 
       
    90 fun type_of_pattern Ts (TBound n) = nth Ts n
       
    91   | type_of_pattern Ts (TPair (l, r)) = HOLogic.mk_prodT (type_of_pattern Ts l, type_of_pattern Ts r)
       
    92 
       
    93 fun term_of_pattern _ (TBound n) = Bound n
       
    94   | term_of_pattern Ts (TPair (l, r)) =
       
    95     let
       
    96       val (lt, rt) = pairself (term_of_pattern Ts) (l, r)
       
    97       val (lT, rT) = pairself (curry fastype_of1 Ts) (lt, rt) 
       
    98     in
       
    99       HOLogic.pair_const lT rT $ lt $ rt
       
   100     end;
       
   101 
       
   102 fun bounds_of_pattern (TBound i) = [i]
       
   103   | bounds_of_pattern (TPair (l, r)) = union (op =) (bounds_of_pattern l) (bounds_of_pattern r)
       
   104 
       
   105 
       
   106 (* formulas *)
       
   107 
       
   108 datatype formula = Atom of (pattern * term) | Int of formula * formula | Un of formula * formula
       
   109 
       
   110 fun mk_atom (Const (@{const_name "Set.member"}, _) $ x $ s) = (mk_pattern x, Atom (mk_pattern x, s))
       
   111   | mk_atom (Const (@{const_name "HOL.Not"}, _) $ (Const (@{const_name "Set.member"}, _) $ x $ s)) =
       
   112       (mk_pattern x, Atom (mk_pattern x, mk_Compl s))
       
   113 
       
   114 fun can_merge (pats1, pats2) =
       
   115   let
       
   116     fun check pat1 pat2 = (pat1 = pat2)
       
   117       orelse (inter (op =) (bounds_of_pattern pat1) (bounds_of_pattern pat2) = [])
       
   118   in
       
   119     forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1 
       
   120   end
       
   121 
       
   122 fun merge_patterns (pats1, pats2) =
       
   123   if can_merge (pats1, pats2) then
       
   124     union (op =) pats1 pats2
       
   125   else raise Fail "merge_patterns: variable groups overlap"
       
   126 
       
   127 fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2))
       
   128 
       
   129 fun mk_formula (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula t1) (mk_formula t2)
       
   130   | mk_formula (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula t1) (mk_formula t2)
       
   131   | mk_formula t = apfst single (mk_atom t)
       
   132 
       
   133 
       
   134 (* term construction *)
       
   135 
       
   136 fun reorder_bounds pats t =
       
   137   let
       
   138     val bounds = maps bounds_of_pattern pats
       
   139     val bperm = bounds ~~ ((length bounds - 1) downto 0)
       
   140       |> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
       
   141   in
       
   142     subst_bounds (map Bound bperm, t)
       
   143   end;
       
   144 
       
   145 fun mk_split_abs vs (Bound i) t = let val (x, T) = nth vs i in Abs (x, T, t) end
       
   146   | mk_split_abs vs (Const ("Product_Type.Pair", _) $ u $ v) t =
       
   147       HOLogic.mk_split (mk_split_abs vs u (mk_split_abs vs v t))
       
   148   | mk_split_abs _ t _ = raise TERM ("mk_split_abs: bad term", [t]);
    81 
   149 
    82 fun mk_pointfree_expr t =
   150 fun mk_pointfree_expr t =
    83   let
   151   let
    84     val (vs, t'') = strip_ex (dest_Collect t)
   152     val (vs, t'') = strip_ex (dest_Collect t)
       
   153     val Ts = map snd (rev vs)
       
   154     fun mk_mem_UNIV n = HOLogic.mk_mem (Bound n, HOLogic.mk_UNIV (nth Ts n))
       
   155     fun lookup (pat', t) pat = if pat = pat' then t else HOLogic.mk_UNIV (type_of_pattern Ts pat)
    85     val conjs = HOLogic.dest_conj t''
   156     val conjs = HOLogic.dest_conj t''
    86     val is_the_eq =
   157     val is_the_eq =
    87       the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
   158       the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
    88     val SOME eq = find_first is_the_eq conjs
   159     val SOME eq = find_first is_the_eq conjs
    89     val f = snd (HOLogic.dest_eq eq)
   160     val f = snd (HOLogic.dest_eq eq)
    90     val conjs' = filter_out (fn t => eq = t) conjs
   161     val conjs' = filter_out (fn t => eq = t) conjs
    91     val mems = map (apfst dest_Bound o HOLogic.dest_mem) conjs'
   162     val unused_bounds = subtract (op =) (distinct (op =) (maps loose_bnos conjs'))
    92     val grouped_mems = AList.group (op =) mems
   163       (0 upto (length vs - 1))
    93     fun mk_grouped_unions (i, T) =
   164     val (pats, fm) =
    94       case AList.lookup (op =) grouped_mems i of
   165       mk_formula (foldr1 HOLogic.mk_conj (conjs' @ map mk_mem_UNIV unused_bounds))
    95         SOME ts => foldr1 mk_inf ts
   166     fun mk_set (Atom pt) = (case map (lookup pt) pats of [t'] => t' | ts => foldr1 mk_sigma ts)
    96       | NONE => HOLogic.mk_UNIV T
   167       | mk_set (Un (f1, f2)) = mk_sup (mk_set f1, mk_set f2)
    97     val complete_sets = map mk_grouped_unions ((length vs - 1) downto 0 ~~ map snd vs)
   168       | mk_set (Int (f1, f2)) = mk_inf (mk_set f1, mk_set f2)
    98   in
   169     val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
    99     mk_image (list_tupled_abs vs f) (foldr1 mk_sigma complete_sets)
   170     val t = mk_split_abs (rev vs) pat (reorder_bounds pats f)
       
   171   in
       
   172     (fm, mk_image t (mk_set fm))
   100   end;
   173   end;
   101 
   174 
   102 val rewrite_term = try mk_pointfree_expr
   175 val rewrite_term = try mk_pointfree_expr
   103 
   176 
       
   177 
   104 (* proof tactic *)
   178 (* proof tactic *)
   105 
   179 
   106 (* Tactic works for arbitrary number of m : S conjuncts *)
   180 val prod_case_distrib = @{lemma "(prod_case g x) z = prod_case (% x y. (g x y) z) x" by (simp add: prod_case_beta)}
   107 
   181 
   108 val dest_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
   182 (* FIXME: one of many clones *)
   109   THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE conjE}))
   183 fun Trueprop_conv cv ct =
       
   184   (case Thm.term_of ct of
       
   185     Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct
       
   186   | _ => raise CTERM ("Trueprop_conv", [ct]))
       
   187 
       
   188 (* FIXME: another clone *)
       
   189 fun eq_conv cv1 cv2 ct =
       
   190   (case Thm.term_of ct of
       
   191     Const (@{const_name HOL.eq}, _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv1) cv2 ct
       
   192   | _ => raise CTERM ("eq_conv", [ct]))
       
   193 
       
   194 val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
       
   195   THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE}))
       
   196   THEN' TRY o etac @{thm conjE}
   110   THEN' hyp_subst_tac;
   197   THEN' hyp_subst_tac;
   111 
   198 
   112 val intro_image_Sigma_tac = rtac @{thm image_eqI}
   199 fun intro_image_tac ctxt = rtac @{thm image_eqI}
   113     THEN' (REPEAT_DETERM1 o
   200     THEN' (REPEAT_DETERM1 o
   114       (rtac @{thm refl}
   201       (rtac @{thm refl}
   115       ORELSE' rtac
   202       ORELSE' rtac
   116         @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}));
   203         @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}
   117 
   204       ORELSE' CONVERSION (Conv.params_conv ~1 (K (Conv.concl_conv ~1
   118 val dest_image_Sigma_tac = etac @{thm imageE}
   205         (Trueprop_conv (eq_conv Conv.all_conv (Conv.rewr_conv (mk_meta_eq prod_case_distrib)))))) ctxt)))
       
   206 
       
   207 val elim_image_tac = etac @{thm imageE}
   119   THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
   208   THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
   120   THEN' hyp_subst_tac
   209   THEN' hyp_subst_tac
   121   THEN' (TRY o REPEAT_DETERM1 o
       
   122     (etac @{thm conjE} ORELSE' dtac @{thm iffD1[OF mem_Sigma_iff]}));
       
   123 
   210 
   124 val intro_Collect_tac = rtac @{thm iffD2[OF mem_Collect_eq]}
   211 val intro_Collect_tac = rtac @{thm iffD2[OF mem_Collect_eq]}
   125   THEN' REPEAT_DETERM1 o resolve_tac @{thms exI}
   212   THEN' REPEAT_DETERM1 o resolve_tac @{thms exI}
   126   THEN' (TRY o REPEAT_ALL_NEW (rtac @{thm conjI}))
   213   THEN' (TRY o (rtac @{thm conjI}))
   127   THEN' (K (ALLGOALS (TRY o ((TRY o hyp_subst_tac) THEN' rtac @{thm refl}))))
   214   THEN' (TRY o hyp_subst_tac)
   128 
   215   THEN' rtac @{thm refl};
   129 val tac =
   216 
       
   217 fun tac1_of_formula (Int (fm1, fm2)) =
       
   218     TRY o etac @{thm conjE}
       
   219     THEN' rtac @{thm IntI}
       
   220     THEN' (fn i => tac1_of_formula fm2 (i + 1))
       
   221     THEN' tac1_of_formula fm1
       
   222   | tac1_of_formula (Un (fm1, fm2)) =
       
   223     etac @{thm disjE} THEN' rtac @{thm UnI1}
       
   224     THEN' tac1_of_formula fm1
       
   225     THEN' rtac @{thm UnI2}
       
   226     THEN' tac1_of_formula fm2
       
   227   | tac1_of_formula (Atom _) =
       
   228     (REPEAT_DETERM1 o (rtac @{thm SigmaI}
       
   229       ORELSE' rtac @{thm UNIV_I}
       
   230       ORELSE' rtac @{thm iffD2[OF Compl_iff]}
       
   231       ORELSE' atac))
       
   232 
       
   233 fun tac2_of_formula (Int (fm1, fm2)) =
       
   234     TRY o etac @{thm IntE}
       
   235     THEN' TRY o rtac @{thm conjI}
       
   236     THEN' (fn i => tac2_of_formula fm2 (i + 1))
       
   237     THEN' tac2_of_formula fm1
       
   238   | tac2_of_formula (Un (fm1, fm2)) =
       
   239     etac @{thm UnE} THEN' rtac @{thm disjI1}
       
   240     THEN' tac2_of_formula fm1
       
   241     THEN' rtac @{thm disjI2}
       
   242     THEN' tac2_of_formula fm2
       
   243   | tac2_of_formula (Atom _) =
       
   244     TRY o REPEAT_DETERM1 o
       
   245       (dtac @{thm iffD1[OF mem_Sigma_iff]}
       
   246        ORELSE' etac @{thm conjE}
       
   247        ORELSE' etac @{thm ComplE}
       
   248        ORELSE' atac)
       
   249 
       
   250 fun tac ctxt fm =
   130   let
   251   let
   131     val subset_tac1 = rtac @{thm subsetI}
   252     val subset_tac1 = rtac @{thm subsetI}
   132       THEN' dest_Collect_tac
   253       THEN' elim_Collect_tac
   133       THEN' intro_image_Sigma_tac
   254       THEN' (intro_image_tac ctxt)
   134       THEN' (REPEAT_DETERM1 o
   255       THEN' (tac1_of_formula fm)
   135         (rtac @{thm SigmaI}
       
   136         ORELSE' rtac @{thm UNIV_I}
       
   137         ORELSE' rtac @{thm IntI}
       
   138         ORELSE' atac));
       
   139 
       
   140     val subset_tac2 = rtac @{thm subsetI}
   256     val subset_tac2 = rtac @{thm subsetI}
   141       THEN' dest_image_Sigma_tac
   257       THEN' elim_image_tac
   142       THEN' intro_Collect_tac
   258       THEN' intro_Collect_tac
   143       THEN' REPEAT_DETERM o (eresolve_tac @{thms IntD1 IntD2} ORELSE' atac);
   259       THEN' tac2_of_formula fm
   144   in
   260   in
   145     rtac @{thm subset_antisym} THEN' subset_tac1 THEN' subset_tac2
   261     rtac @{thm subset_antisym} THEN' subset_tac1 THEN' subset_tac2
   146   end;
   262   end;
       
   263 
       
   264 
       
   265 (* main simprocs *)
   147 
   266 
   148 fun conv ctxt t =
   267 fun conv ctxt t =
   149   let
   268   let
   150     val ct = cterm_of (Proof_Context.theory_of ctxt) t
   269     val ct = cterm_of (Proof_Context.theory_of ctxt) t
   151     val Bex_def = mk_meta_eq @{thm Bex_def}
   270     val Bex_def = mk_meta_eq @{thm Bex_def}
   152     val unfold_eq = Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv Bex_def))) ctxt ct
   271     val unfold_eq = Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv Bex_def))) ctxt ct
   153     val t' = term_of (Thm.rhs_of unfold_eq) 
   272     val t' = term_of (Thm.rhs_of unfold_eq)
   154     fun mk_thm t'' = Goal.prove ctxt [] []
   273     fun mk_thm (fm, t'') = Goal.prove ctxt [] []
   155       (HOLogic.mk_Trueprop (HOLogic.mk_eq (t', t''))) (K (tac 1))
   274       (HOLogic.mk_Trueprop (HOLogic.mk_eq (t', t''))) (fn {context, ...} => tac context fm 1)
   156     fun unfold th = th RS ((unfold_eq RS meta_eq_to_obj_eq) RS @{thm trans})
   275     fun unfold th = th RS ((unfold_eq RS meta_eq_to_obj_eq) RS @{thm trans})
   157   in
   276   in
   158     Option.map (unfold o mk_thm) (rewrite_term t')
   277     Option.map (unfold o mk_thm) (rewrite_term t')
   159   end;
   278   end;
   160 
       
   161 (* simproc *)
       
   162 
   279 
   163 fun base_simproc ss redex =
   280 fun base_simproc ss redex =
   164   let
   281   let
   165     val ctxt = Simplifier.the_context ss
   282     val ctxt = Simplifier.the_context ss
   166     val set_compr = term_of redex
   283     val set_compr = term_of redex