1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000
1.2 +++ b/src/HOL/SMT/Tools/smt_monomorph.ML Fri Sep 18 18:13:19 2009 +0200
1.3 @@ -0,0 +1,120 @@
1.4 +(* Title: HOL/SMT/Tools/smt_monomorph.ML
1.5 + Author: Sascha Boehme, TU Muenchen
1.6 +
1.7 +Monomorphization of terms, i.e., computation of all (necessary) instances.
1.8 +*)
1.9 +
1.10 +signature SMT_MONOMORPH =
1.11 +sig
1.12 + val monomorph: theory -> term list -> term list
1.13 +end
1.14 +
1.15 +structure SMT_Monomorph: SMT_MONOMORPH =
1.16 +struct
1.17 +
1.18 +fun selection [] = []
1.19 + | selection (x :: xs) = (x, xs) :: map (apsnd (cons x)) (selection xs)
1.20 +
1.21 +fun permute [] = []
1.22 + | permute [x] = [[x]]
1.23 + | permute xs = maps (fn (y, ys) => map (cons y) (permute ys)) (selection xs)
1.24 +
1.25 +fun fold_all f = fold (fn x => maps (f x))
1.26 +
1.27 +
1.28 +val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false)
1.29 +val term_has_tvars = Term.exists_type typ_has_tvars
1.30 +
1.31 +val ignored = member (op =) [
1.32 + @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If},
1.33 + @{const_name "op ="}, @{const_name zero_class.zero},
1.34 + @{const_name one_class.one}, @{const_name number_of}]
1.35 +fun consts_of ts = AList.group (op =) (fold Term.add_consts ts [])
1.36 + |> filter_out (ignored o fst)
1.37 +
1.38 +val join_consts = curry (AList.join (op =) (K (merge (op =))))
1.39 +fun diff_consts cs ds =
1.40 + let fun diff (n, Ts) =
1.41 + (case AList.lookup (op =) cs n of
1.42 + NONE => SOME (n, Ts)
1.43 + | SOME Us =>
1.44 + let val Ts' = fold (remove (op =)) Us Ts
1.45 + in if null Ts' then NONE else SOME (n, Ts') end)
1.46 + in map_filter diff ds end
1.47 +
1.48 +fun instances thy is (n, Ts) env =
1.49 + let
1.50 + val Us = these (AList.lookup (op =) is n)
1.51 + val Ts' = filter typ_has_tvars (map (Envir.subst_type env) Ts)
1.52 + in
1.53 + (case map_product pair Ts' Us of
1.54 + [] => [env]
1.55 + | TUs => map_filter (try (fn TU => Sign.typ_match thy TU env)) TUs)
1.56 + end
1.57 +
1.58 +fun proper_match ps env =
1.59 + forall (forall (not o typ_has_tvars o Envir.subst_type env) o snd) ps
1.60 +
1.61 +val eq_tab = gen_eq_set (op =) o pairself Vartab.dest
1.62 +
1.63 +fun specialize thy cs is ((r, ps), ces) (ts, ns) =
1.64 + let
1.65 + val ps' = filter (AList.defined (op =) is o fst) ps
1.66 +
1.67 + val envs = permute ps'
1.68 + |> maps (fn ps => fold_all (instances thy is) ps [Vartab.empty])
1.69 + |> filter (proper_match ps')
1.70 + |> filter_out (member eq_tab ces)
1.71 + |> distinct eq_tab
1.72 +
1.73 + val us = map (fn env => Envir.subst_term_types env r) envs
1.74 + val ns' = join_consts (diff_consts is (diff_consts cs (consts_of us))) ns
1.75 + in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end
1.76 +
1.77 +
1.78 +fun incr_tvar_indices i t =
1.79 + let
1.80 + val incrT = Logic.incr_tvar i
1.81 +
1.82 + fun incr t =
1.83 + (case t of
1.84 + Const (n, T) => Const (n, incrT T)
1.85 + | Free (n, T) => Free (n, incrT T)
1.86 + | Abs (n, T, t1) => Abs (n, incrT T, incr t1)
1.87 + | t1 $ t2 => incr t1 $ incr t2
1.88 + | _ => t)
1.89 + in incr t end
1.90 +
1.91 +
1.92 +val monomorph_limit = 10
1.93 +
1.94 +(* Instantiate all polymorphic constants (i.e., constants occurring both with
1.95 + ground types and type variables) with all (necessary) ground types; thereby
1.96 + create copies of terms containing those constants.
1.97 + To prevent non-termination, there is an upper limit for the number of
1.98 + recursions involved in the fixpoint construction. *)
1.99 +fun monomorph thy ts =
1.100 + let
1.101 + val (ps, ms) = List.partition term_has_tvars ts
1.102 +
1.103 + fun with_tvar (n, Ts) =
1.104 + let val Ts' = filter typ_has_tvars Ts
1.105 + in if null Ts' then NONE else SOME (n, Ts') end
1.106 + fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
1.107 + val rps = fst (fold_map incr ps 0)
1.108 + |> map (fn r => (r, map_filter with_tvar (consts_of [r])))
1.109 +
1.110 + fun mono count is ces cs ts =
1.111 + let
1.112 + val spec = specialize thy cs is
1.113 + val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
1.114 + val cs' = join_consts is cs
1.115 + in
1.116 + if null is' then ts'
1.117 + else if count > monomorph_limit then
1.118 + (Output.warning "monomorphization limit reached"; ts')
1.119 + else mono (count + 1) is' ces' cs' ts'
1.120 + end
1.121 + in mono 0 (consts_of ms) (map (K []) rps) [] ms end
1.122 +
1.123 +end