1 (* Title: HOL/Matrix/Compute_Oracle/am_sml.ML |
|
2 Author: Steven Obua |
|
3 |
|
4 TODO: "parameterless rewrite cannot be used in pattern": In a lot of |
|
5 cases it CAN be used, and these cases should be handled |
|
6 properly; right now, all cases raise an exception. |
|
7 *) |
|
8 |
|
9 signature AM_SML = |
|
10 sig |
|
11 include ABSTRACT_MACHINE |
|
12 val save_result : (string * term) -> unit |
|
13 val set_compiled_rewriter : (term -> term) -> unit |
|
14 val list_nth : 'a list * int -> 'a |
|
15 val dump_output : (string option) Unsynchronized.ref |
|
16 end |
|
17 |
|
18 structure AM_SML : AM_SML = struct |
|
19 |
|
20 open AbstractMachine; |
|
21 |
|
22 val dump_output = Unsynchronized.ref (NONE: string option) |
|
23 |
|
24 type program = term Inttab.table * (term -> term) |
|
25 |
|
26 val saved_result = Unsynchronized.ref (NONE:(string*term)option) |
|
27 |
|
28 fun save_result r = (saved_result := SOME r) |
|
29 |
|
30 val list_nth = List.nth |
|
31 |
|
32 val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option) |
|
33 |
|
34 fun set_compiled_rewriter r = (compiled_rewriter := SOME r) |
|
35 |
|
36 fun count_patternvars PVar = 1 |
|
37 | count_patternvars (PConst (_, ps)) = |
|
38 List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps |
|
39 |
|
40 fun update_arity arity code a = |
|
41 (case Inttab.lookup arity code of |
|
42 NONE => Inttab.update_new (code, a) arity |
|
43 | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) |
|
44 |
|
45 (* We have to find out the maximal arity of each constant *) |
|
46 fun collect_pattern_arity PVar arity = arity |
|
47 | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) |
|
48 |
|
49 (* We also need to find out the maximal toplevel arity of each function constant *) |
|
50 fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity" |
|
51 | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args) |
|
52 |
|
53 local |
|
54 fun collect applevel (Var _) arity = arity |
|
55 | collect applevel (Const c) arity = update_arity arity c applevel |
|
56 | collect applevel (Abs m) arity = collect 0 m arity |
|
57 | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) |
|
58 in |
|
59 fun collect_term_arity t arity = collect 0 t arity |
|
60 end |
|
61 |
|
62 fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity) |
|
63 |
|
64 |
|
65 fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x) |
|
66 |
|
67 fun beta (Const c) = Const c |
|
68 | beta (Var i) = Var i |
|
69 | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b))) |
|
70 | beta (App (a, b)) = |
|
71 (case beta a of |
|
72 Abs m => beta (App (Abs m, b)) |
|
73 | a => App (a, beta b)) |
|
74 | beta (Abs m) = Abs (beta m) |
|
75 | beta (Computed t) = Computed t |
|
76 and subst x (Const c) t = Const c |
|
77 | subst x (Var i) t = if i = x then t else Var i |
|
78 | subst x (App (a,b)) t = App (subst x a t, subst x b t) |
|
79 | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t)) |
|
80 and lift level (Const c) = Const c |
|
81 | lift level (App (a,b)) = App (lift level a, lift level b) |
|
82 | lift level (Var i) = if i < level then Var i else Var (i+1) |
|
83 | lift level (Abs m) = Abs (lift (level + 1) m) |
|
84 and unlift level (Const c) = Const c |
|
85 | unlift level (App (a, b)) = App (unlift level a, unlift level b) |
|
86 | unlift level (Abs m) = Abs (unlift (level+1) m) |
|
87 | unlift level (Var i) = if i < level then Var i else Var (i-1) |
|
88 |
|
89 fun nlift level n (Var m) = if m < level then Var m else Var (m+n) |
|
90 | nlift level n (Const c) = Const c |
|
91 | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) |
|
92 | nlift level n (Abs b) = Abs (nlift (level+1) n b) |
|
93 |
|
94 fun subst_const (c, t) (Const c') = if c = c' then t else Const c' |
|
95 | subst_const _ (Var i) = Var i |
|
96 | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b) |
|
97 | subst_const ct (Abs m) = Abs (subst_const ct m) |
|
98 |
|
99 (* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *) |
|
100 fun inline_rules rules = |
|
101 let |
|
102 fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b |
|
103 | term_contains_const c (Abs m) = term_contains_const c m |
|
104 | term_contains_const c (Var _) = false |
|
105 | term_contains_const c (Const c') = (c = c') |
|
106 fun find_rewrite [] = NONE |
|
107 | find_rewrite ((prems, PConst (c, []), r) :: _) = |
|
108 if check_freevars 0 r then |
|
109 if term_contains_const c r then |
|
110 raise Compile "parameterless rewrite is caught in cycle" |
|
111 else if not (null prems) then |
|
112 raise Compile "parameterless rewrite may not be guarded" |
|
113 else |
|
114 SOME (c, r) |
|
115 else raise Compile "unbound variable on right hand side or guards of rule" |
|
116 | find_rewrite (_ :: rules) = find_rewrite rules |
|
117 fun remove_rewrite _ [] = [] |
|
118 | remove_rewrite (cr as (c, r)) ((rule as (prems', PConst (c', args), r')) :: rules) = |
|
119 if c = c' then |
|
120 if null args andalso r = r' andalso null prems' then remove_rewrite cr rules |
|
121 else raise Compile "incompatible parameterless rewrites found" |
|
122 else |
|
123 rule :: remove_rewrite cr rules |
|
124 | remove_rewrite cr (r :: rs) = r :: remove_rewrite cr rs |
|
125 fun pattern_contains_const c (PConst (c', args)) = c = c' orelse exists (pattern_contains_const c) args |
|
126 | pattern_contains_const c (PVar) = false |
|
127 fun inline_rewrite (ct as (c, _)) (prems, p, r) = |
|
128 if pattern_contains_const c p then |
|
129 raise Compile "parameterless rewrite cannot be used in pattern" |
|
130 else (map (fn (Guard (a, b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r) |
|
131 fun inline inlined rules = |
|
132 case find_rewrite rules of |
|
133 NONE => (Inttab.make inlined, rules) |
|
134 | SOME ct => |
|
135 let |
|
136 val rules = map (inline_rewrite ct) (remove_rewrite ct rules) |
|
137 val inlined = ct :: (map o apsnd) (subst_const ct) inlined |
|
138 in inline inlined rules end |
|
139 in |
|
140 inline [] rules |
|
141 end |
|
142 |
|
143 |
|
144 (* |
|
145 Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity. |
|
146 Also beta reduce the adjusted right hand side of a rule. |
|
147 *) |
|
148 fun adjust_rules rules = |
|
149 let |
|
150 val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty |
|
151 val toplevel_arity = fold (fn (_, p, _) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty |
|
152 fun arity_of c = the (Inttab.lookup arity c) |
|
153 fun test_pattern PVar = () |
|
154 | test_pattern (PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ()) |
|
155 fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") |
|
156 | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") |
|
157 | adjust_rule (rule as (prems, p as PConst (c, args),t)) = |
|
158 let |
|
159 val patternvars_counted = count_patternvars p |
|
160 fun check_fv t = check_freevars patternvars_counted t |
|
161 val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () |
|
162 val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () |
|
163 val _ = map test_pattern args |
|
164 val len = length args |
|
165 val arity = arity_of c |
|
166 val lift = nlift 0 |
|
167 fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1))) |
|
168 fun adjust_term n t = addapps_tm n (lift n t) |
|
169 fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b) |
|
170 in |
|
171 if len = arity then |
|
172 rule |
|
173 else if arity >= len then |
|
174 (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) |
|
175 else (raise Compile "internal error in adjust_rule") |
|
176 end |
|
177 fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule") |
|
178 in |
|
179 (arity, toplevel_arity, map (beta_rule o adjust_rule) rules) |
|
180 end |
|
181 |
|
182 fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count = |
|
183 let |
|
184 fun str x = string_of_int x |
|
185 fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s |
|
186 val module_prefix = (case module of NONE => "" | SOME s => s^".") |
|
187 fun print_apps d f [] = f |
|
188 | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args |
|
189 and print_call d (App (a, b)) args = print_call d a (b::args) |
|
190 | print_call d (Const c) args = |
|
191 (case arity_of c of |
|
192 NONE => print_apps d (module_prefix^"Const "^(str c)) args |
|
193 | SOME 0 => module_prefix^"C"^(str c) |
|
194 | SOME a => |
|
195 let |
|
196 val len = length args |
|
197 in |
|
198 if a <= len then |
|
199 let |
|
200 val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a) |
|
201 val _ = if strict_a > a then raise Compile "strict" else () |
|
202 val s = module_prefix^"c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a)))) |
|
203 val s = s^(implode (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a)))) |
|
204 in |
|
205 print_apps d s (List.drop (args, a)) |
|
206 end |
|
207 else |
|
208 let |
|
209 fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1))) |
|
210 fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) |
|
211 fun append_args [] t = t |
|
212 | append_args (c::cs) t = append_args cs (App (t, c)) |
|
213 in |
|
214 print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) |
|
215 end |
|
216 end) |
|
217 | print_call d t args = print_apps d (print_term d t) args |
|
218 and print_term d (Var x) = |
|
219 if x < d then |
|
220 "b"^(str (d-x-1)) |
|
221 else |
|
222 let |
|
223 val n = pattern_var_count - (x-d) - 1 |
|
224 val x = "x"^(str n) |
|
225 in |
|
226 if n < pattern_var_count - pattern_lazy_var_count then |
|
227 x |
|
228 else |
|
229 "("^x^" ())" |
|
230 end |
|
231 | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")" |
|
232 | print_term d t = print_call d t [] |
|
233 in |
|
234 print_term 0 |
|
235 end |
|
236 |
|
237 fun section n = if n = 0 then [] else (section (n-1))@[n-1] |
|
238 |
|
239 fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = |
|
240 let |
|
241 fun str x = string_of_int x |
|
242 fun print_pattern top n PVar = (n+1, "x"^(str n)) |
|
243 | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")) |
|
244 | print_pattern top n (PConst (c, args)) = |
|
245 let |
|
246 val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "") |
|
247 val (n, s) = print_pattern_list 0 top (n, f) args |
|
248 in |
|
249 (n, s) |
|
250 end |
|
251 and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")") |
|
252 | print_pattern_list' counter top (n, p) (t::ts) = |
|
253 let |
|
254 val (n, t) = print_pattern false n t |
|
255 in |
|
256 print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts |
|
257 end |
|
258 and print_pattern_list counter top (n, p) (t::ts) = |
|
259 let |
|
260 val (n, t) = print_pattern false n t |
|
261 in |
|
262 print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts |
|
263 end |
|
264 val c = (case p of PConst (c, _) => c | _ => raise Match) |
|
265 val (n, pattern) = print_pattern true 0 p |
|
266 val lazy_vars = the (arity_of c) - the (toplevel_arity_of c) |
|
267 fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm |
|
268 fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")" |
|
269 val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(implode (map (fn i => " a"^(str i)) (section (the (arity_of c))))) |
|
270 fun print_guards t [] = print_tm t |
|
271 | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(implode (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch |
|
272 in |
|
273 (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards)) |
|
274 end |
|
275 |
|
276 fun group_rules rules = |
|
277 let |
|
278 fun add_rule (r as (_, PConst (c,_), _)) groups = |
|
279 let |
|
280 val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) |
|
281 in |
|
282 Inttab.update (c, r::rs) groups |
|
283 end |
|
284 | add_rule _ _ = raise Compile "internal error group_rules" |
|
285 in |
|
286 fold_rev add_rule rules Inttab.empty |
|
287 end |
|
288 |
|
289 fun sml_prog name code rules = |
|
290 let |
|
291 val buffer = Unsynchronized.ref "" |
|
292 fun write s = (buffer := (!buffer)^s) |
|
293 fun writeln s = (write s; write "\n") |
|
294 fun writelist [] = () |
|
295 | writelist (s::ss) = (writeln s; writelist ss) |
|
296 fun str i = string_of_int i |
|
297 val (inlinetab, rules) = inline_rules rules |
|
298 val (arity, toplevel_arity, rules) = adjust_rules rules |
|
299 val rules = group_rules rules |
|
300 val constants = Inttab.keys arity |
|
301 fun arity_of c = Inttab.lookup arity c |
|
302 fun toplevel_arity_of c = Inttab.lookup toplevel_arity c |
|
303 fun rep_str s n = implode (rep n s) |
|
304 fun indexed s n = s^(str n) |
|
305 fun string_of_tuple [] = "" |
|
306 | string_of_tuple (x::xs) = "("^x^(implode (map (fn s => ", "^s) xs))^")" |
|
307 fun string_of_args [] = "" |
|
308 | string_of_args (x::xs) = x^(implode (map (fn s => " "^s) xs)) |
|
309 fun default_case gnum c = |
|
310 let |
|
311 val leftargs = implode (map (indexed " x") (section (the (arity_of c)))) |
|
312 val rightargs = section (the (arity_of c)) |
|
313 val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa) |
|
314 val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs |
|
315 val right = (indexed "C" c)^" "^(string_of_tuple xs) |
|
316 val message = "(\"unresolved lazy call: " ^ string_of_int c ^ "\")" |
|
317 val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right |
|
318 in |
|
319 (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right |
|
320 end |
|
321 |
|
322 fun eval_rules c = |
|
323 let |
|
324 val arity = the (arity_of c) |
|
325 val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa) |
|
326 fun eval_rule n = |
|
327 let |
|
328 val sc = string_of_int c |
|
329 val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc) |
|
330 fun arg i = |
|
331 let |
|
332 val x = indexed "x" i |
|
333 val x = if i < n then "(eval bounds "^x^")" else x |
|
334 val x = if i < strict_arity then x else "(fn () => "^x^")" |
|
335 in |
|
336 x |
|
337 end |
|
338 val right = "c"^sc^" "^(string_of_args (map arg (section arity))) |
|
339 val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right |
|
340 val right = if arity > 0 then right else "C"^sc |
|
341 in |
|
342 " | eval bounds ("^left^") = "^right |
|
343 end |
|
344 in |
|
345 map eval_rule (rev (section (arity + 1))) |
|
346 end |
|
347 |
|
348 fun convert_computed_rules (c: int) : string list = |
|
349 let |
|
350 val arity = the (arity_of c) |
|
351 fun eval_rule () = |
|
352 let |
|
353 val sc = string_of_int c |
|
354 val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc) |
|
355 fun arg i = "(convert_computed "^(indexed "x" i)^")" |
|
356 val right = "C"^sc^" "^(string_of_tuple (map arg (section arity))) |
|
357 val right = if arity > 0 then right else "C"^sc |
|
358 in |
|
359 " | convert_computed ("^left^") = "^right |
|
360 end |
|
361 in |
|
362 [eval_rule ()] |
|
363 end |
|
364 |
|
365 fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else "" |
|
366 val _ = writelist [ |
|
367 "structure "^name^" = struct", |
|
368 "", |
|
369 "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)", |
|
370 " "^(implode (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)), |
|
371 ""] |
|
372 fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")" |
|
373 fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^ |
|
374 (case the (arity_of c) of |
|
375 0 => "true" |
|
376 | n => |
|
377 let |
|
378 val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n) |
|
379 val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs)) |
|
380 in |
|
381 eq^(implode eqs) |
|
382 end) |
|
383 val _ = writelist [ |
|
384 "fun term_eq (Const c1) (Const c2) = (c1 = c2)", |
|
385 " | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"] |
|
386 val _ = writelist (map make_term_eq constants) |
|
387 val _ = writelist [ |
|
388 " | term_eq _ _ = false", |
|
389 "" |
|
390 ] |
|
391 val _ = writelist [ |
|
392 "fun app (Abs a) b = a b", |
|
393 " | app a b = App (a, b)", |
|
394 ""] |
|
395 fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else []) |
|
396 fun writefundecl [] = () |
|
397 | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs))) |
|
398 fun list_group c = (case Inttab.lookup rules c of |
|
399 NONE => [defcase 0 c] |
|
400 | SOME rs => |
|
401 let |
|
402 val rs = |
|
403 fold |
|
404 (fn r => |
|
405 fn rs => |
|
406 let |
|
407 val (gnum, l, rs) = |
|
408 (case rs of |
|
409 [] => (0, [], []) |
|
410 | (gnum, l)::rs => (gnum, l, rs)) |
|
411 val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r |
|
412 in |
|
413 if gnum' = gnum then |
|
414 (gnum, r::l)::rs |
|
415 else |
|
416 let |
|
417 val args = implode (map (fn i => " a"^(str i)) (section (the (arity_of c)))) |
|
418 fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args |
|
419 val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') |
|
420 in |
|
421 (gnum', [])::(gnum, s::r::l)::rs |
|
422 end |
|
423 end) |
|
424 rs [] |
|
425 val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs) |
|
426 in |
|
427 rev (map (fn z => rev (snd z)) rs) |
|
428 end) |
|
429 val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants) |
|
430 val _ = writelist [ |
|
431 "fun convert (Const i) = AM_SML.Const i", |
|
432 " | convert (App (a, b)) = AM_SML.App (convert a, convert b)", |
|
433 " | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""] |
|
434 fun make_convert c = |
|
435 let |
|
436 val args = map (indexed "a") (section (the (arity_of c))) |
|
437 val leftargs = |
|
438 case args of |
|
439 [] => "" |
|
440 | (x::xs) => "("^x^(implode (map (fn s => ", "^s) xs))^")" |
|
441 val args = map (indexed "convert a") (section (the (arity_of c))) |
|
442 val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c)) |
|
443 in |
|
444 " | convert (C"^(str c)^" "^leftargs^") = "^right |
|
445 end |
|
446 val _ = writelist (map make_convert constants) |
|
447 val _ = writelist [ |
|
448 "", |
|
449 "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"", |
|
450 " | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""] |
|
451 val _ = map (writelist o convert_computed_rules) constants |
|
452 val _ = writelist [ |
|
453 " | convert_computed (AbstractMachine.Const c) = Const c", |
|
454 " | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)", |
|
455 " | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] |
|
456 val _ = writelist [ |
|
457 "", |
|
458 "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)", |
|
459 " | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"] |
|
460 val _ = map (writelist o eval_rules) constants |
|
461 val _ = writelist [ |
|
462 " | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)", |
|
463 " | eval bounds (AbstractMachine.Const c) = Const c", |
|
464 " | eval bounds (AbstractMachine.Computed t) = convert_computed t"] |
|
465 val _ = writelist [ |
|
466 "", |
|
467 "fun export term = AM_SML.save_result (\""^code^"\", convert term)", |
|
468 "", |
|
469 "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))", |
|
470 "", |
|
471 "end"] |
|
472 in |
|
473 (inlinetab, !buffer) |
|
474 end |
|
475 |
|
476 val guid_counter = Unsynchronized.ref 0 |
|
477 fun get_guid () = |
|
478 let |
|
479 val c = !guid_counter |
|
480 val _ = guid_counter := !guid_counter + 1 |
|
481 in |
|
482 string_of_int (Time.toMicroseconds (Time.now ())) ^ string_of_int c |
|
483 end |
|
484 |
|
485 |
|
486 fun writeTextFile name s = File.write (Path.explode name) s |
|
487 |
|
488 fun use_source src = use_text ML_Env.local_context (1, "") false src |
|
489 |
|
490 fun compile rules = |
|
491 let |
|
492 val guid = get_guid () |
|
493 val code = Real.toString (random ()) |
|
494 val name = "AMSML_"^guid |
|
495 val (inlinetab, source) = sml_prog name code rules |
|
496 val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source |
|
497 val _ = compiled_rewriter := NONE |
|
498 val _ = use_source source |
|
499 in |
|
500 case !compiled_rewriter of |
|
501 NONE => raise Compile "broken link to compiled function" |
|
502 | SOME compiled_fun => (inlinetab, compiled_fun) |
|
503 end |
|
504 |
|
505 fun run (inlinetab, compiled_fun) t = |
|
506 let |
|
507 val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") |
|
508 fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) |
|
509 | inline (Var i) = Var i |
|
510 | inline (App (a, b)) = App (inline a, inline b) |
|
511 | inline (Abs m) = Abs (inline m) |
|
512 | inline (Computed t) = Computed t |
|
513 in |
|
514 compiled_fun (beta (inline t)) |
|
515 end |
|
516 |
|
517 end |
|