src/Pure/Isar/overloading.ML
author wenzelm
Thu, 26 Aug 2010 15:48:08 +0200
changeset 39032 2b3e054ae6fc
parent 38608 8b02c5bf1d0e
child 39624 df86b1b4ce10
permissions -rw-r--r--
renamed Local_Theory.theory(_result) to Local_Theory.background_theory(_result) to emphasize that this belongs to the infrastructure and is rarely appropriate in user-space tools;
     1 (*  Title:      Pure/Isar/overloading.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Overloaded definitions without any discipline.
     5 *)
     6 
     7 signature OVERLOADING =
     8 sig
     9   type improvable_syntax
    10   val add_improvable_syntax: Proof.context -> Proof.context
    11   val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
    12     -> Proof.context -> Proof.context
    13   val set_primary_constraints: Proof.context -> Proof.context
    14 
    15   val overloading: (string * (string * typ) * bool) list -> theory -> local_theory
    16   val overloading_cmd: (string * string * bool) list -> theory -> local_theory
    17 end;
    18 
    19 structure Overloading: OVERLOADING =
    20 struct
    21 
    22 (** generic check/uncheck combinators for improvable constants **)
    23 
    24 type improvable_syntax = ((((string * typ) list * (string * typ) list) *
    25   ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
    26     (term * term) list)) * bool);
    27 
    28 structure ImprovableSyntax = Proof_Data
    29 (
    30   type T = {
    31     primary_constraints: (string * typ) list,
    32     secondary_constraints: (string * typ) list,
    33     improve: string * typ -> (typ * typ) option,
    34     subst: string * typ -> (typ * term) option,
    35     consider_abbrevs: bool,
    36     unchecks: (term * term) list,
    37     passed: bool
    38   };
    39   fun init _ = {
    40     primary_constraints = [],
    41     secondary_constraints = [],
    42     improve = K NONE,
    43     subst = K NONE,
    44     consider_abbrevs = false,
    45     unchecks = [],
    46     passed = true
    47   };
    48 );
    49 
    50 fun map_improvable_syntax f = ImprovableSyntax.map (fn { primary_constraints,
    51   secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed } => let
    52     val (((primary_constraints', secondary_constraints'),
    53       (((improve', subst'), consider_abbrevs'), unchecks')), passed')
    54         = f (((primary_constraints, secondary_constraints),
    55             (((improve, subst), consider_abbrevs), unchecks)), passed)
    56   in { primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
    57     improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
    58     unchecks = unchecks', passed = passed'
    59   } end);
    60 
    61 val mark_passed = (map_improvable_syntax o apsnd) (K true);
    62 
    63 fun improve_term_check ts ctxt =
    64   let
    65     val { secondary_constraints, improve, subst, consider_abbrevs, passed, ... } =
    66       ImprovableSyntax.get ctxt;
    67     val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
    68     val is_abbrev = consider_abbrevs andalso ProofContext.abbrev_mode ctxt;
    69     val passed_or_abbrev = passed orelse is_abbrev;
    70     fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
    71          of SOME ty_ty' => Type.typ_match tsig ty_ty'
    72           | _ => I)
    73       | accumulate_improvements _ = I;
    74     val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
    75     val ts' = (map o map_types) (Envir.subst_type improvements) ts;
    76     fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
    77          of SOME (ty', t') =>
    78               if Type.typ_instance tsig (ty, ty')
    79               then SOME (ty', apply_subst t') else NONE
    80           | NONE => NONE)
    81         | _ => NONE) t;
    82     val ts'' = if is_abbrev then ts' else map apply_subst ts';
    83   in if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
    84     if passed_or_abbrev then SOME (ts'', ctxt)
    85     else SOME (ts'', ctxt
    86       |> fold (ProofContext.add_const_constraint o apsnd SOME) secondary_constraints
    87       |> mark_passed)
    88   end;
    89 
    90 fun rewrite_liberal thy unchecks t =
    91   case try (Pattern.rewrite_term thy unchecks []) t
    92    of NONE => NONE
    93     | SOME t' => if t aconv t' then NONE else SOME t';
    94 
    95 fun improve_term_uncheck ts ctxt =
    96   let
    97     val thy = ProofContext.theory_of ctxt;
    98     val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
    99     val ts' = map (rewrite_liberal thy unchecks) ts;
   100   in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
   101 
   102 fun set_primary_constraints ctxt =
   103   let
   104     val { primary_constraints, ... } = ImprovableSyntax.get ctxt;
   105   in fold (ProofContext.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
   106 
   107 val add_improvable_syntax =
   108   Context.proof_map
   109     (Syntax.add_term_check 0 "improvement" improve_term_check
   110     #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
   111   #> set_primary_constraints;
   112 
   113 
   114 (** overloading target **)
   115 
   116 structure Data = Proof_Data
   117 (
   118   type T = ((string * typ) * (string * bool)) list;
   119   fun init _ = [];
   120 );
   121 
   122 val get_overloading = Data.get o Local_Theory.target_of;
   123 val map_overloading = Local_Theory.target o Data.map;
   124 
   125 fun operation lthy b = get_overloading lthy
   126   |> get_first (fn ((c, _), (v, checked)) =>
   127       if Binding.name_of b = v then SOME (c, (v, checked)) else NONE);
   128 
   129 fun synchronize_syntax ctxt =
   130   let
   131     val overloading = Data.get ctxt;
   132     fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
   133      of SOME (v, _) => SOME (ty, Free (v, ty))
   134       | NONE => NONE;
   135     val unchecks =
   136       map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
   137   in 
   138     ctxt
   139     |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
   140   end
   141 
   142 fun define_overloaded (c, U) (v, checked) (b_def, rhs) =
   143   Local_Theory.background_theory_result
   144     (Thm.add_def (not checked) true (b_def, Logic.mk_equals (Const (c, Term.fastype_of rhs), rhs)))
   145   ##> map_overloading (filter_out (fn (_, (v', _)) => v' = v))
   146   ##> Local_Theory.target synchronize_syntax
   147   #-> (fn (_, def) => pair (Const (c, U), def))
   148 
   149 fun foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params) lthy =
   150   case operation lthy b
   151    of SOME (c, (v, checked)) => if mx <> NoSyn
   152        then error ("Illegal mixfix syntax for overloaded constant " ^ quote c)
   153         else lthy |> define_overloaded (c, U) (v, checked) (b_def, rhs)
   154     | NONE => lthy |>
   155         Generic_Target.theory_foundation (((b, U), mx), (b_def, rhs)) (type_params, term_params);
   156 
   157 fun pretty lthy =
   158   let
   159     val thy = ProofContext.theory_of lthy;
   160     val overloading = get_overloading lthy;
   161     fun pr_operation ((c, ty), (v, _)) =
   162       (Pretty.block o Pretty.breaks) [Pretty.str v, Pretty.str "==",
   163         Pretty.str (Sign.extern_const thy c), Pretty.str "::", Syntax.pretty_typ lthy ty];
   164   in Pretty.str "overloading" :: map pr_operation overloading end;
   165 
   166 fun conclude lthy =
   167   let
   168     val overloading = get_overloading lthy;
   169     val _ = if null overloading then () else
   170       error ("Missing definition(s) for parameter(s) " ^ commas (map (quote
   171         o Syntax.string_of_term lthy o Const o fst) overloading));
   172   in lthy end;
   173 
   174 fun gen_overloading prep_const raw_overloading thy =
   175   let
   176     val ctxt = ProofContext.init_global thy;
   177     val _ = if null raw_overloading then error "At least one parameter must be given" else ();
   178     val overloading = raw_overloading |> map (fn (v, const, checked) =>
   179       (Term.dest_Const (prep_const ctxt const), (v, checked)));
   180   in
   181     thy
   182     |> Theory.checkpoint
   183     |> ProofContext.init_global
   184     |> Data.put overloading
   185     |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
   186     |> add_improvable_syntax
   187     |> synchronize_syntax
   188     |> Local_Theory.init NONE ""
   189        {define = Generic_Target.define foundation,
   190         notes = Generic_Target.notes
   191           (fn kind => fn global_facts => fn _ => Generic_Target.theory_notes kind global_facts),
   192         abbrev = Generic_Target.abbrev
   193           (fn prmode => fn (b, mx) => fn (t, _) => fn _ =>
   194             Generic_Target.theory_abbrev prmode ((b, mx), t)),
   195         declaration = K Generic_Target.theory_declaration,
   196         syntax_declaration = K Generic_Target.theory_declaration,
   197         pretty = pretty,
   198         exit = Local_Theory.target_of o conclude}
   199   end;
   200 
   201 val overloading = gen_overloading (fn ctxt => Syntax.check_term ctxt o Const);
   202 val overloading_cmd = gen_overloading Syntax.read_term;
   203 
   204 end;