src/HOL/Tools/record_package.ML
changeset 24867 e5b55d7be9bb
parent 24830 a7b3ab44d993
child 25070 e2a39b6526b0
equal deleted inserted replaced
24866:6e6d9e80ebb4 24867:e5b55d7be9bb
   605                   => (let
   605                   => (let
   606                        val flds' = but_last flds;
   606                        val flds' = but_last flds;
   607                        val types = map snd flds';
   607                        val types = map snd flds';
   608                        val (args,rest) = splitargs (map fst flds') fargs;
   608                        val (args,rest) = splitargs (map fst flds') fargs;
   609                        val argtypes = map (Sign.certify_typ thy o decode_type thy) args;
   609                        val argtypes = map (Sign.certify_typ thy o decode_type thy) args;
   610                        val midx =  fold (fn T => fn i => Int.max (maxidx_of_typ T, i)) 
   610                        val midx =  fold (fn T => fn i => Int.max (maxidx_of_typ T, i))
   611                                     argtypes 0;
   611                                     argtypes 0;
   612                        val varifyT = varifyT midx;
   612                        val varifyT = varifyT midx;
   613                        val vartypes = map varifyT types;
   613                        val vartypes = map varifyT types;
   614 
   614 
   615                        val subst = fold (Sign.typ_match thy) (vartypes ~~ argtypes)
   615                        val subst = fold (Sign.typ_match thy) (vartypes ~~ argtypes)
   652         (Syntax.term_of_typ false (HOLogic.unitT));
   652         (Syntax.term_of_typ false (HOLogic.unitT));
   653 val adv_record_type_scheme_tr =
   653 val adv_record_type_scheme_tr =
   654       gen_adv_record_type_scheme_tr "_field_types" "_field_type" ext_typeN;
   654       gen_adv_record_type_scheme_tr "_field_types" "_field_type" ext_typeN;
   655 
   655 
   656 
   656 
   657 val parse_translation = 
   657 val parse_translation =
   658  [("_record_update", record_update_tr),
   658  [("_record_update", record_update_tr),
   659   ("_update_name", update_name_tr)];
   659   ("_update_name", update_name_tr)];
   660 
   660 
   661 
   661 
   662 val adv_parse_translation = 
   662 val adv_parse_translation =
   663  [("_record",adv_record_tr),
   663  [("_record",adv_record_tr),
   664   ("_record_scheme",adv_record_scheme_tr),
   664   ("_record_scheme",adv_record_scheme_tr),
   665   ("_record_type",adv_record_type_tr),
   665   ("_record_type",adv_record_type_tr),
   666   ("_record_type_scheme",adv_record_type_scheme_tr)];
   666   ("_record_type_scheme",adv_record_type_scheme_tr)];
   667 
   667 
   760       fun mk_type_abbr subst name alphas =
   760       fun mk_type_abbr subst name alphas =
   761           let val abbrT = Type (name, map (fn a => varifyT (TFree (a, Sign.defaultS thy))) alphas);
   761           let val abbrT = Type (name, map (fn a => varifyT (TFree (a, Sign.defaultS thy))) alphas);
   762           in Syntax.term_of_typ (! Syntax.show_sorts)
   762           in Syntax.term_of_typ (! Syntax.show_sorts)
   763                (Sign.extern_typ thy (Envir.norm_type subst abbrT)) end;
   763                (Sign.extern_typ thy (Envir.norm_type subst abbrT)) end;
   764 
   764 
   765       fun match rT T = (Sign.typ_match thy (varifyT rT,T) 
   765       fun match rT T = (Sign.typ_match thy (varifyT rT,T)
   766                                                 Vartab.empty);
   766                                                 Vartab.empty);
   767 
   767 
   768    in if !print_record_type_abbr
   768    in if !print_record_type_abbr
   769       then (case last_extT T of
   769       then (case last_extT T of
   770              SOME (name,_)
   770              SOME (name,_)
   874      val T = domain_type fT;
   874      val T = domain_type fT;
   875      val (x,T') = hd (Term.variant_frees t' [("x",T)]);
   875      val (x,T') = hd (Term.variant_frees t' [("x",T)]);
   876      val f_x = Free (f',fT)$(Free (x,T'));
   876      val f_x = Free (f',fT)$(Free (x,T'));
   877      fun is_constr (Const (c,_)$_) = can (unsuffix extN) c
   877      fun is_constr (Const (c,_)$_) = can (unsuffix extN) c
   878        | is_constr _ = false;
   878        | is_constr _ = false;
   879      fun subst (t as u$w) = if Free (f',fT)=u 
   879      fun subst (t as u$w) = if Free (f',fT)=u
   880                             then if is_constr w then f_x 
   880                             then if is_constr w then f_x
   881                                  else raise TERM ("abstract_over_fun_app",[t])
   881                                  else raise TERM ("abstract_over_fun_app",[t])
   882                             else subst u$subst w
   882                             else subst u$subst w
   883        | subst (Abs (x,T,t)) = (Abs (x,T,subst t))
   883        | subst (Abs (x,T,t)) = (Abs (x,T,subst t))
   884        | subst t = t
   884        | subst t = t
   885      val t'' = abstract_over (f_x,subst t');
   885      val t'' = abstract_over (f_x,subst t');
   886      val vars = strip_qnt_vars "all" t'';
   886      val vars = strip_qnt_vars "all" t'';
   887      val bdy = strip_qnt_body "all" t'';
   887      val bdy = strip_qnt_body "all" t'';
   888      
   888 
   889   in list_abs ((x,T')::vars,bdy) end
   889   in list_abs ((x,T')::vars,bdy) end
   890   | abstract_over_fun_app t = raise TERM ("abstract_over_fun_app",[t]);
   890   | abstract_over_fun_app t = raise TERM ("abstract_over_fun_app",[t]);
   891 (* Generates a theorem of the kind:
   891 (* Generates a theorem of the kind:
   892  * !!f x*. PROP P (f ( r x* ) x* == !!r x*. PROP P r x* 
   892  * !!f x*. PROP P (f ( r x* ) x* == !!r x*. PROP P r x*
   893  *) 
   893  *)
   894 fun mk_fun_apply_eq (Abs (f, fT, t)) thy =
   894 fun mk_fun_apply_eq (Abs (f, fT, t)) thy =
   895   let
   895   let
   896     val rT = domain_type fT;
   896     val rT = domain_type fT;
   897     val vars = Term.strip_qnt_vars "all" t;
   897     val vars = Term.strip_qnt_vars "all" t;
   898     val Ts = map snd vars;
   898     val Ts = map snd vars;
   899     val n = length vars;
   899     val n = length vars;
   900     fun app_bounds 0 t = t$Bound 0
   900     fun app_bounds 0 t = t$Bound 0
   901       | app_bounds n t = if n > 0 then app_bounds (n-1) (t$Bound n) else t
   901       | app_bounds n t = if n > 0 then app_bounds (n-1) (t$Bound n) else t
   902 
   902 
   903    
   903 
   904     val [P,r] = Term.variant_frees t [("P",rT::Ts--->Term.propT),("r",Ts--->rT)];
   904     val [P,r] = Term.variant_frees t [("P",rT::Ts--->Term.propT),("r",Ts--->rT)];
   905     val prop = Logic.mk_equals
   905     val prop = Logic.mk_equals
   906                 (list_all ((f,fT)::vars,
   906                 (list_all ((f,fT)::vars,
   907                            app_bounds (n - 1) ((Free P)$(Bound n$app_bounds (n-1) (Free r)))),
   907                            app_bounds (n - 1) ((Free P)$(Bound n$app_bounds (n-1) (Free r)))),
   908                  list_all ((fst r,rT)::vars,
   908                  list_all ((fst r,rT)::vars,
   909                            app_bounds (n - 1) ((Free P)$Bound n))); 
   909                            app_bounds (n - 1) ((Free P)$Bound n)));
   910     val prove_standard = quick_and_dirty_prove true thy;
   910     val prove_standard = quick_and_dirty_prove true thy;
   911     val thm = prove_standard [] prop (fn prems =>
   911     val thm = prove_standard [] prop (fn prems =>
   912 	 EVERY [rtac equal_intr_rule 1, 
   912 	 EVERY [rtac equal_intr_rule 1,
   913                 Goal.norm_hhf_tac 1,REPEAT (etac meta_allE 1), atac 1,
   913                 Goal.norm_hhf_tac 1,REPEAT (etac meta_allE 1), atac 1,
   914                 Goal.norm_hhf_tac 1,REPEAT (etac meta_allE 1), atac 1]);
   914                 Goal.norm_hhf_tac 1,REPEAT (etac meta_allE 1), atac 1]);
   915   in thm end
   915   in thm end
   916   | mk_fun_apply_eq t thy = raise TERM ("mk_fun_apply_eq",[t]);
   916   | mk_fun_apply_eq t thy = raise TERM ("mk_fun_apply_eq",[t]);
   917 
   917 
   926     (fn thy => fn _ => fn t =>
   926     (fn thy => fn _ => fn t =>
   927       (case t of (Const ("all", Type (_, [Type (_, [Type("fun",[T,T']), _]), _])))$
   927       (case t of (Const ("all", Type (_, [Type (_, [Type("fun",[T,T']), _]), _])))$
   928                   (trm as Abs _) =>
   928                   (trm as Abs _) =>
   929          (case rec_id (~1) T of
   929          (case rec_id (~1) T of
   930             "" => NONE
   930             "" => NONE
   931           | n => if T=T'  
   931           | n => if T=T'
   932                  then (let
   932                  then (let
   933                         val P=cterm_of thy (abstract_over_fun_app trm); 
   933                         val P=cterm_of thy (abstract_over_fun_app trm);
   934                         val thm = mk_fun_apply_eq trm thy;
   934                         val thm = mk_fun_apply_eq trm thy;
   935                         val PV = cterm_of thy (hd (term_vars (prop_of thm)));
   935                         val PV = cterm_of thy (hd (term_vars (prop_of thm)));
   936                         val thm' = cterm_instantiate [(PV,P)] thm;
   936                         val thm' = cterm_instantiate [(PV,P)] thm;
   937                        in SOME  thm' end handle TERM _ => NONE)
   937                        in SOME  thm' end handle TERM _ => NONE)
   938                 else NONE) 
   938                 else NONE)
   939        | _ => NONE))
   939        | _ => NONE))
   940 end
   940 end
   941 
   941 
   942 fun prove_split_simp thy ss T prop =
   942 fun prove_split_simp thy ss T prop =
   943   let
   943   let
   949                    SOME (all_thm,_,_,_) =>
   949                    SOME (all_thm,_,_,_) =>
   950                      all_thm::(case extsplits of [thm] => [] | _ => extsplits)
   950                      all_thm::(case extsplits of [thm] => [] | _ => extsplits)
   951                               (* [thm] is the same as all_thm *)
   951                               (* [thm] is the same as all_thm *)
   952                  | NONE => extsplits)
   952                  | NONE => extsplits)
   953     val thms'=K_comp_convs@thms;
   953     val thms'=K_comp_convs@thms;
   954     val ss' = (Simplifier.inherit_context ss simpset 
   954     val ss' = (Simplifier.inherit_context ss simpset
   955                 addsimps thms'
   955                 addsimps thms'
   956                 addsimprocs [record_split_f_more_simproc]);
   956                 addsimprocs [record_split_f_more_simproc]);
   957   in
   957   in
   958     quick_and_dirty_prove true thy [] prop (fn _ => simp_tac ss' 1)
   958     quick_and_dirty_prove true thy [] prop (fn _ => simp_tac ss' 1)
   959   end;
   959   end;
   990               fun mk_eq_terms ((upd as Const (u,Type(_,[kT,_]))) $ k $ r) =
   990               fun mk_eq_terms ((upd as Const (u,Type(_,[kT,_]))) $ k $ r) =
   991                   (case Symtab.lookup updates u of
   991                   (case Symtab.lookup updates u of
   992                      NONE => NONE
   992                      NONE => NONE
   993                    | SOME u_name
   993                    | SOME u_name
   994                      => if u_name = s
   994                      => if u_name = s
   995                         then (case mk_eq_terms r of 
   995                         then (case mk_eq_terms r of
   996                                NONE => 
   996                                NONE =>
   997                                  let
   997                                  let
   998                                    val rv = ("r",rT)
   998                                    val rv = ("r",rT)
   999                                    val rb = Bound 0
   999                                    val rb = Bound 0
  1000                                    val kv = ("k",kT)
  1000                                    val kv = ("k",kT)
  1001                                    val kb = Bound 1
  1001                                    val kb = Bound 1
  1062                    then let val kv = ("k", kT);
  1062                    then let val kv = ("k", kT);
  1063                             val kb = Bound (length vars);
  1063                             val kb = Bound (length vars);
  1064                         in ((Const (u,uT)$k$sprout,Const (u,uT)$kb$skeleton),kv::vars) end
  1064                         in ((Const (u,uT)$k$sprout,Const (u,uT)$kb$skeleton),kv::vars) end
  1065                    else ((sprout,skeleton),vars);
  1065                    else ((sprout,skeleton),vars);
  1066 
  1066 
  1067              fun is_upd_same (sprout,skeleton) u 
  1067              fun is_upd_same (sprout,skeleton) u
  1068                                 ((K_rec as Const ("Record.K_record",_))$
  1068                                 ((K_rec as Const ("Record.K_record",_))$
  1069                                   ((sel as Const (s,_))$r)) =
  1069                                   ((sel as Const (s,_))$r)) =
  1070                    if (unsuffix updateN u) = s andalso (seed s sprout) = r
  1070                    if (unsuffix updateN u) = s andalso (seed s sprout) = r
  1071                    then SOME (K_rec,sel,seed s skeleton)
  1071                    then SOME (K_rec,sel,seed s skeleton)
  1072                    else NONE
  1072                    else NONE
  1073                | is_upd_same _ _ _ = NONE
  1073                | is_upd_same _ _ _ = NONE
  1074 
  1074 
  1075              fun init_seed r = ((r,Bound 0), [("r", rT)]);
  1075              fun init_seed r = ((r,Bound 0), [("r", rT)]);
  1076 
  1076 
  1077              fun add (n:string) f fmaps = 
  1077              fun add (n:string) f fmaps =
  1078                (case AList.lookup (op =) fmaps n of
  1078                (case AList.lookup (op =) fmaps n of
  1079                   NONE => AList.update (op =) (n,[f]) fmaps
  1079                   NONE => AList.update (op =) (n,[f]) fmaps
  1080                 | SOME fs => AList.update (op =) (n,f::fs) fmaps) 
  1080                 | SOME fs => AList.update (op =) (n,f::fs) fmaps)
  1081 
  1081 
  1082              fun comps (n:string) T fmaps = 
  1082              fun comps (n:string) T fmaps =
  1083                (case AList.lookup (op =) fmaps n of
  1083                (case AList.lookup (op =) fmaps n of
  1084                  SOME fs => 
  1084                  SOME fs =>
  1085                    foldr1 (fn (f,g) => Const ("Fun.comp",(T-->T)-->(T-->T)-->(T-->T))$f$g) fs
  1085                    foldr1 (fn (f,g) => Const ("Fun.comp",(T-->T)-->(T-->T)-->(T-->T))$f$g) fs
  1086                 | NONE => error ("record_upd_simproc.comps"))
  1086                 | NONE => error ("record_upd_simproc.comps"))
  1087              
  1087 
  1088              (* mk_updterm returns either
  1088              (* mk_updterm returns either
  1089               *  - Init (orig-term, orig-term-skeleton, vars) if no optimisation can be made,
  1089               *  - Init (orig-term, orig-term-skeleton, vars) if no optimisation can be made,
  1090               *     where vars are the bound variables in the skeleton
  1090               *     where vars are the bound variables in the skeleton
  1091               *  - Inter (orig-term-skeleton,simplified-term-skeleton,
  1091               *  - Inter (orig-term-skeleton,simplified-term-skeleton,
  1092               *           vars, (term-sprout, skeleton-sprout))
  1092               *           vars, (term-sprout, skeleton-sprout))
  1117                                  let
  1117                                  let
  1118                                    val n = sel_name u;
  1118                                    val n = sel_name u;
  1119                                    val kv = (n, kT);
  1119                                    val kv = (n, kT);
  1120                                    val kb = Bound (length vars);
  1120                                    val kb = Bound (length vars);
  1121                                    val (sprout',vars') = grow u uT k kT (kv::vars) sprout;
  1121                                    val (sprout',vars') = grow u uT k kT (kv::vars) sprout;
  1122                                  in Inter(upd$kb$trm,trm',kv::vars',add n kb fmaps,sprout') 
  1122                                  in Inter(upd$kb$trm,trm',kv::vars',add n kb fmaps,sprout')
  1123                                  end)
  1123                                  end)
  1124                          else
  1124                          else
  1125                           (case rest (u::already) r of
  1125                           (case rest (u::already) r of
  1126                              Init ((sprout,skel),vars) =>
  1126                              Init ((sprout,skel),vars) =>
  1127                               (case is_upd_same (sprout,skel) u k of
  1127                               (case is_upd_same (sprout,skel) u k of
  1147                                     val n = sel_name u
  1147                                     val n = sel_name u
  1148                                     val T = domain_type kT
  1148                                     val T = domain_type kT
  1149                                     val kv = (n, kT)
  1149                                     val kv = (n, kT)
  1150                                     val kb = Bound (length vars)
  1150                                     val kb = Bound (length vars)
  1151                                     val (sprout',vars') = grow u uT k kT (kv::vars) sprout
  1151                                     val (sprout',vars') = grow u uT k kT (kv::vars) sprout
  1152                                     val fmaps' = add n kb fmaps 
  1152                                     val fmaps' = add n kb fmaps
  1153                                   in Inter (upd$kb$trm,upd$comps n T fmaps'$trm'
  1153                                   in Inter (upd$kb$trm,upd$comps n T fmaps'$trm'
  1154                                            ,vars',fmaps',sprout') end))
  1154                                            ,vars',fmaps',sprout') end))
  1155                      end
  1155                      end
  1156                  else Init (init_seed t)
  1156                  else Init (init_seed t)
  1157                | mk_updterm _ _ t = Init (init_seed t);
  1157                | mk_updterm _ _ t = Init (init_seed t);
  1159          in (case mk_updterm updates [] t of
  1159          in (case mk_updterm updates [] t of
  1160                Inter (trm,trm',vars,_,_)
  1160                Inter (trm,trm',vars,_,_)
  1161                 => SOME (prove_split_simp thy ss rT
  1161                 => SOME (prove_split_simp thy ss rT
  1162                           (list_all(vars,(equals rT$trm$trm'))))
  1162                           (list_all(vars,(equals rT$trm$trm'))))
  1163              | _ => NONE)
  1163              | _ => NONE)
  1164          end 
  1164          end
  1165        | _ => NONE))
  1165        | _ => NONE))
  1166 end
  1166 end
  1167 
  1167 
  1168 (* record_eq_simproc *)
  1168 (* record_eq_simproc *)
  1169 (* looks up the most specific record-equality.
  1169 (* looks up the most specific record-equality.
  1623         fun mkrefl (c,T) = Thm.reflexive
  1623         fun mkrefl (c,T) = Thm.reflexive
  1624                     (cterm_of defs_thy (Free (Name.variant variants (base c ^ "'"),T-->T)));
  1624                     (cterm_of defs_thy (Free (Name.variant variants (base c ^ "'"),T-->T)));
  1625         val refls = map mkrefl fields_more;
  1625         val refls = map mkrefl fields_more;
  1626         val dest_convs' = map mk_meta_eq dest_convs;
  1626         val dest_convs' = map mk_meta_eq dest_convs;
  1627         val map_eqs = map (uncurry Thm.combination) (refls ~~ dest_convs');
  1627         val map_eqs = map (uncurry Thm.combination) (refls ~~ dest_convs');
  1628         
  1628 
  1629         val constr_refl = Thm.reflexive (cterm_of defs_thy (head_of ext));
  1629         val constr_refl = Thm.reflexive (cterm_of defs_thy (head_of ext));
  1630 
  1630 
  1631         fun mkthm (udef,(fld_refl,thms)) =
  1631         fun mkthm (udef,(fld_refl,thms)) =
  1632           let val bdyeq = Library.foldl (uncurry Thm.combination) (constr_refl,thms);
  1632           let val bdyeq = Library.foldl (uncurry Thm.combination) (constr_refl,thms);
  1633                (* (|N=N (|N=N,M=M,K=K,more=more|)
  1633                (* (|N=N (|N=N,M=M,K=K,more=more|)
  1662                 rtac (prop_subst OF [surjective]) 1,
  1662                 rtac (prop_subst OF [surjective]) 1,
  1663                 REPEAT (etac meta_allE 1), atac 1]);
  1663                 REPEAT (etac meta_allE 1), atac 1]);
  1664     val split_meta = timeit_msg "record extension split_meta proof:" split_meta_prf;
  1664     val split_meta = timeit_msg "record extension split_meta proof:" split_meta_prf;
  1665 
  1665 
  1666 
  1666 
  1667     val (([inject',induct',cases',surjective',split_meta'], 
  1667     val (([inject',induct',cases',surjective',split_meta'],
  1668           [dest_convs',upd_convs']),
  1668           [dest_convs',upd_convs']),
  1669       thm_thy) =
  1669       thm_thy) =
  1670       defs_thy
  1670       defs_thy
  1671       |> (PureThy.add_thms o map Thm.no_attributes)
  1671       |> (PureThy.add_thms o map Thm.no_attributes)
  1672            [("ext_inject", inject),
  1672            [("ext_inject", inject),
  1913           (upd_decls @ [make_decl, fields_decl, extend_decl, truncate_decl])
  1913           (upd_decls @ [make_decl, fields_decl, extend_decl, truncate_decl])
  1914       |> ((PureThy.add_defs_i false o map Thm.no_attributes) sel_specs)
  1914       |> ((PureThy.add_defs_i false o map Thm.no_attributes) sel_specs)
  1915       ||>> ((PureThy.add_defs_i false o map Thm.no_attributes) upd_specs)
  1915       ||>> ((PureThy.add_defs_i false o map Thm.no_attributes) upd_specs)
  1916       ||>> ((PureThy.add_defs_i false o map Thm.no_attributes)
  1916       ||>> ((PureThy.add_defs_i false o map Thm.no_attributes)
  1917              [make_spec, fields_spec, extend_spec, truncate_spec])
  1917              [make_spec, fields_spec, extend_spec, truncate_spec])
  1918       |-> (fn defs as ((sel_defs, upd_defs), derived_defs) => 
  1918       |-> (fn defs as ((sel_defs, upd_defs), derived_defs) =>
  1919           fold Code.add_default_func sel_defs
  1919           fold Code.add_default_func sel_defs
  1920           #> fold Code.add_default_func upd_defs
  1920           #> fold Code.add_default_func upd_defs
  1921           #> fold Code.add_default_func derived_defs
  1921           #> fold Code.add_default_func derived_defs
  1922           #> pair defs)
  1922           #> pair defs)
  1923     val (((sel_defs, upd_defs), derived_defs), defs_thy) =
  1923     val (((sel_defs, upd_defs), derived_defs), defs_thy) =
  2274 
  2274 
  2275 val record_decl =
  2275 val record_decl =
  2276   P.type_args -- P.name --
  2276   P.type_args -- P.name --
  2277     (P.$$$ "=" |-- Scan.option (P.typ --| P.$$$ "+") -- Scan.repeat1 P.const);
  2277     (P.$$$ "=" |-- Scan.option (P.typ --| P.$$$ "+") -- Scan.repeat1 P.const);
  2278 
  2278 
  2279 val recordP =  
  2279 val _ =
  2280   OuterSyntax.command "record" "define extensible record" K.thy_decl
  2280   OuterSyntax.command "record" "define extensible record" K.thy_decl
  2281     (record_decl >> (fn (x, (y, z)) => Toplevel.theory (add_record x y z)));
  2281     (record_decl >> (fn (x, (y, z)) => Toplevel.theory (add_record x y z)));
  2282 
  2282 
  2283 val _ = OuterSyntax.add_parsers [recordP];
       
  2284 
       
  2285 end;
  2283 end;
  2286 
  2284 
  2287 end;
  2285 end;
  2288 
  2286 
  2289 
  2287