src/HOL/Tools/recfun_codegen.ML
author haftmann
Tue, 17 Jan 2006 16:36:57 +0100
changeset 18702 7dc7dcd63224
parent 18220 43cf5767f992
child 18708 4b3dadb4fe33
permissions -rw-r--r--
substantial improvements in code generator
     1 (*  Title:      HOL/recfun_codegen.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Code generator for recursive functions.
     6 *)
     7 
     8 signature RECFUN_CODEGEN =
     9 sig
    10   val add: string option -> theory attribute
    11   val del: theory attribute
    12   val get_rec_equations: theory -> string * typ -> (term list * term) list * typ
    13   val get_thm_equations: theory -> string * typ -> (thm list * typ) option
    14   val setup: (theory -> theory) list
    15 end;
    16 
    17 structure RecfunCodegen : RECFUN_CODEGEN =
    18 struct
    19 
    20 open Codegen;
    21 
    22 structure CodegenData = TheoryDataFun
    23 (struct
    24   val name = "HOL/recfun_codegen";
    25   type T = (thm * string option) list Symtab.table;
    26   val empty = Symtab.empty;
    27   val copy = I;
    28   val extend = I;
    29   fun merge _ = Symtab.merge_multi' (Drule.eq_thm_prop o pairself fst);
    30   fun print _ _ = ();
    31 end);
    32 
    33 
    34 val dest_eqn = HOLogic.dest_eq o HOLogic.dest_Trueprop;
    35 val lhs_of = fst o dest_eqn o prop_of;
    36 val const_of = dest_Const o head_of o fst o dest_eqn;
    37 
    38 fun warn thm = warning ("RecfunCodegen: Not a proper equation:\n" ^
    39   string_of_thm thm);
    40 
    41 fun add optmod (p as (thy, thm)) =
    42   let
    43     val tab = CodegenData.get thy;
    44     val (s, _) = const_of (prop_of thm);
    45   in
    46     if Pattern.pattern (lhs_of thm) then
    47       (CodegenData.put (Symtab.update_multi (s, (thm, optmod)) tab) thy, thm)
    48     else (warn thm; p)
    49   end handle TERM _ => (warn thm; p);
    50 
    51 fun del (p as (thy, thm)) =
    52   let
    53     val tab = CodegenData.get thy;
    54     val (s, _) = const_of (prop_of thm);
    55   in case Symtab.lookup tab s of
    56       NONE => p
    57     | SOME thms => (CodegenData.put (Symtab.update (s,
    58         gen_rem (eq_thm o apfst fst) (thms, thm)) tab) thy, thm)
    59   end handle TERM _ => (warn thm; p);
    60 
    61 fun del_redundant thy eqs [] = eqs
    62   | del_redundant thy eqs (eq :: eqs') =
    63     let
    64       val matches = curry
    65         (Pattern.matches thy o pairself (lhs_of o fst))
    66     in del_redundant thy (eq :: eqs) (filter_out (matches eq) eqs') end;
    67 
    68 fun get_equations thy defs (s, T) =
    69   (case Symtab.lookup (CodegenData.get thy) s of
    70      NONE => ([], "")
    71    | SOME thms => 
    72        let val thms' = del_redundant thy []
    73          (List.filter (fn (thm, _) => is_instance thy T
    74            (snd (const_of (prop_of thm)))) thms)
    75        in if null thms' then ([], "")
    76          else (preprocess thy (map fst thms'),
    77            case snd (snd (split_last thms')) of
    78                NONE => (case get_defn thy defs s T of
    79                    NONE => thyname_of_const s thy
    80                  | SOME ((_, (thyname, _)), _) => thyname)
    81              | SOME thyname => thyname)
    82        end);
    83 
    84 fun get_rec_equations thy (s, T) =
    85   Symtab.lookup (CodegenData.get thy) s
    86   |> Option.map (fn thms => 
    87        List.filter (fn (thm, _) => is_instance thy T ((snd o const_of o prop_of) thm)) thms
    88        |> del_redundant thy [])
    89   |> Option.mapPartial (fn thms => if null thms then NONE else SOME thms)
    90   |> Option.map (fn thms =>
    91        (preprocess thy (map fst thms),
    92           (snd o const_of o prop_of o fst o hd) thms))
    93   |> the_default ([], dummyT)
    94   |> apfst (map prop_of)
    95   |> apfst (map (Codegen.rename_term #> HOLogic.dest_Trueprop #> HOLogic.dest_eq #> apfst (strip_comb #> snd)))
    96 
    97 fun get_thm_equations thy (c, ty) =
    98   Symtab.lookup (CodegenData.get thy) c
    99   |> Option.map (fn thms => 
   100        List.filter (fn (thm, _) => is_instance thy ty ((snd o const_of o prop_of) thm)) thms
   101        |> del_redundant thy [])
   102   |> Option.mapPartial (fn thms => if null thms then NONE else SOME thms)
   103   |> Option.map (fn thms =>
   104        (preprocess thy (map fst thms),
   105           (snd o const_of o prop_of o fst o hd) thms))
   106   |> (Option.map o apfst o map) (fn thm => thm RS HOL.eq_reflection);
   107 
   108 fun mk_suffix thy defs (s, T) = (case get_defn thy defs s T of
   109   SOME (_, SOME i) => " def" ^ string_of_int i | _ => "");
   110 
   111 exception EQN of string * typ * string;
   112 
   113 fun cycle g (xs, x) =
   114   if x mem xs then xs
   115   else Library.foldl (cycle g) (x :: xs, List.concat (Graph.find_paths (fst g) (x, x)));
   116 
   117 fun add_rec_funs thy defs gr dep eqs module =
   118   let
   119     fun dest_eq t = (fst (const_of t) ^ mk_suffix thy defs (const_of t),
   120       dest_eqn (rename_term t));
   121     val eqs' = map dest_eq eqs;
   122     val (dname, _) :: _ = eqs';
   123     val (s, T) = const_of (hd eqs);
   124 
   125     fun mk_fundef module fname prfx gr [] = (gr, [])
   126       | mk_fundef module fname prfx gr ((fname', (lhs, rhs)) :: xs) =
   127       let
   128         val (gr1, pl) = invoke_codegen thy defs dname module false (gr, lhs);
   129         val (gr2, pr) = invoke_codegen thy defs dname module false (gr1, rhs);
   130         val (gr3, rest) = mk_fundef module fname' "and " gr2 xs
   131       in
   132         (gr3, Pretty.blk (4, [Pretty.str (if fname = fname' then "  | " else prfx),
   133            pl, Pretty.str " =", Pretty.brk 1, pr]) :: rest)
   134       end;
   135 
   136     fun put_code module fundef = map_node dname
   137       (K (SOME (EQN ("", dummyT, dname)), module, Pretty.string_of (Pretty.blk (0,
   138       separate Pretty.fbrk fundef @ [Pretty.str ";"])) ^ "\n\n"));
   139 
   140   in
   141     (case try (get_node gr) dname of
   142        NONE =>
   143          let
   144            val gr1 = add_edge (dname, dep)
   145              (new_node (dname, (SOME (EQN (s, T, "")), module, "")) gr);
   146            val (gr2, fundef) = mk_fundef module "" "fun " gr1 eqs';
   147            val xs = cycle gr2 ([], dname);
   148            val cs = map (fn x => case get_node gr2 x of
   149                (SOME (EQN (s, T, _)), _, _) => (s, T)
   150              | _ => error ("RecfunCodegen: illegal cyclic dependencies:\n" ^
   151                 implode (separate ", " xs))) xs
   152          in (case xs of
   153              [_] => (put_code module fundef gr2, module)
   154            | _ =>
   155              if not (dep mem xs) then
   156                let
   157                  val thmss as (_, thyname) :: _ = map (get_equations thy defs) cs;
   158                  val module' = if_library thyname module;
   159                  val eqs'' = map (dest_eq o prop_of) (List.concat (map fst thmss));
   160                  val (gr3, fundef') = mk_fundef module' "" "fun "
   161                    (add_edge (dname, dep)
   162                      (foldr (uncurry new_node) (del_nodes xs gr2)
   163                        (map (fn k =>
   164                          (k, (SOME (EQN ("", dummyT, dname)), module', ""))) xs))) eqs''
   165                in (put_code module' fundef' gr3, module') end
   166              else (gr2, module))
   167          end
   168      | SOME (SOME (EQN (_, _, s)), module', _) =>
   169          (if s = "" then
   170             if dname = dep then gr else add_edge (dname, dep) gr
   171           else if s = dep then gr else add_edge (s, dep) gr,
   172           module'))
   173   end;
   174 
   175 fun recfun_codegen thy defs gr dep module brack t = (case strip_comb t of
   176     (Const (p as (s, T)), ts) => (case (get_equations thy defs p, get_assoc_code thy s T) of
   177        (([], _), _) => NONE
   178      | (_, SOME _) => NONE
   179      | ((eqns, thyname), NONE) =>
   180         let
   181           val module' = if_library thyname module;
   182           val (gr', ps) = foldl_map
   183             (invoke_codegen thy defs dep module true) (gr, ts);
   184           val suffix = mk_suffix thy defs p;
   185           val (gr'', module'') =
   186             add_rec_funs thy defs gr' dep (map prop_of eqns) module';
   187           val (gr''', fname) = mk_const_id module'' (s ^ suffix) gr''
   188         in
   189           SOME (gr''', mk_app brack (Pretty.str (mk_qual_id module fname)) ps)
   190         end)
   191   | _ => NONE);
   192 
   193 
   194 val setup = [
   195   CodegenData.init,
   196   add_codegen "recfun" recfun_codegen,
   197   add_attribute ""
   198     (   Args.del |-- Scan.succeed del
   199      || Scan.option (Args.$$$ "target" |-- Args.colon |-- Args.name) >> add),
   200   CodegenPackage.add_eqextr
   201     ("rec", fn thy => fn _ => get_thm_equations thy)
   202 ];
   203 
   204 end;