3 Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
5 Normalization by evaluation, based on generic code generator.
11 Const of string * Univ list (*named (uninterpreted) constants*)
12 | Free of string * Univ list
13 | BVar of int * Univ list
14 | Abs of (int * (Univ list -> Univ)) * Univ list;
15 val free: string -> Univ list -> Univ (*free (uninterpreted) variables*)
16 val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
17 (*abstractions as functions*)
18 val app: Univ -> Univ -> Univ (*explicit application*)
20 val univs_ref: (unit -> Univ list) ref
21 val lookup_fun: string -> Univ
23 val norm_conv: cterm -> thm
24 val norm_term: theory -> term -> term
27 val setup: theory -> theory
33 (* generic non-sense *)
35 val trace = ref false;
36 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
39 (** the semantical universe **)
42 Functions are given by their semantical function value. To avoid
43 trouble with the ML-type system, these functions have the most
44 generic type, that is "Univ list -> Univ". The calling convention is
45 that the arguments come as a list, the last argument first. In
46 other words, a function call that usually would look like
48 f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n)
50 would be in our convention called as
54 Moreover, to handle functions that are still waiting for some
55 arguments we have additionally a list of arguments collected to far
56 and the number of arguments we're still waiting for.
60 Const of string * Univ list (*named (uninterpreted) constants*)
61 | Free of string * Univ list (*free variables*)
62 | BVar of int * Univ list (*bound named variables*)
63 | Abs of (int * (Univ list -> Univ)) * Univ list
64 (*abstractions as closures*);
66 (* constructor functions *)
68 val free = curry Free;
69 fun abs n f ts = Abs ((n, f), ts);
70 fun app (Abs ((1, f), xs)) x = f (x :: xs)
71 | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
72 | app (Const (name, args)) x = Const (name, x :: args)
73 | app (Free (name, args)) x = Free (name, x :: args)
74 | app (BVar (name, args)) x = BVar (name, x :: args);
76 (* global functions store *)
78 structure Nbe_Functions = CodeDataFun
80 type T = Univ Graph.T;
81 val empty = Graph.empty;
82 fun merge _ = Graph.merge (K true);
83 fun purge _ NONE _ = Graph.empty
84 | purge NONE _ _ = Graph.empty
85 | purge (SOME thy) (SOME cs) gr = Graph.empty
88 map_filter (CodeName.const_rev thy) (Graph.keys gr);
89 val dels = (Graph.all_preds gr
90 o map (CodeName.const thy)
91 o filter (member (op =) cs_exisiting)
93 in Graph.del_nodes dels gr end*);
96 fun defined gr = can (Graph.get_node gr);
98 (* sandbox communication *)
100 val univs_ref = ref (fn () => [] : Univ list);
104 val gr_ref = ref NONE : Nbe_Functions.T option ref;
108 fun lookup_fun s = case ! gr_ref
109 of NONE => error "compile_univs"
110 | SOME gr => Graph.get_node gr s;
112 fun compile_univs tab ([], _) = []
113 | compile_univs tab (cs, raw_s) =
115 val _ = univs_ref := (fn () => []);
116 val s = "Nbe.univs_ref := " ^ raw_s;
117 val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
118 val _ = gr_ref := SOME tab;
119 val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
120 Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
122 val _ = gr_ref := NONE;
123 val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs;
129 (** assembling and compiling ML code from terms **)
131 (* abstract ML syntax *)
134 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
135 fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
136 fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
138 fun ml_Val v s = "val " ^ v ^ " = " ^ s;
140 "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
141 fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
143 fun ml_list es = "[" ^ commas es ^ "]";
145 val ml_delay = ml_abs "()"
147 fun ml_fundefs ([(name, [([], e)])]) =
148 "val " ^ name ^ " = " ^ e ^ "\n"
149 | ml_fundefs (eqs :: eqss) =
151 fun fundef (name, eqs) =
153 fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
154 in space_implode "\n | " (map eqn eqs) end;
156 (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
157 |> space_implode "\n"
161 (* nbe specific syntax *)
165 val name_const = prefix ^ "Const";
166 val name_free = prefix ^ "free";
167 val name_abs = prefix ^ "abs";
168 val name_app = prefix ^ "app";
169 val name_lookup_fun = prefix ^ "lookup_fun";
172 fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
173 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
174 fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
175 fun nbe_bound v = "v_" ^ v;
178 Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
180 fun nbe_abss 0 f = f `$` ml_list []
181 | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
183 fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
185 val nbe_value = "value";
189 open BasicCodeThingol;
191 (* greetings to Tarski *)
193 fun assemble_iterm thy is_fun num_args =
197 val (t', ts) = CodeThingol.unfold_app t
198 in of_iapp t' (fold (cons o of_iterm) ts []) end
199 and of_iconst c ts = case num_args c
200 of SOME n => if n <= length ts
201 then let val (args2, args1) = chop (length ts - n) ts
202 in nbe_apps (nbe_fun c `$` ml_list args1) args2
203 end else nbe_const c ts
204 | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
206 and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts
207 | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
208 | of_iapp ((v, _) `|-> t) ts =
209 nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
210 | of_iapp (ICase (((t, _), cs), t0)) ts =
211 nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
212 @ [("_", of_iterm t0)])) ts
215 fun assemble_fun thy is_fun num_args (c, eqns) =
217 val assemble_arg = assemble_iterm thy (K false) (K NONE);
218 val assemble_rhs = assemble_iterm thy is_fun num_args;
219 fun assemble_eqn (args, rhs) =
220 ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
221 val default_params = map nbe_bound
222 (Name.invent_list [] "a" ((the o num_args) c));
223 val default_eqn = ([ml_list default_params], nbe_const c default_params);
224 in map assemble_eqn eqns @ [default_eqn] end;
226 fun assemble_eqnss thy is_fun ([], deps) = ([], "")
227 | assemble_eqnss thy is_fun (eqnss, deps) =
229 val cs = map fst eqnss;
230 val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
231 val funs = fold (fold (CodeThingol.fold_constnames
232 (insert (op =))) o map snd o snd) eqnss [];
233 val bind_funs = map nbe_lookup (filter is_fun funs);
234 val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
235 (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
236 val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
238 in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
240 fun assemble_eval thy is_fun (((vs, ty), t), deps) =
242 val funs = CodeThingol.fold_constnames (insert (op =)) t [];
243 val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
244 val bind_funs = map nbe_lookup (filter is_fun funs);
245 val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
246 assemble_iterm thy is_fun (K NONE) t)])];
247 val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)]
249 in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
251 fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) =
253 | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) =
254 SOME ((name, map fst eqns), deps)
255 | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
257 | eqns_of_stmt ((_, CodeThingol.Datatype _), _) =
259 | eqns_of_stmt ((_, CodeThingol.Class _), _) =
261 | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
263 | eqns_of_stmt ((_, CodeThingol.Classparam _), _) =
265 | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
268 fun compile_stmts thy is_fun =
269 map_filter eqns_of_stmt
271 #> assemble_eqnss thy is_fun
272 #> compile_univs (Nbe_Functions.get thy);
274 fun eval_term thy is_fun =
275 assemble_eval thy is_fun
276 #> compile_univs (Nbe_Functions.get thy)
281 (** compilation and evaluation **)
283 (* ensure global functions *)
285 fun ensure_funs thy code =
287 fun add_dep (name, dep) gr =
288 if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep
289 then Graph.add_edge (name, dep) gr else gr;
290 fun compile' stmts gr =
292 val compiled = compile_stmts thy (defined gr) stmts;
293 val names = map (fst o fst) stmts;
294 val deps = maps snd stmts;
296 Nbe_Functions.change thy (fold Graph.new_node compiled
297 #> fold (fn name => fold (curry add_dep name) deps) names)
299 val nbe_gr = Nbe_Functions.get thy;
300 val stmtss = rev (Graph.strong_conn code)
301 |> (map o map_filter) (fn name => if defined nbe_gr name
303 else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
305 in fold compile' stmtss nbe_gr end;
309 fun term_of_univ thy t =
311 fun of_apps bounds (t, ts) =
312 fold_map (of_univ bounds) ts
313 #>> (fn ts' => list_comb (t, rev ts'))
314 and of_univ bounds (Const (name, ts)) typidx =
316 val SOME c = CodeName.const_rev thy name;
317 val T = Code.default_typ thy c;
318 val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
319 val typidx' = typidx + maxidx_of_typ T' + 1;
320 in of_apps bounds (Term.Const (c, T'), ts) typidx' end
321 | of_univ bounds (Free (name, ts)) typidx =
322 of_apps bounds (Term.Free (name, dummyT), ts) typidx
323 | of_univ bounds (BVar (name, ts)) typidx =
324 of_apps bounds (Bound (bounds - name - 1), ts) typidx
325 | of_univ bounds (t as Abs _) typidx =
327 |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
328 |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
329 in of_univ 0 t 0 |> fst end;
331 (* evaluation with type reconstruction *)
333 fun eval thy code t vs_ty_t deps =
336 fun subst_Frees [] = I
338 Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
341 subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
342 #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
344 singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t);
345 fun check_tvars t = if null (Term.term_tvars t) then t else
346 error ("Illegal schematic type variables in normalized term: "
347 ^ setmp show_types true (Sign.string_of_term thy) t);
350 |> eval_term thy (defined (ensure_funs thy code))
352 |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
354 |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
356 |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
358 |> tracing (fn _ => "---\n")
361 (* evaluation oracle *)
363 exception Norm of CodeThingol.code * term
364 * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
366 fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
367 Logic.mk_equals (t, eval thy code t vs_ty_t deps);
369 fun norm_invoke thy code t vs_ty_t deps =
370 Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
371 (*FIXME get rid of hardwired theory name*)
375 val thy = Thm.theory_of_cterm ct;
376 fun conv code vs_ty_t deps ct =
378 val t = Thm.term_of ct;
379 in norm_invoke thy code t vs_ty_t deps end;
380 in CodePackage.eval_conv thy conv ct end;
384 fun invoke code vs_ty_t deps t =
385 eval thy code t vs_ty_t deps;
386 in CodePackage.eval_term thy invoke #> Code.postprocess_term thy end;
388 (* evaluation command *)
390 fun norm_print_term ctxt modes t =
392 val thy = ProofContext.theory_of ctxt;
393 val t' = norm_term thy t;
394 val ty' = Term.type_of t';
395 val p = PrintMode.with_modes modes (fn () =>
396 Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t'), Pretty.fbrk,
397 Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt ty')]) ();
398 in Pretty.writeln p end;
403 fun norm_print_term_cmd (modes, s) state =
404 let val ctxt = Toplevel.context_of state
405 in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
407 val setup = Theory.add_oracle ("norm", norm_oracle)
409 local structure P = OuterParse and K = OuterKeyword in
411 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
414 OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
415 (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd));