thread through bound types
authorblanchet
Wed, 25 Sep 2013 16:43:46 +0200
changeset 550275f647a5bd46e
parent 55026 d1bd94eb5d0e
child 55028 27da6373a64f
thread through bound types
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 25 16:43:46 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 25 16:43:46 2013 +0200
     1.3 @@ -577,9 +577,9 @@
     1.4      if is_none maybe_sel_eqn then (I, I, I) else
     1.5      let
     1.6        val {fun_args, rhs_term, ... } = the maybe_sel_eqn;
     1.7 -      fun rewrite_q t = if has_call t then @{term False} else @{term True};
     1.8 -      fun rewrite_g t = if has_call t then undef_const else t;
     1.9 -      fun rewrite_h t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
    1.10 +      fun rewrite_q _ t = if has_call t then @{term False} else @{term True};
    1.11 +      fun rewrite_g _ t = if has_call t then undef_const else t;
    1.12 +      fun rewrite_h _ t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
    1.13        fun massage f t = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
    1.14      in
    1.15        (massage rewrite_q,
    1.16 @@ -604,8 +604,7 @@
    1.17        | rewrite _ U T t = if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
    1.18      fun massage NONE t = t
    1.19        | massage (SOME {fun_args, rhs_term, ...}) t =
    1.20 -        massage_indirect_corec_call lthy has_call (rewrite []) []
    1.21 -          (range_type (fastype_of t)) rhs_term
    1.22 +        massage_indirect_corec_call lthy has_call rewrite [] (range_type (fastype_of t)) rhs_term
    1.23          |> abs_tuple fun_args;
    1.24    in
    1.25      massage maybe_sel_eqn
     2.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 16:43:46 2013 +0200
     2.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 16:43:46 2013 +0200
     2.3 @@ -60,13 +60,13 @@
     2.4  
     2.5    val massage_indirect_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
     2.6      typ list -> term -> term -> term -> term
     2.7 -  val massage_direct_corec_call: Proof.context -> (term -> bool) -> (term -> term) -> typ list ->
     2.8 -    term -> term
     2.9 +  val massage_direct_corec_call: Proof.context -> (term -> bool) -> (typ list -> term -> term) ->
    2.10 +    typ list -> term -> term
    2.11    val massage_indirect_corec_call: Proof.context -> (term -> bool) ->
    2.12 -    (typ -> typ -> term -> term) -> typ list -> typ -> term -> term
    2.13 +    (typ list -> typ -> typ -> term -> term) -> typ list -> typ -> term -> term
    2.14    val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
    2.15 -  val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ list -> term ->
    2.16 -    term
    2.17 +  val massage_corec_code_rhs: Proof.context -> (typ list -> term -> term list -> term) ->
    2.18 +    typ list -> term -> term
    2.19    val fold_rev_corec_code_rhs: Proof.context -> (term list -> term -> term list -> 'a -> 'a) ->
    2.20      typ list -> term -> 'a -> 'a
    2.21  
    2.22 @@ -249,38 +249,40 @@
    2.23  
    2.24  fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex);
    2.25  
    2.26 -fun massage_let_if_case ctxt has_call massage_leaf bound_Ts =
    2.27 +fun massage_let_if_case ctxt has_call massage_leaf =
    2.28    let
    2.29      val thy = Proof_Context.theory_of ctxt;
    2.30  
    2.31 -    val typof = curry fastype_of1 bound_Ts; (*###*)
    2.32      fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else ();
    2.33  
    2.34 -    fun massage_rec t =
    2.35 -      (case Term.strip_comb t of
    2.36 -        (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
    2.37 -      | (Const (@{const_name If}, _), obj :: (branches as [then_branch, _])) =>
    2.38 -        let val branches' = map massage_rec branches in
    2.39 -          Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches')
    2.40 -        end
    2.41 -      | (Const (c, _), args as _ :: _) =>
    2.42 -        let val n = num_binder_types (Sign.the_const_type thy c) in
    2.43 -          (case fastype_of1 (bound_Ts, nth args (n - 1)) of
    2.44 -            Type (s, Ts) =>
    2.45 -            if case_of ctxt s = SOME c then
    2.46 -              let
    2.47 -                val (branches, obj_leftovers) = chop n args;
    2.48 -                val branches' = map massage_rec branches;
    2.49 -                val casex' = Const (c, map typof branches' ---> map typof obj_leftovers --->
    2.50 -                  typof t);
    2.51 -              in
    2.52 -                betapplys (casex', branches' @ tap (List.app check_no_call) obj_leftovers)
    2.53 -              end
    2.54 -            else
    2.55 -              massage_leaf t
    2.56 -          | _ => massage_leaf t)
    2.57 -        end
    2.58 -      | _ => massage_leaf t)
    2.59 +    fun massage_rec bound_Ts t =
    2.60 +      let val typof = curry fastype_of1 bound_Ts in
    2.61 +        (case Term.strip_comb t of
    2.62 +          (Const (@{const_name Let}, _), [arg1, arg2]) =>
    2.63 +          massage_rec bound_Ts (betapply (arg2, arg1))
    2.64 +        | (Const (@{const_name If}, _), obj :: (branches as [then_branch, _])) =>
    2.65 +          let val branches' = map (massage_rec bound_Ts) branches in
    2.66 +            Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches')
    2.67 +          end
    2.68 +        | (Const (c, _), args as _ :: _) =>
    2.69 +          let val n = num_binder_types (Sign.the_const_type thy c) in
    2.70 +            (case fastype_of1 (bound_Ts, nth args (n - 1)) of
    2.71 +              Type (s, Ts) =>
    2.72 +              if case_of ctxt s = SOME c then
    2.73 +                let
    2.74 +                  val (branches, obj_leftovers) = chop n args;
    2.75 +                  val branches' = map (massage_rec bound_Ts) branches;
    2.76 +                  val casex' = Const (c, map typof branches' ---> map typof obj_leftovers --->
    2.77 +                    typof t);
    2.78 +                in
    2.79 +                  betapplys (casex', branches' @ tap (List.app check_no_call) obj_leftovers)
    2.80 +                end
    2.81 +              else
    2.82 +                massage_leaf bound_Ts t
    2.83 +            | _ => massage_leaf bound_Ts t)
    2.84 +          end
    2.85 +        | _ => massage_leaf bound_Ts t)
    2.86 +      end
    2.87    in
    2.88      massage_rec
    2.89    end;
    2.90 @@ -289,63 +291,71 @@
    2.91  
    2.92  fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t =
    2.93    let
    2.94 -    val typof = curry fastype_of1 bound_Ts;
    2.95      val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd)
    2.96  
    2.97 -    fun massage_direct_call U T t =
    2.98 -      if has_call t then factor_out_types ctxt raw_massage_call dest_sumT U T t
    2.99 +    fun massage_direct_call bound_Ts U T t =
   2.100 +      if has_call t then factor_out_types ctxt (raw_massage_call bound_Ts) dest_sumT U T t
   2.101        else build_map_Inl (T, U) $ t;
   2.102  
   2.103 -    fun massage_direct_fun U T t =
   2.104 -      let val var = Var ((Name.uu, Term.maxidx_of_term t + 1), domain_type (typof t)) in
   2.105 -        Term.lambda var (massage_direct_call U T (t $ var))
   2.106 +    fun massage_direct_fun bound_Ts U T t =
   2.107 +      let
   2.108 +        val var = Var ((Name.uu, Term.maxidx_of_term t + 1),
   2.109 +          domain_type (fastype_of1 (bound_Ts, t)));
   2.110 +      in
   2.111 +        Term.lambda var (massage_direct_call bound_Ts U T (t $ var))
   2.112        end;
   2.113  
   2.114 -    fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
   2.115 +    fun massage_map bound_Ts (Type (_, Us)) (Type (s, Ts)) t =
   2.116          (case try (dest_map ctxt s) t of
   2.117            SOME (map0, fs) =>
   2.118            let
   2.119 -            val Type (_, dom_Ts) = domain_type (typof t);
   2.120 +            val Type (_, dom_Ts) = domain_type (fastype_of1 (bound_Ts, t));
   2.121              val map' = mk_map (length fs) dom_Ts Us map0;
   2.122 -            val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
   2.123 +            val fs' =
   2.124 +              map_flattened_map_args ctxt s (map3 (massage_map_or_map_arg bound_Ts) Us Ts) fs;
   2.125            in
   2.126              Term.list_comb (map', fs')
   2.127            end
   2.128          | NONE => raise AINT_NO_MAP t)
   2.129 -      | massage_map _ _ t = raise AINT_NO_MAP t
   2.130 -    and massage_map_or_map_arg U T t =
   2.131 +      | massage_map _ _ _ t = raise AINT_NO_MAP t
   2.132 +    and massage_map_or_map_arg bound_Ts U T t =
   2.133        if T = U then
   2.134          if has_call t then unexpected_corec_call ctxt t else t
   2.135        else
   2.136 -        massage_map U T t
   2.137 -        handle AINT_NO_MAP _ => massage_direct_fun U T t;
   2.138 +        massage_map bound_Ts U T t
   2.139 +        handle AINT_NO_MAP _ => massage_direct_fun bound_Ts U T t;
   2.140  
   2.141 -    fun massage_call U T =
   2.142 -      massage_let_if_case ctxt has_call (fn t =>
   2.143 +    fun massage_call bound_Ts U T =
   2.144 +      massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
   2.145          if has_call t then
   2.146            (case U of
   2.147              Type (s, Us) =>
   2.148              (case try (dest_ctr ctxt s) t of
   2.149                SOME (f, args) =>
   2.150 -              let val f' = mk_ctr Us f in
   2.151 -                Term.list_comb (f',
   2.152 -                  map3 massage_call (binder_types (typof f')) (map typof args) args)
   2.153 +              let
   2.154 +                val typof = curry fastype_of1 bound_Ts;
   2.155 +                val f' = mk_ctr Us f
   2.156 +                val f'_T = typof f';
   2.157 +                val arg_Ts = map typof args;
   2.158 +              in
   2.159 +                Term.list_comb (f', map3 (massage_call bound_Ts) (binder_types f'_T) arg_Ts args)
   2.160                end
   2.161              | NONE =>
   2.162                (case t of
   2.163                  t1 $ t2 =>
   2.164                  (if has_call t2 then
   2.165 -                  massage_direct_call U T t
   2.166 +                  massage_direct_call bound_Ts U T t
   2.167                  else
   2.168 -                  massage_map U T t1 $ t2
   2.169 -                  handle AINT_NO_MAP _ => massage_direct_call U T t)
   2.170 -              | Abs (s, T', t') => Abs (s, T', massage_call (range_type U) (range_type T) t')
   2.171 -              | _ => massage_direct_call U T t))
   2.172 +                  massage_map bound_Ts U T t1 $ t2
   2.173 +                  handle AINT_NO_MAP _ => massage_direct_call bound_Ts U T t)
   2.174 +              | Abs (s, T', t') =>
   2.175 +                Abs (s, T', massage_call (T' :: bound_Ts) (range_type U) (range_type T) t')
   2.176 +              | _ => massage_direct_call bound_Ts U T t))
   2.177            | _ => ill_formed_corec_call ctxt t)
   2.178          else
   2.179            build_map_Inl (T, U) $ t) bound_Ts;
   2.180    in
   2.181 -    massage_call U (typof t) t
   2.182 +    massage_call bound_Ts U (fastype_of1 (bound_Ts, t)) t
   2.183    end;
   2.184  
   2.185  fun expand_ctr_term ctxt s Ts t =
   2.186 @@ -357,15 +367,16 @@
   2.187  fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   2.188    (case fastype_of1 (bound_Ts, t) of
   2.189      Type (s, Ts) =>
   2.190 -    massage_let_if_case ctxt has_call (fn t =>
   2.191 +    massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
   2.192        if can (dest_ctr ctxt s) t then
   2.193          t
   2.194        else
   2.195 -        massage_let_if_case ctxt has_call I bound_Ts (expand_ctr_term ctxt s Ts t)) bound_Ts t
   2.196 +        massage_let_if_case ctxt has_call (K I) bound_Ts (expand_ctr_term ctxt s Ts t)) bound_Ts t
   2.197    | _ => raise Fail "expand_corec_code_rhs");
   2.198  
   2.199  fun massage_corec_code_rhs ctxt massage_ctr =
   2.200 -  massage_let_if_case ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
   2.201 +  massage_let_if_case ctxt (K false)
   2.202 +    (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
   2.203  
   2.204  fun fold_rev_corec_code_rhs ctxt f =
   2.205    fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);