src/HOL/Matrix/Compute_Oracle/am_compiler.ML
changeset 47859 9f492f5b0cec
parent 47858 15ce93dfe6da
child 47861 67cf9a6308f3
child 47862 88b0a8052c75
     1.1 --- a/src/HOL/Matrix/Compute_Oracle/am_compiler.ML	Sat Mar 17 12:26:19 2012 +0100
     1.2 +++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.3 @@ -1,208 +0,0 @@
     1.4 -(*  Title:      HOL/Matrix/Compute_Oracle/am_compiler.ML
     1.5 -    Author:     Steven Obua
     1.6 -*)
     1.7 -
     1.8 -signature COMPILING_AM = 
     1.9 -sig
    1.10 -  include ABSTRACT_MACHINE
    1.11 -
    1.12 -  val set_compiled_rewriter : (term -> term) -> unit
    1.13 -  val list_nth : 'a list * int -> 'a
    1.14 -  val list_map : ('a -> 'b) -> 'a list -> 'b list
    1.15 -end
    1.16 -
    1.17 -structure AM_Compiler : COMPILING_AM = struct
    1.18 -
    1.19 -val list_nth = List.nth;
    1.20 -val list_map = map;
    1.21 -
    1.22 -open AbstractMachine;
    1.23 -
    1.24 -val compiled_rewriter = Unsynchronized.ref (NONE:(term -> term)Option.option)
    1.25 -
    1.26 -fun set_compiled_rewriter r = (compiled_rewriter := SOME r)
    1.27 -
    1.28 -type program = (term -> term)
    1.29 -
    1.30 -fun count_patternvars PVar = 1
    1.31 -  | count_patternvars (PConst (_, ps)) =
    1.32 -      List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps
    1.33 -
    1.34 -fun print_rule (p, t) = 
    1.35 -    let
    1.36 -        fun str x = string_of_int x
    1.37 -        fun print_pattern n PVar = (n+1, "x"^(str n))
    1.38 -          | print_pattern n (PConst (c, [])) = (n, "c"^(str c))
    1.39 -          | print_pattern n (PConst (c, args)) = 
    1.40 -            let
    1.41 -                val h = print_pattern n (PConst (c,[]))
    1.42 -            in
    1.43 -                print_pattern_list h args
    1.44 -            end
    1.45 -        and print_pattern_list r [] = r
    1.46 -          | print_pattern_list (n, p) (t::ts) = 
    1.47 -            let
    1.48 -                val (n, t) = print_pattern n t
    1.49 -            in
    1.50 -                print_pattern_list (n, "App ("^p^", "^t^")") ts
    1.51 -            end
    1.52 -
    1.53 -        val (n, pattern) = print_pattern 0 p
    1.54 -        val pattern =
    1.55 -            if exists_string Symbol.is_ascii_blank pattern then "(" ^ pattern ^")"
    1.56 -            else pattern
    1.57 -        
    1.58 -        fun print_term d (Var x) = "Var " ^ str x
    1.59 -          | print_term d (Const c) = "c" ^ str c
    1.60 -          | print_term d (App (a,b)) = "App (" ^ print_term d a ^ ", " ^ print_term d b ^ ")"
    1.61 -          | print_term d (Abs c) = "Abs (" ^ print_term (d + 1) c ^ ")"
    1.62 -          | print_term d (Computed c) = print_term d c
    1.63 -
    1.64 -        fun listvars n = if n = 0 then "x0" else "x"^(str n)^", "^(listvars (n-1))
    1.65 -
    1.66 -        val term = print_term 0 t
    1.67 -        val term =
    1.68 -            if n > 0 then "Closure (["^(listvars (n-1))^"], "^term^")"
    1.69 -            else "Closure ([], "^term^")"
    1.70 -                           
    1.71 -    in
    1.72 -        "  | weak_reduce (false, stack, "^pattern^") = Continue (false, stack, "^term^")"
    1.73 -    end
    1.74 -
    1.75 -fun constants_of PVar = []
    1.76 -  | constants_of (PConst (c, ps)) = c :: maps constants_of ps
    1.77 -
    1.78 -fun constants_of_term (Var _) = []
    1.79 -  | constants_of_term (Abs m) = constants_of_term m
    1.80 -  | constants_of_term (App (a,b)) = (constants_of_term a)@(constants_of_term b)
    1.81 -  | constants_of_term (Const c) = [c]
    1.82 -  | constants_of_term (Computed c) = constants_of_term c
    1.83 -    
    1.84 -fun load_rules sname name prog = 
    1.85 -    let
    1.86 -        val buffer = Unsynchronized.ref ""
    1.87 -        fun write s = (buffer := (!buffer)^s)
    1.88 -        fun writeln s = (write s; write "\n")
    1.89 -        fun writelist [] = ()
    1.90 -          | writelist (s::ss) = (writeln s; writelist ss)
    1.91 -        fun str i = string_of_int i
    1.92 -        val _ = writelist [
    1.93 -                "structure "^name^" = struct",
    1.94 -                "",
    1.95 -                "datatype term = Dummy | App of term * term | Abs of term | Var of int | Const of int | Closure of term list * term"]
    1.96 -        val constants = distinct (op =) (maps (fn (p, r) => ((constants_of p)@(constants_of_term r))) prog)
    1.97 -        val _ = map (fn x => write (" | c"^(str x))) constants
    1.98 -        val _ = writelist [
    1.99 -                "",
   1.100 -                "datatype stack = SEmpty | SAppL of term * stack | SAppR of term * stack | SAbs of stack",
   1.101 -                "",
   1.102 -                "type state = bool * stack * term",
   1.103 -                "",
   1.104 -                "datatype loopstate = Continue of state | Stop of stack * term",
   1.105 -                "",
   1.106 -                "fun proj_C (Continue s) = s",
   1.107 -                "  | proj_C _ = raise Match",
   1.108 -                "",
   1.109 -                "fun proj_S (Stop s) = s",
   1.110 -                "  | proj_S _ = raise Match",
   1.111 -                "",
   1.112 -                "fun cont (Continue _) = true",
   1.113 -                "  | cont _ = false",
   1.114 -                "",
   1.115 -                "fun do_reduction reduce p =",
   1.116 -                "    let",
   1.117 -                "       val s = Unsynchronized.ref (Continue p)",
   1.118 -                "       val _ = while cont (!s) do (s := reduce (proj_C (!s)))",
   1.119 -                "   in",
   1.120 -                "       proj_S (!s)",
   1.121 -                "   end",
   1.122 -                ""]
   1.123 -
   1.124 -        val _ = writelist [
   1.125 -                "fun weak_reduce (false, stack, Closure (e, App (a, b))) = Continue (false, SAppL (Closure (e, b), stack), Closure (e, a))",
   1.126 -                "  | weak_reduce (false, SAppL (b, stack), Closure (e, Abs m)) = Continue (false, stack, Closure (b::e, m))",
   1.127 -                "  | weak_reduce (false, stack, c as Closure (e, Abs m)) = Continue (true, stack, c)",
   1.128 -                "  | weak_reduce (false, stack, Closure (e, Var n)) = Continue (false, stack, case "^sname^".list_nth (e, n) of Dummy => Var n | r => r)",
   1.129 -                "  | weak_reduce (false, stack, Closure (e, c)) = Continue (false, stack, c)"]
   1.130 -        val _ = writelist (map print_rule prog)
   1.131 -        val _ = writelist [
   1.132 -                "  | weak_reduce (false, stack, clos) = Continue (true, stack, clos)",
   1.133 -                "  | weak_reduce (true, SAppR (a, stack), b) = Continue (false, stack, App (a,b))",
   1.134 -                "  | weak_reduce (true, s as (SAppL (b, stack)), a) = Continue (false, SAppR (a, stack), b)",
   1.135 -                "  | weak_reduce (true, stack, c) = Stop (stack, c)",
   1.136 -                "",
   1.137 -                "fun strong_reduce (false, stack, Closure (e, Abs m)) =",
   1.138 -                "    let",
   1.139 -                "        val (stack', wnf) = do_reduction weak_reduce (false, SEmpty, Closure (Dummy::e, m))",
   1.140 -                "    in",
   1.141 -                "        case stack' of",
   1.142 -                "            SEmpty => Continue (false, SAbs stack, wnf)",
   1.143 -                "          | _ => raise ("^sname^".Run \"internal error in strong: weak failed\")",
   1.144 -                "    end",              
   1.145 -                "  | strong_reduce (false, stack, clos as (App (u, v))) = Continue (false, SAppL (v, stack), u)",
   1.146 -                "  | strong_reduce (false, stack, clos) = Continue (true, stack, clos)",
   1.147 -                "  | strong_reduce (true, SAbs stack, m) = Continue (false, stack, Abs m)",
   1.148 -                "  | strong_reduce (true, SAppL (b, stack), a) = Continue (false, SAppR (a, stack), b)",
   1.149 -                "  | strong_reduce (true, SAppR (a, stack), b) = Continue (true, stack, App (a, b))",
   1.150 -                "  | strong_reduce (true, stack, clos) = Stop (stack, clos)",
   1.151 -                ""]
   1.152 -        
   1.153 -        val ic = "(case c of "^(implode (map (fn c => (str c)^" => c"^(str c)^" | ") constants))^" _ => Const c)"                                                       
   1.154 -        val _ = writelist [
   1.155 -                "fun importTerm ("^sname^".Var x) = Var x",
   1.156 -                "  | importTerm ("^sname^".Const c) =  "^ic,
   1.157 -                "  | importTerm ("^sname^".App (a, b)) = App (importTerm a, importTerm b)",
   1.158 -                "  | importTerm ("^sname^".Abs m) = Abs (importTerm m)",
   1.159 -                ""]
   1.160 -
   1.161 -        fun ec c = "  | exportTerm c"^(str c)^" = "^sname^".Const "^(str c)
   1.162 -        val _ = writelist [
   1.163 -                "fun exportTerm (Var x) = "^sname^".Var x",
   1.164 -                "  | exportTerm (Const c) = "^sname^".Const c",
   1.165 -                "  | exportTerm (App (a,b)) = "^sname^".App (exportTerm a, exportTerm b)",
   1.166 -                "  | exportTerm (Abs m) = "^sname^".Abs (exportTerm m)",
   1.167 -                "  | exportTerm (Closure (closlist, clos)) = raise ("^sname^".Run \"internal error, cannot export Closure\")",
   1.168 -                "  | exportTerm Dummy = raise ("^sname^".Run \"internal error, cannot export Dummy\")"]
   1.169 -        val _ = writelist (map ec constants)
   1.170 -                
   1.171 -        val _ = writelist [
   1.172 -                "",
   1.173 -                "fun rewrite t = ",
   1.174 -                "    let",
   1.175 -                "      val (stack, wnf) = do_reduction weak_reduce (false, SEmpty, Closure ([], importTerm t))",
   1.176 -                "    in",
   1.177 -                "      case stack of ",
   1.178 -                "           SEmpty => (case do_reduction strong_reduce (false, SEmpty, wnf) of",
   1.179 -                "                          (SEmpty, snf) => exportTerm snf",
   1.180 -                "                        | _ => raise ("^sname^".Run \"internal error in rewrite: strong failed\"))",
   1.181 -                "         | _ => (raise ("^sname^".Run \"internal error in rewrite: weak failed\"))",
   1.182 -                "    end",
   1.183 -                "",
   1.184 -                "val _ = "^sname^".set_compiled_rewriter rewrite",
   1.185 -                "",
   1.186 -                "end;"]
   1.187 -
   1.188 -    in
   1.189 -        compiled_rewriter := NONE;      
   1.190 -        use_text ML_Env.local_context (1, "") false (!buffer);
   1.191 -        case !compiled_rewriter of 
   1.192 -            NONE => raise (Compile "cannot communicate with compiled function")
   1.193 -          | SOME r => (compiled_rewriter := NONE; r)
   1.194 -    end 
   1.195 -
   1.196 -fun compile eqs = 
   1.197 -    let
   1.198 -        val _ = if exists (fn (a,_,_) => not (null a)) eqs then raise Compile ("cannot deal with guards") else ()
   1.199 -        val eqs = map (fn (_,b,c) => (b,c)) eqs
   1.200 -        fun check (p, r) = if check_freevars (count_patternvars p) r then () else raise Compile ("unbound variables in rule") 
   1.201 -        val _ = map (fn (p, r) => 
   1.202 -                  (check (p, r); 
   1.203 -                   case p of PVar => raise (Compile "pattern is just a variable") | _ => ())) eqs
   1.204 -    in
   1.205 -        load_rules "AM_Compiler" "AM_compiled_code" eqs
   1.206 -    end 
   1.207 -
   1.208 -fun run prog t = prog t
   1.209 -
   1.210 -end
   1.211 -