1 (* Title: HOL/Library/Efficient_Nat.thy
2 Author: Stefan Berghofer, Florian Haftmann, TU Muenchen
5 header {* Implementation of natural numbers by target-language integers *}
8 imports Code_Integer Main
12 When generating code for functions on natural numbers, the
13 canonical representation using @{term "0::nat"} and
14 @{term Suc} is unsuitable for computations involving large
15 numbers. The efficiency of the generated code can be improved
16 drastically by implementing natural numbers by target-language
17 integers. To do this, just include this theory.
20 subsection {* Basic arithmetic *}
23 Most standard arithmetic functions on natural numbers are implemented
24 using their counterparts on the integers:
27 code_datatype number_nat_inst.number_of_nat
29 lemma zero_nat_code [code, code_unfold_post]:
30 "0 = (Numeral0 :: nat)"
33 lemma one_nat_code [code, code_unfold_post]:
34 "1 = (Numeral1 :: nat)"
37 lemma Suc_code [code]:
41 lemma plus_nat_code [code]:
42 "n + m = nat (of_nat n + of_nat m)"
45 lemma minus_nat_code [code]:
46 "n - m = nat (of_nat n - of_nat m)"
49 lemma times_nat_code [code]:
50 "n * m = nat (of_nat n * of_nat m)"
51 unfolding of_nat_mult [symmetric] by simp
53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"}
54 and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
56 definition divmod_aux :: "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
57 [code del]: "divmod_aux = divmod_nat"
60 "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
61 unfolding divmod_aux_def divmod_nat_div_mod by simp
63 lemma divmod_aux_code [code]:
64 "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
65 unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
67 lemma eq_nat_code [code]:
68 "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
71 lemma eq_nat_refl [code nbe]:
72 "eq_class.eq (n::nat) n \<longleftrightarrow> True"
75 lemma less_eq_nat_code [code]:
76 "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
79 lemma less_nat_code [code]:
80 "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
83 subsection {* Case analysis *}
86 Case analysis on natural numbers is rephrased using a conditional
90 lemma [code, code_unfold]:
91 "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
92 by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
95 subsection {* Preprocessors *}
98 In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
99 a constructor term. Therefore, all occurrences of this term in a position
100 where a pattern is expected (i.e.\ on the left-hand side of a recursion
101 equation or in the arguments of an inductive relation in an introduction
102 rule) must be eliminated.
103 This can be accomplished by applying the following transformation rules:
106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
107 f n \<equiv> if n = 0 then g else h (n - 1)"
108 by (rule eq_reflection) (cases n, simp_all)
110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
111 by (cases n) simp_all
114 The rules above are built into a preprocessor that is plugged into
115 the code generator. Since the preprocessor for introduction rules
116 does not know anything about modes, some of the modes that worked
117 for the canonical representation of natural numbers may no longer work.
124 fun remove_suc thy thms =
126 val vname = Name.variant (map fst
127 (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
128 val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
129 fun lhs_of th = snd (Thm.dest_comb
130 (fst (Thm.dest_comb (cprop_of th))));
131 fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
132 fun find_vars ct = (case term_of ct of
133 (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
135 let val (ct1, ct2) = Thm.dest_comb ct
137 map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
138 map (apfst (Thm.capply ct1)) (find_vars ct2)
142 (fn th => map (pair th) (find_vars (lhs_of th))) thms;
143 fun mk_thms (th, (ct, cv')) =
147 (Conv.fconv_rule (Thm.beta_conversion true)
149 [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
150 SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
151 @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
153 case map_filter (fn th'' =>
154 SOME (th'', singleton
155 (Variable.trade (K (fn [th'''] => [th''' RS th']))
156 (Variable.global_thm_context th'')) th'')
157 handle THM _ => NONE) thms of
160 let val (ths1, ths2) = split_list thps
161 in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
163 in get_first mk_thms eqs end;
165 fun eqn_suc_base_preproc thy thms =
167 val dest = fst o Logic.dest_equals o prop_of;
168 val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
170 if forall (can dest) thms andalso exists (contains_suc o dest) thms
171 then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
175 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
177 fun remove_suc_clause thy thms =
179 val vname = Name.variant (map fst
180 (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
181 fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
182 | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
185 let val th' = Conv.fconv_rule Object_Logic.atomize th
186 in Option.map (pair (th, th')) (find_var (prop_of th')) end
188 case get_first find_thm thms of
190 | SOME ((th, th'), (Sucv, v)) =>
192 val cert = cterm_of (Thm.theory_of_thm th);
193 val th'' = Object_Logic.rulify (Thm.implies_elim
194 (Conv.fconv_rule (Thm.beta_conversion true)
195 (Drule.instantiate' []
196 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
198 HOLogic.dest_Trueprop (prop_of th')))))),
199 SOME (cert v)] @{thm Suc_clause}))
200 (Thm.forall_intr (cert v) th'))
202 remove_suc_clause thy (map (fn th''' =>
203 if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
207 fun clause_suc_preproc thy ths =
209 val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
211 if forall (can (dest o concl_of)) ths andalso
212 exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
213 (map_filter (try dest) (concl_of th :: prems_of th))) ths
214 then remove_suc_clause thy ths else ths
218 Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
219 #> Codegen.add_preprocessor clause_suc_preproc
226 subsection {* Target language setup *}
229 For ML, we map @{typ nat} to target language integers, where we
230 ensure that values are always non-negative.
235 (OCaml "Big'_int.big'_int")
240 val term_of_nat = HOLogic.mk_number HOLogic.natT;
244 let val n = random_range 0 i
245 in (n, fn () => term_of_nat n) end;
249 For Haskell ans Scala we define our own @{typ nat} type. The reason
250 is that we have to distinguish type class instances for @{typ nat}
254 code_include Haskell "Nat" {*
255 newtype Nat = Nat Integer deriving (Eq, Show, Read);
257 instance Num Nat where {
258 fromInteger k = Nat (if k >= 0 then k else 0);
259 Nat n + Nat m = Nat (n + m);
260 Nat n - Nat m = fromInteger (n - m);
261 Nat n * Nat m = Nat (n * m);
264 negate n = error "negate Nat";
267 instance Ord Nat where {
268 Nat n <= Nat m = n <= m;
269 Nat n < Nat m = n < m;
272 instance Real Nat where {
273 toRational (Nat n) = toRational n;
276 instance Enum Nat where {
277 toEnum k = fromInteger (toEnum k);
278 fromEnum (Nat n) = fromEnum n;
281 instance Integral Nat where {
282 toInteger (Nat n) = n;
283 divMod n m = quotRem n m;
284 quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
288 code_reserved Haskell Nat
290 code_include Scala "Nat" {*
295 def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
296 def apply(numeral: Int): Nat = Nat(BigInt(numeral))
297 def apply(numeral: String): Nat = Nat(BigInt(numeral))
301 class Nat private(private val value: BigInt) {
303 override def hashCode(): Int = this.value.hashCode()
305 override def equals(that: Any): Boolean = that match {
306 case that: Nat => this equals that
310 override def toString(): String = this.value.toString
312 def equals(that: Nat): Boolean = this.value == that.value
314 def as_BigInt: BigInt = this.value
315 def as_Int: Int = if (this.value >= Math.MAX_INT && this.value <= Math.MAX_INT)
317 else error("Int value too big:" + this.value.toString)
319 def +(that: Nat): Nat = new Nat(this.value + that.value)
320 def -(that: Nat): Nat = Nat(this.value - that.value)
321 def *(that: Nat): Nat = new Nat(this.value * that.value)
323 def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
325 val (k, l) = this.value /% that.value
326 (new Nat(k), new Nat(l))
329 def <=(that: Nat): Boolean = this.value <= that.value
331 def <(that: Nat): Boolean = this.value < that.value
336 code_reserved Scala Nat
342 code_instance nat :: eq
349 lemma [code_unfold_post]:
350 "nat (number_of i) = number_nat_inst.number_of_nat i"
351 -- {* this interacts as desired with @{thm nat_number_of_def} *}
352 by (simp add: number_nat_inst.number_of_nat)
355 fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
356 false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell"]
357 #> Numeral.add_code @{const_name number_nat_inst.number_of_nat}
358 false Code_Printer.literal_positive_numeral "Scala"
362 Since natural numbers are implemented
363 using integers in ML, the coercion function @{const "of_nat"} of type
364 @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
365 For the @{const nat} function for converting an integer to a natural
366 number, we give a specific implementation using an ML function that
367 returns its input value, provided that it is non-negative, and otherwise
371 definition int :: "nat \<Rightarrow> int" where
372 [code del]: "int = of_nat"
374 lemma int_code' [code]:
375 "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
376 unfolding int_nat_number_of [folded int_def] ..
378 lemma nat_code' [code]:
379 "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
380 unfolding nat_number_of_def number_of_is_id neg_def by simp
382 lemma of_nat_int [code_unfold_post]:
383 "of_nat = int" by (simp add: int_def)
385 lemma of_nat_aux_int [code_unfold]:
386 "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
387 by (simp add: int_def Nat.of_nat_code)
397 fun nat i = if i < 0 then 0 else i;
401 (SML "IntInf.max/ (/0,/ _)")
402 (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
404 text {* For Haskell and Scala, things are slightly different again. *}
406 code_const int and nat
407 (Haskell "toInteger" and "fromInteger")
408 (Scala "!_.as'_BigInt" and "!Nat.Nat((_))")
410 text {* Conversion from and to indices. *}
412 code_const Code_Numeral.of_nat
418 code_const Code_Numeral.nat_of
419 (SML "IntInf.fromInt")
422 (Scala "!Nat.Nat((_))")
424 text {* Using target language arithmetic operations whenever appropriate *}
426 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
427 (SML "IntInf.+ ((_), (_))")
428 (OCaml "Big'_int.add'_big'_int")
429 (Haskell infixl 6 "+")
432 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
433 (Haskell infixl 6 "-")
436 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
437 (SML "IntInf.* ((_), (_))")
438 (OCaml "Big'_int.mult'_big'_int")
439 (Haskell infixl 7 "*")
442 code_const divmod_aux
443 (SML "IntInf.divMod/ ((_),/ (_))")
444 (OCaml "Big'_int.quomod'_big'_int")
446 (Scala infixl 8 "/%")
448 code_const divmod_nat
450 (Scala infixl 8 "/%")
452 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
453 (SML "!((_ : IntInf.int) = _)")
454 (OCaml "Big'_int.eq'_big'_int")
455 (Haskell infixl 4 "==")
456 (Scala infixl 5 "==")
458 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
459 (SML "IntInf.<= ((_), (_))")
460 (OCaml "Big'_int.le'_big'_int")
461 (Haskell infix 4 "<=")
462 (Scala infixl 4 "<=")
464 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
465 (SML "IntInf.< ((_), (_))")
466 (OCaml "Big'_int.lt'_big'_int")
467 (Haskell infix 4 "<")
474 "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ +/ _)")
475 "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ */ _)")
476 "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" ("(_ <=/ _)")
477 "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" ("(_ </ _)")
480 text {* Evaluation *}
482 lemma [code, code del]:
483 "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
485 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
486 (SML "HOLogic.mk'_number/ HOLogic.natT")
489 text {* Module names *}
494 code_modulename OCaml
497 code_modulename Haskell