src/Tools/code/code_funcgr.ML
author haftmann
Wed, 15 Aug 2007 08:57:42 +0200
changeset 24283 8ca96f4e49cd
parent 24219 e558fe311376
child 24423 ae9cd0e92423
permissions -rw-r--r--
tuned
     1 (*  Title:      Tools/code/code_funcgr.ML
     2     ID:         $Id$
     3     Author:     Florian Haftmann, TU Muenchen
     4 
     5 Retrieving, normalizing and structuring defining equations in graph
     6 with explicit dependencies.
     7 *)
     8 
     9 signature CODE_FUNCGR =
    10 sig
    11   type T
    12   val timing: bool ref
    13   val funcs: T -> CodeUnit.const -> thm list
    14   val typ: T -> CodeUnit.const -> typ
    15   val all: T -> CodeUnit.const list
    16   val pretty: theory -> T -> Pretty.T
    17   val make: theory -> CodeUnit.const list -> T
    18   val make_consts: theory -> CodeUnit.const list -> CodeUnit.const list * T
    19   val eval_conv: theory -> (T -> cterm -> thm) -> cterm -> thm
    20   val eval_term: theory -> (T -> cterm -> 'a) -> cterm -> 'a
    21   val intervene: theory -> T -> T
    22     (*FIXME drop intervene as soon as possible*)
    23   structure Constgraph : GRAPH
    24 end
    25 
    26 structure CodeFuncgr : CODE_FUNCGR =
    27 struct
    28 
    29 (** the graph type **)
    30 
    31 structure Constgraph = GraphFun (
    32   type key = CodeUnit.const;
    33   val ord = CodeUnit.const_ord;
    34 );
    35 
    36 type T = (typ * thm list) Constgraph.T;
    37 
    38 fun funcs funcgr =
    39   these o Option.map snd o try (Constgraph.get_node funcgr);
    40 
    41 fun typ funcgr =
    42   fst o Constgraph.get_node funcgr;
    43 
    44 fun all funcgr = Constgraph.keys funcgr;
    45 
    46 fun pretty thy funcgr =
    47   AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
    48   |> (map o apfst) (CodeUnit.string_of_const thy)
    49   |> sort (string_ord o pairself fst)
    50   |> map (fn (s, thms) =>
    51        (Pretty.block o Pretty.fbreaks) (
    52          Pretty.str s
    53          :: map Display.pretty_thm thms
    54        ))
    55   |> Pretty.chunks;
    56 
    57 
    58 (** generic combinators **)
    59 
    60 fun fold_consts f thms =
    61   thms
    62   |> maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of)
    63   |> (fold o fold_aterms) (fn Const c => f c | _ => I);
    64 
    65 fun consts_of (const, []) = []
    66   | consts_of (const, thms as thm :: _) = 
    67       let
    68         val thy = Thm.theory_of_thm thm;
    69         val is_refl = curry CodeUnit.eq_const const;
    70         fun the_const c = case try (CodeUnit.const_of_cexpr thy) c
    71          of SOME const => if is_refl const then I else insert CodeUnit.eq_const const
    72           | NONE => I
    73       in fold_consts the_const thms [] end;
    74 
    75 fun insts_of thy algebra c ty_decl ty =
    76   let
    77     val tys_decl = Sign.const_typargs thy (c, ty_decl);
    78     val tys = Sign.const_typargs thy (c, ty);
    79     fun class_relation (x, _) _ = x;
    80     fun type_constructor tyco xs class =
    81       (tyco, class) :: maps (maps fst) xs;
    82     fun type_variable (TVar (_, sort)) = map (pair []) sort
    83       | type_variable (TFree (_, sort)) = map (pair []) sort;
    84     fun mk_inst ty (TVar (_, sort)) = cons (ty, sort)
    85       | mk_inst ty (TFree (_, sort)) = cons (ty, sort)
    86       | mk_inst (Type (_, tys1)) (Type (_, tys2)) = fold2 mk_inst tys1 tys2;
    87     fun of_sort_deriv (ty, sort) =
    88       Sorts.of_sort_derivation (Sign.pp thy) algebra
    89         { class_relation = class_relation, type_constructor = type_constructor,
    90           type_variable = type_variable }
    91         (ty, sort)
    92   in
    93     flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
    94   end;
    95 
    96 fun drop_classes thy tfrees thm =
    97   let
    98     val (_, thm') = Thm.varifyT' [] thm;
    99     val tvars = Term.add_tvars (Thm.prop_of thm') [];
   100     val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
   101     val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
   102       (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
   103   in
   104     thm'
   105     |> fold Thm.unconstrainT unconstr
   106     |> Thm.instantiate (instmap, [])
   107     |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
   108   end;
   109 
   110 
   111 (** graph algorithm **)
   112 
   113 val timing = ref false;
   114 
   115 local
   116 
   117 exception INVALID of CodeUnit.const list * string;
   118 
   119 fun resort_thms algebra tap_typ [] = []
   120   | resort_thms algebra tap_typ (thms as thm :: _) =
   121       let
   122         val thy = Thm.theory_of_thm thm;
   123         val cs = fold_consts (insert (op =)) thms [];
   124         fun match_const c (ty, ty_decl) =
   125           let
   126             val tys = CodeUnit.typargs thy (c, ty);
   127             val sorts = map (snd o dest_TVar) (CodeUnit.typargs thy (c, ty_decl));
   128           in fold2 (curry (CodeUnit.typ_sort_inst algebra)) tys sorts end;
   129         fun match (c_ty as (c, ty)) =
   130           case tap_typ c_ty
   131            of SOME ty_decl => match_const c (ty, ty_decl)
   132             | NONE => I;
   133         val tvars = fold match cs Vartab.empty;
   134       in map (CodeUnit.inst_thm tvars) thms end;
   135 
   136 fun resort_funcss thy algebra funcgr =
   137   let
   138     val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodeUnit.const_of_cexpr thy);
   139     fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
   140       handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
   141                     ^ ",\nfor constant " ^ CodeUnit.string_of_const thy const
   142                     ^ "\nin defining equations\n"
   143                     ^ (cat_lines o map string_of_thm) thms)
   144     fun resort_rec tap_typ (const, []) = (true, (const, []))
   145       | resort_rec tap_typ (const, thms as thm :: _) =
   146           let
   147             val (_, ty) = CodeUnit.head_func thm;
   148             val thms' as thm' :: _ = resort_thms algebra tap_typ thms
   149             val (_, ty') = CodeUnit.head_func thm';
   150           in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
   151     fun resort_recs funcss =
   152       let
   153         fun tap_typ c_ty = case try (CodeUnit.const_of_cexpr thy) c_ty
   154          of SOME const => AList.lookup (CodeUnit.eq_const) funcss const
   155               |> these
   156               |> try hd
   157               |> Option.map (snd o CodeUnit.head_func)
   158           | NONE => NONE;
   159         val (unchangeds, funcss') = split_list (map (resort_rec tap_typ) funcss);
   160         val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
   161       in (unchanged, funcss') end;
   162     fun resort_rec_until funcss =
   163       let
   164         val (unchanged, funcss') = resort_recs funcss;
   165       in if unchanged then funcss' else resort_rec_until funcss' end;
   166   in map resort_dep #> resort_rec_until end;
   167 
   168 fun instances_of thy algebra insts =
   169   let
   170     val thy_classes = (#classes o Sorts.rep_algebra o Sign.classes_of) thy;
   171     fun all_classops tyco class =
   172       try (AxClass.params_of_class thy) class
   173       |> Option.map snd
   174       |> these
   175       |> map (fn (c, _) => (c, SOME tyco))
   176   in
   177     Symtab.empty
   178     |> fold (fn (tyco, class) =>
   179         Symtab.map_default (tyco, []) (insert (op =) class)) insts
   180     |> (fn tab => Symtab.fold (fn (tyco, classes) => append (maps (all_classops tyco)
   181          (Graph.all_succs thy_classes classes))) tab [])
   182   end;
   183 
   184 fun instances_of_consts thy algebra funcgr consts =
   185   let
   186     fun inst (cexpr as (c, ty)) = insts_of thy algebra c
   187       ((fst o Constgraph.get_node funcgr o CodeUnit.const_of_cexpr thy) cexpr)
   188       ty handle CLASS_ERROR => [];
   189   in
   190     []
   191     |> fold (fold (insert (op =)) o inst) consts
   192     |> instances_of thy algebra
   193   end;
   194 
   195 fun ensure_const' thy algebra funcgr const auxgr =
   196   if can (Constgraph.get_node funcgr) const
   197     then (NONE, auxgr)
   198   else if can (Constgraph.get_node auxgr) const
   199     then (SOME const, auxgr)
   200   else if is_some (Code.get_datatype_of_constr thy const) then
   201     auxgr
   202     |> Constgraph.new_node (const, [])
   203     |> pair (SOME const)
   204   else let
   205     val thms = Code.these_funcs thy const
   206       |> CodeUnit.norm_args
   207       |> CodeUnit.norm_varnames CodeName.purify_tvar CodeName.purify_var;
   208     val rhs = consts_of (const, thms);
   209   in
   210     auxgr
   211     |> Constgraph.new_node (const, thms)
   212     |> fold_map (ensure_const thy algebra funcgr) rhs
   213     |-> (fn rhs' => fold (fn SOME const' => Constgraph.add_edge (const, const')
   214                            | NONE => I) rhs')
   215     |> pair (SOME const)
   216   end
   217 and ensure_const thy algebra funcgr const =
   218   let
   219     val timeap = if !timing
   220       then Output.timeap_msg ("time for " ^ CodeUnit.string_of_const thy const)
   221       else I;
   222   in timeap (ensure_const' thy algebra funcgr const) end;
   223 
   224 fun merge_funcss thy algebra raw_funcss funcgr =
   225   let
   226     val funcss = raw_funcss
   227       |> resort_funcss thy algebra funcgr
   228       |> filter_out (can (Constgraph.get_node funcgr) o fst);
   229     fun typ_func const [] = Code.default_typ thy const
   230       | typ_func (_, NONE) (thm :: _) = (snd o CodeUnit.head_func) thm
   231       | typ_func (const as (c, SOME tyco)) (thms as (thm :: _)) =
   232           let
   233             val (_, ty) = CodeUnit.head_func thm;
   234             val SOME class = AxClass.class_of_param thy c;
   235             val sorts_decl = Sorts.mg_domain algebra tyco [class];
   236             val tys = CodeUnit.typargs thy (c, ty);
   237             val sorts = map (snd o dest_TVar) tys;
   238           in if sorts = sorts_decl then ty
   239             else raise INVALID ([const], "Illegal instantation for class operation "
   240               ^ CodeUnit.string_of_const thy const
   241               ^ "\nin defining equations\n"
   242               ^ (cat_lines o map string_of_thm) thms)
   243           end;
   244     fun add_funcs (const, thms) =
   245       Constgraph.new_node (const, (typ_func const thms, thms));
   246     fun add_deps (funcs as (const, thms)) funcgr =
   247       let
   248         val deps = consts_of funcs;
   249         val insts = instances_of_consts thy algebra funcgr
   250           (fold_consts (insert (op =)) thms []);
   251       in
   252         funcgr
   253         |> ensure_consts' thy algebra insts
   254         |> fold (curry Constgraph.add_edge const) deps
   255         |> fold (curry Constgraph.add_edge const) insts
   256        end;
   257   in
   258     funcgr
   259     |> fold add_funcs funcss
   260     |> fold add_deps funcss
   261   end
   262 and ensure_consts' thy algebra cs funcgr =
   263   let
   264     val auxgr = Constgraph.empty
   265       |> fold (snd oo ensure_const thy algebra funcgr) cs;
   266   in
   267     funcgr
   268     |> fold (merge_funcss thy algebra)
   269          (map (AList.make (Constgraph.get_node auxgr))
   270          (rev (Constgraph.strong_conn auxgr)))
   271   end handle INVALID (cs', msg)
   272     => raise INVALID (fold (insert CodeUnit.eq_const) cs' cs, msg);
   273 
   274 fun ensure_consts thy consts funcgr =
   275   let
   276     val algebra = Code.coregular_algebra thy
   277   in ensure_consts' thy algebra consts funcgr
   278     handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
   279     ^ commas (map (CodeUnit.string_of_const thy) cs'))
   280   end;
   281 
   282 in
   283 
   284 (** retrieval interfaces **)
   285 
   286 val ensure_consts = ensure_consts;
   287 
   288 fun check_consts thy consts funcgr =
   289   let
   290     val algebra = Code.coregular_algebra thy;
   291     fun try_const const funcgr =
   292       (SOME const, ensure_consts' thy algebra [const] funcgr)
   293       handle INVALID (cs', msg) => (NONE, funcgr);
   294     val (consts', funcgr') = fold_map try_const consts funcgr;
   295   in (map_filter I consts', funcgr') end;
   296 
   297 fun ensure_consts_term_proto thy f ct funcgr =
   298   let
   299     fun consts_of thy t =
   300       fold_aterms (fn Const c => cons (CodeUnit.const_of_cexpr thy c) | _ => I) t []
   301     fun rhs_conv conv thm =
   302       let
   303         val thm' = (conv o Thm.rhs_of) thm;
   304       in Thm.transitive thm thm' end
   305     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
   306     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
   307     val thm1 = Code.preprocess_conv ct;
   308     val ct' = Thm.rhs_of thm1;
   309     val consts = consts_of thy (Thm.term_of ct');
   310     val funcgr' = ensure_consts thy consts funcgr;
   311     val algebra = Code.coregular_algebra thy;
   312     val (_, thm2) = Thm.varifyT' [] thm1;
   313     val thm3 = Thm.reflexive (Thm.rhs_of thm2);
   314     val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodeUnit.const_of_cexpr thy);
   315     val [thm4] = resort_thms algebra typ_funcgr [thm3];
   316     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
   317     fun inst thm =
   318       let
   319         val tvars = Term.add_tvars (Thm.prop_of thm) [];
   320         val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
   321           (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
   322       in Thm.instantiate (instmap, []) thm end;
   323     val thm5 = inst thm2;
   324     val thm6 = inst thm4;
   325     val ct'' = Thm.rhs_of thm6;
   326     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
   327     val drop = drop_classes thy tfrees;
   328     val instdefs = instances_of_consts thy algebra funcgr' cs;
   329     val funcgr'' = ensure_consts thy instdefs funcgr';
   330   in (f funcgr'' drop ct'' thm5, funcgr'') end;
   331 
   332 fun ensure_consts_eval thy conv =
   333   let
   334     fun conv' funcgr drop_classes ct thm1 =
   335       let
   336         val thm2 = conv funcgr ct;
   337         val thm3 = Code.postprocess_conv (Thm.rhs_of thm2);
   338         val thm23 = drop_classes (Thm.transitive thm2 thm3);
   339       in
   340         Thm.transitive thm1 thm23 handle THM _ =>
   341           error ("eval_conv - could not construct proof:\n"
   342           ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
   343       end;
   344   in ensure_consts_term_proto thy conv' end;
   345 
   346 fun ensure_consts_term thy f =
   347   let
   348     fun f' funcgr drop_classes ct thm1 = f funcgr ct;
   349   in ensure_consts_term_proto thy f' end;
   350 
   351 end; (*local*)
   352 
   353 structure Funcgr = CodeDataFun
   354 (struct
   355   type T = T;
   356   val empty = Constgraph.empty;
   357   fun merge _ _ = Constgraph.empty;
   358   fun purge _ NONE _ = Constgraph.empty
   359     | purge _ (SOME cs) funcgr =
   360         Constgraph.del_nodes ((Constgraph.all_preds funcgr 
   361           o filter (can (Constgraph.get_node funcgr))) cs) funcgr;
   362 end);
   363 
   364 fun make thy =
   365   Funcgr.change thy o ensure_consts thy;
   366 
   367 fun make_consts thy =
   368   Funcgr.change_yield thy o check_consts thy;
   369 
   370 fun eval_conv thy f =
   371   fst o Funcgr.change_yield thy o ensure_consts_eval thy f;
   372 
   373 fun eval_term thy f =
   374   fst o Funcgr.change_yield thy o ensure_consts_term thy f;
   375 
   376 fun intervene thy funcgr = Funcgr.change thy (K funcgr);
   377 
   378 end; (*struct*)