src/HOL/SMT/Tools/smt_monomorph.ML
changeset 32618 42865636d006
child 32950 5d5e123443b3
equal deleted inserted replaced
32608:c0056c2c1d17 32618:42865636d006
       
     1 (*  Title:      HOL/SMT/Tools/smt_monomorph.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Monomorphization of terms, i.e., computation of all (necessary) instances.
       
     5 *)
       
     6 
       
     7 signature SMT_MONOMORPH =
       
     8 sig
       
     9   val monomorph: theory -> term list -> term list
       
    10 end
       
    11 
       
    12 structure SMT_Monomorph: SMT_MONOMORPH =
       
    13 struct
       
    14 
       
    15 fun selection [] = []
       
    16   | selection (x :: xs) = (x, xs) :: map (apsnd (cons x)) (selection xs)
       
    17 
       
    18 fun permute [] = []
       
    19   | permute [x] = [[x]]
       
    20   | permute xs = maps (fn (y, ys) => map (cons y) (permute ys)) (selection xs)
       
    21 
       
    22 fun fold_all f = fold (fn x => maps (f x))
       
    23 
       
    24 
       
    25 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
       
    26 val term_has_tvars = Term.exists_type typ_has_tvars
       
    27 
       
    28 val ignored = member (op =) [
       
    29   @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If},
       
    30   @{const_name "op ="}, @{const_name zero_class.zero},
       
    31   @{const_name one_class.one}, @{const_name number_of}]
       
    32 fun consts_of ts = AList.group (op =) (fold Term.add_consts ts [])
       
    33   |> filter_out (ignored o fst)
       
    34 
       
    35 val join_consts = curry (AList.join (op =) (K (merge (op =))))
       
    36 fun diff_consts cs ds = 
       
    37   let fun diff (n, Ts) =
       
    38     (case AList.lookup (op =) cs n of
       
    39       NONE => SOME (n, Ts)
       
    40     | SOME Us =>
       
    41         let val Ts' = fold (remove (op =)) Us Ts
       
    42         in if null Ts' then NONE else SOME (n, Ts') end)
       
    43   in map_filter diff ds end
       
    44 
       
    45 fun instances thy is (n, Ts) env =
       
    46   let
       
    47     val Us = these (AList.lookup (op =) is n)
       
    48     val Ts' = filter typ_has_tvars (map (Envir.subst_type env) Ts)
       
    49   in
       
    50     (case map_product pair Ts' Us of
       
    51       [] => [env]
       
    52     | TUs => map_filter (try (fn TU => Sign.typ_match thy TU env)) TUs)
       
    53   end
       
    54 
       
    55 fun proper_match ps env =
       
    56   forall (forall (not o typ_has_tvars o Envir.subst_type env) o snd) ps
       
    57 
       
    58 val eq_tab = gen_eq_set (op =) o pairself Vartab.dest
       
    59 
       
    60 fun specialize thy cs is ((r, ps), ces) (ts, ns) =
       
    61   let
       
    62     val ps' = filter (AList.defined (op =) is o fst) ps
       
    63 
       
    64     val envs = permute ps'
       
    65       |> maps (fn ps => fold_all (instances thy is) ps [Vartab.empty])
       
    66       |> filter (proper_match ps')
       
    67       |> filter_out (member eq_tab ces)
       
    68       |> distinct eq_tab
       
    69 
       
    70     val us = map (fn env => Envir.subst_term_types env r) envs
       
    71     val ns' = join_consts (diff_consts is (diff_consts cs (consts_of us))) ns
       
    72   in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end
       
    73 
       
    74 
       
    75 fun incr_tvar_indices i t =
       
    76   let
       
    77     val incrT = Logic.incr_tvar i
       
    78 
       
    79     fun incr t =
       
    80       (case t of
       
    81         Const (n, T) => Const (n, incrT T)
       
    82       | Free (n, T) => Free (n, incrT T)
       
    83       | Abs (n, T, t1) => Abs (n, incrT T, incr t1)
       
    84       | t1 $ t2 => incr t1 $ incr t2
       
    85       | _ => t)
       
    86   in incr t end
       
    87 
       
    88 
       
    89 val monomorph_limit = 10
       
    90 
       
    91 (* Instantiate all polymorphic constants (i.e., constants occurring both with
       
    92    ground types and type variables) with all (necessary) ground types; thereby
       
    93    create copies of terms containing those constants.
       
    94    To prevent non-termination, there is an upper limit for the number of
       
    95    recursions involved in the fixpoint construction. *)
       
    96 fun monomorph thy ts =
       
    97   let
       
    98     val (ps, ms) = List.partition term_has_tvars ts
       
    99 
       
   100     fun with_tvar (n, Ts) =
       
   101       let val Ts' = filter typ_has_tvars Ts
       
   102       in if null Ts' then NONE else SOME (n, Ts') end
       
   103     fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
       
   104     val rps = fst (fold_map incr ps 0)
       
   105       |> map (fn r => (r, map_filter with_tvar (consts_of [r])))
       
   106 
       
   107     fun mono count is ces cs ts =
       
   108       let
       
   109         val spec = specialize thy cs is
       
   110         val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
       
   111         val cs' = join_consts is cs
       
   112       in
       
   113         if null is' then ts'
       
   114         else if count > monomorph_limit then
       
   115           (Output.warning "monomorphization limit reached"; ts')
       
   116         else mono (count + 1) is' ces' cs' ts'
       
   117       end
       
   118   in mono 0 (consts_of ms) (map (K []) rps) [] ms end
       
   119 
       
   120 end