1 (* Title: Pure/Isar/calculation.ML
3 Author: Markus Wenzel, TU Muenchen
5 Support for calculational proofs.
8 signature CALCULATION =
10 val print_global_rules: theory -> unit
11 val print_local_rules: Proof.context -> unit
12 val trans_add_global: theory attribute
13 val trans_del_global: theory attribute
14 val trans_add_local: Proof.context attribute
15 val trans_del_local: Proof.context attribute
16 val sym_add_global: theory attribute
17 val sym_del_global: theory attribute
18 val sym_add_local: Proof.context attribute
19 val sym_del_local: Proof.context attribute
20 val symmetric_global: theory attribute
21 val symmetric_local: Proof.context attribute
22 val also: thm list option -> (Proof.context -> thm list -> unit)
23 -> Proof.state -> Proof.state Seq.seq
24 val finally: thm list option -> (Proof.context -> thm list -> unit)
25 -> Proof.state -> Proof.state Seq.seq
26 val moreover: (Proof.context -> thm list -> unit) -> Proof.state -> Proof.state
27 val ultimately: (Proof.context -> thm list -> unit) -> Proof.state -> Proof.state
30 structure Calculation: CALCULATION =
33 (** global and local calculation data **)
35 (* theory data kind 'Isar/calculation' *)
37 fun print_rules prt x (trans, sym) =
38 [Pretty.big_list "transitivity rules:" (map (prt x) (NetRules.rules trans)),
39 Pretty.big_list "symmetry rules:" (map (prt x) sym)]
40 |> Pretty.chunks |> Pretty.writeln;
42 structure GlobalCalculationArgs =
44 val name = "Isar/calculation";
45 type T = thm NetRules.T * thm list
47 val empty = (NetRules.elim, []);
50 fun merge ((trans1, sym1), (trans2, sym2)) =
51 (NetRules.merge (trans1, trans2), Drule.merge_rules (sym1, sym2));
52 val print = print_rules Display.pretty_thm_sg;
55 structure GlobalCalculation = TheoryDataFun(GlobalCalculationArgs);
56 val _ = Context.add_setup [GlobalCalculation.init];
57 val print_global_rules = GlobalCalculation.print;
60 (* proof data kind 'Isar/calculation' *)
62 structure LocalCalculationArgs =
64 val name = "Isar/calculation";
65 type T = (thm NetRules.T * thm list) * (thm list * int) option;
67 fun init thy = (GlobalCalculation.get thy, NONE);
68 fun print ctxt (rs, _) = print_rules ProofContext.pretty_thm ctxt rs;
71 structure LocalCalculation = ProofDataFun(LocalCalculationArgs);
72 val _ = Context.add_setup [LocalCalculation.init];
73 val get_local_rules = #1 o LocalCalculation.get o Proof.context_of;
74 val print_local_rules = LocalCalculation.print;
77 (* access calculation *)
79 fun get_calculation state =
80 (case #2 (LocalCalculation.get (Proof.context_of state)) of
82 | SOME (thms, lev) => if lev = Proof.level state then SOME thms else NONE);
84 fun put_calculation thms state =
86 (LocalCalculation.put (get_local_rules state, SOME (thms, Proof.level state))) state;
88 fun reset_calculation state =
89 Proof.map_context (LocalCalculation.put (get_local_rules state, NONE)) state;
97 fun global_att f (x, thm) = (GlobalCalculation.map (f thm) x, thm);
98 fun local_att f (x, thm) = (LocalCalculation.map (apfst (f thm)) x, thm);
100 val trans_add_global = global_att (apfst o NetRules.insert);
101 val trans_del_global = global_att (apfst o NetRules.delete);
102 val trans_add_local = local_att (apfst o NetRules.insert);
103 val trans_del_local = local_att (apfst o NetRules.delete);
105 val sym_add_global = global_att (apsnd o Drule.add_rule) o ContextRules.elim_query_global NONE;
106 val sym_del_global = global_att (apsnd o Drule.del_rule) o ContextRules.rule_del_global;
107 val sym_add_local = local_att (apsnd o Drule.add_rule) o ContextRules.elim_query_local NONE;
108 val sym_del_local = local_att (apsnd o Drule.del_rule) o ContextRules.rule_del_local;
113 fun gen_symmetric get_sym = Drule.rule_attribute (fn x => fn th =>
114 (case Seq.chop (2, Method.multi_resolves [th] (get_sym x)) of
116 | ([], _) => raise THM ("symmetric: no unifiers", 1, [th])
117 | _ => raise THM ("symmetric: multiple unifiers", 1, [th])));
119 val symmetric_global = gen_symmetric (#2 o GlobalCalculation.get);
120 val symmetric_local = gen_symmetric (#2 o #1 o LocalCalculation.get);
123 (* concrete syntax *)
126 (Attrib.add_del_args trans_add_global trans_del_global,
127 Attrib.add_del_args trans_add_local trans_del_local);
130 (Attrib.add_del_args sym_add_global sym_del_global,
131 Attrib.add_del_args sym_add_local sym_del_local);
133 val _ = Context.add_setup
134 [Attrib.add_attributes
135 [("trans", trans_attr, "declaration of transitivity rule"),
136 ("sym", sym_attr, "declaration of symmetry rule"),
137 ("symmetric", (Attrib.no_args symmetric_global, Attrib.no_args symmetric_local),
138 "resolution with symmetry rule")],
139 #1 o PureThy.add_thms
140 [(("", transitive_thm), [trans_add_global]),
141 (("", symmetric_thm), [sym_add_global])]];
145 (** proof commands **)
147 fun assert_sane final =
148 if final then Proof.assert_forward else Proof.assert_forward_or_chain;
151 (* maintain calculation register *)
153 val calculationN = "calculation";
155 fun maintain_calculation false calc state =
157 |> put_calculation calc
158 |> Proof.put_thms (calculationN, calc)
159 | maintain_calculation true calc state =
162 |> Proof.reset_thms calculationN
163 |> Proof.simple_note_thms "" calc
167 (* 'also' and 'finally' *)
169 fun err_if state b msg = if b then raise Proof.STATE (msg, state) else ();
171 fun calculate final opt_rules print state =
173 val strip_assums_concl = Logic.strip_assums_concl o Thm.prop_of;
174 val eq_prop = op aconv o pairself (Pattern.eta_contract o strip_assums_concl);
175 fun projection ths th = Library.exists (Library.curry eq_prop th) ths;
178 (case opt_rules of SOME rules => rules
180 (case ths of [] => NetRules.rules (#1 (get_local_rules state))
181 | th :: _ => NetRules.retrieve (#1 (get_local_rules state)) (strip_assums_concl th)))
182 |> Seq.of_list |> Seq.map (Method.multi_resolve ths) |> Seq.flat
183 |> Seq.filter (not o projection ths);
185 val facts = Proof.the_facts (assert_sane final state);
186 val (initial, calculations) =
187 (case get_calculation state of
188 NONE => (true, Seq.single facts)
189 | SOME calc => (false, Seq.map single (combine (calc @ facts))));
191 err_if state (initial andalso final) "No calculation yet";
192 err_if state (initial andalso is_some opt_rules) "Initial calculation -- no rules to be given";
193 calculations |> Seq.map (fn calc => (print (Proof.context_of state) calc;
194 state |> maintain_calculation final calc))
197 fun also print = calculate false print;
198 fun finally print = calculate true print;
201 (* 'moreover' and 'ultimately' *)
203 fun collect final print state =
205 val facts = Proof.the_facts (assert_sane final state);
206 val (initial, thms) =
207 (case get_calculation state of
209 | SOME thms => (false, thms));
210 val calc = thms @ facts;
212 err_if state (initial andalso final) "No calculation yet";
213 print (Proof.context_of state) calc;
214 state |> maintain_calculation final calc
217 fun moreover print = collect false print;
218 fun ultimately print = collect true print;