src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 55517 e30e63d05e58
parent 55181 93ab44e992ae
child 55519 7be49e2bfccc
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 04 17:00:35 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 04 18:27:07 2013 +0200
     1.3 @@ -34,10 +34,10 @@
     1.4  open BNF_FP_Rec_Sugar_Util
     1.5  open BNF_FP_Rec_Sugar_Tactics
     1.6  
     1.7 -val codeN = "code"
     1.8 -val ctrN = "ctr"
     1.9 -val discN = "disc"
    1.10 -val selN = "sel"
    1.11 +val codeN = "code";
    1.12 +val ctrN = "ctr";
    1.13 +val discN = "disc";
    1.14 +val selN = "sel";
    1.15  
    1.16  val nitpick_attrs = @{attributes [nitpick_simp]};
    1.17  val simp_attrs = @{attributes [simp]};
    1.18 @@ -472,7 +472,7 @@
    1.19    Disc of co_eqn_data_disc |
    1.20    Sel of co_eqn_data_sel;
    1.21  
    1.22 -fun co_dissect_eqn_disc sequential fun_names (corec_specs : corec_spec list) prems' concl
    1.23 +fun co_dissect_eqn_disc seq fun_names (corec_specs : corec_spec list) prems' concl
    1.24      matchedsss =
    1.25    let
    1.26      fun find_subterm p = let (* FIXME \<exists>? *)
    1.27 @@ -507,11 +507,11 @@
    1.28      val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
    1.29      val prems = map (abstract (List.rev fun_args)) prems';
    1.30      val real_prems =
    1.31 -      (if catch_all orelse sequential then maps negate_disj matchedss else []) @
    1.32 +      (if catch_all orelse seq then maps negate_disj matchedss else []) @
    1.33        (if catch_all then [] else prems);
    1.34  
    1.35      val matchedsss' = AList.delete (op =) fun_name matchedsss
    1.36 -      |> cons (fun_name, if sequential then matchedss @ [prems] else matchedss @ [real_prems]);
    1.37 +      |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]);
    1.38  
    1.39      val user_eqn =
    1.40        (real_prems, betapply (#disc (nth ctr_specs ctr_no), applied_fun))
    1.41 @@ -560,49 +560,72 @@
    1.42      }
    1.43    end;
    1.44  
    1.45 -fun co_dissect_eqn_ctr sequential fun_names (corec_specs : corec_spec list) eqn' imp_prems imp_rhs
    1.46 +fun co_dissect_eqn_ctr seq fun_names (corec_specs : corec_spec list) eqn' prems concl
    1.47      matchedsss =
    1.48    let
    1.49 -    val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
    1.50 +    val (lhs, rhs) = HOLogic.dest_eq concl;
    1.51      val fun_name = head_of lhs |> fst o dest_Free;
    1.52      val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
    1.53      val (ctr, ctr_args) = strip_comb rhs;
    1.54      val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
    1.55        handle Option.Option => primrec_error_eqn "not a constructor" ctr;
    1.56  
    1.57 -    val disc_imp_rhs = betapply (disc, lhs);
    1.58 +    val disc_concl = betapply (disc, lhs);
    1.59      val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
    1.60        then (NONE, matchedsss)
    1.61        else apfst SOME (co_dissect_eqn_disc
    1.62 -          sequential fun_names corec_specs imp_prems disc_imp_rhs matchedsss);
    1.63 +          seq fun_names corec_specs prems disc_concl matchedsss);
    1.64  
    1.65 -    val sel_imp_rhss = (sels ~~ ctr_args)
    1.66 +    val sel_concls = (sels ~~ ctr_args)
    1.67        |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
    1.68  
    1.69  (*
    1.70 -val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
    1.71 - (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
    1.72 - space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
    1.73 +val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} concl ^ "\nto\n    \<cdot> " ^
    1.74 + (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n    \<cdot> ")) "" ^
    1.75 + space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls));
    1.76  *)
    1.77  
    1.78 -    val eqns_data_sel =
    1.79 -      map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_imp_rhss;
    1.80 +    val eqns_data_sel = map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_concls;
    1.81    in
    1.82      (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
    1.83    end;
    1.84  
    1.85 -fun co_dissect_eqn sequential fun_names (corec_specs : corec_spec list) eqn' of_spec matchedsss =
    1.86 +fun co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss =
    1.87 +  let
    1.88 +    val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
    1.89 +    val fun_name = head_of lhs |> fst o dest_Free;
    1.90 +    val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
    1.91 +
    1.92 +    val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
    1.93 +        if member ((op =) o apsnd #ctr) ctr_specs ctr
    1.94 +        then cons (ctr, cs)
    1.95 +        else primrec_error_eqn "not a constructor" ctr) [] rhs' []
    1.96 +      |> AList.group (op =);
    1.97 +
    1.98 +    val ctr_premss = map (single o mk_disjs o map mk_conjs o snd) cond_ctrs;
    1.99 +    val ctr_concls = cond_ctrs |> map (fn (ctr, _) =>
   1.100 +        binder_types (fastype_of ctr)
   1.101 +        |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
   1.102 +          if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
   1.103 +        |> curry list_comb ctr
   1.104 +        |> curry HOLogic.mk_eq lhs);
   1.105 +  in
   1.106 +    fold_map2 (co_dissect_eqn_ctr false fun_names corec_specs eqn') ctr_premss ctr_concls matchedsss
   1.107 +  end;
   1.108 +
   1.109 +fun co_dissect_eqn lthy has_call seq fun_names (corec_specs : corec_spec list) eqn' of_spec
   1.110 +    matchedsss =
   1.111    let
   1.112      val eqn = drop_All eqn'
   1.113        handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
   1.114 -    val (imp_prems, imp_rhs) = Logic.strip_horn eqn
   1.115 +    val (prems, concl) = Logic.strip_horn eqn
   1.116        |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
   1.117  
   1.118 -    val head = imp_rhs
   1.119 +    val head = concl
   1.120        |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
   1.121        |> head_of;
   1.122  
   1.123 -    val maybe_rhs = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
   1.124 +    val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
   1.125  
   1.126      val discs = maps #ctr_specs corec_specs |> map #disc;
   1.127      val sels = maps #ctr_specs corec_specs |> maps #sels;
   1.128 @@ -611,12 +634,17 @@
   1.129      if member (op =) discs head orelse
   1.130        is_some maybe_rhs andalso
   1.131          member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
   1.132 -      co_dissect_eqn_disc sequential fun_names corec_specs imp_prems imp_rhs matchedsss
   1.133 +      co_dissect_eqn_disc seq fun_names corec_specs prems concl matchedsss
   1.134        |>> single
   1.135      else if member (op =) sels head then
   1.136 -      ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec imp_rhs], matchedsss)
   1.137 -    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
   1.138 -      co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss
   1.139 +      ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec concl], matchedsss)
   1.140 +    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
   1.141 +      member (op =) ctrs (head_of (the maybe_rhs)) then
   1.142 +      co_dissect_eqn_ctr seq fun_names corec_specs eqn' prems concl matchedsss
   1.143 +    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
   1.144 +      null prems then
   1.145 +      co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss
   1.146 +      |>> flat
   1.147      else
   1.148        primrec_error_eqn "malformed function equation" eqn
   1.149    end;
   1.150 @@ -646,7 +674,7 @@
   1.151        fun rewrite_g _ t = if has_call t then undef_const else t;
   1.152        fun rewrite_h bound_Ts t =
   1.153          if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
   1.154 -      fun massage f t = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
   1.155 +      fun massage f _ = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
   1.156      in
   1.157        (massage rewrite_q,
   1.158         massage rewrite_g,
   1.159 @@ -763,7 +791,7 @@
   1.160        chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
   1.161      end;
   1.162  
   1.163 -fun add_primcorec simple sequential fixes specs of_specs lthy =
   1.164 +fun add_primcorec simple seq fixes specs of_specs lthy =
   1.165    let
   1.166      val (bs, mxs) = map_split (apfst fst) fixes;
   1.167      val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
   1.168 @@ -778,8 +806,10 @@
   1.169      val fun_names = map Binding.name_of bs;
   1.170      val corec_specs = take actual_nn corec_specs'; (*###*)
   1.171  
   1.172 +    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   1.173      val eqns_data =
   1.174 -      fold_map2 (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) of_specs []
   1.175 +      fold_map2 (co_dissect_eqn lthy has_call seq fun_names corec_specs)
   1.176 +        (map snd specs) of_specs []
   1.177        |> flat o fst;
   1.178  
   1.179      val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
   1.180 @@ -796,14 +826,13 @@
   1.181        |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
   1.182        |> map (flat o snd);
   1.183  
   1.184 -    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   1.185      val arg_Tss = map (binder_types o snd o fst) fixes;
   1.186      val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
   1.187      val (defs, exclss') =
   1.188        co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
   1.189  
   1.190      fun excl_tac (c, c', a) =
   1.191 -      if a orelse c = c' orelse sequential then
   1.192 +      if a orelse c = c' orelse seq then
   1.193          SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy [])))
   1.194        else if simple then
   1.195          SOME (K (auto_tac lthy))