src/HOL/Library/Efficient_Nat.thy
author haftmann
Thu, 29 Dec 2011 10:47:55 +0100
changeset 46899 9f113cdf3d66
parent 46664 331ebffe0593
child 47368 89ccf66aa73d
permissions -rw-r--r--
attribute code_abbrev superseedes code_unfold_post
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term Suc} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 lemma divmod_nat_code [code]:
    54   "divmod_nat n m = map_pair nat nat (pdivmod (of_nat n) (of_nat m))"
    55   by (simp add: map_pair_def split_def pdivmod_def nat_div_distrib nat_mod_distrib divmod_nat_div_mod)
    56 
    57 lemma eq_nat_code [code]:
    58   "HOL.equal n m \<longleftrightarrow> HOL.equal (of_nat n \<Colon> int) (of_nat m)"
    59   by (simp add: equal)
    60 
    61 lemma eq_nat_refl [code nbe]:
    62   "HOL.equal (n::nat) n \<longleftrightarrow> True"
    63   by (rule equal_refl)
    64 
    65 lemma less_eq_nat_code [code]:
    66   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    67   by simp
    68 
    69 lemma less_nat_code [code]:
    70   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    71   by simp
    72 
    73 subsection {* Case analysis *}
    74 
    75 text {*
    76   Case analysis on natural numbers is rephrased using a conditional
    77   expression:
    78 *}
    79 
    80 lemma [code, code_unfold]:
    81   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    82   by (auto simp add: fun_eq_iff dest!: gr0_implies_Suc)
    83 
    84 
    85 subsection {* Preprocessors *}
    86 
    87 text {*
    88   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    89   a constructor term. Therefore, all occurrences of this term in a position
    90   where a pattern is expected (i.e.\ on the left-hand side of a recursion
    91   equation or in the arguments of an inductive relation in an introduction
    92   rule) must be eliminated.
    93   This can be accomplished by applying the following transformation rules:
    94 *}
    95 
    96 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
    97   f n \<equiv> if n = 0 then g else h (n - 1)"
    98   by (rule eq_reflection) (cases n, simp_all)
    99 
   100 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   101   by (cases n) simp_all
   102 
   103 text {*
   104   The rules above are built into a preprocessor that is plugged into
   105   the code generator. Since the preprocessor for introduction rules
   106   does not know anything about modes, some of the modes that worked
   107   for the canonical representation of natural numbers may no longer work.
   108 *}
   109 
   110 (*<*)
   111 setup {*
   112 let
   113 
   114 fun remove_suc thy thms =
   115   let
   116     val vname = singleton (Name.variant_list (map fst
   117       (fold (Term.add_var_names o Thm.full_prop_of) thms []))) "n";
   118     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   119     fun lhs_of th = snd (Thm.dest_comb
   120       (fst (Thm.dest_comb (cprop_of th))));
   121     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   122     fun find_vars ct = (case term_of ct of
   123         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   124       | _ $ _ =>
   125         let val (ct1, ct2) = Thm.dest_comb ct
   126         in 
   127           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   128           map (apfst (Thm.capply ct1)) (find_vars ct2)
   129         end
   130       | _ => []);
   131     val eqs = maps
   132       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   133     fun mk_thms (th, (ct, cv')) =
   134       let
   135         val th' =
   136           Thm.implies_elim
   137            (Conv.fconv_rule (Thm.beta_conversion true)
   138              (Drule.instantiate'
   139                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   140                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   141                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   142       in
   143         case map_filter (fn th'' =>
   144             SOME (th'', singleton
   145               (Variable.trade (K (fn [th'''] => [th''' RS th']))
   146                 (Variable.global_thm_context th'')) th'')
   147           handle THM _ => NONE) thms of
   148             [] => NONE
   149           | thps =>
   150               let val (ths1, ths2) = split_list thps
   151               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   152       end
   153   in get_first mk_thms eqs end;
   154 
   155 fun eqn_suc_base_preproc thy thms =
   156   let
   157     val dest = fst o Logic.dest_equals o prop_of;
   158     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   159   in
   160     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   161       then thms |> perhaps_loop (remove_suc thy) |> (Option.map o map) Drule.zero_var_indexes
   162        else NONE
   163   end;
   164 
   165 val eqn_suc_preproc = Code_Preproc.simple_functrans eqn_suc_base_preproc;
   166 
   167 in
   168 
   169   Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc)
   170 
   171 end;
   172 *}
   173 (*>*)
   174 
   175 
   176 subsection {* Target language setup *}
   177 
   178 text {*
   179   For ML, we map @{typ nat} to target language integers, where we
   180   ensure that values are always non-negative.
   181 *}
   182 
   183 code_type nat
   184   (SML "IntInf.int")
   185   (OCaml "Big'_int.big'_int")
   186   (Eval "int")
   187 
   188 text {*
   189   For Haskell and Scala we define our own @{typ nat} type.  The reason
   190   is that we have to distinguish type class instances for @{typ nat}
   191   and @{typ int}.
   192 *}
   193 
   194 code_include Haskell "Nat"
   195 {*newtype Nat = Nat Integer deriving (Eq, Show, Read);
   196 
   197 instance Num Nat where {
   198   fromInteger k = Nat (if k >= 0 then k else 0);
   199   Nat n + Nat m = Nat (n + m);
   200   Nat n - Nat m = fromInteger (n - m);
   201   Nat n * Nat m = Nat (n * m);
   202   abs n = n;
   203   signum _ = 1;
   204   negate n = error "negate Nat";
   205 };
   206 
   207 instance Ord Nat where {
   208   Nat n <= Nat m = n <= m;
   209   Nat n < Nat m = n < m;
   210 };
   211 
   212 instance Real Nat where {
   213   toRational (Nat n) = toRational n;
   214 };
   215 
   216 instance Enum Nat where {
   217   toEnum k = fromInteger (toEnum k);
   218   fromEnum (Nat n) = fromEnum n;
   219 };
   220 
   221 instance Integral Nat where {
   222   toInteger (Nat n) = n;
   223   divMod n m = quotRem n m;
   224   quotRem (Nat n) (Nat m)
   225     | (m == 0) = (0, Nat n)
   226     | otherwise = (Nat k, Nat l) where (k, l) = quotRem n m;
   227 };
   228 *}
   229 
   230 code_reserved Haskell Nat
   231 
   232 code_include Scala "Nat"
   233 {*object Nat {
   234 
   235   def apply(numeral: BigInt): Nat = new Nat(numeral max 0)
   236   def apply(numeral: Int): Nat = Nat(BigInt(numeral))
   237   def apply(numeral: String): Nat = Nat(BigInt(numeral))
   238 
   239 }
   240 
   241 class Nat private(private val value: BigInt) {
   242 
   243   override def hashCode(): Int = this.value.hashCode()
   244 
   245   override def equals(that: Any): Boolean = that match {
   246     case that: Nat => this equals that
   247     case _ => false
   248   }
   249 
   250   override def toString(): String = this.value.toString
   251 
   252   def equals(that: Nat): Boolean = this.value == that.value
   253 
   254   def as_BigInt: BigInt = this.value
   255   def as_Int: Int = if (this.value >= scala.Int.MinValue && this.value <= scala.Int.MaxValue)
   256       this.value.intValue
   257     else error("Int value out of range: " + this.value.toString)
   258 
   259   def +(that: Nat): Nat = new Nat(this.value + that.value)
   260   def -(that: Nat): Nat = Nat(this.value - that.value)
   261   def *(that: Nat): Nat = new Nat(this.value * that.value)
   262 
   263   def /%(that: Nat): (Nat, Nat) = if (that.value == 0) (new Nat(0), this)
   264     else {
   265       val (k, l) = this.value /% that.value
   266       (new Nat(k), new Nat(l))
   267     }
   268 
   269   def <=(that: Nat): Boolean = this.value <= that.value
   270 
   271   def <(that: Nat): Boolean = this.value < that.value
   272 
   273 }
   274 *}
   275 
   276 code_reserved Scala Nat
   277 
   278 code_type nat
   279   (Haskell "Nat.Nat")
   280   (Scala "Nat")
   281 
   282 code_instance nat :: equal
   283   (Haskell -)
   284 
   285 text {*
   286   Natural numerals.
   287 *}
   288 
   289 lemma [code_abbrev]:
   290   "number_nat_inst.number_of_nat i = nat (number_of i)"
   291   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   292   by (simp add: number_nat_inst.number_of_nat)
   293 
   294 setup {*
   295   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   296     false Code_Printer.literal_positive_numeral) ["SML", "OCaml", "Haskell", "Scala"]
   297 *}
   298 
   299 text {*
   300   Since natural numbers are implemented
   301   using integers in ML, the coercion function @{const "of_nat"} of type
   302   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   303   For the @{const nat} function for converting an integer to a natural
   304   number, we give a specific implementation using an ML function that
   305   returns its input value, provided that it is non-negative, and otherwise
   306   returns @{text "0"}.
   307 *}
   308 
   309 definition int :: "nat \<Rightarrow> int" where
   310   [code del, code_abbrev]: "int = of_nat"
   311 
   312 lemma int_code' [code]:
   313   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   314   unfolding int_nat_number_of [folded int_def] ..
   315 
   316 lemma nat_code' [code]:
   317   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   318   unfolding nat_number_of_def number_of_is_id neg_def by simp
   319 
   320 lemma of_nat_int: (* FIXME delete candidate *)
   321   "of_nat = int" by (simp add: int_def)
   322 
   323 lemma of_nat_aux_int [code_unfold]:
   324   "of_nat_aux (\<lambda>i. i + 1) k 0 = int k"
   325   by (simp add: int_def Nat.of_nat_code)
   326 
   327 code_const int
   328   (SML "_")
   329   (OCaml "_")
   330 
   331 code_const nat
   332   (SML "IntInf.max/ (0,/ _)")
   333   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   334   (Eval "Integer.max/ _/ 0")
   335 
   336 text {* For Haskell and Scala, things are slightly different again. *}
   337 
   338 code_const int and nat
   339   (Haskell "toInteger" and "fromInteger")
   340   (Scala "!_.as'_BigInt" and "Nat")
   341 
   342 text {* Conversion from and to code numerals. *}
   343 
   344 code_const Code_Numeral.of_nat
   345   (SML "IntInf.toInt")
   346   (OCaml "_")
   347   (Haskell "!(fromInteger/ ./ toInteger)")
   348   (Scala "!Natural(_.as'_BigInt)")
   349   (Eval "_")
   350 
   351 code_const Code_Numeral.nat_of
   352   (SML "IntInf.fromInt")
   353   (OCaml "_")
   354   (Haskell "!(fromInteger/ ./ toInteger)")
   355   (Scala "!Nat(_.as'_BigInt)")
   356   (Eval "_")
   357 
   358 text {* Using target language arithmetic operations whenever appropriate *}
   359 
   360 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   361   (SML "IntInf.+ ((_), (_))")
   362   (OCaml "Big'_int.add'_big'_int")
   363   (Haskell infixl 6 "+")
   364   (Scala infixl 7 "+")
   365   (Eval infixl 8 "+")
   366 
   367 code_const "op - \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   368   (Haskell infixl 6 "-")
   369   (Scala infixl 7 "-")
   370 
   371 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   372   (SML "IntInf.* ((_), (_))")
   373   (OCaml "Big'_int.mult'_big'_int")
   374   (Haskell infixl 7 "*")
   375   (Scala infixl 8 "*")
   376   (Eval infixl 9 "*")
   377 
   378 code_const divmod_nat
   379   (SML "IntInf.divMod/ ((_),/ (_))")
   380   (OCaml "Big'_int.quomod'_big'_int")
   381   (Haskell "divMod")
   382   (Scala infixl 8 "/%")
   383   (Eval "Integer.div'_mod")
   384 
   385 code_const "HOL.equal \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   386   (SML "!((_ : IntInf.int) = _)")
   387   (OCaml "Big'_int.eq'_big'_int")
   388   (Haskell infix 4 "==")
   389   (Scala infixl 5 "==")
   390   (Eval infixl 6 "=")
   391 
   392 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   393   (SML "IntInf.<= ((_), (_))")
   394   (OCaml "Big'_int.le'_big'_int")
   395   (Haskell infix 4 "<=")
   396   (Scala infixl 4 "<=")
   397   (Eval infixl 6 "<=")
   398 
   399 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   400   (SML "IntInf.< ((_), (_))")
   401   (OCaml "Big'_int.lt'_big'_int")
   402   (Haskell infix 4 "<")
   403   (Scala infixl 4 "<")
   404   (Eval infixl 6 "<")
   405 
   406 
   407 text {* Evaluation *}
   408 
   409 lemma [code, code del]:
   410   "(Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term) = Code_Evaluation.term_of" ..
   411 
   412 code_const "Code_Evaluation.term_of \<Colon> nat \<Rightarrow> term"
   413   (SML "HOLogic.mk'_number/ HOLogic.natT")
   414 
   415 text {* Evaluation with @{text "Quickcheck_Narrowing"} does not work, as
   416   @{text "code_module"} is very aggressive leading to bad Haskell code.
   417   Therefore, we simply deactivate the narrowing-based quickcheck from here on.
   418 *}
   419 
   420 declare [[quickcheck_narrowing_active = false]] 
   421 
   422 text {* Module names *}
   423 
   424 code_modulename SML
   425   Efficient_Nat Arith
   426 
   427 code_modulename OCaml
   428   Efficient_Nat Arith
   429 
   430 code_modulename Haskell
   431   Efficient_Nat Arith
   432 
   433 hide_const int
   434 
   435 end