different handling of eq class for nbe
authorhaftmann
Tue, 22 Apr 2008 22:00:25 +0200
changeset 26739947b6013e863
parent 26738 615e1a86787b
child 26740 6c8cd101f875
different handling of eq class for nbe
src/HOL/HOL.thy
src/HOL/ex/NormalForm.thy
src/Tools/nbe.ML
     1.1 --- a/src/HOL/HOL.thy	Tue Apr 22 13:35:26 2008 +0200
     1.2 +++ b/src/HOL/HOL.thy	Tue Apr 22 22:00:25 2008 +0200
     1.3 @@ -1659,8 +1659,6 @@
     1.4  
     1.5  subsection {* Code generator basic setup -- see further @{text Code_Setup.thy} *}
     1.6  
     1.7 -setup "CodeName.setup #> CodeTarget.setup #> Nbe.setup"
     1.8 -
     1.9  code_datatype Trueprop "prop"
    1.10  
    1.11  code_datatype "TYPE('a\<Colon>{})"
    1.12 @@ -1699,6 +1697,9 @@
    1.13  
    1.14  setup {*
    1.15    CodeUnit.add_const_alias @{thm equals_eq}
    1.16 +  #> CodeName.setup
    1.17 +  #> CodeTarget.setup
    1.18 +  #> Nbe.setup @{sort eq} [(@{const_name eq_class.eq}, @{const_name "op ="})]
    1.19  *}
    1.20  
    1.21  lemma [code func]:
     2.1 --- a/src/HOL/ex/NormalForm.thy	Tue Apr 22 13:35:26 2008 +0200
     2.2 +++ b/src/HOL/ex/NormalForm.thy	Tue Apr 22 22:00:25 2008 +0200
     2.3 @@ -58,19 +58,19 @@
     2.4  lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization
     2.5  
     2.6  lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization
     2.7 -lemma "split (%(x\<Colon>'a\<Colon>eq) y. x) (a, b) = a" by normalization rule
     2.8 +lemma "split (%x y. x) (a, b) = a" by normalization rule
     2.9  lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization
    2.10  
    2.11  lemma "case Z of Z \<Rightarrow> True | S x \<Rightarrow> False" by normalization
    2.12  
    2.13  lemma "[] @ [] = []" by normalization
    2.14 -lemma "map f [x,y,z::'x] = [f x \<Colon> 'a\<Colon>eq, f y, f z]" by normalization rule+
    2.15 -lemma "[a \<Colon> 'a\<Colon>eq, b, c] @ xs = a # b # c # xs" by normalization rule+
    2.16 -lemma "[] @ xs = (xs \<Colon> 'a\<Colon>eq list)" by normalization rule
    2.17 +lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+
    2.18 +lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+
    2.19 +lemma "[] @ xs = xs" by normalization rule
    2.20  lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+
    2.21  lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+
    2.22 -lemma "rev [a, b, c] = [c \<Colon> 'a\<Colon>eq, b, a]" by normalization rule+
    2.23 -normal_form "rev (a#b#cs) = rev cs @ [b, a \<Colon> 'a\<Colon>eq]"
    2.24 +lemma "rev [a, b, c] = [c, b, a]" by normalization rule+
    2.25 +normal_form "rev (a#b#cs) = rev cs @ [b, a]"
    2.26  normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])"
    2.27  normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))"
    2.28  normal_form "map (%F. F [Z,S Z,S(S Z)]) (map map [S,add (S Z),mul (S(S Z)),id])"
    2.29 @@ -78,19 +78,19 @@
    2.30    by normalization
    2.31  normal_form "case xs of [] \<Rightarrow> True | x#xs \<Rightarrow> False"
    2.32  normal_form "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) xs = P"
    2.33 -lemma "let x = y in [x, x] = [y \<Colon> 'a\<Colon>eq, y]" by normalization rule+
    2.34 -lemma "Let y (%x. [x,x]) = [y \<Colon> 'a\<Colon>eq, y]" by normalization rule+
    2.35 +lemma "let x = y in [x, x] = [y, y]" by normalization rule+
    2.36 +lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+
    2.37  normal_form "case n of Z \<Rightarrow> True | S x \<Rightarrow> False"
    2.38  lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+
    2.39  normal_form "filter (%x. x) ([True,False,x]@xs)"
    2.40  normal_form "filter Not ([True,False,x]@xs)"
    2.41  
    2.42 -lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b ,c \<Colon> 'a\<Colon>eq]" by normalization rule+
    2.43 -lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f \<Colon> 'a\<Colon>eq]" by normalization rule+
    2.44 +lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+
    2.45 +lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+
    2.46  lemma "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) [None, Some ()] = [False, True]" by normalization
    2.47  
    2.48 -lemma "last [a, b, c \<Colon> 'a\<Colon>eq] = c" by normalization rule
    2.49 -lemma "last ([a, b, c \<Colon> 'a\<Colon>eq] @ xs) = (if null xs then c else last xs)"
    2.50 +lemma "last [a, b, c] = c" by normalization rule
    2.51 +lemma "last ([a, b, c] @ xs) = (if null xs then c else last xs)"
    2.52    by normalization rule
    2.53  
    2.54  lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule
    2.55 @@ -111,10 +111,10 @@
    2.56  lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization
    2.57  normal_form "Suc 0 \<in> set ms"
    2.58  
    2.59 -lemma "f = (f \<Colon> 'a\<Colon>eq)" by normalization rule+
    2.60 -lemma "f x = (f x \<Colon> 'a\<Colon>eq)" by normalization rule+
    2.61 -lemma "(f o g) x = (f (g x) \<Colon> 'a\<Colon>eq)" by normalization rule+
    2.62 -lemma "(f o id) x = (f x \<Colon> 'a\<Colon>eq)" by normalization rule+
    2.63 +lemma "f = f" by normalization rule+
    2.64 +lemma "f x = f x" by normalization rule+
    2.65 +lemma "(f o g) x = f (g x)" by normalization rule+
    2.66 +lemma "(f o id) x = f x" by normalization rule+
    2.67  normal_form "(\<lambda>x. x)"
    2.68  
    2.69  (* Church numerals: *)
     3.1 --- a/src/Tools/nbe.ML	Tue Apr 22 13:35:26 2008 +0200
     3.2 +++ b/src/Tools/nbe.ML	Tue Apr 22 22:00:25 2008 +0200
     3.3 @@ -23,7 +23,7 @@
     3.4    val univs_ref: (unit -> Univ list -> Univ list) option ref
     3.5    val trace: bool ref
     3.6  
     3.7 -  val setup: theory -> theory
     3.8 +  val setup: class list -> (string * string) list -> theory -> theory
     3.9  end;
    3.10  
    3.11  structure Nbe: NBE =
    3.12 @@ -327,7 +327,7 @@
    3.13              val ts' = take_until is_dict ts;
    3.14              val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
    3.15              val T = Code.default_typ thy c;
    3.16 -            val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
    3.17 +            val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
    3.18              val typidx' = typidx + maxidx_of_typ T' + 1;
    3.19            in of_apps bounds (Term.Const (c, T'), ts') typidx' end
    3.20        | of_univ bounds (Free (name, ts)) typidx =
    3.21 @@ -373,20 +373,51 @@
    3.22      |> term_of_univ thy idx_tab
    3.23    end;
    3.24  
    3.25 +(* trivial type classes *)
    3.26 +
    3.27 +structure Nbe_Triv_Classes = TheoryDataFun
    3.28 +(
    3.29 +  type T = class list * (string * string) list;
    3.30 +  val empty = ([], []);
    3.31 +  val copy = I;
    3.32 +  val extend = I;
    3.33 +  fun merge _ ((classes1, consts1), (classes2, consts2)) =
    3.34 +    (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2));
    3.35 +)
    3.36 +
    3.37 +fun add_triv_classes thy =
    3.38 +  let
    3.39 +    val (trivs, _) = Nbe_Triv_Classes.get thy;
    3.40 +    val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs;
    3.41 +    fun map_sorts f = (map_types o map_atyps)
    3.42 +      (fn TVar (v, sort) => TVar (v, f sort)
    3.43 +        | TFree (v, sort) => TFree (v, f sort));
    3.44 +  in map_sorts inters end;
    3.45 +
    3.46 +fun subst_triv_consts thy =
    3.47 +  let
    3.48 +    fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c
    3.49 +         of SOME c' => Term.Const (c', ty)
    3.50 +          | NONE => t)
    3.51 +      | t => t);
    3.52 +    val (_, consts) = Nbe_Triv_Classes.get thy;
    3.53 +    val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy);
    3.54 +  in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end;
    3.55 +
    3.56  (* evaluation with type reconstruction *)
    3.57  
    3.58 -fun eval thy code t vs_ty_t deps =
    3.59 +fun eval thy t code vs_ty_t deps =
    3.60    let
    3.61      val ty = type_of t;
    3.62 -    fun subst_Frees [] = I
    3.63 -      | subst_Frees inst =
    3.64 -          Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
    3.65 -                            | t => t);
    3.66 -    val anno_vars =
    3.67 -      subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
    3.68 -      #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
    3.69 -    fun constrain t =
    3.70 -      singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t);
    3.71 +    val type_free = AList.lookup (op =)
    3.72 +      (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
    3.73 +    val type_frees = Term.map_aterms
    3.74 +      (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
    3.75 +    fun type_infer t = [(t, ty)]
    3.76 +      |> TypeInfer.infer_types (Sign.pp thy) (Sign.tsig_of thy) I
    3.77 +           (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)
    3.78 +           Name.context 0 NONE
    3.79 +      |> fst |> the_single |> fst;
    3.80      fun check_tvars t = if null (Term.term_tvars t) then t else
    3.81        error ("Illegal schematic type variables in normalized term: "
    3.82          ^ setmp show_types true (Sign.string_of_term thy) t);
    3.83 @@ -394,40 +425,39 @@
    3.84    in
    3.85      compile_eval thy code vs_ty_t deps
    3.86      |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
    3.87 -    |> anno_vars
    3.88 +    |> subst_triv_consts thy
    3.89 +    |> type_frees
    3.90      |> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
    3.91 -    |> constrain
    3.92 +    |> type_infer
    3.93      |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
    3.94 +    |> check_tvars
    3.95      |> tracing (fn t => "---\n")
    3.96 -    |> check_tvars
    3.97    end;
    3.98  
    3.99  (* evaluation oracle *)
   3.100  
   3.101 -exception Norm of CodeThingol.code * term
   3.102 +exception Norm of term * CodeThingol.code
   3.103    * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   3.104  
   3.105 -fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
   3.106 -  Logic.mk_equals (t, eval thy code t vs_ty_t deps);
   3.107 +fun norm_oracle (thy, Norm (t, code, vs_ty_t, deps)) =
   3.108 +  Logic.mk_equals (t, eval thy t code vs_ty_t deps);
   3.109  
   3.110 -fun norm_invoke thy code t vs_ty_t deps =
   3.111 -  Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
   3.112 +fun norm_invoke thy t code vs_ty_t deps =
   3.113 +  Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps));
   3.114    (*FIXME get rid of hardwired theory name*)
   3.115  
   3.116  fun norm_conv ct =
   3.117    let
   3.118      val thy = Thm.theory_of_cterm ct;
   3.119 -    fun conv code vs_ty_t deps ct =
   3.120 -      let
   3.121 -        val t = Thm.term_of ct;
   3.122 -      in norm_invoke thy code t vs_ty_t deps end;
   3.123 -  in CodePackage.evaluate_conv thy conv ct end;
   3.124 +    fun evaluator' t code vs_ty_t deps = norm_invoke thy t code vs_ty_t deps;
   3.125 +    fun evaluator t = (add_triv_classes thy t, evaluator' t);
   3.126 +  in CodePackage.evaluate_conv thy evaluator ct end;
   3.127  
   3.128 -fun norm_term thy =
   3.129 +fun norm_term thy t =
   3.130    let
   3.131 -    fun invoke code vs_ty_t deps t =
   3.132 -      eval thy code t vs_ty_t deps;
   3.133 -  in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end;
   3.134 +    fun evaluator' t code vs_ty_t deps = eval thy t code vs_ty_t deps;
   3.135 +    fun evaluator t = (add_triv_classes thy t, evaluator' t);
   3.136 +  in (Code.postprocess_term thy o CodePackage.evaluate_term thy evaluator) t end;
   3.137  
   3.138  (* evaluation command *)
   3.139  
   3.140 @@ -448,7 +478,9 @@
   3.141    let val ctxt = Toplevel.context_of state
   3.142    in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
   3.143  
   3.144 -val setup = Theory.add_oracle ("norm", norm_oracle)
   3.145 +fun setup nbe_classes nbe_consts =
   3.146 +  Theory.add_oracle ("norm", norm_oracle)
   3.147 +  #> Nbe_Triv_Classes.map (K (nbe_classes, nbe_consts));
   3.148  
   3.149  local structure P = OuterParse and K = OuterKeyword in
   3.150