improving code generation for multisets; adding exhaustive quickcheck generators for multisets
authorbulwahn
Tue, 10 Jan 2012 10:17:09 +0100
changeset 47039bef8c811df20
parent 47038 25eba8a5d7d0
child 47040 321abd584588
improving code generation for multisets; adding exhaustive quickcheck generators for multisets
src/HOL/Library/Multiset.thy
     1.1 --- a/src/HOL/Library/Multiset.thy	Tue Jan 10 10:17:07 2012 +0100
     1.2 +++ b/src/HOL/Library/Multiset.thy	Tue Jan 10 10:17:09 2012 +0100
     1.3 @@ -5,7 +5,7 @@
     1.4  header {* (Finite) multisets *}
     1.5  
     1.6  theory Multiset
     1.7 -imports Main
     1.8 +imports Main AList
     1.9  begin
    1.10  
    1.11  subsection {* The type of multisets *}
    1.12 @@ -1041,7 +1041,81 @@
    1.13    by (cases "i = j") (simp_all add: multiset_of_update nth_mem_multiset_of)
    1.14  
    1.15  
    1.16 -subsubsection {* Association lists -- including rudimentary code generation *}
    1.17 +subsubsection {* Association lists -- including code generation *}
    1.18 +
    1.19 +text {* Preliminaries *}
    1.20 +
    1.21 +text {* Raw operations on lists *}
    1.22 +
    1.23 +definition join_raw :: "('key \<Rightarrow> 'val \<times> 'val \<Rightarrow> 'val) \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list \<Rightarrow> ('key \<times> 'val) list"
    1.24 +where
    1.25 +  "join_raw f xs ys = foldr (\<lambda>(k, v). map_default k v (%v'. f k (v', v))) ys xs"
    1.26 +
    1.27 +lemma join_raw_Nil [simp]:
    1.28 +  "join_raw f xs [] = xs"
    1.29 +by (simp add: join_raw_def)
    1.30 +
    1.31 +lemma join_raw_Cons [simp]:
    1.32 +  "join_raw f xs ((k, v) # ys) = map_default k v (%v'. f k (v', v)) (join_raw f xs ys)"
    1.33 +by (simp add: join_raw_def)
    1.34 +
    1.35 +lemma map_of_join_raw:
    1.36 +  assumes "distinct (map fst ys)"
    1.37 +  shows "map_of (join_raw f xs ys) x = (case map_of xs x of None => map_of ys x | Some v => (case map_of ys x of None => Some v | Some v' => Some (f x (v, v'))))"
    1.38 +using assms
    1.39 +apply (induct ys)
    1.40 +apply (auto simp add: map_of_map_default split: option.split)
    1.41 +apply (metis map_of_eq_None_iff option.simps(2) weak_map_of_SomeI)
    1.42 +by (metis Some_eq_map_of_iff map_of_eq_None_iff option.simps(2))
    1.43 +
    1.44 +lemma distinct_join_raw:
    1.45 +  assumes "distinct (map fst xs)"
    1.46 +  shows "distinct (map fst (join_raw f xs ys))"
    1.47 +using assms
    1.48 +proof (induct ys)
    1.49 +  case (Cons y ys)
    1.50 +  thus ?case by (cases y) (simp add: distinct_map_default)
    1.51 +qed auto
    1.52 +
    1.53 +definition
    1.54 +  "subtract_entries_raw xs ys = foldr (%(k, v). AList_Impl.map_entry k (%v'. v' - v)) ys xs"
    1.55 +
    1.56 +lemma map_of_subtract_entries_raw:
    1.57 +  "distinct (map fst ys) ==> map_of (subtract_entries_raw xs ys) x = (case map_of xs x of None => None | Some v => (case map_of ys x of None => Some v | Some v' => Some (v - v')))"
    1.58 +unfolding subtract_entries_raw_def
    1.59 +apply (induct ys)
    1.60 +apply auto
    1.61 +apply (simp split: option.split)
    1.62 +apply (simp add: map_of_map_entry)
    1.63 +apply (auto split: option.split)
    1.64 +apply (metis map_of_eq_None_iff option.simps(3) option.simps(4))
    1.65 +by (metis map_of_eq_None_iff option.simps(4) option.simps(5))
    1.66 +
    1.67 +lemma distinct_subtract_entries_raw:
    1.68 +  assumes "distinct (map fst xs)"
    1.69 +  shows "distinct (map fst (subtract_entries_raw xs ys))"
    1.70 +using assms
    1.71 +unfolding subtract_entries_raw_def by (induct ys) (auto simp add: distinct_map_entry)
    1.72 +
    1.73 +text {* Operations on alists *}
    1.74 +
    1.75 +definition join
    1.76 +where
    1.77 +  "join f xs ys = AList.Alist (join_raw f (AList.impl_of xs) (AList.impl_of ys))" 
    1.78 +
    1.79 +lemma [code abstract]:
    1.80 +  "AList.impl_of (join f xs ys) = join_raw f (AList.impl_of xs) (AList.impl_of ys)"
    1.81 +unfolding join_def by (simp add: Alist_inverse distinct_join_raw)
    1.82 +
    1.83 +definition subtract_entries
    1.84 +where
    1.85 +  "subtract_entries xs ys = AList.Alist (subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys))"
    1.86 +
    1.87 +lemma [code abstract]:
    1.88 +  "AList.impl_of (subtract_entries xs ys) = subtract_entries_raw (AList.impl_of xs) (AList.impl_of ys)"
    1.89 +unfolding subtract_entries_def by (simp add: Alist_inverse distinct_subtract_entries_raw)
    1.90 +
    1.91 +text {* Implementing multisets by means of association lists *}
    1.92  
    1.93  definition count_of :: "('a \<times> nat) list \<Rightarrow> 'a \<Rightarrow> nat" where
    1.94    "count_of xs x = (case map_of xs x of None \<Rightarrow> 0 | Some n \<Rightarrow> n)"
    1.95 @@ -1074,32 +1148,55 @@
    1.96    by (induct xs) (simp_all add: count_of_def)
    1.97  
    1.98  lemma count_of_filter:
    1.99 -  "count_of (filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
   1.100 +  "count_of (List.filter (P \<circ> fst) xs) x = (if P x then count_of xs x else 0)"
   1.101    by (induct xs) auto
   1.102  
   1.103 -definition Bag :: "('a \<times> nat) list \<Rightarrow> 'a multiset" where
   1.104 -  "Bag xs = Abs_multiset (count_of xs)"
   1.105 +lemma count_of_map_default [simp]:
   1.106 +  "count_of (map_default x b (%x. x + b) xs) y = (if x = y then count_of xs x + b else count_of xs y)"
   1.107 +unfolding count_of_def by (simp add: map_of_map_default split: option.split)
   1.108 +
   1.109 +lemma count_of_join_raw:
   1.110 +  "distinct (map fst ys) ==> count_of xs x + count_of ys x = count_of (join_raw (%x (x, y). x + y) xs ys) x"
   1.111 +unfolding count_of_def by (simp add: map_of_join_raw split: option.split)
   1.112 +
   1.113 +lemma count_of_subtract_entries_raw:
   1.114 +  "distinct (map fst ys) ==> count_of xs x - count_of ys x = count_of (subtract_entries_raw xs ys) x"
   1.115 +unfolding count_of_def by (simp add: map_of_subtract_entries_raw split: option.split)
   1.116 +
   1.117 +text {* Code equations for multiset operations *}
   1.118 +
   1.119 +definition Bag :: "('a, nat) alist \<Rightarrow> 'a multiset" where
   1.120 +  "Bag xs = Abs_multiset (count_of (AList.impl_of xs))"
   1.121  
   1.122  code_datatype Bag
   1.123  
   1.124  lemma count_Bag [simp, code]:
   1.125 -  "count (Bag xs) = count_of xs"
   1.126 +  "count (Bag xs) = count_of (AList.impl_of xs)"
   1.127    by (simp add: Bag_def count_of_multiset Abs_multiset_inverse)
   1.128  
   1.129  lemma Mempty_Bag [code]:
   1.130 -  "{#} = Bag []"
   1.131 -  by (simp add: multiset_eq_iff)
   1.132 +  "{#} = Bag (Alist [])"
   1.133 +  by (simp add: multiset_eq_iff alist.Alist_inverse)
   1.134    
   1.135  lemma single_Bag [code]:
   1.136 -  "{#x#} = Bag [(x, 1)]"
   1.137 -  by (simp add: multiset_eq_iff)
   1.138 +  "{#x#} = Bag (Alist [(x, 1)])"
   1.139 +  by (simp add: multiset_eq_iff alist.Alist_inverse)
   1.140 +
   1.141 +lemma union_Bag [code]:
   1.142 +  "Bag xs + Bag ys = Bag (join (\<lambda>x (n1, n2). n1 + n2) xs ys)"
   1.143 +by (rule multiset_eqI) (simp add: count_of_join_raw alist.Alist_inverse distinct_join_raw join_def)
   1.144 +
   1.145 +lemma minus_Bag [code]:
   1.146 +  "Bag xs - Bag ys = Bag (subtract_entries xs ys)"
   1.147 +by (rule multiset_eqI)
   1.148 +  (simp add: count_of_subtract_entries_raw alist.Alist_inverse distinct_subtract_entries_raw subtract_entries_def)
   1.149  
   1.150  lemma filter_Bag [code]:
   1.151 -  "Multiset.filter P (Bag xs) = Bag (filter (P \<circ> fst) xs)"
   1.152 -  by (rule multiset_eqI) (simp add: count_of_filter)
   1.153 +  "Multiset.filter P (Bag xs) = Bag (AList.filter (P \<circ> fst) xs)"
   1.154 +by (rule multiset_eqI) (simp add: count_of_filter impl_of_filter)
   1.155  
   1.156  lemma mset_less_eq_Bag [code]:
   1.157 -  "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set xs. count_of xs x \<le> count A x)"
   1.158 +  "Bag xs \<le> A \<longleftrightarrow> (\<forall>(x, n) \<in> set (AList.impl_of xs). count_of (AList.impl_of xs) x \<le> count A x)"
   1.159      (is "?lhs \<longleftrightarrow> ?rhs")
   1.160  proof
   1.161    assume ?lhs then show ?rhs
   1.162 @@ -1109,8 +1206,8 @@
   1.163    show ?lhs
   1.164    proof (rule mset_less_eqI)
   1.165      fix x
   1.166 -    from `?rhs` have "count_of xs x \<le> count A x"
   1.167 -      by (cases "x \<in> fst ` set xs") (auto simp add: count_of_empty)
   1.168 +    from `?rhs` have "count_of (AList.impl_of xs) x \<le> count A x"
   1.169 +      by (cases "x \<in> fst ` set (AList.impl_of xs)") (auto simp add: count_of_empty)
   1.170      then show "count (Bag xs) x \<le> count A x"
   1.171        by (simp add: mset_le_def count_Bag)
   1.172    qed
   1.173 @@ -1127,12 +1224,10 @@
   1.174  
   1.175  end
   1.176  
   1.177 -lemma [code nbe]:
   1.178 -  "HOL.equal (A :: 'a::equal multiset) A \<longleftrightarrow> True"
   1.179 -  by (fact equal_refl)
   1.180 +text {* Quickcheck generators *}
   1.181  
   1.182  definition (in term_syntax)
   1.183 -  bagify :: "('a\<Colon>typerep \<times> nat) list \<times> (unit \<Rightarrow> Code_Evaluation.term)
   1.184 +  bagify :: "('a\<Colon>typerep, nat) alist \<times> (unit \<Rightarrow> Code_Evaluation.term)
   1.185      \<Rightarrow> 'a multiset \<times> (unit \<Rightarrow> Code_Evaluation.term)" where
   1.186    [code_unfold]: "bagify xs = Code_Evaluation.valtermify Bag {\<cdot>} xs"
   1.187  
   1.188 @@ -1152,6 +1247,28 @@
   1.189  no_notation fcomp (infixl "\<circ>>" 60)
   1.190  no_notation scomp (infixl "\<circ>\<rightarrow>" 60)
   1.191  
   1.192 +instantiation multiset :: (exhaustive) exhaustive
   1.193 +begin
   1.194 +
   1.195 +definition exhaustive_multiset :: "('a multiset => (bool * term list) option) => code_numeral => (bool * term list) option"
   1.196 +where
   1.197 +  "exhaustive_multiset f i = Quickcheck_Exhaustive.exhaustive (%xs. f (Bag xs)) i"
   1.198 +
   1.199 +instance ..
   1.200 +
   1.201 +end
   1.202 +
   1.203 +instantiation multiset :: (full_exhaustive) full_exhaustive
   1.204 +begin
   1.205 +
   1.206 +definition full_exhaustive_multiset :: "('a multiset * (unit => term) => (bool * term list) option) => code_numeral => (bool * term list) option"
   1.207 +where
   1.208 +  "full_exhaustive_multiset f i = Quickcheck_Exhaustive.full_exhaustive (%xs. f (bagify xs)) i"
   1.209 +
   1.210 +instance ..
   1.211 +
   1.212 +end
   1.213 +
   1.214  hide_const (open) bagify
   1.215  
   1.216