src/HOL/BNF/Tools/coinduction.ML
changeset 56013 d64a4ef26edb
parent 55254 13bfdbcfbbfb
equal deleted inserted replaced
56012:cfb21e03fe2a 56013:d64a4ef26edb
     1 (*  Title:      HOL/BNF/Tools/coinduction.ML
       
     2     Author:     Johannes Hölzl, TU Muenchen
       
     3     Author:     Dmitriy Traytel, TU Muenchen
       
     4     Copyright   2013
       
     5 
       
     6 Coinduction method that avoids some boilerplate compared to coinduct.
       
     7 *)
       
     8 
       
     9 signature COINDUCTION =
       
    10 sig
       
    11   val coinduction_tac: Proof.context -> term list -> thm option -> thm list -> cases_tactic
       
    12   val setup: theory -> theory
       
    13 end;
       
    14 
       
    15 structure Coinduction : COINDUCTION =
       
    16 struct
       
    17 
       
    18 open BNF_Util
       
    19 open BNF_Tactics
       
    20 
       
    21 fun filter_in_out _ [] = ([], [])
       
    22   | filter_in_out P (x :: xs) = (let
       
    23       val (ins, outs) = filter_in_out P xs;
       
    24     in
       
    25       if P x then (x :: ins, outs) else (ins, x :: outs)
       
    26     end);
       
    27 
       
    28 fun ALLGOALS_SKIP skip tac st =
       
    29   let fun doall n = if n = skip then all_tac else tac n THEN doall (n - 1)
       
    30   in doall (nprems_of st) st  end;
       
    31 
       
    32 fun THEN_ALL_NEW_SKIP skip tac1 tac2 i st =
       
    33   st |> (tac1 i THEN (fn st' =>
       
    34     Seq.INTERVAL tac2 (i + skip) (i + nprems_of st' - nprems_of st) st'));
       
    35 
       
    36 fun DELETE_PREMS_AFTER skip tac i st =
       
    37   let
       
    38     val n = nth (prems_of st) (i - 1) |> Logic.strip_assums_hyp |> length;
       
    39   in
       
    40     (THEN_ALL_NEW_SKIP skip tac (REPEAT_DETERM_N n o etac thin_rl)) i st
       
    41   end;
       
    42 
       
    43 fun coinduction_tac ctxt raw_vars opt_raw_thm prems st =
       
    44   let
       
    45     val lhs_of_eq = HOLogic.dest_Trueprop #> HOLogic.dest_eq #> fst;
       
    46     fun find_coinduct t = 
       
    47       Induct.find_coinductP ctxt t @
       
    48       (try (Induct.find_coinductT ctxt o fastype_of o lhs_of_eq) t |> the_default [])
       
    49     val raw_thm = case opt_raw_thm
       
    50       of SOME raw_thm => raw_thm
       
    51        | NONE => st |> prems_of |> hd |> Logic.strip_assums_concl |> find_coinduct |> hd;
       
    52     val skip = Integer.max 1 (Rule_Cases.get_consumes raw_thm) - 1
       
    53     val cases = Rule_Cases.get raw_thm |> fst
       
    54   in
       
    55     NO_CASES (HEADGOAL (
       
    56       Object_Logic.rulify_tac THEN'
       
    57       Method.insert_tac prems THEN'
       
    58       Object_Logic.atomize_prems_tac THEN'
       
    59       DELETE_PREMS_AFTER skip (Subgoal.FOCUS (fn {concl, context = ctxt, params, prems, ...} =>
       
    60         let
       
    61           val vars = raw_vars @ map (term_of o snd) params;
       
    62           val names_ctxt = ctxt
       
    63             |> fold Variable.declare_names vars
       
    64             |> fold Variable.declare_thm (raw_thm :: prems);
       
    65           val thm_concl = Thm.cprop_of raw_thm |> strip_imp_concl;
       
    66           val (rhoTs, rhots) = Thm.match (thm_concl, concl)
       
    67             |>> map (pairself typ_of)
       
    68             ||> map (pairself term_of);
       
    69           val xs = hd (Thm.prems_of raw_thm) |> HOLogic.dest_Trueprop |> strip_comb |> snd
       
    70             |> map (subst_atomic_types rhoTs);
       
    71           val raw_eqs = map (fn x => (x, AList.lookup op aconv rhots x |> the)) xs;
       
    72           val ((names, ctxt), Ts) = map_split (apfst fst o dest_Var o fst) raw_eqs
       
    73             |>> (fn names => Variable.variant_fixes names names_ctxt) ;
       
    74           val eqs =
       
    75             map3 (fn name => fn T => fn (_, rhs) =>
       
    76               HOLogic.mk_eq (Free (name, T), rhs))
       
    77             names Ts raw_eqs;
       
    78           val phi = eqs @ map (HOLogic.dest_Trueprop o prop_of) prems
       
    79             |> try (Library.foldr1 HOLogic.mk_conj)
       
    80             |> the_default @{term True}
       
    81             |> list_exists_free vars
       
    82             |> Term.map_abs_vars (Variable.revert_fixed ctxt)
       
    83             |> fold_rev Term.absfree (names ~~ Ts)
       
    84             |> certify ctxt;
       
    85           val thm = cterm_instantiate_pos [SOME phi] raw_thm;
       
    86           val e = length eqs;
       
    87           val p = length prems;
       
    88         in
       
    89           HEADGOAL (EVERY' [rtac thm,
       
    90             EVERY' (map (fn var =>
       
    91               rtac (cterm_instantiate_pos [NONE, SOME (certify ctxt var)] exI)) vars),
       
    92             if p = 0 then CONJ_WRAP' (K (rtac refl)) eqs
       
    93             else REPEAT_DETERM_N e o (rtac conjI THEN' rtac refl) THEN' CONJ_WRAP' rtac prems,
       
    94             K (ALLGOALS_SKIP skip
       
    95                (REPEAT_DETERM_N (length vars) o (etac exE THEN' rotate_tac ~1) THEN'
       
    96                DELETE_PREMS_AFTER 0 (Subgoal.FOCUS (fn {prems, params, context = ctxt, ...} =>
       
    97                  (case prems of
       
    98                    [] => all_tac
       
    99                  | inv::case_prems =>
       
   100                      let
       
   101                        val (init, last) = funpow_yield (p + e - 1) HOLogic.conj_elim inv;
       
   102                        val inv_thms = init @ [last];
       
   103                        val eqs = take e inv_thms;
       
   104                        fun is_local_var t = 
       
   105                          member (fn (t, (_, t')) => t aconv (term_of t')) params t;
       
   106                         val (eqs, assms') = filter_in_out (is_local_var o lhs_of_eq o prop_of) eqs;
       
   107                         val assms = assms' @ drop e inv_thms
       
   108                       in
       
   109                         HEADGOAL (Method.insert_tac (assms @ case_prems)) THEN
       
   110                         unfold_thms_tac ctxt eqs
       
   111                       end)) ctxt)))])
       
   112         end) ctxt) THEN'
       
   113       K (prune_params_tac))) st
       
   114     |> Seq.maps (fn (_, st) =>
       
   115       CASES (Rule_Cases.make_common (Proof_Context.theory_of ctxt, prop_of st) cases) all_tac st)
       
   116   end;
       
   117 
       
   118 local
       
   119 
       
   120 val ruleN = "rule"
       
   121 val arbitraryN = "arbitrary"
       
   122 fun single_rule [rule] = rule
       
   123   | single_rule _ = error "Single rule expected";
       
   124 
       
   125 fun named_rule k arg get =
       
   126   Scan.lift (Args.$$$ k -- Args.colon) |-- Scan.repeat arg :|--
       
   127     (fn names => Scan.peek (fn context => Scan.succeed (names |> map (fn name =>
       
   128       (case get (Context.proof_of context) name of SOME x => x
       
   129       | NONE => error ("No rule for " ^ k ^ " " ^ quote name))))));
       
   130 
       
   131 fun rule get_type get_pred =
       
   132   named_rule Induct.typeN (Args.type_name false) get_type ||
       
   133   named_rule Induct.predN (Args.const false) get_pred ||
       
   134   named_rule Induct.setN (Args.const false) get_pred ||
       
   135   Scan.lift (Args.$$$ ruleN -- Args.colon) |-- Attrib.thms;
       
   136 
       
   137 val coinduct_rule = rule Induct.lookup_coinductT Induct.lookup_coinductP >> single_rule;
       
   138 
       
   139 fun unless_more_args scan = Scan.unless (Scan.lift
       
   140   ((Args.$$$ arbitraryN || Args.$$$ Induct.typeN ||
       
   141     Args.$$$ Induct.predN || Args.$$$ Induct.setN || Args.$$$ ruleN) -- Args.colon)) scan;
       
   142 
       
   143 val arbitrary = Scan.optional (Scan.lift (Args.$$$ arbitraryN -- Args.colon) |--
       
   144   Scan.repeat1 (unless_more_args Args.term)) [];
       
   145 
       
   146 in
       
   147 
       
   148 val setup =
       
   149   Method.setup @{binding coinduction}
       
   150     (arbitrary -- Scan.option coinduct_rule >>
       
   151       (fn (arbitrary, opt_rule) => fn ctxt =>
       
   152         RAW_METHOD_CASES (fn facts =>
       
   153           Seq.DETERM (coinduction_tac ctxt arbitrary opt_rule facts))))
       
   154     "coinduction on types or predicates/sets";
       
   155 
       
   156 end;
       
   157 
       
   158 end;
       
   159