src/HOL/Library/Efficient_Nat.thy
author haftmann
Tue, 14 Jul 2009 16:27:32 +0200
changeset 32061 6d28bbd33e2c
parent 31998 2c7a24f74db9
child 32065 0a83608e21f1
permissions -rw-r--r--
prefer code_inline over code_unfold; use code_unfold_post where appropriate
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code_unfold_post]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 
    33 lemma one_nat_code [code, code_unfold_post]:
    34   "1 = (Numeral1 :: nat)"
    35   by simp
    36 
    37 lemma Suc_code [code]:
    38   "Suc n = n + 1"
    39   by simp
    40 
    41 lemma plus_nat_code [code]:
    42   "n + m = nat (of_nat n + of_nat m)"
    43   by simp
    44 
    45 lemma minus_nat_code [code]:
    46   "n - m = nat (of_nat n - of_nat m)"
    47   by simp
    48 
    49 lemma times_nat_code [code]:
    50   "n * m = nat (of_nat n * of_nat m)"
    51   unfolding of_nat_mult [symmetric] by simp
    52 
    53 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    54   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    55 
    56 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    57   [code del]: "divmod_aux = Divides.divmod"
    58 
    59 lemma [code]:
    60   "Divides.divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    61   unfolding divmod_aux_def divmod_div_mod by simp
    62 
    63 lemma divmod_aux_code [code]:
    64   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    65   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    66 
    67 lemma eq_nat_code [code]:
    68   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    69   by (simp add: eq)
    70 
    71 lemma eq_nat_refl [code nbe]:
    72   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    73   by (rule HOL.eq_refl)
    74 
    75 lemma less_eq_nat_code [code]:
    76   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    77   by simp
    78 
    79 lemma less_nat_code [code]:
    80   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    81   by simp
    82 
    83 subsection {* Case analysis *}
    84 
    85 text {*
    86   Case analysis on natural numbers is rephrased using a conditional
    87   expression:
    88 *}
    89 
    90 lemma [code, code_unfold]:
    91   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    92   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    93 
    94 
    95 subsection {* Preprocessors *}
    96 
    97 text {*
    98   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
    99   a constructor term. Therefore, all occurrences of this term in a position
   100   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   101   equation or in the arguments of an inductive relation in an introduction
   102   rule) must be eliminated.
   103   This can be accomplished by applying the following transformation rules:
   104 *}
   105 
   106 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   107   f n \<equiv> if n = 0 then g else h (n - 1)"
   108   by (rule eq_reflection) (cases n, simp_all)
   109 
   110 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   111   by (cases n) simp_all
   112 
   113 text {*
   114   The rules above are built into a preprocessor that is plugged into
   115   the code generator. Since the preprocessor for introduction rules
   116   does not know anything about modes, some of the modes that worked
   117   for the canonical representation of natural numbers may no longer work.
   118 *}
   119 
   120 (*<*)
   121 setup {*
   122 let
   123 
   124 fun remove_suc thy thms =
   125   let
   126     val vname = Name.variant (map fst
   127       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   128     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   129     fun lhs_of th = snd (Thm.dest_comb
   130       (fst (Thm.dest_comb (cprop_of th))));
   131     fun rhs_of th = snd (Thm.dest_comb (cprop_of th));
   132     fun find_vars ct = (case term_of ct of
   133         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   134       | _ $ _ =>
   135         let val (ct1, ct2) = Thm.dest_comb ct
   136         in 
   137           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   138           map (apfst (Thm.capply ct1)) (find_vars ct2)
   139         end
   140       | _ => []);
   141     val eqs = maps
   142       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   143     fun mk_thms (th, (ct, cv')) =
   144       let
   145         val th' =
   146           Thm.implies_elim
   147            (Conv.fconv_rule (Thm.beta_conversion true)
   148              (Drule.instantiate'
   149                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   150                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   151                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   152       in
   153         case map_filter (fn th'' =>
   154             SOME (th'', singleton
   155               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   156           handle THM _ => NONE) thms of
   157             [] => NONE
   158           | thps =>
   159               let val (ths1, ths2) = split_list thps
   160               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   161       end
   162   in get_first mk_thms eqs end;
   163 
   164 fun eqn_suc_preproc thy thms =
   165   let
   166     val dest = fst o Logic.dest_equals o prop_of;
   167     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   168   in
   169     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   170       then perhaps_loop (remove_suc thy) thms
   171        else NONE
   172   end;
   173 
   174 val eqn_suc_preproc1 = Code_Preproc.simple_functrans eqn_suc_preproc;
   175 
   176 fun eqn_suc_preproc2 thy thms = eqn_suc_preproc thy thms
   177   |> the_default thms;
   178 
   179 fun remove_suc_clause thy thms =
   180   let
   181     val vname = Name.variant (map fst
   182       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   183     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   184       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   185       | find_var _ = NONE;
   186     fun find_thm th =
   187       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   188       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   189   in
   190     case get_first find_thm thms of
   191       NONE => thms
   192     | SOME ((th, th'), (Sucv, v)) =>
   193         let
   194           val cert = cterm_of (Thm.theory_of_thm th);
   195           val th'' = ObjectLogic.rulify (Thm.implies_elim
   196             (Conv.fconv_rule (Thm.beta_conversion true)
   197               (Drule.instantiate' []
   198                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   199                    abstract_over (Sucv,
   200                      HOLogic.dest_Trueprop (prop_of th')))))),
   201                  SOME (cert v)] @{thm Suc_clause}))
   202             (Thm.forall_intr (cert v) th'))
   203         in
   204           remove_suc_clause thy (map (fn th''' =>
   205             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   206         end
   207   end;
   208 
   209 fun clause_suc_preproc thy ths =
   210   let
   211     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   212   in
   213     if forall (can (dest o concl_of)) ths andalso
   214       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   215         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   216     then remove_suc_clause thy ths else ths
   217   end;
   218 in
   219 
   220   Codegen.add_preprocessor eqn_suc_preproc2
   221   #> Codegen.add_preprocessor clause_suc_preproc
   222   #> Code_Preproc.add_functrans ("eqn_Suc", eqn_suc_preproc1)
   223 
   224 end;
   225 *}
   226 (*>*)
   227 
   228 
   229 subsection {* Target language setup *}
   230 
   231 text {*
   232   For ML, we map @{typ nat} to target language integers, where we
   233   assert that values are always non-negative.
   234 *}
   235 
   236 code_type nat
   237   (SML "IntInf.int")
   238   (OCaml "Big'_int.big'_int")
   239 
   240 types_code
   241   nat ("int")
   242 attach (term_of) {*
   243 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   244 *}
   245 attach (test) {*
   246 fun gen_nat i =
   247   let val n = random_range 0 i
   248   in (n, fn () => term_of_nat n) end;
   249 *}
   250 
   251 text {*
   252   For Haskell we define our own @{typ nat} type.  The reason
   253   is that we have to distinguish type class instances
   254   for @{typ nat} and @{typ int}.
   255 *}
   256 
   257 code_include Haskell "Nat" {*
   258 newtype Nat = Nat Integer deriving (Show, Eq);
   259 
   260 instance Num Nat where {
   261   fromInteger k = Nat (if k >= 0 then k else 0);
   262   Nat n + Nat m = Nat (n + m);
   263   Nat n - Nat m = fromInteger (n - m);
   264   Nat n * Nat m = Nat (n * m);
   265   abs n = n;
   266   signum _ = 1;
   267   negate n = error "negate Nat";
   268 };
   269 
   270 instance Ord Nat where {
   271   Nat n <= Nat m = n <= m;
   272   Nat n < Nat m = n < m;
   273 };
   274 
   275 instance Real Nat where {
   276   toRational (Nat n) = toRational n;
   277 };
   278 
   279 instance Enum Nat where {
   280   toEnum k = fromInteger (toEnum k);
   281   fromEnum (Nat n) = fromEnum n;
   282 };
   283 
   284 instance Integral Nat where {
   285   toInteger (Nat n) = n;
   286   divMod n m = quotRem n m;
   287   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   288 };
   289 *}
   290 
   291 code_reserved Haskell Nat
   292 
   293 code_type nat
   294   (Haskell "Nat.Nat")
   295 
   296 code_instance nat :: eq
   297   (Haskell -)
   298 
   299 text {*
   300   Natural numerals.
   301 *}
   302 
   303 lemma [code_unfold_post]:
   304   "nat (number_of i) = number_nat_inst.number_of_nat i"
   305   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   306   by (simp add: number_nat_inst.number_of_nat)
   307 
   308 setup {*
   309   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   310     false true) ["SML", "OCaml", "Haskell"]
   311 *}
   312 
   313 text {*
   314   Since natural numbers are implemented
   315   using integers in ML, the coercion function @{const "of_nat"} of type
   316   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   317   For the @{const "nat"} function for converting an integer to a natural
   318   number, we give a specific implementation using an ML function that
   319   returns its input value, provided that it is non-negative, and otherwise
   320   returns @{text "0"}.
   321 *}
   322 
   323 definition
   324   int :: "nat \<Rightarrow> int"
   325 where
   326   [code del]: "int = of_nat"
   327 
   328 lemma int_code' [code]:
   329   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   330   unfolding int_nat_number_of [folded int_def] ..
   331 
   332 lemma nat_code' [code]:
   333   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   334   unfolding nat_number_of_def number_of_is_id neg_def by simp
   335 
   336 lemma of_nat_int [code_unfold_post]:
   337   "of_nat = int" by (simp add: int_def)
   338 
   339 code_const int
   340   (SML "_")
   341   (OCaml "_")
   342 
   343 consts_code
   344   int ("(_)")
   345   nat ("\<module>nat")
   346 attach {*
   347 fun nat i = if i < 0 then 0 else i;
   348 *}
   349 
   350 code_const nat
   351   (SML "IntInf.max/ (/0,/ _)")
   352   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   353 
   354 text {* For Haskell, things are slightly different again. *}
   355 
   356 code_const int and nat
   357   (Haskell "toInteger" and "fromInteger")
   358 
   359 text {* Conversion from and to indices. *}
   360 
   361 code_const Code_Numeral.of_nat
   362   (SML "IntInf.toInt")
   363   (OCaml "_")
   364   (Haskell "fromEnum")
   365 
   366 code_const Code_Numeral.nat_of
   367   (SML "IntInf.fromInt")
   368   (OCaml "_")
   369   (Haskell "toEnum")
   370 
   371 text {* Using target language arithmetic operations whenever appropriate *}
   372 
   373 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   374   (SML "IntInf.+ ((_), (_))")
   375   (OCaml "Big'_int.add'_big'_int")
   376   (Haskell infixl 6 "+")
   377 
   378 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   379   (SML "IntInf.* ((_), (_))")
   380   (OCaml "Big'_int.mult'_big'_int")
   381   (Haskell infixl 7 "*")
   382 
   383 code_const divmod_aux
   384   (SML "IntInf.divMod/ ((_),/ (_))")
   385   (OCaml "Big'_int.quomod'_big'_int")
   386   (Haskell "divMod")
   387 
   388 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   389   (SML "!((_ : IntInf.int) = _)")
   390   (OCaml "Big'_int.eq'_big'_int")
   391   (Haskell infixl 4 "==")
   392 
   393 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   394   (SML "IntInf.<= ((_), (_))")
   395   (OCaml "Big'_int.le'_big'_int")
   396   (Haskell infix 4 "<=")
   397 
   398 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   399   (SML "IntInf.< ((_), (_))")
   400   (OCaml "Big'_int.lt'_big'_int")
   401   (Haskell infix 4 "<")
   402 
   403 consts_code
   404   "0::nat"                     ("0")
   405   "1::nat"                     ("1")
   406   Suc                          ("(_ +/ 1)")
   407   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   408   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   409   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   410   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   411 
   412 
   413 text {* Evaluation *}
   414 
   415 lemma [code, code del]:
   416   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   417 
   418 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   419   (SML "HOLogic.mk'_number/ HOLogic.natT")
   420 
   421 
   422 text {* Module names *}
   423 
   424 code_modulename SML
   425   Nat Integer
   426   Divides Integer
   427   Ring_and_Field Integer
   428   Efficient_Nat Integer
   429 
   430 code_modulename OCaml
   431   Nat Integer
   432   Divides Integer
   433   Ring_and_Field Integer
   434   Efficient_Nat Integer
   435 
   436 code_modulename Haskell
   437   Nat Integer
   438   Divides Integer
   439   Ring_and_Field Integer
   440   Efficient_Nat Integer
   441 
   442 hide const int
   443 
   444 end