src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 54538 2101a97e6220
parent 54497 7ffc4a746a73
child 54548 ab4edf89992f
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Tue Sep 03 21:46:42 2013 +0100
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Wed Sep 04 02:11:50 2013 +0200
     1.3 @@ -36,8 +36,7 @@
     1.4  
     1.5  fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
     1.6    |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
     1.7 -fun abs_tuple t = if t = undef_const then t else
     1.8 -  strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
     1.9 +val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
    1.10  
    1.11  val simp_attrs = @{attributes [simp]};
    1.12  
    1.13 @@ -107,7 +106,7 @@
    1.14       user_eqn = eqn'}
    1.15    end;
    1.16  
    1.17 -fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type =
    1.18 +fun rewrite_map_arg get_ctr_pos rec_type res_type =
    1.19    let
    1.20      val pT = HOLogic.mk_prodT (rec_type, res_type);
    1.21  
    1.22 @@ -117,11 +116,9 @@
    1.23        | subst d t =
    1.24          let
    1.25            val (u, vs) = strip_comb t;
    1.26 -          val maybe_fun_name_ctr_pos =
    1.27 -            find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list;
    1.28 -          val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos;
    1.29 +          val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
    1.30          in
    1.31 -          if is_some maybe_fun_name_ctr_pos then
    1.32 +          if ctr_pos >= 0 then
    1.33              if d = SOME ~1 andalso length vs = ctr_pos then
    1.34                list_comb (permute_args ctr_pos (snd_const pT), vs)
    1.35              else if length vs > ctr_pos andalso is_some d
    1.36 @@ -138,7 +135,7 @@
    1.37      subst (SOME ~1)
    1.38    end;
    1.39  
    1.40 -fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t =
    1.41 +fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
    1.42    let
    1.43      fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
    1.44        | subst bound_Ts (t as g' $ y) =
    1.45 @@ -146,19 +143,18 @@
    1.46            val maybe_direct_y' = AList.lookup (op =) direct_calls y;
    1.47            val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
    1.48            val (g, g_args) = strip_comb g';
    1.49 -          val maybe_ctr_pos =
    1.50 -            try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list;
    1.51 -          val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse
    1.52 +          val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
    1.53 +          val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
    1.54              primrec_error_eqn "too few arguments in recursive call" t;
    1.55          in
    1.56            if not (member (op =) ctr_args y) then
    1.57              pairself (subst bound_Ts) (g', y) |> (op $)
    1.58 -          else if is_some maybe_ctr_pos then
    1.59 +          else if ctr_pos >= 0 then
    1.60              list_comb (the maybe_direct_y', g_args)
    1.61            else if is_some maybe_indirect_y' then
    1.62              (if has_call g' then t else y)
    1.63              |> massage_indirect_rec_call lthy has_call
    1.64 -              (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y')
    1.65 +              (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
    1.66              |> (if has_call g' then I else curry (op $) g')
    1.67            else
    1.68              t
    1.69 @@ -211,16 +207,17 @@
    1.70              nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
    1.71            indirect_calls';
    1.72  
    1.73 +      val fun_name_ctr_pos_list =
    1.74 +        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
    1.75 +      val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
    1.76        val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
    1.77        val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
    1.78  
    1.79 -      val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
    1.80 -      val fun_name_ctr_pos_list =
    1.81 -        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
    1.82 +      val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
    1.83      in
    1.84        t
    1.85 -      |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls
    1.86 -      |> fold_rev absfree abstractions
    1.87 +      |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
    1.88 +      |> fold_rev lambda abstractions
    1.89      end;
    1.90  
    1.91  fun build_defs lthy bs mxs funs_data rec_specs has_call =
    1.92 @@ -372,15 +369,16 @@
    1.93  
    1.94  type co_eqn_data_disc = {
    1.95    fun_name: string,
    1.96 +  fun_args: term list,
    1.97    ctr_no: int, (*###*)
    1.98    cond: term,
    1.99    user_eqn: term
   1.100  };
   1.101  type co_eqn_data_sel = {
   1.102    fun_name: string,
   1.103 +  fun_args: term list,
   1.104    ctr: term,
   1.105    sel: term,
   1.106 -  fun_args: term list,
   1.107    rhs_term: term,
   1.108    user_eqn: term
   1.109  };
   1.110 @@ -388,11 +386,10 @@
   1.111    Disc of co_eqn_data_disc |
   1.112    Sel of co_eqn_data_sel;
   1.113  
   1.114 -fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
   1.115 +fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
   1.116    let
   1.117      fun find_subterm p = let (* FIXME \<exists>? *)
   1.118 -      fun f (t as u $ v) =
   1.119 -        fold_rev (curry merge_options) [if p t then SOME t else NONE, f u, f v] NONE
   1.120 +      fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
   1.121          | f t = if p t then SOME t else NONE
   1.122        in f end;
   1.123  
   1.124 @@ -406,9 +403,8 @@
   1.125  
   1.126      val discs = #ctr_specs corec_spec |> map #disc;
   1.127      val ctrs = #ctr_specs corec_spec |> map #ctr;
   1.128 -    val n_ctrs = length ctrs;
   1.129      val not_disc = head_of imp_rhs = @{term Not};
   1.130 -    val _ = not_disc andalso n_ctrs <> 2 andalso
   1.131 +    val _ = not_disc andalso length ctrs <> 2 andalso
   1.132        primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" imp_rhs;
   1.133      val disc = find_subterm (member (op =) discs o head_of) imp_rhs;
   1.134      val eq_ctr0 = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
   1.135 @@ -428,32 +424,28 @@
   1.136      val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
   1.137      val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
   1.138      val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;
   1.139 -    val matched_conds = filter (equal fun_name o fst) matched_conds_ps |> map snd;
   1.140 -    val imp_lhs = mk_conjs imp_lhs';
   1.141 +    val matched_cond = filter (equal fun_name o fst) matched_conds |> map snd |> mk_disjs;
   1.142 +    val imp_lhs = mk_conjs imp_lhs'
   1.143 +      |> incr_boundvars (length fun_args)
   1.144 +      |> subst_atomic (fun_args ~~ map Bound (length fun_args - 1 downto 0))
   1.145      val cond =
   1.146        if catch_all then
   1.147 -        if null matched_conds then fold_rev absfree (map dest_Free fun_args) @{const True} else
   1.148 -          (strip_abs_vars (hd matched_conds),
   1.149 -            mk_disjs (map strip_abs_body matched_conds) |> HOLogic.mk_not)
   1.150 -          |-> fold_rev (fn (v, T) => fn u => Abs (v, T, u))
   1.151 +        matched_cond |> HOLogic.mk_not
   1.152        else if sequential then
   1.153 -        HOLogic.mk_conj (HOLogic.mk_not (mk_disjs (map strip_abs_body matched_conds)), imp_lhs)
   1.154 -        |> fold_rev absfree (map dest_Free fun_args)
   1.155 +        HOLogic.mk_conj (HOLogic.mk_not matched_cond, imp_lhs)
   1.156        else
   1.157 -        imp_lhs |> fold_rev absfree (map dest_Free fun_args);
   1.158 -    val matched_cond =
   1.159 -      if sequential then fold_rev absfree (map dest_Free fun_args) imp_lhs else cond;
   1.160 +        imp_lhs;
   1.161  
   1.162 -    val matched_conds_ps' = if catch_all
   1.163 -      then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps
   1.164 -      else (fun_name, matched_cond) :: matched_conds_ps;
   1.165 +    val matched_conds' =
   1.166 +      (fun_name, if catch_all orelse not sequential then cond else imp_lhs) :: matched_conds;
   1.167    in
   1.168      (Disc {
   1.169        fun_name = fun_name,
   1.170 +      fun_args = fun_args,
   1.171        ctr_no = ctr_no,
   1.172        cond = cond,
   1.173        user_eqn = eqn'
   1.174 -    }, matched_conds_ps')
   1.175 +    }, matched_conds')
   1.176    end;
   1.177  
   1.178  fun co_dissect_eqn_sel fun_name_corec_spec_list eqn' eqn =
   1.179 @@ -473,15 +465,15 @@
   1.180    in
   1.181      Sel {
   1.182        fun_name = fun_name,
   1.183 +      fun_args = fun_args,
   1.184        ctr = #ctr ctr_spec,
   1.185        sel = sel,
   1.186 -      fun_args = fun_args,
   1.187        rhs_term = rhs,
   1.188        user_eqn = eqn'
   1.189      }
   1.190    end;
   1.191  
   1.192 -fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
   1.193 +fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
   1.194    let 
   1.195      val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
   1.196      val fun_name = head_of lhs |> fst o dest_Free;
   1.197 @@ -491,10 +483,10 @@
   1.198        handle Option.Option => primrec_error_eqn "not a constructor" ctr;
   1.199  
   1.200      val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
   1.201 -    val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1
   1.202 -      then (NONE, matched_conds_ps)
   1.203 +    val (maybe_eqn_data_disc, matched_conds') = if length (#ctr_specs corec_spec) = 1
   1.204 +      then (NONE, matched_conds)
   1.205        else apfst SOME (co_dissect_eqn_disc
   1.206 -          sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps);
   1.207 +          sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds);
   1.208  
   1.209      val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
   1.210        |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
   1.211 @@ -506,10 +498,10 @@
   1.212      val eqns_data_sel =
   1.213        map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
   1.214    in
   1.215 -    (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps')
   1.216 +    (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds')
   1.217    end;
   1.218  
   1.219 -fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps =
   1.220 +fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds =
   1.221    let
   1.222      val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
   1.223          strip_qnt_body @{const_name all} eqn')
   1.224 @@ -531,65 +523,68 @@
   1.225      if member (op =) discs head orelse
   1.226        is_some maybe_rhs andalso
   1.227          member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
   1.228 -      co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
   1.229 +      co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
   1.230        |>> single
   1.231      else if member (op =) sels head then
   1.232 -      ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds_ps)
   1.233 +      ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds)
   1.234      else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
   1.235 -      co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
   1.236 +      co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
   1.237      else
   1.238        primrec_error_eqn "malformed function equation" eqn
   1.239    end;
   1.240  
   1.241  fun build_corec_args_discs disc_eqns ctr_specs =
   1.242 -  let
   1.243 -    val conds = map #cond disc_eqns;
   1.244 -    val args' =
   1.245 -      if length ctr_specs = 1 then []
   1.246 -      else if length disc_eqns = length ctr_specs then
   1.247 -        fst (split_last conds)
   1.248 -      else if length disc_eqns = length ctr_specs - 1 then
   1.249 -        let val n = 0 upto length ctr_specs - 1
   1.250 -            |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
   1.251 -          if n = length ctr_specs - 1 then
   1.252 -            conds
   1.253 -          else
   1.254 -            split_last conds
   1.255 -            ||> (fn t => fold_rev absfree (strip_abs_vars t) (strip_abs_body t |> HOLogic.mk_not))
   1.256 -            |>> chop n
   1.257 -            |> (fn ((l, r), x) => l @ (x :: r))
   1.258 -        end
   1.259 -      else
   1.260 -        0 upto length ctr_specs - 1
   1.261 -        |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
   1.262 -          |> Option.map #cond
   1.263 -          |> the_default undef_const)
   1.264 -        |> fst o split_last;
   1.265 -  in
   1.266 -    (* FIXME: deal with #preds above *)
   1.267 -    fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
   1.268 -  end;
   1.269 +  if null disc_eqns then I else
   1.270 +    let
   1.271 +      val conds = map #cond disc_eqns;
   1.272 +      val fun_args = #fun_args (hd disc_eqns);
   1.273 +      val args =
   1.274 +        if length ctr_specs = 1 then []
   1.275 +        else if length disc_eqns = length ctr_specs then
   1.276 +          fst (split_last conds)
   1.277 +        else if length disc_eqns = length ctr_specs - 1 then
   1.278 +          let val n = 0 upto length ctr_specs - 1
   1.279 +              |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
   1.280 +            if n = length ctr_specs - 1 then
   1.281 +              conds
   1.282 +            else
   1.283 +              split_last conds
   1.284 +              ||> HOLogic.mk_not
   1.285 +              |>> chop n
   1.286 +              |> (fn ((l, r), x) => l @ (x :: r))
   1.287 +          end
   1.288 +        else
   1.289 +          0 upto length ctr_specs - 1
   1.290 +          |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
   1.291 +            |> Option.map #cond
   1.292 +            |> the_default undef_const)
   1.293 +          |> fst o split_last;
   1.294 +    in
   1.295 +      (* FIXME deal with #preds above *)
   1.296 +      (map_filter #pred ctr_specs, args)
   1.297 +      |-> fold2 (fn idx => fn t => nth_map idx
   1.298 +        (K (subst_bounds (List.rev fun_args, t)
   1.299 +          |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args))))
   1.300 +    end;
   1.301  
   1.302  fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns
   1.303 -  |> try (fn SOME sel_eqn => (#fun_args sel_eqn |> map dest_Free, #rhs_term sel_eqn))
   1.304 +  |> try (fn SOME sel_eqn => (#fun_args sel_eqn, #rhs_term sel_eqn))
   1.305    |> the_default ([], undef_const)
   1.306 -  |-> abs_tuple oo fold_rev absfree;
   1.307 +  |-> abs_tuple;
   1.308  
   1.309  fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
   1.310    let
   1.311      val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns
   1.312 -
   1.313 -    fun rewrite U T t =
   1.314 +    fun rewrite is_end U T t =
   1.315        if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *)
   1.316 -      else if T = U = has_call t then undef_const
   1.317 -      else if T = U then t (* end *)
   1.318 +      else if is_end = has_call t then undef_const
   1.319 +      else if is_end then t (* end *)
   1.320        else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
   1.321 -    fun massage rhs_term t =
   1.322 -      massage_direct_corec_call lthy has_call rewrite [] (body_type (fastype_of t)) rhs_term;
   1.323 -    val abstract = abs_tuple oo fold_rev absfree o map dest_Free;
   1.324 +    fun massage rhs_term is_end t = massage_direct_corec_call
   1.325 +      lthy has_call (rewrite is_end) [] (range_type (fastype_of t)) rhs_term;
   1.326    in
   1.327 -    if is_none maybe_sel_eqn then I else
   1.328 -      massage (#rhs_term (the maybe_sel_eqn)) #> abstract (#fun_args (the maybe_sel_eqn))
   1.329 +    if is_none maybe_sel_eqn then K I else
   1.330 +      abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
   1.331    end;
   1.332  
   1.333  fun build_corec_arg_indirect_call sel_eqns sel =
   1.334 @@ -614,7 +609,7 @@
   1.335            (build_corec_arg_no_call sel_eqns sel |> K)) no_calls'
   1.336          #> fold (fn (sel, (q, g, h)) =>
   1.337            let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
   1.338 -            nth_map h f o nth_map g f o nth_map q f end) direct_calls'
   1.339 +            nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
   1.340          #> fold (fn (sel, n) => nth_map n
   1.341            (build_corec_arg_indirect_call sel_eqns sel |> K)) indirect_calls'
   1.342        end
   1.343 @@ -651,24 +646,25 @@
   1.344        |> fold2 build_corec_args_discs disc_eqnss ctr_specss
   1.345        |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
   1.346  
   1.347 +    fun currys Ts t = if length Ts <= 1 then t else
   1.348 +      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
   1.349 +        (length Ts - 1 downto 0 |> map Bound)
   1.350 +      |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
   1.351 +
   1.352  val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
   1.353   space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));
   1.354  
   1.355      fun uneq_pairs_rev xs = xs (* FIXME \<exists>? *)
   1.356        |> these o try (split_last #> (fn (ys, y) => uneq_pairs_rev ys @ map (pair y) ys));
   1.357      val proof_obligations = if sequential then [] else
   1.358 -      maps (uneq_pairs_rev o map #cond) disc_eqnss
   1.359 -      |> map (fn (x, y) => ((strip_abs_body x, strip_abs_body y), strip_abs_vars x))
   1.360 -      |> map (apfst (apsnd HOLogic.mk_not #> pairself HOLogic.mk_Trueprop
   1.361 -        #> apfst (curry (op $) @{const ==>}) #> (op $)))
   1.362 -      |> map (fn (t, abs_vars) => fold_rev (fn (v, T) => fn u =>
   1.363 -          Const (@{const_name all}, (T --> @{typ prop}) --> @{typ prop}) $
   1.364 -            Abs (v, T, u)) abs_vars t);
   1.365 +      maps (uneq_pairs_rev o map (fn {fun_args, cond, ...} => (fun_args, cond))) disc_eqnss
   1.366 +      |> map (fn ((fun_args, x), (_, y)) => [x, HOLogic.mk_not y]
   1.367 +        |> map (HOLogic.mk_Trueprop o curry subst_bounds (List.rev fun_args))
   1.368 +        |> curry list_comb @{const ==>});
   1.369  
   1.370 -    fun currys Ts t = if length Ts <= 1 then t else
   1.371 -      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
   1.372 -        (length Ts - 1 downto 0 |> map Bound)
   1.373 -      |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
   1.374 +val _ = tracing ("proof obligations:\n    \<cdot> " ^
   1.375 + space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) proof_obligations));
   1.376 +
   1.377    in
   1.378      map (list_comb o rpair corec_args) corecs
   1.379      |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss