src/HOL/Tools/Quotient/quotient_def.ML
author kuncar
Fri, 23 Mar 2012 14:25:31 +0100
changeset 47966 3ea48c19673e
parent 47961 d5cd13aca90b
child 48027 861f53bd95fe
permissions -rw-r--r--
generation of a code certificate from a respectfulness theorem for constants lifted by the quotient_definition command & setup_lifting command: setups Quotient infrastructure from a typedef theorem
     1 (*  Title:      HOL/Tools/Quotient/quotient_def.ML
     2     Author:     Cezary Kaliszyk and Christian Urban
     3 
     4 Definitions for constants on quotient types.
     5 *)
     6 
     7 signature QUOTIENT_DEF =
     8 sig
     9   val add_quotient_def:
    10     ((binding * mixfix) * Attrib.binding) * (term * term) -> thm ->
    11     local_theory -> Quotient_Info.quotconsts * local_theory
    12 
    13   val quotient_def:
    14     (binding * typ option * mixfix) option * (Attrib.binding * (term * term)) ->
    15     local_theory -> Proof.state
    16 
    17   val quotient_def_cmd:
    18     (binding * string option * mixfix) option * (Attrib.binding * (string * string)) ->
    19     local_theory -> Proof.state
    20 
    21   val lift_raw_const: typ list -> (string * term * mixfix) -> local_theory ->
    22     Quotient_Info.quotconsts * local_theory
    23 end;
    24 
    25 structure Quotient_Def: QUOTIENT_DEF =
    26 struct
    27 
    28 (** Interface and Syntax Setup **)
    29 
    30 (* Generation of the code certificate from the rsp theorem *)
    31 
    32 infix 0 MRSL
    33 
    34 fun ants MRSL thm = fold (fn rl => fn thm => rl RS thm) ants thm
    35 
    36 fun get_body_types (Type ("fun", [_, U]), Type ("fun", [_, V])) = get_body_types (U, V)
    37   | get_body_types (U, V)  = (U, V)
    38 
    39 fun get_binder_types (Type ("fun", [T, U]), Type ("fun", [V, W])) = (T, V) :: get_binder_types (U, W)
    40   | get_binder_types _ = []
    41 
    42 fun unabs_def ctxt def = 
    43   let
    44     val (_, rhs) = Thm.dest_equals (cprop_of def)
    45     fun dest_abs (Abs (var_name, T, _)) = (var_name, T)
    46       | dest_abs tm = raise TERM("get_abs_var",[tm])
    47     val (var_name, T) = dest_abs (term_of rhs)
    48     val (new_var_names, ctxt') = Variable.variant_fixes [var_name] ctxt
    49     val thy = Proof_Context.theory_of ctxt'
    50     val refl_thm = Thm.reflexive (cterm_of thy (Free (hd new_var_names, T)))
    51   in
    52     Thm.combination def refl_thm |>
    53     singleton (Proof_Context.export ctxt' ctxt)
    54   end
    55 
    56 fun unabs_all_def ctxt def = 
    57   let
    58     val (_, rhs) = Thm.dest_equals (cprop_of def)
    59     val xs = strip_abs_vars (term_of rhs)
    60   in  
    61     fold (K (unabs_def ctxt)) xs def
    62   end
    63 
    64 val map_fun_unfolded = 
    65   @{thm map_fun_def[abs_def]} |>
    66   unabs_def @{context} |>
    67   unabs_def @{context} |>
    68   Local_Defs.unfold @{context} [@{thm comp_def}]
    69 
    70 fun unfold_fun_maps ctm =
    71   let
    72     fun unfold_conv ctm =
    73       case (Thm.term_of ctm) of
    74         Const (@{const_name "map_fun"}, _) $ _ $ _ => 
    75           (Conv.arg_conv unfold_conv then_conv Conv.rewr_conv map_fun_unfolded) ctm
    76         | _ => Conv.all_conv ctm
    77     val try_beta_conv = Conv.try_conv (Thm.beta_conversion false)
    78   in
    79     (Conv.arg_conv (Conv.fun_conv unfold_conv then_conv try_beta_conv)) ctm
    80   end
    81 
    82 fun prove_rel ctxt rsp_thm (rty, qty) =
    83   let
    84     val ty_args = get_binder_types (rty, qty)
    85     fun disch_arg args_ty thm = 
    86       let
    87         val quot_thm = Quotient_Term.prove_quot_theorem ctxt args_ty
    88       in
    89         [quot_thm, thm] MRSL @{thm apply_rsp''}
    90       end
    91   in
    92     fold disch_arg ty_args rsp_thm
    93   end
    94 
    95 exception CODE_CERT_GEN of string
    96 
    97 fun simplify_code_eq ctxt def_thm = 
    98   Local_Defs.unfold ctxt [@{thm o_def}, @{thm map_fun_def}, @{thm id_def}] def_thm
    99 
   100 fun generate_code_cert ctxt def_thm rsp_thm (rty, qty) =
   101   let
   102     val quot_thm = Quotient_Term.prove_quot_theorem ctxt (get_body_types (rty, qty))
   103     val fun_rel = prove_rel ctxt rsp_thm (rty, qty)
   104     val abs_rep_thm = [quot_thm, fun_rel] MRSL @{thm Quotient_rep_abs}
   105     val abs_rep_eq = 
   106       case (HOLogic.dest_Trueprop o prop_of) fun_rel of
   107         Const (@{const_name HOL.eq}, _) $ _ $ _ => abs_rep_thm
   108         | Const (@{const_name invariant}, _) $ _ $ _ $ _ => abs_rep_thm RS @{thm invariant_to_eq}
   109         | _ => raise CODE_CERT_GEN "relation is neither equality nor invariant"
   110     val unfolded_def = Conv.fconv_rule unfold_fun_maps def_thm
   111     val unabs_def = unabs_all_def ctxt unfolded_def
   112     val rep = (snd o Thm.dest_comb o snd o Thm.dest_comb o cprop_of) quot_thm
   113     val rep_refl = Thm.reflexive rep RS @{thm meta_eq_to_obj_eq}
   114     val repped_eq = [rep_refl, unabs_def RS @{thm meta_eq_to_obj_eq}] MRSL @{thm cong}
   115     val code_cert = [repped_eq, abs_rep_eq] MRSL @{thm trans}
   116   in
   117     simplify_code_eq ctxt code_cert
   118   end
   119 
   120 fun define_code_cert def_thm rsp_thm (rty, qty) lthy = 
   121   let
   122     val quot_thm = Quotient_Term.prove_quot_theorem lthy (get_body_types (rty, qty))
   123   in
   124     if Quotient_Type.can_generate_code_cert quot_thm then
   125       let
   126         val code_cert = generate_code_cert lthy def_thm rsp_thm (rty, qty)
   127         val add_abs_eqn_attribute = 
   128           Thm.declaration_attribute (fn thm => Context.mapping (Code.add_abs_eqn thm) I)
   129         val add_abs_eqn_attrib = Attrib.internal (K add_abs_eqn_attribute);
   130       in
   131         lthy
   132           |> (snd oo Local_Theory.note) ((Binding.empty, [add_abs_eqn_attrib]), [code_cert])
   133       end
   134     else
   135       lthy
   136   end
   137 
   138 fun define_code_eq def_thm lthy =
   139   let
   140     val unfolded_def = Conv.fconv_rule unfold_fun_maps def_thm
   141     val code_eq = unabs_all_def lthy unfolded_def
   142     val simp_code_eq = simplify_code_eq lthy code_eq
   143   in
   144     lthy
   145       |> (snd oo Local_Theory.note) ((Binding.empty, [Code.add_default_eqn_attrib]), [simp_code_eq])
   146   end
   147 
   148 fun define_code def_thm rsp_thm (rty, qty) lthy =
   149   if body_type rty = body_type qty then 
   150     define_code_eq def_thm lthy
   151   else 
   152     define_code_cert def_thm rsp_thm (rty, qty) lthy
   153 
   154 (* The ML-interface for a quotient definition takes
   155    as argument:
   156 
   157     - an optional binding and mixfix annotation
   158     - attributes
   159     - the new constant as term
   160     - the rhs of the definition as term
   161     - respectfulness theorem for the rhs
   162 
   163    It stores the qconst_info in the quotconsts data slot.
   164 
   165    Restriction: At the moment the left- and right-hand
   166    side of the definition must be a constant.
   167 *)
   168 fun error_msg bind str =
   169   let
   170     val name = Binding.name_of bind
   171     val pos = Position.str_of (Binding.pos_of bind)
   172   in
   173     error ("Head of quotient_definition " ^
   174       quote str ^ " differs from declaration " ^ name ^ pos)
   175   end
   176 
   177 fun add_quotient_def ((var, (name, atts)), (lhs, rhs)) rsp_thm lthy =
   178   let
   179     val rty = fastype_of rhs
   180     val qty = fastype_of lhs
   181     val absrep_trm = 
   182       Quotient_Term.absrep_fun lthy Quotient_Term.AbsF (rty, qty) $ rhs
   183     val prop = Syntax.check_term lthy (Logic.mk_equals (lhs, absrep_trm))
   184     val (_, prop') = Local_Defs.cert_def lthy prop
   185     val (_, newrhs) = Local_Defs.abs_def prop'
   186 
   187     val ((trm, (_ , def_thm)), lthy') =
   188       Local_Theory.define (var, ((Thm.def_binding_optional (#1 var) name, atts), newrhs)) lthy
   189 
   190     (* data storage *)
   191     val qconst_data = {qconst = trm, rconst = rhs, def = def_thm}
   192     fun get_rsp_thm_name (lhs_name, _) = Binding.suffix_name "_rsp" lhs_name
   193     
   194     val lthy'' = lthy'
   195       |> Local_Theory.declaration {syntax = false, pervasive = true}
   196         (fn phi =>
   197           (case Quotient_Info.transform_quotconsts phi qconst_data of
   198             qcinfo as {qconst = Const (c, _), ...} =>
   199               Quotient_Info.update_quotconsts c qcinfo
   200           | _ => I))
   201       |> (snd oo Local_Theory.note) 
   202         ((get_rsp_thm_name var, [Attrib.internal (K Quotient_Info.rsp_rules_add)]),
   203         [rsp_thm])
   204       |> define_code def_thm rsp_thm (rty, qty)
   205 
   206   in
   207     (qconst_data, lthy'')
   208   end
   209 
   210 fun mk_readable_rsp_thm_eq tm lthy =
   211   let
   212     val ctm = cterm_of (Proof_Context.theory_of lthy) tm
   213     
   214     fun norm_fun_eq ctm = 
   215       let
   216         fun abs_conv2 cv = Conv.abs_conv (K (Conv.abs_conv (K cv) lthy)) lthy
   217         fun erase_quants ctm' =
   218           case (Thm.term_of ctm') of
   219             Const ("HOL.eq", _) $ _ $ _ => Conv.all_conv ctm'
   220             | _ => (Conv.binder_conv (K erase_quants) lthy then_conv 
   221               Conv.rewr_conv @{thm fun_eq_iff[symmetric, THEN eq_reflection]}) ctm'
   222       in
   223         (abs_conv2 erase_quants then_conv Thm.eta_conversion) ctm
   224       end
   225 
   226     fun simp_arrows_conv ctm =
   227       let
   228         val unfold_conv = Conv.rewrs_conv 
   229           [@{thm fun_rel_eq_invariant[THEN eq_reflection]}, @{thm fun_rel_eq_rel[THEN eq_reflection]}, 
   230             @{thm fun_rel_def[THEN eq_reflection]}]
   231         val left_conv = simp_arrows_conv then_conv Conv.try_conv norm_fun_eq
   232         fun binop_conv2 cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
   233       in
   234         case (Thm.term_of ctm) of
   235           Const (@{const_name "fun_rel"}, _) $ _ $ _ => 
   236             (binop_conv2  left_conv simp_arrows_conv then_conv unfold_conv) ctm
   237           | _ => Conv.all_conv ctm
   238       end
   239 
   240     val unfold_ret_val_invs = Conv.bottom_conv 
   241       (K (Conv.try_conv (Conv.rewr_conv @{thm invariant_same_args}))) lthy 
   242     val simp_conv = Conv.arg_conv (Conv.fun2_conv simp_arrows_conv)
   243     val univq_conv = Conv.rewr_conv @{thm HOL.all_simps(6)[symmetric, THEN eq_reflection]}
   244     val univq_prenex_conv = Conv.top_conv (K (Conv.try_conv univq_conv)) lthy
   245     val beta_conv = Thm.beta_conversion true
   246     val eq_thm = 
   247       (simp_conv then_conv univq_prenex_conv then_conv beta_conv then_conv unfold_ret_val_invs) ctm
   248   in
   249     Object_Logic.rulify(eq_thm RS Drule.equal_elim_rule2)
   250   end
   251 
   252 
   253 
   254 fun gen_quotient_def prep_vars prep_term (raw_var, (attr, (lhs_raw, rhs_raw))) lthy =
   255   let
   256     val (vars, ctxt) = prep_vars (the_list raw_var) lthy
   257     val T_opt = (case vars of [(_, SOME T, _)] => SOME T | _ => NONE)
   258     val lhs = prep_term T_opt ctxt lhs_raw
   259     val rhs = prep_term NONE ctxt rhs_raw
   260 
   261     val (lhs_str, lhs_ty) = dest_Free lhs handle TERM _ => error "Constant already defined."
   262     val _ = if null (strip_abs_vars rhs) then () else error "The definiens cannot be an abstraction"
   263     val _ = if is_Const rhs then () else warning "The definiens is not a constant"
   264 
   265     val var =
   266       (case vars of 
   267         [] => (Binding.name lhs_str, NoSyn)
   268       | [(binding, _, mx)] =>
   269           if Variable.check_name binding = lhs_str then (binding, mx)
   270           else error_msg binding lhs_str
   271       | _ => raise Match)
   272     
   273     fun try_to_prove_refl thm = 
   274       let
   275         val lhs_eq =
   276           thm
   277           |> prop_of
   278           |> Logic.dest_implies
   279           |> fst
   280           |> strip_all_body
   281           |> try HOLogic.dest_Trueprop
   282       in
   283         case lhs_eq of
   284           SOME (Const ("HOL.eq", _) $ _ $ _) => SOME (@{thm refl} RS thm)
   285           | SOME _ => (case body_type (fastype_of lhs) of
   286             Type (typ_name, _) => ( SOME
   287               (#equiv_thm (the (Quotient_Info.lookup_quotients lthy typ_name)) 
   288                 RS @{thm Equiv_Relations.equivp_reflp} RS thm)
   289               handle _ => NONE)
   290             | _ => NONE
   291             )
   292           | _ => NONE
   293       end
   294 
   295     val rsp_rel = Quotient_Term.equiv_relation lthy (fastype_of rhs, lhs_ty)
   296     val internal_rsp_tm = HOLogic.mk_Trueprop (Syntax.check_term lthy (rsp_rel $ rhs $ rhs))
   297     val readable_rsp_thm_eq = mk_readable_rsp_thm_eq internal_rsp_tm lthy
   298     val maybe_proven_rsp_thm = try_to_prove_refl readable_rsp_thm_eq
   299     val (readable_rsp_tm, _) = Logic.dest_implies (prop_of readable_rsp_thm_eq)
   300   
   301     fun after_qed thm_list lthy = 
   302       let
   303         val internal_rsp_thm =
   304           case thm_list of
   305             [] => the maybe_proven_rsp_thm
   306           | [[thm]] => Goal.prove ctxt [] [] internal_rsp_tm 
   307             (fn _ => rtac readable_rsp_thm_eq 1 THEN Proof_Context.fact_tac [thm] 1)
   308       in
   309         snd (add_quotient_def ((var, attr), (lhs, rhs)) internal_rsp_thm lthy)
   310       end
   311 
   312   in
   313     case maybe_proven_rsp_thm of
   314       SOME _ => Proof.theorem NONE after_qed [] lthy
   315       | NONE =>  Proof.theorem NONE after_qed [[(readable_rsp_tm,[])]] lthy
   316   end
   317 
   318 fun check_term' cnstr ctxt =
   319   Syntax.check_term ctxt o (case cnstr of SOME T => Type.constraint T | _ => I)
   320 
   321 fun read_term' cnstr ctxt =
   322   check_term' cnstr ctxt o Syntax.parse_term ctxt
   323 
   324 val quotient_def = gen_quotient_def Proof_Context.cert_vars check_term'
   325 val quotient_def_cmd = gen_quotient_def Proof_Context.read_vars read_term'
   326 
   327 
   328 (* a wrapper for automatically lifting a raw constant *)
   329 fun lift_raw_const qtys (qconst_name, rconst, mx) ctxt =
   330   let
   331     val rty = fastype_of rconst
   332     val qty = Quotient_Term.derive_qtyp ctxt qtys rty
   333     val lhs = Free (qconst_name, qty)
   334   in
   335     (*quotient_def (SOME (Binding.name qconst_name, NONE, mx), (Attrib.empty_binding, (lhs, rconst))) ctxt*)
   336     ({qconst = lhs, rconst = lhs, def = @{thm refl}}, ctxt)
   337   end
   338 
   339 (* parser and command *)
   340 val quotdef_parser =
   341   Scan.option Parse_Spec.constdecl --
   342     Parse.!!! (Parse_Spec.opt_thm_name ":" -- (Parse.term --| @{keyword "is"} -- Parse.term))
   343 
   344 val _ =
   345   Outer_Syntax.local_theory_to_proof @{command_spec "quotient_definition"}
   346     "definition for constants over the quotient type"
   347       (quotdef_parser >> quotient_def_cmd)
   348 
   349 
   350 end; (* structure *)