1 (* Title: HOL/Tools/Lifting/lifting_def.ML
4 Definitions for constants on quotient types.
7 signature LIFTING_DEF =
10 (binding * mixfix) -> typ -> term -> thm -> local_theory -> local_theory
13 (binding * string option * mixfix) * string -> local_theory -> Proof.state
15 val can_generate_code_cert: thm -> bool
18 structure Lifting_Def: LIFTING_DEF =
21 (** Interface and Syntax Setup **)
23 (* Generation of the code certificate from the rsp theorem *)
27 fun ants MRSL thm = fold (fn rl => fn thm => rl RS thm) ants thm
29 fun get_body_types (Type ("fun", [_, U]), Type ("fun", [_, V])) = get_body_types (U, V)
30 | get_body_types (U, V) = (U, V)
32 fun get_binder_types (Type ("fun", [T, U]), Type ("fun", [V, W])) = (T, V) :: get_binder_types (U, W)
33 | get_binder_types _ = []
35 fun force_rty_type ctxt rty rhs =
37 val thy = Proof_Context.theory_of ctxt
38 val rhs_schematic = singleton (Variable.polymorphic ctxt) rhs
39 val rty_schematic = fastype_of rhs_schematic
40 val match = Sign.typ_match thy (rty_schematic, rty) Vartab.empty
42 Envir.subst_term_types match rhs_schematic
45 fun unabs_def ctxt def =
47 val (_, rhs) = Thm.dest_equals (cprop_of def)
48 fun dest_abs (Abs (var_name, T, _)) = (var_name, T)
49 | dest_abs tm = raise TERM("get_abs_var",[tm])
50 val (var_name, T) = dest_abs (term_of rhs)
51 val (new_var_names, ctxt') = Variable.variant_fixes [var_name] ctxt
52 val thy = Proof_Context.theory_of ctxt'
53 val refl_thm = Thm.reflexive (cterm_of thy (Free (hd new_var_names, T)))
55 Thm.combination def refl_thm |>
56 singleton (Proof_Context.export ctxt' ctxt)
59 fun unabs_all_def ctxt def =
61 val (_, rhs) = Thm.dest_equals (cprop_of def)
62 val xs = strip_abs_vars (term_of rhs)
64 fold (K (unabs_def ctxt)) xs def
67 val map_fun_unfolded =
68 @{thm map_fun_def[abs_def]} |>
69 unabs_def @{context} |>
70 unabs_def @{context} |>
71 Local_Defs.unfold @{context} [@{thm comp_def}]
73 fun unfold_fun_maps ctm =
76 case (Thm.term_of ctm) of
77 Const (@{const_name "map_fun"}, _) $ _ $ _ =>
78 (Conv.arg_conv unfold_conv then_conv Conv.rewr_conv map_fun_unfolded) ctm
79 | _ => Conv.all_conv ctm
80 val try_beta_conv = Conv.try_conv (Thm.beta_conversion false)
82 (Conv.arg_conv (Conv.fun_conv unfold_conv then_conv try_beta_conv)) ctm
85 fun prove_rel ctxt rsp_thm (rty, qty) =
87 val ty_args = get_binder_types (rty, qty)
88 fun disch_arg args_ty thm =
90 val quot_thm = Lifting_Term.prove_quot_thm ctxt args_ty
92 [quot_thm, thm] MRSL @{thm apply_rsp''}
95 fold disch_arg ty_args rsp_thm
98 exception CODE_CERT_GEN of string
100 fun simplify_code_eq ctxt def_thm =
101 Local_Defs.unfold ctxt [@{thm o_def}, @{thm map_fun_def}, @{thm id_def}] def_thm
103 fun can_generate_code_cert quot_thm =
104 case Lifting_Term.quot_thm_rel quot_thm of
105 Const (@{const_name HOL.eq}, _) => true
106 | Const (@{const_name invariant}, _) $ _ => true
109 fun generate_code_cert ctxt def_thm rsp_thm (rty, qty) =
111 val thy = Proof_Context.theory_of ctxt
112 val quot_thm = Lifting_Term.prove_quot_thm ctxt (get_body_types (rty, qty))
113 val fun_rel = prove_rel ctxt rsp_thm (rty, qty)
114 val abs_rep_thm = [quot_thm, fun_rel] MRSL @{thm Quotient_rep_abs}
116 case (HOLogic.dest_Trueprop o prop_of) fun_rel of
117 Const (@{const_name HOL.eq}, _) $ _ $ _ => abs_rep_thm
118 | Const (@{const_name invariant}, _) $ _ $ _ $ _ => abs_rep_thm RS @{thm invariant_to_eq}
119 | _ => raise CODE_CERT_GEN "relation is neither equality nor invariant"
120 val unfolded_def = Conv.fconv_rule unfold_fun_maps def_thm
121 val unabs_def = unabs_all_def ctxt unfolded_def
122 val rep = (cterm_of thy o Lifting_Term.quot_thm_rep) quot_thm
123 val rep_refl = Thm.reflexive rep RS @{thm meta_eq_to_obj_eq}
124 val repped_eq = [rep_refl, unabs_def RS @{thm meta_eq_to_obj_eq}] MRSL @{thm cong}
125 val code_cert = [repped_eq, abs_rep_eq] MRSL @{thm trans}
127 simplify_code_eq ctxt code_cert
130 fun is_abstype ctxt typ =
132 val thy = Proof_Context.theory_of ctxt
133 val type_name = (fst o dest_Type) typ
135 (snd oo Code.get_type) thy type_name
139 fun define_code_cert code_eqn_thm_name def_thm rsp_thm (rty, qty) lthy =
141 val (rty_body, qty_body) = get_body_types (rty, qty)
142 val quot_thm = Lifting_Term.prove_quot_thm lthy (rty_body, qty_body)
144 if can_generate_code_cert quot_thm then
146 val code_cert = generate_code_cert lthy def_thm rsp_thm (rty, qty)
147 val add_abs_eqn_attribute =
148 Thm.declaration_attribute (fn thm => Context.mapping (Code.add_abs_eqn thm) I)
149 val add_abs_eqn_attrib = Attrib.internal (K add_abs_eqn_attribute);
151 (snd oo Local_Theory.note) ((code_eqn_thm_name, []), [code_cert]) lthy
153 if is_abstype lthy qty_body then
154 (snd oo Local_Theory.note) ((Binding.empty, [add_abs_eqn_attrib]), [code_cert]) lthy'
162 fun define_code_eq code_eqn_thm_name def_thm lthy =
164 val unfolded_def = Conv.fconv_rule unfold_fun_maps def_thm
165 val code_eq = unabs_all_def lthy unfolded_def
166 val simp_code_eq = simplify_code_eq lthy code_eq
169 |> (snd oo Local_Theory.note) ((code_eqn_thm_name, [Code.add_default_eqn_attrib]), [simp_code_eq])
172 fun define_code code_eqn_thm_name def_thm rsp_thm (rty, qty) lthy =
173 if body_type rty = body_type qty then
174 define_code_eq code_eqn_thm_name def_thm lthy
176 define_code_cert code_eqn_thm_name def_thm rsp_thm (rty, qty) lthy
179 fun add_lift_def var qty rhs rsp_thm lthy =
181 val rty = fastype_of rhs
182 val quotient_thm = Lifting_Term.prove_quot_thm lthy (rty, qty)
183 val absrep_trm = Lifting_Term.quot_thm_abs quotient_thm
184 val rty_forced = (domain_type o fastype_of) absrep_trm
185 val forced_rhs = force_rty_type lthy rty_forced rhs
186 val lhs = Free (Binding.print (#1 var), qty)
187 val prop = Logic.mk_equals (lhs, absrep_trm $ forced_rhs)
188 val (_, prop') = Local_Defs.cert_def lthy prop
189 val (_, newrhs) = Local_Defs.abs_def prop'
191 val ((_, (_ , def_thm)), lthy') =
192 Local_Theory.define (var, ((Thm.def_binding (#1 var), []), newrhs)) lthy
194 val transfer_thm = [quotient_thm, rsp_thm, def_thm] MRSL @{thm Quotient_to_transfer}
195 |> Raw_Simplifier.rewrite_rule (Transfer.get_relator_eq lthy')
197 fun qualify defname suffix = Binding.qualified true suffix defname
199 val lhs_name = (#1 var)
200 val rsp_thm_name = qualify lhs_name "rsp"
201 val code_eqn_thm_name = qualify lhs_name "rep_eq"
202 val transfer_thm_name = qualify lhs_name "transfer"
203 val transfer_attr = Attrib.internal (K Transfer.transfer_add)
206 |> (snd oo Local_Theory.note) ((rsp_thm_name, []), [rsp_thm])
207 |> (snd oo Local_Theory.note) ((transfer_thm_name, [transfer_attr]), [transfer_thm])
208 |> define_code code_eqn_thm_name def_thm rsp_thm (rty_forced, qty)
211 fun mk_readable_rsp_thm_eq tm lthy =
213 val ctm = cterm_of (Proof_Context.theory_of lthy) tm
215 fun norm_fun_eq ctm =
217 fun abs_conv2 cv = Conv.abs_conv (K (Conv.abs_conv (K cv) lthy)) lthy
218 fun erase_quants ctm' =
219 case (Thm.term_of ctm') of
220 Const ("HOL.eq", _) $ _ $ _ => Conv.all_conv ctm'
221 | _ => (Conv.binder_conv (K erase_quants) lthy then_conv
222 Conv.rewr_conv @{thm fun_eq_iff[symmetric, THEN eq_reflection]}) ctm'
224 (abs_conv2 erase_quants then_conv Thm.eta_conversion) ctm
227 fun simp_arrows_conv ctm =
229 val unfold_conv = Conv.rewrs_conv
230 [@{thm fun_rel_eq_invariant[THEN eq_reflection]}, @{thm fun_rel_eq_rel[THEN eq_reflection]},
231 @{thm fun_rel_def[THEN eq_reflection]}]
232 val left_conv = simp_arrows_conv then_conv Conv.try_conv norm_fun_eq
233 fun binop_conv2 cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
235 case (Thm.term_of ctm) of
236 Const (@{const_name "fun_rel"}, _) $ _ $ _ =>
237 (binop_conv2 left_conv simp_arrows_conv then_conv unfold_conv) ctm
238 | _ => Conv.all_conv ctm
241 val unfold_ret_val_invs = Conv.bottom_conv
242 (K (Conv.try_conv (Conv.rewr_conv @{thm invariant_same_args}))) lthy
243 val simp_conv = Conv.arg_conv (Conv.fun2_conv simp_arrows_conv)
244 val univq_conv = Conv.rewr_conv @{thm HOL.all_simps(6)[symmetric, THEN eq_reflection]}
245 val univq_prenex_conv = Conv.top_conv (K (Conv.try_conv univq_conv)) lthy
246 val beta_conv = Thm.beta_conversion true
248 (simp_conv then_conv univq_prenex_conv then_conv beta_conv then_conv unfold_ret_val_invs) ctm
250 Object_Logic.rulify(eq_thm RS Drule.equal_elim_rule2)
255 fun lift_def_cmd (raw_var, rhs_raw) lthy =
257 val ((binding, SOME qty, mx), lthy') = yield_singleton Proof_Context.read_vars raw_var lthy
258 val rhs = (Syntax.check_term lthy' o Syntax.parse_term lthy') rhs_raw
260 fun try_to_prove_refl thm =
265 |> Logic.dest_implies
268 |> try HOLogic.dest_Trueprop
271 SOME (Const ("HOL.eq", _) $ _ $ _) => SOME (@{thm refl} RS thm)
275 val rsp_rel = Lifting_Term.equiv_relation lthy' (fastype_of rhs, qty)
276 val rty_forced = (domain_type o fastype_of) rsp_rel;
277 val forced_rhs = force_rty_type lthy' rty_forced rhs;
278 val internal_rsp_tm = HOLogic.mk_Trueprop (rsp_rel $ forced_rhs $ forced_rhs)
279 val readable_rsp_thm_eq = mk_readable_rsp_thm_eq internal_rsp_tm lthy'
280 val maybe_proven_rsp_thm = try_to_prove_refl readable_rsp_thm_eq
281 val (readable_rsp_tm, _) = Logic.dest_implies (prop_of readable_rsp_thm_eq)
283 fun after_qed thm_list lthy =
285 val internal_rsp_thm =
287 [] => the maybe_proven_rsp_thm
288 | [[thm]] => Goal.prove lthy [] [] internal_rsp_tm
289 (fn _ => rtac readable_rsp_thm_eq 1 THEN Proof_Context.fact_tac [thm] 1)
291 add_lift_def (binding, mx) qty rhs internal_rsp_thm lthy
295 case maybe_proven_rsp_thm of
296 SOME _ => Proof.theorem NONE after_qed [] lthy'
297 | NONE => Proof.theorem NONE after_qed [[(readable_rsp_tm,[])]] lthy'
300 fun quot_thm_err ctxt (rty, qty) pretty_msg =
302 val error_msg = cat_lines
303 ["Lifting failed for the following types:",
304 Pretty.string_of (Pretty.block
305 [Pretty.str "Raw type:", Pretty.brk 2, Syntax.pretty_typ ctxt rty]),
306 Pretty.string_of (Pretty.block
307 [Pretty.str "Abstract type:", Pretty.brk 2, Syntax.pretty_typ ctxt qty]),
309 (Pretty.string_of (Pretty.block
310 [Pretty.str "Reason:", Pretty.brk 2, pretty_msg]))]
315 fun check_rty_err ctxt (rty_schematic, rty_forced) (raw_var, rhs_raw) =
317 val (_, ctxt') = yield_singleton Proof_Context.read_vars raw_var ctxt
318 val rhs = (Syntax.check_term ctxt' o Syntax.parse_term ctxt') rhs_raw
319 val error_msg = cat_lines
320 ["Lifting failed for the following term:",
321 Pretty.string_of (Pretty.block
322 [Pretty.str "Term:", Pretty.brk 2, Syntax.pretty_term ctxt rhs]),
323 Pretty.string_of (Pretty.block
324 [Pretty.str "Type:", Pretty.brk 2, Syntax.pretty_typ ctxt rty_schematic]),
326 (Pretty.string_of (Pretty.block
327 [Pretty.str "Reason:",
329 Pretty.str "The type of the term cannot be instancied to",
331 Pretty.quote (Syntax.pretty_typ ctxt rty_forced),
337 fun lift_def_cmd_with_err_handling (raw_var, rhs_raw) lthy =
338 (lift_def_cmd (raw_var, rhs_raw) lthy
339 handle Lifting_Term.QUOT_THM (rty, qty, msg) => quot_thm_err lthy (rty, qty) msg)
340 handle Lifting_Term.CHECK_RTY (rty_schematic, rty_forced) =>
341 check_rty_err lthy (rty_schematic, rty_forced) (raw_var, rhs_raw)
343 (* parser and command *)
345 ((Parse.binding -- (@{keyword "::"} |-- (Parse.typ >> SOME) -- Parse.opt_mixfix')) >> Parse.triple2)
346 --| @{keyword "is"} -- Parse.term
349 Outer_Syntax.local_theory_to_proof @{command_spec "lift_definition"}
350 "definition for constants over the quotient type"
351 (liftdef_parser >> lift_def_cmd_with_err_handling)