merged
authortraytel
Sat, 31 Aug 2013 23:55:03 +0200
changeset 54491b7469b85ca28
parent 54490 0c1c67e3fccc
parent 54488 4335477c60f5
child 54492 603e6e97c391
merged
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
     1.1 --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sat Aug 31 23:49:36 2013 +0200
     1.2 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sat Aug 31 23:55:03 2013 +0200
     1.3 @@ -103,32 +103,6 @@
     1.4       user_eqn = eqn'}
     1.5    end;
     1.6  
     1.7 -(* substitutes (f ls x rs) by (y ls rs) for all f: get_idx f \<ge> 0, (x,y) \<in> substs *)
     1.8 -fun subst_direct_calls get_idx get_ctr_pos substs = 
     1.9 -  let
    1.10 -    fun subst (Abs (v, T, b)) = Abs (v, T, subst b)
    1.11 -      | subst t =
    1.12 -        let
    1.13 -          val (f, args) = strip_comb t;
    1.14 -          val idx = get_idx f;
    1.15 -          val ctr_pos  = if idx >= 0 then get_ctr_pos idx else ~1;
    1.16 -        in
    1.17 -          if idx < 0 then
    1.18 -            list_comb (f, map subst args)
    1.19 -          else if ctr_pos >= length args then
    1.20 -            primrec_error_eqn "too few arguments in recursive call" t
    1.21 -          else
    1.22 -            let
    1.23 -              val (key, repl) = the (find_first (equal (nth args ctr_pos) o fst) substs)
    1.24 -                handle Option.Option => primrec_error_eqn
    1.25 -                  "recursive call not directly applied to constructor argument" t;
    1.26 -            in
    1.27 -              remove (op =) key args |> map subst |> curry list_comb repl
    1.28 -            end
    1.29 -        end
    1.30 -  in subst end;
    1.31 -
    1.32 -(* FIXME get rid of funs_data or get_indices *)
    1.33  fun rewrite_map_arg funs_data get_indices y rec_type res_type =
    1.34    let
    1.35      val pT = HOLogic.mk_prodT (rec_type, res_type);
    1.36 @@ -169,36 +143,41 @@
    1.37    end;
    1.38  
    1.39  (* FIXME get rid of funs_data or get_indices *)
    1.40 -fun subst_indirect_call lthy funs_data get_indices (y, y') =
    1.41 +fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
    1.42    let
    1.43 -    fun massage massage_map_arg bound_Ts =
    1.44 -      massage_indirect_rec_call lthy (not o null o get_indices) massage_map_arg bound_Ts y y';
    1.45 -    fun subst bound_Ts (t as _ $ _) =
    1.46 +    val contains_fun = not o null o get_indices;
    1.47 +    fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
    1.48 +      | subst bound_Ts (t as g $ y) =
    1.49          let
    1.50 -          val ctr_args = fold_aterms (curry (op @) o get_indices) t []
    1.51 -            |> maps (maps #ctr_args o nth funs_data);
    1.52 -          val (f', args') = strip_comb t;
    1.53 -          val fun_arg_idx = find_index (exists_subterm (not o null o get_indices)) args';
    1.54 -          val arg_idx = find_index (exists_subterm (equal y)) args';
    1.55 -          val (f, args) = chop (arg_idx + 1) args' |>> curry list_comb f';
    1.56 -          val _ = fun_arg_idx < 0 orelse arg_idx >= 0 orelse
    1.57 -            exists (exists_subterm (member (op =) ctr_args)) args' orelse
    1.58 -            primrec_error_eqn "recursive call not applied to constructor argument" t;
    1.59 +          val is_ctr_arg = exists (exists (exists (equal y) o #ctr_args)) funs_data;
    1.60 +          val maybe_direct_y' = AList.lookup (op =) direct_calls y;
    1.61 +          val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
    1.62 +          val (g_head, g_args) = strip_comb g;
    1.63          in
    1.64 -          if fun_arg_idx <> arg_idx andalso fun_arg_idx >= 0 andalso arg_idx >= 0 then
    1.65 -            if nth args' arg_idx = y then
    1.66 -              list_comb (massage (rewrite_map_arg funs_data get_indices y) bound_Ts f, args)
    1.67 -            else
    1.68 -              primrec_error_eqn "recursive call not directly applied to constructor argument" f
    1.69 +          if not is_ctr_arg then
    1.70 +            pairself (subst bound_Ts) (g, y) |> (op $)
    1.71 +          else if contains_fun g_head then
    1.72 +            (length g_args >= the (funs_data |> get_first (fn {fun_name, left_args, ...} :: _ =>
    1.73 +              if fst (dest_Free g_head) = fun_name then SOME (length left_args) else NONE)) (*###*)
    1.74 +                orelse primrec_error_eqn "too few arguments in recursive call" t;
    1.75 +            list_comb (the maybe_direct_y', g_args))
    1.76 +          else if is_some maybe_indirect_y' then
    1.77 +            (if contains_fun g then t else y)
    1.78 +            |> massage_indirect_rec_call lthy contains_fun
    1.79 +              (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y')
    1.80 +            |> (if contains_fun g then I else curry (op $) g)
    1.81            else
    1.82 -            list_comb (f', map (subst bound_Ts) args')
    1.83 +            t
    1.84          end
    1.85 -      | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
    1.86 -      | subst bound_Ts t = t |> t = y ? massage (K I |> K) bound_Ts;
    1.87 -  in subst [] end;
    1.88 +      | subst _ t = t
    1.89 +  in
    1.90 +    subst [] t
    1.91 +    |> (fn u => ((contains_fun u andalso (* FIXME detect this case earlier *)
    1.92 +      primrec_error_eqn "recursive call not directly applied to constructor argument" t); u))
    1.93 +  end;
    1.94  
    1.95  fun build_rec_arg lthy get_indices funs_data ctr_spec maybe_eqn_data =
    1.96 -  if is_some maybe_eqn_data then
    1.97 +  if is_none maybe_eqn_data then Const (@{const_name undefined}, dummyT) else
    1.98      let
    1.99        val eqn_data = the maybe_eqn_data;
   1.100        val t = #rhs_term eqn_data;
   1.101 @@ -241,17 +220,12 @@
   1.102        val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
   1.103        val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
   1.104  
   1.105 -      val get_idx = (fn Free (v, _) => find_index (equal v o #fun_name o hd) funs_data | _ => ~1);
   1.106 -
   1.107 -      val t' = t
   1.108 -        |> fold (subst_indirect_call lthy funs_data get_indices) indirect_calls
   1.109 -        |> subst_direct_calls get_idx (length o #left_args o hd o nth funs_data) direct_calls;
   1.110 -
   1.111        val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
   1.112      in
   1.113 -      t' |> fold_rev absfree abstractions
   1.114 -    end
   1.115 -  else Const (@{const_name undefined}, dummyT)
   1.116 +      t
   1.117 +      |> subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls
   1.118 +      |> fold_rev absfree abstractions
   1.119 +    end;
   1.120  
   1.121  fun build_defs lthy bs mxs funs_data rec_specs get_indices =
   1.122    let