262 ((fp_sugars0, (NONE, NONE)), no_defs_lthy0); |
262 ((fp_sugars0, (NONE, NONE)), no_defs_lthy0); |
263 |
263 |
264 fun indexify_callsss fp_sugar callsss = |
264 fun indexify_callsss fp_sugar callsss = |
265 let |
265 let |
266 val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar; |
266 val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar; |
267 fun do_ctr ctr = |
267 fun indexify_ctr ctr = |
268 (case AList.lookup Term.aconv_untyped callsss ctr of |
268 (case AList.lookup Term.aconv_untyped callsss ctr of |
269 NONE => replicate (num_binder_types (fastype_of ctr)) [] |
269 NONE => replicate (num_binder_types (fastype_of ctr)) [] |
270 | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss); |
270 | SOME callss => map (map (Envir.beta_eta_contract o unfold_let)) callss); |
271 in |
271 in |
272 map do_ctr ctrs |
272 map indexify_ctr ctrs |
273 end; |
273 end; |
|
274 |
|
275 fun retypargs tyargs (Type (s, _)) = Type (s, tyargs); |
|
276 |
|
277 fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) = |
|
278 f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I) |
|
279 | fold_subtype_pairs f TU = f TU; |
274 |
280 |
275 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy = |
281 fun nested_to_mutual_fps fp actual_bs actual_Ts get_indices actual_callssss0 lthy = |
276 let |
282 let |
277 val qsoty = quote o Syntax.string_of_typ lthy; |
283 val qsoty = quote o Syntax.string_of_typ lthy; |
278 val qsotys = space_implode " or " o map qsoty; |
284 val qsotys = space_implode " or " o map qsoty; |
290 val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T); |
296 val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T); |
291 |
297 |
292 val perm_actual_Ts as Type (_, tyargs0) :: _ = |
298 val perm_actual_Ts as Type (_, tyargs0) :: _ = |
293 sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts; |
299 sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts; |
294 |
300 |
|
301 fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s))); |
|
302 |
295 fun the_fp_sugar_of (T as Type (T_name, _)) = |
303 fun the_fp_sugar_of (T as Type (T_name, _)) = |
296 (case fp_sugar_of lthy T_name of |
304 (case fp_sugar_of lthy T_name of |
297 SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T |
305 SOME (fp_sugar as {fp = fp', ...}) => if fp = fp' then fp_sugar else not_co_datatype T |
298 | NONE => not_co_datatype T); |
306 | NONE => not_co_datatype T); |
299 |
307 |
300 fun check_enrich_with_mutuals _ [] = [] |
308 fun gen_rhss_in gen_Ts rho subTs = |
301 | check_enrich_with_mutuals seen ((T as Type (_, tyargs)) :: Ts) = |
309 let |
|
310 fun maybe_insert (T, Type (_, gen_tyargs)) = |
|
311 if member (op =) subTs T then insert (op =) gen_tyargs else I |
|
312 | maybe_insert _ = I; |
|
313 |
|
314 val ctrs = maps the_ctrs_of gen_Ts; |
|
315 val gen_ctr_Ts = maps (binder_types o fastype_of) ctrs; |
|
316 val ctr_Ts = map (Term.typ_subst_atomic rho) gen_ctr_Ts; |
|
317 in |
|
318 fold (fold_subtype_pairs maybe_insert) (ctr_Ts ~~ gen_ctr_Ts) [] |
|
319 end; |
|
320 |
|
321 fun check_enrich_with_mutuals _ _ seen gen_seen [] = (seen, gen_seen) |
|
322 | check_enrich_with_mutuals lthy rho seen gen_seen ((T as Type (_, tyargs)) :: Ts) = |
302 let |
323 let |
303 val {fp_res = {Ts = Ts', ...}, ...} = the_fp_sugar_of T |
324 val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T; |
304 val mutual_Ts = map (fn Type (s, _) => Type (s, tyargs)) Ts'; |
325 val mutual_Ts = map (retypargs tyargs) mutual_Ts0; |
305 val (seen', Ts') = List.partition (member (op =) mutual_Ts) Ts; |
326 |
|
327 fun fresh_tyargs () = |
|
328 let |
|
329 (* The name "'z" is unlikely to clash with the context, yielding more cache hits. *) |
|
330 val (gen_tyargs, lthy') = |
|
331 variant_tfrees (replicate (length tyargs) "z") lthy |
|
332 |>> map Logic.varifyT_global; |
|
333 val rho' = (gen_tyargs ~~ tyargs) @ rho; |
|
334 in |
|
335 (rho', gen_tyargs, gen_seen, lthy') |
|
336 end; |
|
337 |
|
338 val (rho', gen_tyargs, gen_seen', lthy') = |
|
339 if exists (exists_subtype_in seen) mutual_Ts then |
|
340 (case gen_rhss_in gen_seen rho mutual_Ts of |
|
341 [] => fresh_tyargs () |
|
342 | [gen_tyargs] => (rho, gen_tyargs, gen_seen, lthy) |
|
343 | gen_tyargss as gen_tyargs :: gen_tyargss_tl => |
|
344 let |
|
345 val unify_pairs = split_list (maps (curry (op ~~) gen_tyargs) gen_tyargss_tl); |
|
346 val mgu = Type.raw_unifys unify_pairs Vartab.empty; |
|
347 val gen_tyargs' = map (Envir.subst_type mgu) gen_tyargs; |
|
348 val gen_seen' = map (Envir.subst_type mgu) gen_seen; |
|
349 in |
|
350 (rho, gen_tyargs', gen_seen', lthy) |
|
351 end) |
|
352 else |
|
353 fresh_tyargs (); |
|
354 |
|
355 val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0; |
|
356 val Ts' = filter_out (member (op =) mutual_Ts) Ts; |
306 in |
357 in |
307 mutual_Ts @ check_enrich_with_mutuals (seen @ T :: seen') Ts' |
358 check_enrich_with_mutuals lthy' rho' (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts) Ts' |
308 end |
359 end |
309 | check_enrich_with_mutuals _ (T :: _) = not_co_datatype T; |
360 | check_enrich_with_mutuals _ _ _ _ (T :: _) = not_co_datatype T; |
310 |
361 |
311 val perm_Ts = check_enrich_with_mutuals [] perm_actual_Ts; |
362 val (perm_Ts, perm_gen_Ts) = check_enrich_with_mutuals lthy [] [] [] perm_actual_Ts; |
|
363 val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts; |
|
364 |
312 val missing_Ts = perm_Ts |> subtract (op =) actual_Ts; |
365 val missing_Ts = perm_Ts |> subtract (op =) actual_Ts; |
313 val Ts = actual_Ts @ missing_Ts; |
366 val Ts = actual_Ts @ missing_Ts; |
314 |
367 |
315 val nn = length Ts; |
368 val nn = length Ts; |
316 val kks = 0 upto nn - 1; |
369 val kks = 0 upto nn - 1; |
332 val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0; |
385 val perm_callssss = map2 indexify_callsss perm_fp_sugars0 perm_callssss0; |
333 |
386 |
334 val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices; |
387 val get_perm_indices = map (fn kk => find_index (curry (op =) kk) perm_kks) o get_indices; |
335 |
388 |
336 val ((perm_fp_sugars, fp_sugar_thms), lthy) = |
389 val ((perm_fp_sugars, fp_sugar_thms), lthy) = |
337 mutualize_fp_sugars has_nested fp perm_bs perm_Ts get_perm_indices perm_callssss |
390 mutualize_fp_sugars has_nested fp perm_bs perm_frozen_gen_Ts get_perm_indices perm_callssss |
338 perm_fp_sugars0 lthy; |
391 perm_fp_sugars0 lthy; |
339 |
392 |
340 val fp_sugars = unpermute perm_fp_sugars; |
393 val fp_sugars = unpermute perm_fp_sugars; |
341 in |
394 in |
342 ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy) |
395 ((missing_Ts, perm_kks, fp_sugars, fp_sugar_thms), lthy) |