src/HOL/Tools/SMT/z3_proof_reconstruction.ML
changeset 36890 8e55aa1306c5
child 36891 bcd6fce5bf06
equal deleted inserted replaced
36889:6d1ecdb81ff0 36890:8e55aa1306c5
       
     1 (*  Title:      HOL/Tools/SMT/z3_proof_reconstruction.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Proof reconstruction for proofs found by Z3.
       
     5 *)
       
     6 
       
     7 signature Z3_PROOF_RECONSTRUCTION =
       
     8 sig
       
     9   val trace_assms: bool Config.T
       
    10   val reconstruct: string list * SMT_Translate.recon -> Proof.context ->
       
    11     thm * Proof.context
       
    12   val setup: theory -> theory
       
    13 end
       
    14 
       
    15 structure Z3_Proof_Reconstruction: Z3_PROOF_RECONSTRUCTION =
       
    16 struct
       
    17 
       
    18 structure P = Z3_Proof_Parser
       
    19 structure T = Z3_Proof_Tools
       
    20 structure L = Z3_Proof_Literals
       
    21 
       
    22 fun z3_exn msg = raise SMT_Solver.SMT ("Z3 proof reconstruction: " ^ msg)
       
    23 
       
    24 
       
    25 
       
    26 (** net of schematic rules **)
       
    27 
       
    28 val z3_ruleN = "z3_rule"
       
    29 
       
    30 local
       
    31   val description = "declaration of Z3 proof rules"
       
    32 
       
    33   val eq = Thm.eq_thm
       
    34 
       
    35   structure Z3_Rules = Generic_Data
       
    36   (
       
    37     type T = thm Net.net
       
    38     val empty = Net.empty
       
    39     val extend = I
       
    40     val merge = Net.merge eq
       
    41   )
       
    42 
       
    43   val prep = `Thm.prop_of o Simplifier.rewrite_rule [L.rewrite_true]
       
    44 
       
    45   fun ins thm net = Net.insert_term eq (prep thm) net handle Net.INSERT => net
       
    46   fun del thm net = Net.delete_term eq (prep thm) net handle Net.DELETE => net
       
    47 
       
    48   val add = Thm.declaration_attribute (Z3_Rules.map o ins)
       
    49   val del = Thm.declaration_attribute (Z3_Rules.map o del)
       
    50 in
       
    51 
       
    52 fun get_schematic_rules ctxt = Net.content (Z3_Rules.get (Context.Proof ctxt))
       
    53 
       
    54 fun by_schematic_rule ctxt ct =
       
    55   the (T.net_instance (Z3_Rules.get (Context.Proof ctxt)) ct)
       
    56 
       
    57 val z3_rules_setup =
       
    58   Attrib.setup (Binding.name z3_ruleN) (Attrib.add_del add del) description #>
       
    59   PureThy.add_thms_dynamic (Binding.name z3_ruleN, Net.content o Z3_Rules.get)
       
    60 
       
    61 end
       
    62 
       
    63 
       
    64 
       
    65 (** proof tools **)
       
    66 
       
    67 fun named ctxt name prover ct =
       
    68   let val _ = SMT_Solver.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
       
    69   in prover ct end
       
    70 
       
    71 fun NAMED ctxt name tac i st =
       
    72   let val _ = SMT_Solver.trace_msg ctxt I ("Z3: trying " ^ name ^ " ...")
       
    73   in tac i st end
       
    74 
       
    75 fun pretty_goal ctxt thms t =
       
    76   [Pretty.block [Pretty.str "proposition: ", Syntax.pretty_term ctxt t]]
       
    77   |> not (null thms) ? cons (Pretty.big_list "assumptions:"
       
    78        (map (Display.pretty_thm ctxt) thms))
       
    79 
       
    80 fun try_apply ctxt thms =
       
    81   let
       
    82     fun try_apply_err ct = Pretty.string_of (Pretty.chunks [
       
    83       Pretty.big_list ("Z3 found a proof," ^
       
    84         " but proof reconstruction failed at the following subgoal:")
       
    85         (pretty_goal ctxt thms (Thm.term_of ct)),
       
    86       Pretty.str ("Adding a rule to the lemma group " ^ quote z3_ruleN ^
       
    87         " might solve this problem.")])
       
    88 
       
    89     fun apply [] ct = error (try_apply_err ct)
       
    90       | apply (prover :: provers) ct =
       
    91           (case try prover ct of
       
    92             SOME thm => (SMT_Solver.trace_msg ctxt I "Z3: succeeded"; thm)
       
    93           | NONE => apply provers ct)
       
    94 
       
    95   in apply o cons (named ctxt "schematic rules" (by_schematic_rule ctxt)) end
       
    96 
       
    97 
       
    98 
       
    99 (** theorems and proofs **)
       
   100 
       
   101 (* theorem incarnations *)
       
   102 
       
   103 datatype theorem =
       
   104   Thm of thm | (* theorem without special features *)
       
   105   MetaEq of thm | (* meta equality "t == s" *)
       
   106   Literals of thm * L.littab
       
   107     (* "P1 & ... & Pn" and table of all literals P1, ..., Pn *)
       
   108 
       
   109 fun thm_of (Thm thm) = thm
       
   110   | thm_of (MetaEq thm) = thm COMP @{thm meta_eq_to_obj_eq}
       
   111   | thm_of (Literals (thm, _)) = thm
       
   112 
       
   113 fun meta_eq_of (MetaEq thm) = thm
       
   114   | meta_eq_of p = mk_meta_eq (thm_of p)
       
   115 
       
   116 fun literals_of (Literals (_, lits)) = lits
       
   117   | literals_of p = L.make_littab [thm_of p]
       
   118 
       
   119 
       
   120 (* proof representation *)
       
   121 
       
   122 datatype proof = Unproved of P.proof_step | Proved of theorem
       
   123 
       
   124 
       
   125 
       
   126 (** core proof rules **)
       
   127 
       
   128 (* assumption *)
       
   129 
       
   130 val (trace_assms, trace_assms_setup) =
       
   131   Attrib.config_bool "z3_trace_assms" (K false)
       
   132 
       
   133 local
       
   134   val remove_trigger = @{lemma "trigger t p == p"
       
   135     by (rule eq_reflection, rule trigger_def)}
       
   136 
       
   137   val prep_rules = [@{thm Let_def}, remove_trigger, L.rewrite_true]
       
   138 
       
   139   fun rewrite_conv ctxt eqs = Simplifier.full_rewrite
       
   140     (Simplifier.context ctxt Simplifier.empty_ss addsimps eqs)
       
   141 
       
   142   fun rewrites ctxt eqs = map (Conv.fconv_rule (rewrite_conv ctxt eqs))
       
   143 
       
   144   fun trace ctxt thm =
       
   145     if Config.get ctxt trace_assms
       
   146     then tracing (Display.string_of_thm ctxt thm)
       
   147     else ()
       
   148 
       
   149   fun lookup_assm ctxt assms ct =
       
   150     (case T.net_instance assms ct of
       
   151       SOME thm => (trace ctxt thm; thm)
       
   152     | _ => z3_exn ("not asserted: " ^
       
   153         quote (Syntax.string_of_term ctxt (Thm.term_of ct))))
       
   154 in
       
   155 fun prepare_assms ctxt unfolds assms =
       
   156   let
       
   157     val unfolds' = rewrites ctxt [L.rewrite_true] unfolds
       
   158     val assms' = rewrites ctxt (union Thm.eq_thm unfolds' prep_rules) assms
       
   159   in (unfolds', T.thm_net_of assms') end
       
   160 
       
   161 fun asserted _ NONE ct = Thm (Thm.assume ct)
       
   162   | asserted ctxt (SOME (unfolds, assms)) ct =
       
   163       let val revert_conv = rewrite_conv ctxt unfolds
       
   164       in Thm (T.with_conv revert_conv (lookup_assm ctxt assms) ct) end
       
   165 end
       
   166 
       
   167 
       
   168 
       
   169 (* P = Q ==> P ==> Q   or   P --> Q ==> P ==> Q *)
       
   170 local
       
   171   val meta_iffD1 = @{lemma "P == Q ==> P ==> (Q::bool)" by simp}
       
   172   val meta_iffD1_c = T.precompose2 Thm.dest_binop meta_iffD1
       
   173 
       
   174   val iffD1_c = T.precompose2 (Thm.dest_binop o Thm.dest_arg) @{thm iffD1}
       
   175   val mp_c = T.precompose2 (Thm.dest_binop o Thm.dest_arg) @{thm mp}
       
   176 in
       
   177 fun mp (MetaEq thm) p = Thm (Thm.implies_elim (T.compose meta_iffD1_c thm) p)
       
   178   | mp p_q p = 
       
   179       let
       
   180         val pq = thm_of p_q
       
   181         val thm = T.compose iffD1_c pq handle THM _ => T.compose mp_c pq
       
   182       in Thm (Thm.implies_elim thm p) end
       
   183 end
       
   184 
       
   185 
       
   186 
       
   187 (* and_elim:     P1 & ... & Pn ==> Pi *)
       
   188 (* not_or_elim:  ~(P1 | ... | Pn) ==> ~Pi *)
       
   189 local
       
   190   fun is_sublit conj t = L.exists_lit conj (fn u => u aconv t)
       
   191 
       
   192   fun derive conj t lits idx ptab =
       
   193     let
       
   194       val lit = the (L.get_first_lit (is_sublit conj t) lits)
       
   195       val ls = L.explode conj false false [t] lit
       
   196       val lits' = fold L.insert_lit ls (L.delete_lit lit lits)
       
   197 
       
   198       fun upd (Proved thm) = Proved (Literals (thm_of thm, lits'))
       
   199         | upd p = p
       
   200     in (the (L.lookup_lit lits' t), Inttab.map_entry idx upd ptab) end
       
   201 
       
   202   fun lit_elim conj (p, idx) ct ptab =
       
   203     let val lits = literals_of p
       
   204     in
       
   205       (case L.lookup_lit lits (T.term_of ct) of
       
   206         SOME lit => (Thm lit, ptab)
       
   207       | NONE => apfst Thm (derive conj (T.term_of ct) lits idx ptab))
       
   208     end
       
   209 in
       
   210 val and_elim = lit_elim true
       
   211 val not_or_elim = lit_elim false
       
   212 end
       
   213 
       
   214 
       
   215 
       
   216 (* P1, ..., Pn |- False ==> |- ~P1 | ... | ~Pn *)
       
   217 local
       
   218   fun step lit thm =
       
   219     Thm.implies_elim (Thm.implies_intr (Thm.cprop_of lit) thm) lit
       
   220   val explode_disj = L.explode false false false
       
   221   fun intro hyps thm th = fold step (explode_disj hyps th) thm
       
   222 
       
   223   fun dest_ccontr ct = [Thm.dest_arg (Thm.dest_arg (Thm.dest_arg1 ct))]
       
   224   val ccontr = T.precompose dest_ccontr @{thm ccontr}
       
   225 in
       
   226 fun lemma thm ct =
       
   227   let
       
   228     val cu = Thm.capply @{cterm Not} (Thm.dest_arg ct)
       
   229     val hyps = map_filter (try HOLogic.dest_Trueprop) (#hyps (Thm.rep_thm thm))
       
   230   in Thm (T.compose ccontr (T.under_assumption (intro hyps thm) cu)) end
       
   231 end
       
   232 
       
   233 
       
   234 
       
   235 (* \/{P1, ..., Pn, Q1, ..., Qn}, ~P1, ..., ~Pn ==> \/{Q1, ..., Qn} *)
       
   236 local
       
   237   val explode_disj = L.explode false true false
       
   238   val join_disj = L.join false
       
   239   fun unit thm thms th =
       
   240     let val t = @{term Not} $ T.prop_of thm and ts = map T.prop_of thms
       
   241     in join_disj (L.make_littab (thms @ explode_disj ts th)) t end
       
   242 
       
   243   fun dest_arg2 ct = Thm.dest_arg (Thm.dest_arg ct)
       
   244   fun dest ct = pairself dest_arg2 (Thm.dest_binop ct)
       
   245   val contrapos = T.precompose2 dest @{lemma "(~P ==> ~Q) ==> Q ==> P" by fast}
       
   246 in
       
   247 fun unit_resolution thm thms ct =
       
   248   Thm.capply @{cterm Not} (Thm.dest_arg ct)
       
   249   |> T.under_assumption (unit thm thms)
       
   250   |> Thm o T.discharge thm o T.compose contrapos
       
   251 end
       
   252 
       
   253 
       
   254 
       
   255 (* P ==> P == True   or   P ==> P == False *)
       
   256 local
       
   257   val iff1 = @{lemma "P ==> P == (~ False)" by simp}
       
   258   val iff2 = @{lemma "~P ==> P == False" by simp}
       
   259 in
       
   260 fun iff_true thm = MetaEq (thm COMP iff1)
       
   261 fun iff_false thm = MetaEq (thm COMP iff2)
       
   262 end
       
   263 
       
   264 
       
   265 
       
   266 (* distributivity of | over & *)
       
   267 fun distributivity ctxt = Thm o try_apply ctxt [] [
       
   268   named ctxt "fast" (T.by_tac (Classical.best_tac HOL_cs))]
       
   269     (* FIXME: not very well tested *)
       
   270 
       
   271 
       
   272 
       
   273 (* Tseitin-like axioms *)
       
   274 
       
   275 local
       
   276   val disjI1 = @{lemma "(P ==> Q) ==> ~P | Q" by fast}
       
   277   val disjI2 = @{lemma "(~P ==> Q) ==> P | Q" by fast}
       
   278   val disjI3 = @{lemma "(~Q ==> P) ==> P | Q" by fast}
       
   279   val disjI4 = @{lemma "(Q ==> P) ==> P | ~Q" by fast}
       
   280 
       
   281   fun prove' conj1 conj2 ct2 thm =
       
   282     let val lits = L.true_thm :: L.explode conj1 true (conj1 <> conj2) [] thm
       
   283     in L.join conj2 (L.make_littab lits) (Thm.term_of ct2) end
       
   284 
       
   285   fun prove rule (ct1, conj1) (ct2, conj2) =
       
   286     T.under_assumption (prove' conj1 conj2 ct2) ct1 COMP rule
       
   287 
       
   288   fun prove_def_axiom ct =
       
   289     let val (ct1, ct2) = Thm.dest_binop (Thm.dest_arg ct)
       
   290     in
       
   291       (case Thm.term_of ct1 of
       
   292         @{term Not} $ (@{term "op &"} $ _ $ _) =>
       
   293           prove disjI1 (Thm.dest_arg ct1, true) (ct2, true)
       
   294       | @{term "op &"} $ _ $ _ =>
       
   295           prove disjI3 (Thm.capply @{cterm Not} ct2, false) (ct1, true)
       
   296       | @{term Not} $ (@{term "op |"} $ _ $ _) =>
       
   297           prove disjI3 (Thm.capply @{cterm Not} ct2, false) (ct1, false)
       
   298       | @{term "op |"} $ _ $ _ =>
       
   299           prove disjI2 (Thm.capply @{cterm Not} ct1, false) (ct2, true)
       
   300       | Const (@{const_name distinct}, _) $ _ =>
       
   301           let
       
   302             fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv cv)
       
   303             fun prv cu =
       
   304               let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
       
   305               in prove disjI4 (Thm.dest_arg cu2, true) (cu1, true) end
       
   306           in T.with_conv (dis_conv T.unfold_distinct_conv) prv ct end
       
   307       | @{term Not} $ (Const (@{const_name distinct}, _) $ _) =>
       
   308           let
       
   309             fun dis_conv cv = Conv.arg_conv (Conv.arg1_conv (Conv.arg_conv cv))
       
   310             fun prv cu =
       
   311               let val (cu1, cu2) = Thm.dest_binop (Thm.dest_arg cu)
       
   312               in prove disjI1 (Thm.dest_arg cu1, true) (cu2, true) end
       
   313           in T.with_conv (dis_conv T.unfold_distinct_conv) prv ct end
       
   314       | _ => raise CTERM ("prove_def_axiom", [ct]))
       
   315     end
       
   316 
       
   317   val rewr_if =
       
   318     @{lemma "(if P then Q1 else Q2) = ((P --> Q1) & (~P --> Q2))" by simp}
       
   319 in
       
   320 fun def_axiom ctxt = Thm o try_apply ctxt [] [
       
   321   named ctxt "conj/disj/distinct" prove_def_axiom,
       
   322   T.by_abstraction ctxt [] (fn ctxt' =>
       
   323     named ctxt' "simp+fast" (T.by_tac (
       
   324       Simplifier.simp_tac (HOL_ss addsimps [rewr_if])
       
   325       THEN_ALL_NEW Classical.best_tac HOL_cs)))]
       
   326 end
       
   327 
       
   328 
       
   329 
       
   330 (* local definitions *)
       
   331 local
       
   332   val intro_rules = [
       
   333     @{lemma "n == P ==> (~n | P) & (n | ~P)" by simp},
       
   334     @{lemma "n == (if P then s else t) ==> (~P | n = s) & (P | n = t)"
       
   335       by simp},
       
   336     @{lemma "n == P ==> n = P" by (rule meta_eq_to_obj_eq)} ]
       
   337 
       
   338   val apply_rules = [
       
   339     @{lemma "(~n | P) & (n | ~P) ==> P == n" by (atomize(full)) fast},
       
   340     @{lemma "(~P | n = s) & (P | n = t) ==> (if P then s else t) == n"
       
   341       by (atomize(full)) fastsimp} ]
       
   342 
       
   343   val inst_rule = T.match_instantiate Thm.dest_arg
       
   344 
       
   345   fun apply_rule ct =
       
   346     (case get_first (try (inst_rule ct)) intro_rules of
       
   347       SOME thm => thm
       
   348     | NONE => raise CTERM ("intro_def", [ct]))
       
   349 in
       
   350 fun intro_def ct = T.make_hyp_def (apply_rule ct) #>> Thm
       
   351 
       
   352 fun apply_def thm =
       
   353   get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
       
   354   |> the_default (Thm thm)
       
   355 end
       
   356 
       
   357 
       
   358 
       
   359 (* negation normal form *)
       
   360 
       
   361 local
       
   362   val quant_rules1 = ([
       
   363     @{lemma "(!!x. P x == Q) ==> ALL x. P x == Q" by simp},
       
   364     @{lemma "(!!x. P x == Q) ==> EX x. P x == Q" by simp}], [
       
   365     @{lemma "(!!x. P x == Q x) ==> ALL x. P x == ALL x. Q x" by simp},
       
   366     @{lemma "(!!x. P x == Q x) ==> EX x. P x == EX x. Q x" by simp}])
       
   367 
       
   368   val quant_rules2 = ([
       
   369     @{lemma "(!!x. ~P x == Q) ==> ~(ALL x. P x) == Q" by simp},
       
   370     @{lemma "(!!x. ~P x == Q) ==> ~(EX x. P x) == Q" by simp}], [
       
   371     @{lemma "(!!x. ~P x == Q x) ==> ~(ALL x. P x) == EX x. Q x" by simp},
       
   372     @{lemma "(!!x. ~P x == Q x) ==> ~(EX x. P x) == ALL x. Q x" by simp}])
       
   373 
       
   374   fun nnf_quant_tac thm (qs as (qs1, qs2)) i st = (
       
   375     Tactic.rtac thm ORELSE'
       
   376     (Tactic.match_tac qs1 THEN' nnf_quant_tac thm qs) ORELSE'
       
   377     (Tactic.match_tac qs2 THEN' nnf_quant_tac thm qs)) i st
       
   378 
       
   379   fun nnf_quant vars qs p ct =
       
   380     T.as_meta_eq ct
       
   381     |> T.by_tac (nnf_quant_tac (T.varify vars (meta_eq_of p)) qs)
       
   382 
       
   383   fun prove_nnf ctxt = try_apply ctxt [] [
       
   384     named ctxt "conj/disj" L.prove_conj_disj_eq]
       
   385 in
       
   386 fun nnf ctxt vars ps ct =
       
   387   (case T.term_of ct of
       
   388     _ $ (l as Const _ $ Abs _) $ (r as Const _ $ Abs _) =>
       
   389       if l aconv r
       
   390       then MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
       
   391       else MetaEq (nnf_quant vars quant_rules1 (hd ps) ct)
       
   392   | _ $ (@{term Not} $ (Const _ $ Abs _)) $ (Const _ $ Abs _) =>
       
   393       MetaEq (nnf_quant vars quant_rules2 (hd ps) ct)
       
   394   | _ =>
       
   395       let
       
   396         val nnf_rewr_conv = Conv.arg_conv (Conv.arg_conv
       
   397           (T.unfold_eqs ctxt (map (Thm.symmetric o meta_eq_of) ps)))
       
   398       in Thm (T.with_conv nnf_rewr_conv (prove_nnf ctxt) ct) end)
       
   399 end
       
   400 
       
   401 
       
   402 
       
   403 (** equality proof rules **)
       
   404 
       
   405 (* |- t = t *)
       
   406 fun refl ct = MetaEq (Thm.reflexive (Thm.dest_arg (Thm.dest_arg ct)))
       
   407 
       
   408 
       
   409 
       
   410 (* s = t ==> t = s *)
       
   411 local
       
   412   val symm_rule = @{lemma "s = t ==> t == s" by simp}
       
   413 in
       
   414 fun symm (MetaEq thm) = MetaEq (Thm.symmetric thm)
       
   415   | symm p = MetaEq (thm_of p COMP symm_rule)
       
   416 end
       
   417 
       
   418 
       
   419 
       
   420 (* s = t ==> t = u ==> s = u *)
       
   421 local
       
   422   val trans1 = @{lemma "s == t ==> t =  u ==> s == u" by simp}
       
   423   val trans2 = @{lemma "s =  t ==> t == u ==> s == u" by simp}
       
   424   val trans3 = @{lemma "s =  t ==> t =  u ==> s == u" by simp}
       
   425 in
       
   426 fun trans (MetaEq thm1) (MetaEq thm2) = MetaEq (Thm.transitive thm1 thm2)
       
   427   | trans (MetaEq thm) q = MetaEq (thm_of q COMP (thm COMP trans1))
       
   428   | trans p (MetaEq thm) = MetaEq (thm COMP (thm_of p COMP trans2))
       
   429   | trans p q = MetaEq (thm_of q COMP (thm_of p COMP trans3))
       
   430 end
       
   431 
       
   432 
       
   433 
       
   434 (* t1 = s1 ==> ... ==> tn = sn ==> f t1 ... tn = f s1 .. sn
       
   435    (reflexive antecendents are droppped) *)
       
   436 local
       
   437   exception MONO
       
   438 
       
   439   fun prove_refl (ct, _) = Thm.reflexive ct
       
   440   fun prove_comb f g cp =
       
   441     let val ((ct1, ct2), (cu1, cu2)) = pairself Thm.dest_comb cp
       
   442     in Thm.combination (f (ct1, cu1)) (g (ct2, cu2)) end
       
   443   fun prove_arg f = prove_comb prove_refl f
       
   444 
       
   445   fun prove f cp = prove_comb (prove f) f cp handle CTERM _ => prove_refl cp
       
   446 
       
   447   fun prove_nary is_comb f =
       
   448     let
       
   449       fun prove (cp as (ct, _)) = f cp handle MONO =>
       
   450         if is_comb (Thm.term_of ct)
       
   451         then prove_comb (prove_arg prove) prove cp
       
   452         else prove_refl cp
       
   453     in prove end
       
   454 
       
   455   fun prove_list f n cp =
       
   456     if n = 0 then prove_refl cp
       
   457     else prove_comb (prove_arg f) (prove_list f (n-1)) cp
       
   458 
       
   459   fun with_length f (cp as (cl, _)) =
       
   460     f (length (HOLogic.dest_list (Thm.term_of cl))) cp
       
   461 
       
   462   fun prove_distinct f = prove_arg (with_length (prove_list f))
       
   463 
       
   464   fun prove_eq exn lookup cp =
       
   465     (case lookup (Logic.mk_equals (pairself Thm.term_of cp)) of
       
   466       SOME eq => eq
       
   467     | NONE => if exn then raise MONO else prove_refl cp)
       
   468   
       
   469   val prove_eq_exn = prove_eq true
       
   470   and prove_eq_safe = prove_eq false
       
   471 
       
   472   fun mono f (cp as (cl, _)) =
       
   473     (case Term.head_of (Thm.term_of cl) of
       
   474       @{term "op &"} => prove_nary L.is_conj (prove_eq_exn f)
       
   475     | @{term "op |"} => prove_nary L.is_disj (prove_eq_exn f)
       
   476     | Const (@{const_name distinct}, _) => prove_distinct (prove_eq_safe f)
       
   477     | _ => prove (prove_eq_safe f)) cp
       
   478 in
       
   479 fun monotonicity eqs ct =
       
   480   let
       
   481     val lookup = AList.lookup (op aconv) (map (`Thm.prop_of o meta_eq_of) eqs)
       
   482     val cp = Thm.dest_binop (Thm.dest_arg ct)
       
   483   in MetaEq (prove_eq_exn lookup cp handle MONO => mono lookup cp) end
       
   484 end
       
   485 
       
   486 
       
   487 
       
   488 (* |- f a b = f b a (where f is equality) *)
       
   489 local
       
   490   val rule = @{lemma "a = b == b = a" by (atomize(full)) (rule eq_commute)}
       
   491 in
       
   492 fun commutativity ct = MetaEq (T.match_instantiate I (T.as_meta_eq ct) rule)
       
   493 end
       
   494 
       
   495 
       
   496 
       
   497 (** quantifier proof rules **)
       
   498 
       
   499 (* P ?x = Q ?x ==> (ALL x. P x) = (ALL x. Q x)
       
   500    P ?x = Q ?x ==> (EX x. P x) = (EX x. Q x)    *)
       
   501 local
       
   502   val rules = [
       
   503     @{lemma "(!!x. P x == Q x) ==> (ALL x. P x) == (ALL x. Q x)" by simp},
       
   504     @{lemma "(!!x. P x == Q x) ==> (EX x. P x) == (EX x. Q x)" by simp}]
       
   505 in
       
   506 fun quant_intro vars p ct =
       
   507   let
       
   508     val thm = meta_eq_of p
       
   509     val rules' = T.varify vars thm :: rules
       
   510     val cu = T.as_meta_eq ct
       
   511   in MetaEq (T.by_tac (REPEAT_ALL_NEW (Tactic.match_tac rules')) cu) end
       
   512 end
       
   513 
       
   514 
       
   515 
       
   516 (* |- ((ALL x. P x) | Q) = (ALL x. P x | Q) *)
       
   517 fun pull_quant ctxt = Thm o try_apply ctxt [] [
       
   518   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
       
   519     (* FIXME: not very well tested *)
       
   520 
       
   521 
       
   522 
       
   523 (* |- (ALL x. P x & Q x) = ((ALL x. P x) & (ALL x. Q x)) *)
       
   524 fun push_quant ctxt = Thm o try_apply ctxt [] [
       
   525   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
       
   526     (* FIXME: not very well tested *)
       
   527 
       
   528 
       
   529 
       
   530 (* |- (ALL x1 ... xn y1 ... yn. P x1 ... xn) = (ALL x1 ... xn. P x1 ... xn) *)
       
   531 local
       
   532   val elim_all = @{lemma "(ALL x. P) == P" by simp}
       
   533   val elim_ex = @{lemma "(EX x. P) == P" by simp}
       
   534 
       
   535   fun elim_unused_conv ctxt =
       
   536     Conv.params_conv ~1 (K (Conv.arg_conv (Conv.arg1_conv
       
   537       (More_Conv.rewrs_conv [elim_all, elim_ex])))) ctxt
       
   538 
       
   539   fun elim_unused_tac ctxt =
       
   540     REPEAT_ALL_NEW (
       
   541       Tactic.match_tac [@{thm refl}, @{thm iff_allI}, @{thm iff_exI}]
       
   542       ORELSE' CONVERSION (elim_unused_conv ctxt))
       
   543 in
       
   544 fun elim_unused_vars ctxt = Thm o T.by_tac (elim_unused_tac ctxt)
       
   545 end
       
   546 
       
   547 
       
   548 
       
   549 (* |- (ALL x1 ... xn. ~(x1 = t1 & ... xn = tn) | P x1 ... xn) = P t1 ... tn *)
       
   550 fun dest_eq_res ctxt = Thm o try_apply ctxt [] [
       
   551   named ctxt "fast" (T.by_tac (Classical.fast_tac HOL_cs))]
       
   552     (* FIXME: not very well tested *)
       
   553 
       
   554 
       
   555 
       
   556 (* |- ~(ALL x1...xn. P x1...xn) | P a1...an *)
       
   557 local
       
   558   val rule = @{lemma "~ P x | Q ==> ~(ALL x. P x) | Q" by fast}
       
   559 in
       
   560 val quant_inst = Thm o T.by_tac (
       
   561   REPEAT_ALL_NEW (Tactic.match_tac [rule])
       
   562   THEN' Tactic.rtac @{thm excluded_middle})
       
   563 end
       
   564 
       
   565 
       
   566 
       
   567 (* c = SOME x. P x |- (EX x. P x) = P c
       
   568    c = SOME x. ~ P x |- ~(ALL x. P x) = ~ P c *)
       
   569 local
       
   570   val elim_ex = @{lemma "EX x. P == P" by simp}
       
   571   val elim_all = @{lemma "~ (ALL x. P) == ~P" by simp}
       
   572   val sk_ex = @{lemma "c == SOME x. P x ==> EX x. P x == P c"
       
   573     by simp (intro eq_reflection some_eq_ex[symmetric])}
       
   574   val sk_all = @{lemma "c == SOME x. ~ P x ==> ~(ALL x. P x) == ~ P c"
       
   575     by (simp only: not_all) (intro eq_reflection some_eq_ex[symmetric])}
       
   576   val sk_ex_rule = ((sk_ex, I), elim_ex)
       
   577   and sk_all_rule = ((sk_all, Thm.dest_arg), elim_all)
       
   578 
       
   579   fun dest f sk_rule = 
       
   580     Thm.dest_comb (f (Thm.dest_arg (Thm.dest_arg (Thm.cprop_of sk_rule))))
       
   581   fun type_of f sk_rule = Thm.ctyp_of_term (snd (dest f sk_rule))
       
   582   fun pair2 (a, b) (c, d) = [(a, c), (b, d)]
       
   583   fun inst_sk (sk_rule, f) p c =
       
   584     Thm.instantiate ([(type_of f sk_rule, Thm.ctyp_of_term c)], []) sk_rule
       
   585     |> (fn sk' => Thm.instantiate ([], (pair2 (dest f sk') (p, c))) sk')
       
   586     |> Conv.fconv_rule (Thm.beta_conversion true)
       
   587 
       
   588   fun kind (Const (@{const_name Ex}, _) $ _) = (sk_ex_rule, I, I)
       
   589     | kind (@{term Not} $ (Const (@{const_name All}, _) $ _)) =
       
   590         (sk_all_rule, Thm.dest_arg, Thm.capply @{cterm Not})
       
   591     | kind t = raise TERM ("skolemize", [t])
       
   592 
       
   593   fun dest_abs_type (Abs (_, T, _)) = T
       
   594     | dest_abs_type t = raise TERM ("dest_abs_type", [t])
       
   595 
       
   596   fun bodies_of thy lhs rhs =
       
   597     let
       
   598       val (rule, dest, make) = kind (Thm.term_of lhs)
       
   599 
       
   600       fun dest_body idx cbs ct =
       
   601         let
       
   602           val cb = Thm.dest_arg (dest ct)
       
   603           val T = dest_abs_type (Thm.term_of cb)
       
   604           val cv = Thm.cterm_of thy (Var (("x", idx), T))
       
   605           val cu = make (Drule.beta_conv cb cv)
       
   606           val cbs' = (cv, cb) :: cbs
       
   607         in
       
   608           (snd (Thm.first_order_match (cu, rhs)), rev cbs')
       
   609           handle Pattern.MATCH => dest_body (idx+1) cbs' cu
       
   610         end
       
   611     in (rule, dest_body 1 [] lhs) end
       
   612 
       
   613   fun transitive f thm = Thm.transitive thm (f (Thm.rhs_of thm))
       
   614 
       
   615   fun sk_step (rule, elim) (cv, mct, cb) ((is, thm), ctxt) =
       
   616     (case mct of
       
   617       SOME ct =>
       
   618         ctxt
       
   619         |> T.make_hyp_def (inst_sk rule (Thm.instantiate_cterm ([], is) cb) ct)
       
   620         |>> pair ((cv, ct) :: is) o Thm.transitive thm
       
   621     | NONE => ((is, transitive (Conv.rewr_conv elim) thm), ctxt))
       
   622 in
       
   623 fun skolemize ct ctxt =
       
   624   let
       
   625     val (lhs, rhs) = Thm.dest_binop (Thm.dest_arg ct)
       
   626     val (rule, (ctab, cbs)) = bodies_of (ProofContext.theory_of ctxt) lhs rhs
       
   627     fun lookup_var (cv, cb) = (cv, AList.lookup (op aconvc) ctab cv, cb)
       
   628   in
       
   629     (([], Thm.reflexive lhs), ctxt)
       
   630     |> fold (sk_step rule) (map lookup_var cbs)
       
   631     |>> MetaEq o snd
       
   632   end
       
   633 end
       
   634 
       
   635 
       
   636 
       
   637 (** theory proof rules **)
       
   638 
       
   639 (* theory lemmas: linear arithmetic, arrays *)
       
   640 
       
   641 fun th_lemma ctxt simpset thms = Thm o try_apply ctxt thms [
       
   642   T.by_abstraction ctxt thms (fn ctxt' => T.by_tac (
       
   643     NAMED ctxt' "arith" (Arith_Data.arith_tac ctxt')
       
   644     ORELSE' NAMED ctxt' "simp+arith" (Simplifier.simp_tac simpset THEN_ALL_NEW
       
   645       Arith_Data.arith_tac ctxt')))]
       
   646 
       
   647 
       
   648 
       
   649 (* rewriting: prove equalities:
       
   650      * ACI of conjunction/disjunction
       
   651      * contradiction, excluded middle
       
   652      * logical rewriting rules (for negation, implication, equivalence,
       
   653          distinct)
       
   654      * normal forms for polynoms (integer/real arithmetic)
       
   655      * quantifier elimination over linear arithmetic
       
   656      * ... ? **)
       
   657 structure Z3_Simps = Named_Thms
       
   658 (
       
   659   val name = "z3_simp"
       
   660   val description = "simplification rules for Z3 proof reconstruction"
       
   661 )
       
   662 
       
   663 local
       
   664   fun spec_meta_eq_of thm =
       
   665     (case try (fn th => th RS @{thm spec}) thm of
       
   666       SOME thm' => spec_meta_eq_of thm'
       
   667     | NONE => mk_meta_eq thm)
       
   668 
       
   669   fun prep (Thm thm) = spec_meta_eq_of thm
       
   670     | prep (MetaEq thm) = thm
       
   671     | prep (Literals (thm, _)) = spec_meta_eq_of thm
       
   672 
       
   673   fun unfold_conv ctxt ths =
       
   674     Conv.arg_conv (Conv.binop_conv (T.unfold_eqs ctxt (map prep ths)))
       
   675 
       
   676   fun with_conv _ [] prv = prv
       
   677     | with_conv ctxt ths prv = T.with_conv (unfold_conv ctxt ths) prv
       
   678 
       
   679   val unfold_conv =
       
   680     Conv.arg_conv (Conv.binop_conv (Conv.try_conv T.unfold_distinct_conv))
       
   681   val prove_conj_disj_eq = T.with_conv unfold_conv L.prove_conj_disj_eq
       
   682 in
       
   683 
       
   684 fun rewrite ctxt simpset ths = Thm o with_conv ctxt ths (try_apply ctxt [] [
       
   685   named ctxt "conj/disj/distinct" prove_conj_disj_eq,
       
   686   T.by_abstraction ctxt [] (fn ctxt' => T.by_tac (
       
   687     NAMED ctxt' "simp" (Simplifier.simp_tac simpset)
       
   688     THEN_ALL_NEW (
       
   689       NAMED ctxt' "fast" (Classical.fast_tac HOL_cs)
       
   690       ORELSE' NAMED ctxt' "arith" (Arith_Data.arith_tac ctxt'))))])
       
   691 
       
   692 end
       
   693 
       
   694 
       
   695 
       
   696 (** proof reconstruction **)
       
   697 
       
   698 (* tracing and checking *)
       
   699 
       
   700 local
       
   701   fun count_rules ptab =
       
   702     let
       
   703       fun count (_, Unproved _) (solved, total) = (solved, total + 1)
       
   704         | count (_, Proved _) (solved, total) = (solved + 1, total + 1)
       
   705     in Inttab.fold count ptab (0, 0) end
       
   706 
       
   707   fun header idx r (solved, total) = 
       
   708     "Z3: #" ^ string_of_int idx ^ ": " ^ P.string_of_rule r ^ " (goal " ^
       
   709     string_of_int (solved + 1) ^ " of " ^ string_of_int total ^ ")"
       
   710 
       
   711   fun check ctxt idx r ps ct p =
       
   712     let val thm = thm_of p |> tap (Thm.join_proofs o single)
       
   713     in
       
   714       if (Thm.cprop_of thm) aconvc ct then ()
       
   715       else z3_exn (Pretty.string_of (Pretty.big_list ("proof step failed: " ^
       
   716         quote (P.string_of_rule r) ^ " (#" ^ string_of_int idx ^ ")")
       
   717           (pretty_goal ctxt (map (thm_of o fst) ps) (Thm.prop_of thm) @
       
   718            [Pretty.block [Pretty.str "expected: ",
       
   719             Syntax.pretty_term ctxt (Thm.term_of ct)]])))
       
   720     end
       
   721 in
       
   722 fun trace_rule idx prove r ps ct (cxp as (ctxt, ptab)) =
       
   723   let
       
   724     val _ = SMT_Solver.trace_msg ctxt (header idx r o count_rules) ptab
       
   725     val result as (p, cxp' as (ctxt', _)) = prove r ps ct cxp
       
   726     val _ = if not (Config.get ctxt' SMT_Solver.trace) then ()
       
   727       else check ctxt' idx r ps ct p
       
   728   in result end
       
   729 end
       
   730 
       
   731 
       
   732 (* overall reconstruction procedure *)
       
   733 
       
   734 fun not_supported r =
       
   735   raise Fail ("Z3: proof rule not implemented: " ^ quote (P.string_of_rule r))
       
   736 
       
   737 fun prove ctxt unfolds assms vars =
       
   738   let
       
   739     val assms' = Option.map (prepare_assms ctxt unfolds) assms
       
   740     val simpset = T.make_simpset ctxt (Z3_Simps.get ctxt)
       
   741 
       
   742     fun step r ps ct (cxp as (cx, ptab)) =
       
   743       (case (r, ps) of
       
   744         (* core rules *)
       
   745         (P.TrueAxiom, _) => (Thm L.true_thm, cxp)
       
   746       | (P.Asserted, _) => (asserted cx assms' ct, cxp)
       
   747       | (P.Goal, _) => (asserted cx assms' ct, cxp)
       
   748       | (P.ModusPonens, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
       
   749       | (P.ModusPonensOeq, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
       
   750       | (P.AndElim, [(p, i)]) => and_elim (p, i) ct ptab ||> pair cx
       
   751       | (P.NotOrElim, [(p, i)]) => not_or_elim (p, i) ct ptab ||> pair cx
       
   752       | (P.Hypothesis, _) => (Thm (Thm.assume ct), cxp)
       
   753       | (P.Lemma, [(p, _)]) => (lemma (thm_of p) ct, cxp)
       
   754       | (P.UnitResolution, (p, _) :: ps) =>
       
   755           (unit_resolution (thm_of p) (map (thm_of o fst) ps) ct, cxp)
       
   756       | (P.IffTrue, [(p, _)]) => (iff_true (thm_of p), cxp)
       
   757       | (P.IffFalse, [(p, _)]) => (iff_false (thm_of p), cxp)
       
   758       | (P.Distributivity, _) => (distributivity cx ct, cxp)
       
   759       | (P.DefAxiom, _) => (def_axiom cx ct, cxp)
       
   760       | (P.IntroDef, _) => intro_def ct cx ||> rpair ptab
       
   761       | (P.ApplyDef, [(p, _)]) => (apply_def (thm_of p), cxp)
       
   762       | (P.IffOeq, [(p, _)]) => (p, cxp)
       
   763       | (P.NnfPos, _) => (nnf cx vars (map fst ps) ct, cxp)
       
   764       | (P.NnfNeg, _) => (nnf cx vars (map fst ps) ct, cxp)
       
   765 
       
   766         (* equality rules *)
       
   767       | (P.Reflexivity, _) => (refl ct, cxp)
       
   768       | (P.Symmetry, [(p, _)]) => (symm p, cxp)
       
   769       | (P.Transitivity, [(p, _), (q, _)]) => (trans p q, cxp)
       
   770       | (P.Monotonicity, _) => (monotonicity (map fst ps) ct, cxp)
       
   771       | (P.Commutativity, _) => (commutativity ct, cxp)
       
   772 
       
   773         (* quantifier rules *)
       
   774       | (P.QuantIntro, [(p, _)]) => (quant_intro vars p ct, cxp)
       
   775       | (P.PullQuant, _) => (pull_quant cx ct, cxp)
       
   776       | (P.PushQuant, _) => (push_quant cx ct, cxp)
       
   777       | (P.ElimUnusedVars, _) => (elim_unused_vars cx ct, cxp)
       
   778       | (P.DestEqRes, _) => (dest_eq_res cx ct, cxp)
       
   779       | (P.QuantInst, _) => (quant_inst ct, cxp)
       
   780       | (P.Skolemize, _) => skolemize ct cx ||> rpair ptab
       
   781 
       
   782         (* theory rules *)
       
   783       | (P.ThLemma, _) =>
       
   784           (th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
       
   785       | (P.Rewrite, _) => (rewrite cx simpset [] ct, cxp)
       
   786       | (P.RewriteStar, ps) =>
       
   787           (rewrite cx simpset (map fst ps) ct, cxp)
       
   788 
       
   789       | (P.NnfStar, _) => not_supported r
       
   790       | (P.CnfStar, _) => not_supported r
       
   791       | (P.TransitivityStar, _) => not_supported r
       
   792       | (P.PullQuantStar, _) => not_supported r
       
   793 
       
   794       | _ => raise Fail ("Z3: proof rule " ^ quote (P.string_of_rule r) ^
       
   795          " has an unexpected number of arguments."))
       
   796 
       
   797     fun conclude idx rule prop (ps, cxp) =
       
   798       trace_rule idx step rule ps prop cxp
       
   799       |-> (fn p => apsnd (Inttab.update (idx, Proved p)) #> pair p)
       
   800 
       
   801     fun lookup idx (cxp as (cx, ptab)) =
       
   802       (case Inttab.lookup ptab idx of
       
   803         SOME (Unproved (P.Proof_Step {rule, prems, prop})) =>
       
   804           fold_map lookup prems cxp
       
   805           |>> map2 rpair prems
       
   806           |> conclude idx rule prop
       
   807       | SOME (Proved p) => (p, cxp)
       
   808       | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
       
   809 
       
   810     fun result (p, (cx, _)) = (thm_of p, cx)
       
   811   in
       
   812     (fn (idx, ptab) => result (lookup idx (ctxt, Inttab.map Unproved ptab)))
       
   813   end
       
   814 
       
   815 fun reconstruct (output, {typs, terms, unfolds, assms}) ctxt =
       
   816   P.parse ctxt typs terms output
       
   817   |> (fn (idx, (ptab, vars, cx)) => prove cx unfolds assms vars (idx, ptab))
       
   818 
       
   819 val setup = trace_assms_setup #> z3_rules_setup #> Z3_Simps.setup
       
   820 
       
   821 end