src/HOL/Matrix/Compute_Oracle/am_sml.ML
changeset 47859 9f492f5b0cec
parent 47413 dcc575b30842
equal deleted inserted replaced
47858:15ce93dfe6da 47859:9f492f5b0cec
     1 (*  Title:      HOL/Matrix/Compute_Oracle/am_sml.ML
       
     2     Author:     Steven Obua
       
     3 
       
     4 TODO: "parameterless rewrite cannot be used in pattern": In a lot of
       
     5 cases it CAN be used, and these cases should be handled
       
     6 properly; right now, all cases raise an exception. 
       
     7 *)
       
     8 
       
     9 signature AM_SML = 
       
    10 sig
       
    11   include ABSTRACT_MACHINE
       
    12   val save_result : (string * term) -> unit
       
    13   val set_compiled_rewriter : (term -> term) -> unit
       
    14   val list_nth : 'a list * int -> 'a
       
    15   val dump_output : (string option) Unsynchronized.ref 
       
    16 end
       
    17 
       
    18 structure AM_SML : AM_SML = struct
       
    19 
       
    20 open AbstractMachine;
       
    21 
       
    22 val dump_output = Unsynchronized.ref (NONE: string option)
       
    23 
       
    24 type program = term Inttab.table * (term -> term)
       
    25 
       
    26 val saved_result = Unsynchronized.ref (NONE:(string*term)option)
       
    27 
       
    28 fun save_result r = (saved_result := SOME r)
       
    29 
       
    30 val list_nth = List.nth
       
    31 
       
    32 val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option)
       
    33 
       
    34 fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
       
    35 
       
    36 fun count_patternvars PVar = 1
       
    37   | count_patternvars (PConst (_, ps)) =
       
    38       List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
       
    39 
       
    40 fun update_arity arity code a = 
       
    41     (case Inttab.lookup arity code of
       
    42          NONE => Inttab.update_new (code, a) arity
       
    43        | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
       
    44 
       
    45 (* We have to find out the maximal arity of each constant *)
       
    46 fun collect_pattern_arity PVar arity = arity
       
    47   | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
       
    48 
       
    49 (* We also need to find out the maximal toplevel arity of each function constant *)
       
    50 fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity"
       
    51   | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args)
       
    52 
       
    53 local
       
    54 fun collect applevel (Var _) arity = arity
       
    55   | collect applevel (Const c) arity = update_arity arity c applevel
       
    56   | collect applevel (Abs m) arity = collect 0 m arity
       
    57   | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
       
    58 in
       
    59 fun collect_term_arity t arity = collect 0 t arity
       
    60 end
       
    61 
       
    62 fun collect_guard_arity (Guard (a,b)) arity  = collect_term_arity b (collect_term_arity a arity)
       
    63 
       
    64 
       
    65 fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x)
       
    66 
       
    67 fun beta (Const c) = Const c
       
    68   | beta (Var i) = Var i
       
    69   | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b)))
       
    70   | beta (App (a, b)) = 
       
    71     (case beta a of
       
    72          Abs m => beta (App (Abs m, b))
       
    73        | a => App (a, beta b))
       
    74   | beta (Abs m) = Abs (beta m)
       
    75   | beta (Computed t) = Computed t
       
    76 and subst x (Const c) t = Const c
       
    77   | subst x (Var i) t = if i = x then t else Var i
       
    78   | subst x (App (a,b)) t = App (subst x a t, subst x b t)
       
    79   | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t))
       
    80 and lift level (Const c) = Const c
       
    81   | lift level (App (a,b)) = App (lift level a, lift level b)
       
    82   | lift level (Var i) = if i < level then Var i else Var (i+1)
       
    83   | lift level (Abs m) = Abs (lift (level + 1) m)
       
    84 and unlift level (Const c) = Const c
       
    85   | unlift level (App (a, b)) = App (unlift level a, unlift level b)
       
    86   | unlift level (Abs m) = Abs (unlift (level+1) m)
       
    87   | unlift level (Var i) = if i < level then Var i else Var (i-1)
       
    88 
       
    89 fun nlift level n (Var m) = if m < level then Var m else Var (m+n) 
       
    90   | nlift level n (Const c) = Const c
       
    91   | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
       
    92   | nlift level n (Abs b) = Abs (nlift (level+1) n b)
       
    93 
       
    94 fun subst_const (c, t) (Const c') = if c = c' then t else Const c'
       
    95   | subst_const _ (Var i) = Var i
       
    96   | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b)
       
    97   | subst_const ct (Abs m) = Abs (subst_const ct m)
       
    98 
       
    99 (* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *)
       
   100 fun inline_rules rules =
       
   101   let
       
   102     fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b
       
   103       | term_contains_const c (Abs m) = term_contains_const c m
       
   104       | term_contains_const c (Var _) = false
       
   105       | term_contains_const c (Const c') = (c = c')
       
   106     fun find_rewrite [] = NONE
       
   107       | find_rewrite ((prems, PConst (c, []), r) :: _) = 
       
   108           if check_freevars 0 r then 
       
   109             if term_contains_const c r then 
       
   110               raise Compile "parameterless rewrite is caught in cycle"
       
   111             else if not (null prems) then
       
   112               raise Compile "parameterless rewrite may not be guarded"
       
   113             else
       
   114               SOME (c, r) 
       
   115           else raise Compile "unbound variable on right hand side or guards of rule"
       
   116       | find_rewrite (_ :: rules) = find_rewrite rules
       
   117     fun remove_rewrite _ [] = []
       
   118       | remove_rewrite (cr as (c, r)) ((rule as (prems', PConst (c', args), r')) :: rules) = 
       
   119           if c = c' then 
       
   120             if null args andalso r = r' andalso null prems' then remove_rewrite cr rules 
       
   121             else raise Compile "incompatible parameterless rewrites found"
       
   122           else
       
   123             rule :: remove_rewrite cr rules
       
   124       | remove_rewrite cr (r :: rs) = r :: remove_rewrite cr rs
       
   125     fun pattern_contains_const c (PConst (c', args)) = c = c' orelse exists (pattern_contains_const c) args
       
   126       | pattern_contains_const c (PVar) = false
       
   127     fun inline_rewrite (ct as (c, _)) (prems, p, r) = 
       
   128         if pattern_contains_const c p then 
       
   129           raise Compile "parameterless rewrite cannot be used in pattern"
       
   130         else (map (fn (Guard (a, b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r)
       
   131     fun inline inlined rules =
       
   132       case find_rewrite rules of 
       
   133           NONE => (Inttab.make inlined, rules)
       
   134         | SOME ct => 
       
   135             let
       
   136               val rules = map (inline_rewrite ct) (remove_rewrite ct rules)
       
   137               val inlined = ct :: (map o apsnd) (subst_const ct) inlined
       
   138             in inline inlined rules end
       
   139   in
       
   140     inline [] rules
       
   141   end
       
   142 
       
   143 
       
   144 (*
       
   145    Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity.
       
   146    Also beta reduce the adjusted right hand side of a rule.   
       
   147 *)
       
   148 fun adjust_rules rules = 
       
   149     let
       
   150         val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty
       
   151         val toplevel_arity = fold (fn (_, p, _) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty
       
   152         fun arity_of c = the (Inttab.lookup arity c)
       
   153         fun test_pattern PVar = ()
       
   154           | test_pattern (PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ())
       
   155         fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable")
       
   156           | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters")
       
   157           | adjust_rule (rule as (prems, p as PConst (c, args),t)) = 
       
   158             let
       
   159                 val patternvars_counted = count_patternvars p
       
   160                 fun check_fv t = check_freevars patternvars_counted t
       
   161                 val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () 
       
   162                 val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () 
       
   163                 val _ = map test_pattern args           
       
   164                 val len = length args
       
   165                 val arity = arity_of c
       
   166                 val lift = nlift 0
       
   167                 fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1)))
       
   168                 fun adjust_term n t = addapps_tm n (lift n t)
       
   169                 fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b)
       
   170             in
       
   171                 if len = arity then
       
   172                     rule
       
   173                 else if arity >= len then  
       
   174                     (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t)
       
   175                 else (raise Compile "internal error in adjust_rule")
       
   176             end
       
   177         fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule")
       
   178     in
       
   179         (arity, toplevel_arity, map (beta_rule o adjust_rule) rules)
       
   180     end             
       
   181 
       
   182 fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count =
       
   183 let
       
   184     fun str x = string_of_int x
       
   185     fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
       
   186     val module_prefix = (case module of NONE => "" | SOME s => s^".")                                                                                     
       
   187     fun print_apps d f [] = f
       
   188       | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
       
   189     and print_call d (App (a, b)) args = print_call d a (b::args) 
       
   190       | print_call d (Const c) args = 
       
   191         (case arity_of c of 
       
   192              NONE => print_apps d (module_prefix^"Const "^(str c)) args 
       
   193            | SOME 0 => module_prefix^"C"^(str c)
       
   194            | SOME a =>
       
   195              let
       
   196                  val len = length args
       
   197              in
       
   198                  if a <= len then 
       
   199                      let
       
   200                          val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a)
       
   201                          val _ = if strict_a > a then raise Compile "strict" else ()
       
   202                          val s = module_prefix^"c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a))))
       
   203                          val s = s^(implode (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a))))
       
   204                      in
       
   205                          print_apps d s (List.drop (args, a))
       
   206                      end
       
   207                  else 
       
   208                      let
       
   209                          fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1)))
       
   210                          fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
       
   211                          fun append_args [] t = t
       
   212                            | append_args (c::cs) t = append_args cs (App (t, c))
       
   213                      in
       
   214                          print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
       
   215                      end
       
   216              end)
       
   217       | print_call d t args = print_apps d (print_term d t) args
       
   218     and print_term d (Var x) = 
       
   219         if x < d then 
       
   220             "b"^(str (d-x-1)) 
       
   221         else 
       
   222             let
       
   223                 val n = pattern_var_count - (x-d) - 1
       
   224                 val x = "x"^(str n)
       
   225             in
       
   226                 if n < pattern_var_count - pattern_lazy_var_count then 
       
   227                     x
       
   228                 else 
       
   229                     "("^x^" ())"
       
   230             end                                                         
       
   231       | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")"
       
   232       | print_term d t = print_call d t []
       
   233 in
       
   234     print_term 0 
       
   235 end
       
   236 
       
   237 fun section n = if n = 0 then [] else (section (n-1))@[n-1]
       
   238 
       
   239 fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = 
       
   240     let 
       
   241         fun str x = string_of_int x                  
       
   242         fun print_pattern top n PVar = (n+1, "x"^(str n))
       
   243           | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else ""))
       
   244           | print_pattern top n (PConst (c, args)) = 
       
   245             let
       
   246                 val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")
       
   247                 val (n, s) = print_pattern_list 0 top (n, f) args
       
   248             in
       
   249                 (n, s)
       
   250             end
       
   251         and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")")
       
   252           | print_pattern_list' counter top (n, p) (t::ts) = 
       
   253             let
       
   254                 val (n, t) = print_pattern false n t
       
   255             in
       
   256                 print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts
       
   257             end 
       
   258         and print_pattern_list counter top (n, p) (t::ts) = 
       
   259             let
       
   260                 val (n, t) = print_pattern false n t
       
   261             in
       
   262                 print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts
       
   263             end
       
   264         val c = (case p of PConst (c, _) => c | _ => raise Match)
       
   265         val (n, pattern) = print_pattern true 0 p
       
   266         val lazy_vars = the (arity_of c) - the (toplevel_arity_of c)
       
   267         fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm
       
   268         fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")"
       
   269         val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(implode (map (fn i => " a"^(str i)) (section (the (arity_of c)))))
       
   270         fun print_guards t [] = print_tm t
       
   271           | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(implode (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch
       
   272     in
       
   273         (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards))
       
   274     end
       
   275 
       
   276 fun group_rules rules =
       
   277     let
       
   278         fun add_rule (r as (_, PConst (c,_), _)) groups =
       
   279             let
       
   280                 val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
       
   281             in
       
   282                 Inttab.update (c, r::rs) groups
       
   283             end
       
   284           | add_rule _ _ = raise Compile "internal error group_rules"
       
   285     in
       
   286         fold_rev add_rule rules Inttab.empty
       
   287     end
       
   288 
       
   289 fun sml_prog name code rules = 
       
   290     let
       
   291         val buffer = Unsynchronized.ref ""
       
   292         fun write s = (buffer := (!buffer)^s)
       
   293         fun writeln s = (write s; write "\n")
       
   294         fun writelist [] = ()
       
   295           | writelist (s::ss) = (writeln s; writelist ss)
       
   296         fun str i = string_of_int i
       
   297         val (inlinetab, rules) = inline_rules rules
       
   298         val (arity, toplevel_arity, rules) = adjust_rules rules
       
   299         val rules = group_rules rules
       
   300         val constants = Inttab.keys arity
       
   301         fun arity_of c = Inttab.lookup arity c
       
   302         fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
       
   303         fun rep_str s n = implode (rep n s)
       
   304         fun indexed s n = s^(str n)
       
   305         fun string_of_tuple [] = ""
       
   306           | string_of_tuple (x::xs) = "("^x^(implode (map (fn s => ", "^s) xs))^")"
       
   307         fun string_of_args [] = ""
       
   308           | string_of_args (x::xs) = x^(implode (map (fn s => " "^s) xs))
       
   309         fun default_case gnum c = 
       
   310             let
       
   311                 val leftargs = implode (map (indexed " x") (section (the (arity_of c))))
       
   312                 val rightargs = section (the (arity_of c))
       
   313                 val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa)
       
   314                 val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs
       
   315                 val right = (indexed "C" c)^" "^(string_of_tuple xs)
       
   316                 val message = "(\"unresolved lazy call: " ^ string_of_int c ^ "\")"
       
   317                 val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right               
       
   318             in
       
   319                 (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right
       
   320             end
       
   321 
       
   322         fun eval_rules c = 
       
   323             let
       
   324                 val arity = the (arity_of c)
       
   325                 val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa)
       
   326                 fun eval_rule n = 
       
   327                     let
       
   328                         val sc = string_of_int c
       
   329                         val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc)
       
   330                         fun arg i = 
       
   331                             let
       
   332                                 val x = indexed "x" i
       
   333                                 val x = if i < n then "(eval bounds "^x^")" else x
       
   334                                 val x = if i < strict_arity then x else "(fn () => "^x^")"
       
   335                             in
       
   336                                 x
       
   337                             end
       
   338                         val right = "c"^sc^" "^(string_of_args (map arg (section arity)))
       
   339                         val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right             
       
   340                         val right = if arity > 0 then right else "C"^sc
       
   341                     in
       
   342                         "  | eval bounds ("^left^") = "^right
       
   343                     end
       
   344             in
       
   345                 map eval_rule (rev (section (arity + 1)))
       
   346             end
       
   347 
       
   348         fun convert_computed_rules (c: int) : string list = 
       
   349             let
       
   350                 val arity = the (arity_of c)
       
   351                 fun eval_rule () = 
       
   352                     let
       
   353                         val sc = string_of_int c
       
   354                         val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc)
       
   355                         fun arg i = "(convert_computed "^(indexed "x" i)^")" 
       
   356                         val right = "C"^sc^" "^(string_of_tuple (map arg (section arity)))              
       
   357                         val right = if arity > 0 then right else "C"^sc
       
   358                     in
       
   359                         "  | convert_computed ("^left^") = "^right
       
   360                     end
       
   361             in
       
   362                 [eval_rule ()]
       
   363             end
       
   364         
       
   365         fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else ""
       
   366         val _ = writelist [                   
       
   367                 "structure "^name^" = struct",
       
   368                 "",
       
   369                 "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)",
       
   370                 "         "^(implode (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)),
       
   371                 ""]
       
   372         fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")"
       
   373         fun make_term_eq c = "  | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^
       
   374                              (case the (arity_of c) of 
       
   375                                   0 => "true"
       
   376                                 | n => 
       
   377                                   let 
       
   378                                       val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n)
       
   379                                       val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs))
       
   380                                   in
       
   381                                       eq^(implode eqs)
       
   382                                   end)
       
   383         val _ = writelist [
       
   384                 "fun term_eq (Const c1) (Const c2) = (c1 = c2)",
       
   385                 "  | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"]
       
   386         val _ = writelist (map make_term_eq constants)          
       
   387         val _ = writelist [
       
   388                 "  | term_eq _ _ = false",
       
   389                 "" 
       
   390                 ] 
       
   391         val _ = writelist [
       
   392                 "fun app (Abs a) b = a b",
       
   393                 "  | app a b = App (a, b)",
       
   394                 ""]     
       
   395         fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else [])
       
   396         fun writefundecl [] = () 
       
   397           | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => "  | "^s) xs)))
       
   398         fun list_group c = (case Inttab.lookup rules c of 
       
   399                                 NONE => [defcase 0 c]
       
   400                               | SOME rs => 
       
   401                                 let
       
   402                                     val rs = 
       
   403                                         fold
       
   404                                             (fn r => 
       
   405                                              fn rs =>
       
   406                                                 let 
       
   407                                                     val (gnum, l, rs) = 
       
   408                                                         (case rs of 
       
   409                                                              [] => (0, [], []) 
       
   410                                                            | (gnum, l)::rs => (gnum, l, rs))
       
   411                                                     val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r 
       
   412                                                 in 
       
   413                                                     if gnum' = gnum then 
       
   414                                                         (gnum, r::l)::rs
       
   415                                                     else
       
   416                                                         let
       
   417                                                             val args = implode (map (fn i => " a"^(str i)) (section (the (arity_of c))))
       
   418                                                             fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args
       
   419                                                             val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') 
       
   420                                                         in
       
   421                                                             (gnum', [])::(gnum, s::r::l)::rs
       
   422                                                         end
       
   423                                                 end)
       
   424                                         rs []
       
   425                                     val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs)
       
   426                                 in
       
   427                                     rev (map (fn z => rev (snd z)) rs)
       
   428                                 end)
       
   429         val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants)
       
   430         val _ = writelist [
       
   431                 "fun convert (Const i) = AM_SML.Const i",
       
   432                 "  | convert (App (a, b)) = AM_SML.App (convert a, convert b)",
       
   433                 "  | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""]  
       
   434         fun make_convert c = 
       
   435             let
       
   436                 val args = map (indexed "a") (section (the (arity_of c)))
       
   437                 val leftargs = 
       
   438                     case args of
       
   439                         [] => ""
       
   440                       | (x::xs) => "("^x^(implode (map (fn s => ", "^s) xs))^")"
       
   441                 val args = map (indexed "convert a") (section (the (arity_of c)))
       
   442                 val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c))
       
   443             in
       
   444                 "  | convert (C"^(str c)^" "^leftargs^") = "^right
       
   445             end                 
       
   446         val _ = writelist (map make_convert constants)
       
   447         val _ = writelist [
       
   448                 "",
       
   449                 "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"",
       
   450                 "  | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""]
       
   451         val _ = map (writelist o convert_computed_rules) constants
       
   452         val _ = writelist [
       
   453                 "  | convert_computed (AbstractMachine.Const c) = Const c",
       
   454                 "  | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)",
       
   455                 "  | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] 
       
   456         val _ = writelist [
       
   457                 "",
       
   458                 "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)",
       
   459                 "  | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"]
       
   460         val _ = map (writelist o eval_rules) constants
       
   461         val _ = writelist [
       
   462                 "  | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)",
       
   463                 "  | eval bounds (AbstractMachine.Const c) = Const c",
       
   464                 "  | eval bounds (AbstractMachine.Computed t) = convert_computed t"]                
       
   465         val _ = writelist [             
       
   466                 "",
       
   467                 "fun export term = AM_SML.save_result (\""^code^"\", convert term)",
       
   468                 "",
       
   469                 "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))",
       
   470                 "",
       
   471                 "end"]
       
   472     in
       
   473         (inlinetab, !buffer)
       
   474     end
       
   475 
       
   476 val guid_counter = Unsynchronized.ref 0
       
   477 fun get_guid () = 
       
   478     let
       
   479         val c = !guid_counter
       
   480         val _ = guid_counter := !guid_counter + 1
       
   481     in
       
   482         string_of_int (Time.toMicroseconds (Time.now ())) ^ string_of_int c
       
   483     end
       
   484 
       
   485 
       
   486 fun writeTextFile name s = File.write (Path.explode name) s
       
   487 
       
   488 fun use_source src = use_text ML_Env.local_context (1, "") false src
       
   489     
       
   490 fun compile rules = 
       
   491     let
       
   492         val guid = get_guid ()
       
   493         val code = Real.toString (random ())
       
   494         val name = "AMSML_"^guid
       
   495         val (inlinetab, source) = sml_prog name code rules
       
   496         val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source
       
   497         val _ = compiled_rewriter := NONE
       
   498         val _ = use_source source
       
   499     in
       
   500         case !compiled_rewriter of 
       
   501             NONE => raise Compile "broken link to compiled function"
       
   502           | SOME compiled_fun => (inlinetab, compiled_fun)
       
   503     end
       
   504 
       
   505 fun run (inlinetab, compiled_fun) t = 
       
   506     let 
       
   507         val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
       
   508         fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
       
   509           | inline (Var i) = Var i
       
   510           | inline (App (a, b)) = App (inline a, inline b)
       
   511           | inline (Abs m) = Abs (inline m)
       
   512           | inline (Computed t) = Computed t
       
   513     in
       
   514         compiled_fun (beta (inline t))
       
   515     end 
       
   516 
       
   517 end