src/HOL/SMT/Tools/smt_monomorph.ML
changeset 32618 42865636d006
child 32950 5d5e123443b3
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/SMT/Tools/smt_monomorph.ML	Fri Sep 18 18:13:19 2009 +0200
     1.3 @@ -0,0 +1,120 @@
     1.4 +(*  Title:      HOL/SMT/Tools/smt_monomorph.ML
     1.5 +    Author:     Sascha Boehme, TU Muenchen
     1.6 +
     1.7 +Monomorphization of terms, i.e., computation of all (necessary) instances.
     1.8 +*)
     1.9 +
    1.10 +signature SMT_MONOMORPH =
    1.11 +sig
    1.12 +  val monomorph: theory -> term list -> term list
    1.13 +end
    1.14 +
    1.15 +structure SMT_Monomorph: SMT_MONOMORPH =
    1.16 +struct
    1.17 +
    1.18 +fun selection [] = []
    1.19 +  | selection (x :: xs) = (x, xs) :: map (apsnd (cons x)) (selection xs)
    1.20 +
    1.21 +fun permute [] = []
    1.22 +  | permute [x] = [[x]]
    1.23 +  | permute xs = maps (fn (y, ys) => map (cons y) (permute ys)) (selection xs)
    1.24 +
    1.25 +fun fold_all f = fold (fn x => maps (f x))
    1.26 +
    1.27 +
    1.28 +val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
    1.29 +val term_has_tvars = Term.exists_type typ_has_tvars
    1.30 +
    1.31 +val ignored = member (op =) [
    1.32 +  @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If},
    1.33 +  @{const_name "op ="}, @{const_name zero_class.zero},
    1.34 +  @{const_name one_class.one}, @{const_name number_of}]
    1.35 +fun consts_of ts = AList.group (op =) (fold Term.add_consts ts [])
    1.36 +  |> filter_out (ignored o fst)
    1.37 +
    1.38 +val join_consts = curry (AList.join (op =) (K (merge (op =))))
    1.39 +fun diff_consts cs ds = 
    1.40 +  let fun diff (n, Ts) =
    1.41 +    (case AList.lookup (op =) cs n of
    1.42 +      NONE => SOME (n, Ts)
    1.43 +    | SOME Us =>
    1.44 +        let val Ts' = fold (remove (op =)) Us Ts
    1.45 +        in if null Ts' then NONE else SOME (n, Ts') end)
    1.46 +  in map_filter diff ds end
    1.47 +
    1.48 +fun instances thy is (n, Ts) env =
    1.49 +  let
    1.50 +    val Us = these (AList.lookup (op =) is n)
    1.51 +    val Ts' = filter typ_has_tvars (map (Envir.subst_type env) Ts)
    1.52 +  in
    1.53 +    (case map_product pair Ts' Us of
    1.54 +      [] => [env]
    1.55 +    | TUs => map_filter (try (fn TU => Sign.typ_match thy TU env)) TUs)
    1.56 +  end
    1.57 +
    1.58 +fun proper_match ps env =
    1.59 +  forall (forall (not o typ_has_tvars o Envir.subst_type env) o snd) ps
    1.60 +
    1.61 +val eq_tab = gen_eq_set (op =) o pairself Vartab.dest
    1.62 +
    1.63 +fun specialize thy cs is ((r, ps), ces) (ts, ns) =
    1.64 +  let
    1.65 +    val ps' = filter (AList.defined (op =) is o fst) ps
    1.66 +
    1.67 +    val envs = permute ps'
    1.68 +      |> maps (fn ps => fold_all (instances thy is) ps [Vartab.empty])
    1.69 +      |> filter (proper_match ps')
    1.70 +      |> filter_out (member eq_tab ces)
    1.71 +      |> distinct eq_tab
    1.72 +
    1.73 +    val us = map (fn env => Envir.subst_term_types env r) envs
    1.74 +    val ns' = join_consts (diff_consts is (diff_consts cs (consts_of us))) ns
    1.75 +  in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end
    1.76 +
    1.77 +
    1.78 +fun incr_tvar_indices i t =
    1.79 +  let
    1.80 +    val incrT = Logic.incr_tvar i
    1.81 +
    1.82 +    fun incr t =
    1.83 +      (case t of
    1.84 +        Const (n, T) => Const (n, incrT T)
    1.85 +      | Free (n, T) => Free (n, incrT T)
    1.86 +      | Abs (n, T, t1) => Abs (n, incrT T, incr t1)
    1.87 +      | t1 $ t2 => incr t1 $ incr t2
    1.88 +      | _ => t)
    1.89 +  in incr t end
    1.90 +
    1.91 +
    1.92 +val monomorph_limit = 10
    1.93 +
    1.94 +(* Instantiate all polymorphic constants (i.e., constants occurring both with
    1.95 +   ground types and type variables) with all (necessary) ground types; thereby
    1.96 +   create copies of terms containing those constants.
    1.97 +   To prevent non-termination, there is an upper limit for the number of
    1.98 +   recursions involved in the fixpoint construction. *)
    1.99 +fun monomorph thy ts =
   1.100 +  let
   1.101 +    val (ps, ms) = List.partition term_has_tvars ts
   1.102 +
   1.103 +    fun with_tvar (n, Ts) =
   1.104 +      let val Ts' = filter typ_has_tvars Ts
   1.105 +      in if null Ts' then NONE else SOME (n, Ts') end
   1.106 +    fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
   1.107 +    val rps = fst (fold_map incr ps 0)
   1.108 +      |> map (fn r => (r, map_filter with_tvar (consts_of [r])))
   1.109 +
   1.110 +    fun mono count is ces cs ts =
   1.111 +      let
   1.112 +        val spec = specialize thy cs is
   1.113 +        val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
   1.114 +        val cs' = join_consts is cs
   1.115 +      in
   1.116 +        if null is' then ts'
   1.117 +        else if count > monomorph_limit then
   1.118 +          (Output.warning "monomorphization limit reached"; ts')
   1.119 +        else mono (count + 1) is' ces' cs' ts'
   1.120 +      end
   1.121 +  in mono 0 (consts_of ms) (map (K []) rps) [] ms end
   1.122 +
   1.123 +end