1 (* Title: HOL/Tools/Function/scnp_reconstruct.ML
2 Author: Armin Heller, TU Muenchen
3 Author: Alexander Krauss, TU Muenchen
5 Proof reconstruction for SCNP
8 signature SCNP_RECONSTRUCT =
11 val sizechange_tac : Proof.context -> tactic -> tactic
13 val decomp_scnp_tac : ScnpSolve.label list -> Proof.context -> tactic
15 val setup : theory -> theory
17 datatype multiset_setup =
21 mk_mset : typ -> term list -> term,
22 mset_regroup_conv : int list -> conv,
23 mset_member_tac : int -> int -> tactic,
24 mset_nonempty_tac : int -> tactic,
25 mset_pwleq_tac : int -> tactic,
26 set_of_simps : thm list,
34 val multiset_setup : multiset_setup -> theory -> theory
38 structure ScnpReconstruct : SCNP_RECONSTRUCT =
41 val PROFILE = Function_Common.PROFILE
45 val natT = HOLogic.natT
46 val nat_pairT = HOLogic.mk_prodT (natT, natT)
48 (* Theory dependencies *)
50 datatype multiset_setup =
54 mk_mset : typ -> term list -> term,
55 mset_regroup_conv : int list -> conv,
56 mset_member_tac : int -> int -> tactic,
57 mset_nonempty_tac : int -> tactic,
58 mset_pwleq_tac : int -> tactic,
59 set_of_simps : thm list,
66 structure Multiset_Setup = Theory_Data
68 type T = multiset_setup option
71 fun merge (v1, v2) = if is_some v1 then v1 else v2
74 val multiset_setup = Multiset_Setup.put o SOME
76 fun undef _ = error "undef"
77 fun get_multiset_setup thy = Multiset_Setup.get thy
78 |> the_default (Multiset
79 { msetT = undef, mk_mset=undef,
80 mset_regroup_conv=undef, mset_member_tac = undef,
81 mset_nonempty_tac = undef, mset_pwleq_tac = undef,
82 set_of_simps = [],reduction_pair = refl,
83 smsI'=refl, wmsI2''=refl, wmsI1=refl })
85 fun order_rpair _ MAX = @{thm max_rpair_set}
86 | order_rpair msrp MS = msrp
87 | order_rpair _ MIN = @{thm min_rpair_set}
89 fun ord_intros_max true =
90 (@{thm smax_emptyI}, @{thm smax_insertI})
91 | ord_intros_max false =
92 (@{thm wmax_emptyI}, @{thm wmax_insertI})
93 fun ord_intros_min true =
94 (@{thm smin_emptyI}, @{thm smin_insertI})
95 | ord_intros_min false =
96 (@{thm wmin_emptyI}, @{thm wmin_insertI})
100 val n = Termination.get_num_points D
101 val arity = length o Termination.get_measures D
102 fun measure p i = nth (Termination.get_measures D p) i
106 val (_, p, _, q, _, _) = Termination.dest_call D c
109 case Termination.get_descent D c (measure p i) (measure q j)
110 of SOME (Termination.Less _) => cons (i, GTR, j)
111 | SOME (Termination.LessEq _) => cons (i, GEQ, j)
115 fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
120 GP (map_range arity n, map mk_graph cs)
123 (* General reduction pair application *)
124 fun rem_inv_img ctxt =
126 val unfold_tac = Local_Defs.unfold_tac ctxt
128 rtac @{thm subsetI} 1
129 THEN etac @{thm CollectE} 1
130 THEN REPEAT (etac @{thm exE} 1)
131 THEN unfold_tac @{thms inv_image_def}
132 THEN rtac @{thm CollectI} 1
133 THEN etac @{thm conjE} 1
134 THEN etac @{thm ssubst} 1
135 THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality}
141 val setT = HOLogic.mk_setT
143 fun set_member_tac m i =
144 if m = 0 then rtac @{thm insertI1} i
145 else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i
147 val set_nonempty_tac = rtac @{thm insert_not_empty}
149 fun set_finite_tac i =
150 rtac @{thm finite.emptyI} i
151 ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st))
156 fun reconstruct_tac ctxt D cs (GP (_, gs)) certificate =
158 val thy = ProofContext.theory_of ctxt
161 mset_regroup_conv, mset_pwleq_tac, set_of_simps,
162 smsI', wmsI2'', wmsI1, reduction_pair=ms_rp, ...}
163 = get_multiset_setup thy
165 fun measure_fn p = nth (Termination.get_measures D p)
167 fun get_desc_thm cidx m1 m2 bStrict =
168 case Termination.get_descent D (nth cs cidx) m1 m2
169 of SOME (Termination.Less thm) =>
171 else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le}))
172 | SOME (Termination.LessEq (thm, _)) =>
173 if not bStrict then thm
174 else sys_error "get_desc_thm"
175 | _ => sys_error "get_desc_thm"
177 val (label, lev, sl, covering) = certificate
179 fun prove_lev strict g =
181 val G (p, q, _) = nth gs g
183 fun less_proof strict (j, b) (i, a) =
185 val tag_flag = b < a orelse (not strict andalso b <= a)
188 get_desc_thm g (measure_fn p i) (measure_fn q j)
190 |> Conv.fconv_rule (Thm.beta_conversion true)
193 then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
194 else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
196 rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
197 THEN (if tag_flag then Arith_Data.arith_tac ctxt 1 else all_tac)
200 fun steps_tac MAX strict lq lp =
202 val (empty, step) = ord_intros_max strict
205 then rtac empty 1 THEN set_finite_tac 1
206 THEN (if strict then set_nonempty_tac 1 else all_tac)
209 val (j, b) :: rest = lq
210 val (i, a) = the (covering g strict j)
211 fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1
212 val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
214 rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp
217 | steps_tac MIN strict lq lp =
219 val (empty, step) = ord_intros_min strict
223 THEN (if strict then set_nonempty_tac 1 else all_tac)
226 val (i, a) :: rest = lp
227 val (j, b) = the (covering g strict i)
228 fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1
229 val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
231 rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest
234 | steps_tac MS strict lq lp =
236 fun get_str_cover (j, b) =
237 if is_some (covering g true j) then SOME (j, b) else NONE
238 fun get_wk_cover (j, b) = the (covering g false j)
240 val qs = subtract (op =) (map_filter get_str_cover lq) lq
241 val ps = map get_wk_cover qs
243 fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys
244 val iqs = indices lq qs
245 val ips = indices lp ps
249 params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
251 t_conv arg1_conv (mset_regroup_conv iqs)
252 then_conv t_conv arg_conv (mset_regroup_conv ips)
255 CONVERSION goal_rewrite 1
256 THEN (if strict then rtac smsI' 1
257 else if qs = lq then rtac wmsI2'' 1
259 THEN mset_pwleq_tac 1
260 THEN EVERY (map2 (less_proof false) qs ps)
261 THEN (if strict orelse qs <> lq
262 then Local_Defs.unfold_tac ctxt set_of_simps
263 THEN steps_tac MAX true
264 (subtract (op =) qs lq) (subtract (op =) ps lp)
269 THEN steps_tac label strict (nth lev q) (nth lev p)
272 val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT)
274 fun tag_pair p (i, tag) =
275 HOLogic.pair_const natT natT $
276 (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag
278 fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p,
279 mk_set nat_pairT (map (tag_pair p) lm))
283 |> Termination.mk_sumcases D (setT nat_pairT)
286 PROFILE "Proof Reconstruction"
287 (CONVERSION (Conv.arg_conv (Conv.arg_conv (Function_Lib.regroup_union_conv sl))) 1
288 THEN (rtac @{thm reduction_pair_lemma} 1)
289 THEN (rtac @{thm rp_inv_image_rp} 1)
290 THEN (rtac (order_rpair ms_rp label) 1)
291 THEN PRIMITIVE (instantiate' [] [SOME level_mapping])
292 THEN unfold_tac @{thms rp_inv_image_def} (simpset_of ctxt)
293 THEN Local_Defs.unfold_tac ctxt
294 (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv})
295 THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}]))
296 THEN EVERY (map (prove_lev true) sl)
297 THEN EVERY (map (prove_lev false) (subtract (op =) sl (0 upto length cs - 1))))
302 local open Termination in
303 fun print_cell (SOME (Less _)) = "<"
304 | print_cell (SOME (LessEq _)) = "\<le>"
305 | print_cell (SOME (None _)) = "-"
306 | print_cell (SOME (False _)) = "-"
307 | print_cell (NONE) = "?"
309 fun print_error ctxt D = CALLS (fn (cs, _) =>
311 val np = get_num_points D
312 val ms = map_range (get_measures D) np
313 fun index xs = (1 upto length xs) ~~ xs
314 fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs
315 val ims = index (map index ms)
316 val _ = tracing (implode (outp "fn #" ":\n" (implode o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims))
317 fun print_call (k, c) =
319 val (_, p, _, q, _, _) = dest_call D c
320 val _ = tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^
321 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1))
322 val caller_ms = nth ms p
323 val callee_ms = nth ms q
324 val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms)
325 fun print_ln (i : int, l) = implode (Int.toString i :: " " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l)
326 val _ = tracing (implode (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^
327 " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n"
328 ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries)))
332 fun list_call (k, c) =
334 val (_, p, _, q, _, _) = dest_call D c
335 val _ = tracing ("call #" ^ (Int.toString k) ^ ": fn " ^
336 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^
337 (Syntax.string_of_term ctxt c))
339 val _ = forall list_call ((1 upto length cs) ~~ cs)
340 val _ = forall print_call ((1 upto length cs) ~~ cs)
346 fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) =>
348 val ms_configured = is_some (Multiset_Setup.get (ProofContext.theory_of ctxt))
349 val orders' = if ms_configured then orders
350 else filter_out (curry op = MS) orders
351 val gp = gen_probl D cs
352 val certificate = generate_certificate use_tags orders' gp
355 of NONE => err_cont D i
357 SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
358 THEN (rtac @{thm wf_empty} i ORELSE cont D i)
361 fun gen_decomp_scnp_tac orders autom_tac ctxt err_cont =
364 val derive_diag = Termination.derive_diag ctxt autom_tac
365 val derive_all = Termination.derive_all ctxt autom_tac
366 val decompose = Termination.decompose_tac ctxt autom_tac
367 val scnp_no_tags = single_scnp_tac false orders ctxt
368 val scnp_full = single_scnp_tac true orders ctxt
370 fun first_round c e =
371 derive_diag (REPEAT scnp_no_tags c e)
374 REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e)
378 REPEAT (fn c => fn e =>
379 scnp_full (decompose c c) e)
381 fun Then s1 s2 c e = s1 (s2 c c) (s2 c e)
383 val strategy = Then (Then first_round second_round) third_round
386 TERMINATION ctxt (strategy err_cont err_cont)
389 fun gen_sizechange_tac orders autom_tac ctxt err_cont =
390 TRY (Function_Common.apply_termination_rule ctxt 1)
391 THEN TRY (Termination.wf_union_tac ctxt)
393 (rtac @{thm wf_empty} 1
394 ORELSE gen_decomp_scnp_tac orders autom_tac ctxt err_cont 1)
396 fun sizechange_tac ctxt autom_tac =
397 gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt (K (K all_tac))
399 fun decomp_scnp_tac orders ctxt =
401 val extra_simps = Function_Common.Termination_Simps.get ctxt
402 val autom_tac = auto_tac (clasimpset_of ctxt addsimps2 extra_simps)
404 gen_sizechange_tac orders autom_tac ctxt (print_error ctxt)
412 ((Args.$$$ "max" >> K MAX) ||
413 (Args.$$$ "min" >> K MIN) ||
414 (Args.$$$ "ms" >> K MS))
415 || Scan.succeed [MAX, MS, MIN]
417 val setup = Method.setup @{binding size_change}
418 (Scan.lift orders --| Method.sections clasimp_modifiers >>
419 (fn orders => SIMPLE_METHOD o decomp_scnp_tac orders))
420 "termination prover with graph decomposition and the NP subset of size change termination"