1 (* Title: HOL/Library/Efficient_Nat.thy
2 Author: Stefan Berghofer, Florian Haftmann, TU Muenchen
5 header {* Implementation of natural numbers by target-language integers *}
8 imports Code_Integer Main
12 When generating code for functions on natural numbers, the
13 canonical representation using @{term "0::nat"} and
14 @{term "Suc"} is unsuitable for computations involving large
15 numbers. The efficiency of the generated code can be improved
16 drastically by implementing natural numbers by target-language
17 integers. To do this, just include this theory.
20 subsection {* Basic arithmetic *}
23 Most standard arithmetic functions on natural numbers are implemented
24 using their counterparts on the integers:
27 code_datatype number_nat_inst.number_of_nat
29 lemma zero_nat_code [code, code_unfold_post]:
30 "0 = (Numeral0 :: nat)"
33 lemma one_nat_code [code, code_unfold_post]:
34 "1 = (Numeral1 :: nat)"
37 lemma Suc_code [code]:
41 lemma plus_nat_code [code]:
42 "n + m = nat (of_nat n + of_nat m)"
45 lemma minus_nat_code [code]:
46 "n - m = nat (of_nat n - of_nat m)"
49 lemma times_nat_code [code]:
50 "n * m = nat (of_nat n * of_nat m)"
51 unfolding of_nat_mult [symmetric] by simp
53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"}
54 and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
56 definition divmod_aux :: "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
57 [code del]: "divmod_aux = divmod_nat"
60 "divmod_nat n m = (if m = 0 then (0, n) else divmod_aux n m)"
61 unfolding divmod_aux_def divmod_nat_div_mod by simp
63 lemma divmod_aux_code [code]:
64 "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
65 unfolding divmod_aux_def divmod_nat_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
67 lemma eq_nat_code [code]:
68 "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
71 lemma eq_nat_refl [code nbe]:
72 "eq_class.eq (n::nat) n \<longleftrightarrow> True"
75 lemma less_eq_nat_code [code]:
76 "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
79 lemma less_nat_code [code]:
80 "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
83 subsection {* Case analysis *}
86 Case analysis on natural numbers is rephrased using a conditional
90 lemma [code, code_unfold]:
91 "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
92 by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
95 subsection {* Preprocessors *}
98 In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
99 a constructor term. Therefore, all occurrences of this term in a position
100 where a pattern is expected (i.e.\ on the left-hand side of a recursion
101 equation or in the arguments of an inductive relation in an introduction
102 rule) must be eliminated.
103 This can be accomplished by applying the following transformation rules:
106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
107 f n \<equiv> if n = 0 then g else h (n - 1)"
108 by (rule eq_reflection) (cases n, simp_all)
110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
111 by (cases n) simp_all
114 The rules above are built into a preprocessor that is plugged into
115 the code generator. Since the preprocessor for introduction rules
116 does not know anything about modes, some of the modes that worked
117 for the canonical representation of natural numbers may no longer work.
124 fun remove_suc thy thms =
126 val vname = Name.variant (map fst
127 (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
128 val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
129 fun lhs_of th = snd (Thm.dest_comb
130 (fst (Thm.dest_comb (cprop_of th))));
131 fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
132 fun find_vars ct = (case term_of ct of
133 (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
135 let val (ct1, ct2) = Thm.dest_comb ct
137 map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
138 map (apfst (Thm.capply ct1)) (find_vars ct2)
142 (fn th => map (pair th) (find_vars (lhs_of th))) thms;
143 fun mk_thms (th, (ct, cv')) =
147 (Conv.fconv_rule (Thm.beta_conversion true)
149 [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
150 SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
151 @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
153 case map_filter (fn th'' =>
154 SOME (th'', singleton
155 (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
156 handle THM _ => NONE) thms of
159 let val (ths1, ths2) = split_list thps
160 in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
162 in get_first mk_thms eqs end;
164 fun eqn_suc_base_preproc thy thms =
166 val dest = fst o Logic.dest_equals o prop_of;
167 val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
169 if forall (can dest) thms andalso exists (contains_suc o dest) thms
170 then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
174 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
176 fun remove_suc_clause thy thms =
178 val vname = Name.variant (map fst
179 (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
180 fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
181 | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
184 let val th' = Conv.fconv_rule Object_Logic.atomize th
185 in Option.map (pair (th, th')) (find_var (prop_of th')) end
187 case get_first find_thm thms of
189 | SOME ((th, th'), (Sucv, v)) =>
191 val cert = cterm_of (Thm.theory_of_thm th);
192 val th'' = Object_Logic.rulify (Thm.implies_elim
193 (Conv.fconv_rule (Thm.beta_conversion true)
194 (Drule.instantiate' []
195 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
197 HOLogic.dest_Trueprop (prop_of th')))))),
198 SOME (cert v)] @{thm Suc_clause}))
199 (Thm.forall_intr (cert v) th'))
201 remove_suc_clause thy (map (fn th''' =>
202 if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
206 fun clause_suc_preproc thy ths =
208 val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
210 if forall (can (dest o concl_of)) ths andalso
211 exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
212 (map_filter (try dest) (concl_of th :: prems_of th))) ths
213 then remove_suc_clause thy ths else ths
217 Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
218 #> Codegen.add_preprocessor clause_suc_preproc
225 subsection {* Target language setup *}
228 For ML, we map @{typ nat} to target language integers, where we
229 ensure that values are always non-negative.
234 (OCaml "Big'_int.big'_int")
239 val term_of_nat = HOLogic.mk_number HOLogic.natT;
243 let val n = random_range 0 i
244 in (n, fn () => term_of_nat n) end;
248 For Haskell ans Scala we define our own @{typ nat} type. The reason
249 is that we have to distinguish type class instances for @{typ nat}
253 code_include Haskell "Nat" {*
254 newtype Nat = Nat Integer deriving (Show, Eq);
256 instance Num Nat where {
257 fromInteger k = Nat (if k >= 0 then k else 0);
258 Nat n + Nat m = Nat (n + m);
259 Nat n - Nat m = fromInteger (n - m);
260 Nat n * Nat m = Nat (n * m);
263 negate n = error "negate Nat";
266 instance Ord Nat where {
267 Nat n <= Nat m = n <= m;
268 Nat n < Nat m = n < m;
271 instance Real Nat where {
272 toRational (Nat n) = toRational n;
275 instance Enum Nat where {
276 toEnum k = fromInteger (toEnum k);
277 fromEnum (Nat n) = fromEnum n;
280 instance Integral Nat where {
281 toInteger (Nat n) = n;
282 divMod n m = quotRem n m;
283 quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
287 code_reserved Haskell Nat
289 code_include Scala "Nat" {*
294 def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
295 def apply(numeral: Int): Nat = Nat(BigInt(numeral))
296 def apply(numeral: String): Nat = Nat(BigInt(numeral))
300 class Nat private(private val value: BigInt) {
302 override def hashCode(): Int = this.value.hashCode()
304 override def equals(that: Any): Boolean = that match {
305 case that: Nat => this equals that
309 override def toString(): String = this.value.toString
311 def equals(that: Nat): Boolean = this.value == that.value
313 def as_BigInt: BigInt = this.value
314 def as_Int: Int = if (this.value >= Math.MAX_INT && this.value <= Math.MAX_INT)
316 else error("Int value too big:" + this.value.toString)
318 def +(that: Nat): Nat = new Nat(this.value + that.value)
319 def -(that: Nat): Nat = Nat(this.value + that.value)
320 def *(that: Nat): Nat = new Nat(this.value * that.value)
322 def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
324 val (k, l) = this.value /% that.value
325 (new Nat(k), new Nat(l))
328 def <=(that: Nat): Boolean = this.value <= that.value
330 def <(that: Nat): Boolean = this.value < that.value
335 code_reserved Scala Nat
341 code_instance nat :: eq
348 lemma [code_unfold_post]:
349 "nat (number_of i) = number_nat_inst.number_of_nat i"
350 -- {* this interacts as desired with @{thm nat_number_of_def} *}
351 by (simp add: number_nat_inst.number_of_nat)
354 fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
355 false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell"]
356 #> Numeral.add_code @{const_name number_nat_inst.number_of_nat}
357 false Code_Printer.literal_positive_numeral "Scala"
361 Since natural numbers are implemented
362 using integers in ML, the coercion function @{const "of_nat"} of type
363 @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
364 For the @{const "nat"} function for converting an integer to a natural
365 number, we give a specific implementation using an ML function that
366 returns its input value, provided that it is non-negative, and otherwise
370 definition int :: "nat \<Rightarrow> int" where
371 [code del]: "int = of_nat"
373 lemma int_code' [code]:
374 "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
375 unfolding int_nat_number_of [folded int_def] ..
377 lemma nat_code' [code]:
378 "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
379 unfolding nat_number_of_def number_of_is_id neg_def by simp
381 lemma of_nat_int [code_unfold_post]:
382 "of_nat = int" by (simp add: int_def)
384 lemma of_nat_aux_int [code_unfold]:
385 "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
386 by (simp add: int_def Nat.of_nat_code)
396 fun nat i = if i < 0 then 0 else i;
400 (SML "IntInf.max/ (/0,/ _)")
401 (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
403 text {* For Haskell and Scala, things are slightly different again. *}
405 code_const int and nat
406 (Haskell "toInteger" and "fromInteger")
407 (Scala "!_.as'_BigInt" and "!Nat.Nat((_))")
409 text {* Conversion from and to indices. *}
411 code_const Code_Numeral.of_nat
417 code_const Code_Numeral.nat_of
418 (SML "IntInf.fromInt")
421 (Scala "!Nat.Nat((_))")
423 text {* Using target language arithmetic operations whenever appropriate *}
425 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
426 (SML "IntInf.+ ((_), (_))")
427 (OCaml "Big'_int.add'_big'_int")
428 (Haskell infixl 6 "+")
431 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
432 (Haskell infixl 6 "-")
435 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
436 (SML "IntInf.* ((_), (_))")
437 (OCaml "Big'_int.mult'_big'_int")
438 (Haskell infixl 7 "*")
441 code_const divmod_aux
442 (SML "IntInf.divMod/ ((_),/ (_))")
443 (OCaml "Big'_int.quomod'_big'_int")
445 (Scala infixl 8 "/%")
447 code_const divmod_nat
449 (Scala infixl 8 "/%")
451 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
452 (SML "!((_ : IntInf.int) = _)")
453 (OCaml "Big'_int.eq'_big'_int")
454 (Haskell infixl 4 "==")
455 (Scala infixl 5 "==")
457 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
458 (SML "IntInf.<= ((_), (_))")
459 (OCaml "Big'_int.le'_big'_int")
460 (Haskell infix 4 "<=")
461 (Scala infixl 4 "<=")
463 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
464 (SML "IntInf.< ((_), (_))")
465 (OCaml "Big'_int.lt'_big'_int")
466 (Haskell infix 4 "<")
473 "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ +/ _)")
474 "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" ("(_ */ _)")
475 "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" ("(_ <=/ _)")
476 "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" ("(_ </ _)")
479 text {* Evaluation *}
481 lemma [code, code del]:
482 "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
484 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
485 (SML "HOLogic.mk'_number/ HOLogic.natT")
488 text {* Module names *}
493 code_modulename OCaml
496 code_modulename Haskell