src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
author blanchet
Thu, 07 Nov 2013 00:37:18 +0100
changeset 55738 22616f65d4ea
parent 55736 0b53378080d9
child 55740 ce58fb149ff6
permissions -rw-r--r--
properly detect when to perform n2m -- e.g. handle the case of two independent functions on irrelevant types being defined in parallel
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2013
     4 
     5 Suggared flattening of nested to mutual (co)recursion.
     6 *)
     7 
     8 signature BNF_FP_N2M_SUGAR =
     9 sig
    10   val unfold_let: term -> term
    11   val dest_map: Proof.context -> string -> term -> term * term list
    12 
    13   val mutualize_fp_sugars: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
    14     term list list list list -> BNF_FP_Def_Sugar.fp_sugar list -> local_theory ->
    15     (BNF_FP_Def_Sugar.fp_sugar list
    16      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    17     * local_theory
    18   val indexify_callsss: BNF_FP_Def_Sugar.fp_sugar -> (term * term list list) list ->
    19     term list list list
    20   val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
    21     (term * term list list) list list -> local_theory ->
    22     (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
    23      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    24     * local_theory
    25 end;
    26 
    27 structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
    28 struct
    29 
    30 open Ctr_Sugar
    31 open BNF_Util
    32 open BNF_Def
    33 open BNF_FP_Util
    34 open BNF_FP_Def_Sugar
    35 open BNF_FP_N2M
    36 
    37 val n2mN = "n2m_"
    38 
    39 type n2m_sugar = fp_sugar list * (lfp_sugar_thms option * gfp_sugar_thms option);
    40 
    41 structure Data = Generic_Data
    42 (
    43   type T = n2m_sugar Typtab.table;
    44   val empty = Typtab.empty;
    45   val extend = I;
    46   val merge = Typtab.merge (eq_fst (eq_list eq_fp_sugar));
    47 );
    48 
    49 fun morph_n2m_sugar phi (fp_sugars, (lfp_sugar_thms_opt, gfp_sugar_thms_opt)) =
    50   (map (morph_fp_sugar phi) fp_sugars,
    51    (Option.map (morph_lfp_sugar_thms phi) lfp_sugar_thms_opt,
    52     Option.map (morph_gfp_sugar_thms phi) gfp_sugar_thms_opt));
    53 
    54 val transfer_n2m_sugar =
    55   morph_n2m_sugar o Morphism.thm_morphism o Thm.transfer o Proof_Context.theory_of;
    56 
    57 fun n2m_sugar_of ctxt =
    58   Typtab.lookup (Data.get (Context.Proof ctxt))
    59   #> Option.map (transfer_n2m_sugar ctxt);
    60 
    61 fun register_n2m_sugar key n2m_sugar =
    62   Local_Theory.declaration {syntax = false, pervasive = false}
    63     (fn phi => Data.map (Typtab.default (key, morph_n2m_sugar phi n2m_sugar)));
    64 
    65 fun unfold_let (Const (@{const_name Let}, _) $ arg1 $ arg2) = unfold_let (betapply (arg2, arg1))
    66   | unfold_let (Const (@{const_name prod_case}, _) $ t) =
    67     (case unfold_let t of
    68       t' as Abs (s1, T1, Abs (s2, T2, _)) =>
    69       let
    70         val x = (s1 ^ s2, Term.maxidx_of_term t + 1);
    71         val v = Var (x, HOLogic.mk_prodT (T1, T2));
    72       in
    73         lambda v (unfold_let (betapplys (t', [HOLogic.mk_fst v, HOLogic.mk_snd v])))
    74       end
    75     | _ => t)
    76   | unfold_let (t $ u) = betapply (unfold_let t, unfold_let u)
    77   | unfold_let (Abs (s, T, t)) = Abs (s, T, unfold_let t)
    78   | unfold_let t = t;
    79 
    80 fun mk_map_pattern ctxt s =
    81   let
    82     val bnf = the (bnf_of ctxt s);
    83     val mapx = map_of_bnf bnf;
    84     val live = live_of_bnf bnf;
    85     val (f_Ts, _) = strip_typeN live (fastype_of mapx);
    86     val fs = map_index (fn (i, T) => Var (("?f", i), T)) f_Ts;
    87   in
    88     (mapx, betapplys (mapx, fs))
    89   end;
    90 
    91 fun dest_map ctxt s call =
    92   let
    93     val (map0, pat) = mk_map_pattern ctxt s;
    94     val (_, tenv) = fo_match ctxt call pat;
    95   in
    96     (map0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
    97   end;
    98 
    99 fun dest_abs_or_applied_map _ _ (Abs (_, _, t)) = (Term.dummy, [t])
   100   | dest_abs_or_applied_map ctxt s (t1 $ _) = dest_map ctxt s t1;
   101 
   102 fun map_partition f xs =
   103   fold_rev (fn x => fn (ys, (good, bad)) =>
   104       case f x of SOME y => (y :: ys, (x :: good, bad)) | NONE => (ys, (good, x :: bad)))
   105     xs ([], ([], []));
   106 
   107 fun key_of_fp_eqs fp fpTs fp_eqs =
   108   Type (fp_case fp "l" "g", fpTs @ maps (fn (x, T) => [TFree x, T]) fp_eqs);
   109 
   110 (* TODO: test with sort constraints on As *)
   111 fun mutualize_fp_sugars fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
   112   let
   113     val thy = Proof_Context.theory_of no_defs_lthy0;
   114 
   115     val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
   116 
   117     fun incompatible_calls t1 t2 =
   118       error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
   119 
   120     val b_names = map Binding.name_of bs;
   121     val fp_b_names = map base_name_of_typ fpTs;
   122 
   123     val nn = length fpTs;
   124 
   125     fun target_ctr_sugar_of_fp_sugar fpT ({T, index, ctr_sugars, ...} : fp_sugar) =
   126       let
   127         val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
   128         val phi = Morphism.term_morphism (Term.subst_TVars rho);
   129       in
   130         morph_ctr_sugar phi (nth ctr_sugars index)
   131       end;
   132 
   133     val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
   134     val mapss = map (of_fp_sugar #mapss) fp_sugars0;
   135     val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
   136 
   137     val ctrss = map #ctrs ctr_sugars0;
   138     val ctr_Tss = map (map fastype_of) ctrss;
   139 
   140     val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
   141     val As = map TFree As';
   142 
   143     val ((Cs, Xs), no_defs_lthy) =
   144       no_defs_lthy0
   145       |> fold Variable.declare_typ As
   146       |> mk_TFrees nn
   147       ||>> variant_tfrees fp_b_names;
   148 
   149     fun check_call_dead live_call call =
   150       if null (get_indices call) then () else incompatible_calls live_call call;
   151 
   152     fun freeze_fpTs_simple (T as Type (s, Ts)) =
   153         (case find_index (curry (op =) T) fpTs of
   154           ~1 => Type (s, map freeze_fpTs_simple Ts)
   155         | kk => nth Xs kk)
   156       | freeze_fpTs_simple T = T;
   157 
   158     fun freeze_fpTs_map (callss, (live_call :: _, dead_calls)) s Ts =
   159       (List.app (check_call_dead live_call) dead_calls;
   160        Type (s, map2 freeze_fpTs (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
   161          (transpose callss)) Ts))
   162     and freeze_fpTs calls (T as Type (s, Ts)) =
   163         (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
   164           ([], _) =>
   165           (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
   166             ([], _) => freeze_fpTs_simple T
   167           | callsp => freeze_fpTs_map callsp s Ts)
   168         | callsp => freeze_fpTs_map callsp s Ts)
   169       | freeze_fpTs _ T = T;
   170 
   171     val ctr_Tsss = map (map binder_types) ctr_Tss;
   172     val ctrXs_Tsss = map2 (map2 (map2 freeze_fpTs)) callssss ctr_Tsss;
   173     val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
   174     val Ts = map (body_type o hd) ctr_Tss;
   175 
   176     val ns = map length ctr_Tsss;
   177     val kss = map (fn n => 1 upto n) ns;
   178     val mss = map (map length) ctr_Tsss;
   179 
   180     val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
   181     val key = key_of_fp_eqs fp fpTs fp_eqs;
   182   in
   183     (case n2m_sugar_of no_defs_lthy key of
   184       SOME n2m_sugar => (n2m_sugar, no_defs_lthy)
   185     | NONE =>
   186       let
   187         val base_fp_names = Name.variant_list [] fp_b_names;
   188         val fp_bs = map2 (fn b_name => fn base_fp_name =>
   189             Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
   190           b_names base_fp_names;
   191 
   192         val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct, dtor_injects,
   193                dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
   194           fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
   195 
   196         val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
   197         val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
   198 
   199         val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
   200           mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
   201 
   202         fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
   203 
   204         val ((co_iterss, co_iter_defss), lthy) =
   205           fold_map2 (fn b =>
   206             (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
   207              else define_coiters [unfoldN, corecN] (the coiters_args_types))
   208               (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
   209           |>> split_list;
   210 
   211         val rho = tvar_subst thy Ts fpTs;
   212         val ctr_sugar_phi = Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
   213             (Morphism.term_morphism (Term.subst_TVars rho));
   214         val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
   215 
   216         val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
   217 
   218         val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
   219               sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
   220           if fp = Least_FP then
   221             derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
   222               xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
   223               co_iterss co_iter_defss lthy
   224             |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
   225               ([induct], fold_thmss, rec_thmss, [], [], [], []))
   226             ||> (fn info => (SOME info, NONE))
   227           else
   228             derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
   229               dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs ctrXs_Tsss kss mss
   230               ns ctr_defss ctr_sugars co_iterss co_iter_defss
   231               (Proof_Context.export lthy no_defs_lthy) lthy
   232             |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
   233                     (disc_unfold_thmss, disc_corec_thmss, _), _,
   234                     (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
   235               (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
   236                disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
   237             ||> (fn info => (NONE, SOME info));
   238 
   239         val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
   240 
   241         fun mk_target_fp_sugar (kk, T) =
   242           {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
   243            nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
   244            ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
   245            co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
   246            disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
   247            sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
   248           |> morph_fp_sugar phi;
   249 
   250         val n2m_sugar = (map_index mk_target_fp_sugar fpTs, fp_sugar_thms);
   251       in
   252         (n2m_sugar, lthy |> register_n2m_sugar key n2m_sugar)
   253       end)
   254   end;
   255 
   256 fun indexify_callsss fp_sugar callsss =
   257   let
   258     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
   259     fun indexify_ctr ctr =
   260       (case AList.lookup Term.aconv_untyped callsss ctr of
   261         NONE => replicate (num_binder_types (fastype_of ctr)) []
   262       | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
   263   in
   264     map indexify_ctr ctrs
   265   end;
   266 
   267 fun retypargs tyargs (Type (s, _)) = Type (s, tyargs);
   268 
   269 fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
   270     f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
   271   | fold_subtype_pairs f TU = f TU;
   272 
   273 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   274   let
   275     val qsoty = quote o Syntax.string_of_typ lthy;
   276     val qsotys = space_implode " or " o map qsoty;
   277 
   278     fun duplicate_datatype T = error (qsoty T ^ " is not mutually recursive with itself");
   279     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
   280     fun not_co_datatype (T as Type (s, _)) =
   281         if fp = Least_FP andalso
   282            is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
   283           error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
   284         else
   285           not_co_datatype0 T
   286       | not_co_datatype T = not_co_datatype0 T;
   287 
   288     val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
   289 
   290     val perm_actual_Ts =
   291       sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
   292 
   293     fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
   294 
   295     fun the_fp_sugar_of (T as Type (T_name, _)) =
   296       (case fp_sugar_of lthy T_name of
   297         SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T
   298       | NONE => not_co_datatype T);
   299 
   300     fun gen_rhss_in gen_Ts rho subTs =
   301       let
   302         fun maybe_insert (T, Type (_, gen_tyargs)) =
   303             if member (op =) subTs T then insert (op =) gen_tyargs else I
   304           | maybe_insert _ = I;
   305 
   306         val ctrs = maps the_ctrs_of gen_Ts;
   307         val gen_ctr_Ts = maps (binder_types o fastype_of) ctrs;
   308         val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts;
   309       in
   310         fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) []
   311       end;
   312 
   313     fun gather_types _ _ num_groups seen gen_seen [] = (num_groups, seen, gen_seen)
   314       | gather_types lthy rho num_groups seen gen_seen ((T as Type (_, tyargs)) :: Ts) =
   315         let
   316           val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
   317           val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
   318 
   319           fun fresh_tyargs () =
   320             let
   321               (* The name "'z" is unlikely to clash with the context, yielding more cache hits. *)
   322               val (gen_tyargs, lthy') =
   323                 variant_tfrees (replicate (length tyargs) "z") lthy
   324                 |>> map Logic.varifyT_global;
   325               val rho' = (gen_tyargs ~~ tyargs) @ rho;
   326             in
   327               (rho', gen_tyargs, gen_seen, lthy')
   328             end;
   329 
   330           val (rho', gen_tyargs, gen_seen', lthy') =
   331             if exists (exists_subtype_in seen) mutual_Ts then
   332               (case gen_rhss_in gen_seen rho mutual_Ts of
   333                 [] => fresh_tyargs ()
   334               | gen_tyargss as gen_tyargs :: gen_tyargss_tl =>
   335                 let
   336                   val unify_pairs = split_list (maps (curry (op ~~) gen_tyargs) gen_tyargss_tl);
   337                   val mgu = Type.raw_unifys unify_pairs Vartab.empty;
   338                   val gen_tyargs' = map (Envir.subst_type mgu) gen_tyargs;
   339                   val gen_seen' = map (Envir.subst_type mgu) gen_seen;
   340                 in
   341                   (rho, gen_tyargs', gen_seen', lthy)
   342                 end)
   343             else
   344               fresh_tyargs ();
   345 
   346           val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
   347           val Ts' = filter_out (member (op =) mutual_Ts) Ts;
   348         in
   349           gather_types lthy' rho' (num_groups + 1) (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts)
   350             Ts'
   351         end
   352       | gather_types _ _ _ _ _ (T :: _) = not_co_datatype T;
   353 
   354     val (num_groups, perm_Ts, perm_gen_Ts) = gather_types lthy [] 0 [] [] perm_actual_Ts;
   355     val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
   356 
   357     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   358     val Ts = actual_Ts @ missing_Ts;
   359 
   360     val nn = length Ts;
   361     val kks = 0 upto nn - 1;
   362 
   363     val callssss0 = pad_list [] nn actual_callssss0;
   364 
   365     val common_name = mk_common_name (map Binding.name_of actual_bs);
   366     val bs = pad_list (Binding.name common_name) nn actual_bs;
   367 
   368     fun permute xs = permute_like (op =) Ts perm_Ts xs;
   369     fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
   370 
   371     val perm_bs = permute bs;
   372     val perm_kks = permute kks;
   373     val perm_callssss0 = permute callssss0;
   374     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
   375 
   376     val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0;
   377 
   378     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
   379 
   380     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   381       if num_groups > 1 then
   382         mutualize_fp_sugars fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
   383           perm_fp_sugars0 lthy
   384       else
   385         ((perm_fp_sugars0, (NONE, NONE)), lthy);
   386 
   387     val fp_sugars = unpermute perm_fp_sugars;
   388   in
   389     ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)
   390   end;
   391 
   392 end;