src/HOL/Library/Efficient_Nat.thy
author haftmann
Mon, 11 May 2009 10:53:19 +0200
changeset 31090 3be41b271023
parent 30663 0b6aff7451b2
child 31128 b3bb28c87409
permissions -rw-r--r--
clarified matter of "proper" flag in code equations
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Index Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code inline]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 lemmas [code post] = zero_nat_code [symmetric]
    33 
    34 lemma one_nat_code [code, code inline]:
    35   "1 = (Numeral1 :: nat)"
    36   by simp
    37 lemmas [code post] = one_nat_code [symmetric]
    38 
    39 lemma Suc_code [code]:
    40   "Suc n = n + 1"
    41   by simp
    42 
    43 lemma plus_nat_code [code]:
    44   "n + m = nat (of_nat n + of_nat m)"
    45   by simp
    46 
    47 lemma minus_nat_code [code]:
    48   "n - m = nat (of_nat n - of_nat m)"
    49   by simp
    50 
    51 lemma times_nat_code [code]:
    52   "n * m = nat (of_nat n * of_nat m)"
    53   unfolding of_nat_mult [symmetric] by simp
    54 
    55 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    56   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    57 
    58 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    59   [code del]: "divmod_aux = Divides.divmod"
    60 
    61 lemma [code]:
    62   "Divides.divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    63   unfolding divmod_aux_def divmod_div_mod by simp
    64 
    65 lemma divmod_aux_code [code]:
    66   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    67   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    68 
    69 lemma eq_nat_code [code]:
    70   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    71   by (simp add: eq)
    72 
    73 lemma eq_nat_refl [code nbe]:
    74   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    75   by (rule HOL.eq_refl)
    76 
    77 lemma less_eq_nat_code [code]:
    78   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    79   by simp
    80 
    81 lemma less_nat_code [code]:
    82   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    83   by simp
    84 
    85 subsection {* Case analysis *}
    86 
    87 text {*
    88   Case analysis on natural numbers is rephrased using a conditional
    89   expression:
    90 *}
    91 
    92 lemma [code, code unfold]:
    93   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    94   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    95 
    96 
    97 subsection {* Preprocessors *}
    98 
    99 text {*
   100   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   101   a constructor term. Therefore, all occurrences of this term in a position
   102   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   103   equation or in the arguments of an inductive relation in an introduction
   104   rule) must be eliminated.
   105   This can be accomplished by applying the following transformation rules:
   106 *}
   107 
   108 lemma Suc_if_eq': "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   109   f n = (if n = 0 then g else h (n - 1))"
   110   by (cases n) simp_all
   111 
   112 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   113   f n \<equiv> if n = 0 then g else h (n - 1)"
   114   by (rule eq_reflection, rule Suc_if_eq')
   115     (rule meta_eq_to_obj_eq, assumption,
   116      rule meta_eq_to_obj_eq, assumption)
   117 
   118 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   119   by (cases n) simp_all
   120 
   121 text {*
   122   The rules above are built into a preprocessor that is plugged into
   123   the code generator. Since the preprocessor for introduction rules
   124   does not know anything about modes, some of the modes that worked
   125   for the canonical representation of natural numbers may no longer work.
   126 *}
   127 
   128 (*<*)
   129 setup {*
   130 let
   131 
   132 fun gen_remove_suc Suc_if_eq dest_judgement thy thms =
   133   let
   134     val vname = Name.variant (map fst
   135       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   136     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   137     fun lhs_of th = snd (Thm.dest_comb
   138       (fst (Thm.dest_comb (dest_judgement (cprop_of th)))));
   139     fun rhs_of th = snd (Thm.dest_comb (dest_judgement (cprop_of th)));
   140     fun find_vars ct = (case term_of ct of
   141         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   142       | _ $ _ =>
   143         let val (ct1, ct2) = Thm.dest_comb ct
   144         in 
   145           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   146           map (apfst (Thm.capply ct1)) (find_vars ct2)
   147         end
   148       | _ => []);
   149     val eqs = maps
   150       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   151     fun mk_thms (th, (ct, cv')) =
   152       let
   153         val th' =
   154           Thm.implies_elim
   155            (Conv.fconv_rule (Thm.beta_conversion true)
   156              (Drule.instantiate'
   157                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   158                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   159                Suc_if_eq)) (Thm.forall_intr cv' th)
   160       in
   161         case map_filter (fn th'' =>
   162             SOME (th'', singleton
   163               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   164           handle THM _ => NONE) thms of
   165             [] => NONE
   166           | thps =>
   167               let val (ths1, ths2) = split_list thps
   168               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   169       end
   170   in get_first mk_thms eqs end;
   171 
   172 fun gen_eqn_suc_preproc Suc_if_eq dest_judgement dest_lhs thy thms =
   173   let
   174     val dest = dest_lhs o prop_of;
   175     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   176   in
   177     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   178       then perhaps_loop (gen_remove_suc Suc_if_eq dest_judgement thy) thms
   179        else NONE
   180   end;
   181 
   182 val eqn_suc_preproc = Code.simple_functrans (gen_eqn_suc_preproc
   183   @{thm Suc_if_eq} I (fst o Logic.dest_equals));
   184 
   185 fun eqn_suc_preproc' thy thms = gen_eqn_suc_preproc
   186   @{thm Suc_if_eq'} (snd o Thm.dest_comb) (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop) thy thms
   187   |> the_default thms;
   188 
   189 fun remove_suc_clause thy thms =
   190   let
   191     val vname = Name.variant (map fst
   192       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   193     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   194       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   195       | find_var _ = NONE;
   196     fun find_thm th =
   197       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   198       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   199   in
   200     case get_first find_thm thms of
   201       NONE => thms
   202     | SOME ((th, th'), (Sucv, v)) =>
   203         let
   204           val cert = cterm_of (Thm.theory_of_thm th);
   205           val th'' = ObjectLogic.rulify (Thm.implies_elim
   206             (Conv.fconv_rule (Thm.beta_conversion true)
   207               (Drule.instantiate' []
   208                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   209                    abstract_over (Sucv,
   210                      HOLogic.dest_Trueprop (prop_of th')))))),
   211                  SOME (cert v)] @{thm Suc_clause}))
   212             (Thm.forall_intr (cert v) th'))
   213         in
   214           remove_suc_clause thy (map (fn th''' =>
   215             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   216         end
   217   end;
   218 
   219 fun clause_suc_preproc thy ths =
   220   let
   221     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   222   in
   223     if forall (can (dest o concl_of)) ths andalso
   224       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   225         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   226     then remove_suc_clause thy ths else ths
   227   end;
   228 in
   229 
   230   Codegen.add_preprocessor eqn_suc_preproc'
   231   #> Codegen.add_preprocessor clause_suc_preproc
   232   #> Code.add_functrans ("eqn_Suc", eqn_suc_preproc)
   233 
   234 end;
   235 *}
   236 (*>*)
   237 
   238 
   239 subsection {* Target language setup *}
   240 
   241 text {*
   242   For ML, we map @{typ nat} to target language integers, where we
   243   assert that values are always non-negative.
   244 *}
   245 
   246 code_type nat
   247   (SML "IntInf.int")
   248   (OCaml "Big'_int.big'_int")
   249 
   250 types_code
   251   nat ("int")
   252 attach (term_of) {*
   253 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   254 *}
   255 attach (test) {*
   256 fun gen_nat i =
   257   let val n = random_range 0 i
   258   in (n, fn () => term_of_nat n) end;
   259 *}
   260 
   261 text {*
   262   For Haskell we define our own @{typ nat} type.  The reason
   263   is that we have to distinguish type class instances
   264   for @{typ nat} and @{typ int}.
   265 *}
   266 
   267 code_include Haskell "Nat" {*
   268 newtype Nat = Nat Integer deriving (Show, Eq);
   269 
   270 instance Num Nat where {
   271   fromInteger k = Nat (if k >= 0 then k else 0);
   272   Nat n + Nat m = Nat (n + m);
   273   Nat n - Nat m = fromInteger (n - m);
   274   Nat n * Nat m = Nat (n * m);
   275   abs n = n;
   276   signum _ = 1;
   277   negate n = error "negate Nat";
   278 };
   279 
   280 instance Ord Nat where {
   281   Nat n <= Nat m = n <= m;
   282   Nat n < Nat m = n < m;
   283 };
   284 
   285 instance Real Nat where {
   286   toRational (Nat n) = toRational n;
   287 };
   288 
   289 instance Enum Nat where {
   290   toEnum k = fromInteger (toEnum k);
   291   fromEnum (Nat n) = fromEnum n;
   292 };
   293 
   294 instance Integral Nat where {
   295   toInteger (Nat n) = n;
   296   divMod n m = quotRem n m;
   297   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   298 };
   299 *}
   300 
   301 code_reserved Haskell Nat
   302 
   303 code_type nat
   304   (Haskell "Nat.Nat")
   305 
   306 code_instance nat :: eq
   307   (Haskell -)
   308 
   309 text {*
   310   Natural numerals.
   311 *}
   312 
   313 lemma [code inline, symmetric, code post]:
   314   "nat (number_of i) = number_nat_inst.number_of_nat i"
   315   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   316   by (simp add: number_nat_inst.number_of_nat)
   317 
   318 setup {*
   319   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   320     true false) ["SML", "OCaml", "Haskell"]
   321 *}
   322 
   323 text {*
   324   Since natural numbers are implemented
   325   using integers in ML, the coercion function @{const "of_nat"} of type
   326   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   327   For the @{const "nat"} function for converting an integer to a natural
   328   number, we give a specific implementation using an ML function that
   329   returns its input value, provided that it is non-negative, and otherwise
   330   returns @{text "0"}.
   331 *}
   332 
   333 definition
   334   int :: "nat \<Rightarrow> int"
   335 where
   336   [code del]: "int = of_nat"
   337 
   338 lemma int_code' [code]:
   339   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   340   unfolding int_nat_number_of [folded int_def] ..
   341 
   342 lemma nat_code' [code]:
   343   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   344   unfolding nat_number_of_def number_of_is_id neg_def by simp
   345 
   346 lemma of_nat_int [code unfold]:
   347   "of_nat = int" by (simp add: int_def)
   348 declare of_nat_int [symmetric, code post]
   349 
   350 code_const int
   351   (SML "_")
   352   (OCaml "_")
   353 
   354 consts_code
   355   int ("(_)")
   356   nat ("\<module>nat")
   357 attach {*
   358 fun nat i = if i < 0 then 0 else i;
   359 *}
   360 
   361 code_const nat
   362   (SML "IntInf.max/ (/0,/ _)")
   363   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   364 
   365 text {* For Haskell, things are slightly different again. *}
   366 
   367 code_const int and nat
   368   (Haskell "toInteger" and "fromInteger")
   369 
   370 text {* Conversion from and to indices. *}
   371 
   372 code_const Code_Index.of_nat
   373   (SML "IntInf.toInt")
   374   (OCaml "Big'_int.int'_of'_big'_int")
   375   (Haskell "fromEnum")
   376 
   377 code_const Code_Index.nat_of
   378   (SML "IntInf.fromInt")
   379   (OCaml "Big'_int.big'_int'_of'_int")
   380   (Haskell "toEnum")
   381 
   382 text {* Using target language arithmetic operations whenever appropriate *}
   383 
   384 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   385   (SML "IntInf.+ ((_), (_))")
   386   (OCaml "Big'_int.add'_big'_int")
   387   (Haskell infixl 6 "+")
   388 
   389 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   390   (SML "IntInf.* ((_), (_))")
   391   (OCaml "Big'_int.mult'_big'_int")
   392   (Haskell infixl 7 "*")
   393 
   394 code_const divmod_aux
   395   (SML "IntInf.divMod/ ((_),/ (_))")
   396   (OCaml "Big'_int.quomod'_big'_int")
   397   (Haskell "divMod")
   398 
   399 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   400   (SML "!((_ : IntInf.int) = _)")
   401   (OCaml "Big'_int.eq'_big'_int")
   402   (Haskell infixl 4 "==")
   403 
   404 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   405   (SML "IntInf.<= ((_), (_))")
   406   (OCaml "Big'_int.le'_big'_int")
   407   (Haskell infix 4 "<=")
   408 
   409 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   410   (SML "IntInf.< ((_), (_))")
   411   (OCaml "Big'_int.lt'_big'_int")
   412   (Haskell infix 4 "<")
   413 
   414 consts_code
   415   "0::nat"                     ("0")
   416   "1::nat"                     ("1")
   417   Suc                          ("(_ +/ 1)")
   418   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   419   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   420   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   421   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   422 
   423 
   424 text {* Evaluation *}
   425 
   426 lemma [code, code del]:
   427   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   428 
   429 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   430   (SML "HOLogic.mk'_number/ HOLogic.natT")
   431 
   432 
   433 text {* Module names *}
   434 
   435 code_modulename SML
   436   Nat Integer
   437   Divides Integer
   438   Ring_and_Field Integer
   439   Efficient_Nat Integer
   440 
   441 code_modulename OCaml
   442   Nat Integer
   443   Divides Integer
   444   Ring_and_Field Integer
   445   Efficient_Nat Integer
   446 
   447 code_modulename Haskell
   448   Nat Integer
   449   Divides Integer
   450   Ring_and_Field Integer
   451   Efficient_Nat Integer
   452 
   453 hide const int
   454 
   455 end