src/HOL/HOLCF/Tools/Domain/domain_induction.ML
author wenzelm
Wed, 12 Oct 2011 22:48:23 +0200
changeset 46004 2214ba5bdfff
parent 44951 53d95b52954c
child 46527 cf10bde35973
permissions -rw-r--r--
modernized structure Induct_Tacs;
     1 (*  Title:      HOL/HOLCF/Tools/Domain/domain_induction.ML
     2     Author:     David von Oheimb
     3     Author:     Brian Huffman
     4 
     5 Proofs of high-level (co)induction rules for domain command.
     6 *)
     7 
     8 signature DOMAIN_INDUCTION =
     9 sig
    10   val comp_theorems :
    11       binding list ->
    12       Domain_Take_Proofs.take_induct_info ->
    13       Domain_Constructors.constr_info list ->
    14       theory -> thm list * theory
    15 
    16   val quiet_mode: bool Unsynchronized.ref
    17   val trace_domain: bool Unsynchronized.ref
    18 end
    19 
    20 structure Domain_Induction : DOMAIN_INDUCTION =
    21 struct
    22 
    23 val quiet_mode = Unsynchronized.ref false
    24 val trace_domain = Unsynchronized.ref false
    25 
    26 fun message s = if !quiet_mode then () else writeln s
    27 fun trace s = if !trace_domain then tracing s else ()
    28 
    29 open HOLCF_Library
    30 
    31 (******************************************************************************)
    32 (***************************** proofs about take ******************************)
    33 (******************************************************************************)
    34 
    35 fun take_theorems
    36     (dbinds : binding list)
    37     (take_info : Domain_Take_Proofs.take_induct_info)
    38     (constr_infos : Domain_Constructors.constr_info list)
    39     (thy : theory) : thm list list * theory =
    40 let
    41   val {take_consts, take_Suc_thms, deflation_take_thms, ...} = take_info
    42   val deflation_thms = Domain_Take_Proofs.get_deflation_thms thy
    43 
    44   val n = Free ("n", @{typ nat})
    45   val n' = @{const Suc} $ n
    46 
    47   local
    48     val newTs = map (#absT o #iso_info) constr_infos
    49     val subs = newTs ~~ map (fn t => t $ n) take_consts
    50     fun is_ID (Const (c, _)) = (c = @{const_name ID})
    51       | is_ID _              = false
    52   in
    53     fun map_of_arg thy v T =
    54       let val m = Domain_Take_Proofs.map_of_typ thy subs T
    55       in if is_ID m then v else mk_capply (m, v) end
    56   end
    57 
    58   fun prove_take_apps
    59       ((dbind, take_const), constr_info) thy =
    60     let
    61       val {iso_info, con_specs, con_betas, ...} : Domain_Constructors.constr_info = constr_info
    62       val {abs_inverse, ...} = iso_info
    63       fun prove_take_app (con_const, args) =
    64         let
    65           val Ts = map snd args
    66           val ns = Name.variant_list ["n"] (Datatype_Prop.make_tnames Ts)
    67           val vs = map Free (ns ~~ Ts)
    68           val lhs = mk_capply (take_const $ n', list_ccomb (con_const, vs))
    69           val rhs = list_ccomb (con_const, map2 (map_of_arg thy) vs Ts)
    70           val goal = mk_trp (mk_eq (lhs, rhs))
    71           val rules =
    72               [abs_inverse] @ con_betas @ @{thms take_con_rules}
    73               @ take_Suc_thms @ deflation_thms @ deflation_take_thms
    74           val tac = simp_tac (HOL_basic_ss addsimps rules) 1
    75         in
    76           Goal.prove_global thy [] [] goal (K tac)
    77         end
    78       val take_apps = map prove_take_app con_specs
    79     in
    80       yield_singleton Global_Theory.add_thmss
    81         ((Binding.qualified true "take_rews" dbind, take_apps),
    82         [Simplifier.simp_add]) thy
    83     end
    84 in
    85   fold_map prove_take_apps
    86     (dbinds ~~ take_consts ~~ constr_infos) thy
    87 end
    88 
    89 (******************************************************************************)
    90 (****************************** induction rules *******************************)
    91 (******************************************************************************)
    92 
    93 val case_UU_allI =
    94     @{lemma "(!!x. x ~= UU ==> P x) ==> P UU ==> ALL x. P x" by metis}
    95 
    96 fun prove_induction
    97     (comp_dbind : binding)
    98     (constr_infos : Domain_Constructors.constr_info list)
    99     (take_info : Domain_Take_Proofs.take_induct_info)
   100     (take_rews : thm list)
   101     (thy : theory) =
   102 let
   103   val comp_dname = Binding.name_of comp_dbind
   104 
   105   val iso_infos = map #iso_info constr_infos
   106   val exhausts = map #exhaust constr_infos
   107   val con_rews = maps #con_rews constr_infos
   108   val {take_consts, take_induct_thms, ...} = take_info
   109 
   110   val newTs = map #absT iso_infos
   111   val P_names = Datatype_Prop.indexify_names (map (K "P") newTs)
   112   val x_names = Datatype_Prop.indexify_names (map (K "x") newTs)
   113   val P_types = map (fn T => T --> HOLogic.boolT) newTs
   114   val Ps = map Free (P_names ~~ P_types)
   115   val xs = map Free (x_names ~~ newTs)
   116   val n = Free ("n", HOLogic.natT)
   117 
   118   fun con_assm defined p (con, args) =
   119     let
   120       val Ts = map snd args
   121       val ns = Name.variant_list P_names (Datatype_Prop.make_tnames Ts)
   122       val vs = map Free (ns ~~ Ts)
   123       val nonlazy = map snd (filter_out (fst o fst) (args ~~ vs))
   124       fun ind_hyp (v, T) t =
   125           case AList.lookup (op =) (newTs ~~ Ps) T of NONE => t
   126           | SOME p' => Logic.mk_implies (mk_trp (p' $ v), t)
   127       val t1 = mk_trp (p $ list_ccomb (con, vs))
   128       val t2 = fold_rev ind_hyp (vs ~~ Ts) t1
   129       val t3 = Logic.list_implies (map (mk_trp o mk_defined) nonlazy, t2)
   130     in fold_rev Logic.all vs (if defined then t3 else t2) end
   131   fun eq_assms ((p, T), cons) =
   132       mk_trp (p $ HOLCF_Library.mk_bottom T) :: map (con_assm true p) cons
   133   val assms = maps eq_assms (Ps ~~ newTs ~~ map #con_specs constr_infos)
   134 
   135   val take_ss = HOL_ss addsimps (@{thm Rep_cfun_strict1} :: take_rews)
   136   fun quant_tac ctxt i = EVERY
   137     (map (fn name => res_inst_tac ctxt [(("x", 0), name)] spec i) x_names)
   138 
   139   (* FIXME: move this message to domain_take_proofs.ML *)
   140   val is_finite = #is_finite take_info
   141   val _ = if is_finite
   142           then message ("Proving finiteness rule for domain "^comp_dname^" ...")
   143           else ()
   144 
   145   val _ = trace " Proving finite_ind..."
   146   val finite_ind =
   147     let
   148       val concls =
   149           map (fn ((P, t), x) => P $ mk_capply (t $ n, x))
   150               (Ps ~~ take_consts ~~ xs)
   151       val goal = mk_trp (foldr1 mk_conj concls)
   152 
   153       fun tacf {prems, context = ctxt} =
   154         let
   155           (* Prove stronger prems, without definedness side conditions *)
   156           fun con_thm p (con, args) =
   157             let
   158               val subgoal = con_assm false p (con, args)
   159               val rules = prems @ con_rews @ simp_thms
   160               val simplify = asm_simp_tac (HOL_basic_ss addsimps rules)
   161               fun arg_tac (lazy, _) =
   162                   rtac (if lazy then allI else case_UU_allI) 1
   163               val tacs =
   164                   rewrite_goals_tac @{thms atomize_all atomize_imp} ::
   165                   map arg_tac args @
   166                   [REPEAT (rtac impI 1), ALLGOALS simplify]
   167             in
   168               Goal.prove ctxt [] [] subgoal (K (EVERY tacs))
   169             end
   170           fun eq_thms (p, cons) = map (con_thm p) cons
   171           val conss = map #con_specs constr_infos
   172           val prems' = maps eq_thms (Ps ~~ conss)
   173 
   174           val tacs1 = [
   175             quant_tac ctxt 1,
   176             simp_tac HOL_ss 1,
   177             Induct_Tacs.induct_tac ctxt [[SOME "n"]] 1,
   178             simp_tac (take_ss addsimps prems) 1,
   179             TRY (safe_tac (put_claset HOL_cs ctxt))]
   180           fun con_tac _ = 
   181             asm_simp_tac take_ss 1 THEN
   182             (resolve_tac prems' THEN_ALL_NEW etac spec) 1
   183           fun cases_tacs (cons, exhaust) =
   184             res_inst_tac ctxt [(("y", 0), "x")] exhaust 1 ::
   185             asm_simp_tac (take_ss addsimps prems) 1 ::
   186             map con_tac cons
   187           val tacs = tacs1 @ maps cases_tacs (conss ~~ exhausts)
   188         in
   189           EVERY (map DETERM tacs)
   190         end
   191     in Goal.prove_global thy [] assms goal tacf end
   192 
   193   val _ = trace " Proving ind..."
   194   val ind =
   195     let
   196       val concls = map (op $) (Ps ~~ xs)
   197       val goal = mk_trp (foldr1 mk_conj concls)
   198       val adms = if is_finite then [] else map (mk_trp o mk_adm) Ps
   199       fun tacf {prems, context = ctxt} =
   200         let
   201           fun finite_tac (take_induct, fin_ind) =
   202               rtac take_induct 1 THEN
   203               (if is_finite then all_tac else resolve_tac prems 1) THEN
   204               (rtac fin_ind THEN_ALL_NEW solve_tac prems) 1
   205           val fin_inds = Project_Rule.projections ctxt finite_ind
   206         in
   207           TRY (safe_tac (put_claset HOL_cs ctxt)) THEN
   208           EVERY (map finite_tac (take_induct_thms ~~ fin_inds))
   209         end
   210     in Goal.prove_global thy [] (adms @ assms) goal tacf end
   211 
   212   (* case names for induction rules *)
   213   val dnames = map (fst o dest_Type) newTs
   214   val case_ns =
   215     let
   216       val adms =
   217           if is_finite then [] else
   218           if length dnames = 1 then ["adm"] else
   219           map (fn s => "adm_" ^ Long_Name.base_name s) dnames
   220       val bottoms =
   221           if length dnames = 1 then ["bottom"] else
   222           map (fn s => "bottom_" ^ Long_Name.base_name s) dnames
   223       fun one_eq bot (constr_info : Domain_Constructors.constr_info) =
   224         let fun name_of (c, _) = Long_Name.base_name (fst (dest_Const c))
   225         in bot :: map name_of (#con_specs constr_info) end
   226     in adms @ flat (map2 one_eq bottoms constr_infos) end
   227 
   228   val inducts = Project_Rule.projections (Proof_Context.init_global thy) ind
   229   fun ind_rule (dname, rule) =
   230       ((Binding.empty, rule),
   231        [Rule_Cases.case_names case_ns, Induct.induct_type dname])
   232 
   233 in
   234   thy
   235   |> snd o Global_Theory.add_thms [
   236      ((Binding.qualified true "finite_induct" comp_dbind, finite_ind), []),
   237      ((Binding.qualified true "induct"        comp_dbind, ind       ), [])]
   238   |> (snd o Global_Theory.add_thms (map ind_rule (dnames ~~ inducts)))
   239 end (* prove_induction *)
   240 
   241 (******************************************************************************)
   242 (************************ bisimulation and coinduction ************************)
   243 (******************************************************************************)
   244 
   245 fun prove_coinduction
   246     (comp_dbind : binding, dbinds : binding list)
   247     (constr_infos : Domain_Constructors.constr_info list)
   248     (take_info : Domain_Take_Proofs.take_induct_info)
   249     (take_rews : thm list list)
   250     (thy : theory) : theory =
   251 let
   252   val iso_infos = map #iso_info constr_infos
   253   val newTs = map #absT iso_infos
   254 
   255   val {take_consts, take_0_thms, take_lemma_thms, ...} = take_info
   256 
   257   val R_names = Datatype_Prop.indexify_names (map (K "R") newTs)
   258   val R_types = map (fn T => T --> T --> boolT) newTs
   259   val Rs = map Free (R_names ~~ R_types)
   260   val n = Free ("n", natT)
   261   val reserved = "x" :: "y" :: R_names
   262 
   263   (* declare bisimulation predicate *)
   264   val bisim_bind = Binding.suffix_name "_bisim" comp_dbind
   265   val bisim_type = R_types ---> boolT
   266   val (bisim_const, thy) =
   267       Sign.declare_const_global ((bisim_bind, bisim_type), NoSyn) thy
   268 
   269   (* define bisimulation predicate *)
   270   local
   271     fun one_con T (con, args) =
   272       let
   273         val Ts = map snd args
   274         val ns1 = Name.variant_list reserved (Datatype_Prop.make_tnames Ts)
   275         val ns2 = map (fn n => n^"'") ns1
   276         val vs1 = map Free (ns1 ~~ Ts)
   277         val vs2 = map Free (ns2 ~~ Ts)
   278         val eq1 = mk_eq (Free ("x", T), list_ccomb (con, vs1))
   279         val eq2 = mk_eq (Free ("y", T), list_ccomb (con, vs2))
   280         fun rel ((v1, v2), T) =
   281             case AList.lookup (op =) (newTs ~~ Rs) T of
   282               NONE => mk_eq (v1, v2) | SOME r => r $ v1 $ v2
   283         val eqs = foldr1 mk_conj (map rel (vs1 ~~ vs2 ~~ Ts) @ [eq1, eq2])
   284       in
   285         Library.foldr mk_ex (vs1 @ vs2, eqs)
   286       end
   287     fun one_eq ((T, R), cons) =
   288       let
   289         val x = Free ("x", T)
   290         val y = Free ("y", T)
   291         val disj1 = mk_conj (mk_eq (x, mk_bottom T), mk_eq (y, mk_bottom T))
   292         val disjs = disj1 :: map (one_con T) cons
   293       in
   294         mk_all (x, mk_all (y, mk_imp (R $ x $ y, foldr1 mk_disj disjs)))
   295       end
   296     val conjs = map one_eq (newTs ~~ Rs ~~ map #con_specs constr_infos)
   297     val bisim_rhs = lambdas Rs (Library.foldr1 mk_conj conjs)
   298     val bisim_eqn = Logic.mk_equals (bisim_const, bisim_rhs)
   299   in
   300     val (bisim_def_thm, thy) = thy |>
   301         yield_singleton (Global_Theory.add_defs false)
   302          ((Binding.qualified true "bisim_def" comp_dbind, bisim_eqn), [])
   303   end (* local *)
   304 
   305   (* prove coinduction lemma *)
   306   val coind_lemma =
   307     let
   308       val assm = mk_trp (list_comb (bisim_const, Rs))
   309       fun one ((T, R), take_const) =
   310         let
   311           val x = Free ("x", T)
   312           val y = Free ("y", T)
   313           val lhs = mk_capply (take_const $ n, x)
   314           val rhs = mk_capply (take_const $ n, y)
   315         in
   316           mk_all (x, mk_all (y, mk_imp (R $ x $ y, mk_eq (lhs, rhs))))
   317         end
   318       val goal =
   319           mk_trp (foldr1 mk_conj (map one (newTs ~~ Rs ~~ take_consts)))
   320       val rules = @{thm Rep_cfun_strict1} :: take_0_thms
   321       fun tacf {prems, context = ctxt} =
   322         let
   323           val prem' = rewrite_rule [bisim_def_thm] (hd prems)
   324           val prems' = Project_Rule.projections ctxt prem'
   325           val dests = map (fn th => th RS spec RS spec RS mp) prems'
   326           fun one_tac (dest, rews) =
   327               dtac dest 1 THEN safe_tac (put_claset HOL_cs ctxt) THEN
   328               ALLGOALS (asm_simp_tac (HOL_basic_ss addsimps rews))
   329         in
   330           rtac @{thm nat.induct} 1 THEN
   331           simp_tac (HOL_ss addsimps rules) 1 THEN
   332           safe_tac (put_claset HOL_cs ctxt) THEN
   333           EVERY (map one_tac (dests ~~ take_rews))
   334         end
   335     in
   336       Goal.prove_global thy [] [assm] goal tacf
   337     end
   338 
   339   (* prove individual coinduction rules *)
   340   fun prove_coind ((T, R), take_lemma) =
   341     let
   342       val x = Free ("x", T)
   343       val y = Free ("y", T)
   344       val assm1 = mk_trp (list_comb (bisim_const, Rs))
   345       val assm2 = mk_trp (R $ x $ y)
   346       val goal = mk_trp (mk_eq (x, y))
   347       fun tacf {prems, context = _} =
   348         let
   349           val rule = hd prems RS coind_lemma
   350         in
   351           rtac take_lemma 1 THEN
   352           asm_simp_tac (HOL_basic_ss addsimps (rule :: prems)) 1
   353         end
   354     in
   355       Goal.prove_global thy [] [assm1, assm2] goal tacf
   356     end
   357   val coinds = map prove_coind (newTs ~~ Rs ~~ take_lemma_thms)
   358   val coind_binds = map (Binding.qualified true "coinduct") dbinds
   359 
   360 in
   361   thy |> snd o Global_Theory.add_thms
   362     (map Thm.no_attributes (coind_binds ~~ coinds))
   363 end (* let *)
   364 
   365 (******************************************************************************)
   366 (******************************* main function ********************************)
   367 (******************************************************************************)
   368 
   369 fun comp_theorems
   370     (dbinds : binding list)
   371     (take_info : Domain_Take_Proofs.take_induct_info)
   372     (constr_infos : Domain_Constructors.constr_info list)
   373     (thy : theory) =
   374 let
   375 
   376 val comp_dname = space_implode "_" (map Binding.name_of dbinds)
   377 val comp_dbind = Binding.name comp_dname
   378 
   379 (* Test for emptiness *)
   380 (* FIXME: reimplement emptiness test
   381 local
   382   open Domain_Library
   383   val dnames = map (fst o fst) eqs
   384   val conss = map snd eqs
   385   fun rec_to ns lazy_rec (n,cons) = forall (exists (fn arg => 
   386         is_rec arg andalso not (member (op =) ns (rec_of arg)) andalso
   387         ((rec_of arg =  n andalso not (lazy_rec orelse is_lazy arg)) orelse 
   388           rec_of arg <> n andalso rec_to (rec_of arg::ns) 
   389             (lazy_rec orelse is_lazy arg) (n, nth conss (rec_of arg)))
   390         ) o snd) cons
   391   fun warn (n,cons) =
   392     if rec_to [] false (n,cons)
   393     then (warning ("domain " ^ nth dnames n ^ " is empty!") true)
   394     else false
   395 in
   396   val n__eqs = mapn (fn n => fn (_,cons) => (n,cons)) 0 eqs
   397   val is_emptys = map warn n__eqs
   398 end
   399 *)
   400 
   401 (* Test for indirect recursion *)
   402 local
   403   val newTs = map (#absT o #iso_info) constr_infos
   404   fun indirect_typ (Type (_, Ts)) =
   405       exists (fn T => member (op =) newTs T orelse indirect_typ T) Ts
   406     | indirect_typ _ = false
   407   fun indirect_arg (_, T) = indirect_typ T
   408   fun indirect_con (_, args) = exists indirect_arg args
   409   fun indirect_eq cons = exists indirect_con cons
   410 in
   411   val is_indirect = exists indirect_eq (map #con_specs constr_infos)
   412   val _ =
   413       if is_indirect
   414       then message "Indirect recursion detected, skipping proofs of (co)induction rules"
   415       else message ("Proving induction properties of domain "^comp_dname^" ...")
   416 end
   417 
   418 (* theorems about take *)
   419 
   420 val (take_rewss, thy) =
   421     take_theorems dbinds take_info constr_infos thy
   422 
   423 val {take_0_thms, take_strict_thms, ...} = take_info
   424 
   425 val take_rews = take_0_thms @ take_strict_thms @ flat take_rewss
   426 
   427 (* prove induction rules, unless definition is indirect recursive *)
   428 val thy =
   429     if is_indirect then thy else
   430     prove_induction comp_dbind constr_infos take_info take_rews thy
   431 
   432 val thy =
   433     if is_indirect then thy else
   434     prove_coinduction (comp_dbind, dbinds) constr_infos take_info take_rewss thy
   435 
   436 in
   437   (take_rews, thy)
   438 end (* let *)
   439 end (* struct *)