src/Tools/code/code_funcgr.ML
author haftmann
Fri, 10 Aug 2007 17:04:34 +0200
changeset 24219 e558fe311376
child 24283 8ca96f4e49cd
permissions -rw-r--r--
new structure for code generator modules
     1 (*  Title:      Tools/code/code_funcgr.ML
     2     ID:         $Id$
     3     Author:     Florian Haftmann, TU Muenchen
     4 
     5 Retrieving, normalizing and structuring defining equations in graph
     6 with explicit dependencies.
     7 *)
     8 
     9 signature CODE_FUNCGR =
    10 sig
    11   type T
    12   val timing: bool ref
    13   val funcs: T -> CodeUnit.const -> thm list
    14   val typ: T -> CodeUnit.const -> typ
    15   val deps: T -> CodeUnit.const list -> CodeUnit.const list list
    16   val all: T -> CodeUnit.const list
    17   val pretty: theory -> T -> Pretty.T
    18   structure Constgraph : GRAPH
    19 end
    20 
    21 signature CODE_FUNCGR_RETRIEVAL =
    22 sig
    23   type T (* = CODE_FUNCGR.T *)
    24   val make: theory -> CodeUnit.const list -> T
    25   val make_consts: theory -> CodeUnit.const list -> CodeUnit.const list * T
    26   val make_term: theory -> (T -> (thm -> thm) -> cterm -> thm -> 'a) -> cterm -> 'a * T
    27     (*FIXME drop make_term as soon as possible*)
    28   val eval_conv: theory -> (T -> cterm -> thm) -> cterm -> thm
    29   val intervene: theory -> T -> T
    30     (*FIXME drop intervene as soon as possible*)
    31 end;
    32 
    33 structure CodeFuncgr = (*signature is added later*)
    34 struct
    35 
    36 (** the graph type **)
    37 
    38 structure Constgraph = GraphFun (
    39   type key = CodeUnit.const;
    40   val ord = CodeUnit.const_ord;
    41 );
    42 
    43 type T = (typ * thm list) Constgraph.T;
    44 
    45 fun funcs funcgr =
    46   these o Option.map snd o try (Constgraph.get_node funcgr);
    47 
    48 fun typ funcgr =
    49   fst o Constgraph.get_node funcgr;
    50 
    51 fun deps funcgr cs =
    52   let
    53     val conn = Constgraph.strong_conn funcgr;
    54     val order = rev conn;
    55   in
    56     (map o filter) (member (op =) (Constgraph.all_succs funcgr cs)) order
    57     |> filter_out null
    58   end;
    59 
    60 fun all funcgr = Constgraph.keys funcgr;
    61 
    62 fun pretty thy funcgr =
    63   AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
    64   |> (map o apfst) (CodeUnit.string_of_const thy)
    65   |> sort (string_ord o pairself fst)
    66   |> map (fn (s, thms) =>
    67        (Pretty.block o Pretty.fbreaks) (
    68          Pretty.str s
    69          :: map Display.pretty_thm thms
    70        ))
    71   |> Pretty.chunks;
    72 
    73 
    74 (** generic combinators **)
    75 
    76 fun fold_consts f thms =
    77   thms
    78   |> maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of)
    79   |> (fold o fold_aterms) (fn Const c => f c | _ => I);
    80 
    81 fun consts_of (const, []) = []
    82   | consts_of (const, thms as thm :: _) = 
    83       let
    84         val thy = Thm.theory_of_thm thm;
    85         val is_refl = curry CodeUnit.eq_const const;
    86         fun the_const c = case try (CodeUnit.const_of_cexpr thy) c
    87          of SOME const => if is_refl const then I else insert CodeUnit.eq_const const
    88           | NONE => I
    89       in fold_consts the_const thms [] end;
    90 
    91 fun insts_of thy algebra c ty_decl ty =
    92   let
    93     val tys_decl = Sign.const_typargs thy (c, ty_decl);
    94     val tys = Sign.const_typargs thy (c, ty);
    95     fun class_relation (x, _) _ = x;
    96     fun type_constructor tyco xs class =
    97       (tyco, class) :: maps (maps fst) xs;
    98     fun type_variable (TVar (_, sort)) = map (pair []) sort
    99       | type_variable (TFree (_, sort)) = map (pair []) sort;
   100     fun mk_inst ty (TVar (_, sort)) = cons (ty, sort)
   101       | mk_inst ty (TFree (_, sort)) = cons (ty, sort)
   102       | mk_inst (Type (_, tys1)) (Type (_, tys2)) = fold2 mk_inst tys1 tys2;
   103     fun of_sort_deriv (ty, sort) =
   104       Sorts.of_sort_derivation (Sign.pp thy) algebra
   105         { class_relation = class_relation, type_constructor = type_constructor,
   106           type_variable = type_variable }
   107         (ty, sort)
   108   in
   109     flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
   110   end;
   111 
   112 fun drop_classes thy tfrees thm =
   113   let
   114     val (_, thm') = Thm.varifyT' [] thm;
   115     val tvars = Term.add_tvars (Thm.prop_of thm') [];
   116     val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
   117     val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
   118       (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
   119   in
   120     thm'
   121     |> fold Thm.unconstrainT unconstr
   122     |> Thm.instantiate (instmap, [])
   123     |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
   124   end;
   125 
   126 
   127 (** graph algorithm **)
   128 
   129 val timing = ref false;
   130 
   131 local
   132 
   133 exception INVALID of CodeUnit.const list * string;
   134 
   135 fun resort_thms algebra tap_typ [] = []
   136   | resort_thms algebra tap_typ (thms as thm :: _) =
   137       let
   138         val thy = Thm.theory_of_thm thm;
   139         val cs = fold_consts (insert (op =)) thms [];
   140         fun match_const c (ty, ty_decl) =
   141           let
   142             val tys = CodeUnit.typargs thy (c, ty);
   143             val sorts = map (snd o dest_TVar) (CodeUnit.typargs thy (c, ty_decl));
   144           in fold2 (curry (CodeUnit.typ_sort_inst algebra)) tys sorts end;
   145         fun match (c_ty as (c, ty)) =
   146           case tap_typ c_ty
   147            of SOME ty_decl => match_const c (ty, ty_decl)
   148             | NONE => I;
   149         val tvars = fold match cs Vartab.empty;
   150       in map (CodeUnit.inst_thm tvars) thms end;
   151 
   152 fun resort_funcss thy algebra funcgr =
   153   let
   154     val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodeUnit.const_of_cexpr thy);
   155     fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
   156       handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
   157                     ^ ",\nfor constant " ^ CodeUnit.string_of_const thy const
   158                     ^ "\nin defining equations\n"
   159                     ^ (cat_lines o map string_of_thm) thms)
   160     fun resort_rec tap_typ (const, []) = (true, (const, []))
   161       | resort_rec tap_typ (const, thms as thm :: _) =
   162           let
   163             val (_, ty) = CodeUnit.head_func thm;
   164             val thms' as thm' :: _ = resort_thms algebra tap_typ thms
   165             val (_, ty') = CodeUnit.head_func thm';
   166           in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
   167     fun resort_recs funcss =
   168       let
   169         fun tap_typ c_ty = case try (CodeUnit.const_of_cexpr thy) c_ty
   170          of SOME const => AList.lookup (CodeUnit.eq_const) funcss const
   171               |> these
   172               |> try hd
   173               |> Option.map (snd o CodeUnit.head_func)
   174           | NONE => NONE;
   175         val (unchangeds, funcss') = split_list (map (resort_rec tap_typ) funcss);
   176         val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
   177       in (unchanged, funcss') end;
   178     fun resort_rec_until funcss =
   179       let
   180         val (unchanged, funcss') = resort_recs funcss;
   181       in if unchanged then funcss' else resort_rec_until funcss' end;
   182   in map resort_dep #> resort_rec_until end;
   183 
   184 fun instances_of thy algebra insts =
   185   let
   186     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
   187     fun all_classops tyco class =
   188       try (AxClass.params_of_class thy) class
   189       |> Option.map snd
   190       |> these
   191       |> map (fn (c, _) => (c, SOME tyco))
   192   in
   193     Symtab.empty
   194     |> fold (fn (tyco, class) =>
   195         Symtab.map_default (tyco, []) (insert (op =) class)) insts
   196     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops tyco)
   197          (Graph.all_succs thy_classes classes))) tab [])
   198   end;
   199 
   200 fun instances_of_consts thy algebra funcgr consts =
   201   let
   202     fun inst (cexpr as (c, ty)) = insts_of thy algebra c
   203       ((fst o Constgraph.get_node funcgr o CodeUnit.const_of_cexpr thy) cexpr)
   204       ty handle CLASS_ERROR => [];
   205   in
   206     []
   207     |> fold (fold (insert (op =)) o inst) consts
   208     |> instances_of thy algebra
   209   end;
   210 
   211 fun ensure_const' rewrites thy algebra funcgr const auxgr =
   212   if can (Constgraph.get_node funcgr) const
   213     then (NONE, auxgr)
   214   else if can (Constgraph.get_node auxgr) const
   215     then (SOME const, auxgr)
   216   else if is_some (Code.get_datatype_of_constr thy const) then
   217     auxgr
   218     |> Constgraph.new_node (const, [])
   219     |> pair (SOME const)
   220   else let
   221     val thms = Code.these_funcs thy const
   222       |> map (CodeUnit.rewrite_func (rewrites thy))
   223       |> CodeUnit.norm_args
   224       |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
   225     val rhs = consts_of (const, thms);
   226   in
   227     auxgr
   228     |> Constgraph.new_node (const, thms)
   229     |> fold_map (ensure_const rewrites thy algebra funcgr) rhs
   230     |-> (fn rhs' => fold (fn SOME const' => Constgraph.add_edge (const, const')
   231                            | NONE => I) rhs')
   232     |> pair (SOME const)
   233   end
   234 and ensure_const rewrites thy algebra funcgr const =
   235   let
   236     val timeap = if !timing
   237       then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
   238       else I;
   239   in timeap (ensure_const' rewrites thy algebra funcgr const) end;
   240 
   241 fun merge_funcss rewrites thy algebra raw_funcss funcgr =
   242   let
   243     val funcss = raw_funcss
   244       |> resort_funcss thy algebra funcgr
   245       |> filter_out (can (Constgraph.get_node funcgr) o fst);
   246     fun typ_func const [] = Code.default_typ thy const
   247       | typ_func (_, NONE) (thm :: _) = (snd o CodeUnit.head_func) thm
   248       | typ_func (const as (c, SOME tyco)) (thms as (thm :: _)) =
   249           let
   250             val (_, ty) = CodeUnit.head_func thm;
   251             val SOME class = AxClass.class_of_param thy c;
   252             val sorts_decl = Sorts.mg_domain algebra tyco [class];
   253             val tys = CodeUnit.typargs thy (c, ty);
   254             val sorts = map (snd o dest_TVar) tys;
   255           in if sorts = sorts_decl then ty
   256             else raise INVALID ([const], "Illegal instantation for class operation "
   257               ^ CodeUnit.string_of_const thy const
   258               ^ "\nin defining equations\n"
   259               ^ (cat_lines o map string_of_thm) thms)
   260           end;
   261     fun add_funcs (const, thms) =
   262       Constgraph.new_node (const, (typ_func const thms, thms));
   263     fun add_deps (funcs as (const, thms)) funcgr =
   264       let
   265         val deps = consts_of funcs;
   266         val insts = instances_of_consts thy algebra funcgr
   267           (fold_consts (insert (op =)) thms []);
   268       in
   269         funcgr
   270         |> ensure_consts' rewrites thy algebra insts
   271         |> fold (curry Constgraph.add_edge const) deps
   272         |> fold (curry Constgraph.add_edge const) insts
   273        end;
   274   in
   275     funcgr
   276     |> fold add_funcs funcss
   277     |> fold add_deps funcss
   278   end
   279 and ensure_consts' rewrites thy algebra cs funcgr =
   280   let
   281     val auxgr = Constgraph.empty
   282       |> fold (snd oo ensure_const rewrites thy algebra funcgr) cs;
   283   in
   284     funcgr
   285     |> fold (merge_funcss rewrites thy algebra)
   286          (map (AList.make (Constgraph.get_node auxgr))
   287          (rev (Constgraph.strong_conn auxgr)))
   288   end handle INVALID (cs', msg)
   289     => raise INVALID (fold (insert CodeUnit.eq_const) cs' cs, msg);
   290 
   291 fun ensure_consts rewrites thy consts funcgr =
   292   let
   293     val algebra = Code.coregular_algebra thy
   294   in ensure_consts' rewrites thy algebra consts funcgr
   295     handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
   296     ^ commas (map (CodeUnit.string_of_const thy) cs'))
   297   end;
   298 
   299 in
   300 
   301 (** retrieval interfaces **)
   302 
   303 val ensure_consts = ensure_consts;
   304 
   305 fun check_consts rewrites thy consts funcgr =
   306   let
   307     val algebra = Code.coregular_algebra thy;
   308     fun try_const const funcgr =
   309       (SOME const, ensure_consts' rewrites thy algebra [const] funcgr)
   310       handle INVALID (cs', msg) => (NONE, funcgr);
   311     val (consts', funcgr') = fold_map try_const consts funcgr;
   312   in (map_filter I consts', funcgr') end;
   313 
   314 fun ensure_consts_term rewrites thy f ct funcgr =
   315   let
   316     fun consts_of thy t =
   317       fold_aterms (fn Const c => cons (CodeUnit.const_of_cexpr thy c) | _ => I) t []
   318     fun rhs_conv conv thm =
   319       let
   320         val thm' = (conv o Thm.rhs_of) thm;
   321       in Thm.transitive thm thm' end
   322     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   323     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   324     val thm1 = Code.preprocess_conv ct
   325       |> fold (rhs_conv o MetaSimplifier.rewrite false o single) (rewrites thy);
   326     val ct' = Thm.rhs_of thm1;
   327     val consts = consts_of thy (Thm.term_of ct');
   328     val funcgr' = ensure_consts rewrites thy consts funcgr;
   329     val algebra = Code.coregular_algebra thy;
   330     val (_, thm2) = Thm.varifyT' [] thm1;
   331     val thm3 = Thm.reflexive (Thm.rhs_of thm2);
   332     val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodeUnit.const_of_cexpr thy);
   333     val [thm4] = resort_thms algebra typ_funcgr [thm3];
   334     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
   335     fun inst thm =
   336       let
   337         val tvars = Term.add_tvars (Thm.prop_of thm) [];
   338         val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
   339           (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
   340       in Thm.instantiate (instmap, []) thm end;
   341     val thm5 = inst thm2;
   342     val thm6 = inst thm4;
   343     val ct'' = Thm.rhs_of thm6;
   344     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
   345     val drop = drop_classes thy tfrees;
   346     val instdefs = instances_of_consts thy algebra funcgr' cs;
   347     val funcgr'' = ensure_consts rewrites thy instdefs funcgr';
   348   in (f funcgr'' drop ct'' thm5, funcgr'') end;
   349 
   350 fun ensure_consts_eval rewrites thy conv =
   351   let
   352     fun conv' funcgr drop_classes ct thm1 =
   353       let
   354         val thm2 = conv funcgr ct;
   355         val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
   356         val thm23 = drop_classes (Thm.transitive thm2 thm3);
   357       in
   358         Thm.transitive thm1 thm23 handle THM _ =>
   359           error ("eval_conv - could not construct proof:\n"
   360           ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
   361       end;
   362   in ensure_consts_term rewrites thy conv' end;
   363 
   364 end; (*local*)
   365 
   366 end; (*struct*)
   367 
   368 functor CodeFuncgrRetrieval (val rewrites: theory -> thm list) : CODE_FUNCGR_RETRIEVAL =
   369 struct
   370 
   371 (** code data **)
   372 
   373 type T = CodeFuncgr.T;
   374 
   375 structure Funcgr = CodeDataFun
   376 (struct
   377   type T = T;
   378   val empty = CodeFuncgr.Constgraph.empty;
   379   fun merge _ _ = CodeFuncgr.Constgraph.empty;
   380   fun purge _ NONE _ = CodeFuncgr.Constgraph.empty
   381     | purge _ (SOME cs) funcgr =
   382         CodeFuncgr.Constgraph.del_nodes ((CodeFuncgr.Constgraph.all_preds funcgr 
   383           o filter (can (CodeFuncgr.Constgraph.get_node funcgr))) cs) funcgr;
   384 end);
   385 
   386 fun make thy =
   387   Funcgr.change thy o CodeFuncgr.ensure_consts rewrites thy;
   388 
   389 fun make_consts thy =
   390   Funcgr.change_yield thy o CodeFuncgr.check_consts rewrites thy;
   391 
   392 fun make_term thy f =
   393   Funcgr.change_yield thy o CodeFuncgr.ensure_consts_term rewrites thy f;
   394 
   395 fun eval_conv thy f =
   396   fst o Funcgr.change_yield thy o CodeFuncgr.ensure_consts_eval rewrites thy f;
   397 
   398 fun intervene thy funcgr = Funcgr.change thy (K funcgr);
   399 
   400 end; (*functor*)
   401 
   402 structure CodeFuncgr : CODE_FUNCGR =
   403 struct
   404 
   405 open CodeFuncgr;
   406 
   407 end; (*struct*)