src/HOL/Library/Efficient_Nat.thy
author haftmann
Mon, 16 Feb 2009 13:38:09 +0100
changeset 29869 a2594b5c945a
parent 29752 9e94b7078fa5
child 29874 ffed4bd4bfad
permissions -rw-r--r--
dropped clause_suc_preproc for generic code generator
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Index Code_Integer
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code inline]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 lemmas [code post] = zero_nat_code [symmetric]
    33 
    34 lemma one_nat_code [code, code inline]:
    35   "1 = (Numeral1 :: nat)"
    36   by simp
    37 lemmas [code post] = one_nat_code [symmetric]
    38 
    39 lemma Suc_code [code]:
    40   "Suc n = n + 1"
    41   by simp
    42 
    43 lemma plus_nat_code [code]:
    44   "n + m = nat (of_nat n + of_nat m)"
    45   by simp
    46 
    47 lemma minus_nat_code [code]:
    48   "n - m = nat (of_nat n - of_nat m)"
    49   by simp
    50 
    51 lemma times_nat_code [code]:
    52   "n * m = nat (of_nat n * of_nat m)"
    53   unfolding of_nat_mult [symmetric] by simp
    54 
    55 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    56   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    57 
    58 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    59   [code del]: "divmod_aux = Divides.divmod"
    60 
    61 lemma [code]:
    62   "Divides.divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    63   unfolding divmod_aux_def divmod_div_mod by simp
    64 
    65 lemma divmod_aux_code [code]:
    66   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    67   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    68 
    69 lemma eq_nat_code [code]:
    70   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    71   by (simp add: eq)
    72 
    73 lemma eq_nat_refl [code nbe]:
    74   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    75   by (rule HOL.eq_refl)
    76 
    77 lemma less_eq_nat_code [code]:
    78   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    79   by simp
    80 
    81 lemma less_nat_code [code]:
    82   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    83   by simp
    84 
    85 subsection {* Case analysis *}
    86 
    87 text {*
    88   Case analysis on natural numbers is rephrased using a conditional
    89   expression:
    90 *}
    91 
    92 lemma [code, code unfold]:
    93   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    94   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    95 
    96 
    97 subsection {* Preprocessors *}
    98 
    99 text {*
   100   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   101   a constructor term. Therefore, all occurrences of this term in a position
   102   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   103   equation or in the arguments of an inductive relation in an introduction
   104   rule) must be eliminated.
   105   This can be accomplished by applying the following transformation rules:
   106 *}
   107 
   108 lemma Suc_if_eq: "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   109   f n = (if n = 0 then g else h (n - 1))"
   110   by (cases n) simp_all
   111 
   112 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   113   by (cases n) simp_all
   114 
   115 text {*
   116   The rules above are built into a preprocessor that is plugged into
   117   the code generator. Since the preprocessor for introduction rules
   118   does not know anything about modes, some of the modes that worked
   119   for the canonical representation of natural numbers may no longer work.
   120 *}
   121 
   122 (*<*)
   123 setup {*
   124 let
   125 
   126 fun remove_suc thy thms =
   127   let
   128     val vname = Name.variant (map fst
   129       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   130     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   131     fun lhs_of th = snd (Thm.dest_comb
   132       (fst (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))))));
   133     fun rhs_of th = snd (Thm.dest_comb (snd (Thm.dest_comb (cprop_of th))));
   134     fun find_vars ct = (case term_of ct of
   135         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   136       | _ $ _ =>
   137         let val (ct1, ct2) = Thm.dest_comb ct
   138         in 
   139           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   140           map (apfst (Thm.capply ct1)) (find_vars ct2)
   141         end
   142       | _ => []);
   143     val eqs = maps
   144       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   145     fun mk_thms (th, (ct, cv')) =
   146       let
   147         val th' =
   148           Thm.implies_elim
   149            (Conv.fconv_rule (Thm.beta_conversion true)
   150              (Drule.instantiate'
   151                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   152                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   153                @{thm Suc_if_eq})) (Thm.forall_intr cv' th)
   154       in
   155         case map_filter (fn th'' =>
   156             SOME (th'', singleton
   157               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   158           handle THM _ => NONE) thms of
   159             [] => NONE
   160           | thps =>
   161               let val (ths1, ths2) = split_list thps
   162               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   163       end
   164   in case get_first mk_thms eqs of
   165       NONE => thms
   166     | SOME x => remove_suc thy x
   167   end;
   168 
   169 fun eqn_suc_preproc thy ths =
   170   let
   171     val dest = fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of;
   172     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   173   in
   174     if forall (can dest) ths andalso
   175       exists (contains_suc o dest) ths
   176     then remove_suc thy ths else ths
   177   end;
   178 
   179 fun remove_suc_clause thy thms =
   180   let
   181     val vname = Name.variant (map fst
   182       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   183     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   184       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   185       | find_var _ = NONE;
   186     fun find_thm th =
   187       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   188       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   189   in
   190     case get_first find_thm thms of
   191       NONE => thms
   192     | SOME ((th, th'), (Sucv, v)) =>
   193         let
   194           val cert = cterm_of (Thm.theory_of_thm th);
   195           val th'' = ObjectLogic.rulify (Thm.implies_elim
   196             (Conv.fconv_rule (Thm.beta_conversion true)
   197               (Drule.instantiate' []
   198                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   199                    abstract_over (Sucv,
   200                      HOLogic.dest_Trueprop (prop_of th')))))),
   201                  SOME (cert v)] @{thm Suc_clause}))
   202             (Thm.forall_intr (cert v) th'))
   203         in
   204           remove_suc_clause thy (map (fn th''' =>
   205             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   206         end
   207   end;
   208 
   209 fun clause_suc_preproc thy ths =
   210   let
   211     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   212   in
   213     if forall (can (dest o concl_of)) ths andalso
   214       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   215         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   216     then remove_suc_clause thy ths else ths
   217   end;
   218 
   219 fun lift f thy eqns1 =
   220   let
   221     val eqns2 = burrow_fst Drule.zero_var_indexes_list eqns1;
   222     val thms3 = try (map fst
   223       #> map (fn thm => thm RS @{thm meta_eq_to_obj_eq})
   224       #> f thy
   225       #> map (fn thm => thm RS @{thm eq_reflection})
   226       #> map (Conv.fconv_rule Drule.beta_eta_conversion)) eqns2;
   227     val thms4 = Option.map Drule.zero_var_indexes_list thms3;
   228   in case thms4
   229    of NONE => NONE
   230     | SOME thms4 => if Thm.eq_thms (map fst eqns2, thms4)
   231         then NONE else SOME (map (apfst (AxClass.overload thy) o Code_Unit.mk_eqn thy) thms4)
   232   end
   233 
   234 in
   235 
   236   Codegen.add_preprocessor eqn_suc_preproc
   237   #> Codegen.add_preprocessor clause_suc_preproc
   238   #> Code.add_functrans ("eqn_Suc", lift eqn_suc_preproc)
   239 
   240 end;
   241 *}
   242 (*>*)
   243 
   244 
   245 subsection {* Target language setup *}
   246 
   247 text {*
   248   For ML, we map @{typ nat} to target language integers, where we
   249   assert that values are always non-negative.
   250 *}
   251 
   252 code_type nat
   253   (SML "IntInf.int")
   254   (OCaml "Big'_int.big'_int")
   255 
   256 types_code
   257   nat ("int")
   258 attach (term_of) {*
   259 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   260 *}
   261 attach (test) {*
   262 fun gen_nat i =
   263   let val n = random_range 0 i
   264   in (n, fn () => term_of_nat n) end;
   265 *}
   266 
   267 text {*
   268   For Haskell we define our own @{typ nat} type.  The reason
   269   is that we have to distinguish type class instances
   270   for @{typ nat} and @{typ int}.
   271 *}
   272 
   273 code_include Haskell "Nat" {*
   274 newtype Nat = Nat Integer deriving (Show, Eq);
   275 
   276 instance Num Nat where {
   277   fromInteger k = Nat (if k >= 0 then k else 0);
   278   Nat n + Nat m = Nat (n + m);
   279   Nat n - Nat m = fromInteger (n - m);
   280   Nat n * Nat m = Nat (n * m);
   281   abs n = n;
   282   signum _ = 1;
   283   negate n = error "negate Nat";
   284 };
   285 
   286 instance Ord Nat where {
   287   Nat n <= Nat m = n <= m;
   288   Nat n < Nat m = n < m;
   289 };
   290 
   291 instance Real Nat where {
   292   toRational (Nat n) = toRational n;
   293 };
   294 
   295 instance Enum Nat where {
   296   toEnum k = fromInteger (toEnum k);
   297   fromEnum (Nat n) = fromEnum n;
   298 };
   299 
   300 instance Integral Nat where {
   301   toInteger (Nat n) = n;
   302   divMod n m = quotRem n m;
   303   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   304 };
   305 *}
   306 
   307 code_reserved Haskell Nat
   308 
   309 code_type nat
   310   (Haskell "Nat.Nat")
   311 
   312 code_instance nat :: eq
   313   (Haskell -)
   314 
   315 text {*
   316   Natural numerals.
   317 *}
   318 
   319 lemma [code inline, symmetric, code post]:
   320   "nat (number_of i) = number_nat_inst.number_of_nat i"
   321   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   322   by (simp add: number_nat_inst.number_of_nat)
   323 
   324 setup {*
   325   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   326     true false) ["SML", "OCaml", "Haskell"]
   327 *}
   328 
   329 text {*
   330   Since natural numbers are implemented
   331   using integers in ML, the coercion function @{const "of_nat"} of type
   332   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   333   For the @{const "nat"} function for converting an integer to a natural
   334   number, we give a specific implementation using an ML function that
   335   returns its input value, provided that it is non-negative, and otherwise
   336   returns @{text "0"}.
   337 *}
   338 
   339 definition
   340   int :: "nat \<Rightarrow> int"
   341 where
   342   [code del]: "int = of_nat"
   343 
   344 lemma int_code' [code]:
   345   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   346   unfolding int_nat_number_of [folded int_def] ..
   347 
   348 lemma nat_code' [code]:
   349   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   350   unfolding nat_number_of_def number_of_is_id neg_def by simp
   351 
   352 lemma of_nat_int [code unfold]:
   353   "of_nat = int" by (simp add: int_def)
   354 declare of_nat_int [symmetric, code post]
   355 
   356 code_const int
   357   (SML "_")
   358   (OCaml "_")
   359 
   360 consts_code
   361   int ("(_)")
   362   nat ("\<module>nat")
   363 attach {*
   364 fun nat i = if i < 0 then 0 else i;
   365 *}
   366 
   367 code_const nat
   368   (SML "IntInf.max/ (/0,/ _)")
   369   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   370 
   371 text {* For Haskell, things are slightly different again. *}
   372 
   373 code_const int and nat
   374   (Haskell "toInteger" and "fromInteger")
   375 
   376 text {* Conversion from and to indices. *}
   377 
   378 code_const Code_Index.of_nat
   379   (SML "IntInf.toInt")
   380   (OCaml "Big'_int.int'_of'_big'_int")
   381   (Haskell "fromEnum")
   382 
   383 code_const Code_Index.nat_of
   384   (SML "IntInf.fromInt")
   385   (OCaml "Big'_int.big'_int'_of'_int")
   386   (Haskell "toEnum")
   387 
   388 text {* Using target language arithmetic operations whenever appropriate *}
   389 
   390 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   391   (SML "IntInf.+ ((_), (_))")
   392   (OCaml "Big'_int.add'_big'_int")
   393   (Haskell infixl 6 "+")
   394 
   395 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   396   (SML "IntInf.* ((_), (_))")
   397   (OCaml "Big'_int.mult'_big'_int")
   398   (Haskell infixl 7 "*")
   399 
   400 code_const divmod_aux
   401   (SML "IntInf.divMod/ ((_),/ (_))")
   402   (OCaml "Big'_int.quomod'_big'_int")
   403   (Haskell "divMod")
   404 
   405 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   406   (SML "!((_ : IntInf.int) = _)")
   407   (OCaml "Big'_int.eq'_big'_int")
   408   (Haskell infixl 4 "==")
   409 
   410 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   411   (SML "IntInf.<= ((_), (_))")
   412   (OCaml "Big'_int.le'_big'_int")
   413   (Haskell infix 4 "<=")
   414 
   415 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   416   (SML "IntInf.< ((_), (_))")
   417   (OCaml "Big'_int.lt'_big'_int")
   418   (Haskell infix 4 "<")
   419 
   420 consts_code
   421   "0::nat"                     ("0")
   422   "1::nat"                     ("1")
   423   Suc                          ("(_ +/ 1)")
   424   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   425   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   426   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   427   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   428 
   429 
   430 text {* Evaluation *}
   431 
   432 lemma [code, code del]:
   433   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   434 
   435 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   436   (SML "HOLogic.mk'_number/ HOLogic.natT")
   437 
   438 
   439 text {* Module names *}
   440 
   441 code_modulename SML
   442   Nat Integer
   443   Divides Integer
   444   Ring_and_Field Integer
   445   Efficient_Nat Integer
   446 
   447 code_modulename OCaml
   448   Nat Integer
   449   Divides Integer
   450   Ring_and_Field Integer
   451   Efficient_Nat Integer
   452 
   453 code_modulename Haskell
   454   Nat Integer
   455   Divides Integer
   456   Ring_and_Field Integer
   457   Efficient_Nat Integer
   458 
   459 hide const int
   460 
   461 end