src/Tools/code/code_funcgr.ML
author haftmann
Fri, 23 May 2008 16:05:11 +0200
changeset 26971 160117247294
parent 26939 1035c89b4c02
child 26997 40552bbac005
permissions -rw-r--r--
more permissive preprocessor
     1 (*  Title:      Tools/code/code_funcgr.ML
     2     ID:         $Id$
     3     Author:     Florian Haftmann, TU Muenchen
     4 
     5 Retrieving, normalizing and structuring defining equations in graph
     6 with explicit dependencies.
     7 *)
     8 
     9 signature CODE_FUNCGR =
    10 sig
    11   type T
    12   val funcs: T -> string -> thm list
    13   val typ: T -> string -> (string * sort) list * typ
    14   val all: T -> string list
    15   val pretty: theory -> T -> Pretty.T
    16   val make: theory -> string list -> T
    17   val make_consts: theory -> string list -> string list * T
    18   val eval_conv: theory -> (term -> term * (T -> term -> thm)) -> cterm -> thm
    19   val eval_term: theory -> (term -> term * (T -> term -> 'a)) -> term -> 'a
    20   val timing: bool ref
    21 end
    22 
    23 structure CodeFuncgr : CODE_FUNCGR =
    24 struct
    25 
    26 (** the graph type **)
    27 
    28 type T = (((string * sort) list * typ) * thm list) Graph.T;
    29 
    30 fun funcs funcgr =
    31   these o Option.map snd o try (Graph.get_node funcgr);
    32 
    33 fun typ funcgr =
    34   fst o Graph.get_node funcgr;
    35 
    36 fun all funcgr = Graph.keys funcgr;
    37 
    38 fun pretty thy funcgr =
    39   AList.make (snd o Graph.get_node funcgr) (Graph.keys funcgr)
    40   |> (map o apfst) (CodeUnit.string_of_const thy)
    41   |> sort (string_ord o pairself fst)
    42   |> map (fn (s, thms) =>
    43        (Pretty.block o Pretty.fbreaks) (
    44          Pretty.str s
    45          :: map Display.pretty_thm thms
    46        ))
    47   |> Pretty.chunks;
    48 
    49 
    50 (** generic combinators **)
    51 
    52 fun fold_consts f thms =
    53   thms
    54   |> maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of)
    55   |> (fold o fold_aterms) (fn Const c => f c | _ => I);
    56 
    57 fun consts_of (const, []) = []
    58   | consts_of (const, thms as _ :: _) = 
    59       let
    60         fun the_const (c, _) = if c = const then I else insert (op =) c
    61       in fold_consts the_const thms [] end;
    62 
    63 fun insts_of thy algebra tys sorts =
    64   let
    65     fun class_relation (x, _) _ = x;
    66     fun type_constructor tyco xs class =
    67       (tyco, class) :: (maps o maps) fst xs;
    68     fun type_variable (TVar (_, sort)) = map (pair []) sort
    69       | type_variable (TFree (_, sort)) = map (pair []) sort;
    70     fun of_sort_deriv ty sort =
    71       Sorts.of_sort_derivation (Syntax.pp_global thy) algebra
    72         { class_relation = class_relation, type_constructor = type_constructor,
    73           type_variable = type_variable }
    74         (ty, sort) handle Sorts.CLASS_ERROR _ => [] (*permissive!*)
    75   in (flat o flat) (map2 of_sort_deriv tys sorts) end;
    76 
    77 fun meets_of thy algebra =
    78   let
    79     fun meet_of ty sort tab =
    80       Sorts.meet_sort algebra (ty, sort) tab
    81         handle Sorts.CLASS_ERROR _ => tab (*permissive!*);
    82   in fold2 meet_of end;
    83 
    84 
    85 (** graph algorithm **)
    86 
    87 val timing = ref false;
    88 
    89 local
    90 
    91 exception CLASS_ERROR of string list * string;
    92 
    93 fun resort_thms thy algebra typ_of thms =
    94   let
    95     val cs = fold_consts (insert (op =)) thms [];
    96     fun meets (c, ty) = case typ_of c
    97        of SOME (vs, _) =>
    98             meets_of thy algebra (Sign.const_typargs thy (c, ty)) (map snd vs)
    99         | NONE => I;
   100     val tab = fold meets cs Vartab.empty;
   101   in map (CodeUnit.inst_thm tab) thms end;
   102 
   103 fun resort_funcss thy algebra funcgr =
   104   let
   105     val typ_funcgr = try (fst o Graph.get_node funcgr);
   106     val resort_dep = apsnd (resort_thms thy algebra typ_funcgr);
   107     fun resort_rec typ_of (const, []) = (true, (const, []))
   108       | resort_rec typ_of (const, thms as thm :: _) =
   109           let
   110             val (_, (vs, ty)) = CodeUnit.head_func thm;
   111             val thms' as thm' :: _ = resort_thms thy algebra typ_of thms
   112             val (_, (vs', ty')) = CodeUnit.head_func thm'; (*FIXME simplify check*)
   113           in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
   114     fun resort_recs funcss =
   115       let
   116         fun typ_of c = case these (AList.lookup (op =) funcss c)
   117          of thm :: _ => (SOME o snd o CodeUnit.head_func) thm
   118           | [] => NONE;
   119         val (unchangeds, funcss') = split_list (map (resort_rec typ_of) funcss);
   120         val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
   121       in (unchanged, funcss') end;
   122     fun resort_rec_until funcss =
   123       let
   124         val (unchanged, funcss') = resort_recs funcss;
   125       in if unchanged then funcss' else resort_rec_until funcss' end;
   126   in map resort_dep #> resort_rec_until end;
   127 
   128 fun instances_of thy algebra insts =
   129   let
   130     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
   131     fun all_classparams tyco class =
   132       these (try (#params o AxClass.get_info thy) class)
   133       |> map_filter (fn (c, _) => try (AxClass.param_of_inst thy) (c, tyco))
   134   in
   135     Symtab.empty
   136     |> fold (fn (tyco, class) =>
   137         Symtab.map_default (tyco, []) (insert (op =) class)) insts
   138     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classparams tyco)
   139          (Graph.all_succs thy_classes classes))) tab [])
   140   end;
   141 
   142 fun instances_of_consts thy algebra funcgr consts =
   143   let
   144     fun inst (cexpr as (c, ty)) = insts_of thy algebra
   145       (Sign.const_typargs thy (c, ty)) ((map snd o fst) (typ funcgr c));
   146   in
   147     []
   148     |> fold (fold (insert (op =)) o inst) consts
   149     |> instances_of thy algebra
   150   end;
   151 
   152 fun ensure_const' thy algebra funcgr const auxgr =
   153   if can (Graph.get_node funcgr) const
   154     then (NONE, auxgr)
   155   else if can (Graph.get_node auxgr) const
   156     then (SOME const, auxgr)
   157   else if is_some (Code.get_datatype_of_constr thy const) then
   158     auxgr
   159     |> Graph.new_node (const, [])
   160     |> pair (SOME const)
   161   else let
   162     val thms = Code.these_funcs thy const
   163       |> CodeUnit.norm_args
   164       |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
   165     val rhs = consts_of (const, thms);
   166   in
   167     auxgr
   168     |> Graph.new_node (const, thms)
   169     |> fold_map (ensure_const thy algebra funcgr) rhs
   170     |-> (fn rhs' => fold (fn SOME const' => Graph.add_edge (const, const')
   171                            | NONE => I) rhs')
   172     |> pair (SOME const)
   173   end
   174 and ensure_const thy algebra funcgr const =
   175   let
   176     val timeap = if !timing
   177       then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
   178       else I;
   179   in timeap (ensure_const' thy algebra funcgr const) end;
   180 
   181 fun merge_funcss thy algebra raw_funcss funcgr =
   182   let
   183     val funcss = raw_funcss
   184       |> resort_funcss thy algebra funcgr
   185       |> filter_out (can (Graph.get_node funcgr) o fst);
   186     fun typ_func c [] = Code.default_typ thy c
   187       | typ_func c (thms as thm :: _) = case AxClass.inst_of_param thy c
   188          of SOME (c', tyco) => 
   189               let
   190                 val (_, (vs, ty)) = CodeUnit.head_func thm;
   191                 val SOME class = AxClass.class_of_param thy c';
   192                 val sorts_decl = Sorts.mg_domain algebra tyco [class];
   193               in if map snd vs = sorts_decl then (vs, ty)
   194                 else raise CLASS_ERROR ([c], "Illegal instantation for class operation "
   195                   ^ CodeUnit.string_of_const thy c
   196                   ^ "\nin defining equations\n"
   197                   ^ (cat_lines o map (Display.string_of_thm o AxClass.overload thy)) thms)
   198               end
   199           | NONE => (snd o CodeUnit.head_func) thm;
   200     fun add_funcs (const, thms) =
   201       Graph.new_node (const, (typ_func const thms, thms));
   202     fun add_deps (funcs as (const, thms)) funcgr =
   203       let
   204         val deps = consts_of funcs;
   205         val insts = instances_of_consts thy algebra funcgr
   206           (fold_consts (insert (op =)) thms []);
   207       in
   208         funcgr
   209         |> ensure_consts' thy algebra insts
   210         |> fold (curry Graph.add_edge const) deps
   211         |> fold (curry Graph.add_edge const) insts
   212        end;
   213   in
   214     funcgr
   215     |> fold add_funcs funcss
   216     |> fold add_deps funcss
   217   end
   218 and ensure_consts' thy algebra cs funcgr =
   219   let
   220     val auxgr = Graph.empty
   221       |> fold (snd oo ensure_const thy algebra funcgr) cs;
   222   in
   223     funcgr
   224     |> fold (merge_funcss thy algebra)
   225          (map (AList.make (Graph.get_node auxgr))
   226          (rev (Graph.strong_conn auxgr)))
   227   end handle CLASS_ERROR (cs', msg)
   228     => raise CLASS_ERROR (fold (insert (op =)) cs' cs, msg);
   229 
   230 in
   231 
   232 (** retrieval interfaces **)
   233 
   234 fun ensure_consts thy algebra consts funcgr =
   235   ensure_consts' thy algebra consts funcgr
   236     handle CLASS_ERROR (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
   237     ^ commas (map (CodeUnit.string_of_const thy) cs'));
   238 
   239 fun check_consts thy consts funcgr =
   240   let
   241     val algebra = Code.coregular_algebra thy;
   242     fun try_const const funcgr =
   243       (SOME const, ensure_consts' thy algebra [const] funcgr)
   244       handle CLASS_ERROR (cs', msg) => (NONE, funcgr);
   245     val (consts', funcgr') = fold_map try_const consts funcgr;
   246   in (map_filter I consts', funcgr') end;
   247 
   248 fun proto_eval thy cterm_of evaluator_fr evaluator proto_ct funcgr =
   249   let
   250     val ct = cterm_of proto_ct;
   251     val _ = Sign.no_vars (Syntax.pp_global thy) (Thm.term_of ct);
   252     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   253     fun consts_of t = fold_aterms (fn Const c_ty => cons c_ty | _ => I)
   254       t [];
   255     val algebra = Code.coregular_algebra thy;
   256     val thm = Code.preprocess_conv ct;
   257     val ct' = Thm.rhs_of thm;
   258     val t' = Thm.term_of ct';
   259     val consts = map fst (consts_of t');
   260     val funcgr' = ensure_consts thy algebra consts funcgr;
   261     val (t'', evaluator') = apsnd evaluator_fr (evaluator t');
   262     val consts' = consts_of t'';
   263     val dicts = instances_of_consts thy algebra funcgr' consts';
   264     val funcgr'' = ensure_consts thy algebra dicts funcgr';
   265   in (evaluator' thm funcgr'' t'', funcgr'') end;
   266 
   267 fun proto_eval_conv thy =
   268   let
   269     fun evaluator evaluator' thm1 funcgr t =
   270       let
   271         val thm2 = evaluator' funcgr t;
   272         val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
   273       in
   274         Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ =>
   275           error ("could not construct evaluation proof:\n"
   276           ^ (cat_lines o map Display.string_of_thm) [thm1, thm2, thm3])
   277       end;
   278   in proto_eval thy I evaluator end;
   279 
   280 fun proto_eval_term thy =
   281   let
   282     fun evaluator evaluator' _ funcgr t = evaluator' funcgr t;
   283   in proto_eval thy (Thm.cterm_of thy) evaluator end;
   284 
   285 end; (*local*)
   286 
   287 structure Funcgr = CodeDataFun
   288 (
   289   type T = T;
   290   val empty = Graph.empty;
   291   fun merge _ _ = Graph.empty;
   292   fun purge _ NONE _ = Graph.empty
   293     | purge _ (SOME cs) funcgr =
   294         Graph.del_nodes ((Graph.all_preds funcgr 
   295           o filter (can (Graph.get_node funcgr))) cs) funcgr;
   296 );
   297 
   298 fun make thy =
   299   Funcgr.change thy o ensure_consts thy (Code.coregular_algebra thy);
   300 
   301 fun make_consts thy =
   302   Funcgr.change_yield thy o check_consts thy;
   303 
   304 fun eval_conv thy f =
   305   fst o Funcgr.change_yield thy o proto_eval_conv thy f;
   306 
   307 fun eval_term thy f =
   308   fst o Funcgr.change_yield thy o proto_eval_term thy f;
   309 
   310 end; (*struct*)