src/HOL/Library/Efficient_Nat.thy
author haftmann
Wed, 28 Jan 2009 13:36:11 +0100
changeset 29657 881f328dfbb3
parent 29287 5b0bfd63b5da
child 29730 86cac1fab613
permissions -rw-r--r--
slightly adapted towards more uniformity with div/mod on nat
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Index Code_Integer
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code inline]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 lemmas [code post] = zero_nat_code [symmetric]
    33 
    34 lemma one_nat_code [code, code inline]:
    35   "1 = (Numeral1 :: nat)"
    36   by simp
    37 lemmas [code post] = one_nat_code [symmetric]
    38 
    39 lemma Suc_code [code]:
    40   "Suc n = n + 1"
    41   by simp
    42 
    43 lemma plus_nat_code [code]:
    44   "n + m = nat (of_nat n + of_nat m)"
    45   by simp
    46 
    47 lemma minus_nat_code [code]:
    48   "n - m = nat (of_nat n - of_nat m)"
    49   by simp
    50 
    51 lemma times_nat_code [code]:
    52   "n * m = nat (of_nat n * of_nat m)"
    53   unfolding of_nat_mult [symmetric] by simp
    54 
    55 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    56   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    57 
    58 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    59   [code del]: "divmod_aux = Divides.divmod"
    60 
    61 lemma [code]:
    62   "Divides.divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    63   unfolding divmod_aux_def divmod_div_mod by simp
    64 
    65 lemma divmod_aux_code [code]:
    66   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    67   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    68 
    69 lemma eq_nat_code [code]:
    70   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    71   by (simp add: eq)
    72 
    73 lemma eq_nat_refl [code nbe]:
    74   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    75   by (rule HOL.eq_refl)
    76 
    77 lemma less_eq_nat_code [code]:
    78   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    79   by simp
    80 
    81 lemma less_nat_code [code]:
    82   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    83   by simp
    84 
    85 subsection {* Case analysis *}
    86 
    87 text {*
    88   Case analysis on natural numbers is rephrased using a conditional
    89   expression:
    90 *}
    91 
    92 lemma [code, code unfold]:
    93   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    94   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    95 
    96 
    97 subsection {* Preprocessors *}
    98 
    99 text {*
   100   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   101   a constructor term. Therefore, all occurrences of this term in a position
   102   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   103   equation or in the arguments of an inductive relation in an introduction
   104   rule) must be eliminated.
   105   This can be accomplished by applying the following transformation rules:
   106 *}
   107 
   108 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   109   f n = (if n = 0 then g else h (n - 1))"
   110   by (case_tac n) simp_all
   111 
   112 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   113   by (case_tac n) simp_all
   114 
   115 text {*
   116   The rules above are built into a preprocessor that is plugged into
   117   the code generator. Since the preprocessor for introduction rules
   118   does not know anything about modes, some of the modes that worked
   119   for the canonical representation of natural numbers may no longer work.
   120 *}
   121 
   122 (*<*)
   123 setup {*
   124 let
   125 
   126 fun remove_suc thy thms =
   127   let
   128     val vname = Name.variant (map fst
   129       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   130     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   131     fun lhs_of th = snd (Thm.dest_comb
   132       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   133     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   134     fun find_vars ct = (case term_of ct of
   135         (Const ("Suc", _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   136       | _ $ _ =>
   137         let val (ct1, ct2) = Thm.dest_comb ct
   138         in 
   139           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   140           map (apfst (Thm.capply ct1)) (find_vars ct2)
   141         end
   142       | _ => []);
   143     val eqs = maps
   144       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   145     fun mk_thms (th, (ct, cv')) =
   146       let
   147         val th' =
   148           Thm.implies_elim
   149            (Conv.fconv_rule (Thm.beta_conversion true)
   150              (Drule.instantiate'
   151                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   152                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   153                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   154       in
   155         case map_filter (fn th'' =>
   156             SOME (th'', singleton
   157               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   158           handle THM _ => NONE) thms of
   159             [] => NONE
   160           | thps =>
   161               let val (ths1, ths2) = split_list thps
   162               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   163       end
   164   in case get_first mk_thms eqs of
   165       NONE => thms
   166     | SOME x => remove_suc thy x
   167   end;
   168 
   169 fun eqn_suc_preproc thy ths =
   170   let
   171     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   172     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   173   in
   174     if forall (can dest) ths andalso
   175       exists (contains_suc o dest) ths
   176     then remove_suc thy ths else ths
   177   end;
   178 
   179 fun remove_suc_clause thy thms =
   180   let
   181     val vname = Name.variant (map fst
   182       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   183     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   184       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   185       | find_var _ = NONE;
   186     fun find_thm th =
   187       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   188       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   189   in
   190     case get_first find_thm thms of
   191       NONE => thms
   192     | SOME ((th, th'), (Sucv, v)) =>
   193         let
   194           val cert = cterm_of (Thm.theory_of_thm th);
   195           val th'' = ObjectLogic.rulify (Thm.implies_elim
   196             (Conv.fconv_rule (Thm.beta_conversion true)
   197               (Drule.instantiate' []
   198                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   199                    abstract_over (Sucv,
   200                      HOLogic.dest_Trueprop (prop_of th')))))),
   201                  SOME (cert v)] @{thm Suc_clause}))
   202             (Thm.forall_intr (cert v) th'))
   203         in
   204           remove_suc_clause thy (map (fn th''' =>
   205             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   206         end
   207   end;
   208 
   209 fun clause_suc_preproc thy ths =
   210   let
   211     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   212   in
   213     if forall (can (dest o concl_of)) ths andalso
   214       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   215         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   216     then remove_suc_clause thy ths else ths
   217   end;
   218 
   219 fun lift f thy eqns1 =
   220   let
   221     val eqns2 = burrow_fst Drule.zero_var_indexes_list eqns1;
   222     val thms3 = try (map fst
   223       #> map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   224       #> f thy
   225       #> map (fn thm => thm RS @{thm eq_reflection})
   226       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) eqns2;
   227     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
   228   in case thms4
   229    of NONE => NONE
   230     | SOME thms4 => if Thm.eq_thms (map fst eqns2, thms4)
   231         then NONE else SOME (map (apfst (AxClass.overload thy) o Code_Unit.mk_eqn thy) thms4)
   232   end
   233 
   234 in
   235 
   236   Codegen.add_preprocessor eqn_suc_preproc
   237   #> Codegen.add_preprocessor clause_suc_preproc
   238   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
   239   #> Code.add_functrans ("clause_Suc", lift clause_suc_preproc)
   240 
   241 end;
   242 *}
   243 (*>*)
   244 
   245 
   246 subsection {* Target language setup *}
   247 
   248 text {*
   249   For ML, we map @{typ nat} to target language integers, where we
   250   assert that values are always non-negative.
   251 *}
   252 
   253 code_type nat
   254   (SML "IntInf.int")
   255   (OCaml "Big'_int.big'_int")
   256 
   257 types_code
   258   nat ("int")
   259 attach (term_of) {*
   260 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   261 *}
   262 attach (test) {*
   263 fun gen_nat i =
   264   let val n = random_range 0 i
   265   in (n, fn () => term_of_nat n) end;
   266 *}
   267 
   268 text {*
   269   For Haskell we define our own @{typ nat} type.  The reason
   270   is that we have to distinguish type class instances
   271   for @{typ nat} and @{typ int}.
   272 *}
   273 
   274 code_include Haskell "Nat" {*
   275 newtype Nat = Nat Integer deriving (Show, Eq);
   276 
   277 instance Num Nat where {
   278   fromInteger k = Nat (if k >= 0 then k else 0);
   279   Nat n + Nat m = Nat (n + m);
   280   Nat n - Nat m = fromInteger (n - m);
   281   Nat n * Nat m = Nat (n * m);
   282   abs n = n;
   283   signum _ = 1;
   284   negate n = error "negate Nat";
   285 };
   286 
   287 instance Ord Nat where {
   288   Nat n <= Nat m = n <= m;
   289   Nat n < Nat m = n < m;
   290 };
   291 
   292 instance Real Nat where {
   293   toRational (Nat n) = toRational n;
   294 };
   295 
   296 instance Enum Nat where {
   297   toEnum k = fromInteger (toEnum k);
   298   fromEnum (Nat n) = fromEnum n;
   299 };
   300 
   301 instance Integral Nat where {
   302   toInteger (Nat n) = n;
   303   divMod n m = quotRem n m;
   304   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   305 };
   306 *}
   307 
   308 code_reserved Haskell Nat
   309 
   310 code_type nat
   311   (Haskell "Nat")
   312 
   313 code_instance nat :: eq
   314   (Haskell -)
   315 
   316 text {*
   317   Natural numerals.
   318 *}
   319 
   320 lemma [code inline, symmetric, code post]:
   321   "nat (number_of i) = number_nat_inst.number_of_nat i"
   322   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   323   by (simp add: number_nat_inst.number_of_nat)
   324 
   325 setup {*
   326   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   327     true false) ["SML", "OCaml", "Haskell"]
   328 *}
   329 
   330 text {*
   331   Since natural numbers are implemented
   332   using integers in ML, the coercion function @{const "of_nat"} of type
   333   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   334   For the @{const "nat"} function for converting an integer to a natural
   335   number, we give a specific implementation using an ML function that
   336   returns its input value, provided that it is non-negative, and otherwise
   337   returns @{text "0"}.
   338 *}
   339 
   340 definition
   341   int :: "nat \<Rightarrow> int"
   342 where
   343   [code del]: "int = of_nat"
   344 
   345 lemma int_code' [code]:
   346   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   347   unfolding int_nat_number_of [folded int_def] ..
   348 
   349 lemma nat_code' [code]:
   350   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   351   unfolding nat_number_of_def number_of_is_id neg_def by simp
   352 
   353 lemma of_nat_int [code unfold]:
   354   "of_nat = int" by (simp add: int_def)
   355 declare of_nat_int [symmetric, code post]
   356 
   357 code_const int
   358   (SML "_")
   359   (OCaml "_")
   360 
   361 consts_code
   362   int ("(_)")
   363   nat ("\<module>nat")
   364 attach {*
   365 fun nat i = if i < 0 then 0 else i;
   366 *}
   367 
   368 code_const nat
   369   (SML "IntInf.max/ (/0,/ _)")
   370   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   371 
   372 text {* For Haskell, things are slightly different again. *}
   373 
   374 code_const int and nat
   375   (Haskell "toInteger" and "fromInteger")
   376 
   377 text {* Conversion from and to indices. *}
   378 
   379 code_const index_of_nat
   380   (SML "IntInf.toInt")
   381   (OCaml "Big'_int.int'_of'_big'_int")
   382   (Haskell "fromEnum")
   383 
   384 code_const nat_of_index
   385   (SML "IntInf.fromInt")
   386   (OCaml "Big'_int.big'_int'_of'_int")
   387   (Haskell "toEnum")
   388 
   389 text {* Using target language arithmetic operations whenever appropriate *}
   390 
   391 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   392   (SML "IntInf.+ ((_), (_))")
   393   (OCaml "Big'_int.add'_big'_int")
   394   (Haskell infixl 6 "+")
   395 
   396 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   397   (SML "IntInf.* ((_), (_))")
   398   (OCaml "Big'_int.mult'_big'_int")
   399   (Haskell infixl 7 "*")
   400 
   401 code_const divmod_aux
   402   (SML "IntInf.divMod/ ((_),/ (_))")
   403   (OCaml "Big'_int.quomod'_big'_int")
   404   (Haskell "divMod")
   405 
   406 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   407   (SML "!((_ : IntInf.int) = _)")
   408   (OCaml "Big'_int.eq'_big'_int")
   409   (Haskell infixl 4 "==")
   410 
   411 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   412   (SML "IntInf.<= ((_), (_))")
   413   (OCaml "Big'_int.le'_big'_int")
   414   (Haskell infix 4 "<=")
   415 
   416 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   417   (SML "IntInf.< ((_), (_))")
   418   (OCaml "Big'_int.lt'_big'_int")
   419   (Haskell infix 4 "<")
   420 
   421 consts_code
   422   "0::nat"                     ("0")
   423   "1::nat"                     ("1")
   424   Suc                          ("(_ +/ 1)")
   425   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   426   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   427   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   428   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   429 
   430 
   431 text {* Evaluation *}
   432 
   433 lemma [code, code del]:
   434   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   435 
   436 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   437   (SML "HOLogic.mk'_number/ HOLogic.natT")
   438 
   439 
   440 text {* Module names *}
   441 
   442 code_modulename SML
   443   Nat Integer
   444   Divides Integer
   445   Ring_and_Field Integer
   446   Efficient_Nat Integer
   447 
   448 code_modulename OCaml
   449   Nat Integer
   450   Divides Integer
   451   Ring_and_Field Integer
   452   Efficient_Nat Integer
   453 
   454 code_modulename Haskell
   455   Nat Integer
   456   Divides Integer
   457   Ring_and_Field Integer
   458   Efficient_Nat Integer
   459 
   460 hide const int
   461 
   462 end