src/HOL/Tools/Function/partial_function.ML
author wenzelm
Thu, 15 Mar 2012 20:07:00 +0100
changeset 47823 94aa7b81bcf6
parent 46277 7a0b8debef77
child 47836 5c6955f487e5
permissions -rw-r--r--
prefer formally checked @{keyword} parser;
     1 (*  Title:      HOL/Tools/Function/partial_function.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Partial function definitions based on least fixed points in ccpos.
     5 *)
     6 
     7 signature PARTIAL_FUNCTION =
     8 sig
     9   val setup: theory -> theory
    10   val init: string -> term -> term -> thm -> thm option -> declaration
    11 
    12   val add_partial_function: string -> (binding * typ option * mixfix) list ->
    13     Attrib.binding * term -> local_theory -> local_theory
    14 
    15   val add_partial_function_cmd: string -> (binding * string option * mixfix) list ->
    16     Attrib.binding * string -> local_theory -> local_theory
    17 end;
    18 
    19 
    20 structure Partial_Function: PARTIAL_FUNCTION =
    21 struct
    22 
    23 (*** Context Data ***)
    24 
    25 datatype setup_data = Setup_Data of 
    26  {fixp: term,
    27   mono: term,
    28   fixp_eq: thm,
    29   fixp_induct: thm option};
    30 
    31 structure Modes = Generic_Data
    32 (
    33   type T = setup_data Symtab.table;
    34   val empty = Symtab.empty;
    35   val extend = I;
    36   fun merge data = Symtab.merge (K true) data;
    37 )
    38 
    39 fun init mode fixp mono fixp_eq fixp_induct phi =
    40   let
    41     val term = Morphism.term phi;
    42     val thm = Morphism.thm phi;
    43     val data' = Setup_Data 
    44       {fixp=term fixp, mono=term mono, fixp_eq=thm fixp_eq,
    45        fixp_induct=Option.map thm fixp_induct};
    46   in
    47     Modes.map (Symtab.update (mode, data'))
    48   end
    49 
    50 val known_modes = Symtab.keys o Modes.get o Context.Proof;
    51 val lookup_mode = Symtab.lookup o Modes.get o Context.Proof;
    52 
    53 
    54 structure Mono_Rules = Named_Thms
    55 (
    56   val name = @{binding partial_function_mono};
    57   val description = "monotonicity rules for partial function definitions";
    58 );
    59 
    60 
    61 (*** Automated monotonicity proofs ***)
    62 
    63 fun strip_cases ctac = ctac #> Seq.map snd;
    64 
    65 (*rewrite conclusion with k-th assumtion*)
    66 fun rewrite_with_asm_tac ctxt k =
    67   Subgoal.FOCUS (fn {context = ctxt', prems, ...} =>
    68     Local_Defs.unfold_tac ctxt' [nth prems k]) ctxt;
    69 
    70 fun dest_case thy t =
    71   case strip_comb t of
    72     (Const (case_comb, _), args) =>
    73       (case Datatype.info_of_case thy case_comb of
    74          NONE => NONE
    75        | SOME {case_rewrites, ...} =>
    76            let
    77              val lhs = prop_of (hd case_rewrites)
    78                |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst;
    79              val arity = length (snd (strip_comb lhs));
    80              val conv = funpow (length args - arity) Conv.fun_conv
    81                (Conv.rewrs_conv (map mk_meta_eq case_rewrites));
    82            in
    83              SOME (nth args (arity - 1), conv)
    84            end)
    85   | _ => NONE;
    86 
    87 (*split on case expressions*)
    88 val split_cases_tac = Subgoal.FOCUS_PARAMS (fn {context=ctxt, ...} =>
    89   SUBGOAL (fn (t, i) => case t of
    90     _ $ (_ $ Abs (_, _, body)) =>
    91       (case dest_case (Proof_Context.theory_of ctxt) body of
    92          NONE => no_tac
    93        | SOME (arg, conv) =>
    94            let open Conv in
    95               if Term.is_open arg then no_tac
    96               else ((DETERM o strip_cases o Induct.cases_tac ctxt false [[SOME arg]] NONE [])
    97                 THEN_ALL_NEW (rewrite_with_asm_tac ctxt 0)
    98                 THEN_ALL_NEW etac @{thm thin_rl}
    99                 THEN_ALL_NEW (CONVERSION
   100                   (params_conv ~1 (fn ctxt' =>
   101                     arg_conv (arg_conv (abs_conv (K conv) ctxt'))) ctxt))) i
   102            end)
   103   | _ => no_tac) 1);
   104 
   105 (*monotonicity proof: apply rules + split case expressions*)
   106 fun mono_tac ctxt =
   107   K (Local_Defs.unfold_tac ctxt [@{thm curry_def}])
   108   THEN' (TRY o REPEAT_ALL_NEW
   109    (resolve_tac (Mono_Rules.get ctxt)
   110      ORELSE' split_cases_tac ctxt));
   111 
   112 
   113 (*** Auxiliary functions ***)
   114 
   115 (*positional instantiation with computed type substitution.
   116   internal version of  attribute "[of s t u]".*)
   117 fun cterm_instantiate' cts thm =
   118   let
   119     val thy = Thm.theory_of_thm thm;
   120     val vs = rev (Term.add_vars (prop_of thm) [])
   121       |> map (Thm.cterm_of thy o Var);
   122   in
   123     cterm_instantiate (zip_options vs cts) thm
   124   end;
   125 
   126 (*Returns t $ u, but instantiates the type of t to make the
   127 application type correct*)
   128 fun apply_inst ctxt t u =
   129   let
   130     val thy = Proof_Context.theory_of ctxt;
   131     val T = domain_type (fastype_of t);
   132     val T' = fastype_of u;
   133     val subst = Sign.typ_match thy (T, T') Vartab.empty
   134       handle Type.TYPE_MATCH => raise TYPE ("apply_inst", [T, T'], [t, u])
   135   in
   136     map_types (Envir.norm_type subst) t $ u
   137   end;
   138 
   139 fun head_conv cv ct =
   140   if can Thm.dest_comb ct then Conv.fun_conv (head_conv cv) ct else cv ct;
   141 
   142 
   143 (*** currying transformation ***)
   144 
   145 fun curry_const (A, B, C) =
   146   Const (@{const_name Product_Type.curry},
   147     [HOLogic.mk_prodT (A, B) --> C, A, B] ---> C);
   148 
   149 fun mk_curry f =
   150   case fastype_of f of
   151     Type ("fun", [Type (_, [S, T]), U]) =>
   152       curry_const (S, T, U) $ f
   153   | T => raise TYPE ("mk_curry", [T], [f]);
   154 
   155 (* iterated versions. Nonstandard left-nested tuples arise naturally
   156 from "split o split o split"*)
   157 fun curry_n arity = funpow (arity - 1) mk_curry;
   158 fun uncurry_n arity = funpow (arity - 1) HOLogic.mk_split;
   159 
   160 val curry_uncurry_ss = HOL_basic_ss addsimps
   161   [@{thm Product_Type.curry_split}, @{thm Product_Type.split_curry}]
   162 
   163 val split_conv_ss = HOL_basic_ss addsimps
   164   [@{thm Product_Type.split_conv}];
   165 
   166 fun mk_curried_induct args ctxt ccurry cuncurry rule =
   167   let
   168     val cert = Thm.cterm_of (Proof_Context.theory_of ctxt)
   169     val ([P], ctxt') = Variable.variant_fixes ["P"] ctxt
   170 
   171     val split_paired_all_conv =
   172       Conv.every_conv (replicate (length args - 1) (Conv.rewr_conv @{thm split_paired_all}))
   173 
   174     val split_params_conv = 
   175       Conv.params_conv ~1 (fn ctxt' =>
   176         Conv.implies_conv split_paired_all_conv Conv.all_conv)
   177 
   178     val inst_rule =
   179       cterm_instantiate' [SOME cuncurry, NONE, SOME ccurry] rule
   180 
   181     val plain_resultT = 
   182       Thm.prop_of inst_rule |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop
   183       |> Term.head_of |> Term.dest_Var |> snd |> range_type |> domain_type
   184     val PT = map (snd o dest_Free) args ---> plain_resultT --> HOLogic.boolT
   185     val x_inst = cert (foldl1 HOLogic.mk_prod args)
   186     val P_inst = cert (uncurry_n (length args) (Free (P, PT)))
   187 
   188     val inst_rule' = inst_rule
   189       |> Tactic.rule_by_tactic ctxt
   190         (Simplifier.simp_tac curry_uncurry_ss 4
   191          THEN Simplifier.simp_tac curry_uncurry_ss 3
   192          THEN CONVERSION (split_params_conv ctxt
   193            then_conv (Conv.forall_conv (K split_paired_all_conv) ctxt)) 3)
   194       |> Drule.instantiate' [] [NONE, NONE, SOME P_inst, SOME x_inst]
   195       |> Simplifier.full_simplify split_conv_ss
   196       |> singleton (Variable.export ctxt' ctxt)
   197   in
   198     inst_rule'
   199   end;
   200     
   201 
   202 (*** partial_function definition ***)
   203 
   204 fun gen_add_partial_function prep mode fixes_raw eqn_raw lthy =
   205   let
   206     val setup_data = the (lookup_mode lthy mode)
   207       handle Option.Option => error (cat_lines ["Unknown mode " ^ quote mode ^ ".",
   208         "Known modes are " ^ commas_quote (known_modes lthy) ^ "."]);
   209     val Setup_Data {fixp, mono, fixp_eq, fixp_induct} = setup_data;
   210 
   211     val ((fixes, [(eq_abinding, eqn)]), _) = prep fixes_raw [eqn_raw] lthy;
   212     val ((_, plain_eqn), args_ctxt) = Variable.focus eqn lthy;
   213 
   214     val ((f_binding, fT), mixfix) = the_single fixes;
   215     val fname = Binding.name_of f_binding;
   216 
   217     val cert = cterm_of (Proof_Context.theory_of lthy);
   218     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn);
   219     val (head, args) = strip_comb lhs;
   220     val argnames = map (fst o dest_Free) args;
   221     val F = fold_rev lambda (head :: args) rhs;
   222 
   223     val arity = length args;
   224     val (aTs, bTs) = chop arity (binder_types fT);
   225 
   226     val tupleT = foldl1 HOLogic.mk_prodT aTs;
   227     val fT_uc = tupleT :: bTs ---> body_type fT;
   228     val f_uc = Var ((fname, 0), fT_uc);
   229     val x_uc = Var (("x", 0), tupleT);
   230     val uncurry = lambda head (uncurry_n arity head);
   231     val curry = lambda f_uc (curry_n arity f_uc);
   232 
   233     val F_uc =
   234       lambda f_uc (uncurry_n arity (F $ curry_n arity f_uc));
   235 
   236     val mono_goal = apply_inst lthy mono (lambda f_uc (F_uc $ f_uc $ x_uc))
   237       |> HOLogic.mk_Trueprop
   238       |> Logic.all x_uc;
   239 
   240     val mono_thm = Goal.prove_internal [] (cert mono_goal)
   241         (K (mono_tac lthy 1))
   242       |> Thm.forall_elim (cert x_uc);
   243 
   244     val f_def_rhs = curry_n arity (apply_inst lthy fixp F_uc);
   245     val f_def_binding = Binding.conceal (Binding.name (Thm.def_name fname));
   246     val ((f, (_, f_def)), lthy') = Local_Theory.define
   247       ((f_binding, mixfix), ((f_def_binding, []), f_def_rhs)) lthy;
   248 
   249     val eqn = HOLogic.mk_eq (list_comb (f, args),
   250         Term.betapplys (F, f :: args))
   251       |> HOLogic.mk_Trueprop;
   252 
   253     val unfold =
   254       (cterm_instantiate' (map (SOME o cert) [uncurry, F, curry]) fixp_eq
   255         OF [mono_thm, f_def])
   256       |> Tactic.rule_by_tactic lthy (Simplifier.simp_tac curry_uncurry_ss 1);
   257 
   258     val mk_raw_induct =
   259       mk_curried_induct args args_ctxt (cert curry) (cert uncurry)
   260       #> singleton (Variable.export args_ctxt lthy)
   261       #> (fn thm => cterm_instantiate' [SOME (cert F)] thm OF [mono_thm, f_def])
   262       #> Drule.rename_bvars' (map SOME (fname :: argnames @ argnames))
   263 
   264     val raw_induct = Option.map mk_raw_induct fixp_induct
   265     val rec_rule = let open Conv in
   266       Goal.prove lthy' (map (fst o dest_Free) args) [] eqn (fn _ =>
   267         CONVERSION ((arg_conv o arg1_conv o head_conv o rewr_conv) (mk_meta_eq unfold)) 1
   268         THEN rtac @{thm refl} 1) end;
   269   in
   270     lthy'
   271     |> Local_Theory.note (eq_abinding, [rec_rule])
   272     |-> (fn (_, rec') =>
   273       Spec_Rules.add Spec_Rules.Equational ([f], rec')
   274       #> Local_Theory.note ((Binding.qualify true fname (Binding.name "simps"), []), rec') #> snd)
   275     |> (case raw_induct of NONE => I | SOME thm =>
   276          Local_Theory.note ((Binding.qualify true fname (Binding.name "raw_induct"), []), [thm]) #> snd)
   277   end;
   278 
   279 val add_partial_function = gen_add_partial_function Specification.check_spec;
   280 val add_partial_function_cmd = gen_add_partial_function Specification.read_spec;
   281 
   282 val mode = @{keyword "("} |-- Parse.xname --| @{keyword ")"};
   283 
   284 val _ = Outer_Syntax.local_theory
   285   "partial_function" "define partial function" Keyword.thy_decl
   286   ((mode -- (Parse.fixes -- (Parse.where_ |-- Parse_Spec.spec)))
   287      >> (fn (mode, (fixes, spec)) => add_partial_function_cmd mode fixes spec));
   288 
   289 
   290 val setup = Mono_Rules.setup;
   291 
   292 end