src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
author blanchet
Tue, 05 Nov 2013 05:48:08 +0100
changeset 55707 4f7c016d5bc6
parent 55706 d1478807f287
child 55708 4843082be7ef
permissions -rw-r--r--
also generalize fixed types
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2013
     4 
     5 Suggared flattening of nested to mutual (co)recursion.
     6 *)
     7 
     8 signature BNF_FP_N2M_SUGAR =
     9 sig
    10   val unfold_let: term -> term
    11   val dest_map: Proof.context -> string -> term -> term * term list
    12 
    13   val mutualize_fp_sugars: bool -> BNF_FP_Util.fp_kind -> binding list -> typ list ->
    14     (term -> int list) -> term list list list list -> BNF_FP_Def_Sugar.fp_sugar list ->
    15     local_theory ->
    16     (BNF_FP_Def_Sugar.fp_sugar list
    17      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    18     * local_theory
    19   val pad_and_indexify_calls: BNF_FP_Def_Sugar.fp_sugar list -> int ->
    20     (term * term list list) list list -> term list list list list
    21   val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
    22     (term * term list list) list list -> local_theory ->
    23     (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
    24      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
    25     * local_theory
    26 end;
    27 
    28 structure BNF_FP_N2M_Sugar : BNF_FP_N2M_SUGAR =
    29 struct
    30 
    31 open Ctr_Sugar
    32 open BNF_Util
    33 open BNF_Def
    34 open BNF_FP_Util
    35 open BNF_FP_Def_Sugar
    36 open BNF_FP_N2M
    37 
    38 val n2mN = "n2m_"
    39 
    40 fun unfold_let (Const (@{const_name Let}, _) $ arg1 $ arg2) = unfold_let (betapply (arg2, arg1))
    41   | unfold_let (Const (@{const_name prod_case}, _) $ t) =
    42     (case unfold_let t of
    43       t' as Abs (s1, T1, Abs (s2, T2, _)) =>
    44       let
    45         val x = (s1 ^ s2, Term.maxidx_of_term t + 1);
    46         val v = Var (x, HOLogic.mk_prodT (T1, T2));
    47       in
    48         lambda v (unfold_let (betapplys (t', [HOLogic.mk_fst v, HOLogic.mk_snd v])))
    49       end
    50     | _ => t)
    51   | unfold_let (t $ u) = betapply (unfold_let t, unfold_let u)
    52   | unfold_let (Abs (s, T, t)) = Abs (s, T, unfold_let t)
    53   | unfold_let t = t;
    54 
    55 val dummy_var_name = "?f"
    56 
    57 fun mk_map_pattern ctxt s =
    58   let
    59     val bnf = the (bnf_of ctxt s);
    60     val mapx = map_of_bnf bnf;
    61     val live = live_of_bnf bnf;
    62     val (f_Ts, _) = strip_typeN live (fastype_of mapx);
    63     val fs = map_index (fn (i, T) => Var ((dummy_var_name, i), T)) f_Ts;
    64   in
    65     (mapx, betapplys (mapx, fs))
    66   end;
    67 
    68 fun dest_map ctxt s call =
    69   let
    70     val (map0, pat) = mk_map_pattern ctxt s;
    71     val (_, tenv) = fo_match ctxt call pat;
    72   in
    73     (map0, Vartab.fold_rev (fn (_, (_, f)) => cons f) tenv [])
    74   end;
    75 
    76 fun dest_abs_or_applied_map_or_ctr _ _ (Abs (_, _, t)) = (Term.dummy, [t])
    77   | dest_abs_or_applied_map_or_ctr ctxt s (t as t1 $ _) =
    78     (case try (dest_map ctxt s) t1 of
    79       SOME res => res
    80     | NONE =>
    81       let
    82         val thy = Proof_Context.theory_of ctxt;
    83         val map_thms = of_fp_sugar #mapss (the (fp_sugar_of ctxt s))
    84         val map_thms' = map (fn thm => thm RS sym RS eq_reflection) map_thms;
    85         val t' = Raw_Simplifier.rewrite_term thy map_thms' [] t;
    86       in
    87         if t aconv t' then raise Fail "dest_applied_map_or_ctr"
    88         else dest_map ctxt s (fst (dest_comb t'))
    89       end);
    90 
    91 fun map_partition f xs =
    92   fold_rev (fn x => fn (ys, (good, bad)) =>
    93       case f x of SOME y => (y :: ys, (x :: good, bad)) | NONE => (ys, (good, x :: bad)))
    94     xs ([], ([], []));
    95 
    96 (* TODO: test with sort constraints on As *)
    97 (* TODO: use right sorting order for "fp_sort" w.r.t. original BNFs (?) -- treat new variables
    98    as deads? *)
    99 fun mutualize_fp_sugars has_nested fp bs fpTs get_indices callssss fp_sugars0 no_defs_lthy0 =
   100   if has_nested then
   101     let
   102       val thy = Proof_Context.theory_of no_defs_lthy0;
   103 
   104       val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
   105 
   106       fun incompatible_calls t1 t2 =
   107         error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^
   108           qsotm t2);
   109 
   110       val b_names = map Binding.name_of bs;
   111       val fp_b_names = map base_name_of_typ fpTs;
   112 
   113       val nn = length fpTs;
   114 
   115       fun target_ctr_sugar_of_fp_sugar fpT ({T, index, ctr_sugars, ...} : fp_sugar) =
   116         let
   117           val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T, fpT) Vartab.empty) [];
   118           val phi = Morphism.term_morphism (Term.subst_TVars rho);
   119         in
   120           morph_ctr_sugar phi (nth ctr_sugars index)
   121         end;
   122 
   123       val ctr_defss = map (of_fp_sugar #ctr_defss) fp_sugars0;
   124       val mapss = map (of_fp_sugar #mapss) fp_sugars0;
   125       val ctr_sugars0 = map2 target_ctr_sugar_of_fp_sugar fpTs fp_sugars0;
   126 
   127       val ctrss = map #ctrs ctr_sugars0;
   128       val ctr_Tss = map (map fastype_of) ctrss;
   129 
   130       val As' = fold (fold Term.add_tfreesT) ctr_Tss [];
   131       val As = map TFree As';
   132 
   133       val ((Cs, Xs), no_defs_lthy) =
   134         no_defs_lthy0
   135         |> fold Variable.declare_typ As
   136         |> mk_TFrees nn
   137         ||>> variant_tfrees fp_b_names;
   138 
   139       fun check_call_dead live_call call =
   140         if null (get_indices call) then () else incompatible_calls live_call call;
   141 
   142       fun freeze_fpTs_simple (T as Type (s, Ts)) =
   143           (case find_index (curry (op =) T) fpTs of
   144             ~1 => Type (s, map freeze_fpTs_simple Ts)
   145           | kk => nth Xs kk)
   146         | freeze_fpTs_simple T = T;
   147 
   148       fun freeze_fpTs_map (callss, (live_call :: _, dead_calls)) s Ts =
   149         (List.app (check_call_dead live_call) dead_calls;
   150          Type (s, map2 freeze_fpTs (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
   151            (transpose callss)) Ts))
   152       and freeze_fpTs calls (T as Type (s, Ts)) =
   153           (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
   154             ([], _) =>
   155             (case map_partition (try (snd o dest_abs_or_applied_map_or_ctr no_defs_lthy s)) calls of
   156               ([], _) => freeze_fpTs_simple T
   157             | callsp => freeze_fpTs_map callsp s Ts)
   158           | callsp => freeze_fpTs_map callsp s Ts)
   159         | freeze_fpTs _ T = T;
   160 
   161       val ctr_Tsss = map (map binder_types) ctr_Tss;
   162       val ctrXs_Tsss = map2 (map2 (map2 freeze_fpTs)) callssss ctr_Tsss;
   163       val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
   164       val Ts = map (body_type o hd) ctr_Tss;
   165 
   166       val ns = map length ctr_Tsss;
   167       val kss = map (fn n => 1 upto n) ns;
   168       val mss = map (map length) ctr_Tsss;
   169 
   170       val fp_eqs = map dest_TFree Xs ~~ ctrXs_sum_prod_Ts;
   171 
   172       val base_fp_names = Name.variant_list [] fp_b_names;
   173       val fp_bs = map2 (fn b_name => fn base_fp_name =>
   174           Binding.qualify true b_name (Binding.name (n2mN ^ base_fp_name)))
   175         b_names base_fp_names;
   176 
   177       val (pre_bnfs, (fp_res as {xtor_co_iterss = xtor_co_iterss0, xtor_co_induct,
   178              dtor_injects, dtor_ctors, xtor_co_iter_thmss, ...}, lthy)) =
   179         fp_bnf (construct_mutualized_fp fp fpTs fp_sugars0) fp_bs As' fp_eqs no_defs_lthy;
   180 
   181       val nesting_bnfs = nesty_bnfs lthy ctrXs_Tsss As;
   182       val nested_bnfs = nesty_bnfs lthy ctrXs_Tsss Xs;
   183 
   184       val ((xtor_co_iterss, iters_args_types, coiters_args_types), _) =
   185         mk_co_iters_prelims fp ctr_Tsss fpTs Cs ns mss xtor_co_iterss0 lthy;
   186 
   187       fun mk_binding b suf = Binding.suffix_name ("_" ^ suf) b;
   188 
   189       val ((co_iterss, co_iter_defss), lthy) =
   190         fold_map2 (fn b =>
   191           (if fp = Least_FP then define_iters [foldN, recN] (the iters_args_types)
   192            else define_coiters [unfoldN, corecN] (the coiters_args_types))
   193             (mk_binding b) fpTs Cs) fp_bs xtor_co_iterss lthy
   194         |>> split_list;
   195 
   196       val rho = tvar_subst thy Ts fpTs;
   197       val ctr_sugar_phi =
   198         Morphism.compose (Morphism.typ_morphism (Term.typ_subst_TVars rho))
   199           (Morphism.term_morphism (Term.subst_TVars rho));
   200       val inst_ctr_sugar = morph_ctr_sugar ctr_sugar_phi;
   201 
   202       val ctr_sugars = map inst_ctr_sugar ctr_sugars0;
   203 
   204       val ((co_inducts, un_fold_thmss, co_rec_thmss, disc_unfold_thmss, disc_corec_thmss,
   205             sel_unfold_thmsss, sel_corec_thmsss), fp_sugar_thms) =
   206         if fp = Least_FP then
   207           derive_induct_iters_thms_for_types pre_bnfs (the iters_args_types) xtor_co_induct
   208             xtor_co_iter_thmss nesting_bnfs nested_bnfs fpTs Cs Xs ctrXs_Tsss ctrss ctr_defss
   209             co_iterss co_iter_defss lthy
   210           |> `(fn ((_, induct, _), (fold_thmss, rec_thmss, _)) =>
   211             ([induct], fold_thmss, rec_thmss, [], [], [], []))
   212           ||> (fn info => (SOME info, NONE))
   213         else
   214           derive_coinduct_coiters_thms_for_types pre_bnfs (the coiters_args_types) xtor_co_induct
   215             dtor_injects dtor_ctors xtor_co_iter_thmss nesting_bnfs fpTs Cs Xs ctrXs_Tsss kss mss ns
   216             ctr_defss ctr_sugars co_iterss co_iter_defss (Proof_Context.export lthy no_defs_lthy)
   217             lthy
   218           |> `(fn ((coinduct_thms_pairs, _), (unfold_thmss, corec_thmss, _),
   219                   (disc_unfold_thmss, disc_corec_thmss, _), _,
   220                   (sel_unfold_thmsss, sel_corec_thmsss, _)) =>
   221             (map snd coinduct_thms_pairs, unfold_thmss, corec_thmss, disc_unfold_thmss,
   222              disc_corec_thmss, sel_unfold_thmsss, sel_corec_thmsss))
   223           ||> (fn info => (NONE, SOME info));
   224 
   225       val phi = Proof_Context.export_morphism no_defs_lthy no_defs_lthy0;
   226 
   227       fun mk_target_fp_sugar (kk, T) =
   228         {T = T, fp = fp, index = kk, pre_bnfs = pre_bnfs, nested_bnfs = nested_bnfs,
   229          nesting_bnfs = nesting_bnfs, fp_res = fp_res, ctr_defss = ctr_defss,
   230          ctr_sugars = ctr_sugars, co_iterss = co_iterss, mapss = mapss, co_inducts = co_inducts,
   231          co_iter_thmsss = transpose [un_fold_thmss, co_rec_thmss],
   232          disc_co_itersss = transpose [disc_unfold_thmss, disc_corec_thmss],
   233          sel_co_iterssss = transpose [sel_unfold_thmsss, sel_corec_thmsss]}
   234         |> morph_fp_sugar phi;
   235     in
   236       ((map_index mk_target_fp_sugar fpTs, fp_sugar_thms), lthy)
   237     end
   238   else
   239     (* TODO: reorder hypotheses and predicates in (co)induction rules? *)
   240     ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
   241 
   242 fun indexify_callsss fp_sugar callsss =
   243   let
   244     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
   245     fun do_ctr ctr =
   246       (case AList.lookup Term.aconv_untyped callsss ctr of
   247         NONE => replicate (num_binder_types (fastype_of ctr)) []
   248       | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
   249   in
   250     map do_ctr ctrs
   251   end;
   252 
   253 fun pad_and_indexify_calls fp_sugars0 = map2 indexify_callsss fp_sugars0 oo pad_list [];
   254 
   255 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   256   let
   257     val qsoty = quote o Syntax.string_of_typ lthy;
   258     val qsotys = space_implode " or " o map qsoty;
   259 
   260     fun duplicate_datatype T = error (qsoty T ^ " is not mutually recursive with itself");
   261     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
   262     fun not_co_datatype (T as Type (s, _)) =
   263         if fp = Least_FP andalso
   264            is_some (Datatype_Data.get_info (Proof_Context.theory_of lthy) s) then
   265           error (qsoty T ^ " is not a new-style datatype (cf. \"datatype_new\")")
   266         else
   267           not_co_datatype0 T
   268       | not_co_datatype T = not_co_datatype0 T;
   269     fun not_mutually_nested_rec Ts1 Ts2 =
   270       error (qsotys Ts1 ^ " is neither mutually recursive with nor nested recursive via " ^
   271         qsotys Ts2);
   272 
   273     val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
   274 
   275     val perm_actual_Ts as Type (_, ty_args0) :: _ =
   276       sort (int_ord o pairself Term.size_of_typ) actual_Ts;
   277 
   278     fun check_enrich_with_mutuals _ [] = []
   279       | check_enrich_with_mutuals seen ((T as Type (T_name, ty_args)) :: Ts) =
   280         (case fp_sugar_of lthy T_name of
   281           SOME ({fp = fp', fp_res = {Ts = Ts', ...}, ...}) =>
   282           if fp = fp' then
   283             let
   284               val mutual_Ts = map (fn Type (s, _) => Type (s, ty_args)) Ts';
   285               val _ =
   286                 seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
   287                 not_mutually_nested_rec mutual_Ts seen;
   288               val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
   289             in
   290               mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
   291             end
   292           else
   293             not_co_datatype T
   294         | NONE => not_co_datatype T)
   295       | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
   296 
   297     val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
   298     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   299     val Ts = actual_Ts @ missing_Ts;
   300 
   301     fun generalize_simple_type T (seen, lthy) =
   302       mk_TFrees 1 lthy |> (fn ([U], lthy) => (U, ((T, U) :: seen, lthy)));
   303 
   304     fun generalize_type T (seen_lthy as (seen, _)) =
   305       (case AList.lookup (op =) seen T of
   306         SOME U => (U, seen_lthy)
   307       | NONE =>
   308         (case T of
   309           Type (s, Ts) =>
   310           if exists_subtype_in Ts T then fold_map generalize_type Ts seen_lthy |>> curry Type s
   311           else generalize_simple_type T seen_lthy
   312         | _ => generalize_simple_type T seen_lthy));
   313 
   314     val (perm_Us, _) = fold_map generalize_type perm_Ts ([], lthy);
   315 
   316     val nn = length Ts;
   317     val kks = 0 upto nn - 1;
   318 
   319     val common_name = mk_common_name (map Binding.name_of actual_bs);
   320     val bs = pad_list (Binding.name common_name) nn actual_bs;
   321 
   322     fun permute xs = permute_like (op =) Ts perm_Ts xs;
   323     fun unpermute perm_xs = permute_like (op =) perm_Ts Ts perm_xs;
   324 
   325     val perm_bs = permute bs;
   326     val perm_kks = permute kks;
   327     val perm_fp_sugars0 = map (the o fp_sugar_of lthy o fst o dest_Type) perm_Ts;
   328 
   329     val has_nested = exists (fn Type (_, ty_args) => ty_args <> ty_args0) Ts;
   330     val perm_callssss = pad_and_indexify_calls perm_fp_sugars0 nn actual_callssss0;
   331 
   332     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
   333 
   334     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   335       mutualize_fp_sugars has_nested fp perm_bs perm_Us get_perm_indices perm_callssss
   336         perm_fp_sugars0 lthy;
   337 
   338     val fp_sugars = unpermute perm_fp_sugars;
   339   in
   340     ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)
   341   end;
   342 
   343 end;