src/Pure/Isar/overloading.ML
changeset 43231 da8817d01e7c
parent 43230 6ca5407863ed
child 43246 774df7c59508
equal deleted inserted replaced
43230:6ca5407863ed 43231:da8817d01e7c
    63 
    63 
    64 fun improve_term_check ts ctxt =
    64 fun improve_term_check ts ctxt =
    65   let
    65   let
    66     val { secondary_constraints, improve, subst, consider_abbrevs, passed, ... } =
    66     val { secondary_constraints, improve, subst, consider_abbrevs, passed, ... } =
    67       ImprovableSyntax.get ctxt;
    67       ImprovableSyntax.get ctxt;
    68     val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
    68     val tsig = (Sign.tsig_of o Proof_Context.theory_of) ctxt;
    69     val is_abbrev = consider_abbrevs andalso ProofContext.abbrev_mode ctxt;
    69     val is_abbrev = consider_abbrevs andalso Proof_Context.abbrev_mode ctxt;
    70     val passed_or_abbrev = passed orelse is_abbrev;
    70     val passed_or_abbrev = passed orelse is_abbrev;
    71     fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
    71     fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
    72          of SOME ty_ty' => Type.typ_match tsig ty_ty'
    72          of SOME ty_ty' => Type.typ_match tsig ty_ty'
    73           | _ => I)
    73           | _ => I)
    74       | accumulate_improvements _ = I;
    74       | accumulate_improvements _ = I;
    83     val ts'' = if is_abbrev then ts' else map apply_subst ts';
    83     val ts'' = if is_abbrev then ts' else map apply_subst ts';
    84   in
    84   in
    85     if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
    85     if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
    86     if passed_or_abbrev then SOME (ts'', ctxt)
    86     if passed_or_abbrev then SOME (ts'', ctxt)
    87     else SOME (ts'', ctxt
    87     else SOME (ts'', ctxt
    88       |> fold (ProofContext.add_const_constraint o apsnd SOME) secondary_constraints
    88       |> fold (Proof_Context.add_const_constraint o apsnd SOME) secondary_constraints
    89       |> mark_passed)
    89       |> mark_passed)
    90   end;
    90   end;
    91 
    91 
    92 fun rewrite_liberal thy unchecks t =
    92 fun rewrite_liberal thy unchecks t =
    93   case try (Pattern.rewrite_term thy unchecks []) t
    93   case try (Pattern.rewrite_term thy unchecks []) t
    94    of NONE => NONE
    94    of NONE => NONE
    95     | SOME t' => if t aconv t' then NONE else SOME t';
    95     | SOME t' => if t aconv t' then NONE else SOME t';
    96 
    96 
    97 fun improve_term_uncheck ts ctxt =
    97 fun improve_term_uncheck ts ctxt =
    98   let
    98   let
    99     val thy = ProofContext.theory_of ctxt;
    99     val thy = Proof_Context.theory_of ctxt;
   100     val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
   100     val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
   101     val ts' = map (rewrite_liberal thy unchecks) ts;
   101     val ts' = map (rewrite_liberal thy unchecks) ts;
   102   in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
   102   in if exists is_some ts' then SOME (map2 the_default ts ts', ctxt) else NONE end;
   103 
   103 
   104 fun set_primary_constraints ctxt =
   104 fun set_primary_constraints ctxt =
   105   let
   105   let
   106     val { primary_constraints, ... } = ImprovableSyntax.get ctxt;
   106     val { primary_constraints, ... } = ImprovableSyntax.get ctxt;
   107   in fold (ProofContext.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
   107   in fold (Proof_Context.add_const_constraint o apsnd SOME) primary_constraints ctxt end;
   108 
   108 
   109 val activate_improvable_syntax =
   109 val activate_improvable_syntax =
   110   Context.proof_map
   110   Context.proof_map
   111     (Syntax.add_term_check 0 "improvement" improve_term_check
   111     (Syntax.add_term_check 0 "improvement" improve_term_check
   112     #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
   112     #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
   159 fun pretty lthy =
   159 fun pretty lthy =
   160   let
   160   let
   161     val overloading = get_overloading lthy;
   161     val overloading = get_overloading lthy;
   162     fun pr_operation ((c, ty), (v, _)) =
   162     fun pr_operation ((c, ty), (v, _)) =
   163       Pretty.block (Pretty.breaks
   163       Pretty.block (Pretty.breaks
   164         [Pretty.str v, Pretty.str "==", Pretty.str (ProofContext.extern_const lthy c),
   164         [Pretty.str v, Pretty.str "==", Pretty.str (Proof_Context.extern_const lthy c),
   165           Pretty.str "::", Syntax.pretty_typ lthy ty]);
   165           Pretty.str "::", Syntax.pretty_typ lthy ty]);
   166   in Pretty.str "overloading" :: map pr_operation overloading end;
   166   in Pretty.str "overloading" :: map pr_operation overloading end;
   167 
   167 
   168 fun conclude lthy =
   168 fun conclude lthy =
   169   let
   169   let
   175           o Syntax.string_of_term lthy o Const o fst) overloading));
   175           o Syntax.string_of_term lthy o Const o fst) overloading));
   176   in lthy end;
   176   in lthy end;
   177 
   177 
   178 fun gen_overloading prep_const raw_overloading thy =
   178 fun gen_overloading prep_const raw_overloading thy =
   179   let
   179   let
   180     val ctxt = ProofContext.init_global thy;
   180     val ctxt = Proof_Context.init_global thy;
   181     val _ = if null raw_overloading then error "At least one parameter must be given" else ();
   181     val _ = if null raw_overloading then error "At least one parameter must be given" else ();
   182     val overloading = raw_overloading |> map (fn (v, const, checked) =>
   182     val overloading = raw_overloading |> map (fn (v, const, checked) =>
   183       (Term.dest_Const (prep_const ctxt const), (v, checked)));
   183       (Term.dest_Const (prep_const ctxt const), (v, checked)));
   184   in
   184   in
   185     thy
   185     thy
   186     |> Theory.checkpoint
   186     |> Theory.checkpoint
   187     |> ProofContext.init_global
   187     |> Proof_Context.init_global
   188     |> Data.put overloading
   188     |> Data.put overloading
   189     |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
   189     |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
   190     |> activate_improvable_syntax
   190     |> activate_improvable_syntax
   191     |> synchronize_syntax
   191     |> synchronize_syntax
   192     |> Local_Theory.init NONE ""
   192     |> Local_Theory.init NONE ""