adding type inference for disambiguation annotations in code equation
authorbulwahn
Wed, 07 Sep 2011 13:51:34 +0200
changeset 45658c13fdf710a40
parent 45657 5a062c23c7db
child 45659 7ecb4124a3a3
adding type inference for disambiguation annotations in code equation
src/Tools/Code/code_thingol.ML
     1.1 --- a/src/Tools/Code/code_thingol.ML	Wed Sep 07 13:51:32 2011 +0200
     1.2 +++ b/src/Tools/Code/code_thingol.ML	Wed Sep 07 13:51:34 2011 +0200
     1.3 @@ -609,6 +609,43 @@
     1.4        (err_typ ^ "\n" ^ err_class)
     1.5    end;
     1.6  
     1.7 +(* inference of type annotations for disambiguation with type classes *)
     1.8 +
     1.9 +
    1.10 +fun annotate_term (Const (c', T'), Const (c, T)) tvar_names =
    1.11 +    let
    1.12 +      val tvar_names' = Term.add_tvar_namesT T' tvar_names
    1.13 +    in
    1.14 +      (Const (c, if eq_set (op =) (tvar_names, tvar_names') then T else Type("", [T])), tvar_names')
    1.15 +    end
    1.16 +  | annotate_term (t1 $ u1, t $ u) tvar_names =
    1.17 +    let
    1.18 +      val (u', tvar_names') = annotate_term (u1, u) tvar_names
    1.19 +      val (t', tvar_names'') = annotate_term (t1, t) tvar_names'    
    1.20 +    in
    1.21 +      (t' $ u', tvar_names'')
    1.22 +    end
    1.23 +  | annotate_term (Abs (_, _, t1) , Abs (x, T, t)) tvar_names =
    1.24 +    apfst (fn t => Abs (x, T, t)) (annotate_term (t1, t) tvar_names)
    1.25 +  | annotate_term (_, t) tvar_names = (t, tvar_names)
    1.26 +
    1.27 +fun annotate_eqns thy eqns = 
    1.28 +  let
    1.29 +    val ctxt = ProofContext.init_global thy
    1.30 +    val erase = map_types (fn _ => Type_Infer.anyT [])
    1.31 +    val reinfer = singleton (Type_Infer_Context.infer_types ctxt)
    1.32 +    fun add_annotations ((args, (rhs, some_abs)), (SOME th, proper)) =
    1.33 +      let
    1.34 +        val (lhs, drhs) = Logic.dest_equals (prop_of (Thm.unvarify_global th))
    1.35 +        val drhs' = snd (Logic.dest_equals (reinfer (Logic.mk_equals (lhs, erase drhs))))
    1.36 +        val (rhs', _) = annotate_term (drhs', rhs) []
    1.37 +     in
    1.38 +        ((args, (rhs', some_abs)), (SOME th, proper))
    1.39 +     end
    1.40 +     | add_annotations eqn = eqn
    1.41 +  in
    1.42 +    map add_annotations eqns
    1.43 +  end;
    1.44  
    1.45  (* translation *)
    1.46