src/HOL/Tools/Function/scnp_reconstruct.ML
author wenzelm
Thu, 26 Aug 2010 17:01:12 +0200
changeset 39036 b32975d3db3e
parent 36514 73ed9f18fdd3
child 40104 0e1bd289c8ea
permissions -rw-r--r--
theory data merge: prefer left side uniformly;
     1 (*  Title:       HOL/Tools/Function/scnp_reconstruct.ML
     2     Author:      Armin Heller, TU Muenchen
     3     Author:      Alexander Krauss, TU Muenchen
     4 
     5 Proof reconstruction for SCNP
     6 *)
     7 
     8 signature SCNP_RECONSTRUCT =
     9 sig
    10 
    11   val sizechange_tac : Proof.context -> tactic -> tactic
    12 
    13   val decomp_scnp_tac : ScnpSolve.label list -> Proof.context -> tactic
    14 
    15   val setup : theory -> theory
    16 
    17   datatype multiset_setup =
    18     Multiset of
    19     {
    20      msetT : typ -> typ,
    21      mk_mset : typ -> term list -> term,
    22      mset_regroup_conv : int list -> conv,
    23      mset_member_tac : int -> int -> tactic,
    24      mset_nonempty_tac : int -> tactic,
    25      mset_pwleq_tac : int -> tactic,
    26      set_of_simps : thm list,
    27      smsI' : thm,
    28      wmsI2'' : thm,
    29      wmsI1 : thm,
    30      reduction_pair : thm
    31     }
    32 
    33 
    34   val multiset_setup : multiset_setup -> theory -> theory
    35 
    36 end
    37 
    38 structure ScnpReconstruct : SCNP_RECONSTRUCT =
    39 struct
    40 
    41 val PROFILE = Function_Common.PROFILE
    42 
    43 open ScnpSolve
    44 
    45 val natT = HOLogic.natT
    46 val nat_pairT = HOLogic.mk_prodT (natT, natT)
    47 
    48 (* Theory dependencies *)
    49 
    50 datatype multiset_setup =
    51   Multiset of
    52   {
    53    msetT : typ -> typ,
    54    mk_mset : typ -> term list -> term,
    55    mset_regroup_conv : int list -> conv,
    56    mset_member_tac : int -> int -> tactic,
    57    mset_nonempty_tac : int -> tactic,
    58    mset_pwleq_tac : int -> tactic,
    59    set_of_simps : thm list,
    60    smsI' : thm,
    61    wmsI2'' : thm,
    62    wmsI1 : thm,
    63    reduction_pair : thm
    64   }
    65 
    66 structure Multiset_Setup = Theory_Data
    67 (
    68   type T = multiset_setup option
    69   val empty = NONE
    70   val extend = I;
    71   fun merge (v1, v2) = if is_some v1 then v1 else v2
    72 )
    73 
    74 val multiset_setup = Multiset_Setup.put o SOME
    75 
    76 fun undef _ = error "undef"
    77 fun get_multiset_setup thy = Multiset_Setup.get thy
    78   |> the_default (Multiset
    79 { msetT = undef, mk_mset=undef,
    80   mset_regroup_conv=undef, mset_member_tac = undef,
    81   mset_nonempty_tac = undef, mset_pwleq_tac = undef,
    82   set_of_simps = [],reduction_pair = refl,
    83   smsI'=refl, wmsI2''=refl, wmsI1=refl })
    84 
    85 fun order_rpair _ MAX = @{thm max_rpair_set}
    86   | order_rpair msrp MS  = msrp
    87   | order_rpair _ MIN = @{thm min_rpair_set}
    88 
    89 fun ord_intros_max true =
    90     (@{thm smax_emptyI}, @{thm smax_insertI})
    91   | ord_intros_max false =
    92     (@{thm wmax_emptyI}, @{thm wmax_insertI})
    93 fun ord_intros_min true =
    94     (@{thm smin_emptyI}, @{thm smin_insertI})
    95   | ord_intros_min false =
    96     (@{thm wmin_emptyI}, @{thm wmin_insertI})
    97 
    98 fun gen_probl D cs =
    99   let
   100     val n = Termination.get_num_points D
   101     val arity = length o Termination.get_measures D
   102     fun measure p i = nth (Termination.get_measures D p) i
   103 
   104     fun mk_graph c =
   105       let
   106         val (_, p, _, q, _, _) = Termination.dest_call D c
   107 
   108         fun add_edge i j =
   109           case Termination.get_descent D c (measure p i) (measure q j)
   110            of SOME (Termination.Less _) => cons (i, GTR, j)
   111             | SOME (Termination.LessEq _) => cons (i, GEQ, j)
   112             | _ => I
   113 
   114         val edges =
   115           fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
   116       in
   117         G (p, q, edges)
   118       end
   119   in
   120     GP (map_range arity n, map mk_graph cs)
   121   end
   122 
   123 (* General reduction pair application *)
   124 fun rem_inv_img ctxt =
   125   let
   126     val unfold_tac = Local_Defs.unfold_tac ctxt
   127   in
   128     rtac @{thm subsetI} 1
   129     THEN etac @{thm CollectE} 1
   130     THEN REPEAT (etac @{thm exE} 1)
   131     THEN unfold_tac @{thms inv_image_def}
   132     THEN rtac @{thm CollectI} 1
   133     THEN etac @{thm conjE} 1
   134     THEN etac @{thm ssubst} 1
   135     THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality}
   136                      @ @{thms sum.cases})
   137   end
   138 
   139 (* Sets *)
   140 
   141 val setT = HOLogic.mk_setT
   142 
   143 fun set_member_tac m i =
   144   if m = 0 then rtac @{thm insertI1} i
   145   else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i
   146 
   147 val set_nonempty_tac = rtac @{thm insert_not_empty}
   148 
   149 fun set_finite_tac i =
   150   rtac @{thm finite.emptyI} i
   151   ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st))
   152 
   153 
   154 (* Reconstruction *)
   155 
   156 fun reconstruct_tac ctxt D cs (GP (_, gs)) certificate =
   157   let
   158     val thy = ProofContext.theory_of ctxt
   159     val Multiset
   160           { msetT, mk_mset,
   161             mset_regroup_conv, mset_pwleq_tac, set_of_simps,
   162             smsI', wmsI2'', wmsI1, reduction_pair=ms_rp, ...} 
   163         = get_multiset_setup thy
   164 
   165     fun measure_fn p = nth (Termination.get_measures D p)
   166 
   167     fun get_desc_thm cidx m1 m2 bStrict =
   168       case Termination.get_descent D (nth cs cidx) m1 m2
   169        of SOME (Termination.Less thm) =>
   170           if bStrict then thm
   171           else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le}))
   172         | SOME (Termination.LessEq (thm, _))  =>
   173           if not bStrict then thm
   174           else sys_error "get_desc_thm"
   175         | _ => sys_error "get_desc_thm"
   176 
   177     val (label, lev, sl, covering) = certificate
   178 
   179     fun prove_lev strict g =
   180       let
   181         val G (p, q, _) = nth gs g
   182 
   183         fun less_proof strict (j, b) (i, a) =
   184           let
   185             val tag_flag = b < a orelse (not strict andalso b <= a)
   186 
   187             val stored_thm =
   188               get_desc_thm g (measure_fn p i) (measure_fn q j)
   189                              (not tag_flag)
   190               |> Conv.fconv_rule (Thm.beta_conversion true)
   191 
   192             val rule = if strict
   193               then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
   194               else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
   195           in
   196             rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
   197             THEN (if tag_flag then Arith_Data.arith_tac ctxt 1 else all_tac)
   198           end
   199 
   200         fun steps_tac MAX strict lq lp =
   201           let
   202             val (empty, step) = ord_intros_max strict
   203           in
   204             if length lq = 0
   205             then rtac empty 1 THEN set_finite_tac 1
   206                  THEN (if strict then set_nonempty_tac 1 else all_tac)
   207             else
   208               let
   209                 val (j, b) :: rest = lq
   210                 val (i, a) = the (covering g strict j)
   211                 fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1
   212                 val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
   213               in
   214                 rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp
   215               end
   216           end
   217           | steps_tac MIN strict lq lp =
   218           let
   219             val (empty, step) = ord_intros_min strict
   220           in
   221             if length lp = 0
   222             then rtac empty 1
   223                  THEN (if strict then set_nonempty_tac 1 else all_tac)
   224             else
   225               let
   226                 val (i, a) :: rest = lp
   227                 val (j, b) = the (covering g strict i)
   228                 fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1
   229                 val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
   230               in
   231                 rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest
   232               end
   233           end
   234           | steps_tac MS strict lq lp =
   235           let
   236             fun get_str_cover (j, b) =
   237               if is_some (covering g true j) then SOME (j, b) else NONE
   238             fun get_wk_cover (j, b) = the (covering g false j)
   239 
   240             val qs = subtract (op =) (map_filter get_str_cover lq) lq
   241             val ps = map get_wk_cover qs
   242 
   243             fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys
   244             val iqs = indices lq qs
   245             val ips = indices lp ps
   246 
   247             local open Conv in
   248             fun t_conv a C =
   249               params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
   250             val goal_rewrite =
   251                 t_conv arg1_conv (mset_regroup_conv iqs)
   252                 then_conv t_conv arg_conv (mset_regroup_conv ips)
   253             end
   254           in
   255             CONVERSION goal_rewrite 1
   256             THEN (if strict then rtac smsI' 1
   257                   else if qs = lq then rtac wmsI2'' 1
   258                   else rtac wmsI1 1)
   259             THEN mset_pwleq_tac 1
   260             THEN EVERY (map2 (less_proof false) qs ps)
   261             THEN (if strict orelse qs <> lq
   262                   then Local_Defs.unfold_tac ctxt set_of_simps
   263                        THEN steps_tac MAX true
   264                        (subtract (op =) qs lq) (subtract (op =) ps lp)
   265                   else all_tac)
   266           end
   267       in
   268         rem_inv_img ctxt
   269         THEN steps_tac label strict (nth lev q) (nth lev p)
   270       end
   271 
   272     val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT)
   273 
   274     fun tag_pair p (i, tag) =
   275       HOLogic.pair_const natT natT $
   276         (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag
   277 
   278     fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p,
   279                            mk_set nat_pairT (map (tag_pair p) lm))
   280 
   281     val level_mapping =
   282       map_index pt_lev lev
   283         |> Termination.mk_sumcases D (setT nat_pairT)
   284         |> cterm_of thy
   285     in
   286       PROFILE "Proof Reconstruction"
   287         (CONVERSION (Conv.arg_conv (Conv.arg_conv (Function_Lib.regroup_union_conv sl))) 1
   288          THEN (rtac @{thm reduction_pair_lemma} 1)
   289          THEN (rtac @{thm rp_inv_image_rp} 1)
   290          THEN (rtac (order_rpair ms_rp label) 1)
   291          THEN PRIMITIVE (instantiate' [] [SOME level_mapping])
   292          THEN unfold_tac @{thms rp_inv_image_def} (simpset_of ctxt)
   293          THEN Local_Defs.unfold_tac ctxt
   294            (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv})
   295          THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}]))
   296          THEN EVERY (map (prove_lev true) sl)
   297          THEN EVERY (map (prove_lev false) (subtract (op =) sl (0 upto length cs - 1))))
   298     end
   299 
   300 
   301 
   302 local open Termination in
   303 fun print_cell (SOME (Less _)) = "<"
   304   | print_cell (SOME (LessEq _)) = "\<le>"
   305   | print_cell (SOME (None _)) = "-"
   306   | print_cell (SOME (False _)) = "-"
   307   | print_cell (NONE) = "?"
   308 
   309 fun print_error ctxt D = CALLS (fn (cs, _) =>
   310   let
   311     val np = get_num_points D
   312     val ms = map_range (get_measures D) np
   313     fun index xs = (1 upto length xs) ~~ xs
   314     fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs
   315     val ims = index (map index ms)
   316     val _ = tracing (implode (outp "fn #" ":\n" (implode o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims))
   317     fun print_call (k, c) =
   318       let
   319         val (_, p, _, q, _, _) = dest_call D c
   320         val _ = tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^ 
   321                                 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1))
   322         val caller_ms = nth ms p
   323         val callee_ms = nth ms q
   324         val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms)
   325         fun print_ln (i : int, l) = implode (Int.toString i :: "   " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l)
   326         val _ = tracing (implode (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^ 
   327                                         " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n" 
   328                                 ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries)))
   329       in
   330         true
   331       end
   332     fun list_call (k, c) =
   333       let
   334         val (_, p, _, q, _, _) = dest_call D c
   335         val _ = tracing ("call #" ^ (Int.toString k) ^ ": fn " ^
   336                                 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^ 
   337                                 (Syntax.string_of_term ctxt c))
   338       in true end
   339     val _ = forall list_call ((1 upto length cs) ~~ cs)
   340     val _ = forall print_call ((1 upto length cs) ~~ cs)
   341   in
   342     all_tac
   343   end)
   344 end
   345 
   346 fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) =>
   347   let
   348     val ms_configured = is_some (Multiset_Setup.get (ProofContext.theory_of ctxt))
   349     val orders' = if ms_configured then orders
   350                   else filter_out (curry op = MS) orders
   351     val gp = gen_probl D cs
   352     val certificate = generate_certificate use_tags orders' gp
   353   in
   354     case certificate
   355      of NONE => err_cont D i
   356       | SOME cert =>
   357           SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
   358           THEN (rtac @{thm wf_empty} i ORELSE cont D i)
   359   end)
   360 
   361 fun gen_decomp_scnp_tac orders autom_tac ctxt err_cont =
   362   let
   363     open Termination
   364     val derive_diag = Termination.derive_diag ctxt autom_tac
   365     val derive_all = Termination.derive_all ctxt autom_tac
   366     val decompose = Termination.decompose_tac ctxt autom_tac
   367     val scnp_no_tags = single_scnp_tac false orders ctxt
   368     val scnp_full = single_scnp_tac true orders ctxt
   369 
   370     fun first_round c e =
   371         derive_diag (REPEAT scnp_no_tags c e)
   372 
   373     val second_round =
   374         REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e)
   375 
   376     val third_round =
   377         derive_all oo
   378         REPEAT (fn c => fn e =>
   379           scnp_full (decompose c c) e)
   380 
   381     fun Then s1 s2 c e = s1 (s2 c c) (s2 c e)
   382 
   383     val strategy = Then (Then first_round second_round) third_round
   384 
   385   in
   386     TERMINATION ctxt (strategy err_cont err_cont)
   387   end
   388 
   389 fun gen_sizechange_tac orders autom_tac ctxt err_cont =
   390   TRY (Function_Common.apply_termination_rule ctxt 1)
   391   THEN TRY (Termination.wf_union_tac ctxt)
   392   THEN
   393    (rtac @{thm wf_empty} 1
   394     ORELSE gen_decomp_scnp_tac orders autom_tac ctxt err_cont 1)
   395 
   396 fun sizechange_tac ctxt autom_tac =
   397   gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt (K (K all_tac))
   398 
   399 fun decomp_scnp_tac orders ctxt =
   400   let
   401     val extra_simps = Function_Common.Termination_Simps.get ctxt
   402     val autom_tac = auto_tac (clasimpset_of ctxt addsimps2 extra_simps)
   403   in
   404      gen_sizechange_tac orders autom_tac ctxt (print_error ctxt)
   405   end
   406 
   407 
   408 (* Method setup *)
   409 
   410 val orders =
   411   Scan.repeat1
   412     ((Args.$$$ "max" >> K MAX) ||
   413      (Args.$$$ "min" >> K MIN) ||
   414      (Args.$$$ "ms" >> K MS))
   415   || Scan.succeed [MAX, MS, MIN]
   416 
   417 val setup = Method.setup @{binding size_change}
   418   (Scan.lift orders --| Method.sections clasimp_modifiers >>
   419     (fn orders => SIMPLE_METHOD o decomp_scnp_tac orders))
   420   "termination prover with graph decomposition and the NP subset of size change termination"
   421 
   422 end