|
1 (* Title: HOL/SMT/Tools/smt_monomorph.ML |
|
2 Author: Sascha Boehme, TU Muenchen |
|
3 |
|
4 Monomorphization of terms, i.e., computation of all (necessary) instances. |
|
5 *) |
|
6 |
|
7 signature SMT_MONOMORPH = |
|
8 sig |
|
9 val monomorph: theory -> term list -> term list |
|
10 end |
|
11 |
|
12 structure SMT_Monomorph: SMT_MONOMORPH = |
|
13 struct |
|
14 |
|
15 fun selection [] = [] |
|
16 | selection (x :: xs) = (x, xs) :: map (apsnd (cons x)) (selection xs) |
|
17 |
|
18 fun permute [] = [] |
|
19 | permute [x] = [[x]] |
|
20 | permute xs = maps (fn (y, ys) => map (cons y) (permute ys)) (selection xs) |
|
21 |
|
22 fun fold_all f = fold (fn x => maps (f x)) |
|
23 |
|
24 |
|
25 val typ_has_tvars = Term.exists_subtype (fn TVar _ => true | _ => false) |
|
26 val term_has_tvars = Term.exists_type typ_has_tvars |
|
27 |
|
28 val ignored = member (op =) [ |
|
29 @{const_name All}, @{const_name Ex}, @{const_name Let}, @{const_name If}, |
|
30 @{const_name "op ="}, @{const_name zero_class.zero}, |
|
31 @{const_name one_class.one}, @{const_name number_of}] |
|
32 fun consts_of ts = AList.group (op =) (fold Term.add_consts ts []) |
|
33 |> filter_out (ignored o fst) |
|
34 |
|
35 val join_consts = curry (AList.join (op =) (K (merge (op =)))) |
|
36 fun diff_consts cs ds = |
|
37 let fun diff (n, Ts) = |
|
38 (case AList.lookup (op =) cs n of |
|
39 NONE => SOME (n, Ts) |
|
40 | SOME Us => |
|
41 let val Ts' = fold (remove (op =)) Us Ts |
|
42 in if null Ts' then NONE else SOME (n, Ts') end) |
|
43 in map_filter diff ds end |
|
44 |
|
45 fun instances thy is (n, Ts) env = |
|
46 let |
|
47 val Us = these (AList.lookup (op =) is n) |
|
48 val Ts' = filter typ_has_tvars (map (Envir.subst_type env) Ts) |
|
49 in |
|
50 (case map_product pair Ts' Us of |
|
51 [] => [env] |
|
52 | TUs => map_filter (try (fn TU => Sign.typ_match thy TU env)) TUs) |
|
53 end |
|
54 |
|
55 fun proper_match ps env = |
|
56 forall (forall (not o typ_has_tvars o Envir.subst_type env) o snd) ps |
|
57 |
|
58 val eq_tab = gen_eq_set (op =) o pairself Vartab.dest |
|
59 |
|
60 fun specialize thy cs is ((r, ps), ces) (ts, ns) = |
|
61 let |
|
62 val ps' = filter (AList.defined (op =) is o fst) ps |
|
63 |
|
64 val envs = permute ps' |
|
65 |> maps (fn ps => fold_all (instances thy is) ps [Vartab.empty]) |
|
66 |> filter (proper_match ps') |
|
67 |> filter_out (member eq_tab ces) |
|
68 |> distinct eq_tab |
|
69 |
|
70 val us = map (fn env => Envir.subst_term_types env r) envs |
|
71 val ns' = join_consts (diff_consts is (diff_consts cs (consts_of us))) ns |
|
72 in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end |
|
73 |
|
74 |
|
75 fun incr_tvar_indices i t = |
|
76 let |
|
77 val incrT = Logic.incr_tvar i |
|
78 |
|
79 fun incr t = |
|
80 (case t of |
|
81 Const (n, T) => Const (n, incrT T) |
|
82 | Free (n, T) => Free (n, incrT T) |
|
83 | Abs (n, T, t1) => Abs (n, incrT T, incr t1) |
|
84 | t1 $ t2 => incr t1 $ incr t2 |
|
85 | _ => t) |
|
86 in incr t end |
|
87 |
|
88 |
|
89 val monomorph_limit = 10 |
|
90 |
|
91 (* Instantiate all polymorphic constants (i.e., constants occurring both with |
|
92 ground types and type variables) with all (necessary) ground types; thereby |
|
93 create copies of terms containing those constants. |
|
94 To prevent non-termination, there is an upper limit for the number of |
|
95 recursions involved in the fixpoint construction. *) |
|
96 fun monomorph thy ts = |
|
97 let |
|
98 val (ps, ms) = List.partition term_has_tvars ts |
|
99 |
|
100 fun with_tvar (n, Ts) = |
|
101 let val Ts' = filter typ_has_tvars Ts |
|
102 in if null Ts' then NONE else SOME (n, Ts') end |
|
103 fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1) |
|
104 val rps = fst (fold_map incr ps 0) |
|
105 |> map (fn r => (r, map_filter with_tvar (consts_of [r]))) |
|
106 |
|
107 fun mono count is ces cs ts = |
|
108 let |
|
109 val spec = specialize thy cs is |
|
110 val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, []) |
|
111 val cs' = join_consts is cs |
|
112 in |
|
113 if null is' then ts' |
|
114 else if count > monomorph_limit then |
|
115 (Output.warning "monomorphization limit reached"; ts') |
|
116 else mono (count + 1) is' ces' cs' ts' |
|
117 end |
|
118 in mono 0 (consts_of ms) (map (K []) rps) [] ms end |
|
119 |
|
120 end |