improving code generation for multisets; adding exhaustive quickcheck generators for multisets
1.1 --- a/src/HOL/Library/Multiset.thy Tue Jan 10 10:17:07 2012 +0100
1.2 +++ b/src/HOL/Library/Multiset.thy Tue Jan 10 10:17:09 2012 +0100
1.3 @@ -5,7 +5,7 @@
1.4 header {* (Finite) multisets *}
1.5
1.6 theory Multiset
1.7 -imports Main
1.8 +imports Main AList
1.9 begin
1.10
1.11 subsection {* The type of multisets *}
1.12 @@ -1041,7 +1041,81 @@
1.13 by (cases "i = j") (simp_all add: multiset_of_update nth_mem_multiset_of)
1.14
1.15
1.16 -subsubsection {* Association lists -- including rudimentary code generation *}
1.17 +subsubsection {* Association lists -- including code generation *}
1.18 +
1.19 +text {* Preliminaries *}
1.20 +
1.21 +text {* Raw operations on lists *}
1.22 +
1.23 +definition join_raw :: "('key \<Rightarrow> 'val \<times> 'val \<Rightarrow> 'val) \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list"
1.24 +where
1.25 + "join_raw f xs ys = foldr (\<lambda>(k, v). map_default k v (%v'. f k (v', v))) ys xs"
1.26 +
1.27 +lemma join_raw_Nil [simp]:
1.28 + "join_raw f xs [] = xs"
1.29 +by (simp add: join_raw_def)
1.30 +
1.31 +lemma join_raw_Cons [simp]:
1.32 + "join_raw f xs ((k, v) # ys) = map_default k v (%v'. f k (v', v)) (join_raw f xs ys)"
1.33 +by (simp add: join_raw_def)
1.34 +
1.35 +lemma map_of_join_raw:
1.36 + assumes "distinct (map fst ys)"
1.37 + shows "map_of (join_raw f xs ys) x = (case map_of xs x of None => map_of ys x | Some v => (case map_of ys x of None => Some v | Some v' => Some (f x (v, v'))))"
1.38 +using assms
1.39 +apply (induct ys)
1.40 +apply (auto simp add: map_of_map_default split: option.split)
1.41 +apply (metis map_of_eq_None_iff option.simps(2) weak_map_of_SomeI)
1.42 +by (metis Some_eq_map_of_iff map_of_eq_None_iff option.simps(2))
1.43 +
1.44 +lemma distinct_join_raw:
1.45 + assumes "distinct (map fst xs)"
1.46 + shows "distinct (map fst (join_raw f xs ys))"
1.47 +using assms
1.48 +proof (induct ys)
1.49 + case (Cons y ys)
1.50 + thus ?case by (cases y) (simp add: distinct_map_default)
1.51 +qed auto
1.52 +
1.53 +definition
1.54 + "subtract_entries_raw xs ys = foldr (%(k, v). AList_Impl.map_entry k (%v'. v' - v)) ys xs"
1.55 +
1.56 +lemma map_of_subtract_entries_raw:
1.57 + "distinct (map fst ys) ==> map_of (subtract_entries_raw xs ys) x = (case map_of xs x of None => None | Some v => (case map_of ys x of None => Some v | Some v' => Some (v - v')))"
1.58 +unfolding subtract_entries_raw_def
1.59 +apply (induct ys)
1.60 +apply auto
1.61 +apply (simp split: option.split)
1.62 +apply (simp add: map_of_map_entry)
1.63 +apply (auto split: option.split)
1.64 +apply (metis map_of_eq_None_iff option.simps(3) option.simps(4))
1.65 +by (metis map_of_eq_None_iff option.simps(4) option.simps(5))
1.66 +
1.67 +lemma distinct_subtract_entries_raw:
1.68 + assumes "distinct (map fst xs)"
1.69 + shows "distinct (map fst (subtract_entries_raw xs ys))"
1.70 +using assms
1.71 +unfolding subtract_entries_raw_def by (induct ys) (auto simp add: distinct_map_entry)
1.72 +
1.73 +text {* Operations on alists *}
1.74 +
1.75 +definition join
1.76 +where
1.77 + "join f xs ys = AList.Alist (join_raw f (AList.impl_of xs) (AList.impl_of ys))"
1.78 +
1.79 +lemma [code abstract]:
1.80 + "AList.impl_of (join f xs ys) = join_raw f (AList.impl_of xs) (AList.impl_of ys)"
1.81 +unfolding join_def by (simp add: Alist_inverse distinct_join_raw)
1.82 +
1.83 +definition subtract_entries
1.84 +where
1.85 + "subtract_entries xs ys = AList.Alist (subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys))"
1.86 +
1.87 +lemma [code abstract]:
1.88 + "AList.impl_of (subtract_entries xs ys) = subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys)"
1.89 +unfolding subtract_entries_def by (simp add: Alist_inverse distinct_subtract_entries_raw)
1.90 +
1.91 +text {* Implementing multisets by means of association lists *}
1.92
1.93 definition count_of :: "('a \<times> nat) list \<Rightarrow> 'a \<Rightarrow> nat" where
1.94 "count_of xs x = (case map_of xs x of None \<Rightarrow> 0 | Some n \<Rightarrow> n)"
1.95 @@ -1074,32 +1148,55 @@
1.96 by (induct xs) (simp_all add: count_of_def)
1.97
1.98 lemma count_of_filter:
1.99 - "count_of (filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
1.100 + "count_of (List.filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
1.101 by (induct xs) auto
1.102
1.103 -definition Bag :: "('a \<times> nat) list \<Rightarrow> 'a multiset" where
1.104 - "Bag xs = Abs_multiset (count_of xs)"
1.105 +lemma count_of_map_default [simp]:
1.106 + "count_of (map_default x b (%x. x + b) xs) y = (if x = y then count_of xs x + b else count_of xs y)"
1.107 +unfolding count_of_def by (simp add: map_of_map_default split: option.split)
1.108 +
1.109 +lemma count_of_join_raw:
1.110 + "distinct (map fst ys) ==> count_of xs x + count_of ys x = count_of (join_raw (%x (x, y). x + y) xs ys) x"
1.111 +unfolding count_of_def by (simp add: map_of_join_raw split: option.split)
1.112 +
1.113 +lemma count_of_subtract_entries_raw:
1.114 + "distinct (map fst ys) ==> count_of xs x - count_of ys x = count_of (subtract_entries_raw xs ys) x"
1.115 +unfolding count_of_def by (simp add: map_of_subtract_entries_raw split: option.split)
1.116 +
1.117 +text {* Code equations for multiset operations *}
1.118 +
1.119 +definition Bag :: "('a, nat) alist \<Rightarrow> 'a multiset" where
1.120 + "Bag xs = Abs_multiset (count_of (AList.impl_of xs))"
1.121
1.122 code_datatype Bag
1.123
1.124 lemma count_Bag [simp, code]:
1.125 - "count (Bag xs) = count_of xs"
1.126 + "count (Bag xs) = count_of (AList.impl_of xs)"
1.127 by (simp add: Bag_def count_of_multiset Abs_multiset_inverse)
1.128
1.129 lemma Mempty_Bag [code]:
1.130 - "{#} = Bag []"
1.131 - by (simp add: multiset_eq_iff)
1.132 + "{#} = Bag (Alist [])"
1.133 + by (simp add: multiset_eq_iff alist.Alist_inverse)
1.134
1.135 lemma single_Bag [code]:
1.136 - "{#x#} = Bag [(x, 1)]"
1.137 - by (simp add: multiset_eq_iff)
1.138 + "{#x#} = Bag (Alist [(x, 1)])"
1.139 + by (simp add: multiset_eq_iff alist.Alist_inverse)
1.140 +
1.141 +lemma union_Bag [code]:
1.142 + "Bag xs + Bag ys = Bag (join (\<lambda>x (n1, n2). n1 + n2) xs ys)"
1.143 +by (rule multiset_eqI) (simp add: count_of_join_raw alist.Alist_inverse distinct_join_raw join_def)
1.144 +
1.145 +lemma minus_Bag [code]:
1.146 + "Bag xs - Bag ys = Bag (subtract_entries xs ys)"
1.147 +by (rule multiset_eqI)
1.148 + (simp add: count_of_subtract_entries_raw alist.Alist_inverse distinct_subtract_entries_raw subtract_entries_def)
1.149
1.150 lemma filter_Bag [code]:
1.151 - "Multiset.filter P (Bag xs) = Bag (filter (P \<circ> fst) xs)"
1.152 - by (rule multiset_eqI) (simp add: count_of_filter)
1.153 + "Multiset.filter P (Bag xs) = Bag (AList.filter (P \<circ> fst) xs)"
1.154 +by (rule multiset_eqI) (simp add: count_of_filter impl_of_filter)
1.155
1.156 lemma mset_less_eq_Bag [code]:
1.157 - "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set xs. count_of xs x \<le> count A x)"
1.158 + "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set (AList.impl_of xs). count_of (AList.impl_of xs) x \<le> count A x)"
1.159 (is "?lhs \<longleftrightarrow> ?rhs")
1.160 proof
1.161 assume ?lhs then show ?rhs
1.162 @@ -1109,8 +1206,8 @@
1.163 show ?lhs
1.164 proof (rule mset_less_eqI)
1.165 fix x
1.166 - from `?rhs` have "count_of xs x \<le> count A x"
1.167 - by (cases "x \<in> fst ` set xs") (auto simp add: count_of_empty)
1.168 + from `?rhs` have "count_of (AList.impl_of xs) x \<le> count A x"
1.169 + by (cases "x \<in> fst ` set (AList.impl_of xs)") (auto simp add: count_of_empty)
1.170 then show "count (Bag xs) x \<le> count A x"
1.171 by (simp add: mset_le_def count_Bag)
1.172 qed
1.173 @@ -1127,12 +1224,10 @@
1.174
1.175 end
1.176
1.177 -lemma [code nbe]:
1.178 - "HOL.equal (A :: 'a::equal multiset) A \<longleftrightarrow> True"
1.179 - by (fact equal_refl)
1.180 +text {* Quickcheck generators *}
1.181
1.182 definition (in term_syntax)
1.183 - bagify :: "('a\<Colon>typerep \<times> nat) list \<times> (unit \<Rightarrow> Code_Evaluation.term)
1.184 + bagify :: "('a\<Colon>typerep, nat) alist \<times> (unit \<Rightarrow> Code_Evaluation.term)
1.185 \<Rightarrow> 'a multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
1.186 [code_unfold]: "bagify xs = Code_Evaluation.valtermify Bag {\<cdot>} xs"
1.187
1.188 @@ -1152,6 +1247,28 @@
1.189 no_notation fcomp (infixl "\<circ>>" 60)
1.190 no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
1.191
1.192 +instantiation multiset :: (exhaustive) exhaustive
1.193 +begin
1.194 +
1.195 +definition exhaustive_multiset :: "('a multiset => (bool * term list) option) => code_numeral => (bool * term list) option"
1.196 +where
1.197 + "exhaustive_multiset f i = Quickcheck_Exhaustive.exhaustive (%xs. f (Bag xs)) i"
1.198 +
1.199 +instance ..
1.200 +
1.201 +end
1.202 +
1.203 +instantiation multiset :: (full_exhaustive) full_exhaustive
1.204 +begin
1.205 +
1.206 +definition full_exhaustive_multiset :: "('a multiset * (unit => term) => (bool * term list) option) => code_numeral => (bool * term list) option"
1.207 +where
1.208 + "full_exhaustive_multiset f i = Quickcheck_Exhaustive.full_exhaustive (%xs. f (bagify xs)) i"
1.209 +
1.210 +instance ..
1.211 +
1.212 +end
1.213 +
1.214 hide_const (open) bagify
1.215
1.216