simplified code; eliminated some dummyTs
authorpanny
Thu, 19 Sep 2013 16:12:43 +0200
changeset 5487299331dac1e1c
parent 54871 7613573f023a
child 54873 82799e03fff7
simplified code; eliminated some dummyTs
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Sep 19 12:20:12 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Sep 19 16:12:43 2013 +0200
     1.3 @@ -56,6 +56,8 @@
     1.4          | a n t = let val idx = find_index (equal t) vs in
     1.5              if idx < 0 then t else Bound (n + idx) end
     1.6    in a 0 end;
     1.7 +fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u;
     1.8 +fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts));
     1.9  
    1.10  val simp_attrs = @{attributes [simp]};
    1.11  
    1.12 @@ -561,48 +563,43 @@
    1.13    |> the_default undef_const
    1.14    |> K;
    1.15  
    1.16 -fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
    1.17 +fun build_corec_args_direct_call lthy has_call sel_eqns sel =
    1.18    let
    1.19      val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
    1.20 -    fun massage rhs_term is_end t =
    1.21 -      let
    1.22 -        val U = range_type (fastype_of t);
    1.23 -        fun rewrite t =
    1.24 -          if U = @{typ bool} then (if has_call t then @{term False} else @{term True}) (* stop? *)
    1.25 -          else if is_end = has_call t then undef_const
    1.26 -          else if is_end then t (* end *)
    1.27 -          else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
    1.28 -      in
    1.29 -        massage_direct_corec_call lthy has_call rewrite U rhs_term
    1.30 -      end;
    1.31 +    fun rewrite_q t = if has_call t then @{term False} else @{term True};
    1.32 +    fun rewrite_g t = if has_call t then undef_const else t;
    1.33 +    fun rewrite_h t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
    1.34 +    fun massage _ NONE t = t
    1.35 +      | massage f (SOME {fun_args, rhs_term, ...}) t =
    1.36 +        massage_direct_corec_call lthy has_call f (range_type (fastype_of t)) rhs_term
    1.37 +        |> abs_tuple fun_args;
    1.38    in
    1.39 -    if is_none maybe_sel_eqn then K I else
    1.40 -      abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
    1.41 +    (massage rewrite_q maybe_sel_eqn,
    1.42 +     massage rewrite_g maybe_sel_eqn,
    1.43 +     massage rewrite_h maybe_sel_eqn)
    1.44    end;
    1.45  
    1.46  fun build_corec_arg_indirect_call lthy has_call sel_eqns sel =
    1.47    let
    1.48      val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
    1.49 -    fun rewrite (Abs (v, T, b)) = Abs (v, T, rewrite b)
    1.50 -      | rewrite t =
    1.51 +    fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
    1.52 +      | rewrite bound_Ts U T (t as _ $ _) =
    1.53          let val (u, vs) = strip_comb t in
    1.54            if is_Free u andalso has_call u then
    1.55 -            Const (@{const_name Inr}, dummyT) $
    1.56 -              (if null vs then HOLogic.unit
    1.57 -               else foldr1 (fn (x, y) => Const (@{const_name Pair}, dummyT) $ x $ y) vs)
    1.58 +            Inr_const U T $ mk_tuple1 bound_Ts vs
    1.59            else if try (fst o dest_Const) u = SOME @{const_name prod_case} then
    1.60 -            list_comb (u |> map_types (K dummyT), map rewrite vs)
    1.61 -          else if null vs then
    1.62 -            u
    1.63 +            list_comb (map_types (K dummyT) u, map (rewrite bound_Ts U T) vs)
    1.64            else
    1.65 -            list_comb (rewrite u, map rewrite vs)
    1.66 -        end;
    1.67 -    fun massage rhs_term t =
    1.68 -      massage_indirect_corec_call lthy has_call (K (K rewrite)) [] (range_type (fastype_of t))
    1.69 -        rhs_term;
    1.70 +            list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs)
    1.71 +        end
    1.72 +      | rewrite _ U T t = if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
    1.73 +    fun massage NONE t = t
    1.74 +      | massage (SOME {fun_args, rhs_term, ...}) t =
    1.75 +        massage_indirect_corec_call lthy has_call (rewrite []) []
    1.76 +          (range_type (fastype_of t)) rhs_term
    1.77 +        |> abs_tuple fun_args;
    1.78    in
    1.79 -    if is_none maybe_sel_eqn then I else
    1.80 -      abs_tuple (#fun_args (the maybe_sel_eqn)) o massage (#rhs_term (the maybe_sel_eqn))
    1.81 +    massage maybe_sel_eqn
    1.82    end;
    1.83  
    1.84  fun build_corec_args_sel lthy has_call all_sel_eqns ctr_spec =
    1.85 @@ -616,11 +613,10 @@
    1.86          val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
    1.87        in
    1.88          I
    1.89 -        #> fold (fn (sel, n) => nth_map n
    1.90 -          (build_corec_arg_no_call sel_eqns sel)) no_calls'
    1.91 +        #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls'
    1.92          #> fold (fn (sel, (q, g, h)) =>
    1.93 -          let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
    1.94 -            nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
    1.95 +          let val (fq, fg, fh) = build_corec_args_direct_call lthy has_call sel_eqns sel in
    1.96 +            nth_map q fq o nth_map g fg o nth_map h fh end) direct_calls'
    1.97          #> fold (fn (sel, n) => nth_map n
    1.98            (build_corec_arg_indirect_call lthy has_call sel_eqns sel)) indirect_calls'
    1.99        end
   1.100 @@ -636,10 +632,9 @@
   1.101        |> map (Const o pair @{const_name undefined})
   1.102        |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss
   1.103        |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
   1.104 -    fun currys Ts t = if length Ts <= 1 then t else
   1.105 -      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
   1.106 -        (length Ts - 1 downto 0 |> map Bound)
   1.107 -      |> fold_rev (Term.abs o pair Name.uu) Ts;
   1.108 +    fun currys [] t = t
   1.109 +      | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0))
   1.110 +          |> fold_rev (Term.abs o pair Name.uu) Ts;
   1.111  
   1.112  val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
   1.113   space_implode "\n    \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
   1.114 @@ -792,7 +787,6 @@
   1.115  
   1.116          fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns
   1.117              {ctr, disc, sels, collapse, ...} =
   1.118 -let val _ = tracing ("disc = " ^ @{make_string} disc); in
   1.119            if not (exists (equal ctr o #ctr) disc_eqns)
   1.120                andalso not (exists (equal ctr o #ctr) sel_eqns)
   1.121  andalso (warning ("no eqns for ctr " ^ Syntax.string_of_term lthy ctr); true)
   1.122 @@ -804,7 +798,7 @@
   1.123            then [] else
   1.124              let
   1.125  val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
   1.126 -val _ = tracing (the_default "NO disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
   1.127 +val _ = tracing (the_default "no disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
   1.128                val (fun_name, fun_T, fun_args, prems) =
   1.129                  (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
   1.130                  |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
   1.131 @@ -831,8 +825,6 @@
   1.132                mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
   1.133                |> K |> Goal.prove lthy [] [] t
   1.134                |> single
   1.135 -(*handle ERROR x => (warning x; []))*)
   1.136 -end
   1.137            end;
   1.138  
   1.139          val (disc_notes, disc_thmss) =