1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
1.2 +++ b/src/HOL/Matrix_LP/Compute_Oracle/am_sml.ML Sat Mar 17 12:52:40 2012 +0100
1.3 @@ -0,0 +1,517 @@
1.4 +(* Title: HOL/Matrix/Compute_Oracle/am_sml.ML
1.5 + Author: Steven Obua
1.6 +
1.7 +TODO: "parameterless rewrite cannot be used in pattern": In a lot of
1.8 +cases it CAN be used, and these cases should be handled
1.9 +properly; right now, all cases raise an exception.
1.10 +*)
1.11 +
1.12 +signature AM_SML =
1.13 +sig
1.14 + include ABSTRACT_MACHINE
1.15 + val save_result : (string * term) -> unit
1.16 + val set_compiled_rewriter : (term -> term) -> unit
1.17 + val list_nth : 'a list * int -> 'a
1.18 + val dump_output : (string option) Unsynchronized.ref
1.19 +end
1.20 +
1.21 +structure AM_SML : AM_SML = struct
1.22 +
1.23 +open AbstractMachine;
1.24 +
1.25 +val dump_output = Unsynchronized.ref (NONE: string option)
1.26 +
1.27 +type program = term Inttab.table * (term -> term)
1.28 +
1.29 +val saved_result = Unsynchronized.ref (NONE:(string*term)option)
1.30 +
1.31 +fun save_result r = (saved_result := SOME r)
1.32 +
1.33 +val list_nth = List.nth
1.34 +
1.35 +val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option)
1.36 +
1.37 +fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
1.38 +
1.39 +fun count_patternvars PVar = 1
1.40 + | count_patternvars (PConst (_, ps)) =
1.41 + List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
1.42 +
1.43 +fun update_arity arity code a =
1.44 + (case Inttab.lookup arity code of
1.45 + NONE => Inttab.update_new (code, a) arity
1.46 + | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
1.47 +
1.48 +(* We have to find out the maximal arity of each constant *)
1.49 +fun collect_pattern_arity PVar arity = arity
1.50 + | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
1.51 +
1.52 +(* We also need to find out the maximal toplevel arity of each function constant *)
1.53 +fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity"
1.54 + | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args)
1.55 +
1.56 +local
1.57 +fun collect applevel (Var _) arity = arity
1.58 + | collect applevel (Const c) arity = update_arity arity c applevel
1.59 + | collect applevel (Abs m) arity = collect 0 m arity
1.60 + | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
1.61 +in
1.62 +fun collect_term_arity t arity = collect 0 t arity
1.63 +end
1.64 +
1.65 +fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity)
1.66 +
1.67 +
1.68 +fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x)
1.69 +
1.70 +fun beta (Const c) = Const c
1.71 + | beta (Var i) = Var i
1.72 + | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b)))
1.73 + | beta (App (a, b)) =
1.74 + (case beta a of
1.75 + Abs m => beta (App (Abs m, b))
1.76 + | a => App (a, beta b))
1.77 + | beta (Abs m) = Abs (beta m)
1.78 + | beta (Computed t) = Computed t
1.79 +and subst x (Const c) t = Const c
1.80 + | subst x (Var i) t = if i = x then t else Var i
1.81 + | subst x (App (a,b)) t = App (subst x a t, subst x b t)
1.82 + | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t))
1.83 +and lift level (Const c) = Const c
1.84 + | lift level (App (a,b)) = App (lift level a, lift level b)
1.85 + | lift level (Var i) = if i < level then Var i else Var (i+1)
1.86 + | lift level (Abs m) = Abs (lift (level + 1) m)
1.87 +and unlift level (Const c) = Const c
1.88 + | unlift level (App (a, b)) = App (unlift level a, unlift level b)
1.89 + | unlift level (Abs m) = Abs (unlift (level+1) m)
1.90 + | unlift level (Var i) = if i < level then Var i else Var (i-1)
1.91 +
1.92 +fun nlift level n (Var m) = if m < level then Var m else Var (m+n)
1.93 + | nlift level n (Const c) = Const c
1.94 + | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
1.95 + | nlift level n (Abs b) = Abs (nlift (level+1) n b)
1.96 +
1.97 +fun subst_const (c, t) (Const c') = if c = c' then t else Const c'
1.98 + | subst_const _ (Var i) = Var i
1.99 + | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b)
1.100 + | subst_const ct (Abs m) = Abs (subst_const ct m)
1.101 +
1.102 +(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *)
1.103 +fun inline_rules rules =
1.104 + let
1.105 + fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b
1.106 + | term_contains_const c (Abs m) = term_contains_const c m
1.107 + | term_contains_const c (Var _) = false
1.108 + | term_contains_const c (Const c') = (c = c')
1.109 + fun find_rewrite [] = NONE
1.110 + | find_rewrite ((prems, PConst (c, []), r) :: _) =
1.111 + if check_freevars 0 r then
1.112 + if term_contains_const c r then
1.113 + raise Compile "parameterless rewrite is caught in cycle"
1.114 + else if not (null prems) then
1.115 + raise Compile "parameterless rewrite may not be guarded"
1.116 + else
1.117 + SOME (c, r)
1.118 + else raise Compile "unbound variable on right hand side or guards of rule"
1.119 + | find_rewrite (_ :: rules) = find_rewrite rules
1.120 + fun remove_rewrite _ [] = []
1.121 + | remove_rewrite (cr as (c, r)) ((rule as (prems', PConst (c', args), r')) :: rules) =
1.122 + if c = c' then
1.123 + if null args andalso r = r' andalso null prems' then remove_rewrite cr rules
1.124 + else raise Compile "incompatible parameterless rewrites found"
1.125 + else
1.126 + rule :: remove_rewrite cr rules
1.127 + | remove_rewrite cr (r :: rs) = r :: remove_rewrite cr rs
1.128 + fun pattern_contains_const c (PConst (c', args)) = c = c' orelse exists (pattern_contains_const c) args
1.129 + | pattern_contains_const c (PVar) = false
1.130 + fun inline_rewrite (ct as (c, _)) (prems, p, r) =
1.131 + if pattern_contains_const c p then
1.132 + raise Compile "parameterless rewrite cannot be used in pattern"
1.133 + else (map (fn (Guard (a, b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r)
1.134 + fun inline inlined rules =
1.135 + case find_rewrite rules of
1.136 + NONE => (Inttab.make inlined, rules)
1.137 + | SOME ct =>
1.138 + let
1.139 + val rules = map (inline_rewrite ct) (remove_rewrite ct rules)
1.140 + val inlined = ct :: (map o apsnd) (subst_const ct) inlined
1.141 + in inline inlined rules end
1.142 + in
1.143 + inline [] rules
1.144 + end
1.145 +
1.146 +
1.147 +(*
1.148 + Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity.
1.149 + Also beta reduce the adjusted right hand side of a rule.
1.150 +*)
1.151 +fun adjust_rules rules =
1.152 + let
1.153 + val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty
1.154 + val toplevel_arity = fold (fn (_, p, _) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty
1.155 + fun arity_of c = the (Inttab.lookup arity c)
1.156 + fun test_pattern PVar = ()
1.157 + | test_pattern (PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ())
1.158 + fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable")
1.159 + | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters")
1.160 + | adjust_rule (rule as (prems, p as PConst (c, args),t)) =
1.161 + let
1.162 + val patternvars_counted = count_patternvars p
1.163 + fun check_fv t = check_freevars patternvars_counted t
1.164 + val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else ()
1.165 + val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else ()
1.166 + val _ = map test_pattern args
1.167 + val len = length args
1.168 + val arity = arity_of c
1.169 + val lift = nlift 0
1.170 + fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1)))
1.171 + fun adjust_term n t = addapps_tm n (lift n t)
1.172 + fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b)
1.173 + in
1.174 + if len = arity then
1.175 + rule
1.176 + else if arity >= len then
1.177 + (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t)
1.178 + else (raise Compile "internal error in adjust_rule")
1.179 + end
1.180 + fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule")
1.181 + in
1.182 + (arity, toplevel_arity, map (beta_rule o adjust_rule) rules)
1.183 + end
1.184 +
1.185 +fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count =
1.186 +let
1.187 + fun str x = string_of_int x
1.188 + fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
1.189 + val module_prefix = (case module of NONE => "" | SOME s => s^".")
1.190 + fun print_apps d f [] = f
1.191 + | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
1.192 + and print_call d (App (a, b)) args = print_call d a (b::args)
1.193 + | print_call d (Const c) args =
1.194 + (case arity_of c of
1.195 + NONE => print_apps d (module_prefix^"Const "^(str c)) args
1.196 + | SOME 0 => module_prefix^"C"^(str c)
1.197 + | SOME a =>
1.198 + let
1.199 + val len = length args
1.200 + in
1.201 + if a <= len then
1.202 + let
1.203 + val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a)
1.204 + val _ = if strict_a > a then raise Compile "strict" else ()
1.205 + val s = module_prefix^"c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a))))
1.206 + val s = s^(implode (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a))))
1.207 + in
1.208 + print_apps d s (List.drop (args, a))
1.209 + end
1.210 + else
1.211 + let
1.212 + fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1)))
1.213 + fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
1.214 + fun append_args [] t = t
1.215 + | append_args (c::cs) t = append_args cs (App (t, c))
1.216 + in
1.217 + print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
1.218 + end
1.219 + end)
1.220 + | print_call d t args = print_apps d (print_term d t) args
1.221 + and print_term d (Var x) =
1.222 + if x < d then
1.223 + "b"^(str (d-x-1))
1.224 + else
1.225 + let
1.226 + val n = pattern_var_count - (x-d) - 1
1.227 + val x = "x"^(str n)
1.228 + in
1.229 + if n < pattern_var_count - pattern_lazy_var_count then
1.230 + x
1.231 + else
1.232 + "("^x^" ())"
1.233 + end
1.234 + | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")"
1.235 + | print_term d t = print_call d t []
1.236 +in
1.237 + print_term 0
1.238 +end
1.239 +
1.240 +fun section n = if n = 0 then [] else (section (n-1))@[n-1]
1.241 +
1.242 +fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) =
1.243 + let
1.244 + fun str x = string_of_int x
1.245 + fun print_pattern top n PVar = (n+1, "x"^(str n))
1.246 + | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else ""))
1.247 + | print_pattern top n (PConst (c, args)) =
1.248 + let
1.249 + val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")
1.250 + val (n, s) = print_pattern_list 0 top (n, f) args
1.251 + in
1.252 + (n, s)
1.253 + end
1.254 + and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")")
1.255 + | print_pattern_list' counter top (n, p) (t::ts) =
1.256 + let
1.257 + val (n, t) = print_pattern false n t
1.258 + in
1.259 + print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts
1.260 + end
1.261 + and print_pattern_list counter top (n, p) (t::ts) =
1.262 + let
1.263 + val (n, t) = print_pattern false n t
1.264 + in
1.265 + print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts
1.266 + end
1.267 + val c = (case p of PConst (c, _) => c | _ => raise Match)
1.268 + val (n, pattern) = print_pattern true 0 p
1.269 + val lazy_vars = the (arity_of c) - the (toplevel_arity_of c)
1.270 + fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm
1.271 + fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")"
1.272 + val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(implode (map (fn i => " a"^(str i)) (section (the (arity_of c)))))
1.273 + fun print_guards t [] = print_tm t
1.274 + | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(implode (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch
1.275 + in
1.276 + (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards))
1.277 + end
1.278 +
1.279 +fun group_rules rules =
1.280 + let
1.281 + fun add_rule (r as (_, PConst (c,_), _)) groups =
1.282 + let
1.283 + val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
1.284 + in
1.285 + Inttab.update (c, r::rs) groups
1.286 + end
1.287 + | add_rule _ _ = raise Compile "internal error group_rules"
1.288 + in
1.289 + fold_rev add_rule rules Inttab.empty
1.290 + end
1.291 +
1.292 +fun sml_prog name code rules =
1.293 + let
1.294 + val buffer = Unsynchronized.ref ""
1.295 + fun write s = (buffer := (!buffer)^s)
1.296 + fun writeln s = (write s; write "\n")
1.297 + fun writelist [] = ()
1.298 + | writelist (s::ss) = (writeln s; writelist ss)
1.299 + fun str i = string_of_int i
1.300 + val (inlinetab, rules) = inline_rules rules
1.301 + val (arity, toplevel_arity, rules) = adjust_rules rules
1.302 + val rules = group_rules rules
1.303 + val constants = Inttab.keys arity
1.304 + fun arity_of c = Inttab.lookup arity c
1.305 + fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
1.306 + fun rep_str s n = implode (rep n s)
1.307 + fun indexed s n = s^(str n)
1.308 + fun string_of_tuple [] = ""
1.309 + | string_of_tuple (x::xs) = "("^x^(implode (map (fn s => ", "^s) xs))^")"
1.310 + fun string_of_args [] = ""
1.311 + | string_of_args (x::xs) = x^(implode (map (fn s => " "^s) xs))
1.312 + fun default_case gnum c =
1.313 + let
1.314 + val leftargs = implode (map (indexed " x") (section (the (arity_of c))))
1.315 + val rightargs = section (the (arity_of c))
1.316 + val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa)
1.317 + val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs
1.318 + val right = (indexed "C" c)^" "^(string_of_tuple xs)
1.319 + val message = "(\"unresolved lazy call: " ^ string_of_int c ^ "\")"
1.320 + val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right
1.321 + in
1.322 + (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right
1.323 + end
1.324 +
1.325 + fun eval_rules c =
1.326 + let
1.327 + val arity = the (arity_of c)
1.328 + val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa)
1.329 + fun eval_rule n =
1.330 + let
1.331 + val sc = string_of_int c
1.332 + val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc)
1.333 + fun arg i =
1.334 + let
1.335 + val x = indexed "x" i
1.336 + val x = if i < n then "(eval bounds "^x^")" else x
1.337 + val x = if i < strict_arity then x else "(fn () => "^x^")"
1.338 + in
1.339 + x
1.340 + end
1.341 + val right = "c"^sc^" "^(string_of_args (map arg (section arity)))
1.342 + val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right
1.343 + val right = if arity > 0 then right else "C"^sc
1.344 + in
1.345 + " | eval bounds ("^left^") = "^right
1.346 + end
1.347 + in
1.348 + map eval_rule (rev (section (arity + 1)))
1.349 + end
1.350 +
1.351 + fun convert_computed_rules (c: int) : string list =
1.352 + let
1.353 + val arity = the (arity_of c)
1.354 + fun eval_rule () =
1.355 + let
1.356 + val sc = string_of_int c
1.357 + val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc)
1.358 + fun arg i = "(convert_computed "^(indexed "x" i)^")"
1.359 + val right = "C"^sc^" "^(string_of_tuple (map arg (section arity)))
1.360 + val right = if arity > 0 then right else "C"^sc
1.361 + in
1.362 + " | convert_computed ("^left^") = "^right
1.363 + end
1.364 + in
1.365 + [eval_rule ()]
1.366 + end
1.367 +
1.368 + fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else ""
1.369 + val _ = writelist [
1.370 + "structure "^name^" = struct",
1.371 + "",
1.372 + "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)",
1.373 + " "^(implode (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)),
1.374 + ""]
1.375 + fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")"
1.376 + fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^
1.377 + (case the (arity_of c) of
1.378 + 0 => "true"
1.379 + | n =>
1.380 + let
1.381 + val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n)
1.382 + val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs))
1.383 + in
1.384 + eq^(implode eqs)
1.385 + end)
1.386 + val _ = writelist [
1.387 + "fun term_eq (Const c1) (Const c2) = (c1 = c2)",
1.388 + " | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"]
1.389 + val _ = writelist (map make_term_eq constants)
1.390 + val _ = writelist [
1.391 + " | term_eq _ _ = false",
1.392 + ""
1.393 + ]
1.394 + val _ = writelist [
1.395 + "fun app (Abs a) b = a b",
1.396 + " | app a b = App (a, b)",
1.397 + ""]
1.398 + fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else [])
1.399 + fun writefundecl [] = ()
1.400 + | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs)))
1.401 + fun list_group c = (case Inttab.lookup rules c of
1.402 + NONE => [defcase 0 c]
1.403 + | SOME rs =>
1.404 + let
1.405 + val rs =
1.406 + fold
1.407 + (fn r =>
1.408 + fn rs =>
1.409 + let
1.410 + val (gnum, l, rs) =
1.411 + (case rs of
1.412 + [] => (0, [], [])
1.413 + | (gnum, l)::rs => (gnum, l, rs))
1.414 + val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r
1.415 + in
1.416 + if gnum' = gnum then
1.417 + (gnum, r::l)::rs
1.418 + else
1.419 + let
1.420 + val args = implode (map (fn i => " a"^(str i)) (section (the (arity_of c))))
1.421 + fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args
1.422 + val s = gnumc (gnum) ^ " = " ^ gnumc (gnum')
1.423 + in
1.424 + (gnum', [])::(gnum, s::r::l)::rs
1.425 + end
1.426 + end)
1.427 + rs []
1.428 + val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs)
1.429 + in
1.430 + rev (map (fn z => rev (snd z)) rs)
1.431 + end)
1.432 + val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants)
1.433 + val _ = writelist [
1.434 + "fun convert (Const i) = AM_SML.Const i",
1.435 + " | convert (App (a, b)) = AM_SML.App (convert a, convert b)",
1.436 + " | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""]
1.437 + fun make_convert c =
1.438 + let
1.439 + val args = map (indexed "a") (section (the (arity_of c)))
1.440 + val leftargs =
1.441 + case args of
1.442 + [] => ""
1.443 + | (x::xs) => "("^x^(implode (map (fn s => ", "^s) xs))^")"
1.444 + val args = map (indexed "convert a") (section (the (arity_of c)))
1.445 + val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c))
1.446 + in
1.447 + " | convert (C"^(str c)^" "^leftargs^") = "^right
1.448 + end
1.449 + val _ = writelist (map make_convert constants)
1.450 + val _ = writelist [
1.451 + "",
1.452 + "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"",
1.453 + " | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""]
1.454 + val _ = map (writelist o convert_computed_rules) constants
1.455 + val _ = writelist [
1.456 + " | convert_computed (AbstractMachine.Const c) = Const c",
1.457 + " | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)",
1.458 + " | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""]
1.459 + val _ = writelist [
1.460 + "",
1.461 + "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)",
1.462 + " | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"]
1.463 + val _ = map (writelist o eval_rules) constants
1.464 + val _ = writelist [
1.465 + " | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)",
1.466 + " | eval bounds (AbstractMachine.Const c) = Const c",
1.467 + " | eval bounds (AbstractMachine.Computed t) = convert_computed t"]
1.468 + val _ = writelist [
1.469 + "",
1.470 + "fun export term = AM_SML.save_result (\""^code^"\", convert term)",
1.471 + "",
1.472 + "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))",
1.473 + "",
1.474 + "end"]
1.475 + in
1.476 + (inlinetab, !buffer)
1.477 + end
1.478 +
1.479 +val guid_counter = Unsynchronized.ref 0
1.480 +fun get_guid () =
1.481 + let
1.482 + val c = !guid_counter
1.483 + val _ = guid_counter := !guid_counter + 1
1.484 + in
1.485 + string_of_int (Time.toMicroseconds (Time.now ())) ^ string_of_int c
1.486 + end
1.487 +
1.488 +
1.489 +fun writeTextFile name s = File.write (Path.explode name) s
1.490 +
1.491 +fun use_source src = use_text ML_Env.local_context (1, "") false src
1.492 +
1.493 +fun compile rules =
1.494 + let
1.495 + val guid = get_guid ()
1.496 + val code = Real.toString (random ())
1.497 + val name = "AMSML_"^guid
1.498 + val (inlinetab, source) = sml_prog name code rules
1.499 + val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source
1.500 + val _ = compiled_rewriter := NONE
1.501 + val _ = use_source source
1.502 + in
1.503 + case !compiled_rewriter of
1.504 + NONE => raise Compile "broken link to compiled function"
1.505 + | SOME compiled_fun => (inlinetab, compiled_fun)
1.506 + end
1.507 +
1.508 +fun run (inlinetab, compiled_fun) t =
1.509 + let
1.510 + val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
1.511 + fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
1.512 + | inline (Var i) = Var i
1.513 + | inline (App (a, b)) = App (inline a, inline b)
1.514 + | inline (Abs m) = Abs (inline m)
1.515 + | inline (Computed t) = Computed t
1.516 + in
1.517 + compiled_fun (beta (inline t))
1.518 + end
1.519 +
1.520 +end