src/HOL/Tools/Nitpick/nitpick_util.ML
author blanchet
Wed, 20 Jan 2010 10:38:06 +0100
changeset 34923 c4f04bee79f3
parent 34121 c4628a1dcf75
child 34969 7b8c366e34a2
permissions -rw-r--r--
some work on Nitpick's support for quotient types;
quotient types are not yet in Isabelle, so for now I hardcoded "IntEx.my_int"
     1 (*  Title:      HOL/Tools/Nitpick/nitpick_util.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2008, 2009
     4 
     5 General-purpose functions used by the Nitpick modules.
     6 *)
     7 
     8 signature NITPICK_UTIL =
     9 sig
    10   type styp = string * typ
    11   datatype polarity = Pos | Neg | Neut
    12 
    13   exception ARG of string * string
    14   exception BAD of string * string
    15   exception TOO_SMALL of string * string
    16   exception TOO_LARGE of string * string
    17   exception NOT_SUPPORTED of string
    18   exception SAME of unit
    19 
    20   val nitpick_prefix : string
    21   val curry3 : ('a * 'b * 'c -> 'd) -> 'a -> 'b -> 'c -> 'd
    22   val pairf : ('a -> 'b) -> ('a -> 'c) -> 'a -> 'b * 'c
    23   val int_for_bool : bool -> int
    24   val nat_minus : int -> int -> int
    25   val reasonable_power : int -> int -> int
    26   val exact_log : int -> int -> int
    27   val exact_root : int -> int -> int
    28   val offset_list : int list -> int list
    29   val index_seq : int -> int -> int list
    30   val filter_indices : int list -> 'a list -> 'a list
    31   val filter_out_indices : int list -> 'a list -> 'a list
    32   val fold1 : ('a -> 'a -> 'a) -> 'a list -> 'a
    33   val replicate_list : int -> 'a list -> 'a list
    34   val n_fold_cartesian_product : 'a list list -> 'a list list
    35   val all_distinct_unordered_pairs_of : ''a list -> (''a * ''a) list
    36   val nth_combination : (int * int) list -> int -> int list
    37   val all_combinations : (int * int) list -> int list list
    38   val all_permutations : 'a list -> 'a list list
    39   val batch_list : int -> 'a list -> 'a list list
    40   val chunk_list_unevenly : int list -> 'a list -> 'a list list
    41   val map3 : ('a -> 'b -> 'c -> 'd) -> 'a list -> 'b list -> 'c list -> 'd list
    42   val double_lookup :
    43     ('a * 'a -> bool) -> ('a option * 'b) list -> 'a -> 'b option
    44   val triple_lookup :
    45     (''a * ''a -> bool) -> (''a option * 'b) list -> ''a -> 'b option
    46   val is_substring_of : string -> string -> bool
    47   val serial_commas : string -> string list -> string list
    48   val plural_s : int -> string
    49   val plural_s_for_list : 'a list -> string
    50   val flip_polarity : polarity -> polarity
    51   val prop_T : typ
    52   val bool_T : typ
    53   val nat_T : typ
    54   val int_T : typ
    55   val nat_subscript : int -> string
    56   val time_limit : Time.time option -> ('a -> 'b) -> 'a -> 'b
    57   val DETERM_TIMEOUT : Time.time option -> tactic -> tactic
    58   val setmp_show_all_types : ('a -> 'b) -> 'a -> 'b
    59   val indent_size : int
    60   val pstrs : string -> Pretty.T list
    61   val plain_string_from_yxml : string -> string
    62   val maybe_quote : string -> string
    63 end
    64 
    65 structure Nitpick_Util : NITPICK_UTIL =
    66 struct
    67 
    68 type styp = string * typ
    69 
    70 datatype polarity = Pos | Neg | Neut
    71 
    72 exception ARG of string * string
    73 exception BAD of string * string
    74 exception TOO_SMALL of string * string
    75 exception TOO_LARGE of string * string
    76 exception NOT_SUPPORTED of string
    77 exception SAME of unit
    78 
    79 val nitpick_prefix = "Nitpick."
    80 
    81 (* ('a * 'b * 'c -> 'd) -> 'a -> 'b -> 'c -> 'd *)
    82 fun curry3 f = fn x => fn y => fn z => f (x, y, z)
    83 
    84 (* ('a -> 'b) -> ('a -> 'c) -> 'a -> 'b * 'c *)
    85 fun pairf f g x = (f x, g x)
    86 
    87 (* bool -> int *)
    88 fun int_for_bool b = if b then 1 else 0
    89 (* int -> int -> int *)
    90 fun nat_minus i j = if i > j then i - j else 0
    91 
    92 val max_exponent = 16384
    93 
    94 (* int -> int -> int *)
    95 fun reasonable_power a 0 = 1
    96   | reasonable_power a 1 = a
    97   | reasonable_power 0 _ = 0
    98   | reasonable_power 1 _ = 1
    99   | reasonable_power a b =
   100     if b < 0 then
   101       raise ARG ("Nitpick_Util.reasonable_power",
   102                  "negative exponent (" ^ signed_string_of_int b ^ ")")
   103     else if b > max_exponent then
   104       raise TOO_LARGE ("Nitpick_Util.reasonable_power",
   105                        "too large exponent (" ^ signed_string_of_int b ^ ")")
   106     else
   107       let val c = reasonable_power a (b div 2) in
   108         c * c * reasonable_power a (b mod 2)
   109       end
   110 
   111 (* int -> int -> int *)
   112 fun exact_log m n =
   113   let
   114     val r = Math.ln (Real.fromInt n) / Math.ln (Real.fromInt m) |> Real.round
   115   in
   116     if reasonable_power m r = n then
   117       r
   118     else
   119       raise ARG ("Nitpick_Util.exact_log",
   120                  commas (map signed_string_of_int [m, n]))
   121   end
   122 
   123 (* int -> int -> int *)
   124 fun exact_root m n =
   125   let val r = Math.pow (Real.fromInt n, 1.0 / (Real.fromInt m)) |> Real.round in
   126     if reasonable_power r m = n then
   127       r
   128     else
   129       raise ARG ("Nitpick_Util.exact_root",
   130                  commas (map signed_string_of_int [m, n]))
   131   end
   132 
   133 (* ('a -> 'a -> 'a) -> 'a list -> 'a *)
   134 fun fold1 f = foldl1 (uncurry f)
   135 
   136 (* int -> 'a list -> 'a list *)
   137 fun replicate_list 0 _ = []
   138   | replicate_list n xs = xs @ replicate_list (n - 1) xs
   139 
   140 (* int list -> int list *)
   141 fun offset_list ns = rev (tl (fold (fn x => fn xs => (x + hd xs) :: xs) ns [0]))
   142 (* int -> int -> int list *)
   143 fun index_seq j0 n = if j0 < 0 then j0 downto j0 - n + 1 else j0 upto j0 + n - 1
   144 
   145 (* int list -> 'a list -> 'a list *)
   146 fun filter_indices js xs =
   147   let
   148     (* int -> int list -> 'a list -> 'a list *)
   149     fun aux _ [] _ = []
   150       | aux i (j :: js) (x :: xs) =
   151         if i = j then x :: aux (i + 1) js xs else aux (i + 1) (j :: js) xs
   152       | aux _ _ _ = raise ARG ("Nitpick_Util.filter_indices",
   153                                "indices unordered or out of range")
   154   in aux 0 js xs end
   155 fun filter_out_indices js xs =
   156   let
   157     (* int -> int list -> 'a list -> 'a list *)
   158     fun aux _ [] xs = xs
   159       | aux i (j :: js) (x :: xs) =
   160         if i = j then aux (i + 1) js xs else x :: aux (i + 1) (j :: js) xs
   161       | aux _ _ _ = raise ARG ("Nitpick_Util.filter_out_indices",
   162                                "indices unordered or out of range")
   163   in aux 0 js xs end
   164 
   165 (* 'a list -> 'a list list -> 'a list list *)
   166 fun cartesian_product [] _ = []
   167   | cartesian_product (x :: xs) yss =
   168     map (cons x) yss @ cartesian_product xs yss
   169 (* 'a list list -> 'a list list *)
   170 fun n_fold_cartesian_product xss = fold_rev cartesian_product xss [[]]
   171 (* ''a list -> (''a * ''a) list *)
   172 fun all_distinct_unordered_pairs_of [] = []
   173   | all_distinct_unordered_pairs_of (x :: xs) =
   174     map (pair x) xs @ all_distinct_unordered_pairs_of xs
   175 
   176 (* (int * int) list -> int -> int list *)
   177 val nth_combination =
   178   let
   179     (* (int * int) list -> int -> int list * int *)
   180     fun aux [] n = ([], n)
   181       | aux ((k, j0) :: xs) n =
   182         let val (js, n) = aux xs n in ((n mod k) + j0 :: js, n div k) end
   183   in fst oo aux end
   184 
   185 (* (int * int) list -> int list list *)
   186 val all_combinations = n_fold_cartesian_product o map (uncurry index_seq o swap)
   187 
   188 (* 'a list -> 'a list list *)
   189 fun all_permutations [] = [[]]
   190   | all_permutations xs =
   191     maps (fn j => map (cons (nth xs j)) (all_permutations (nth_drop j xs)))
   192          (index_seq 0 (length xs))
   193 
   194 (* int -> 'a list -> 'a list list *)
   195 fun batch_list _ [] = []
   196   | batch_list k xs =
   197     if length xs <= k then [xs]
   198     else List.take (xs, k) :: batch_list k (List.drop (xs, k))
   199 
   200 (* int list -> 'a list -> 'a list list *)
   201 fun chunk_list_unevenly _ [] = []
   202   | chunk_list_unevenly [] ys = map single ys
   203   | chunk_list_unevenly (k :: ks) ys =
   204     let val (ys1, ys2) = chop k ys in ys1 :: chunk_list_unevenly ks ys2 end
   205 
   206 (* ('a -> 'b -> 'c -> 'd) -> 'a list -> 'b list -> 'c list -> 'd list *)
   207 fun map3 _ [] [] [] = []
   208   | map3 f (x :: xs) (y :: ys) (z :: zs) = f x y z :: map3 f xs ys zs
   209   | map3 _ _ _ _ = raise UnequalLengths
   210 
   211 (* ('a * 'a -> bool) -> ('a option * 'b) list -> 'a -> 'b option *)
   212 fun double_lookup eq ps key =
   213   case AList.lookup (fn (SOME x, SOME y) => eq (x, y) | _ => false) ps
   214                     (SOME key) of
   215     SOME z => SOME z
   216   | NONE => ps |> find_first (is_none o fst) |> Option.map snd
   217 (* (''a * ''a -> bool) -> (''a option * 'b) list -> ''a -> 'b option *)
   218 fun triple_lookup eq ps key =
   219   case AList.lookup (op =) ps (SOME key) of
   220     SOME z => SOME z
   221   | NONE => double_lookup eq ps key
   222 
   223 (* string -> string -> bool *)
   224 fun is_substring_of needle stack =
   225   not (Substring.isEmpty (snd (Substring.position needle
   226                                                   (Substring.full stack))))
   227 
   228 (* string -> string list -> string list *)
   229 fun serial_commas _ [] = ["??"]
   230   | serial_commas _ [s] = [s]
   231   | serial_commas conj [s1, s2] = [s1, conj, s2]
   232   | serial_commas conj [s1, s2, s3] = [s1 ^ ",", s2 ^ ",", conj, s3]
   233   | serial_commas conj (s :: ss) = s ^ "," :: serial_commas conj ss
   234 
   235 (* int -> string *)
   236 fun plural_s n = if n = 1 then "" else "s"
   237 (* 'a list -> string *)
   238 fun plural_s_for_list xs = plural_s (length xs)
   239 
   240 (* polarity -> polarity *)
   241 fun flip_polarity Pos = Neg
   242   | flip_polarity Neg = Pos
   243   | flip_polarity Neut = Neut
   244 
   245 val prop_T = @{typ prop}
   246 val bool_T = @{typ bool}
   247 val nat_T = @{typ nat}
   248 val int_T = @{typ int}
   249 
   250 (* string -> string *)
   251 val subscript = implode o map (prefix "\<^isub>") o explode
   252 (* int -> string *)
   253 fun nat_subscript n =
   254   (* cheap trick to ensure proper alphanumeric ordering for one- and two-digit
   255      numbers *)
   256   if n <= 9 then "\<^bsub>" ^ signed_string_of_int n ^ "\<^esub>"
   257   else subscript (string_of_int n)
   258 
   259 (* Time.time option -> ('a -> 'b) -> 'a -> 'b *)
   260 fun time_limit NONE = I
   261   | time_limit (SOME delay) = TimeLimit.timeLimit delay
   262 
   263 (* Time.time option -> tactic -> tactic *)
   264 fun DETERM_TIMEOUT delay tac st =
   265   Seq.of_list (the_list (time_limit delay (fn () => SINGLE tac st) ()))
   266 
   267 (* ('a -> 'b) -> 'a -> 'b *)
   268 fun setmp_show_all_types f =
   269   setmp_CRITICAL show_all_types
   270                  (! show_types orelse ! show_sorts orelse ! show_all_types) f
   271 
   272 val indent_size = 2
   273 
   274 (* string -> Pretty.T list *)
   275 val pstrs = Pretty.breaks o map Pretty.str o space_explode " "
   276 
   277 (* XML.tree -> string *)
   278 fun plain_string_from_xml_tree t =
   279   Buffer.empty |> XML.add_content t |> Buffer.content
   280 (* string -> string *)
   281 val plain_string_from_yxml = plain_string_from_xml_tree o YXML.parse
   282 
   283 (* string -> bool *)
   284 val is_long_identifier = forall Syntax.is_identifier o space_explode "."
   285 (* string -> string *)
   286 fun maybe_quote y =
   287   let val s = plain_string_from_yxml y in
   288     y |> (not (is_long_identifier (perhaps (try (unprefix "'")) s)) orelse
   289           OuterKeyword.is_keyword s) ? quote
   290   end
   291 
   292 end;