src/HOL/Tools/SMT/smt_monomorph.ML
changeset 36890 8e55aa1306c5
child 39093 4abe644fcea5
equal deleted inserted replaced
36889:6d1ecdb81ff0 36890:8e55aa1306c5
       
     1 (*  Title:      HOL/Tools/SMT/smt_monomorph.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Monomorphization of theorems, i.e., computation of all (necessary) instances.
       
     5 *)
       
     6 
       
     7 signature SMT_MONOMORPH =
       
     8 sig
       
     9   val monomorph: thm list -> Proof.context -> thm list * Proof.context
       
    10 end
       
    11 
       
    12 structure SMT_Monomorph: SMT_MONOMORPH =
       
    13 struct
       
    14 
       
    15 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
       
    16 
       
    17 val ignored = member (op =) [
       
    18   @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If},
       
    19   @{const_name "op ="}, @{const_name zero_class.zero},
       
    20   @{const_name one_class.one}, @{const_name number_of}]
       
    21 
       
    22 fun is_const f (n, T) = not (ignored n) andalso f T
       
    23 fun add_const_if f g (Const c) = if is_const f c then g c else I
       
    24   | add_const_if _ _ _ = I
       
    25 
       
    26 fun collect_consts_if f g thm =
       
    27   Term.fold_aterms (add_const_if f g) (Thm.prop_of thm)
       
    28 
       
    29 fun add_consts f =
       
    30   collect_consts_if f (fn (n, T) => Symtab.map_entry n (insert (op =) T))
       
    31 
       
    32 val insert_const = OrdList.insert (prod_ord fast_string_ord Term_Ord.typ_ord)
       
    33 fun tvar_consts_of thm = collect_consts_if typ_has_tvars insert_const thm []
       
    34 
       
    35 
       
    36 fun incr_indexes thms =
       
    37   let fun inc thm idx = (Thm.incr_indexes idx thm, Thm.maxidx_of thm + idx + 1)
       
    38   in fst (fold_map inc thms 0) end
       
    39 
       
    40 
       
    41 (* Compute all substitutions from the types "Ts" to all relevant
       
    42    types in "grounds", with respect to the given substitution. *)
       
    43 fun new_substitutions thy grounds (n, T) subst =
       
    44   if not (typ_has_tvars T) then [subst]
       
    45   else
       
    46     Symtab.lookup_list grounds n
       
    47     |> map_filter (try (fn U => Sign.typ_match thy (T, U) subst))
       
    48     |> cons subst
       
    49 
       
    50 
       
    51 (* Instantiate a set of constants with a substitution.  Also collect
       
    52    all new ground instances for the next round of specialization. *)
       
    53 fun apply_subst grounds consts subst =
       
    54   let
       
    55     fun is_new_ground (n, T) = not (typ_has_tvars T) andalso
       
    56       not (member (op =) (Symtab.lookup_list grounds n) T)
       
    57 
       
    58     fun apply_const (n, T) new_grounds =
       
    59       let val c = (n, Envir.subst_type subst T)
       
    60       in
       
    61         new_grounds
       
    62         |> is_new_ground c ? Symtab.insert_list (op =) c
       
    63         |> pair c
       
    64       end
       
    65   in fold_map apply_const consts #>> pair subst end
       
    66 
       
    67 
       
    68 (* Compute new substitutions for the theorem "thm", based on
       
    69    previously found substitutions.
       
    70      Also collect new grounds, i.e., instantiated constants
       
    71    (without schematic types) which do not occur in any of the
       
    72    previous rounds. Note that thus no schematic type variables are
       
    73    shared among theorems. *)
       
    74 fun specialize thy all_grounds new_grounds (thm, scs) =
       
    75   let
       
    76     fun spec (subst, consts) next_grounds =
       
    77       [subst]
       
    78       |> fold (maps o new_substitutions thy new_grounds) consts
       
    79       |> rpair next_grounds
       
    80       |-> fold_map (apply_subst all_grounds consts)
       
    81   in
       
    82     fold_map spec scs #>> (fn scss =>
       
    83     (thm, fold (fold (insert (eq_snd (op =)))) scss []))
       
    84   end
       
    85 
       
    86 
       
    87 (* Compute all necessary substitutions.
       
    88      Instead of operating on the propositions of the theorems, the
       
    89    computation uses only the constants occurring with schematic type
       
    90    variables in the propositions. To ease comparisons, such sets of
       
    91    costants are always kept in their initial order. *)
       
    92 fun incremental_monomorph thy limit all_grounds new_grounds ths =
       
    93   let
       
    94     val all_grounds' = Symtab.merge_list (op =) (all_grounds, new_grounds)
       
    95     val spec = specialize thy all_grounds' new_grounds
       
    96     val (ths', new_grounds') = fold_map spec ths Symtab.empty
       
    97   in
       
    98     if Symtab.is_empty new_grounds' then ths'
       
    99     else if limit > 0
       
   100     then incremental_monomorph thy (limit-1) all_grounds' new_grounds' ths'
       
   101     else (warning "SMT: monomorphization limit reached"; ths')
       
   102   end
       
   103 
       
   104 
       
   105 fun filter_most_specific thy =
       
   106   let
       
   107     fun typ_match (_, T) (_, U) = Sign.typ_match thy (T, U)
       
   108 
       
   109     fun is_trivial subst = Vartab.is_empty subst orelse
       
   110       forall (fn (v, (S, T)) => TVar (v, S) = T) (Vartab.dest subst)
       
   111 
       
   112     fun match general specific =
       
   113       (case try (fold2 typ_match general specific) Vartab.empty of
       
   114         NONE => false
       
   115       | SOME subst => not (is_trivial subst))
       
   116 
       
   117     fun most_specific _ [] = []
       
   118       | most_specific css ((ss, cs) :: scs) =
       
   119           let val substs = most_specific (cs :: css) scs
       
   120           in
       
   121             if exists (match cs) css orelse exists (match cs o snd) scs
       
   122             then substs else ss :: substs
       
   123           end
       
   124 
       
   125   in most_specific [] end
       
   126 
       
   127 
       
   128 fun instantiate thy Tenv =
       
   129   let
       
   130     fun replace (v, (_, T)) (U as TVar (u, _)) = if u = v then T else U
       
   131       | replace _ T = T
       
   132 
       
   133     fun complete (vT as (v, _)) subst =
       
   134       subst
       
   135       |> not (Vartab.defined subst v) ? Vartab.update vT
       
   136       |> Vartab.map (apsnd (Term.map_atyps (replace vT)))
       
   137 
       
   138     fun cert (ix, (S, T)) = pairself (Thm.ctyp_of thy) (TVar (ix, S), T)
       
   139 
       
   140     fun inst thm subst =
       
   141       let val cTs = Vartab.fold (cons o cert) (fold complete Tenv subst) []
       
   142       in Thm.instantiate (cTs, []) thm end
       
   143 
       
   144   in uncurry (map o inst) end
       
   145 
       
   146 
       
   147 fun mono_all ctxt _ [] monos = (monos, ctxt)
       
   148   | mono_all ctxt limit polys monos =
       
   149       let
       
   150         fun invent_types thm ctxt =
       
   151           let val (vs, Ss) = split_list (Term.add_tvars (Thm.prop_of thm) [])
       
   152           in
       
   153             ctxt
       
   154             |> Variable.invent_types Ss
       
   155             |>> map2 (fn v => fn (n, S) => (v, (S, TFree (n, S)))) vs
       
   156           end
       
   157         val (Tenvs, ctxt') = fold_map invent_types polys ctxt
       
   158 
       
   159         val thy = ProofContext.theory_of ctxt'
       
   160 
       
   161         val ths = polys
       
   162           |> map (fn thm => (thm, [(Vartab.empty, tvar_consts_of thm)]))
       
   163 
       
   164         (* all constant names occurring with schematic types *)
       
   165         val ns = fold (fold (fold (insert (op =) o fst) o snd) o snd) ths []
       
   166 
       
   167         (* all known instances with non-schematic types *)
       
   168         val grounds =
       
   169           Symtab.make (map (rpair []) ns)
       
   170           |> fold (add_consts (K true)) monos
       
   171           |> fold (add_consts (not o typ_has_tvars)) polys
       
   172       in
       
   173         polys
       
   174         |> map (fn thm => (thm, [(Vartab.empty, tvar_consts_of thm)]))
       
   175         |> incremental_monomorph thy limit Symtab.empty grounds
       
   176         |> map (apsnd (filter_most_specific thy))
       
   177         |> flat o map2 (instantiate thy) Tenvs
       
   178         |> append monos
       
   179         |> rpair ctxt'
       
   180       end
       
   181 
       
   182 
       
   183 val monomorph_limit = 10
       
   184 
       
   185 
       
   186 (* Instantiate all polymorphic constants (i.e., constants occurring
       
   187    both with ground types and type variables) with all (necessary)
       
   188    ground types; thereby create copies of theorems containing those
       
   189    constants.
       
   190      To prevent non-termination, there is an upper limit for the
       
   191    number of recursions involved in the fixpoint construction.
       
   192      The initial set of theorems must not contain any schematic term
       
   193    variables, and the final list of theorems does not contain any
       
   194    schematic type variables anymore. *)
       
   195 fun monomorph thms ctxt =
       
   196   thms
       
   197   |> List.partition (Term.exists_type typ_has_tvars o Thm.prop_of)
       
   198   |>> incr_indexes
       
   199   |-> mono_all ctxt monomorph_limit
       
   200 
       
   201 end