1 (* Title: Tools/code/code_funcgr.ML
3 Author: Florian Haftmann, TU Muenchen
5 Retrieving, normalizing and structuring defining equations in graph
6 with explicit dependencies.
9 signature CODE_FUNCGR =
13 val funcs: T -> string -> thm list
14 val typ: T -> string -> typ
15 val all: T -> string list
16 val pretty: theory -> T -> Pretty.T
17 val make: theory -> string list -> T
18 val make_consts: theory -> string list -> string list * T
19 val eval_conv: theory -> (T -> cterm -> thm) -> cterm -> thm
20 val eval_term: theory -> (T -> cterm -> 'a) -> cterm -> 'a
23 structure CodeFuncgr : CODE_FUNCGR =
26 (** the graph type **)
28 type T = (typ * thm list) Graph.T;
31 these o Option.map snd o try (Graph.get_node funcgr);
34 fst o Graph.get_node funcgr;
36 fun all funcgr = Graph.keys funcgr;
38 fun pretty thy funcgr =
39 AList.make (snd o Graph.get_node funcgr) (Graph.keys funcgr)
40 |> (map o apfst) (CodeUnit.string_of_const thy)
41 |> sort (string_ord o pairself fst)
42 |> map (fn (s, thms) =>
43 (Pretty.block o Pretty.fbreaks) (
45 :: map Display.pretty_thm thms
50 (** generic combinators **)
52 fun fold_consts f thms =
54 |> maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of)
55 |> (fold o fold_aterms) (fn Const c => f c | _ => I);
57 fun consts_of (const, []) = []
58 | consts_of (const, thms as _ :: _) =
60 fun the_const (c, _) = if c = const then I else insert (op =) c
61 in fold_consts the_const thms [] end;
63 fun insts_of thy algebra c ty_decl ty =
65 val tys_decl = Sign.const_typargs thy (c, ty_decl);
66 val tys = Sign.const_typargs thy (c, ty);
67 fun class_relation (x, _) _ = x;
68 fun type_constructor tyco xs class =
69 (tyco, class) :: maps (maps fst) xs;
70 fun type_variable (TVar (_, sort)) = map (pair []) sort
71 | type_variable (TFree (_, sort)) = map (pair []) sort;
72 fun mk_inst ty (TVar (_, sort)) = cons (ty, sort)
73 | mk_inst ty (TFree (_, sort)) = cons (ty, sort)
74 | mk_inst (Type (_, tys1)) (Type (_, tys2)) = fold2 mk_inst tys1 tys2;
75 fun of_sort_deriv (ty, sort) =
76 Sorts.of_sort_derivation (Sign.pp thy) algebra
77 { class_relation = class_relation, type_constructor = type_constructor,
78 type_variable = type_variable }
81 flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
84 fun drop_classes thy tfrees thm =
86 val (_, thm') = Thm.varifyT' [] thm;
87 val tvars = Term.add_tvars (Thm.prop_of thm') [];
88 val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
89 val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
90 (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
93 |> fold Thm.unconstrainT unconstr
94 |> Thm.instantiate (instmap, [])
95 |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
99 (** graph algorithm **)
101 val timing = ref false;
105 exception INVALID of string list * string;
107 fun resort_thms algebra tap_typ [] = []
108 | resort_thms algebra tap_typ (thms as thm :: _) =
110 val thy = Thm.theory_of_thm thm;
111 val cs = fold_consts (insert (op =)) thms [];
112 fun match_const c (ty, ty_decl) =
114 val tys = Sign.const_typargs thy (c, ty);
115 val sorts = map (snd o dest_TVar) (Sign.const_typargs thy (c, ty_decl));
116 in fold2 (curry (CodeUnit.typ_sort_inst algebra)) tys sorts end;
119 of SOME ty_decl => match_const c (ty, ty_decl)
121 val tvars = fold match cs Vartab.empty;
122 in map (CodeUnit.inst_thm tvars) thms end;
124 fun resort_funcss thy algebra funcgr =
126 val typ_funcgr = try (fst o Graph.get_node funcgr);
127 fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
128 handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
129 ^ ",\nfor constant " ^ CodeUnit.string_of_const thy const
130 ^ "\nin defining equations\n"
131 ^ (cat_lines o map string_of_thm) thms)
132 fun resort_rec tap_typ (const, []) = (true, (const, []))
133 | resort_rec tap_typ (const, thms as thm :: _) =
135 val (_, ty) = CodeUnit.head_func thm;
136 val thms' as thm' :: _ = resort_thms algebra tap_typ thms
137 val (_, ty') = CodeUnit.head_func thm';
138 in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
139 fun resort_recs funcss =
142 AList.lookup (op =) funcss c
145 |> Option.map (snd o CodeUnit.head_func);
146 val (unchangeds, funcss') = split_list (map (resort_rec tap_typ) funcss);
147 val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
148 in (unchanged, funcss') end;
149 fun resort_rec_until funcss =
151 val (unchanged, funcss') = resort_recs funcss;
152 in if unchanged then funcss' else resort_rec_until funcss' end;
153 in map resort_dep #> resort_rec_until end;
155 fun instances_of thy algebra insts =
157 val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
158 fun all_classops tyco class =
159 try (AxClass.params_of_class thy) class
162 |> map (fn (c, _) => Class.inst_const thy (c, tyco))
165 |> fold (fn (tyco, class) =>
166 Symtab.map_default (tyco, []) (insert (op =) class)) insts
167 |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops tyco)
168 (Graph.all_succs thy_classes classes))) tab [])
171 fun instances_of_consts thy algebra funcgr consts =
173 fun inst (cexpr as (c, ty)) = insts_of thy algebra c
174 ((fst o Graph.get_node funcgr) c) ty handle CLASS_ERROR => [];
177 |> fold (fold (insert (op =)) o inst) consts
178 |> instances_of thy algebra
181 fun ensure_const' thy algebra funcgr const auxgr =
182 if can (Graph.get_node funcgr) const
184 else if can (Graph.get_node auxgr) const
185 then (SOME const, auxgr)
186 else if is_some (Code.get_datatype_of_constr thy const) then
188 |> Graph.new_node (const, [])
191 val thms = Code.these_funcs thy const
192 |> CodeUnit.norm_args
193 |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
194 val rhs = consts_of (const, thms);
197 |> Graph.new_node (const, thms)
198 |> fold_map (ensure_const thy algebra funcgr) rhs
199 |-> (fn rhs' => fold (fn SOME const' => Graph.add_edge (const, const')
203 and ensure_const thy algebra funcgr const =
205 val timeap = if !timing
206 then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
208 in timeap (ensure_const' thy algebra funcgr const) end;
210 fun merge_funcss thy algebra raw_funcss funcgr =
212 val funcss = raw_funcss
213 |> resort_funcss thy algebra funcgr
214 |> filter_out (can (Graph.get_node funcgr) o fst);
215 fun typ_func c [] = Code.default_typ thy c
216 | typ_func c (thms as thm :: _) = case Class.param_const thy c
217 of SOME (c', tyco) =>
219 val (_, ty) = CodeUnit.head_func thm;
220 val SOME class = AxClass.class_of_param thy c';
221 val sorts_decl = Sorts.mg_domain algebra tyco [class];
222 val tys = Sign.const_typargs thy (c, ty);
223 val sorts = map (snd o dest_TVar) tys;
224 in if sorts = sorts_decl then ty
225 else raise INVALID ([c], "Illegal instantation for class operation "
226 ^ CodeUnit.string_of_const thy c
227 ^ "\nin defining equations\n"
228 ^ (cat_lines o map string_of_thm) thms)
230 | NONE => (snd o CodeUnit.head_func) thm;
231 fun add_funcs (const, thms) =
232 Graph.new_node (const, (typ_func const thms, thms));
233 fun add_deps (funcs as (const, thms)) funcgr =
235 val deps = consts_of funcs;
236 val insts = instances_of_consts thy algebra funcgr
237 (fold_consts (insert (op =)) thms []);
240 |> ensure_consts' thy algebra insts
241 |> fold (curry Graph.add_edge const) deps
242 |> fold (curry Graph.add_edge const) insts
246 |> fold add_funcs funcss
247 |> fold add_deps funcss
249 and ensure_consts' thy algebra cs funcgr =
251 val auxgr = Graph.empty
252 |> fold (snd oo ensure_const thy algebra funcgr) cs;
255 |> fold (merge_funcss thy algebra)
256 (map (AList.make (Graph.get_node auxgr))
257 (rev (Graph.strong_conn auxgr)))
258 end handle INVALID (cs', msg)
259 => raise INVALID (fold (insert (op =)) cs' cs, msg);
263 (** retrieval interfaces **)
265 fun ensure_consts thy algebra consts funcgr =
266 ensure_consts' thy algebra consts funcgr
267 handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
268 ^ commas (map (CodeUnit.string_of_const thy) cs'));
270 fun check_consts thy consts funcgr =
272 val algebra = Code.coregular_algebra thy;
273 fun try_const const funcgr =
274 (SOME const, ensure_consts' thy algebra [const] funcgr)
275 handle INVALID (cs', msg) => (NONE, funcgr);
276 val (consts', funcgr') = fold_map try_const consts funcgr;
277 in (map_filter I consts', funcgr') end;
279 fun raw_eval thy f ct funcgr =
281 val algebra = Code.coregular_algebra thy;
282 fun consts_of ct = fold_aterms (fn Const c_ty => cons c_ty | _ => I)
284 val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
285 val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
286 val thm1 = Code.preprocess_conv ct;
287 val ct' = Thm.rhs_of thm1;
288 val cs = map fst (consts_of ct');
289 val funcgr' = ensure_consts thy algebra cs funcgr;
290 val (_, thm2) = Thm.varifyT' [] thm1;
291 val thm3 = Thm.reflexive (Thm.rhs_of thm2);
292 val [thm4] = resort_thms algebra (try (fst o Graph.get_node funcgr')) [thm3];
293 val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
296 val tvars = Term.add_tvars (Thm.prop_of thm) [];
297 val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
298 (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
299 in Thm.instantiate (instmap, []) thm end;
300 val thm5 = inst thm2;
301 val thm6 = inst thm4;
302 val ct'' = Thm.rhs_of thm6;
303 val c_exprs = consts_of ct'';
304 val drop = drop_classes thy tfrees;
305 val instdefs = instances_of_consts thy algebra funcgr' c_exprs;
306 val funcgr'' = ensure_consts thy algebra instdefs funcgr';
307 in (f drop thm5 funcgr'' ct'' , funcgr'') end;
309 fun raw_eval_conv thy conv =
311 fun conv' drop_classes thm1 funcgr ct =
313 val thm2 = conv funcgr ct;
314 val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
315 val thm23 = drop_classes (Thm.transitive thm2 thm3);
317 Thm.transitive thm1 thm23 handle THM _ =>
318 error ("could not construct proof:\n"
319 ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
321 in raw_eval thy conv' end;
323 fun raw_eval_term thy f =
325 fun f' _ _ funcgr ct = f funcgr ct;
326 in raw_eval thy f' end;
330 structure Funcgr = CodeDataFun
333 val empty = Graph.empty;
334 fun merge _ _ = Graph.empty;
335 fun purge _ NONE _ = Graph.empty
336 | purge _ (SOME cs) funcgr =
337 Graph.del_nodes ((Graph.all_preds funcgr
338 o filter (can (Graph.get_node funcgr))) cs) funcgr;
342 Funcgr.change thy o ensure_consts thy (Code.coregular_algebra thy);
344 fun make_consts thy =
345 Funcgr.change_yield thy o check_consts thy;
347 fun eval_conv thy f =
348 fst o Funcgr.change_yield thy o raw_eval_conv thy f;
350 fun eval_term thy f =
351 fst o Funcgr.change_yield thy o raw_eval_term thy f;