src/HOL/SMT/Tools/smt_translate.ML
changeset 36890 8e55aa1306c5
parent 36885 48cf03469dc6
equal deleted inserted replaced
36889:6d1ecdb81ff0 36890:8e55aa1306c5
     1 (*  Title:      HOL/SMT/Tools/smt_translate.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Translate theorems into an SMT intermediate format and serialize them.
       
     5 *)
       
     6 
       
     7 signature SMT_TRANSLATE =
       
     8 sig
       
     9   (* intermediate term structure *)
       
    10   datatype squant = SForall | SExists
       
    11   datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
       
    12   datatype sterm =
       
    13     SVar of int |
       
    14     SApp of string * sterm list |
       
    15     SLet of string * sterm * sterm |
       
    16     SQua of squant * string list * sterm spattern list * sterm
       
    17 
       
    18   (* configuration options *)
       
    19   type prefixes = {sort_prefix: string, func_prefix: string}
       
    20   type strict = {
       
    21     is_builtin_conn: string * typ -> bool,
       
    22     is_builtin_pred: string * typ -> bool,
       
    23     is_builtin_distinct: bool}
       
    24   type builtins = {
       
    25     builtin_typ: typ -> string option,
       
    26     builtin_num: typ -> int -> string option,
       
    27     builtin_fun: string * typ -> term list -> (string * term list) option }
       
    28   datatype smt_theory = Integer | Real | Bitvector
       
    29   type sign = {
       
    30     theories: smt_theory list,
       
    31     sorts: string list,
       
    32     funcs: (string * (string list * string)) list }
       
    33   type config = {
       
    34     prefixes: prefixes,
       
    35     strict: strict option,
       
    36     builtins: builtins,
       
    37     serialize: sign -> sterm list -> string }
       
    38   type recon = {
       
    39     typs: typ Symtab.table,
       
    40     terms: term Symtab.table,
       
    41     unfolds: thm list,
       
    42     assms: thm list option }
       
    43 
       
    44   val translate: config -> Proof.context -> thm list -> string * recon
       
    45 end
       
    46 
       
    47 structure SMT_Translate: SMT_TRANSLATE =
       
    48 struct
       
    49 
       
    50 (* intermediate term structure *)
       
    51 
       
    52 datatype squant = SForall | SExists
       
    53 
       
    54 datatype 'a spattern = SPat of 'a list | SNoPat of 'a list
       
    55 
       
    56 datatype sterm =
       
    57   SVar of int |
       
    58   SApp of string * sterm list |
       
    59   SLet of string * sterm * sterm |
       
    60   SQua of squant * string list * sterm spattern list * sterm
       
    61 
       
    62 
       
    63 
       
    64 (* configuration options *)
       
    65 
       
    66 type prefixes = {sort_prefix: string, func_prefix: string}
       
    67 
       
    68 type strict = {
       
    69   is_builtin_conn: string * typ -> bool,
       
    70   is_builtin_pred: string * typ -> bool,
       
    71   is_builtin_distinct: bool}
       
    72 
       
    73 type builtins = {
       
    74   builtin_typ: typ -> string option,
       
    75   builtin_num: typ -> int -> string option,
       
    76   builtin_fun: string * typ -> term list -> (string * term list) option }
       
    77 
       
    78 datatype smt_theory = Integer | Real | Bitvector
       
    79 
       
    80 type sign = {
       
    81   theories: smt_theory list,
       
    82   sorts: string list,
       
    83   funcs: (string * (string list * string)) list }
       
    84 
       
    85 type config = {
       
    86   prefixes: prefixes,
       
    87   strict: strict option,
       
    88   builtins: builtins,
       
    89   serialize: sign -> sterm list -> string }
       
    90 
       
    91 type recon = {
       
    92   typs: typ Symtab.table,
       
    93   terms: term Symtab.table,
       
    94   unfolds: thm list,
       
    95   assms: thm list option }
       
    96 
       
    97 
       
    98 
       
    99 (* utility functions *)
       
   100 
       
   101 val dest_funT =
       
   102   let
       
   103     fun dest Ts 0 T = (rev Ts, T)
       
   104       | dest Ts i (Type ("fun", [T, U])) = dest (T::Ts) (i-1) U
       
   105       | dest _ _ T = raise TYPE ("dest_funT", [T], [])
       
   106   in dest [] end
       
   107 
       
   108 val quantifier = (fn
       
   109     @{const_name All} => SOME SForall
       
   110   | @{const_name Ex} => SOME SExists
       
   111   | _ => NONE)
       
   112 
       
   113 fun group_quant qname Ts (t as Const (q, _) $ Abs (_, T, u)) =
       
   114       if q = qname then group_quant qname (T :: Ts) u else (Ts, t)
       
   115   | group_quant _ Ts t = (Ts, t)
       
   116 
       
   117 fun dest_pat ts (Const (@{const_name pat}, _) $ t) = SPat (rev (t :: ts))
       
   118   | dest_pat ts (Const (@{const_name nopat}, _) $ t) = SNoPat (rev (t :: ts))
       
   119   | dest_pat ts (Const (@{const_name andpat}, _) $ p $ t) = dest_pat (t::ts) p
       
   120   | dest_pat _ t = raise TERM ("dest_pat", [t])
       
   121 
       
   122 fun dest_trigger (@{term trigger} $ tl $ t) =
       
   123       (map (dest_pat []) (HOLogic.dest_list tl), t)
       
   124   | dest_trigger t = ([], t)
       
   125 
       
   126 fun dest_quant qn T t = quantifier qn |> Option.map (fn q =>
       
   127   let
       
   128     val (Ts, u) = group_quant qn [T] t
       
   129     val (ps, b) = dest_trigger u
       
   130   in (q, rev Ts, ps, b) end)
       
   131 
       
   132 fun fold_map_pat f (SPat ts) = fold_map f ts #>> SPat
       
   133   | fold_map_pat f (SNoPat ts) = fold_map f ts #>> SNoPat
       
   134 
       
   135 fun prop_of thm = HOLogic.dest_Trueprop (Thm.prop_of thm)
       
   136 
       
   137 
       
   138 
       
   139 (* enforce a strict separation between formulas and terms *)
       
   140 
       
   141 val term_eq_rewr = @{lemma "x term_eq y == x = y" by (simp add: term_eq_def)}
       
   142 
       
   143 val term_bool = @{lemma "~(True term_eq False)" by (simp add: term_eq_def)}
       
   144 val term_bool' = Simplifier.rewrite_rule [term_eq_rewr] term_bool
       
   145 
       
   146 
       
   147 val needs_rewrite = Thm.prop_of #> Term.exists_subterm (fn
       
   148     Const (@{const_name Let}, _) => true
       
   149   | @{term "op = :: bool => _"} $ _ $ @{term True} => true
       
   150   | Const (@{const_name If}, _) $ _ $ @{term True} $ @{term False} => true
       
   151   | _ => false)
       
   152 
       
   153 val rewrite_rules = [
       
   154   Let_def,
       
   155   @{lemma "P = True == P" by (rule eq_reflection) simp},
       
   156   @{lemma "if P then True else False == P" by (rule eq_reflection) simp}]
       
   157 
       
   158 fun rewrite ctxt = Simplifier.full_rewrite
       
   159   (Simplifier.context ctxt empty_ss addsimps rewrite_rules)
       
   160 
       
   161 fun normalize ctxt thm =
       
   162   if needs_rewrite thm then Conv.fconv_rule (rewrite ctxt) thm else thm
       
   163 
       
   164 val unfold_rules = term_eq_rewr :: rewrite_rules
       
   165 
       
   166 
       
   167 val revert_types =
       
   168   let
       
   169     fun revert @{typ prop} = @{typ bool}
       
   170       | revert (Type (n, Ts)) = Type (n, map revert Ts)
       
   171       | revert T = T
       
   172   in Term.map_types revert end
       
   173 
       
   174 
       
   175 fun strictify {is_builtin_conn, is_builtin_pred, is_builtin_distinct} ctxt =
       
   176   let
       
   177 
       
   178     fun is_builtin_conn' (@{const_name True}, _) = false
       
   179       | is_builtin_conn' (@{const_name False}, _) = false
       
   180       | is_builtin_conn' c = is_builtin_conn c
       
   181 
       
   182     val propT = @{typ prop} and boolT = @{typ bool}
       
   183     val as_propT = (fn @{typ bool} => propT | T => T)
       
   184     fun mapTs f g = Term.strip_type #> (fn (Ts, T) => map f Ts ---> g T)
       
   185     fun conn (n, T) = (n, mapTs as_propT as_propT T)
       
   186     fun pred (n, T) = (n, mapTs I as_propT T)
       
   187 
       
   188     val term_eq = @{term "op = :: bool => _"} |> Term.dest_Const |> pred
       
   189     fun as_term t = Const term_eq $ t $ @{term True}
       
   190 
       
   191     val if_term = Const (@{const_name If}, [propT, boolT, boolT] ---> boolT)
       
   192     fun wrap_in_if t = if_term $ t $ @{term True} $ @{term False}
       
   193 
       
   194     fun in_list T f t = HOLogic.mk_list T (map f (HOLogic.dest_list t))
       
   195 
       
   196     fun in_term t =
       
   197       (case Term.strip_comb t of
       
   198         (c as Const (@{const_name If}, _), [t1, t2, t3]) =>
       
   199           c $ in_form t1 $ in_term t2 $ in_term t3
       
   200       | (h as Const c, ts) =>
       
   201           if is_builtin_conn' (conn c) orelse is_builtin_pred (pred c)
       
   202           then wrap_in_if (in_form t)
       
   203           else Term.list_comb (h, map in_term ts)
       
   204       | (h as Free _, ts) => Term.list_comb (h, map in_term ts)
       
   205       | _ => t)
       
   206 
       
   207     and in_pat ((c as Const (@{const_name pat}, _)) $ t) = c $ in_term t
       
   208       | in_pat ((c as Const (@{const_name nopat}, _)) $ t) = c $ in_term t
       
   209       | in_pat ((c as Const (@{const_name andpat}, _)) $ p $ t) =
       
   210           c $ in_pat p $ in_term t
       
   211       | in_pat t = raise TERM ("in_pat", [t])
       
   212 
       
   213     and in_pats p = in_list @{typ pattern} in_pat p
       
   214 
       
   215     and in_trig ((c as @{term trigger}) $ p $ t) = c $ in_pats p $ in_form t
       
   216       | in_trig t = in_form t
       
   217 
       
   218     and in_form t =
       
   219       (case Term.strip_comb t of
       
   220         (q as Const (qn, _), [Abs (n, T, t')]) =>
       
   221           if is_some (quantifier qn) then q $ Abs (n, T, in_trig t')
       
   222           else as_term (in_term t)
       
   223       | (Const (c as (@{const_name distinct}, T)), [t']) =>
       
   224           if is_builtin_distinct then Const (pred c) $ in_list T in_term t'
       
   225           else as_term (in_term t)
       
   226       | (Const c, ts) =>
       
   227           if is_builtin_conn (conn c)
       
   228           then Term.list_comb (Const (conn c), map in_form ts)
       
   229           else if is_builtin_pred (pred c)
       
   230           then Term.list_comb (Const (pred c), map in_term ts)
       
   231           else as_term (in_term t)
       
   232       | _ => as_term (in_term t))
       
   233   in
       
   234     map (normalize ctxt) #> (fn thms => ((unfold_rules, term_bool' :: thms),
       
   235     map (in_form o prop_of) (term_bool :: thms)))
       
   236   end
       
   237 
       
   238 
       
   239 
       
   240 (* translation from Isabelle terms into SMT intermediate terms *)
       
   241 
       
   242 val empty_context = (1, Typtab.empty, 1, Termtab.empty, [])
       
   243 
       
   244 fun make_sign (_, typs, _, terms, thys) = {
       
   245   theories = thys,
       
   246   sorts = Typtab.fold (cons o snd) typs [],
       
   247   funcs = Termtab.fold (cons o snd) terms [] }
       
   248 
       
   249 fun make_recon (unfolds, assms) (_, typs, _, terms, _) = {
       
   250   typs = Symtab.make (map swap (Typtab.dest typs)),
       
   251   terms = Symtab.make (map (fn (t, (n, _)) => (n, t)) (Termtab.dest terms)),
       
   252   unfolds = unfolds,
       
   253   assms = SOME assms }
       
   254 
       
   255 fun string_of_index pre i = pre ^ string_of_int i
       
   256 
       
   257 fun add_theory T (Tidx, typs, idx, terms, thys) =
       
   258   let
       
   259     fun add @{typ int} = insert (op =) Integer
       
   260       | add @{typ real} = insert (op =) Real
       
   261       | add (Type (@{type_name word}, _)) = insert (op =) Bitvector
       
   262       | add (Type (_, Ts)) = fold add Ts
       
   263       | add _ = I
       
   264   in (Tidx, typs, idx, terms, add T thys) end
       
   265 
       
   266 fun fresh_typ sort_prefix T (cx as (Tidx, typs, idx, terms, thys)) =
       
   267   (case Typtab.lookup typs T of
       
   268     SOME s => (s, cx)
       
   269   | NONE =>
       
   270       let
       
   271         val s = string_of_index sort_prefix Tidx
       
   272         val typs' = Typtab.update (T, s) typs
       
   273       in (s, (Tidx+1, typs', idx, terms, thys)) end)
       
   274 
       
   275 fun fresh_fun func_prefix t ss (cx as (Tidx, typs, idx, terms, thys)) =
       
   276   (case Termtab.lookup terms t of
       
   277     SOME (f, _) => (f, cx)
       
   278   | NONE =>
       
   279       let
       
   280         val f = string_of_index func_prefix idx
       
   281         val terms' = Termtab.update (revert_types t, (f, ss)) terms
       
   282       in (f, (Tidx, typs, idx+1, terms', thys)) end)
       
   283 
       
   284 fun relaxed thms = (([], thms), map prop_of thms)
       
   285 
       
   286 fun with_context f (ths, ts) =
       
   287   let val (us, context) = fold_map f ts empty_context
       
   288   in ((make_sign context, us), make_recon ths context) end
       
   289 
       
   290 
       
   291 fun translate {prefixes, strict, builtins, serialize} ctxt =
       
   292   let
       
   293     val {sort_prefix, func_prefix} = prefixes
       
   294     val {builtin_typ, builtin_num, builtin_fun} = builtins
       
   295 
       
   296     fun transT T = add_theory T #>
       
   297       (case builtin_typ T of
       
   298         SOME n => pair n
       
   299       | NONE => fresh_typ sort_prefix T)
       
   300 
       
   301     fun app n ts = SApp (n, ts)
       
   302 
       
   303     fun trans t =
       
   304       (case Term.strip_comb t of
       
   305         (Const (qn, _), [Abs (_, T, t1)]) =>
       
   306           (case dest_quant qn T t1 of
       
   307             SOME (q, Ts, ps, b) =>
       
   308               fold_map transT Ts ##>> fold_map (fold_map_pat trans) ps ##>>
       
   309               trans b #>> (fn ((Ts', ps'), b') => SQua (q, Ts', ps', b'))
       
   310           | NONE => raise TERM ("intermediate", [t]))
       
   311       | (Const (@{const_name Let}, _), [t1, Abs (_, T, t2)]) =>
       
   312           transT T ##>> trans t1 ##>> trans t2 #>>
       
   313           (fn ((U, u1), u2) => SLet (U, u1, u2))
       
   314       | (h as Const (c as (@{const_name distinct}, T)), [t1]) =>
       
   315           (case builtin_fun c (HOLogic.dest_list t1) of
       
   316             SOME (n, ts) => add_theory T #> fold_map trans ts #>> app n
       
   317           | NONE => transs h T [t1])
       
   318       | (h as Const (c as (_, T)), ts) =>
       
   319           (case try HOLogic.dest_number t of
       
   320             SOME (T, i) =>
       
   321               (case builtin_num T i of
       
   322                 SOME n => add_theory T #> pair (SApp (n, []))
       
   323               | NONE => transs t T [])
       
   324           | NONE =>
       
   325               (case builtin_fun c ts of
       
   326                 SOME (n, ts') => add_theory T #> fold_map trans ts' #>> app n
       
   327               | NONE => transs h T ts))
       
   328       | (h as Free (_, T), ts) => transs h T ts
       
   329       | (Bound i, []) => pair (SVar i)
       
   330       | _ => raise TERM ("intermediate", [t]))
       
   331 
       
   332     and transs t T ts =
       
   333       let val (Us, U) = dest_funT (length ts) T
       
   334       in
       
   335         fold_map transT Us ##>> transT U #-> (fn Up =>
       
   336         fresh_fun func_prefix t Up ##>> fold_map trans ts #>> SApp)
       
   337       end
       
   338   in
       
   339     (if is_some strict then strictify (the strict) ctxt else relaxed) #>
       
   340     with_context trans #>> uncurry serialize
       
   341   end
       
   342 
       
   343 end