adding specialisation of predicates to the predicate compiler
authorbulwahn
Mon, 29 Mar 2010 17:30:52 +0200
changeset 36026dfd30b5b4e73
parent 36025 199fe16cdaab
child 36027 7106f079bd05
adding specialisation of predicates to the predicate compiler
src/HOL/IsaMakefile
src/HOL/Predicate_Compile.thy
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
     1.1 --- a/src/HOL/IsaMakefile	Mon Mar 29 17:30:50 2010 +0200
     1.2 +++ b/src/HOL/IsaMakefile	Mon Mar 29 17:30:52 2010 +0200
     1.3 @@ -301,6 +301,7 @@
     1.4    Tools/Predicate_Compile/predicate_compile_data.ML \
     1.5    Tools/Predicate_Compile/predicate_compile_fun.ML \
     1.6    Tools/Predicate_Compile/predicate_compile.ML \
     1.7 +  Tools/Predicate_Compile/predicate_compile_specialisation.ML \
     1.8    Tools/Predicate_Compile/predicate_compile_pred.ML \
     1.9    Tools/quickcheck_generators.ML \
    1.10    Tools/Qelim/cooper_data.ML \
     2.1 --- a/src/HOL/Predicate_Compile.thy	Mon Mar 29 17:30:50 2010 +0200
     2.2 +++ b/src/HOL/Predicate_Compile.thy	Mon Mar 29 17:30:52 2010 +0200
     2.3 @@ -12,6 +12,7 @@
     2.4    "Tools/Predicate_Compile/predicate_compile_data.ML"
     2.5    "Tools/Predicate_Compile/predicate_compile_fun.ML"
     2.6    "Tools/Predicate_Compile/predicate_compile_pred.ML"
     2.7 +  "Tools/Predicate_Compile/predicate_compile_specialisation.ML"
     2.8    "Tools/Predicate_Compile/predicate_compile.ML"
     2.9  begin
    2.10  
     3.1 --- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Mon Mar 29 17:30:50 2010 +0200
     3.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Mon Mar 29 17:30:52 2010 +0200
     3.3 @@ -97,7 +97,7 @@
     3.4        val _ = print_step options
     3.5          ("Compiling functions (" ^ commas (map (Syntax.string_of_term_global thy) funnames) ^
     3.6            ") to predicates...")
     3.7 -      val (fun_pred_specs, thy') =
     3.8 +      val (fun_pred_specs, thy1) =
     3.9          (if function_flattening options andalso (not (null funnames)) then
    3.10            if fail_safe_function_flattening options then
    3.11              case try (Predicate_Compile_Fun.define_predicates (get_specs funnames)) thy of
    3.12 @@ -106,24 +106,26 @@
    3.13            else Predicate_Compile_Fun.define_predicates (get_specs funnames) thy
    3.14          else ([], thy))
    3.15          (*||> Theory.checkpoint*)
    3.16 -      val _ = print_specs options thy' fun_pred_specs
    3.17 +      val _ = print_specs options thy1 fun_pred_specs
    3.18        val specs = (get_specs prednames) @ fun_pred_specs
    3.19 -      val (intross3, thy''') = process_specification options specs thy'
    3.20 -      val _ = print_intross options thy''' "Introduction rules with new constants: " intross3
    3.21 +      val (intross3, thy2) = process_specification options specs thy1
    3.22 +      val _ = print_intross options thy2 "Introduction rules with new constants: " intross3
    3.23        val intross4 = map_specs (maps remove_pointless_clauses) intross3
    3.24 -      val _ = print_intross options thy''' "After removing pointless clauses: " intross4
    3.25 -      val intross5 = map_specs (map (remove_equalities thy''')) intross4
    3.26 -      val _ = print_intross options thy''' "After removing equality premises:" intross5
    3.27 +      val _ = print_intross options thy2 "After removing pointless clauses: " intross4
    3.28 +      val intross5 = map_specs (map (remove_equalities thy2)) intross4
    3.29 +      val _ = print_intross options thy2 "After removing equality premises:" intross5
    3.30        val intross6 =
    3.31 -        map (fn (s, ths) => (overload_const thy''' s, map (AxClass.overload thy''') ths)) intross5
    3.32 -      val intross7 = map_specs (map (expand_tuples thy''')) intross6
    3.33 -      val intross8 = map_specs (map (eta_contract_ho_arguments thy''')) intross7
    3.34 -      val _ = case !intro_hook of NONE => () | SOME f => (map_specs (map (f thy''')) intross8; ())
    3.35 -      val _ = print_intross options thy''' "introduction rules before registering: " intross8
    3.36 +        map (fn (s, ths) => (overload_const thy2 s, map (AxClass.overload thy2) ths)) intross5
    3.37 +      val intross7 = map_specs (map (expand_tuples thy2)) intross6
    3.38 +      val intross8 = map_specs (map (eta_contract_ho_arguments thy2)) intross7
    3.39 +      val _ = case !intro_hook of NONE => () | SOME f => (map_specs (map (f thy2)) intross8; ())
    3.40 +      val _ = print_step options ("Looking for specialisations in " ^ commas (map fst intross8) ^ "...")
    3.41 +      val (intross9, thy3) = Predicate_Compile_Specialisation.find_specialisations [] intross8 thy2
    3.42 +      val _ = print_intross options thy3 "introduction rules before registering: " intross9
    3.43        val _ = print_step options "Registering introduction rules..."
    3.44 -      val thy'''' = fold Predicate_Compile_Core.register_intros intross8 thy'''
    3.45 +      val thy4 = fold Predicate_Compile_Core.register_intros intross9 thy3
    3.46      in
    3.47 -      thy''''
    3.48 +      thy4
    3.49      end;
    3.50  
    3.51  fun preprocess options t thy =
     4.1 --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 29 17:30:50 2010 +0200
     4.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 29 17:30:52 2010 +0200
     4.3 @@ -295,12 +295,13 @@
     4.4        (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
     4.5        (Symtab.dest (Datatype.get_all thy)));
     4.6      fun check t = (case strip_comb t of
     4.7 -        (Free _, []) => true
     4.8 +        (Var _, []) => true
     4.9 +      | (Free _, []) => true
    4.10        | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
    4.11              (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
    4.12            | _ => false)
    4.13        | _ => false)
    4.14 -  in check end;  
    4.15 +  in check end;
    4.16  
    4.17  fun is_funtype (Type ("fun", [_, _])) = true
    4.18    | is_funtype _ = false;
     5.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     5.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML	Mon Mar 29 17:30:52 2010 +0200
     5.3 @@ -0,0 +1,200 @@
     5.4 +(*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_specialisation.ML
     5.5 +    Author:     Lukas Bulwahn, TU Muenchen
     5.6 +
     5.7 +Deriving specialised predicates and their intro rules
     5.8 +*)
     5.9 +
    5.10 +signature PREDICATE_COMPILE_SPECIALISATION =
    5.11 +sig
    5.12 +  val find_specialisations : string list -> (string * thm list) list -> theory -> (string * thm list) list * theory
    5.13 +end;
    5.14 +
    5.15 +structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION =
    5.16 +struct
    5.17 +
    5.18 +open Predicate_Compile_Aux;
    5.19 +
    5.20 +(* table of specialisations *)
    5.21 +structure Specialisations = Theory_Data
    5.22 +(
    5.23 +  type T = (term * term) Item_Net.T;
    5.24 +  val empty = Item_Net.init ((op aconv o pairself fst) : (term * term) * (term * term) -> bool)
    5.25 +    (single o fst);
    5.26 +  val extend = I;
    5.27 +  val merge = Item_Net.merge;
    5.28 +)
    5.29 +
    5.30 +fun specialisation_of thy atom =
    5.31 +  Item_Net.retrieve (Specialisations.get thy) atom
    5.32 +
    5.33 +fun print_specialisations thy =
    5.34 +  tracing (cat_lines (map (fn (t, spec_t) =>
    5.35 +      Syntax.string_of_term_global thy t ^ " ~~~> " ^ Syntax.string_of_term_global thy spec_t)
    5.36 +    (Item_Net.content (Specialisations.get thy))))
    5.37 +
    5.38 +fun import (pred, intros) args ctxt =
    5.39 +  let
    5.40 +    val thy = ProofContext.theory_of ctxt
    5.41 +    val ((Tinst, intros'), ctxt') = Variable.importT intros ctxt
    5.42 +    val pred' = fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros')))))
    5.43 +    val Ts = binder_types (fastype_of pred')
    5.44 +    val argTs = map fastype_of args
    5.45 +    val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty
    5.46 +    val args' = map (Envir.subst_term_types Tsubst) args
    5.47 +  in
    5.48 +    (((pred', intros'), args'), ctxt')
    5.49 +  end
    5.50 +
    5.51 +
    5.52 +
    5.53 +
    5.54 +fun specialise_intros black_list (pred, intros) pats thy =
    5.55 +  let
    5.56 +    val ctxt = ProofContext.init thy
    5.57 +    val maxidx = fold (Term.maxidx_term o prop_of) intros ~1
    5.58 +    val pats = map (Logic.incr_indexes ([],  maxidx + 1)) pats
    5.59 +    val (((pred, intros), pats), ctxt') = import (pred, intros) pats ctxt
    5.60 +    val intros_t = map prop_of intros
    5.61 +    val result_pats = map Var (fold_rev Term.add_vars pats [])
    5.62 +    fun mk_fresh_name names =
    5.63 +      let
    5.64 +        val name =
    5.65 +          Name.variant names ("specialised_" ^ Long_Name.base_name (fst (dest_Const pred)))
    5.66 +        val bname = Sign.full_bname thy name
    5.67 +      in
    5.68 +        if Sign.declared_const thy bname then
    5.69 +          mk_fresh_name (name :: names)
    5.70 +        else
    5.71 +          bname
    5.72 +      end
    5.73 +    val constname = mk_fresh_name []
    5.74 +    val constT = map fastype_of result_pats ---> @{typ bool}
    5.75 +    val specialised_const = Const (constname, constT)
    5.76 +    val specialisation =
    5.77 +      [(HOLogic.mk_Trueprop (list_comb (pred, pats)),
    5.78 +        HOLogic.mk_Trueprop (list_comb (specialised_const, result_pats)))]
    5.79 +    fun specialise_intro intro =
    5.80 +      (let
    5.81 +        val (prems, concl) = Logic.strip_horn (prop_of intro)
    5.82 +        val env = Pattern.unify thy
    5.83 +          (HOLogic.mk_Trueprop (list_comb (pred, pats)), concl) (Envir.empty 0)
    5.84 +        val prems = map (Envir.norm_term env) prems
    5.85 +        val args = map (Envir.norm_term env) result_pats
    5.86 +        val concl = HOLogic.mk_Trueprop (list_comb (specialised_const, args))
    5.87 +        val intro = Logic.list_implies (prems, concl)
    5.88 +      in
    5.89 +        SOME intro
    5.90 +      end handle Pattern.Unif => NONE)
    5.91 +    val specialised_intros_t = map_filter I (map specialise_intro intros)
    5.92 +    val thy' = Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy
    5.93 +    val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t
    5.94 +    val exported_intros = Variable.exportT ctxt' ctxt specialised_intros
    5.95 +    val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt
    5.96 +      [list_comb (pred, pats), list_comb (specialised_const, result_pats)]
    5.97 +    val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy'
    5.98 +    val ([spec], thy''') = find_specialisations black_list [(constname, exported_intros)] thy''
    5.99 +    val thy'''' = Predicate_Compile_Core.register_intros spec thy'''
   5.100 +  in
   5.101 +    thy''''
   5.102 +  end
   5.103 +
   5.104 +and find_specialisations black_list specs thy =
   5.105 +  let
   5.106 +    val add_vars = fold_aterms (fn Var v => cons v | _ => I);
   5.107 +    fun is_nontrivial_constrt thy t = not (is_Var t) andalso (is_constrt thy t)
   5.108 +    fun fresh_free T free_names =
   5.109 +      let
   5.110 +        val free_name = Name.variant free_names "x"
   5.111 +      in
   5.112 +        (Free (free_name, T), free_name :: free_names)
   5.113 +      end
   5.114 +    fun replace_term_and_restrict thy T t Tts free_names =
   5.115 +      let
   5.116 +        val (free, free_names') = fresh_free T free_names
   5.117 +        val Tts' = map (apsnd (Pattern.rewrite_term thy [(t, free)] [])) Tts
   5.118 +        val (ts', free_names'') = restrict_pattern' thy Tts' free_names'
   5.119 +      in
   5.120 +        (free :: ts', free_names'')
   5.121 +      end
   5.122 +    and restrict_pattern' thy [] free_names = ([], free_names)
   5.123 +      | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names =
   5.124 +      let
   5.125 +        val (ts', free_names') = restrict_pattern' thy Tts free_names
   5.126 +      in
   5.127 +        (Free (x, T) :: ts', free_names')
   5.128 +      end
   5.129 +      | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names =
   5.130 +        replace_term_and_restrict thy T t Tts free_names
   5.131 +      | restrict_pattern' thy ((T as Type (Tcon, Ts), t) :: Tts) free_names =
   5.132 +        case Datatype_Data.get_constrs thy Tcon of
   5.133 +          NONE => replace_term_and_restrict thy T t Tts free_names
   5.134 +        | SOME constrs => (case strip_comb t of
   5.135 +          (Const (s, _), ats) => (case AList.lookup (op =) constrs s of
   5.136 +            SOME constr_T =>
   5.137 +              let
   5.138 +                val (Ts', T') = strip_type constr_T
   5.139 +                val Tsubst = Type.raw_match (T', T) Vartab.empty
   5.140 +                val Ts = map (Envir.subst_type Tsubst) Ts'
   5.141 +                val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names
   5.142 +                val (ats', ts') = chop (length ats) bts'
   5.143 +              in
   5.144 +                (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names')
   5.145 +              end
   5.146 +            | NONE => replace_term_and_restrict thy T t Tts free_names))
   5.147 +    fun restrict_pattern thy Ts args =
   5.148 +      let
   5.149 +        val args = map Logic.unvarify_global args
   5.150 +        val Ts = map Logic.unvarifyT_global Ts
   5.151 +        val free_names = fold Term.add_free_names args []
   5.152 +        val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names
   5.153 +      in map Logic.varify_global pat end
   5.154 +    fun detect' atom thy =
   5.155 +      case strip_comb atom of
   5.156 +        (pred as Const (pred_name, _), args) =>
   5.157 +          let
   5.158 +          val Ts = binder_types (Sign.the_const_type thy pred_name)
   5.159 +          val vnames = map fst (fold Term.add_var_names args [])
   5.160 +          val pats = restrict_pattern thy Ts args
   5.161 +        in
   5.162 +          if (exists (is_nontrivial_constrt thy) pats)
   5.163 +            orelse (has_duplicates (op =) (fold add_vars pats [])) then
   5.164 +            let
   5.165 +              val thy' =
   5.166 +                case specialisation_of thy atom of
   5.167 +                  [] =>
   5.168 +                    if member (op =) ((map fst specs) @ black_list) pred_name then
   5.169 +                      thy
   5.170 +                    else
   5.171 +                      (case try (Predicate_Compile_Core.intros_of thy) pred_name of
   5.172 +                        NONE => thy
   5.173 +                      | SOME intros =>
   5.174 +                          specialise_intros ((map fst specs) @ (pred_name :: black_list))
   5.175 +                            (pred, intros) pats thy)
   5.176 +                  | (t, specialised_t) :: _ => thy
   5.177 +                val atom' =
   5.178 +                  case specialisation_of thy' atom of
   5.179 +                    [] => atom
   5.180 +                  | (t, specialised_t) :: _ =>
   5.181 +                    let
   5.182 +                      val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty)
   5.183 +                    in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom
   5.184 +                    (*FIXME: this exception could be caught earlier in specialisation_of *)
   5.185 +            in
   5.186 +              (atom', thy')
   5.187 +            end
   5.188 +          else (atom, thy)
   5.189 +        end
   5.190 +      | _ => (atom, thy)
   5.191 +    fun specialise' (constname, intros) thy =
   5.192 +      let
   5.193 +        (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *)
   5.194 +        val intros = Drule.zero_var_indexes_list intros
   5.195 +        val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy
   5.196 +      in
   5.197 +        ((constname, map (Skip_Proof.make_thm thy') intros_t'), thy')
   5.198 +      end
   5.199 +  in
   5.200 +    fold_map specialise' specs thy
   5.201 +  end
   5.202 +
   5.203 +end;
   5.204 \ No newline at end of file