src/HOL/Library/simps_case_conv.ML
author noschinl
Fri, 06 Sep 2013 10:56:40 +0200
changeset 54566 9d9945941eab
parent 54563 92db671e0ac6
child 54570 3b356b7f7cad
permissions -rw-r--r--
allowed less exhaustive patterns
noschinl@54563
     1
(*  Title:      HOL/Library/simps_case_conv.ML
noschinl@54563
     2
    Author:     Lars Noschinski, TU Muenchen
noschinl@54563
     3
                Gerwin Klein, NICTA
noschinl@54563
     4
noschinl@54563
     5
  Converts function specifications between the representation as
noschinl@54563
     6
  a list of equations (with patterns on the lhs) and a single
noschinl@54563
     7
  equation (with a nested case expression on the rhs).
noschinl@54563
     8
*)
noschinl@54563
     9
noschinl@54563
    10
signature SIMPS_CASE_CONV =
noschinl@54563
    11
sig
noschinl@54563
    12
  val to_case: Proof.context -> thm list -> thm
noschinl@54563
    13
  val gen_to_simps: Proof.context -> thm list -> thm -> thm list
noschinl@54563
    14
  val to_simps: Proof.context -> thm -> thm list
noschinl@54563
    15
end
noschinl@54563
    16
noschinl@54563
    17
structure Simps_Case_Conv: SIMPS_CASE_CONV =
noschinl@54563
    18
struct
noschinl@54563
    19
noschinl@54563
    20
(* Collects all type constructors in a type *)
noschinl@54563
    21
fun collect_Tcons (Type (name,Ts)) = name :: maps collect_Tcons Ts
noschinl@54563
    22
  | collect_Tcons (TFree _) = []
noschinl@54563
    23
  | collect_Tcons (TVar _) = []
noschinl@54563
    24
noschinl@54563
    25
fun get_split_ths thy = collect_Tcons
noschinl@54563
    26
    #> distinct (op =)
noschinl@54563
    27
    #> map_filter (Datatype_Data.get_info thy)
noschinl@54563
    28
    #> map #split
noschinl@54563
    29
noschinl@54563
    30
val strip_eq = prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
noschinl@54563
    31
noschinl@54563
    32
noschinl@54563
    33
local
noschinl@54563
    34
noschinl@54566
    35
  fun transpose [] = []
noschinl@54566
    36
    | transpose ([] :: xss) = transpose xss
noschinl@54566
    37
    | transpose xss = map hd xss :: transpose (map tl xss);
noschinl@54566
    38
noschinl@54566
    39
  fun same_fun (ts as _ $ _ :: _) =
noschinl@54566
    40
      let
noschinl@54566
    41
        val (fs, argss) = map strip_comb ts |> split_list
noschinl@54566
    42
        val f = hd fs
noschinl@54566
    43
      in if forall (fn x => f = x) fs then SOME (f, argss) else NONE end
noschinl@54566
    44
    | same_fun _ = NONE
noschinl@54566
    45
noschinl@54566
    46
  (* pats must be non-empty *)
noschinl@54566
    47
  fun split_pat pats ctxt =
noschinl@54566
    48
      case same_fun pats of
noschinl@54566
    49
        NONE =>
noschinl@54566
    50
          let
noschinl@54566
    51
            val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
noschinl@54566
    52
            val var = Free (name, fastype_of (hd pats))
