src/HOL/Tools/enriched_type.ML
author wenzelm
Thu, 15 Mar 2012 20:07:00 +0100
changeset 47823 94aa7b81bcf6
parent 47722 0b8dd4c8c79a
child 47836 5c6955f487e5
permissions -rw-r--r--
prefer formally checked @{keyword} parser;
     1 (*  Title:      HOL/Tools/enriched_type.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Functorial structure of types.
     5 *)
     6 
     7 signature ENRICHED_TYPE =
     8 sig
     9   val find_atomic: Proof.context -> typ -> (typ * (bool * bool)) list
    10   val construct_mapper: Proof.context -> (string * bool -> term)
    11     -> bool -> typ -> typ -> term
    12   val enriched_type: string option -> term -> local_theory -> Proof.state
    13   type entry
    14   val entries: Proof.context -> entry list Symtab.table
    15 end;
    16 
    17 structure Enriched_Type : ENRICHED_TYPE =
    18 struct
    19 
    20 (* bookkeeping *)
    21 
    22 val compN = "comp";
    23 val idN = "id";
    24 val compositionalityN = "compositionality";
    25 val identityN = "identity";
    26 
    27 type entry = { mapper: term, variances: (sort * (bool * bool)) list,
    28   comp: thm, id: thm };
    29 
    30 structure Data = Generic_Data
    31 (
    32   type T = entry list Symtab.table
    33   val empty = Symtab.empty
    34   val extend = I
    35   fun merge data = Symtab.merge (K true) data
    36 );
    37 
    38 val entries = Data.get o Context.Proof;
    39 
    40 
    41 (* type analysis *)
    42 
    43 fun term_with_typ ctxt T t = Envir.subst_term_types
    44   (Type.typ_match (Proof_Context.tsig_of ctxt) (fastype_of t, T) Vartab.empty) t;
    45 
    46 fun find_atomic ctxt T =
    47   let
    48     val variances_of = Option.map #variances o try hd o Symtab.lookup_list (entries ctxt);
    49     fun add_variance is_contra T =
    50       AList.map_default (op =) (T, (false, false))
    51         ((if is_contra then apsnd else apfst) (K true));
    52     fun analyze' is_contra (_, (co, contra)) T =
    53       (if co then analyze is_contra T else I)
    54       #> (if contra then analyze (not is_contra) T else I)
    55     and analyze is_contra (T as Type (tyco, Ts)) = (case variances_of tyco
    56           of NONE => add_variance is_contra T
    57            | SOME variances => fold2 (analyze' is_contra) variances Ts)
    58       | analyze is_contra T = add_variance is_contra T;
    59   in analyze false T [] end;
    60 
    61 fun construct_mapper ctxt atomic =
    62   let
    63     val lookup = hd o Symtab.lookup_list (entries ctxt);
    64     fun constructs is_contra (_, (co, contra)) T T' =
    65       (if co then [construct is_contra T T'] else [])
    66       @ (if contra then [construct (not is_contra) T T'] else [])
    67     and construct is_contra (T as Type (tyco, Ts)) (T' as Type (_, Ts')) =
    68           let
    69             val { mapper = raw_mapper, variances, ... } = lookup tyco;
    70             val args = maps (fn (arg_pattern, (T, T')) =>
    71               constructs is_contra arg_pattern T T')
    72                 (variances ~~ (Ts ~~ Ts'));
    73             val (U, U') = if is_contra then (T', T) else (T, T');
    74             val mapper = term_with_typ ctxt (map fastype_of args ---> U --> U') raw_mapper;
    75           in list_comb (mapper, args) end
    76       | construct is_contra (TFree (v, _)) (TFree _) = atomic (v, is_contra);
    77   in construct end;
    78 
    79 
    80 (* mapper properties *)
    81 
    82 val compositionality_ss = Simplifier.add_simp (Simpdata.mk_eq @{thm comp_def}) HOL_basic_ss;
    83 
    84 fun make_comp_prop ctxt variances (tyco, mapper) =
    85   let
    86     val sorts = map fst variances
    87     val (((vs3, vs2), vs1), _) = ctxt
    88       |> Variable.invent_types sorts
    89       ||>> Variable.invent_types sorts
    90       ||>> Variable.invent_types sorts
    91     val (Ts1, Ts2, Ts3) = (map TFree vs1, map TFree vs2, map TFree vs3);
    92     fun mk_argT ((T, T'), (_, (co, contra))) =
    93       (if co then [(T --> T')] else [])
    94       @ (if contra then [(T' --> T)] else []);
    95     val contras = maps (fn (_, (co, contra)) =>
    96       (if co then [false] else []) @ (if contra then [true] else [])) variances;
    97     val Ts21 = maps mk_argT ((Ts2 ~~ Ts1) ~~ variances);
    98     val Ts32 = maps mk_argT ((Ts3 ~~ Ts2) ~~ variances);
    99     fun invents n k nctxt =
   100       let
   101         val names = Name.invent nctxt n k;
   102       in (names, fold Name.declare names nctxt) end;
   103     val ((names21, names32), nctxt) = Variable.names_of ctxt
   104       |> invents "f" (length Ts21)
   105       ||>> invents "f" (length Ts32);
   106     val T1 = Type (tyco, Ts1);
   107     val T2 = Type (tyco, Ts2);
   108     val T3 = Type (tyco, Ts3);
   109     val (args21, args32) = (names21 ~~ Ts21, names32 ~~ Ts32);
   110     val args31 = map2 (fn is_contra => fn ((f21, T21), (f32, T32)) =>
   111       if not is_contra then
   112         HOLogic.mk_comp (Free (f21, T21), Free (f32, T32))
   113       else
   114         HOLogic.mk_comp (Free (f32, T32), Free (f21, T21))
   115       ) contras (args21 ~~ args32)
   116     fun mk_mapper T T' args = list_comb
   117       (term_with_typ ctxt (map fastype_of args ---> T --> T') mapper, args);
   118     val mapper21 = mk_mapper T2 T1 (map Free args21);
   119     val mapper32 = mk_mapper T3 T2 (map Free args32);
   120     val mapper31 = mk_mapper T3 T1 args31;
   121     val eq1 = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
   122       (HOLogic.mk_comp (mapper21, mapper32), mapper31);
   123     val x = Free (the_single (Name.invent nctxt (Long_Name.base_name tyco) 1), T3)
   124     val eq2 = (HOLogic.mk_Trueprop o HOLogic.mk_eq)
   125       (mapper21 $ (mapper32 $ x), mapper31 $ x);
   126     val comp_prop = fold_rev Logic.all (map Free (args21 @ args32)) eq1;
   127     val compositionality_prop = fold_rev Logic.all (map Free (args21 @ args32) @ [x]) eq2;
   128     fun prove_compositionality ctxt comp_thm = Skip_Proof.prove ctxt [] [] compositionality_prop
   129       (K (ALLGOALS (Method.insert_tac [@{thm fun_cong} OF [comp_thm]]
   130         THEN' Simplifier.asm_lr_simp_tac compositionality_ss
   131         THEN_ALL_NEW (Goal.assume_rule_tac ctxt))));
   132   in (comp_prop, prove_compositionality) end;
   133 
   134 val identity_ss = Simplifier.add_simp (Simpdata.mk_eq @{thm id_def}) HOL_basic_ss;
   135 
   136 fun make_id_prop ctxt variances (tyco, mapper) =
   137   let
   138     val (vs, _) = Variable.invent_types (map fst variances) ctxt;
   139     val Ts = map TFree vs;
   140     fun bool_num b = if b then 1 else 0;
   141     fun mk_argT (T, (_, (co, contra))) =
   142       replicate (bool_num co + bool_num contra) T
   143     val arg_Ts = maps mk_argT (Ts ~~ variances)
   144     val T = Type (tyco, Ts);
   145     val head = term_with_typ ctxt (map (fn T => T --> T) arg_Ts ---> T --> T) mapper;
   146     val lhs1 = list_comb (head, map (HOLogic.id_const) arg_Ts);
   147     val lhs2 = list_comb (head, map (fn arg_T => Abs ("x", arg_T, Bound 0)) arg_Ts);
   148     val rhs = HOLogic.id_const T;
   149     val (id_prop, identity_prop) = pairself
   150       (HOLogic.mk_Trueprop o HOLogic.mk_eq o rpair rhs) (lhs1, lhs2);
   151     fun prove_identity ctxt id_thm = Skip_Proof.prove ctxt [] [] identity_prop
   152       (K (ALLGOALS (Method.insert_tac [id_thm] THEN' Simplifier.asm_lr_simp_tac identity_ss)));
   153   in (id_prop, prove_identity) end;
   154 
   155 
   156 (* analyzing and registering mappers *)
   157 
   158 fun consume eq x [] = (false, [])
   159   | consume eq x (ys as z :: zs) = if eq (x, z) then (true, zs) else (false, ys);
   160 
   161 fun split_mapper_typ "fun" T =
   162       let
   163         val (Ts', T') = strip_type T;
   164         val (Ts'', T'') = split_last Ts';
   165         val (Ts''', T''') = split_last Ts'';
   166       in (Ts''', T''', T'' --> T') end
   167   | split_mapper_typ _ T =
   168       let
   169         val (Ts', T') = strip_type T;
   170         val (Ts'', T'') = split_last Ts';
   171       in (Ts'', T'', T') end;
   172 
   173 fun analyze_mapper ctxt input_mapper =
   174   let
   175     val T = fastype_of input_mapper;
   176     val _ = Type.no_tvars T;
   177     val _ =
   178       if null (subtract (op =) (Term.add_tfreesT T []) (Term.add_tfrees input_mapper []))
   179       then ()
   180       else error ("Illegal additional type variable(s) in term: " ^ Syntax.string_of_term ctxt input_mapper);
   181     val _ =
   182       if null (Term.add_vars (singleton
   183         (Variable.export_terms (Variable.auto_fixes input_mapper ctxt) ctxt) input_mapper) [])
   184       then ()
   185       else error ("Illegal locally free variable(s) in term: "
   186         ^ Syntax.string_of_term ctxt input_mapper);;
   187     val mapper = singleton (Variable.polymorphic ctxt) input_mapper;
   188     val _ =
   189       if null (Term.add_tfreesT (fastype_of mapper) []) then ()
   190       else error ("Illegal locally fixed type variable(s) in type: " ^ Syntax.string_of_typ ctxt T);
   191     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
   192       | add_tycos _ = I;
   193     val tycos = add_tycos T [];
   194     val tyco = if tycos = ["fun"] then "fun"
   195       else case remove (op =) "fun" tycos
   196        of [tyco] => tyco
   197         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ ctxt T);
   198   in (mapper, T, tyco) end;
   199 
   200 fun analyze_variances ctxt tyco T =
   201   let
   202     fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ ctxt T);
   203     val (Ts, T1, T2) = split_mapper_typ tyco T
   204       handle List.Empty => bad_typ ();
   205     val _ = pairself
   206       ((fn tyco' => if tyco' = tyco then () else bad_typ ()) o fst o dest_Type) (T1, T2)
   207       handle TYPE _ => bad_typ ();
   208     val (vs1, vs2) = pairself (map dest_TFree o snd o dest_Type) (T1, T2)
   209       handle TYPE _ => bad_typ ();
   210     val _ = if has_duplicates (eq_fst (op =)) (vs1 @ vs2)
   211       then bad_typ () else ();
   212     fun check_variance_pair (var1 as (_, sort1), var2 as (_, sort2)) =
   213       let
   214         val coT = TFree var1 --> TFree var2;
   215         val contraT = TFree var2 --> TFree var1;
   216         val sort = Sign.inter_sort (Proof_Context.theory_of ctxt) (sort1, sort2);
   217       in
   218         consume (op =) coT
   219         ##>> consume (op =) contraT
   220         #>> pair sort
   221       end;
   222     val (variances, left_variances) = fold_map check_variance_pair (vs1 ~~ vs2) Ts;
   223     val _ = if null left_variances then () else bad_typ ();
   224   in variances end;
   225 
   226 fun gen_enriched_type prep_term some_prfx raw_mapper lthy =
   227   let
   228     val (mapper, T, tyco) = analyze_mapper lthy (prep_term lthy raw_mapper);
   229     val prfx = the_default (Long_Name.base_name tyco) some_prfx;
   230     val variances = analyze_variances lthy tyco T;
   231     val (comp_prop, prove_compositionality) = make_comp_prop lthy variances (tyco, mapper);
   232     val (id_prop, prove_identity) = make_id_prop lthy variances (tyco, mapper);
   233     val qualify = Binding.qualify true prfx o Binding.name;
   234     fun mapper_declaration comp_thm id_thm phi context =
   235       let
   236         val typ_instance = Sign.typ_instance (Context.theory_of context);
   237         val mapper' = Morphism.term phi mapper;
   238         val T_T' = pairself fastype_of (mapper, mapper');
   239         val vars = Term.add_vars mapper' [];
   240       in
   241         if null vars andalso typ_instance T_T' andalso typ_instance (swap T_T')
   242         then (Data.map o Symtab.cons_list) (tyco,
   243           { mapper = mapper', variances = variances,
   244             comp = Morphism.thm phi comp_thm, id = Morphism.thm phi id_thm }) context
   245         else context
   246       end;
   247     fun after_qed [single_comp_thm, single_id_thm] lthy =
   248       lthy
   249       |> Local_Theory.note ((qualify compN, []), single_comp_thm)
   250       ||>> Local_Theory.note ((qualify idN, []), single_id_thm)
   251       |-> (fn ((_, [comp_thm]), (_, [id_thm])) => fn lthy =>
   252         lthy
   253         |> Local_Theory.note ((qualify compositionalityN, []),
   254             [prove_compositionality lthy comp_thm])
   255         |> snd
   256         |> Local_Theory.note ((qualify identityN, []),
   257             [prove_identity lthy id_thm])
   258         |> snd
   259         |> Local_Theory.declaration {syntax = false, pervasive = false}
   260           (mapper_declaration comp_thm id_thm))
   261   in
   262     lthy
   263     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
   264   end
   265 
   266 val enriched_type = gen_enriched_type Syntax.check_term;
   267 val enriched_type_cmd = gen_enriched_type Syntax.read_term;
   268 
   269 val _ = Outer_Syntax.local_theory_to_proof "enriched_type"
   270   "register operations managing the functorial structure of a type"
   271   Keyword.thy_goal (Scan.option (Parse.name --| @{keyword ":"}) -- Parse.term
   272     >> (fn (prfx, t) => enriched_type_cmd prfx t));
   273 
   274 end;