src/Tools/Compute_Oracle/compute.ML
author haftmann
Sat, 15 Sep 2007 19:27:35 +0200
changeset 24584 01e83ffa6c54
parent 24137 8d7896398147
child 24654 329f1b4d9d16
permissions -rw-r--r--
fixed title
     1 (*  Title:      Tools/Compute_Oracle/compute.ML
     2     ID:         $Id$
     3     Author:     Steven Obua
     4 *)
     5 
     6 signature COMPUTE = sig
     7 
     8     type computer
     9 
    10     datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML
    11 
    12     exception Make of string
    13     val make : machine -> theory -> thm list -> computer
    14 
    15     exception Compute of string
    16     val compute : computer -> (int -> string) -> cterm -> term
    17     val theory_of : computer -> theory
    18     val hyps_of : computer -> term list
    19     val shyps_of : computer -> sort list
    20 
    21     val rewrite_param : computer -> (int -> string) -> cterm -> thm
    22     val rewrite : computer -> cterm -> thm
    23 
    24     val discard : computer -> unit
    25 
    26     val setup : theory -> theory
    27 
    28 end
    29 
    30 structure Compute :> COMPUTE = struct
    31 
    32 datatype machine = BARRAS | BARRAS_COMPILED | HASKELL | SML	 
    33 
    34 (* Terms are mapped to integer codes *)
    35 structure Encode :> 
    36 sig
    37     type encoding
    38     val empty : encoding
    39     val insert : term -> encoding -> int * encoding
    40     val lookup_code : term -> encoding -> int option
    41     val lookup_term : int -> encoding -> term option					
    42     val remove_code : int -> encoding -> encoding
    43     val remove_term : term -> encoding -> encoding
    44     val fold : ((term * int) -> 'a -> 'a) -> encoding -> 'a -> 'a    
    45 end 
    46 = 
    47 struct
    48 
    49 type encoding = int * (int Termtab.table) * (term Inttab.table)
    50 
    51 val empty = (0, Termtab.empty, Inttab.empty)
    52 
    53 fun insert t (e as (count, term2int, int2term)) = 
    54     (case Termtab.lookup term2int t of
    55 	 NONE => (count, (count+1, Termtab.update_new (t, count) term2int, Inttab.update_new (count, t) int2term))
    56        | SOME code => (code, e))
    57 
    58 fun lookup_code t (_, term2int, _) = Termtab.lookup term2int t
    59 
    60 fun lookup_term c (_, _, int2term) = Inttab.lookup int2term c
    61 
    62 fun remove_code c (e as (count, term2int, int2term)) = 
    63     (case lookup_term c e of NONE => e | SOME t => (count, Termtab.delete t term2int, Inttab.delete c int2term))
    64 
    65 fun remove_term t (e as (count, term2int, int2term)) = 
    66     (case lookup_code t e of NONE => e | SOME c => (count, Termtab.delete t term2int, Inttab.delete c int2term))
    67 
    68 fun fold f (_, term2int, _) = Termtab.fold f term2int 
    69 
    70 end
    71 
    72 
    73 exception Make of string;
    74 exception Compute of string;
    75 
    76 local
    77     fun make_constant t ty encoding = 
    78 	let 
    79 	    val (code, encoding) = Encode.insert t encoding 
    80 	in 
    81 	    (encoding, AbstractMachine.Const code)
    82 	end
    83 in
    84 
    85 fun remove_types encoding t =
    86     case t of 
    87 	Var (_, ty) => make_constant t ty encoding
    88       | Free (_, ty) => make_constant t ty encoding
    89       | Const (_, ty) => make_constant t ty encoding
    90       | Abs (_, ty, t') => 
    91 	let val (encoding, t'') = remove_types encoding t' in
    92 	    (encoding, AbstractMachine.Abs t'')
    93 	end
    94       | a $ b => 
    95 	let
    96 	    val (encoding, a) = remove_types encoding a
    97 	    val (encoding, b) = remove_types encoding b
    98 	in
    99 	    (encoding, AbstractMachine.App (a,b))
   100 	end
   101       | Bound b => (encoding, AbstractMachine.Var b)
   102 end
   103     
   104 local
   105     fun type_of (Free (_, ty)) = ty
   106       | type_of (Const (_, ty)) = ty
   107       | type_of (Var (_, ty)) = ty
   108       | type_of _ = sys_error "infer_types: type_of error"
   109 in
   110 fun infer_types naming encoding =
   111     let
   112         fun infer_types _ bounds _ (AbstractMachine.Var v) = (Bound v, List.nth (bounds, v))
   113 	  | infer_types _ bounds _ (AbstractMachine.Const code) = 
   114 	    let
   115 		val c = the (Encode.lookup_term code encoding)
   116 	    in
   117 		(c, type_of c)
   118 	    end
   119 	  | infer_types level bounds _ (AbstractMachine.App (a, b)) = 
   120 	    let
   121 		val (a, aty) = infer_types level bounds NONE a
   122 		val (adom, arange) =
   123                     case aty of
   124                         Type ("fun", [dom, range]) => (dom, range)
   125                       | _ => sys_error "infer_types: function type expected"
   126                 val (b, bty) = infer_types level bounds (SOME adom) b
   127 	    in
   128 		(a $ b, arange)
   129 	    end
   130           | infer_types level bounds (SOME (ty as Type ("fun", [dom, range]))) (AbstractMachine.Abs m) =
   131             let
   132                 val (m, _) = infer_types (level+1) (dom::bounds) (SOME range) m
   133             in
   134                 (Abs (naming level, dom, m), ty)
   135             end
   136           | infer_types _ _ NONE (AbstractMachine.Abs m) = sys_error "infer_types: cannot infer type of abstraction"
   137 
   138         fun infer ty term =
   139             let
   140                 val (term', _) = infer_types 0 [] (SOME ty) term
   141             in
   142                 term'
   143             end
   144     in
   145         infer
   146     end
   147 end
   148 
   149 datatype prog = 
   150 	 ProgBarras of AM_Interpreter.program 
   151        | ProgBarrasC of AM_Compiler.program
   152        | ProgHaskell of AM_GHC.program
   153        | ProgSML of AM_SML.program
   154 
   155 structure Sorttab = TableFun(type key = sort val ord = Term.sort_ord)
   156 
   157 datatype computer = Computer of theory_ref * Encode.encoding * term list * unit Sorttab.table * prog
   158 
   159 datatype cthm = ComputeThm of term list * sort list * term
   160 
   161 fun thm2cthm th = 
   162     let
   163 	val {hyps, prop, tpairs, shyps, ...} = Thm.rep_thm th
   164 	val _ = if not (null tpairs) then raise Make "theorems may not contain tpairs" else ()
   165     in
   166 	ComputeThm (hyps, shyps, prop)
   167     end
   168 
   169 fun make machine thy raw_ths =
   170     let
   171 	fun transfer (x:thm) = Thm.transfer thy x
   172 	val ths = map (thm2cthm o Thm.strip_shyps o transfer) raw_ths
   173 
   174         fun thm2rule (encoding, hyptable, shyptable) th =
   175             let
   176 		val (ComputeThm (hyps, shyps, prop)) = th
   177 		val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable
   178 		val shyptable = fold (fn sh => Sorttab.update (sh, ())) shyps shyptable
   179 		val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop)
   180                 val (a, b) = Logic.dest_equals prop
   181                   handle TERM _ => raise (Make "theorems must be meta-level equations (with optional guards)")
   182 		val a = Envir.eta_contract a
   183 		val b = Envir.eta_contract b
   184 		val prems = map Envir.eta_contract prems
   185 
   186                 val (encoding, left) = remove_types encoding a     
   187 		val (encoding, right) = remove_types encoding b  
   188                 fun remove_types_of_guard encoding g = 
   189 		    (let
   190 			 val (t1, t2) = Logic.dest_equals g 
   191 			 val (encoding, t1) = remove_types encoding t1
   192 			 val (encoding, t2) = remove_types encoding t2
   193 		     in
   194 			 (encoding, AbstractMachine.Guard (t1, t2))
   195 		     end handle TERM _ => raise (Make "guards must be meta-level equations"))
   196                 val (encoding, prems) = fold_rev (fn p => fn (encoding, ps) => let val (e, p) = remove_types_of_guard encoding p in (e, p::ps) end) prems (encoding, [])
   197                 
   198                 fun make_pattern encoding n vars (var as AbstractMachine.Abs _) =
   199 		    raise (Make "no lambda abstractions allowed in pattern")
   200 		  | make_pattern encoding n vars (var as AbstractMachine.Var _) =
   201 		    raise (Make "no bound variables allowed in pattern")
   202 		  | make_pattern encoding n vars (AbstractMachine.Const code) =
   203 		    (case the (Encode.lookup_term code encoding) of
   204 			 Var _ => ((n+1, Inttab.update_new (code, n) vars, AbstractMachine.PVar)
   205 				   handle Inttab.DUP _ => raise (Make "no duplicate variable in pattern allowed"))
   206 		       | _ => (n, vars, AbstractMachine.PConst (code, [])))
   207                   | make_pattern encoding n vars (AbstractMachine.App (a, b)) =
   208                     let
   209                         val (n, vars, pa) = make_pattern encoding n vars a
   210                         val (n, vars, pb) = make_pattern encoding n vars b
   211                     in
   212                         case pa of
   213                             AbstractMachine.PVar =>
   214                               raise (Make "patterns may not start with a variable")
   215                           | AbstractMachine.PConst (c, args) =>
   216                               (n, vars, AbstractMachine.PConst (c, args@[pb]))
   217                     end
   218 
   219                 (* Principally, a check should be made here to see if the (meta-) hyps contain any of the variables of the rule.
   220                    As it is, all variables of the rule are schematic, and there are no schematic variables in meta-hyps, therefore
   221                    this check can be left out. *)
   222 
   223                 val (vcount, vars, pattern) = make_pattern encoding 0 Inttab.empty left
   224                 val _ = (case pattern of
   225                              AbstractMachine.PVar =>
   226                              raise (Make "patterns may not start with a variable")
   227                          (*  | AbstractMachine.PConst (_, []) => 
   228 			     (print th; raise (Make "no parameter rewrite found"))*)
   229 			   | _ => ())
   230 
   231                 (* finally, provide a function for renaming the
   232                    pattern bound variables on the right hand side *)
   233 
   234                 fun rename level vars (var as AbstractMachine.Var _) = var
   235 		  | rename level vars (c as AbstractMachine.Const code) =
   236 		    (case Inttab.lookup vars code of 
   237 			 NONE => c 
   238 		       | SOME n => AbstractMachine.Var (vcount-n-1+level))
   239                   | rename level vars (AbstractMachine.App (a, b)) =
   240                     AbstractMachine.App (rename level vars a, rename level vars b)
   241                   | rename level vars (AbstractMachine.Abs m) =
   242                     AbstractMachine.Abs (rename (level+1) vars m)
   243 		    
   244 		fun rename_guard (AbstractMachine.Guard (a,b)) = 
   245 		    AbstractMachine.Guard (rename 0 vars a, rename 0 vars b)
   246             in
   247                 ((encoding, hyptable, shyptable), (map rename_guard prems, pattern, rename 0 vars right))
   248             end
   249 
   250         val ((encoding, hyptable, shyptable), rules) =
   251           fold_rev (fn th => fn (encoding_hyptable, rules) =>
   252             let
   253               val (encoding_hyptable, rule) = thm2rule encoding_hyptable th
   254             in (encoding_hyptable, rule::rules) end)
   255           ths ((Encode.empty, Termtab.empty, Sorttab.empty), [])
   256 
   257         val prog = 
   258 	    case machine of 
   259 		BARRAS => ProgBarras (AM_Interpreter.compile rules)
   260 	      | BARRAS_COMPILED => ProgBarrasC (AM_Compiler.compile rules)
   261 	      | HASKELL => ProgHaskell (AM_GHC.compile rules)
   262 	      | SML => ProgSML (AM_SML.compile rules)
   263 
   264 (*	val _ = print (Encode.fold (fn x => fn s => x::s) encoding [])*)
   265 
   266         fun has_witness s = not (null (Sign.witness_sorts thy [] [s]))
   267 
   268 	val shyptable = fold Sorttab.delete (filter has_witness (Sorttab.keys (shyptable))) shyptable
   269 
   270     in Computer (Theory.check_thy thy, encoding, Termtab.keys hyptable, shyptable, prog) end
   271 
   272 (*fun timeit f =
   273     let
   274 	val t1 = Time.toMicroseconds (Time.now ())
   275 	val x = f ()
   276 	val t2 = Time.toMicroseconds (Time.now ())
   277 	val _ = writeln ("### time = "^(Real.toString ((Real.fromLargeInt t2 - Real.fromLargeInt t1)/(1000000.0)))^"s")
   278     in
   279 	x
   280     end*)
   281 
   282 fun report s f = f () (*writeln s; timeit f*)
   283 
   284 fun compute (Computer (rthy, encoding, hyps, shyptable, prog)) naming ct =
   285     let
   286 	fun run (ProgBarras p) = AM_Interpreter.run p
   287 	  | run (ProgBarrasC p) = AM_Compiler.run p
   288 	  | run (ProgHaskell p) = AM_GHC.run p
   289 	  | run (ProgSML p) = AM_SML.run p	    
   290         val {t=t, T=ty, thy=ctthy, ...} = rep_cterm ct
   291         val thy = Theory.merge (Theory.deref rthy, ctthy)
   292         val (encoding, t) = report "remove_types" (fn () => remove_types encoding t)
   293         val t = report "run" (fn () => run prog t)
   294         val t = report "infer_types" (fn () => infer_types naming encoding ty t)
   295     in
   296         t
   297     end
   298 
   299 fun discard (Computer (rthy, encoding, hyps, shyptable, prog)) = 
   300     (case prog of
   301 	 ProgBarras p => AM_Interpreter.discard p
   302        | ProgBarrasC p => AM_Compiler.discard p
   303        | ProgHaskell p => AM_GHC.discard p
   304        | ProgSML p => AM_SML.discard p)
   305 
   306 fun theory_of (Computer (rthy, _, _,_,_)) = Theory.deref rthy
   307 fun hyps_of (Computer (_, _, hyps, _, _)) = hyps
   308 fun shyps_of (Computer (_, _, _, shyptable, _)) = Sorttab.keys (shyptable)
   309 fun shyptab_of (Computer (_, _, _, shyptable, _)) = shyptable
   310 
   311 fun default_naming i = "v_" ^ Int.toString i
   312 
   313 exception Param of computer * (int -> string) * cterm;
   314 
   315 fun rewrite_param r n ct =
   316     let 
   317 	val thy = theory_of_cterm ct 
   318 	val th = timeit (fn () => invoke_oracle_i thy "Compute_Oracle.compute" (thy, Param (r, n, ct)))
   319 	val hyps = map (fn h => assume (cterm_of thy h)) (hyps_of r)
   320     in
   321 	fold (fn h => fn p => implies_elim p h) hyps th 
   322     end
   323 
   324 (*fun rewrite_param r n ct =
   325     let	
   326 	val hyps = hyps_of r
   327 	val shyps = shyps_of r
   328 	val thy = theory_of_cterm ct
   329 	val _ = Theory.assert_super (theory_of r) thy
   330 	val t' = timeit (fn () => compute r n ct)
   331 	val eq = Logic.mk_equals (term_of ct, t')
   332     in
   333 	Thm.unchecked_oracle thy "Compute.compute" (eq, hyps, shyps)
   334     end*)
   335 
   336 fun rewrite r ct = rewrite_param r default_naming ct
   337 
   338 (* theory setup *)
   339 
   340 fun compute_oracle (thy, Param (r, naming, ct)) =
   341     let
   342         val _ = Theory.assert_super (theory_of r) thy
   343         val t' = compute r naming ct
   344 	val eq = Logic.mk_equals (term_of ct, t')
   345 	val hyps = hyps_of r
   346 	val shyptab = shyptab_of r
   347 	fun delete s shyptab = Sorttab.delete s shyptab handle Sorttab.UNDEF _ => shyptab
   348 	fun delete_term t shyptab = fold delete (Sorts.insert_term t []) shyptab
   349 	val shyps = if Sorttab.is_empty shyptab then [] else Sorttab.keys (fold delete_term (eq::hyps) shyptab)
   350 	val _ = if not (null shyps) then raise Compute ("dangling sort hypotheses: "^(makestring shyps)) else ()
   351     in
   352         fold_rev (fn hyp => fn p => Logic.mk_implies (hyp, p)) hyps eq
   353     end
   354   | compute_oracle _ = raise Match
   355 
   356 
   357 val setup = (fn thy => (writeln "install oracle"; Theory.add_oracle ("compute", compute_oracle) thy))
   358 
   359 (*val _ = Context.add_setup (Theory.add_oracle ("compute", compute_oracle))*)
   360 
   361 end
   362