noschinl@54566
    53
          in (((var, [var]), map single pats), ctxt') end
noschinl@54566
    54
      | SOME (f, argss) =>
noschinl@54566
    55
          let
noschinl@54566
    56
            val (((def_pats, def_frees), case_patss), ctxt') =
noschinl@54566
    57
              split_pats argss ctxt
noschinl@54566
    58
            val def_pat = list_comb (f, def_pats)
noschinl@54566
    59
          in (((def_pat, flat def_frees), case_patss), ctxt') end
noschinl@54566
    60
  and
noschinl@54566
    61
      split_pats patss ctxt =
noschinl@54566
    62
        let
noschinl@54566
    63
          val (splitted, ctxt') = fold_map split_pat (transpose patss) ctxt
noschinl@54566
    64
          val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
noschinl@54566
    65
        in (r, ctxt') end
noschinl@54566
    66
noschinl@54566
    67
(*
noschinl@54566
    68
  Takes a list lhss of left hand sides (which are lists of patterns)
noschinl@54566
    69
  and a list rhss of right hand sides. Returns
noschinl@54566
    70
    - a single equation with a (nested) case-expression on the rhs
noschinl@54566
    71
    - a list of all split-thms needed to split the rhs
noschinl@54566
    72
  Patterns which have the same outer context in all lhss remain
noschinl@54566
    73
  on the lhs of the computed equation.
noschinl@54566
    74
*)
noschinl@54566
    75
fun build_case_t fun_t lhss rhss ctxt =
noschinl@54563
    76
  let
noschinl@54566
    77
    val (((def_pats, def_frees), case_patss), ctxt') =
noschinl@54566
    78
      split_pats lhss ctxt
noschinl@54566
    79
    val pattern = map HOLogic.mk_tuple case_patss
noschinl@54566
    80
    val case_arg = HOLogic.mk_tuple (flat def_frees)
noschinl@54566
    81
    val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
noschinl@54566
    82
      case_arg (pattern ~~ rhss)
noschinl@54566
    83
    val split_thms = get_split_ths (Proof_Context.theory_of ctxt') (fastype_of case_arg)
noschinl@54566
    84
    val t = (list_comb (fun_t, def_pats), cases)
noschinl@54566
    85
      |> HOLogic.mk_eq
noschinl@54566
    86
      |> HOLogic.mk_Trueprop
noschinl@54566
    87
  in ((t, split_thms), ctxt') end
noschinl@54563
    88
noschinl@54563
    89
fun tac ctxt {splits, intros, defs} =
noschinl@54563
    90
  let val ctxt' = Classical.addSIs (ctxt, intros) in
noschinl@54563
    91
    REPEAT_DETERM1 (FIRSTGOAL (split_tac splits))
noschinl@54563
    92
    THEN Local_Defs.unfold_tac ctxt defs
noschinl@54563
    93
    THEN safe_tac ctxt'
noschinl@54563
    94
  end
noschinl@54563
    95
noschinl@54563
    96
fun import [] ctxt = ([], ctxt)
noschinl@54563
    97
  | import (thm :: thms) ctxt =
noschinl@54563
    98
    let
noschinl@54563
    99
      val fun_ct = strip_eq #> fst #> strip_comb #> fst #> Logic.mk_term
noschinl@54563
   100
        #> Thm.cterm_of (Proof_Context.theory_of ctxt)
noschinl@54563
   101
      val ct = fun_ct thm
noschinl@54563
   102
      val cts = map fun_ct thms
noschinl@54563
   103
      val pairs = map (fn s => (s,ct)) cts
noschinl@54563
   104
      val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
noschinl@54563
   105
    in Variable.import true (thm :: thms') ctxt |> apfst snd end
noschinl@54563
   106
noschinl@54563
   107
in
noschinl@54563
   108
noschinl@54563
   109
(*
noschinl@54563
   110
  For a list
noschinl@54563
   111
    f p_11 ... p_1n = t1
noschinl@54563
   112
    f p_21 ... p_2n = t2
noschinl@54563
   113
    ...
noschinl@54563
   114
    f p_mn ... p_mn = tm
noschinl@54563
   115
  of theorems, prove a single theorem
noschinl@54563
   116
    f x1 ... xn = t
noschinl@54566
   117
  where t is a (nested) case expression. f must not be a function
noschinl@54566
   118
  application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
noschinl@54566
   119
  datatype patterns. The patterns must be exhausting up to common constructor
noschinl@54566
   120
  contexts.
noschinl@54563
   121
*)
noschinl@54563
   122
fun to_case ctxt ths =
noschinl@54563
   123
  let
noschinl@54563
   124
    val (iths, ctxt') = import ths ctxt
noschinl@54566
   125
    val fun_t = hd iths |> strip_eq |> fst |> head_of
noschinl@54563
   126
    val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
noschinl@54563
   127
noschinl@54563
   128
    fun hide_rhs ((pat, rhs), name) lthy = let
noschinl@54563
   129
        val frees = fold Term.add_frees pat []
noschinl@54563
   130
        val abs_rhs = fold absfree frees rhs
noschinl@54563
   131
        val ((f,def), lthy') = Local_Defs.add_def
noschinl@54563
   132
          ((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
noschinl@54563
   133
      in ((list_comb (f, map Free (rev frees)), def), lthy') end
noschinl@54563
   134
noschinl@54566
   135
    val ((def_ts, def_thms), ctxt2) = let
noschinl@54566
   136
        val nctxt = Variable.names_of ctxt'
noschinl@54563
   137
        val names = Name.invent nctxt "rhs" (length eqs)
noschinl@54566
   138
      in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
noschinl@54563
   139
noschinl@54566
   140
    val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
noschinl@54563
   141
noschinl@54563
   142
    val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
noschinl@54563
   143
          tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
noschinl@54563
   144
  in th
noschinl@54563
   145
    |> singleton (Proof_Context.export ctxt3 ctxt)
noschinl@54563
   146
    |> Goal.norm_result
noschinl@54563
   147
  end
noschinl@54563
   148
noschinl@54563
   149
end
noschinl@54563
   150
noschinl@54563
   151
local
noschinl@54563
   152
noschinl@54563
   153
fun was_split t =
noschinl@54563
   154
  let
noschinl@54563
   155
    val is_free_eq_imp = is_Free o fst o HOLogic.dest_eq
noschinl@54563
   156
              o fst o HOLogic.dest_imp
noschinl@54563
   157
    val get_conjs = HOLogic.dest_conj o HOLogic.dest_Trueprop
noschinl@54563
   158
    fun dest_alls (Const ("HOL.All", _) $ Abs (_, _, t)) = dest_alls t
noschinl@54563
   159
      | dest_alls t = t
noschinl@54563
   160
  in forall (is_free_eq_imp o dest_alls) (get_conjs t) end
noschinl@54563
   161
        handle TERM _ => false
noschinl@54563
   162
noschinl@54563
   163
fun apply_split ctxt split thm = Seq.of_list
noschinl@54563
   164
  let val ((_,thm'), ctxt') = Variable.import false [thm] ctxt in
noschinl@54563
   165
    (Variable.export ctxt' ctxt) (filter (was_split o prop_of) (thm' RL [split]))
noschinl@54563
   166
  end
noschinl@54563
   167
noschinl@54563
   168
fun forward_tac rules t = Seq.of_list ([t] RL rules)
noschinl@54563
   169
noschinl@54563
   170
val refl_imp = refl RSN (2, mp)
noschinl@54563
   171
noschinl@54563
   172
val get_rules_once_split =
noschinl@54563
   173
  REPEAT (forward_tac [conjunct1, conjunct2])
noschinl@54563
   174
    THEN REPEAT (forward_tac [spec])
noschinl@54563
   175
    THEN (forward_tac [refl_imp])
noschinl@54563
   176
noschinl@54563
   177
fun do_split ctxt split =
noschinl@54563
   178
  let
noschinl@54563
   179
    val split' = split RS iffD1;
noschinl@54563
   180
    val split_rhs = concl_of (hd (snd (fst (Variable.import false [split'] ctxt))))
noschinl@54563
   181
  in if was_split split_rhs
noschinl@54563
   182
     then DETERM (apply_split ctxt split') THEN get_rules_once_split
noschinl@54563
   183
     else raise TERM ("malformed split rule", [split_rhs])
noschinl@54563
   184
  end
noschinl@54563
   185
noschinl@54563
   186
val atomize_meta_eq = forward_tac [meta_eq_to_obj_eq]
noschinl@54563
   187
noschinl@54563
   188
in
noschinl@54563
   189
noschinl@54563
   190
fun gen_to_simps ctxt splitthms thm =
noschinl@54563
   191
  Seq.list_of ((TRY atomize_meta_eq
noschinl@54563
   192
                 THEN (REPEAT (FIRST (map (do_split ctxt) splitthms)))) thm)
noschinl@54563
   193
noschinl@54563
   194
fun to_simps ctxt thm =
noschinl@54563
   195
  let
noschinl@54563
   196
    val T = thm |> strip_eq |> fst |> strip_comb |> fst |> fastype_of
noschinl@54563
   197
    val splitthms = get_split_ths (Proof_Context.theory_of ctxt) T
noschinl@54563
   198
  in gen_to_simps ctxt splitthms thm end
noschinl@54563
   199
noschinl@54563
   200
noschinl@54563
   201
end
noschinl@54563
   202
noschinl@54563
   203
fun case_of_simps_cmd (bind, thms_ref) lthy =
noschinl@54563
   204
  let
noschinl@54563
   205
    val thy = Proof_Context.theory_of lthy
noschinl@54563
   206
    val bind' = apsnd (map (Attrib.intern_src thy)) bind
noschinl@54563
   207
    val thm = (Attrib.eval_thms lthy) thms_ref |> to_case lthy
noschinl@54563
   208
  in
noschinl@54563
   209
    Local_Theory.note (bind', [thm]) lthy |> snd
noschinl@54563
   210
  end
noschinl@54563
   211
noschinl@54563
   212
fun simps_of_case_cmd ((bind, thm_ref), splits_ref) lthy =
noschinl@54563
   213
  let
noschinl@54563
   214
    val thy = Proof_Context.theory_of lthy
noschinl@54563
   215
    val bind' = apsnd (map (Attrib.intern_src thy)) bind
noschinl@54563
   216
    val thm = singleton (Attrib.eval_thms lthy) thm_ref
noschinl@54563
   217
    val simps = if null splits_ref
noschinl@54563
   218
      then to_simps lthy thm
noschinl@54563
   219
      else gen_to_simps lthy (Attrib.eval_thms lthy splits_ref) thm
noschinl@54563
   220
  in
noschinl@54563
   221
    Local_Theory.note (bind', simps) lthy |> snd
noschinl@54563
   222
  end
noschinl@54563
   223
noschinl@54563
   224
val _ =
noschinl@54563
   225
  Outer_Syntax.local_theory @{command_spec "case_of_simps"}
noschinl@54563
   226
    "turns a list of equations into a case expression"
noschinl@54563
   227
    (Parse_Spec.opt_thm_name ":"  -- Parse_Spec.xthms1 >> case_of_simps_cmd)
noschinl@54563
   228
noschinl@54563
   229
val parse_splits = @{keyword "("} |-- Parse.reserved "splits" |-- @{keyword ":"} |--
noschinl@54563
   230
  Parse_Spec.xthms1 --| @{keyword ")"}
noschinl@54563
   231
noschinl@54563
   232
val _ =
noschinl@54563
   233
  Outer_Syntax.local_theory @{command_spec "simps_of_case"}
noschinl@54563
   234
    "perform case split on rule"
noschinl@54563
   235
    (Parse_Spec.opt_thm_name ":"  -- Parse_Spec.xthm --
noschinl@54563
   236
      Scan.optional parse_splits [] >> simps_of_case_cmd)
noschinl@54563
   237
noschinl@54563
   238
end
noschinl@54563
   239