src/Pure/Isar/overloading.ML
author wenzelm
Fri, 28 Oct 2011 23:16:50 +0200
changeset 46164 57def0b39696
parent 46162 57cd50f98fdc
child 46181 adaf2184b79d
permissions -rw-r--r--
refined Local_Theory.declaration {syntax = false, pervasive} semantics: update is applied to auxiliary context as well;
     1 (*  Title:      Pure/Isar/overloading.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Overloaded definitions without any discipline.
     5 *)
     6 
     7 signature OVERLOADING =
     8 sig
     9   type improvable_syntax
    10   val activate_improvable_syntax: Proof.context -> Proof.context
    11   val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
    12     -> Proof.context -> Proof.context
    13   val set_primary_constraints: Proof.context -> Proof.context
    14 
    15   val overloading: (string * (string * typ) * bool) list -> theory -> local_theory
    16   val overloading_cmd: (string * string * bool) list -> theory -> local_theory
    17 end;
    18 
    19 structure Overloading: OVERLOADING =
    20 struct
    21 
    22 (* generic check/uncheck combinators for improvable constants *)
    23 
    24 type improvable_syntax = ((((string * typ) list * (string * typ) list) *
    25   ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
    26     (term * term) list)) * bool);
    27 
    28 structure Improvable_Syntax = Proof_Data
    29 (
    30   type T = {
    31     primary_constraints: (string * typ) list,
    32     secondary_constraints: (string * typ) list,
    33     improve: string * typ -> (typ * typ) option,
    34     subst: string * typ -> (typ * term) option,
    35     consider_abbrevs: bool,
    36     unchecks: (term * term) list,
    37     passed: bool
    38   };
    39   fun init _ = {
    40     primary_constraints = [],
    41     secondary_constraints = [],
    42     improve = K NONE,
    43     subst = K NONE,
    44     consider_abbrevs = false,
    45     unchecks = [],
    46     passed = true
    47   };
    48 );
    49 
    50 fun map_improvable_syntax f = Improvable_Syntax.map (fn {primary_constraints,
    51     secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed} =>
    52   let
    53     val (((primary_constraints', secondary_constraints'),
    54       (((improve', subst'), consider_abbrevs'), unchecks')), passed')
    55         = f (((primary_constraints, secondary_constraints),
    56             (((improve, subst), consider_abbrevs), unchecks)), passed)
    57   in
    58    {primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
    59     improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
    60     unchecks = unchecks', passed = passed'}
    61   end);
    62 
    63 val mark_passed = (map_improvable_syntax o apsnd) (K true);
    64 
    65 fun improve_term_check ts ctxt =
    66   let
    67     val thy = Proof_Context.theory_of ctxt;
    68 
    69     val {secondary_constraints, improve, subst, consider_abbrevs, passed, ...} =
    70       Improvable_Syntax.get ctxt;
    71     val is_abbrev = consider_abbrevs andalso Proof_Context.abbrev_mode ctxt;
    72     val passed_or_abbrev = passed orelse is_abbrev;
    73     fun accumulate_improvements (Const (c, ty)) =
    74           (case improve (c, ty) of
    75             SOME ty_ty' => Sign.typ_match thy ty_ty'
    76           | _ => I)
    77       | accumulate_improvements _ = I;
    78     val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
    79     val ts' = (map o map_types) (Envir.subst_type improvements) ts;
    80     fun apply_subst t =
    81       Envir.expand_term
    82         (fn Const (c, ty) =>
    83           (case subst (c, ty) of
    84             SOME (ty', t') =>
    85               if Sign.typ_instance thy (ty, ty')
    86               then SOME (ty', apply_subst t') else NONE
    87           | NONE => NONE)
    88         | _ => NONE) t;
    89     val ts'' = if is_abbrev then ts' else map apply_subst ts';
    90   in
    91     if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE
    92     else if passed_or_abbrev then SOME (ts'', ctxt)
    93     else
    94       SOME (ts'', ctxt
    95         |> fold (Proof_Context.add_const_constraint o apsnd SOME) secondary_constraints
    96         |> mark_passed)
    97   end;
    98 
    99 fun rewrite_liberal thy unchecks t =
   100   (case try (Pattern.rewrite_term thy unchecks []) t of
   101     NONE => NONE
   102   | SOME t' => if t aconv t' then NONE else SOME t');
   103 
   104 fun improve_term_uncheck ts ctxt =
   105   let
   106     val thy = Proof_Context.theory_of ctxt;
   107     val {unchecks, ...} = Improvable_Syntax.get ctxt;
   108     val ts' = map (rewrite_liberal thy unchecks) ts;
   109   in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
   110 
   111 fun set_primary_constraints ctxt =
   112   let val {primary_constraints, ...} = Improvable_Syntax.get ctxt;
   113   in fold (Proof_Context.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
   114 
   115 val activate_improvable_syntax =
   116   Context.proof_map
   117     (Syntax.context_term_check 0 "improvement" improve_term_check
   118     #> Syntax.context_term_uncheck 0 "improvement" improve_term_uncheck)
   119   #> set_primary_constraints;
   120 
   121 
   122 (* overloading target *)
   123 
   124 structure Data = Proof_Data
   125 (
   126   type T = ((string * typ) * (string * bool)) list;
   127   fun init _ = [];
   128 );
   129 
   130 val get_overloading = Data.get o Local_Theory.target_of;
   131 val map_overloading = Local_Theory.target o Data.map;
   132 
   133 fun operation lthy b =
   134   get_overloading lthy
   135   |> get_first (fn ((c, _), (v, checked)) =>
   136       if Binding.name_of b = v then SOME (c, (v, checked)) else NONE);
   137 
   138 fun synchronize_syntax ctxt =
   139   let
   140     val overloading = Data.get ctxt;
   141     fun subst (c, ty) =
   142       (case AList.lookup (op =) overloading (c, ty) of
   143         SOME (v, _) => SOME (ty, Free (v, ty))
   144       | NONE => NONE);
   145     val unchecks =
   146       map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
   147   in 
   148     ctxt
   149     |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
   150   end
   151 
   152 fun define_overloaded (c, U) (v, checked) (b_def, rhs) =
   153   Local_Theory.background_theory_result
   154     (Thm.add_def_global (not checked) true
   155       (b_def, Logic.mk_equals (Const (c, Term.fastype_of rhs), rhs)))
   156   ##> map_overloading (filter_out (fn (_, (v', _)) => v' = v))
   157   ##> Local_Theory.target synchronize_syntax
   158   #-> (fn (_, def) => pair (Const (c, U), def))
   159 
   160 fun foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params) lthy =
   161   (case operation lthy b of
   162     SOME (c, (v, checked)) =>
   163       if mx <> NoSyn
   164       then error ("Illegal mixfix syntax for overloaded constant " ^ quote c)
   165       else lthy |> define_overloaded (c, U) (v, checked) (b_def, rhs)
   166   | NONE => lthy
   167       |> Generic_Target.theory_foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params));
   168 
   169 fun pretty lthy =
   170   let
   171     val overloading = get_overloading lthy;
   172     fun pr_operation ((c, ty), (v, _)) =
   173       Pretty.block (Pretty.breaks
   174         [Pretty.str v, Pretty.str "==", Pretty.str (Proof_Context.extern_const lthy c),
   175           Pretty.str "::", Syntax.pretty_typ lthy ty]);
   176   in Pretty.str "overloading" :: map pr_operation overloading end;
   177 
   178 fun conclude lthy =
   179   let
   180     val overloading = get_overloading lthy;
   181     val _ =
   182       if null overloading then ()
   183       else
   184         error ("Missing definition(s) for parameter(s) " ^
   185           commas_quote (map (Syntax.string_of_term lthy o Const o fst) overloading));
   186   in lthy end;
   187 
   188 fun gen_overloading prep_const raw_overloading thy =
   189   let
   190     val ctxt = Proof_Context.init_global thy;
   191     val _ = if null raw_overloading then error "At least one parameter must be given" else ();
   192     val overloading = raw_overloading |> map (fn (v, const, checked) =>
   193       (Term.dest_Const (prep_const ctxt const), (v, checked)));
   194   in
   195     thy
   196     |> Theory.checkpoint
   197     |> Proof_Context.init_global
   198     |> Data.put overloading
   199     |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
   200     |> activate_improvable_syntax
   201     |> synchronize_syntax
   202     |> Local_Theory.init NONE ""
   203        {define = Generic_Target.define foundation,
   204         notes = Generic_Target.notes
   205           (fn kind => fn global_facts => fn _ => Generic_Target.theory_notes kind global_facts),
   206         abbrev = Generic_Target.abbrev
   207           (fn prmode => fn (b, mx) => fn (t, _) => fn _ =>
   208             Generic_Target.theory_abbrev prmode ((b, mx), t)),
   209         declaration = Generic_Target.theory_declaration o #syntax,
   210         pretty = pretty,
   211         exit = Local_Theory.target_of o conclude}
   212   end;
   213 
   214 val overloading = gen_overloading (fn ctxt => Syntax.check_term ctxt o Const);
   215 val overloading_cmd = gen_overloading Syntax.read_term;
   216 
   217 end;