src/HOL/Library/Efficient_Nat.thy
author haftmann
Mon, 23 Mar 2009 08:14:24 +0100
changeset 30663 0b6aff7451b2
parent 29874 ffed4bd4bfad
child 31090 3be41b271023
permissions -rw-r--r--
Main is (Complex_Main) base entry point in library theories
     1 (*  Title:      HOL/Library/Efficient_Nat.thy
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
     3 *)
     4 
     5 header {* Implementation of natural numbers by target-language integers *}
     6 
     7 theory Efficient_Nat
     8 imports Code_Index Code_Integer Main
     9 begin
    10 
    11 text {*
    12   When generating code for functions on natural numbers, the
    13   canonical representation using @{term "0::nat"} and
    14   @{term "Suc"} is unsuitable for computations involving large
    15   numbers.  The efficiency of the generated code can be improved
    16   drastically by implementing natural numbers by target-language
    17   integers.  To do this, just include this theory.
    18 *}
    19 
    20 subsection {* Basic arithmetic *}
    21 
    22 text {*
    23   Most standard arithmetic functions on natural numbers are implemented
    24   using their counterparts on the integers:
    25 *}
    26 
    27 code_datatype number_nat_inst.number_of_nat
    28 
    29 lemma zero_nat_code [code, code inline]:
    30   "0 = (Numeral0 :: nat)"
    31   by simp
    32 lemmas [code post] = zero_nat_code [symmetric]
    33 
    34 lemma one_nat_code [code, code inline]:
    35   "1 = (Numeral1 :: nat)"
    36   by simp
    37 lemmas [code post] = one_nat_code [symmetric]
    38 
    39 lemma Suc_code [code]:
    40   "Suc n = n + 1"
    41   by simp
    42 
    43 lemma plus_nat_code [code]:
    44   "n + m = nat (of_nat n + of_nat m)"
    45   by simp
    46 
    47 lemma minus_nat_code [code]:
    48   "n - m = nat (of_nat n - of_nat m)"
    49   by simp
    50 
    51 lemma times_nat_code [code]:
    52   "n * m = nat (of_nat n * of_nat m)"
    53   unfolding of_nat_mult [symmetric] by simp
    54 
    55 text {* Specialized @{term "op div \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} 
    56   and @{term "op mod \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"} operations. *}
    57 
    58 definition divmod_aux ::  "nat \<Rightarrow> nat \<Rightarrow> nat \<times> nat" where
    59   [code del]: "divmod_aux = Divides.divmod"
    60 
    61 lemma [code]:
    62   "Divides.divmod n m = (if m = 0 then (0, n) else divmod_aux n m)"
    63   unfolding divmod_aux_def divmod_div_mod by simp
    64 
    65 lemma divmod_aux_code [code]:
    66   "divmod_aux n m = (nat (of_nat n div of_nat m), nat (of_nat n mod of_nat m))"
    67   unfolding divmod_aux_def divmod_div_mod zdiv_int [symmetric] zmod_int [symmetric] by simp
    68 
    69 lemma eq_nat_code [code]:
    70   "eq_class.eq n m \<longleftrightarrow> eq_class.eq (of_nat n \<Colon> int) (of_nat m)"
    71   by (simp add: eq)
    72 
    73 lemma eq_nat_refl [code nbe]:
    74   "eq_class.eq (n::nat) n \<longleftrightarrow> True"
    75   by (rule HOL.eq_refl)
    76 
    77 lemma less_eq_nat_code [code]:
    78   "n \<le> m \<longleftrightarrow> (of_nat n \<Colon> int) \<le> of_nat m"
    79   by simp
    80 
    81 lemma less_nat_code [code]:
    82   "n < m \<longleftrightarrow> (of_nat n \<Colon> int) < of_nat m"
    83   by simp
    84 
    85 subsection {* Case analysis *}
    86 
    87 text {*
    88   Case analysis on natural numbers is rephrased using a conditional
    89   expression:
    90 *}
    91 
    92 lemma [code, code unfold]:
    93   "nat_case = (\<lambda>f g n. if n = 0 then f else g (n - 1))"
    94   by (auto simp add: expand_fun_eq dest!: gr0_implies_Suc)
    95 
    96 
    97 subsection {* Preprocessors *}
    98 
    99 text {*
   100   In contrast to @{term "Suc n"}, the term @{term "n + (1::nat)"} is no longer
   101   a constructor term. Therefore, all occurrences of this term in a position
   102   where a pattern is expected (i.e.\ on the left-hand side of a recursion
   103   equation or in the arguments of an inductive relation in an introduction
   104   rule) must be eliminated.
   105   This can be accomplished by applying the following transformation rules:
   106 *}
   107 
   108 lemma Suc_if_eq': "(\<And>n. f (Suc n) = h n) \<Longrightarrow> f 0 = g \<Longrightarrow>
   109   f n = (if n = 0 then g else h (n - 1))"
   110   by (cases n) simp_all
   111 
   112 lemma Suc_if_eq: "(\<And>n. f (Suc n) \<equiv> h n) \<Longrightarrow> f 0 \<equiv> g \<Longrightarrow>
   113   f n \<equiv> if n = 0 then g else h (n - 1)"
   114   by (rule eq_reflection, rule Suc_if_eq')
   115     (rule meta_eq_to_obj_eq, assumption,
   116      rule meta_eq_to_obj_eq, assumption)
   117 
   118 lemma Suc_clause: "(\<And>n. P n (Suc n)) \<Longrightarrow> n \<noteq> 0 \<Longrightarrow> P (n - 1) n"
   119   by (cases n) simp_all
   120 
   121 text {*
   122   The rules above are built into a preprocessor that is plugged into
   123   the code generator. Since the preprocessor for introduction rules
   124   does not know anything about modes, some of the modes that worked
   125   for the canonical representation of natural numbers may no longer work.
   126 *}
   127 
   128 (*<*)
   129 setup {*
   130 let
   131 
   132 fun gen_remove_suc Suc_if_eq dest_judgement thy thms =
   133   let
   134     val vname = Name.variant (map fst
   135       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "n";
   136     val cv = cterm_of thy (Var ((vname, 0), HOLogic.natT));
   137     fun lhs_of th = snd (Thm.dest_comb
   138       (fst (Thm.dest_comb (dest_judgement (cprop_of th)))));
   139     fun rhs_of th = snd (Thm.dest_comb (dest_judgement (cprop_of th)));
   140     fun find_vars ct = (case term_of ct of
   141         (Const (@{const_name Suc}, _) $ Var _) => [(cv, snd (Thm.dest_comb ct))]
   142       | _ $ _ =>
   143         let val (ct1, ct2) = Thm.dest_comb ct
   144         in 
   145           map (apfst (fn ct => Thm.capply ct ct2)) (find_vars ct1) @
   146           map (apfst (Thm.capply ct1)) (find_vars ct2)
   147         end
   148       | _ => []);
   149     val eqs = maps
   150       (fn th => map (pair th) (find_vars (lhs_of th))) thms;
   151     fun mk_thms (th, (ct, cv')) =
   152       let
   153         val th' =
   154           Thm.implies_elim
   155            (Conv.fconv_rule (Thm.beta_conversion true)
   156              (Drule.instantiate'
   157                [SOME (ctyp_of_term ct)] [SOME (Thm.cabs cv ct),
   158                  SOME (Thm.cabs cv' (rhs_of th)), NONE, SOME cv']
   159                Suc_if_eq)) (Thm.forall_intr cv' th)
   160       in
   161         case map_filter (fn th'' =>
   162             SOME (th'', singleton
   163               (Variable.trade (K (fn [th'''] => [th''' RS th'])) (Variable.thm_context th'')) th'')
   164           handle THM _ => NONE) thms of
   165             [] => NONE
   166           | thps =>
   167               let val (ths1, ths2) = split_list thps
   168               in SOME (subtract Thm.eq_thm (th :: ths1) thms @ ths2) end
   169       end
   170   in get_first mk_thms eqs end;
   171 
   172 fun gen_eqn_suc_preproc Suc_if_eq dest_judgement dest_lhs thy thms =
   173   let
   174     val dest = dest_lhs o prop_of;
   175     val contains_suc = exists_Const (fn (c, _) => c = @{const_name Suc});
   176   in
   177     if forall (can dest) thms andalso exists (contains_suc o dest) thms
   178       then perhaps_loop (gen_remove_suc Suc_if_eq dest_judgement thy) thms
   179        else NONE
   180   end;
   181 
   182 fun eqn_suc_preproc thy = map fst
   183   #> gen_eqn_suc_preproc
   184       @{thm Suc_if_eq} I (fst o Logic.dest_equals) thy
   185   #> (Option.map o map) (Code_Unit.mk_eqn thy);
   186 
   187 fun eqn_suc_preproc' thy thms = gen_eqn_suc_preproc
   188   @{thm Suc_if_eq'} (snd o Thm.dest_comb) (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop) thy thms
   189   |> the_default thms;
   190 
   191 fun remove_suc_clause thy thms =
   192   let
   193     val vname = Name.variant (map fst
   194       (fold (Term.add_var_names o Thm.full_prop_of) thms [])) "x";
   195     fun find_var (t as Const (@{const_name Suc}, _) $ (v as Var _)) = SOME (t, v)
   196       | find_var (t $ u) = (case find_var t of NONE => find_var u | x => x)
   197       | find_var _ = NONE;
   198     fun find_thm th =
   199       let val th' = Conv.fconv_rule ObjectLogic.atomize th
   200       in Option.map (pair (th, th')) (find_var (prop_of th')) end
   201   in
   202     case get_first find_thm thms of
   203       NONE => thms
   204     | SOME ((th, th'), (Sucv, v)) =>
   205         let
   206           val cert = cterm_of (Thm.theory_of_thm th);
   207           val th'' = ObjectLogic.rulify (Thm.implies_elim
   208             (Conv.fconv_rule (Thm.beta_conversion true)
   209               (Drule.instantiate' []
   210                 [SOME (cert (lambda v (Abs ("x", HOLogic.natT,
   211                    abstract_over (Sucv,
   212                      HOLogic.dest_Trueprop (prop_of th')))))),
   213                  SOME (cert v)] @{thm Suc_clause}))
   214             (Thm.forall_intr (cert v) th'))
   215         in
   216           remove_suc_clause thy (map (fn th''' =>
   217             if (op = o pairself prop_of) (th''', th) then th'' else th''') thms)
   218         end
   219   end;
   220 
   221 fun clause_suc_preproc thy ths =
   222   let
   223     val dest = fst o HOLogic.dest_mem o HOLogic.dest_Trueprop
   224   in
   225     if forall (can (dest o concl_of)) ths andalso
   226       exists (fn th => exists (exists_Const (fn (c, _) => c = @{const_name Suc}))
   227         (map_filter (try dest) (concl_of th :: prems_of th))) ths
   228     then remove_suc_clause thy ths else ths
   229   end;
   230 in
   231 
   232   Codegen.add_preprocessor eqn_suc_preproc'
   233   #> Codegen.add_preprocessor clause_suc_preproc
   234   #> Code.add_functrans ("eqn_Suc", eqn_suc_preproc)
   235 
   236 end;
   237 *}
   238 (*>*)
   239 
   240 
   241 subsection {* Target language setup *}
   242 
   243 text {*
   244   For ML, we map @{typ nat} to target language integers, where we
   245   assert that values are always non-negative.
   246 *}
   247 
   248 code_type nat
   249   (SML "IntInf.int")
   250   (OCaml "Big'_int.big'_int")
   251 
   252 types_code
   253   nat ("int")
   254 attach (term_of) {*
   255 val term_of_nat = HOLogic.mk_number HOLogic.natT;
   256 *}
   257 attach (test) {*
   258 fun gen_nat i =
   259   let val n = random_range 0 i
   260   in (n, fn () => term_of_nat n) end;
   261 *}
   262 
   263 text {*
   264   For Haskell we define our own @{typ nat} type.  The reason
   265   is that we have to distinguish type class instances
   266   for @{typ nat} and @{typ int}.
   267 *}
   268 
   269 code_include Haskell "Nat" {*
   270 newtype Nat = Nat Integer deriving (Show, Eq);
   271 
   272 instance Num Nat where {
   273   fromInteger k = Nat (if k >= 0 then k else 0);
   274   Nat n + Nat m = Nat (n + m);
   275   Nat n - Nat m = fromInteger (n - m);
   276   Nat n * Nat m = Nat (n * m);
   277   abs n = n;
   278   signum _ = 1;
   279   negate n = error "negate Nat";
   280 };
   281 
   282 instance Ord Nat where {
   283   Nat n <= Nat m = n <= m;
   284   Nat n < Nat m = n < m;
   285 };
   286 
   287 instance Real Nat where {
   288   toRational (Nat n) = toRational n;
   289 };
   290 
   291 instance Enum Nat where {
   292   toEnum k = fromInteger (toEnum k);
   293   fromEnum (Nat n) = fromEnum n;
   294 };
   295 
   296 instance Integral Nat where {
   297   toInteger (Nat n) = n;
   298   divMod n m = quotRem n m;
   299   quotRem (Nat n) (Nat m) = (Nat k, Nat l) where (k, l) = quotRem n m;
   300 };
   301 *}
   302 
   303 code_reserved Haskell Nat
   304 
   305 code_type nat
   306   (Haskell "Nat.Nat")
   307 
   308 code_instance nat :: eq
   309   (Haskell -)
   310 
   311 text {*
   312   Natural numerals.
   313 *}
   314 
   315 lemma [code inline, symmetric, code post]:
   316   "nat (number_of i) = number_nat_inst.number_of_nat i"
   317   -- {* this interacts as desired with @{thm nat_number_of_def} *}
   318   by (simp add: number_nat_inst.number_of_nat)
   319 
   320 setup {*
   321   fold (Numeral.add_code @{const_name number_nat_inst.number_of_nat}
   322     true false) ["SML", "OCaml", "Haskell"]
   323 *}
   324 
   325 text {*
   326   Since natural numbers are implemented
   327   using integers in ML, the coercion function @{const "of_nat"} of type
   328   @{typ "nat \<Rightarrow> int"} is simply implemented by the identity function.
   329   For the @{const "nat"} function for converting an integer to a natural
   330   number, we give a specific implementation using an ML function that
   331   returns its input value, provided that it is non-negative, and otherwise
   332   returns @{text "0"}.
   333 *}
   334 
   335 definition
   336   int :: "nat \<Rightarrow> int"
   337 where
   338   [code del]: "int = of_nat"
   339 
   340 lemma int_code' [code]:
   341   "int (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   342   unfolding int_nat_number_of [folded int_def] ..
   343 
   344 lemma nat_code' [code]:
   345   "nat (number_of l) = (if neg (number_of l \<Colon> int) then 0 else number_of l)"
   346   unfolding nat_number_of_def number_of_is_id neg_def by simp
   347 
   348 lemma of_nat_int [code unfold]:
   349   "of_nat = int" by (simp add: int_def)
   350 declare of_nat_int [symmetric, code post]
   351 
   352 code_const int
   353   (SML "_")
   354   (OCaml "_")
   355 
   356 consts_code
   357   int ("(_)")
   358   nat ("\<module>nat")
   359 attach {*
   360 fun nat i = if i < 0 then 0 else i;
   361 *}
   362 
   363 code_const nat
   364   (SML "IntInf.max/ (/0,/ _)")
   365   (OCaml "Big'_int.max'_big'_int/ Big'_int.zero'_big'_int")
   366 
   367 text {* For Haskell, things are slightly different again. *}
   368 
   369 code_const int and nat
   370   (Haskell "toInteger" and "fromInteger")
   371 
   372 text {* Conversion from and to indices. *}
   373 
   374 code_const Code_Index.of_nat
   375   (SML "IntInf.toInt")
   376   (OCaml "Big'_int.int'_of'_big'_int")
   377   (Haskell "fromEnum")
   378 
   379 code_const Code_Index.nat_of
   380   (SML "IntInf.fromInt")
   381   (OCaml "Big'_int.big'_int'_of'_int")
   382   (Haskell "toEnum")
   383 
   384 text {* Using target language arithmetic operations whenever appropriate *}
   385 
   386 code_const "op + \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   387   (SML "IntInf.+ ((_), (_))")
   388   (OCaml "Big'_int.add'_big'_int")
   389   (Haskell infixl 6 "+")
   390 
   391 code_const "op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat"
   392   (SML "IntInf.* ((_), (_))")
   393   (OCaml "Big'_int.mult'_big'_int")
   394   (Haskell infixl 7 "*")
   395 
   396 code_const divmod_aux
   397   (SML "IntInf.divMod/ ((_),/ (_))")
   398   (OCaml "Big'_int.quomod'_big'_int")
   399   (Haskell "divMod")
   400 
   401 code_const "eq_class.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   402   (SML "!((_ : IntInf.int) = _)")
   403   (OCaml "Big'_int.eq'_big'_int")
   404   (Haskell infixl 4 "==")
   405 
   406 code_const "op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   407   (SML "IntInf.<= ((_), (_))")
   408   (OCaml "Big'_int.le'_big'_int")
   409   (Haskell infix 4 "<=")
   410 
   411 code_const "op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool"
   412   (SML "IntInf.< ((_), (_))")
   413   (OCaml "Big'_int.lt'_big'_int")
   414   (Haskell infix 4 "<")
   415 
   416 consts_code
   417   "0::nat"                     ("0")
   418   "1::nat"                     ("1")
   419   Suc                          ("(_ +/ 1)")
   420   "op + \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ +/ _)")
   421   "op * \<Colon>  nat \<Rightarrow> nat \<Rightarrow> nat"   ("(_ */ _)")
   422   "op \<le> \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ <=/ _)")
   423   "op < \<Colon>  nat \<Rightarrow> nat \<Rightarrow> bool"  ("(_ </ _)")
   424 
   425 
   426 text {* Evaluation *}
   427 
   428 lemma [code, code del]:
   429   "(Code_Eval.term_of \<Colon> nat \<Rightarrow> term) = Code_Eval.term_of" ..
   430 
   431 code_const "Code_Eval.term_of \<Colon> nat \<Rightarrow> term"
   432   (SML "HOLogic.mk'_number/ HOLogic.natT")
   433 
   434 
   435 text {* Module names *}
   436 
   437 code_modulename SML
   438   Nat Integer
   439   Divides Integer
   440   Ring_and_Field Integer
   441   Efficient_Nat Integer
   442 
   443 code_modulename OCaml
   444   Nat Integer
   445   Divides Integer
   446   Ring_and_Field Integer
   447   Efficient_Nat Integer
   448 
   449 code_modulename Haskell
   450   Nat Integer
   451   Divides Integer
   452   Ring_and_Field Integer
   453   Efficient_Nat Integer
   454 
   455 hide const int
   456 
   457 end