src/HOL/BNF/Tools/bnf_fp_n2m_sugar.ML
changeset 55735 6f0a49ed1bb1
parent 55734 32b5c4821d9d
child 55736 0b53378080d9
equal deleted inserted replaced
55734:32b5c4821d9d 55735:6f0a49ed1bb1
   262     ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
   262     ((fp_sugars0, (NONE, NONE)), no_defs_lthy0);
   263 
   263 
   264 fun indexify_callsss fp_sugar callsss =
   264 fun indexify_callsss fp_sugar callsss =
   265   let
   265   let
   266     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
   266     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
   267     fun do_ctr ctr =
   267     fun indexify_ctr ctr =
   268       (case AList.lookup Term.aconv_untyped callsss ctr of
   268       (case AList.lookup Term.aconv_untyped callsss ctr of
   269         NONE => replicate (num_binder_types (fastype_of ctr)) []
   269         NONE => replicate (num_binder_types (fastype_of ctr)) []
   270       | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
   270       | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss);
   271   in
   271   in
   272     map do_ctr ctrs
   272     map indexify_ctr ctrs
   273   end;
   273   end;
       
   274 
       
   275 fun retypargs tyargs (Type (s, _)) = Type (s, tyargs);
       
   276 
       
   277 fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
       
   278     f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
       
   279   | fold_subtype_pairs f TU = f TU;
   274 
   280 
   275 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   281 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy =
   276   let
   282   let
   277     val qsoty = quote o Syntax.string_of_typ lthy;
   283     val qsoty = quote o Syntax.string_of_typ lthy;
   278     val qsotys = space_implode " or " o map qsoty;
   284     val qsotys = space_implode " or " o map qsoty;
   290     val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
   296     val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
   291 
   297 
   292     val perm_actual_Ts as Type (_, tyargs0) :: _ =
   298     val perm_actual_Ts as Type (_, tyargs0) :: _ =
   293       sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
   299       sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
   294 
   300 
       
   301     fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
       
   302 
   295     fun the_fp_sugar_of (T as Type (T_name, _)) =
   303     fun the_fp_sugar_of (T as Type (T_name, _)) =
   296       (case fp_sugar_of lthy T_name of
   304       (case fp_sugar_of lthy T_name of
   297         SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T
   305         SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T
   298       | NONE => not_co_datatype T);
   306       | NONE => not_co_datatype T);
   299 
   307 
   300     fun check_enrich_with_mutuals _ [] = []
   308     fun gen_rhss_in gen_Ts rho subTs =
   301       | check_enrich_with_mutuals seen ((T as Type (_, tyargs)) :: Ts) =
   309       let
       
   310         fun maybe_insert (T, Type (_, gen_tyargs)) =
       
   311             if member (op =) subTs T then insert (op =) gen_tyargs else I
       
   312           | maybe_insert _ = I;
       
   313 
       
   314         val ctrs = maps the_ctrs_of gen_Ts;
       
   315         val gen_ctr_Ts = maps (binder_types o fastype_of) ctrs;
       
   316         val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts;
       
   317       in
       
   318         fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) []
       
   319       end;
       
   320 
       
   321     fun check_enrich_with_mutuals _ _ seen gen_seen [] = (seen, gen_seen)
       
   322       | check_enrich_with_mutuals lthy rho seen gen_seen ((T as Type (_, tyargs)) :: Ts) =
   302         let
   323         let
   303           val {fp_res = {Ts = Ts', ...}, ...} = the_fp_sugar_of T
   324           val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
   304           val mutual_Ts = map (fn Type (s, _) => Type (s, tyargs)) Ts';
   325           val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
   305           val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts;
   326 
       
   327           fun fresh_tyargs () =
       
   328             let
       
   329               (* The name "'z" is unlikely to clash with the context, yielding more cache hits. *)
       
   330               val (gen_tyargs, lthy') =
       
   331                 variant_tfrees (replicate (length tyargs) "z") lthy
       
   332                 |>> map Logic.varifyT_global;
       
   333               val rho' = (gen_tyargs ~~ tyargs) @ rho;
       
   334             in
       
   335               (rho', gen_tyargs, gen_seen, lthy')
       
   336             end;
       
   337 
       
   338           val (rho', gen_tyargs, gen_seen', lthy') =
       
   339             if exists (exists_subtype_in seen) mutual_Ts then
       
   340               (case gen_rhss_in gen_seen rho mutual_Ts of
       
   341                 [] => fresh_tyargs ()
       
   342               | [gen_tyargs] => (rho, gen_tyargs, gen_seen, lthy)
       
   343               | gen_tyargss as gen_tyargs :: gen_tyargss_tl =>
       
   344                 let
       
   345                   val unify_pairs = split_list (maps (curry (op ~~) gen_tyargs) gen_tyargss_tl);
       
   346                   val mgu = Type.raw_unifys unify_pairs Vartab.empty;
       
   347                   val gen_tyargs' = map (Envir.subst_type mgu) gen_tyargs;
       
   348                   val gen_seen' = map (Envir.subst_type mgu) gen_seen;
       
   349                 in
       
   350                   (rho, gen_tyargs', gen_seen', lthy)
       
   351                 end)
       
   352             else
       
   353               fresh_tyargs ();
       
   354 
       
   355           val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
       
   356           val Ts' = filter_out (member (op =) mutual_Ts) Ts;
   306         in
   357         in
   307           mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts'
   358           check_enrich_with_mutuals lthy' rho' (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts) Ts'
   308         end
   359         end
   309       | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T;
   360       | check_enrich_with_mutuals _ _ _ _ (T :: _) = not_co_datatype T;
   310 
   361 
   311     val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts;
   362     val (perm_Ts, perm_gen_Ts) = check_enrich_with_mutuals lthy [] [] [] perm_actual_Ts;
       
   363     val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
       
   364 
   312     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   365     val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   313     val Ts = actual_Ts @ missing_Ts;
   366     val Ts = actual_Ts @ missing_Ts;
   314 
   367 
   315     val nn = length Ts;
   368     val nn = length Ts;
   316     val kks = 0 upto nn - 1;
   369     val kks = 0 upto nn - 1;
   332     val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0;
   385     val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0;
   333 
   386 
   334     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
   387     val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices;
   335 
   388 
   336     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   389     val ((perm_fp_sugars, fp_sugar_thms), lthy) =
   337       mutualize_fp_sugars has_nested fp perm_bs perm_Ts get_perm_indices perm_callssss
   390       mutualize_fp_sugars has_nested fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss
   338         perm_fp_sugars0 lthy;
   391         perm_fp_sugars0 lthy;
   339 
   392 
   340     val fp_sugars = unpermute perm_fp_sugars;
   393     val fp_sugars = unpermute perm_fp_sugars;
   341   in
   394   in
   342     ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)
   395     ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy)