towards proper handling of argument order in comprehensions
authorhaftmann
Thu, 30 Jul 2009 13:52:18 +0200
changeset 32341c8c17c2e6ceb
parent 32340 b4632820e74c
child 32342 3fabf5b5fc83
towards proper handling of argument order in comprehensions
src/HOL/ex/predicate_compile.ML
     1.1 --- a/src/HOL/ex/predicate_compile.ML	Thu Jul 30 13:52:18 2009 +0200
     1.2 +++ b/src/HOL/ex/predicate_compile.ML	Thu Jul 30 13:52:18 2009 +0200
     1.3 @@ -82,9 +82,9 @@
     1.4    | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
     1.5    | dest_tuple t = [t]
     1.6  
     1.7 -fun mk_pred_enumT T = Type (@{type_name "Predicate.pred"}, [T])
     1.8 +fun mk_pred_enumT T = Type (@{type_name Predicate.pred}, [T])
     1.9  
    1.10 -fun dest_pred_enumT (Type (@{type_name "Predicate.pred"}, [T])) = T
    1.11 +fun dest_pred_enumT (Type (@{type_name Predicate.pred}, [T])) = T
    1.12    | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
    1.13  
    1.14  fun mk_Enum f =
    1.15 @@ -119,6 +119,10 @@
    1.16  fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
    1.17    in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
    1.18  
    1.19 +fun mk_pred_map T1 T2 tf tp = Const (@{const_name Predicate.map},
    1.20 +  (T1 --> T2) --> mk_pred_enumT T1 --> mk_pred_enumT T2) $ tf $ tp;
    1.21 +
    1.22 +
    1.23  (* destruction of intro rules *)
    1.24  
    1.25  (* FIXME: look for other place where this functionality was used before *)
    1.26 @@ -383,7 +387,7 @@
    1.27  
    1.28  fun get_args is ts = let
    1.29    fun get_args' _ _ [] = ([], [])
    1.30 -    | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
    1.31 +    | get_args' is i (t::ts) = (if member (op =) is i then apfst else apsnd) (cons t)
    1.32          (get_args' is (i+1) ts)
    1.33  in get_args' is 1 ts end
    1.34  
    1.35 @@ -1527,18 +1531,17 @@
    1.36  
    1.37  val eval_ref = ref (NONE : (unit -> term Predicate.pred) option);
    1.38  
    1.39 +(*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
    1.40  fun analyze_compr thy t_compr =
    1.41    let
    1.42      val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
    1.43        | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
    1.44      val (body, Ts, fp) = HOLogic.strip_splits split;
    1.45 -      (*FIXME former order of tuple positions must be restored*)
    1.46 -    val (pred as Const (name, T), all_args) = strip_comb body
    1.47 -    val (params, args) = chop (nparams_of thy name) all_args
    1.48 +    val (pred as Const (name, T), all_args) = strip_comb body;
    1.49 +    val (params, args) = chop (nparams_of thy name) all_args;
    1.50      val user_mode = map_filter I (map_index
    1.51        (fn (i, t) => case t of Bound j => if j < length Ts then NONE
    1.52 -        else SOME (i+1) | _ => SOME (i+1)) args) (*FIXME dangling bounds should not occur*)
    1.53 -    val (inargs, _) = get_args user_mode args;
    1.54 +        else SOME (i+1) | _ => SOME (i+1)) args); (*FIXME dangling bounds should not occur*)
    1.55      val modes = filter (fn Mode (_, is, _) => is = user_mode)
    1.56        (modes_of_term (all_modes_of thy) (list_comb (pred, params)));
    1.57      val m = case modes
    1.58 @@ -1547,9 +1550,63 @@
    1.59        | [m] => m
    1.60        | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
    1.61                  ^ Syntax.string_of_term_global thy t_compr); m);
    1.62 -    val t_eval = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)),
    1.63 -      inargs)
    1.64 +    val (inargs, outargs) = get_args user_mode args;
    1.65 +    val t_pred = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)),
    1.66 +      inargs);
    1.67 +    val t_eval = if null outargs then t_pred else let
    1.68 +        val outargs_bounds = map (fn Bound i => i) outargs;
    1.69 +        val outargsTs = map (nth Ts) outargs_bounds;
    1.70 +        val T_pred = mk_tupleT outargsTs;
    1.71 +        val T_compr = HOLogic.mk_tupleT fp Ts;
    1.72 +        val arrange_bounds = map_index I outargs_bounds
    1.73 +          |> sort (prod_ord (K EQUAL) int_ord)
    1.74 +          |> map fst;
    1.75 +        val arrange = funpow (length outargs_bounds - 1) HOLogic.mk_split
    1.76 +          (Term.list_abs (map (pair "") outargsTs,
    1.77 +            HOLogic.mk_tuple fp T_compr (map Bound arrange_bounds)))
    1.78 +      in mk_pred_map T_pred T_compr arrange t_pred end
    1.79    in t_eval end;
    1.80  
    1.81 +fun eval thy t_compr =
    1.82 +  let
    1.83 +    val t = analyze_compr thy t_compr;
    1.84 +    val T = dest_pred_enumT (fastype_of t);
    1.85 +    val t' = mk_pred_map T HOLogic.termT (HOLogic.term_of_const T) t;
    1.86 +  in (T, Code_ML.eval NONE ("Predicate_Compile.eval_ref", eval_ref) Predicate.map thy t' []) end;
    1.87 +
    1.88 +fun values ctxt k t_compr =
    1.89 +  let
    1.90 +    val thy = ProofContext.theory_of ctxt;
    1.91 +    val (T, t) = eval thy t_compr;
    1.92 +    val setT = HOLogic.mk_setT T;
    1.93 +    val (ts, _) = Predicate.yieldn k t;
    1.94 +    val elemsT = HOLogic.mk_set T ts;
    1.95 +  in if k = ~1 orelse length ts < k then elemsT
    1.96 +    else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr
    1.97 +  end;
    1.98 +
    1.99 +fun values_cmd modes k raw_t state =
   1.100 +  let
   1.101 +    val ctxt = Toplevel.context_of state;
   1.102 +    val t = Syntax.read_term ctxt raw_t;
   1.103 +    val t' = values ctxt k t;
   1.104 +    val ty' = Term.type_of t';
   1.105 +    val ctxt' = Variable.auto_fixes t' ctxt;
   1.106 +    val p = PrintMode.with_modes modes (fn () =>
   1.107 +      Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
   1.108 +        Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
   1.109 +  in Pretty.writeln p end;
   1.110 +
   1.111 +local structure P = OuterParse in
   1.112 +
   1.113 +val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   1.114 +
   1.115 +val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag
   1.116 +  (opt_modes -- Scan.optional P.nat ~1 -- P.term
   1.117 +    >> (fn ((modes, k), t) => Toplevel.no_timing o Toplevel.keep
   1.118 +        (values_cmd modes k t)));
   1.119 +
   1.120  end;
   1.121  
   1.122 +end;
   1.123 +