src/HOL/Matrix_LP/Compute_Oracle/linker.ML
changeset 47859 9f492f5b0cec
parent 47402 eff798e48efc
child 48326 26315a545e26
equal deleted inserted replaced
47858:15ce93dfe6da 47859:9f492f5b0cec
       
     1 (*  Title:      HOL/Matrix/Compute_Oracle/linker.ML
       
     2     Author:     Steven Obua
       
     3 
       
     4 This module solves the problem that the computing oracle does not
       
     5 instantiate polymorphic rules. By going through the PCompute
       
     6 interface, all possible instantiations are resolved by compiling new
       
     7 programs, if necessary. The obvious disadvantage of this approach is
       
     8 that in the worst case for each new term to be rewritten, a new
       
     9 program may be compiled.
       
    10 *)
       
    11 
       
    12 (*
       
    13    Given constants/frees c_1::t_1, c_2::t_2, ...., c_n::t_n,
       
    14    and constants/frees d_1::d_1, d_2::s_2, ..., d_m::s_m
       
    15 
       
    16    Find all substitutions S such that
       
    17    a) the domain of S is tvars (t_1, ..., t_n)
       
    18    b) there are indices i_1, ..., i_k, and j_1, ..., j_k with
       
    19       1. S (c_i_1::t_i_1) = d_j_1::s_j_1, ..., S (c_i_k::t_i_k) = d_j_k::s_j_k
       
    20       2. tvars (t_i_1, ..., t_i_k) = tvars (t_1, ..., t_n)
       
    21 *)
       
    22 signature LINKER =
       
    23 sig
       
    24     exception Link of string
       
    25 
       
    26     datatype constant = Constant of bool * string * typ
       
    27     val constant_of : term -> constant
       
    28 
       
    29     type instances
       
    30     type subst = Type.tyenv
       
    31 
       
    32     val empty : constant list -> instances
       
    33     val typ_of_constant : constant -> typ
       
    34     val add_instances : theory -> instances -> constant list -> subst list * instances
       
    35     val substs_of : instances -> subst list
       
    36     val is_polymorphic : constant -> bool
       
    37     val distinct_constants : constant list -> constant list
       
    38     val collect_consts : term list -> constant list
       
    39 end
       
    40 
       
    41 structure Linker : LINKER = struct
       
    42 
       
    43 exception Link of string;
       
    44 
       
    45 type subst = Type.tyenv
       
    46 
       
    47 datatype constant = Constant of bool * string * typ
       
    48 fun constant_of (Const (name, ty)) = Constant (false, name, ty)
       
    49   | constant_of (Free (name, ty)) = Constant (true, name, ty)
       
    50   | constant_of _ = raise Link "constant_of"
       
    51 
       
    52 fun bool_ord (x,y) = if x then (if y then EQUAL else GREATER) else (if y then LESS else EQUAL)
       
    53 fun constant_ord (Constant (x1,x2,x3), Constant (y1,y2,y3)) = (prod_ord (prod_ord bool_ord fast_string_ord) Term_Ord.typ_ord) (((x1,x2),x3), ((y1,y2),y3))
       
    54 fun constant_modty_ord (Constant (x1,x2,_), Constant (y1,y2,_)) = (prod_ord bool_ord fast_string_ord) ((x1,x2), (y1,y2))
       
    55 
       
    56 
       
    57 structure Consttab = Table(type key = constant val ord = constant_ord);
       
    58 structure ConsttabModTy = Table(type key = constant val ord = constant_modty_ord);
       
    59 
       
    60 fun typ_of_constant (Constant (_, _, ty)) = ty
       
    61 
       
    62 val empty_subst = (Vartab.empty : Type.tyenv)
       
    63 
       
    64 fun merge_subst (A:Type.tyenv) (B:Type.tyenv) =
       
    65     SOME (Vartab.fold (fn (v, t) =>
       
    66                        fn tab =>
       
    67                           (case Vartab.lookup tab v of
       
    68                                NONE => Vartab.update (v, t) tab
       
    69                              | SOME t' => if t = t' then tab else raise Type.TYPE_MATCH)) A B)
       
    70     handle Type.TYPE_MATCH => NONE
       
    71 
       
    72 fun subst_ord (A:Type.tyenv, B:Type.tyenv) =
       
    73     (list_ord (prod_ord Term_Ord.fast_indexname_ord (prod_ord Term_Ord.sort_ord Term_Ord.typ_ord))) (Vartab.dest A, Vartab.dest B)
       
    74 
       
    75 structure Substtab = Table(type key = Type.tyenv val ord = subst_ord);
       
    76 
       
    77 fun substtab_union c = Substtab.fold Substtab.update c
       
    78 fun substtab_unions [] = Substtab.empty
       
    79   | substtab_unions [c] = c
       
    80   | substtab_unions (c::cs) = substtab_union c (substtab_unions cs)
       
    81 
       
    82 datatype instances = Instances of unit ConsttabModTy.table * Type.tyenv Consttab.table Consttab.table * constant list list * unit Substtab.table
       
    83 
       
    84 fun is_polymorphic (Constant (_, _, ty)) = not (null (Term.add_tvarsT ty []))
       
    85 
       
    86 fun distinct_constants cs =
       
    87     Consttab.keys (fold (fn c => Consttab.update (c, ())) cs Consttab.empty)
       
    88 
       
    89 fun empty cs =
       
    90     let
       
    91         val cs = distinct_constants (filter is_polymorphic cs)
       
    92         val old_cs = cs
       
    93 (*      fun collect_tvars ty tab = fold (fn v => fn tab => Typtab.update (TVar v, ()) tab) (Misc_Legacy.typ_tvars ty) tab
       
    94         val tvars_count = length (Typtab.keys (fold (fn c => fn tab => collect_tvars (typ_of_constant c) tab) cs Typtab.empty))
       
    95         fun tvars_of ty = collect_tvars ty Typtab.empty
       
    96         val cs = map (fn c => (c, tvars_of (typ_of_constant c))) cs
       
    97 
       
    98         fun tyunion A B =
       
    99             Typtab.fold
       
   100                 (fn (v,()) => fn tab => Typtab.update (v, case Typtab.lookup tab v of NONE => 1 | SOME n => n+1) tab)
       
   101                 A B
       
   102 
       
   103         fun is_essential A B =
       
   104             Typtab.fold
       
   105             (fn (v, ()) => fn essential => essential orelse (case Typtab.lookup B v of NONE => raise Link "is_essential" | SOME n => n=1))
       
   106             A false
       
   107 
       
   108         fun add_minimal (c', tvs') (tvs, cs) =
       
   109             let
       
   110                 val tvs = tyunion tvs' tvs
       
   111                 val cs = (c', tvs')::cs
       
   112             in
       
   113                 if forall (fn (c',tvs') => is_essential tvs' tvs) cs then
       
   114                     SOME (tvs, cs)
       
   115                 else
       
   116                     NONE
       
   117             end
       
   118 
       
   119         fun is_spanning (tvs, _) = (length (Typtab.keys tvs) = tvars_count)
       
   120 
       
   121         fun generate_minimal_subsets subsets [] = subsets
       
   122           | generate_minimal_subsets subsets (c::cs) =
       
   123             let
       
   124                 val subsets' = map_filter (add_minimal c) subsets
       
   125             in
       
   126                 generate_minimal_subsets (subsets@subsets') cs
       
   127             end*)
       
   128 
       
   129         val minimal_subsets = [old_cs] (*map (fn (tvs, cs) => map fst cs) (filter is_spanning (generate_minimal_subsets [(Typtab.empty, [])] cs))*)
       
   130 
       
   131         val constants = Consttab.keys (fold (fold (fn c => Consttab.update (c, ()))) minimal_subsets Consttab.empty)
       
   132 
       
   133     in
       
   134         Instances (
       
   135         fold (fn c => fn tab => ConsttabModTy.update (c, ()) tab) constants ConsttabModTy.empty,
       
   136         Consttab.make (map (fn c => (c, Consttab.empty : Type.tyenv Consttab.table)) constants),
       
   137         minimal_subsets, Substtab.empty)
       
   138     end
       
   139 
       
   140 local
       
   141 fun calc ctab substtab [] = substtab
       
   142   | calc ctab substtab (c::cs) =
       
   143     let
       
   144         val csubsts = map snd (Consttab.dest (the (Consttab.lookup ctab c)))
       
   145         fun merge_substs substtab subst =
       
   146             Substtab.fold (fn (s,_) =>
       
   147                            fn tab =>
       
   148                               (case merge_subst subst s of NONE => tab | SOME s => Substtab.update (s, ()) tab))
       
   149                           substtab Substtab.empty
       
   150         val substtab = substtab_unions (map (merge_substs substtab) csubsts)
       
   151     in
       
   152         calc ctab substtab cs
       
   153     end
       
   154 in
       
   155 fun calc_substs ctab (cs:constant list) = calc ctab (Substtab.update (empty_subst, ()) Substtab.empty) cs
       
   156 end
       
   157 
       
   158 fun add_instances thy (Instances (cfilter, ctab,minsets,substs)) cs =
       
   159     let
       
   160 (*      val _ = writeln (makestring ("add_instances: ", length_cs, length cs, length (Consttab.keys ctab)))*)
       
   161         fun calc_instantiations (constant as Constant (free, name, ty)) instantiations =
       
   162             Consttab.fold (fn (constant' as Constant (free', name', ty'), insttab) =>
       
   163                            fn instantiations =>
       
   164                               if free <> free' orelse name <> name' then
       
   165                                   instantiations
       
   166                               else case Consttab.lookup insttab constant of
       
   167                                        SOME _ => instantiations
       
   168                                      | NONE => ((constant', (constant, Sign.typ_match thy (ty', ty) empty_subst))::instantiations
       
   169                                                 handle Type.TYPE_MATCH => instantiations))
       
   170                           ctab instantiations
       
   171         val instantiations = fold calc_instantiations cs []
       
   172         (*val _ = writeln ("instantiations = "^(makestring (length instantiations)))*)
       
   173         fun update_ctab (constant', entry) ctab =
       
   174             (case Consttab.lookup ctab constant' of
       
   175                  NONE => raise Link "internal error: update_ctab"
       
   176                | SOME tab => Consttab.update (constant', Consttab.update entry tab) ctab)
       
   177         val ctab = fold update_ctab instantiations ctab
       
   178         val new_substs = fold (fn minset => fn substs => substtab_union (calc_substs ctab minset) substs)
       
   179                               minsets Substtab.empty
       
   180         val (added_substs, substs) =
       
   181             Substtab.fold (fn (ns, _) =>
       
   182                            fn (added, substtab) =>
       
   183                               (case Substtab.lookup substs ns of
       
   184                                    NONE => (ns::added, Substtab.update (ns, ()) substtab)
       
   185                                  | SOME () => (added, substtab)))
       
   186                           new_substs ([], substs)
       
   187     in
       
   188         (added_substs, Instances (cfilter, ctab, minsets, substs))
       
   189     end
       
   190 
       
   191 fun substs_of (Instances (_,_,_,substs)) = Substtab.keys substs
       
   192 
       
   193 
       
   194 local
       
   195 
       
   196 fun collect (Var _) tab = tab
       
   197   | collect (Bound _) tab = tab
       
   198   | collect (a $ b) tab = collect b (collect a tab)
       
   199   | collect (Abs (_, _, body)) tab = collect body tab
       
   200   | collect t tab = Consttab.update (constant_of t, ()) tab
       
   201 
       
   202 in
       
   203   fun collect_consts tms = Consttab.keys (fold collect tms Consttab.empty)
       
   204 end
       
   205 
       
   206 end
       
   207 
       
   208 signature PCOMPUTE =
       
   209 sig
       
   210     type pcomputer
       
   211 
       
   212     val make : Compute.machine -> theory -> thm list -> Linker.constant list -> pcomputer
       
   213     val make_with_cache : Compute.machine -> theory -> term list -> thm list -> Linker.constant list -> pcomputer
       
   214     
       
   215     val add_instances : pcomputer -> Linker.constant list -> bool 
       
   216     val add_instances' : pcomputer -> term list -> bool
       
   217 
       
   218     val rewrite : pcomputer -> cterm list -> thm list
       
   219     val simplify : pcomputer -> Compute.theorem -> thm
       
   220 
       
   221     val make_theorem : pcomputer -> thm -> string list -> Compute.theorem
       
   222     val instantiate : pcomputer -> (string * cterm) list -> Compute.theorem -> Compute.theorem
       
   223     val evaluate_prem : pcomputer -> int -> Compute.theorem -> Compute.theorem
       
   224     val modus_ponens : pcomputer -> int -> thm -> Compute.theorem -> Compute.theorem 
       
   225 
       
   226 end
       
   227 
       
   228 structure PCompute : PCOMPUTE = struct
       
   229 
       
   230 exception PCompute of string
       
   231 
       
   232 datatype theorem = MonoThm of thm | PolyThm of thm * Linker.instances * thm list
       
   233 datatype pattern = MonoPattern of term | PolyPattern of term * Linker.instances * term list
       
   234 
       
   235 datatype pcomputer =
       
   236   PComputer of theory_ref * Compute.computer * theorem list Unsynchronized.ref *
       
   237     pattern list Unsynchronized.ref 
       
   238 
       
   239 (*fun collect_consts (Var x) = []
       
   240   | collect_consts (Bound _) = []
       
   241   | collect_consts (a $ b) = (collect_consts a)@(collect_consts b)
       
   242   | collect_consts (Abs (_, _, body)) = collect_consts body
       
   243   | collect_consts t = [Linker.constant_of t]*)
       
   244 
       
   245 fun computer_of (PComputer (_,computer,_,_)) = computer
       
   246 
       
   247 fun collect_consts_of_thm th = 
       
   248     let
       
   249         val th = prop_of th
       
   250         val (prems, th) = (Logic.strip_imp_prems th, Logic.strip_imp_concl th)
       
   251         val (left, right) = Logic.dest_equals th
       
   252     in
       
   253         (Linker.collect_consts [left], Linker.collect_consts (right::prems))
       
   254     end
       
   255 
       
   256 fun create_theorem th =
       
   257 let
       
   258     val (left, right) = collect_consts_of_thm th
       
   259     val polycs = filter Linker.is_polymorphic left
       
   260     val tytab = fold (fn p => fn tab => fold (fn n => fn tab => Typtab.update (TVar n, ()) tab) (Misc_Legacy.typ_tvars (Linker.typ_of_constant p)) tab) polycs Typtab.empty
       
   261     fun check_const (c::cs) cs' =
       
   262         let
       
   263             val tvars = Misc_Legacy.typ_tvars (Linker.typ_of_constant c)
       
   264             val wrong = fold (fn n => fn wrong => wrong orelse is_none (Typtab.lookup tytab (TVar n))) tvars false
       
   265         in
       
   266             if wrong then raise PCompute "right hand side of theorem contains type variables which do not occur on the left hand side"
       
   267             else
       
   268                 if null (tvars) then
       
   269                     check_const cs (c::cs')
       
   270                 else
       
   271                     check_const cs cs'
       
   272         end
       
   273       | check_const [] cs' = cs'
       
   274     val monocs = check_const right []
       
   275 in
       
   276     if null (polycs) then
       
   277         (monocs, MonoThm th)
       
   278     else
       
   279         (monocs, PolyThm (th, Linker.empty polycs, []))
       
   280 end
       
   281 
       
   282 fun create_pattern pat = 
       
   283 let
       
   284     val cs = Linker.collect_consts [pat]
       
   285     val polycs = filter Linker.is_polymorphic cs
       
   286 in
       
   287     if null (polycs) then
       
   288         MonoPattern pat
       
   289     else
       
   290         PolyPattern (pat, Linker.empty polycs, [])
       
   291 end
       
   292              
       
   293 fun create_computer machine thy pats ths =
       
   294     let
       
   295         fun add (MonoThm th) ths = th::ths
       
   296           | add (PolyThm (_, _, ths')) ths = ths'@ths
       
   297         fun addpat (MonoPattern p) pats = p::pats
       
   298           | addpat (PolyPattern (_, _, ps)) pats = ps@pats
       
   299         val ths = fold_rev add ths []
       
   300         val pats = fold_rev addpat pats []
       
   301     in
       
   302         Compute.make_with_cache machine thy pats ths
       
   303     end
       
   304 
       
   305 fun update_computer computer pats ths = 
       
   306     let
       
   307         fun add (MonoThm th) ths = th::ths
       
   308           | add (PolyThm (_, _, ths')) ths = ths'@ths
       
   309         fun addpat (MonoPattern p) pats = p::pats
       
   310           | addpat (PolyPattern (_, _, ps)) pats = ps@pats
       
   311         val ths = fold_rev add ths []
       
   312         val pats = fold_rev addpat pats []
       
   313     in
       
   314         Compute.update_with_cache computer pats ths
       
   315     end
       
   316 
       
   317 fun conv_subst thy (subst : Type.tyenv) =
       
   318     map (fn (iname, (sort, ty)) => (ctyp_of thy (TVar (iname, sort)), ctyp_of thy ty)) (Vartab.dest subst)
       
   319 
       
   320 fun add_monos thy monocs pats ths =
       
   321     let
       
   322         val changed = Unsynchronized.ref false
       
   323         fun add monocs (th as (MonoThm _)) = ([], th)
       
   324           | add monocs (PolyThm (th, instances, instanceths)) =
       
   325             let
       
   326                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
       
   327                 val _ = if not (null newsubsts) then changed := true else ()
       
   328                 val newths = map (fn subst => Thm.instantiate (conv_subst thy subst, []) th) newsubsts
       
   329 (*              val _ = if not (null newths) then (print ("added new theorems: ", newths); ()) else ()*)
       
   330                 val newmonos = fold (fn th => fn monos => (snd (collect_consts_of_thm th))@monos) newths []
       
   331             in
       
   332                 (newmonos, PolyThm (th, instances, instanceths@newths))
       
   333             end
       
   334         fun addpats monocs (pat as (MonoPattern _)) = pat
       
   335           | addpats monocs (PolyPattern (p, instances, instancepats)) =
       
   336             let
       
   337                 val (newsubsts, instances) = Linker.add_instances thy instances monocs
       
   338                 val _ = if not (null newsubsts) then changed := true else ()
       
   339                 val newpats = map (fn subst => Envir.subst_term_types subst p) newsubsts
       
   340             in
       
   341                 PolyPattern (p, instances, instancepats@newpats)
       
   342             end 
       
   343         fun step monocs ths =
       
   344             fold_rev (fn th =>
       
   345                       fn (newmonos, ths) =>
       
   346                          let 
       
   347                              val (newmonos', th') = add monocs th 
       
   348                          in
       
   349                              (newmonos'@newmonos, th'::ths)
       
   350                          end)
       
   351                      ths ([], [])
       
   352         fun loop monocs pats ths =
       
   353             let 
       
   354                 val (monocs', ths') = step monocs ths 
       
   355                 val pats' = map (addpats monocs) pats
       
   356             in
       
   357                 if null (monocs') then
       
   358                     (pats', ths')
       
   359                 else
       
   360                     loop monocs' pats' ths'
       
   361             end
       
   362         val result = loop monocs pats ths
       
   363     in
       
   364         (!changed, result)
       
   365     end
       
   366 
       
   367 datatype cthm = ComputeThm of term list * sort list * term
       
   368 
       
   369 fun thm2cthm th =
       
   370     let
       
   371         val {hyps, prop, shyps, ...} = Thm.rep_thm th
       
   372     in
       
   373         ComputeThm (hyps, shyps, prop)
       
   374     end
       
   375 
       
   376 val cthm_ord' = prod_ord (prod_ord (list_ord Term_Ord.term_ord) (list_ord Term_Ord.sort_ord)) Term_Ord.term_ord
       
   377 
       
   378 fun cthm_ord (ComputeThm (h1, sh1, p1), ComputeThm (h2, sh2, p2)) = cthm_ord' (((h1,sh1), p1), ((h2, sh2), p2))
       
   379 
       
   380 structure CThmtab = Table(type key = cthm val ord = cthm_ord)
       
   381 
       
   382 fun remove_duplicates ths =
       
   383     let
       
   384         val counter = Unsynchronized.ref 0
       
   385         val tab = Unsynchronized.ref (CThmtab.empty : unit CThmtab.table)
       
   386         val thstab = Unsynchronized.ref (Inttab.empty : thm Inttab.table)
       
   387         fun update th =
       
   388             let
       
   389                 val key = thm2cthm th
       
   390             in
       
   391                 case CThmtab.lookup (!tab) key of
       
   392                     NONE => ((tab := CThmtab.update_new (key, ()) (!tab)); thstab := Inttab.update_new (!counter, th) (!thstab); counter := !counter + 1)
       
   393                   | _ => ()
       
   394             end
       
   395         val _ = map update ths
       
   396     in
       
   397         map snd (Inttab.dest (!thstab))
       
   398     end
       
   399 
       
   400 fun make_with_cache machine thy pats ths cs =
       
   401     let
       
   402         val ths = remove_duplicates ths
       
   403         val (monocs, ths) = fold_rev (fn th => 
       
   404                                       fn (monocs, ths) => 
       
   405                                          let val (m, t) = create_theorem th in 
       
   406                                              (m@monocs, t::ths)
       
   407                                          end)
       
   408                                      ths (cs, [])
       
   409         val pats = map create_pattern pats
       
   410         val (_, (pats, ths)) = add_monos thy monocs pats ths
       
   411         val computer = create_computer machine thy pats ths
       
   412     in
       
   413         PComputer (Theory.check_thy thy, computer, Unsynchronized.ref ths, Unsynchronized.ref pats)
       
   414     end
       
   415 
       
   416 fun make machine thy ths cs = make_with_cache machine thy [] ths cs
       
   417 
       
   418 fun add_instances (PComputer (thyref, computer, rths, rpats)) cs = 
       
   419     let
       
   420         val thy = Theory.deref thyref
       
   421         val (changed, (pats, ths)) = add_monos thy cs (!rpats) (!rths)
       
   422     in
       
   423         if changed then
       
   424             (update_computer computer pats ths;
       
   425              rths := ths;
       
   426              rpats := pats;
       
   427              true)
       
   428         else
       
   429             false
       
   430 
       
   431     end
       
   432 
       
   433 fun add_instances' pc ts = add_instances pc (Linker.collect_consts ts)
       
   434 
       
   435 fun rewrite pc cts =
       
   436     let
       
   437         val _ = add_instances' pc (map term_of cts)
       
   438         val computer = (computer_of pc)
       
   439     in
       
   440         map (fn ct => Compute.rewrite computer ct) cts
       
   441     end
       
   442 
       
   443 fun simplify pc th = Compute.simplify (computer_of pc) th
       
   444 
       
   445 fun make_theorem pc th vars = 
       
   446     let
       
   447         val _ = add_instances' pc [prop_of th]
       
   448 
       
   449     in
       
   450         Compute.make_theorem (computer_of pc) th vars
       
   451     end
       
   452 
       
   453 fun instantiate pc insts th = 
       
   454     let
       
   455         val _ = add_instances' pc (map (term_of o snd) insts)
       
   456     in
       
   457         Compute.instantiate (computer_of pc) insts th
       
   458     end
       
   459 
       
   460 fun evaluate_prem pc prem_no th = Compute.evaluate_prem (computer_of pc) prem_no th
       
   461 
       
   462 fun modus_ponens pc prem_no th' th =
       
   463     let
       
   464         val _ = add_instances' pc [prop_of th']
       
   465     in
       
   466         Compute.modus_ponens (computer_of pc) prem_no th' th
       
   467     end    
       
   468                                                                                                     
       
   469 
       
   470 end