src/Tools/code/code_funcgr.ML
author haftmann
Mon, 10 Dec 2007 11:24:15 +0100
changeset 25597 34860182b250
parent 25485 33840a854e63
child 26331 92120667172f
permissions -rw-r--r--
moved instance parameter management from class.ML to axclass.ML
     1 (*  Title:      Tools/code/code_funcgr.ML
     2     ID:         $Id$
     3     Author:     Florian Haftmann, TU Muenchen
     4 
     5 Retrieving, normalizing and structuring defining equations in graph
     6 with explicit dependencies.
     7 *)
     8 
     9 signature CODE_FUNCGR =
    10 sig
    11   type T
    12   val timing: bool ref
    13   val funcs: T -> string -> thm list
    14   val typ: T -> string -> typ
    15   val all: T -> string list
    16   val pretty: theory -> T -> Pretty.T
    17   val make: theory -> string list -> T
    18   val make_consts: theory -> string list -> string list * T
    19   val eval_conv: theory -> (T -> cterm -> thm) -> cterm -> thm
    20   val eval_term: theory -> (T -> term -> 'a) -> term -> 'a
    21 end
    22 
    23 structure CodeFuncgr : CODE_FUNCGR =
    24 struct
    25 
    26 (** the graph type **)
    27 
    28 type T = (typ * thm list) Graph.T;
    29 
    30 fun funcs funcgr =
    31   these o Option.map snd o try (Graph.get_node funcgr);
    32 
    33 fun typ funcgr =
    34   fst o Graph.get_node funcgr;
    35 
    36 fun all funcgr = Graph.keys funcgr;
    37 
    38 fun pretty thy funcgr =
    39   AList.make (snd o Graph.get_node funcgr) (Graph.keys funcgr)
    40   |> (map o apfst) (CodeUnit.string_of_const thy)
    41   |> sort (string_ord o pairself fst)
    42   |> map (fn (s, thms) =>
    43        (Pretty.block o Pretty.fbreaks) (
    44          Pretty.str s
    45          :: map Display.pretty_thm thms
    46        ))
    47   |> Pretty.chunks;
    48 
    49 
    50 (** generic combinators **)
    51 
    52 fun fold_consts f thms =
    53   thms
    54   |> maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of)
    55   |> (fold o fold_aterms) (fn Const c => f c | _ => I);
    56 
    57 fun consts_of (const, []) = []
    58   | consts_of (const, thms as _ :: _) = 
    59       let
    60         fun the_const (c, _) = if c = const then I else insert (op =) c
    61       in fold_consts the_const thms [] end;
    62 
    63 fun insts_of thy algebra c ty_decl ty =
    64   let
    65     val tys_decl = Sign.const_typargs thy (c, ty_decl);
    66     val tys = Sign.const_typargs thy (c, ty);
    67     fun class_relation (x, _) _ = x;
    68     fun type_constructor tyco xs class =
    69       (tyco, class) :: maps (maps fst) xs;
    70     fun type_variable (TVar (_, sort)) = map (pair []) sort
    71       | type_variable (TFree (_, sort)) = map (pair []) sort;
    72     fun mk_inst ty (TVar (_, sort)) = cons (ty, sort)
    73       | mk_inst ty (TFree (_, sort)) = cons (ty, sort)
    74       | mk_inst (Type (_, tys1)) (Type (_, tys2)) = fold2 mk_inst tys1 tys2;
    75     fun of_sort_deriv (ty, sort) =
    76       Sorts.of_sort_derivation (Sign.pp thy) algebra
    77         { class_relation = class_relation, type_constructor = type_constructor,
    78           type_variable = type_variable }
    79         (ty, sort)
    80   in
    81     flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
    82   end;
    83 
    84 fun drop_classes thy tfrees thm =
    85   let
    86     val (_, thm') = Thm.varifyT' [] thm;
    87     val tvars = Term.add_tvars (Thm.prop_of thm') [];
    88     val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
    89     val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
    90       (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
    91   in
    92     thm'
    93     |> fold Thm.unconstrainT unconstr
    94     |> Thm.instantiate (instmap, [])
    95     |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
    96   end;
    97 
    98 
    99 (** graph algorithm **)
   100 
   101 val timing = ref false;
   102 
   103 local
   104 
   105 exception INVALID of string list * string;
   106 
   107 fun resort_thms algebra tap_typ [] = []
   108   | resort_thms algebra tap_typ (thms as thm :: _) =
   109       let
   110         val thy = Thm.theory_of_thm thm;
   111         val cs = fold_consts (insert (op =)) thms [];
   112         fun match_const c (ty, ty_decl) =
   113           let
   114             val tys = Sign.const_typargs thy (c, ty);
   115             val sorts = map (snd o dest_TVar) (Sign.const_typargs thy (c, ty_decl));
   116           in fold2 (curry (CodeUnit.typ_sort_inst algebra)) tys sorts end;
   117         fun match (c, ty) =
   118           case tap_typ c
   119            of SOME ty_decl => match_const c (ty, ty_decl)
   120             | NONE => I;
   121         val tvars = fold match cs Vartab.empty;
   122       in map (CodeUnit.inst_thm tvars) thms end;
   123 
   124 fun resort_funcss thy algebra funcgr =
   125   let
   126     val typ_funcgr = try (fst o Graph.get_node funcgr);
   127     fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
   128       handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
   129                     ^ ",\nfor constant " ^ CodeUnit.string_of_const thy const
   130                     ^ "\nin defining equations\n"
   131                     ^ (cat_lines o map string_of_thm) thms)
   132     fun resort_rec tap_typ (const, []) = (true, (const, []))
   133       | resort_rec tap_typ (const, thms as thm :: _) =
   134           let
   135             val (_, ty) = CodeUnit.head_func thm;
   136             val thms' as thm' :: _ = resort_thms algebra tap_typ thms
   137             val (_, ty') = CodeUnit.head_func thm';
   138           in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
   139     fun resort_recs funcss =
   140       let
   141         fun tap_typ c =
   142           AList.lookup (op =) funcss c
   143           |> these
   144           |> try hd
   145           |> Option.map (snd o CodeUnit.head_func);
   146         val (unchangeds, funcss') = split_list (map (resort_rec tap_typ) funcss);
   147         val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
   148       in (unchanged, funcss') end;
   149     fun resort_rec_until funcss =
   150       let
   151         val (unchanged, funcss') = resort_recs funcss;
   152       in if unchanged then funcss' else resort_rec_until funcss' end;
   153   in map resort_dep #> resort_rec_until end;
   154 
   155 fun instances_of thy algebra insts =
   156   let
   157     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
   158     fun all_classparams tyco class =
   159       these (try (#params o AxClass.get_info thy) class)
   160       |> map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
   161   in
   162     Symtab.empty
   163     |> fold (fn (tyco, class) =>
   164         Symtab.map_default (tyco, []) (insert (op =) class)) insts
   165     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classparams tyco)
   166          (Graph.all_succs thy_classes classes))) tab [])
   167   end;
   168 
   169 fun instances_of_consts thy algebra funcgr consts =
   170   let
   171     fun inst (cexpr as (c, ty)) = insts_of thy algebra c
   172       ((fst o Graph.get_node funcgr) c) ty handle CLASS_ERROR => [];
   173   in
   174     []
   175     |> fold (fold (insert (op =)) o inst) consts
   176     |> instances_of thy algebra
   177   end;
   178 
   179 fun ensure_const' thy algebra funcgr const auxgr =
   180   if can (Graph.get_node funcgr) const
   181     then (NONE, auxgr)
   182   else if can (Graph.get_node auxgr) const
   183     then (SOME const, auxgr)
   184   else if is_some (Code.get_datatype_of_constr thy const) then
   185     auxgr
   186     |> Graph.new_node (const, [])
   187     |> pair (SOME const)
   188   else let
   189     val thms = Code.these_funcs thy const
   190       |> CodeUnit.norm_args
   191       |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
   192     val rhs = consts_of (const, thms);
   193   in
   194     auxgr
   195     |> Graph.new_node (const, thms)
   196     |> fold_map (ensure_const thy algebra funcgr) rhs
   197     |-> (fn rhs' => fold (fn SOME const' => Graph.add_edge (const, const')
   198                            | NONE => I) rhs')
   199     |> pair (SOME const)
   200   end
   201 and ensure_const thy algebra funcgr const =
   202   let
   203     val timeap = if !timing
   204       then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
   205       else I;
   206   in timeap (ensure_const' thy algebra funcgr const) end;
   207 
   208 fun merge_funcss thy algebra raw_funcss funcgr =
   209   let
   210     val funcss = raw_funcss
   211       |> resort_funcss thy algebra funcgr
   212       |> filter_out (can (Graph.get_node funcgr) o fst);
   213     fun typ_func c [] = Code.default_typ thy c
   214       | typ_func c (thms as thm :: _) = case AxClass.inst_of_param thy c
   215          of SOME (c', tyco) => 
   216               let
   217                 val (_, ty) = CodeUnit.head_func thm;
   218                 val SOME class = AxClass.class_of_param thy c';
   219                 val sorts_decl = Sorts.mg_domain algebra tyco [class];
   220                 val tys = Sign.const_typargs thy (c, ty);
   221                 val sorts = map (snd o dest_TVar) tys;
   222               in if sorts = sorts_decl then ty
   223                 else raise INVALID ([c], "Illegal instantation for class operation "
   224                   ^ CodeUnit.string_of_const thy c
   225                   ^ "\nin defining equations\n"
   226                   ^ (cat_lines o map string_of_thm) thms)
   227               end
   228           | NONE => (snd o CodeUnit.head_func) thm;
   229     fun add_funcs (const, thms) =
   230       Graph.new_node (const, (typ_func const thms, thms));
   231     fun add_deps (funcs as (const, thms)) funcgr =
   232       let
   233         val deps = consts_of funcs;
   234         val insts = instances_of_consts thy algebra funcgr
   235           (fold_consts (insert (op =)) thms []);
   236       in
   237         funcgr
   238         |> ensure_consts' thy algebra insts
   239         |> fold (curry Graph.add_edge const) deps
   240         |> fold (curry Graph.add_edge const) insts
   241        end;
   242   in
   243     funcgr
   244     |> fold add_funcs funcss
   245     |> fold add_deps funcss
   246   end
   247 and ensure_consts' thy algebra cs funcgr =
   248   let
   249     val auxgr = Graph.empty
   250       |> fold (snd oo ensure_const thy algebra funcgr) cs;
   251   in
   252     funcgr
   253     |> fold (merge_funcss thy algebra)
   254          (map (AList.make (Graph.get_node auxgr))
   255          (rev (Graph.strong_conn auxgr)))
   256   end handle INVALID (cs', msg)
   257     => raise INVALID (fold (insert (op =)) cs' cs, msg);
   258 
   259 in
   260 
   261 (** retrieval interfaces **)
   262 
   263 fun ensure_consts thy algebra consts funcgr =
   264   ensure_consts' thy algebra consts funcgr
   265     handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
   266     ^ commas (map (CodeUnit.string_of_const thy) cs'));
   267 
   268 fun check_consts thy consts funcgr =
   269   let
   270     val algebra = Code.coregular_algebra thy;
   271     fun try_const const funcgr =
   272       (SOME const, ensure_consts' thy algebra [const] funcgr)
   273       handle INVALID (cs', msg) => (NONE, funcgr);
   274     val (consts', funcgr') = fold_map try_const consts funcgr;
   275   in (map_filter I consts', funcgr') end;
   276 
   277 fun raw_eval thy f ct funcgr =
   278   let
   279     val algebra = Code.coregular_algebra thy;
   280     fun consts_of ct = fold_aterms (fn Const c_ty => cons c_ty | _ => I)
   281       (Thm.term_of ct) [];
   282     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   283     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   284     val thm1 = Code.preprocess_conv ct;
   285     val ct' = Thm.rhs_of thm1;
   286     val cs = map fst (consts_of ct');
   287     val funcgr' = ensure_consts thy algebra cs funcgr;
   288     val (_, thm2) = Thm.varifyT' [] thm1;
   289     val thm3 = Thm.reflexive (Thm.rhs_of thm2);
   290     val [thm4] = resort_thms algebra (try (fst o Graph.get_node funcgr')) [thm3];
   291     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
   292     fun inst thm =
   293       let
   294         val tvars = Term.add_tvars (Thm.prop_of thm) [];
   295         val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
   296           (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
   297       in Thm.instantiate (instmap, []) thm end;
   298     val thm5 = inst thm2;
   299     val thm6 = inst thm4;
   300     val ct'' = Thm.rhs_of thm6;
   301     val c_exprs = consts_of ct'';
   302     val drop = drop_classes thy tfrees;
   303     val instdefs = instances_of_consts thy algebra funcgr' c_exprs;
   304     val funcgr'' = ensure_consts thy algebra instdefs funcgr';
   305   in (f drop thm5 funcgr'' ct'', funcgr'') end;
   306 
   307 fun raw_eval_conv thy conv =
   308   let
   309     fun conv' drop_classes thm1 funcgr ct =
   310       let
   311         val thm2 = conv funcgr ct;
   312         val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
   313         val thm23 = drop_classes (Thm.transitive thm2 thm3);
   314       in
   315         Thm.transitive thm1 thm23 handle THM _ =>
   316           error ("could not construct proof:\n"
   317           ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
   318       end;
   319   in raw_eval thy conv' end;
   320 
   321 fun raw_eval_term thy f t =
   322   let
   323     fun f' _ _ funcgr ct = f funcgr (Thm.term_of ct);
   324   in raw_eval thy f' (Thm.cterm_of thy t) end;
   325 
   326 end; (*local*)
   327 
   328 structure Funcgr = CodeDataFun
   329 (
   330   type T = T;
   331   val empty = Graph.empty;
   332   fun merge _ _ = Graph.empty;
   333   fun purge _ NONE _ = Graph.empty
   334     | purge _ (SOME cs) funcgr =
   335         Graph.del_nodes ((Graph.all_preds funcgr 
   336           o filter (can (Graph.get_node funcgr))) cs) funcgr;
   337 );
   338 
   339 fun make thy =
   340   Funcgr.change thy o ensure_consts thy (Code.coregular_algebra thy);
   341 
   342 fun make_consts thy =
   343   Funcgr.change_yield thy o check_consts thy;
   344 
   345 fun eval_conv thy f =
   346   fst o Funcgr.change_yield thy o raw_eval_conv thy f;
   347 
   348 fun eval_term thy f =
   349   fst o Funcgr.change_yield thy o raw_eval_term thy f;
   350 
   351 end; (*struct*)