src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
author wenzelm
Thu, 09 Jun 2011 16:34:49 +0200
changeset 44206 2b47822868e4
parent 43232 23f352990944
child 46777 0aaeb5520f2f
permissions -rw-r--r--
discontinued Name.variant to emphasize that this is old-style / indirect;
     1 (*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Deriving specialised predicates and their intro rules
     5 *)
     6 
     7 signature PREDICATE_COMPILE_SPECIALISATION =
     8 sig
     9   val find_specialisations : string list -> (string * thm list) list -> theory -> (string * thm list) list * theory
    10 end;
    11 
    12 structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION =
    13 struct
    14 
    15 open Predicate_Compile_Aux;
    16 
    17 (* table of specialisations *)
    18 structure Specialisations = Theory_Data
    19 (
    20   type T = (term * term) Item_Net.T;
    21   val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst);
    22   val extend = I;
    23   val merge = Item_Net.merge;
    24 )
    25 
    26 fun specialisation_of thy atom =
    27   Item_Net.retrieve (Specialisations.get thy) atom
    28 
    29 fun print_specialisations thy =
    30   tracing (cat_lines (map (fn (t, spec_t) =>
    31       Syntax.string_of_term_global thy t ^ " ~~~> " ^ Syntax.string_of_term_global thy spec_t)
    32     (Item_Net.content (Specialisations.get thy))))
    33 
    34 fun import (pred, intros) args ctxt =
    35   let
    36     val thy = Proof_Context.theory_of ctxt
    37     val ((Tinst, intros'), ctxt') = Variable.importT intros ctxt
    38     val pred' = fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros')))))
    39     val Ts = binder_types (fastype_of pred')
    40     val argTs = map fastype_of args
    41     val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty
    42     val args' = map (Envir.subst_term_types Tsubst) args
    43   in
    44     (((pred', intros'), args'), ctxt')
    45   end
    46 
    47 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*)
    48 fun is_nontrivial_constrt thy t =
    49   let
    50     val cnstrs = flat (maps
    51       (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
    52       (Symtab.dest (Datatype.get_all thy)));
    53     fun check t = (case strip_comb t of
    54         (Var _, []) => (true, true)
    55       | (Free _, []) => (true, true)
    56       | (Const (@{const_name Pair}, _), ts) =>
    57         pairself (forall I) (split_list (map check ts))
    58       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
    59             (SOME (i, Tname), Type (Tname', _)) => (false,
    60               length ts = i andalso Tname = Tname' andalso forall (snd o check) ts)
    61           | _ => (false, false))
    62       | _ => (false, false))
    63   in check t = (false, true) end;
    64 
    65 fun specialise_intros black_list (pred, intros) pats thy =
    66   let
    67     val ctxt = Proof_Context.init_global thy
    68     val maxidx = fold (Term.maxidx_term o prop_of) intros ~1
    69     val pats = map (Logic.incr_indexes ([],  maxidx + 1)) pats
    70     val (((pred, intros), pats), ctxt') = import (pred, intros) pats ctxt
    71     val intros_t = map prop_of intros
    72     val result_pats = map Var (fold_rev Term.add_vars pats [])
    73     fun mk_fresh_name names =
    74       let
    75         val name =
    76           singleton (Name.variant_list names)
    77             ("specialised_" ^ Long_Name.base_name (fst (dest_Const pred)))
    78         val bname = Sign.full_bname thy name
    79       in
    80         if Sign.declared_const thy bname then
    81           mk_fresh_name (name :: names)
    82         else
    83           bname
    84       end
    85     val constname = mk_fresh_name []
    86     val constT = map fastype_of result_pats ---> @{typ bool}
    87     val specialised_const = Const (constname, constT)
    88     val specialisation =
    89       [(HOLogic.mk_Trueprop (list_comb (pred, pats)),
    90         HOLogic.mk_Trueprop (list_comb (specialised_const, result_pats)))]
    91     fun specialise_intro intro =
    92       (let
    93         val (prems, concl) = Logic.strip_horn (prop_of intro)
    94         val env = Pattern.unify thy
    95           (HOLogic.mk_Trueprop (list_comb (pred, pats)), concl) (Envir.empty 0)
    96         val prems = map (Envir.norm_term env) prems
    97         val args = map (Envir.norm_term env) result_pats
    98         val concl = HOLogic.mk_Trueprop (list_comb (specialised_const, args))
    99         val intro = Logic.list_implies (prems, concl)
   100       in
   101         SOME intro
   102       end handle Pattern.Unif => NONE)
   103     val specialised_intros_t = map_filter I (map specialise_intro intros)
   104     val thy' = Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy
   105     val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t
   106     val exported_intros = Variable.exportT ctxt' ctxt specialised_intros
   107     val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt
   108       [list_comb (pred, pats), list_comb (specialised_const, result_pats)]
   109     val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy'
   110     val optimised_intros =
   111       map_filter (Predicate_Compile_Aux.peephole_optimisation thy'') exported_intros
   112     val ([spec], thy''') = find_specialisations black_list [(constname, optimised_intros)] thy''
   113     val thy'''' = Core_Data.register_intros spec thy'''
   114   in
   115     thy''''
   116   end
   117 
   118 and find_specialisations black_list specs thy =
   119   let
   120     val add_vars = fold_aterms (fn Var v => cons v | _ => I);
   121     fun fresh_free T free_names =
   122       let
   123         val free_name = singleton (Name.variant_list free_names) "x"
   124       in
   125         (Free (free_name, T), free_name :: free_names)
   126       end
   127     fun replace_term_and_restrict thy T t Tts free_names =
   128       let
   129         val (free, free_names') = fresh_free T free_names
   130         val Tts' = map (apsnd (Pattern.rewrite_term thy [(t, free)] [])) Tts
   131         val (ts', free_names'') = restrict_pattern' thy Tts' free_names'
   132       in
   133         (free :: ts', free_names'')
   134       end
   135     and restrict_pattern' thy [] free_names = ([], free_names)
   136       | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names =
   137       let
   138         val (ts', free_names') = restrict_pattern' thy Tts free_names
   139       in
   140         (Free (x, T) :: ts', free_names')
   141       end
   142       | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
   143         replace_term_and_restrict thy T t Tts free_names
   144       | restrict_pattern' thy ((T as Type (Tcon, Ts), t) :: Tts) free_names =
   145         case Datatype_Data.get_constrs thy Tcon of
   146           NONE => replace_term_and_restrict thy T t Tts free_names
   147         | SOME constrs => (case strip_comb t of
   148           (Const (s, _), ats) => (case AList.lookup (op =) constrs s of
   149             SOME constr_T =>
   150               let
   151                 val (Ts', T') = strip_type constr_T
   152                 val Tsubst = Type.raw_match (T', T) Vartab.empty
   153                 val Ts = map (Envir.subst_type Tsubst) Ts'
   154                 val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names
   155                 val (ats', ts') = chop (length ats) bts'
   156               in
   157                 (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names')
   158               end
   159             | NONE => replace_term_and_restrict thy T t Tts free_names))
   160     fun restrict_pattern thy Ts args =
   161       let
   162         val args = map Logic.unvarify_global args
   163         val Ts = map Logic.unvarifyT_global Ts
   164         val free_names = fold Term.add_free_names args []
   165         val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names
   166       in map Logic.varify_global pat end
   167     fun detect' atom thy =
   168       case strip_comb atom of
   169         (pred as Const (pred_name, _), args) =>
   170           let
   171           val Ts = binder_types (Sign.the_const_type thy pred_name)
   172           val vnames = map fst (fold Term.add_var_names args [])
   173           val pats = restrict_pattern thy Ts args
   174         in
   175           if (exists (is_nontrivial_constrt thy) pats)
   176             orelse (has_duplicates (op =) (fold add_vars pats [])) then
   177             let
   178               val thy' =
   179                 case specialisation_of thy atom of
   180                   [] =>
   181                     if member (op =) ((map fst specs) @ black_list) pred_name then
   182                       thy
   183                     else
   184                       (case try (Core_Data.intros_of (Proof_Context.init_global thy)) pred_name of
   185                         NONE => thy
   186                       | SOME [] => thy
   187                       | SOME intros =>
   188                           specialise_intros ((map fst specs) @ (pred_name :: black_list))
   189                             (pred, intros) pats thy)
   190                   | (t, specialised_t) :: _ => thy
   191                 val atom' =
   192                   case specialisation_of thy' atom of
   193                     [] => atom
   194                   | (t, specialised_t) :: _ =>
   195                     let
   196                       val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty)
   197                     in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom
   198                     (*FIXME: this exception could be caught earlier in specialisation_of *)
   199             in
   200               (atom', thy')
   201             end
   202           else (atom, thy)
   203         end
   204       | _ => (atom, thy)
   205     fun specialise' (constname, intros) thy =
   206       let
   207         (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *)
   208         val intros = Drule.zero_var_indexes_list intros
   209         val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy
   210       in
   211         ((constname, map (Skip_Proof.make_thm thy') intros_t'), thy')
   212       end
   213   in
   214     fold_map specialise' specs thy
   215   end
   216 
   217 end;