src/Tools/nbe.ML
author wenzelm
Tue, 09 Oct 2007 00:20:13 +0200
changeset 24920 2a45e400fdad
parent 24867 e5b55d7be9bb
child 25080 21d44e3aea4c
permissions -rw-r--r--
generic Syntax.pretty/string_of operations;
     1 (*  Title:      Tools/nbe.ML
     2     ID:         $Id$
     3     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
     4 
     5 Normalization by evaluation, based on generic code generator.
     6 *)
     7 
     8 signature NBE =
     9 sig
    10   datatype Univ = 
    11       Const of string * Univ list            (*named (uninterpreted) constants*)
    12     | Free of string * Univ list
    13     | BVar of int * Univ list
    14     | Abs of (int * (Univ list -> Univ)) * Univ list;
    15   val free: string -> Univ list -> Univ       (*free (uninterpreted) variables*)
    16   val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
    17                                             (*abstractions as functions*)
    18   val app: Univ -> Univ -> Univ              (*explicit application*)
    19 
    20   val univs_ref: (unit -> Univ list) ref 
    21   val lookup_fun: string -> Univ
    22 
    23   val norm_conv: cterm -> thm
    24   val norm_term: theory -> term -> term
    25 
    26   val trace: bool ref
    27   val setup: theory -> theory
    28 end;
    29 
    30 structure Nbe: NBE =
    31 struct
    32 
    33 (* generic non-sense *)
    34 
    35 val trace = ref false;
    36 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
    37 
    38 
    39 (** the semantical universe **)
    40 
    41 (*
    42    Functions are given by their semantical function value. To avoid
    43    trouble with the ML-type system, these functions have the most
    44    generic type, that is "Univ list -> Univ". The calling convention is
    45    that the arguments come as a list, the last argument first. In
    46    other words, a function call that usually would look like
    47 
    48    f x_1 x_2 ... x_n   or   f(x_1,x_2, ..., x_n)
    49 
    50    would be in our convention called as
    51 
    52               f [x_n,..,x_2,x_1]
    53 
    54    Moreover, to handle functions that are still waiting for some
    55    arguments we have additionally a list of arguments collected to far
    56    and the number of arguments we're still waiting for.
    57 *)
    58 
    59 datatype Univ = 
    60     Const of string * Univ list        (*named (uninterpreted) constants*)
    61   | Free of string * Univ list         (*free variables*)
    62   | BVar of int * Univ list            (*bound named variables*)
    63   | Abs of (int * (Univ list -> Univ)) * Univ list
    64                                       (*abstractions as closures*);
    65 
    66 (* constructor functions *)
    67 
    68 val free = curry Free;
    69 fun abs n f ts = Abs ((n, f), ts);
    70 fun app (Abs ((1, f), xs)) x = f (x :: xs)
    71   | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
    72   | app (Const (name, args)) x = Const (name, x :: args)
    73   | app (Free (name, args)) x = Free (name, x :: args)
    74   | app (BVar (name, args)) x = BVar (name, x :: args);
    75 
    76 (* global functions store *)
    77 
    78 structure Nbe_Functions = CodeDataFun
    79 (
    80   type T = Univ Graph.T;
    81   val empty = Graph.empty;
    82   fun merge _ = Graph.merge (K true);
    83   fun purge _ NONE _ = Graph.empty
    84     | purge NONE _ _ = Graph.empty
    85     | purge (SOME thy) (SOME cs) gr = Graph.empty
    86         (*let
    87           val cs_exisiting =
    88             map_filter (CodeName.const_rev thy) (Graph.keys gr);
    89           val dels = (Graph.all_preds gr
    90               o map (CodeName.const thy)
    91               o filter (member (op =) cs_exisiting)
    92             ) cs;
    93         in Graph.del_nodes dels gr end*);
    94 );
    95 
    96 fun defined gr = can (Graph.get_node gr);
    97 
    98 (* sandbox communication *)
    99 
   100 val univs_ref = ref (fn () => [] : Univ list);
   101 
   102 local
   103 
   104 val gr_ref = ref NONE : Nbe_Functions.T option ref;
   105 
   106 in
   107 
   108 fun lookup_fun s = case ! gr_ref
   109  of NONE => error "compile_univs"
   110   | SOME gr => Graph.get_node gr s;
   111 
   112 fun compile_univs tab ([], _) = []
   113   | compile_univs tab (cs, raw_s) =
   114       let
   115         val _ = univs_ref := (fn () => []);
   116         val s = "Nbe.univs_ref := " ^ raw_s;
   117         val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
   118         val _ = gr_ref := SOME tab;
   119         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
   120           Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   121           (!trace) s;
   122         val _ = gr_ref := NONE;
   123         val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs;
   124       in cs ~~ univs end;
   125 
   126 end; (*local*)
   127 
   128 
   129 (** assembling and compiling ML code from terms **)
   130 
   131 (* abstract ML syntax *)
   132 
   133 infix 9 `$` `$$`;
   134 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
   135 fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
   136 fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
   137 
   138 fun ml_Val v s = "val " ^ v ^ " = " ^ s;
   139 fun ml_cases t cs =
   140   "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
   141 fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
   142 
   143 fun ml_list es = "[" ^ commas es ^ "]";
   144 
   145 val ml_delay = ml_abs "()"
   146 
   147 fun ml_fundefs ([(name, [([], e)])]) =
   148       "val " ^ name ^ " = " ^ e ^ "\n"
   149   | ml_fundefs (eqs :: eqss) =
   150       let
   151         fun fundef (name, eqs) =
   152           let
   153             fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
   154           in space_implode "\n  | " (map eqn eqs) end;
   155       in
   156         (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
   157         |> space_implode "\n"
   158         |> suffix "\n"
   159       end;
   160 
   161 (* nbe specific syntax *)
   162 
   163 local
   164   val prefix =          "Nbe.";
   165   val name_const =      prefix ^ "Const";
   166   val name_free =       prefix ^ "free";
   167   val name_abs =        prefix ^ "abs";
   168   val name_app =        prefix ^ "app";
   169   val name_lookup_fun = prefix ^ "lookup_fun";
   170 in
   171 
   172 fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
   173 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
   174 fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
   175 fun nbe_bound v = "v_" ^ v;
   176 
   177 fun nbe_apps e es =
   178   Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
   179 
   180 fun nbe_abss 0 f = f `$` ml_list []
   181   | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
   182 
   183 fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
   184 
   185 val nbe_value = "value";
   186 
   187 end;
   188 
   189 open BasicCodeThingol;
   190 
   191 (* greetings to Tarski *)
   192 
   193 fun assemble_iterm thy is_fun num_args =
   194   let
   195     fun of_iterm t =
   196       let
   197         val (t', ts) = CodeThingol.unfold_app t
   198       in of_iapp t' (fold (cons o of_iterm) ts []) end
   199     and of_iconst c ts = case num_args c
   200      of SOME n => if n <= length ts
   201           then let val (args2, args1) = chop (length ts - n) ts
   202           in nbe_apps (nbe_fun c `$` ml_list args1) args2
   203           end else nbe_const c ts
   204       | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
   205           else nbe_const c ts
   206     and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts
   207       | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   208       | of_iapp ((v, _) `|-> t) ts =
   209           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   210       | of_iapp (ICase (((t, _), cs), t0)) ts =
   211           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   212             @ [("_", of_iterm t0)])) ts
   213   in of_iterm end;
   214 
   215 fun assemble_fun thy is_fun num_args (c, eqns) =
   216   let
   217     val assemble_arg = assemble_iterm thy (K false) (K NONE);
   218     val assemble_rhs = assemble_iterm thy is_fun num_args;
   219     fun assemble_eqn (args, rhs) =
   220       ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
   221     val default_params = map nbe_bound
   222       (Name.invent_list [] "a" ((the o num_args) c));
   223     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   224   in map assemble_eqn eqns @ [default_eqn] end;
   225 
   226 fun assemble_eqnss thy is_fun ([], deps) = ([], "")
   227   | assemble_eqnss thy is_fun (eqnss, deps) =
   228       let
   229         val cs = map fst eqnss;
   230         val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
   231         val funs = fold (fold (CodeThingol.fold_constnames
   232           (insert (op =))) o map snd o snd) eqnss [];
   233         val bind_funs = map nbe_lookup (filter is_fun funs);
   234         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
   235           (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
   236         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
   237           |> ml_delay;
   238       in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
   239 
   240 fun assemble_eval thy is_fun (((vs, ty), t), deps) =
   241   let
   242     val funs = CodeThingol.fold_constnames (insert (op =)) t [];
   243     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
   244     val bind_funs = map nbe_lookup (filter is_fun funs);
   245     val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
   246       assemble_iterm thy is_fun (K NONE) t)])];
   247     val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)]
   248       |> ml_delay;
   249   in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
   250 
   251 fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) =
   252       NONE
   253   | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) =
   254       SOME ((name, map fst eqns), deps)
   255   | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
   256       NONE
   257   | eqns_of_stmt ((_, CodeThingol.Datatype _), _) =
   258       NONE
   259   | eqns_of_stmt ((_, CodeThingol.Class _), _) =
   260       NONE
   261   | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
   262       NONE
   263   | eqns_of_stmt ((_, CodeThingol.Classparam _), _) =
   264       NONE
   265   | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
   266       NONE;
   267 
   268 fun compile_stmts thy is_fun =
   269   map_filter eqns_of_stmt
   270   #> split_list
   271   #> assemble_eqnss thy is_fun
   272   #> compile_univs (Nbe_Functions.get thy);
   273 
   274 fun eval_term thy is_fun =
   275   assemble_eval thy is_fun
   276   #> compile_univs (Nbe_Functions.get thy)
   277   #> the_single
   278   #> snd;
   279 
   280 
   281 (** compilation and evaluation **)
   282 
   283 (* ensure global functions *)
   284 
   285 fun ensure_funs thy code =
   286   let
   287     fun add_dep (name, dep) gr =
   288       if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep
   289       then Graph.add_edge (name, dep) gr else gr;
   290     fun compile' stmts gr =
   291       let
   292         val compiled = compile_stmts thy (defined gr) stmts;
   293         val names = map (fst o fst) stmts;
   294         val deps = maps snd stmts;
   295       in
   296         Nbe_Functions.change thy (fold Graph.new_node compiled
   297           #> fold (fn name => fold (curry add_dep name) deps) names)
   298       end;
   299     val nbe_gr = Nbe_Functions.get thy;
   300     val stmtss = rev (Graph.strong_conn code)
   301       |> (map o map_filter) (fn name => if defined nbe_gr name
   302            then NONE
   303            else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
   304       |> filter_out null
   305   in fold compile' stmtss nbe_gr end;
   306 
   307 (* reification *)
   308 
   309 fun term_of_univ thy t =
   310   let
   311     fun of_apps bounds (t, ts) =
   312       fold_map (of_univ bounds) ts
   313       #>> (fn ts' => list_comb (t, rev ts'))
   314     and of_univ bounds (Const (name, ts)) typidx =
   315           let
   316             val SOME c = CodeName.const_rev thy name;
   317             val T = Code.default_typ thy c;
   318             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
   319             val typidx' = typidx + maxidx_of_typ T' + 1;
   320           in of_apps bounds (Term.Const (c, T'), ts) typidx' end
   321       | of_univ bounds (Free (name, ts)) typidx =
   322           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   323       | of_univ bounds (BVar (name, ts)) typidx =
   324           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   325       | of_univ bounds (t as Abs _) typidx =
   326           typidx
   327           |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
   328           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   329   in of_univ 0 t 0 |> fst end;
   330 
   331 (* evaluation with type reconstruction *)
   332 
   333 fun eval thy code t vs_ty_t deps =
   334   let
   335     val ty = type_of t;
   336     fun subst_Frees [] = I
   337       | subst_Frees inst =
   338           Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
   339                             | t => t);
   340     val anno_vars =
   341       subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
   342       #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
   343     fun constrain t =
   344       singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t);
   345     fun check_tvars t = if null (Term.term_tvars t) then t else
   346       error ("Illegal schematic type variables in normalized term: "
   347         ^ setmp show_types true (Sign.string_of_term thy) t);
   348   in
   349     (vs_ty_t, deps)
   350     |> eval_term thy (defined (ensure_funs thy code))
   351     |> term_of_univ thy
   352     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
   353     |> anno_vars
   354     |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
   355     |> constrain
   356     |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
   357     |> check_tvars
   358     |> tracing (fn _ => "---\n")
   359   end;
   360 
   361 (* evaluation oracle *)
   362 
   363 exception Norm of CodeThingol.code * term
   364   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   365 
   366 fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
   367   Logic.mk_equals (t, eval thy code t vs_ty_t deps);
   368 
   369 fun norm_invoke thy code t vs_ty_t deps =
   370   Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
   371   (*FIXME get rid of hardwired theory name*)
   372 
   373 fun norm_conv ct =
   374   let
   375     val thy = Thm.theory_of_cterm ct;
   376     fun conv code vs_ty_t deps ct =
   377       let
   378         val t = Thm.term_of ct;
   379       in norm_invoke thy code t vs_ty_t deps end;
   380   in CodePackage.eval_conv thy conv ct end;
   381 
   382 fun norm_term thy =
   383   let
   384     fun invoke code vs_ty_t deps t =
   385       eval thy code t vs_ty_t deps;
   386   in CodePackage.eval_term thy invoke #> Code.postprocess_term thy end;
   387 
   388 (* evaluation command *)
   389 
   390 fun norm_print_term ctxt modes t =
   391   let
   392     val thy = ProofContext.theory_of ctxt;
   393     val t' = norm_term thy t;
   394     val ty' = Term.type_of t';
   395     val p = PrintMode.with_modes modes (fn () =>
   396       Pretty.block [Pretty.quote (Syntax.pretty_term ctxt t'), Pretty.fbrk,
   397         Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt ty')]) ();
   398   in Pretty.writeln p end;
   399 
   400 
   401 (** Isar setup **)
   402 
   403 fun norm_print_term_cmd (modes, s) state =
   404   let val ctxt = Toplevel.context_of state
   405   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
   406 
   407 val setup = Theory.add_oracle ("norm", norm_oracle)
   408 
   409 local structure P = OuterParse and K = OuterKeyword in
   410 
   411 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   412 
   413 val _ =
   414   OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
   415     (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd));
   416 
   417 end;
   418 
   419 end;