handling partiality in the case where the equality optimisation is applied
authorbulwahn
Sat, 21 Jul 2012 10:53:26 +0200
changeset 4942943875bab3a4c
parent 49428 3e730188f328
child 49430 b42067a3188f
handling partiality in the case where the equality optimisation is applied
src/HOL/Tools/Quickcheck/exhaustive_generators.ML
     1.1 --- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Fri Jul 20 23:38:15 2012 +0200
     1.2 +++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML	Sat Jul 21 10:53:26 2012 +0200
     1.3 @@ -299,6 +299,16 @@
     1.4      Const (@{const_name Let}, T1 --> (T1 --> T2) --> T2) $ t $ lambda x (e genuine)
     1.5    end
     1.6  
     1.7 +fun mk_safe_let_expr genuine_only none safe (x, t, e) genuine =
     1.8 +  let
     1.9 +    val (T1, T2) = (fastype_of x, fastype_of (e genuine))
    1.10 +    val if_t = Const (@{const_name "If"}, @{typ bool} --> T2 --> T2 --> T2)
    1.11 +  in
    1.12 +    Const (@{const_name "Quickcheck.catch_match"}, T2 --> T2 --> T2) $ 
    1.13 +      (Const (@{const_name Let}, T1 --> (T1 --> T2) --> T2) $ t $ lambda x (e genuine)) $
    1.14 +      (if_t $ genuine_only $ none $ safe false)
    1.15 +  end
    1.16 +
    1.17  fun mk_test_term lookup mk_closure mk_if mk_let none_t return ctxt =
    1.18    let
    1.19      val cnstrs = flat (maps
    1.20 @@ -311,6 +321,7 @@
    1.21      fun mk_naive_test_term t =
    1.22        fold_rev mk_closure (map lookup (Term.add_free_names t []))
    1.23          (mk_if (t, none_t, return) true)
    1.24 +    fun mk_test (vars, check) = fold_rev mk_closure (map lookup vars) check
    1.25      fun mk_smart_test_term' concl bound_vars assms genuine =
    1.26        let
    1.27          fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
    1.28 @@ -318,9 +329,16 @@
    1.29            if member (op =) (Term.add_free_names lhs bound_vars) x then
    1.30              c (assm, assms)
    1.31            else
    1.32 -            (remove (op =) x (vars_of assm),
    1.33 -              mk_let f (try lookup x) lhs 
    1.34 -                (mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
    1.35 +            (let
    1.36 +               val rec_call = mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms
    1.37 +               fun safe genuine =
    1.38 +                 the_default I (Option.map mk_closure (try lookup x)) (rec_call genuine)
    1.39 +            in
    1.40 +              mk_test (remove (op =) x (vars_of assm),
    1.41 +                mk_let safe f (try lookup x) lhs 
    1.42 +                  (mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
    1.43 +            
    1.44 +            end)
    1.45            | mk_equality_term (lhs, t) c (assm, assms) =
    1.46              if is_constrt (strip_comb t) then
    1.47                let
    1.48 @@ -335,24 +353,23 @@
    1.49                  val bound_vars' = union (op =) (vars_of lhs) (union (op =) varnames bound_vars)
    1.50                  val cont_t = mk_smart_test_term' concl bound_vars' (new_assms @ assms) genuine
    1.51                in
    1.52 -                (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs
    1.53 +                mk_test (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs
    1.54                    [(list_comb (constr, vars), cont_t), (dummy_var, none_t)])
    1.55                end
    1.56              else c (assm, assms)
    1.57 -        fun default (assm, assms) = (vars_of assm,
    1.58 -          mk_if (HOLogic.mk_not assm, none_t, 
    1.59 -          mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
    1.60 -        val (vars, check) =
    1.61 -          case assms of [] => (vars_of concl, mk_if (concl, none_t, return) genuine)
    1.62 -            | assm :: assms =>
    1.63 -              if Config.get ctxt optimise_equality then
    1.64 -                (case try HOLogic.dest_eq assm of
    1.65 -                  SOME (lhs, rhs) =>
    1.66 -                    mk_equality_term (lhs, rhs) (mk_equality_term (rhs, lhs) default) (assm, assms)
    1.67 -                | NONE => default (assm, assms))
    1.68 -              else default (assm, assms)
    1.69 +        fun default (assm, assms) =
    1.70 +          mk_test (vars_of assm,
    1.71 +            mk_if (HOLogic.mk_not assm, none_t, 
    1.72 +            mk_smart_test_term' concl (union (op =) (vars_of assm) bound_vars) assms) genuine)
    1.73        in
    1.74 -        fold_rev mk_closure (map lookup vars) check
    1.75 +        case assms of [] => mk_test (vars_of concl, mk_if (concl, none_t, return) genuine)
    1.76 +          | assm :: assms =>
    1.77 +            if Config.get ctxt optimise_equality then
    1.78 +              (case try HOLogic.dest_eq assm of
    1.79 +                SOME (lhs, rhs) =>
    1.80 +                  mk_equality_term (lhs, rhs) (mk_equality_term (rhs, lhs) default) (assm, assms)
    1.81 +              | NONE => default (assm, assms))
    1.82 +            else default (assm, assms)
    1.83        end
    1.84      val mk_smart_test_term =
    1.85        Quickcheck_Common.strip_imp #> (fn (assms, concl) => mk_smart_test_term' concl [] assms true)
    1.86 @@ -377,7 +394,7 @@
    1.87          $ lambda free t $ depth
    1.88      val none_t = @{term "()"}
    1.89      fun mk_safe_if (cond, then_t, else_t) genuine = mk_if (cond, then_t, else_t genuine)
    1.90 -    fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
    1.91 +    fun mk_let _ def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
    1.92      val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_safe_if mk_let none_t return ctxt 
    1.93    in lambda depth (@{term "catch_Counterexample :: unit => term list option"} $ mk_test_term t) end
    1.94  
    1.95 @@ -406,7 +423,7 @@
    1.96          $ lambda free t $ depth
    1.97      val none_t = Const (@{const_name "None"}, resultT)
    1.98      val mk_if = Quickcheck_Common.mk_safe_if genuine_only none_t
    1.99 -    fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
   1.100 +    fun mk_let safe def v_opt t e = mk_safe_let_expr genuine_only none_t safe (the_default def v_opt, t, e)
   1.101      val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if mk_let none_t return ctxt
   1.102    in lambda genuine_only (lambda depth (mk_test_term t)) end
   1.103  
   1.104 @@ -436,10 +453,10 @@
   1.105              $ lambda free (lambda term_var t)) $ depth
   1.106      val none_t = Const (@{const_name "None"}, resultT)
   1.107      val mk_if = Quickcheck_Common.mk_safe_if genuine_only none_t
   1.108 -    fun mk_let _ (SOME (v, term_var)) t e =
   1.109 -      mk_let_expr (v, t, 
   1.110 -        e #> subst_free [(term_var, absdummy @{typ unit} (HOLogic.mk_term_of (fastype_of t) t))])
   1.111 -      | mk_let v NONE t e = mk_let_expr (v, t, e)
   1.112 +    fun mk_let safe _ (SOME (v, term_var)) t e =
   1.113 +        mk_safe_let_expr genuine_only none_t safe (v, t, 
   1.114 +          e #> subst_free [(term_var, absdummy @{typ unit} (mk_safe_term t))])
   1.115 +      | mk_let safe v NONE t e = mk_safe_let_expr genuine_only none_t safe (v, t, e)
   1.116      val mk_test_term = mk_test_term lookup mk_exhaustive_closure mk_if mk_let none_t return ctxt
   1.117    in lambda genuine_only (lambda depth (mk_test_term t)) end
   1.118  
   1.119 @@ -462,7 +479,7 @@
   1.120        Const (@{const_name "Quickcheck_Exhaustive.bounded_forall_class.bounded_forall"}, bounded_forallT T)
   1.121          $ lambda (Free (s, T)) t $ depth
   1.122      fun mk_safe_if (cond, then_t, else_t) genuine = mk_if (cond, then_t, else_t genuine)
   1.123 -    fun mk_let def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
   1.124 +    fun mk_let safe def v_opt t e = mk_let_expr (the_default def v_opt, t, e)
   1.125      val mk_test_term =
   1.126        mk_test_term lookup mk_bounded_forall mk_safe_if mk_let @{term True} (K @{term False}) ctxt
   1.127    in lambda depth (mk_test_term t) end