src/HOL/Matrix_LP/Compute_Oracle/am_sml.ML
changeset 47859 9f492f5b0cec
parent 47413 dcc575b30842
child 48326 26315a545e26
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Matrix_LP/Compute_Oracle/am_sml.ML	Sat Mar 17 12:52:40 2012 +0100
     1.3 @@ -0,0 +1,517 @@
     1.4 +(*  Title:      HOL/Matrix/Compute_Oracle/am_sml.ML
     1.5 +    Author:     Steven Obua
     1.6 +
     1.7 +TODO: "parameterless rewrite cannot be used in pattern": In a lot of
     1.8 +cases it CAN be used, and these cases should be handled
     1.9 +properly; right now, all cases raise an exception. 
    1.10 +*)
    1.11 +
    1.12 +signature AM_SML = 
    1.13 +sig
    1.14 +  include ABSTRACT_MACHINE
    1.15 +  val save_result : (string * term) -> unit
    1.16 +  val set_compiled_rewriter : (term -> term) -> unit
    1.17 +  val list_nth : 'a list * int -> 'a
    1.18 +  val dump_output : (string option) Unsynchronized.ref 
    1.19 +end
    1.20 +
    1.21 +structure AM_SML : AM_SML = struct
    1.22 +
    1.23 +open AbstractMachine;
    1.24 +
    1.25 +val dump_output = Unsynchronized.ref (NONE: string option)
    1.26 +
    1.27 +type program = term Inttab.table * (term -> term)
    1.28 +
    1.29 +val saved_result = Unsynchronized.ref (NONE:(string*term)option)
    1.30 +
    1.31 +fun save_result r = (saved_result := SOME r)
    1.32 +
    1.33 +val list_nth = List.nth
    1.34 +
    1.35 +val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option)
    1.36 +
    1.37 +fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
    1.38 +
    1.39 +fun count_patternvars PVar = 1
    1.40 +  | count_patternvars (PConst (_, ps)) =
    1.41 +      List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
    1.42 +
    1.43 +fun update_arity arity code a = 
    1.44 +    (case Inttab.lookup arity code of
    1.45 +         NONE => Inttab.update_new (code, a) arity
    1.46 +       | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity)
    1.47 +
    1.48 +(* We have to find out the maximal arity of each constant *)
    1.49 +fun collect_pattern_arity PVar arity = arity
    1.50 +  | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args))
    1.51 +
    1.52 +(* We also need to find out the maximal toplevel arity of each function constant *)
    1.53 +fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity"
    1.54 +  | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args)
    1.55 +
    1.56 +local
    1.57 +fun collect applevel (Var _) arity = arity
    1.58 +  | collect applevel (Const c) arity = update_arity arity c applevel
    1.59 +  | collect applevel (Abs m) arity = collect 0 m arity
    1.60 +  | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity)
    1.61 +in
    1.62 +fun collect_term_arity t arity = collect 0 t arity
    1.63 +end
    1.64 +
    1.65 +fun collect_guard_arity (Guard (a,b)) arity  = collect_term_arity b (collect_term_arity a arity)
    1.66 +
    1.67 +
    1.68 +fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x)
    1.69 +
    1.70 +fun beta (Const c) = Const c
    1.71 +  | beta (Var i) = Var i
    1.72 +  | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b)))
    1.73 +  | beta (App (a, b)) = 
    1.74 +    (case beta a of
    1.75 +         Abs m => beta (App (Abs m, b))
    1.76 +       | a => App (a, beta b))
    1.77 +  | beta (Abs m) = Abs (beta m)
    1.78 +  | beta (Computed t) = Computed t
    1.79 +and subst x (Const c) t = Const c
    1.80 +  | subst x (Var i) t = if i = x then t else Var i
    1.81 +  | subst x (App (a,b)) t = App (subst x a t, subst x b t)
    1.82 +  | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t))
    1.83 +and lift level (Const c) = Const c
    1.84 +  | lift level (App (a,b)) = App (lift level a, lift level b)
    1.85 +  | lift level (Var i) = if i < level then Var i else Var (i+1)
    1.86 +  | lift level (Abs m) = Abs (lift (level + 1) m)
    1.87 +and unlift level (Const c) = Const c
    1.88 +  | unlift level (App (a, b)) = App (unlift level a, unlift level b)
    1.89 +  | unlift level (Abs m) = Abs (unlift (level+1) m)
    1.90 +  | unlift level (Var i) = if i < level then Var i else Var (i-1)
    1.91 +
    1.92 +fun nlift level n (Var m) = if m < level then Var m else Var (m+n) 
    1.93 +  | nlift level n (Const c) = Const c
    1.94 +  | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b)
    1.95 +  | nlift level n (Abs b) = Abs (nlift (level+1) n b)
    1.96 +
    1.97 +fun subst_const (c, t) (Const c') = if c = c' then t else Const c'
    1.98 +  | subst_const _ (Var i) = Var i
    1.99 +  | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b)
   1.100 +  | subst_const ct (Abs m) = Abs (subst_const ct m)
   1.101 +
   1.102 +(* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *)
   1.103 +fun inline_rules rules =
   1.104 +  let
   1.105 +    fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b
   1.106 +      | term_contains_const c (Abs m) = term_contains_const c m
   1.107 +      | term_contains_const c (Var _) = false
   1.108 +      | term_contains_const c (Const c') = (c = c')
   1.109 +    fun find_rewrite [] = NONE
   1.110 +      | find_rewrite ((prems, PConst (c, []), r) :: _) = 
   1.111 +          if check_freevars 0 r then 
   1.112 +            if term_contains_const c r then 
   1.113 +              raise Compile "parameterless rewrite is caught in cycle"
   1.114 +            else if not (null prems) then
   1.115 +              raise Compile "parameterless rewrite may not be guarded"
   1.116 +            else
   1.117 +              SOME (c, r) 
   1.118 +          else raise Compile "unbound variable on right hand side or guards of rule"
   1.119 +      | find_rewrite (_ :: rules) = find_rewrite rules
   1.120 +    fun remove_rewrite _ [] = []
   1.121 +      | remove_rewrite (cr as (c, r)) ((rule as (prems', PConst (c', args), r')) :: rules) = 
   1.122 +          if c = c' then 
   1.123 +            if null args andalso r = r' andalso null prems' then remove_rewrite cr rules 
   1.124 +            else raise Compile "incompatible parameterless rewrites found"
   1.125 +          else
   1.126 +            rule :: remove_rewrite cr rules
   1.127 +      | remove_rewrite cr (r :: rs) = r :: remove_rewrite cr rs
   1.128 +    fun pattern_contains_const c (PConst (c', args)) = c = c' orelse exists (pattern_contains_const c) args
   1.129 +      | pattern_contains_const c (PVar) = false
   1.130 +    fun inline_rewrite (ct as (c, _)) (prems, p, r) = 
   1.131 +        if pattern_contains_const c p then 
   1.132 +          raise Compile "parameterless rewrite cannot be used in pattern"
   1.133 +        else (map (fn (Guard (a, b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r)
   1.134 +    fun inline inlined rules =
   1.135 +      case find_rewrite rules of 
   1.136 +          NONE => (Inttab.make inlined, rules)
   1.137 +        | SOME ct => 
   1.138 +            let
   1.139 +              val rules = map (inline_rewrite ct) (remove_rewrite ct rules)
   1.140 +              val inlined = ct :: (map o apsnd) (subst_const ct) inlined
   1.141 +            in inline inlined rules end
   1.142 +  in
   1.143 +    inline [] rules
   1.144 +  end
   1.145 +
   1.146 +
   1.147 +(*
   1.148 +   Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity.
   1.149 +   Also beta reduce the adjusted right hand side of a rule.   
   1.150 +*)
   1.151 +fun adjust_rules rules = 
   1.152 +    let
   1.153 +        val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty
   1.154 +        val toplevel_arity = fold (fn (_, p, _) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty
   1.155 +        fun arity_of c = the (Inttab.lookup arity c)
   1.156 +        fun test_pattern PVar = ()
   1.157 +          | test_pattern (PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ())
   1.158 +        fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable")
   1.159 +          | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters")
   1.160 +          | adjust_rule (rule as (prems, p as PConst (c, args),t)) = 
   1.161 +            let
   1.162 +                val patternvars_counted = count_patternvars p
   1.163 +                fun check_fv t = check_freevars patternvars_counted t
   1.164 +                val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () 
   1.165 +                val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () 
   1.166 +                val _ = map test_pattern args           
   1.167 +                val len = length args
   1.168 +                val arity = arity_of c
   1.169 +                val lift = nlift 0
   1.170 +                fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1)))
   1.171 +                fun adjust_term n t = addapps_tm n (lift n t)
   1.172 +                fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b)
   1.173 +            in
   1.174 +                if len = arity then
   1.175 +                    rule
   1.176 +                else if arity >= len then  
   1.177 +                    (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t)
   1.178 +                else (raise Compile "internal error in adjust_rule")
   1.179 +            end
   1.180 +        fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule")
   1.181 +    in
   1.182 +        (arity, toplevel_arity, map (beta_rule o adjust_rule) rules)
   1.183 +    end             
   1.184 +
   1.185 +fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count =
   1.186 +let
   1.187 +    fun str x = string_of_int x
   1.188 +    fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s
   1.189 +    val module_prefix = (case module of NONE => "" | SOME s => s^".")                                                                                     
   1.190 +    fun print_apps d f [] = f
   1.191 +      | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args
   1.192 +    and print_call d (App (a, b)) args = print_call d a (b::args) 
   1.193 +      | print_call d (Const c) args = 
   1.194 +        (case arity_of c of 
   1.195 +             NONE => print_apps d (module_prefix^"Const "^(str c)) args 
   1.196 +           | SOME 0 => module_prefix^"C"^(str c)
   1.197 +           | SOME a =>
   1.198 +             let
   1.199 +                 val len = length args
   1.200 +             in
   1.201 +                 if a <= len then 
   1.202 +                     let
   1.203 +                         val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a)
   1.204 +                         val _ = if strict_a > a then raise Compile "strict" else ()
   1.205 +                         val s = module_prefix^"c"^(str c)^(implode (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a))))
   1.206 +                         val s = s^(implode (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a))))
   1.207 +                     in
   1.208 +                         print_apps d s (List.drop (args, a))
   1.209 +                     end
   1.210 +                 else 
   1.211 +                     let
   1.212 +                         fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1)))
   1.213 +                         fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t)
   1.214 +                         fun append_args [] t = t
   1.215 +                           | append_args (c::cs) t = append_args cs (App (t, c))
   1.216 +                     in
   1.217 +                         print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c)))))
   1.218 +                     end
   1.219 +             end)
   1.220 +      | print_call d t args = print_apps d (print_term d t) args
   1.221 +    and print_term d (Var x) = 
   1.222 +        if x < d then 
   1.223 +            "b"^(str (d-x-1)) 
   1.224 +        else 
   1.225 +            let
   1.226 +                val n = pattern_var_count - (x-d) - 1
   1.227 +                val x = "x"^(str n)
   1.228 +            in
   1.229 +                if n < pattern_var_count - pattern_lazy_var_count then 
   1.230 +                    x
   1.231 +                else 
   1.232 +                    "("^x^" ())"
   1.233 +            end                                                         
   1.234 +      | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")"
   1.235 +      | print_term d t = print_call d t []
   1.236 +in
   1.237 +    print_term 0 
   1.238 +end
   1.239 +
   1.240 +fun section n = if n = 0 then [] else (section (n-1))@[n-1]
   1.241 +
   1.242 +fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = 
   1.243 +    let 
   1.244 +        fun str x = string_of_int x                  
   1.245 +        fun print_pattern top n PVar = (n+1, "x"^(str n))
   1.246 +          | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else ""))
   1.247 +          | print_pattern top n (PConst (c, args)) = 
   1.248 +            let
   1.249 +                val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")
   1.250 +                val (n, s) = print_pattern_list 0 top (n, f) args
   1.251 +            in
   1.252 +                (n, s)
   1.253 +            end
   1.254 +        and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")")
   1.255 +          | print_pattern_list' counter top (n, p) (t::ts) = 
   1.256 +            let
   1.257 +                val (n, t) = print_pattern false n t
   1.258 +            in
   1.259 +                print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts
   1.260 +            end 
   1.261 +        and print_pattern_list counter top (n, p) (t::ts) = 
   1.262 +            let
   1.263 +                val (n, t) = print_pattern false n t
   1.264 +            in
   1.265 +                print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts
   1.266 +            end
   1.267 +        val c = (case p of PConst (c, _) => c | _ => raise Match)
   1.268 +        val (n, pattern) = print_pattern true 0 p
   1.269 +        val lazy_vars = the (arity_of c) - the (toplevel_arity_of c)
   1.270 +        fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm
   1.271 +        fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")"
   1.272 +        val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(implode (map (fn i => " a"^(str i)) (section (the (arity_of c)))))
   1.273 +        fun print_guards t [] = print_tm t
   1.274 +          | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(implode (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch
   1.275 +    in
   1.276 +        (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards))
   1.277 +    end
   1.278 +
   1.279 +fun group_rules rules =
   1.280 +    let
   1.281 +        fun add_rule (r as (_, PConst (c,_), _)) groups =
   1.282 +            let
   1.283 +                val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs)
   1.284 +            in
   1.285 +                Inttab.update (c, r::rs) groups
   1.286 +            end
   1.287 +          | add_rule _ _ = raise Compile "internal error group_rules"
   1.288 +    in
   1.289 +        fold_rev add_rule rules Inttab.empty
   1.290 +    end
   1.291 +
   1.292 +fun sml_prog name code rules = 
   1.293 +    let
   1.294 +        val buffer = Unsynchronized.ref ""
   1.295 +        fun write s = (buffer := (!buffer)^s)
   1.296 +        fun writeln s = (write s; write "\n")
   1.297 +        fun writelist [] = ()
   1.298 +          | writelist (s::ss) = (writeln s; writelist ss)
   1.299 +        fun str i = string_of_int i
   1.300 +        val (inlinetab, rules) = inline_rules rules
   1.301 +        val (arity, toplevel_arity, rules) = adjust_rules rules
   1.302 +        val rules = group_rules rules
   1.303 +        val constants = Inttab.keys arity
   1.304 +        fun arity_of c = Inttab.lookup arity c
   1.305 +        fun toplevel_arity_of c = Inttab.lookup toplevel_arity c
   1.306 +        fun rep_str s n = implode (rep n s)
   1.307 +        fun indexed s n = s^(str n)
   1.308 +        fun string_of_tuple [] = ""
   1.309 +          | string_of_tuple (x::xs) = "("^x^(implode (map (fn s => ", "^s) xs))^")"
   1.310 +        fun string_of_args [] = ""
   1.311 +          | string_of_args (x::xs) = x^(implode (map (fn s => " "^s) xs))
   1.312 +        fun default_case gnum c = 
   1.313 +            let
   1.314 +                val leftargs = implode (map (indexed " x") (section (the (arity_of c))))
   1.315 +                val rightargs = section (the (arity_of c))
   1.316 +                val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa)
   1.317 +                val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs
   1.318 +                val right = (indexed "C" c)^" "^(string_of_tuple xs)
   1.319 +                val message = "(\"unresolved lazy call: " ^ string_of_int c ^ "\")"
   1.320 +                val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right               
   1.321 +            in
   1.322 +                (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right
   1.323 +            end
   1.324 +
   1.325 +        fun eval_rules c = 
   1.326 +            let
   1.327 +                val arity = the (arity_of c)
   1.328 +                val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa)
   1.329 +                fun eval_rule n = 
   1.330 +                    let
   1.331 +                        val sc = string_of_int c
   1.332 +                        val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc)
   1.333 +                        fun arg i = 
   1.334 +                            let
   1.335 +                                val x = indexed "x" i
   1.336 +                                val x = if i < n then "(eval bounds "^x^")" else x
   1.337 +                                val x = if i < strict_arity then x else "(fn () => "^x^")"
   1.338 +                            in
   1.339 +                                x
   1.340 +                            end
   1.341 +                        val right = "c"^sc^" "^(string_of_args (map arg (section arity)))
   1.342 +                        val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right             
   1.343 +                        val right = if arity > 0 then right else "C"^sc
   1.344 +                    in
   1.345 +                        "  | eval bounds ("^left^") = "^right
   1.346 +                    end
   1.347 +            in
   1.348 +                map eval_rule (rev (section (arity + 1)))
   1.349 +            end
   1.350 +
   1.351 +        fun convert_computed_rules (c: int) : string list = 
   1.352 +            let
   1.353 +                val arity = the (arity_of c)
   1.354 +                fun eval_rule () = 
   1.355 +                    let
   1.356 +                        val sc = string_of_int c
   1.357 +                        val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc)
   1.358 +                        fun arg i = "(convert_computed "^(indexed "x" i)^")" 
   1.359 +                        val right = "C"^sc^" "^(string_of_tuple (map arg (section arity)))              
   1.360 +                        val right = if arity > 0 then right else "C"^sc
   1.361 +                    in
   1.362 +                        "  | convert_computed ("^left^") = "^right
   1.363 +                    end
   1.364 +            in
   1.365 +                [eval_rule ()]
   1.366 +            end
   1.367 +        
   1.368 +        fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else ""
   1.369 +        val _ = writelist [                   
   1.370 +                "structure "^name^" = struct",
   1.371 +                "",
   1.372 +                "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)",
   1.373 +                "         "^(implode (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)),
   1.374 +                ""]
   1.375 +        fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")"
   1.376 +        fun make_term_eq c = "  | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^
   1.377 +                             (case the (arity_of c) of 
   1.378 +                                  0 => "true"
   1.379 +                                | n => 
   1.380 +                                  let 
   1.381 +                                      val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n)
   1.382 +                                      val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs))
   1.383 +                                  in
   1.384 +                                      eq^(implode eqs)
   1.385 +                                  end)
   1.386 +        val _ = writelist [
   1.387 +                "fun term_eq (Const c1) (Const c2) = (c1 = c2)",
   1.388 +                "  | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"]
   1.389 +        val _ = writelist (map make_term_eq constants)          
   1.390 +        val _ = writelist [
   1.391 +                "  | term_eq _ _ = false",
   1.392 +                "" 
   1.393 +                ] 
   1.394 +        val _ = writelist [
   1.395 +                "fun app (Abs a) b = a b",
   1.396 +                "  | app a b = App (a, b)",
   1.397 +                ""]     
   1.398 +        fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else [])
   1.399 +        fun writefundecl [] = () 
   1.400 +          | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => "  | "^s) xs)))
   1.401 +        fun list_group c = (case Inttab.lookup rules c of 
   1.402 +                                NONE => [defcase 0 c]
   1.403 +                              | SOME rs => 
   1.404 +                                let
   1.405 +                                    val rs = 
   1.406 +                                        fold
   1.407 +                                            (fn r => 
   1.408 +                                             fn rs =>
   1.409 +                                                let 
   1.410 +                                                    val (gnum, l, rs) = 
   1.411 +                                                        (case rs of 
   1.412 +                                                             [] => (0, [], []) 
   1.413 +                                                           | (gnum, l)::rs => (gnum, l, rs))
   1.414 +                                                    val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r 
   1.415 +                                                in 
   1.416 +                                                    if gnum' = gnum then 
   1.417 +                                                        (gnum, r::l)::rs
   1.418 +                                                    else
   1.419 +                                                        let
   1.420 +                                                            val args = implode (map (fn i => " a"^(str i)) (section (the (arity_of c))))
   1.421 +                                                            fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args
   1.422 +                                                            val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') 
   1.423 +                                                        in
   1.424 +                                                            (gnum', [])::(gnum, s::r::l)::rs
   1.425 +                                                        end
   1.426 +                                                end)
   1.427 +                                        rs []
   1.428 +                                    val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs)
   1.429 +                                in
   1.430 +                                    rev (map (fn z => rev (snd z)) rs)
   1.431 +                                end)
   1.432 +        val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants)
   1.433 +        val _ = writelist [
   1.434 +                "fun convert (Const i) = AM_SML.Const i",
   1.435 +                "  | convert (App (a, b)) = AM_SML.App (convert a, convert b)",
   1.436 +                "  | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""]  
   1.437 +        fun make_convert c = 
   1.438 +            let
   1.439 +                val args = map (indexed "a") (section (the (arity_of c)))
   1.440 +                val leftargs = 
   1.441 +                    case args of
   1.442 +                        [] => ""
   1.443 +                      | (x::xs) => "("^x^(implode (map (fn s => ", "^s) xs))^")"
   1.444 +                val args = map (indexed "convert a") (section (the (arity_of c)))
   1.445 +                val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c))
   1.446 +            in
   1.447 +                "  | convert (C"^(str c)^" "^leftargs^") = "^right
   1.448 +            end                 
   1.449 +        val _ = writelist (map make_convert constants)
   1.450 +        val _ = writelist [
   1.451 +                "",
   1.452 +                "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"",
   1.453 +                "  | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""]
   1.454 +        val _ = map (writelist o convert_computed_rules) constants
   1.455 +        val _ = writelist [
   1.456 +                "  | convert_computed (AbstractMachine.Const c) = Const c",
   1.457 +                "  | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)",
   1.458 +                "  | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] 
   1.459 +        val _ = writelist [
   1.460 +                "",
   1.461 +                "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)",
   1.462 +                "  | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"]
   1.463 +        val _ = map (writelist o eval_rules) constants
   1.464 +        val _ = writelist [
   1.465 +                "  | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)",
   1.466 +                "  | eval bounds (AbstractMachine.Const c) = Const c",
   1.467 +                "  | eval bounds (AbstractMachine.Computed t) = convert_computed t"]                
   1.468 +        val _ = writelist [             
   1.469 +                "",
   1.470 +                "fun export term = AM_SML.save_result (\""^code^"\", convert term)",
   1.471 +                "",
   1.472 +                "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))",
   1.473 +                "",
   1.474 +                "end"]
   1.475 +    in
   1.476 +        (inlinetab, !buffer)
   1.477 +    end
   1.478 +
   1.479 +val guid_counter = Unsynchronized.ref 0
   1.480 +fun get_guid () = 
   1.481 +    let
   1.482 +        val c = !guid_counter
   1.483 +        val _ = guid_counter := !guid_counter + 1
   1.484 +    in
   1.485 +        string_of_int (Time.toMicroseconds (Time.now ())) ^ string_of_int c
   1.486 +    end
   1.487 +
   1.488 +
   1.489 +fun writeTextFile name s = File.write (Path.explode name) s
   1.490 +
   1.491 +fun use_source src = use_text ML_Env.local_context (1, "") false src
   1.492 +    
   1.493 +fun compile rules = 
   1.494 +    let
   1.495 +        val guid = get_guid ()
   1.496 +        val code = Real.toString (random ())
   1.497 +        val name = "AMSML_"^guid
   1.498 +        val (inlinetab, source) = sml_prog name code rules
   1.499 +        val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source
   1.500 +        val _ = compiled_rewriter := NONE
   1.501 +        val _ = use_source source
   1.502 +    in
   1.503 +        case !compiled_rewriter of 
   1.504 +            NONE => raise Compile "broken link to compiled function"
   1.505 +          | SOME compiled_fun => (inlinetab, compiled_fun)
   1.506 +    end
   1.507 +
   1.508 +fun run (inlinetab, compiled_fun) t = 
   1.509 +    let 
   1.510 +        val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms")
   1.511 +        fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t)
   1.512 +          | inline (Var i) = Var i
   1.513 +          | inline (App (a, b)) = App (inline a, inline b)
   1.514 +          | inline (Abs m) = Abs (inline m)
   1.515 +          | inline (Computed t) = Computed t
   1.516 +    in
   1.517 +        compiled_fun (beta (inline t))
   1.518 +    end 
   1.519 +
   1.520 +end