src/HOL/Tools/function_package/fundef_common.ML
author krauss
Fri, 25 Apr 2008 16:28:06 +0200
changeset 26749 397a1aeede7d
parent 26748 4d51ddd6aa5c
child 26989 9b2acb536228
permissions -rw-r--r--
* New attribute "termination_simp": Simp rules for termination proofs
* General lemmas about list_size
     1 (*  Title:      HOL/Tools/function_package/fundef_common.ML
     2     ID:         $Id$
     3     Author:     Alexander Krauss, TU Muenchen
     4 
     5 A package for general recursive function definitions. 
     6 Common definitions and other infrastructure.
     7 *)
     8 
     9 structure FundefCommon =
    10 struct
    11 
    12 local open FundefLib in
    13 
    14 (* Profiling *)
    15 val profile = ref false;
    16 
    17 fun PROFILE msg = if !profile then timeap_msg msg else I
    18 
    19 
    20 val acc_const_name = @{const_name "accp"}
    21 fun mk_acc domT R =
    22     Const (acc_const_name, (domT --> domT --> HOLogic.boolT) --> domT --> HOLogic.boolT) $ R 
    23 
    24 val function_name = suffix "C"
    25 val graph_name = suffix "_graph"
    26 val rel_name = suffix "_rel"
    27 val dom_name = suffix "_dom"
    28 
    29 
    30 datatype fundef_result =
    31   FundefResult of
    32      {
    33       fs: term list,
    34       G: term,
    35       R: term,
    36 
    37       psimps : thm list, 
    38       trsimps : thm list option, 
    39 
    40       simple_pinducts : thm list, 
    41       cases : thm,
    42       termination : thm,
    43       domintros : thm list option
    44      }
    45 
    46 
    47 datatype fundef_context_data =
    48   FundefCtxData of
    49      {
    50       defname : string,
    51 
    52       (* contains no logical entities: invariant under morphisms *)
    53       add_simps : (string -> string) -> string -> Attrib.src list -> thm list 
    54                   -> local_theory -> thm list * local_theory,
    55       case_names : string list,
    56 
    57       fs : term list,
    58       R : term,
    59       
    60       psimps: thm list,
    61       pinducts: thm list,
    62       termination: thm
    63      }
    64 
    65 fun morph_fundef_data (FundefCtxData {add_simps, case_names, fs, R, 
    66                                       psimps, pinducts, termination, defname}) phi =
    67     let
    68       val term = Morphism.term phi val thm = Morphism.thm phi val fact = Morphism.fact phi
    69       val name = Morphism.name phi
    70     in
    71       FundefCtxData { add_simps = add_simps, case_names = case_names,
    72                       fs = map term fs, R = term R, psimps = fact psimps, 
    73                       pinducts = fact pinducts, termination = thm termination,
    74                       defname = name defname }
    75     end
    76 
    77 structure FundefData = GenericDataFun
    78 (
    79   type T = (term * fundef_context_data) NetRules.T;
    80   val empty = NetRules.init
    81     (op aconv o pairself fst : (term * fundef_context_data) * (term * fundef_context_data) -> bool)
    82     fst;
    83   val copy = I;
    84   val extend = I;
    85   fun merge _ (tab1, tab2) = NetRules.merge (tab1, tab2)
    86 );
    87 
    88 
    89 (* Generally useful?? *)
    90 fun lift_morphism thy f = 
    91     let 
    92       val term = Drule.term_rule thy f
    93     in
    94       Morphism.thm_morphism f $> Morphism.term_morphism term $> Morphism.typ_morphism (Logic.type_map term)
    95     end
    96 
    97 fun import_fundef_data t ctxt =
    98     let
    99       val thy = Context.theory_of ctxt
   100       val ct = cterm_of thy t
   101       val inst_morph = lift_morphism thy o Thm.instantiate 
   102 
   103       fun match (trm, data) = 
   104           SOME (morph_fundef_data data (inst_morph (Thm.match (cterm_of thy trm, ct))))
   105           handle Pattern.MATCH => NONE
   106     in 
   107       get_first match (NetRules.retrieve (FundefData.get ctxt) t)
   108     end
   109 
   110 fun import_last_fundef ctxt =
   111     case NetRules.rules (FundefData.get ctxt) of
   112       [] => NONE
   113     | (t, data) :: _ =>
   114       let 
   115         val ([t'], ctxt') = Variable.import_terms true [t] (Context.proof_of ctxt)
   116       in
   117         import_fundef_data t' (Context.Proof ctxt')
   118       end
   119 
   120 val all_fundef_data = NetRules.rules o FundefData.get
   121 
   122 structure TerminationSimps = NamedThmsFun(
   123   val name = "termination_simp" 
   124   val description = "Simplification rule for termination proofs"
   125 );
   126 
   127 structure TerminationRule = GenericDataFun
   128 (
   129   type T = thm list
   130   val empty = []
   131   val extend = I
   132   fun merge _ = Thm.merge_thms
   133 );
   134 
   135 val get_termination_rules = TerminationRule.get
   136 val store_termination_rule = TerminationRule.map o cons
   137 val apply_termination_rule = resolve_tac o get_termination_rules o Context.Proof
   138 
   139 fun add_fundef_data (data as FundefCtxData {fs, termination, ...}) =
   140     FundefData.map (fold (fn f => NetRules.insert (f, data)) fs)
   141     #> store_termination_rule termination
   142 
   143 (* Configuration management *)
   144 datatype fundef_opt 
   145   = Sequential
   146   | Default of string
   147   | Target of xstring
   148   | DomIntros
   149   | Tailrec
   150 
   151 datatype fundef_config
   152   = FundefConfig of
   153    {
   154     sequential: bool,
   155     default: string,
   156     target: xstring option,
   157     domintros: bool,
   158     tailrec: bool
   159    }
   160 
   161 fun apply_opt Sequential (FundefConfig {sequential, default, target, domintros,tailrec}) = 
   162     FundefConfig {sequential=true, default=default, target=target, domintros=domintros, tailrec=tailrec}
   163   | apply_opt (Default d) (FundefConfig {sequential, default, target, domintros,tailrec}) = 
   164     FundefConfig {sequential=sequential, default=d, target=target, domintros=domintros, tailrec=tailrec}
   165   | apply_opt (Target t) (FundefConfig {sequential, default, target, domintros,tailrec}) =
   166     FundefConfig {sequential=sequential, default=default, target=SOME t, domintros=domintros, tailrec=tailrec}
   167   | apply_opt DomIntros (FundefConfig {sequential, default, target, domintros,tailrec}) =
   168     FundefConfig {sequential=sequential, default=default, target=target, domintros=true,tailrec=tailrec}
   169   | apply_opt Tailrec (FundefConfig {sequential, default, target, domintros,tailrec}) =
   170     FundefConfig {sequential=sequential, default=default, target=target, domintros=domintros,tailrec=true}
   171 
   172 fun target_of (FundefConfig {target, ...}) = target
   173 
   174 val default_config = FundefConfig { sequential=false, default="%x. arbitrary", 
   175                                     target=NONE, domintros=false, tailrec=false }
   176 
   177 
   178 (* Common operations on equations *)
   179 
   180 fun open_all_all (Const ("all", _) $ Abs (n, T, b)) = apfst (cons (n, T)) (open_all_all b)
   181   | open_all_all t = ([], t)
   182 
   183 fun split_def ctxt geq =
   184     let
   185       fun input_error msg = cat_lines [msg, Syntax.string_of_term ctxt geq]
   186       val (qs, imp) = open_all_all geq
   187       val (gs, eq) = Logic.strip_horn imp
   188 
   189       val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
   190           handle TERM _ => error (input_error "Not an equation")
   191 
   192       val (head, args) = strip_comb f_args
   193 
   194       val fname = fst (dest_Free head)
   195           handle TERM _ => error (input_error "Head symbol must not be a bound variable")
   196     in
   197       (fname, qs, gs, args, rhs)
   198     end
   199 
   200 exception ArgumentCount of string
   201 
   202 fun mk_arities fqgars =
   203     let fun f (fname, _, _, args, _) arities =
   204             let val k = length args
   205             in
   206               case Symtab.lookup arities fname of
   207                 NONE => Symtab.update (fname, k) arities
   208               | SOME i => (if i = k then arities else raise ArgumentCount fname)
   209             end
   210     in
   211       fold f fqgars Symtab.empty
   212     end
   213 
   214 
   215 (* Check for all sorts of errors in the input *)
   216 fun check_defs ctxt fixes eqs =
   217     let
   218       val fnames = map (fst o fst) fixes
   219                                 
   220       fun check geq = 
   221           let
   222             fun input_error msg = cat_lines [msg, Syntax.string_of_term ctxt geq]
   223                                   
   224             val fqgar as (fname, qs, gs, args, rhs) = split_def ctxt geq
   225                                  
   226             val _ = fname mem fnames 
   227                     orelse error (input_error ("Head symbol of left hand side must be " ^ plural "" "one out of " fnames 
   228                                                ^ commas_quote fnames))
   229                                             
   230             fun add_bvs t is = add_loose_bnos (t, 0, is)
   231             val rvs = (add_bvs rhs [] \\ fold add_bvs args [])
   232                         |> map (fst o nth (rev qs))
   233                       
   234             val _ = null rvs orelse error (input_error ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs
   235                                                         ^ " occur" ^ plural "s" "" rvs ^ " on right hand side only:"))
   236                                     
   237             val _ = forall (not o Term.exists_subterm (fn Free (n, _) => n mem fnames | _ => false)) gs 
   238                     orelse error (input_error "Recursive Calls not allowed in premises")
   239 
   240             val freeargs = map (fn t => subst_bounds (rev (map Free qs), t)) args
   241             val funvars = filter (fn q => exists (exists_subterm (fn (Free q') $ _ => q = q' | _ => false)) freeargs) qs
   242             val _ = null funvars
   243                     orelse (warning (cat_lines ["Bound variable" ^ plural " " "s " funvars ^ commas_quote (map fst funvars) ^  
   244                                                 " occur" ^ plural "s" "" funvars ^ " in function position.",  
   245                                                 "Misspelled constructor???"]); true)
   246           in
   247             fqgar
   248           end
   249           
   250       fun check_sorts ((fname, fT), _) =
   251           Sorts.of_sort (Sign.classes_of (ProofContext.theory_of ctxt)) (fT, HOLogic.typeS)
   252           orelse error ("Type of " ^ quote fname ^ " is not of sort " ^ quote "type" ^ ".")
   253 
   254       val _ = map check_sorts fixes
   255 
   256       val _ = mk_arities (map check eqs)
   257           handle ArgumentCount fname => 
   258                  error ("Function " ^ quote fname ^ " has different numbers of arguments in different equations")
   259     in
   260       ()
   261     end
   262 
   263 (* Preprocessors *)
   264 
   265 type fixes = ((string * typ) * mixfix) list
   266 type 'a spec = ((bstring * Attrib.src list) * 'a list) list
   267 type preproc = fundef_config -> bool list -> Proof.context -> fixes -> term spec 
   268                -> (term list * (thm list -> thm spec) * (thm list -> thm list list) * string list)
   269 
   270 val fname_of = fst o dest_Free o fst o strip_comb o fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o Logic.strip_imp_concl o snd o dest_all_all
   271 
   272 fun mk_case_names i "" k = mk_case_names i (string_of_int (i + 1)) k
   273   | mk_case_names _ n 0 = []
   274   | mk_case_names _ n 1 = [n]
   275   | mk_case_names _ n k = map (fn i => n ^ "_" ^ string_of_int i) (1 upto k)
   276 
   277 fun empty_preproc check _ _ ctxt fixes spec =
   278     let 
   279       val (nas,tss) = split_list spec
   280       val ts = flat tss
   281       val _ = check ctxt fixes ts
   282       val fnames = map (fst o fst) fixes
   283       val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) ts
   284 
   285       fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   286                         |> map (map snd)
   287 
   288       (* using theorem names for case name currently disabled *)
   289       val cnames = map_index (fn (i, (n,_)) => mk_case_names i "" 1) nas |> flat
   290     in
   291       (ts, curry op ~~ nas o Library.unflat tss, sort, cnames)
   292     end
   293 
   294 structure Preprocessor = GenericDataFun
   295 (
   296   type T = preproc
   297   val empty : T = empty_preproc check_defs
   298   val extend = I
   299   fun merge _ (a, _) = a
   300 );
   301 
   302 val get_preproc = Preprocessor.get o Context.Proof
   303 val set_preproc = Preprocessor.map o K
   304 
   305 
   306 
   307 local 
   308   structure P = OuterParse and K = OuterKeyword
   309 
   310   val option_parser = (P.reserved "sequential" >> K Sequential)
   311                    || ((P.reserved "default" |-- P.term) >> Default)
   312                    || (P.reserved "domintros" >> K DomIntros)
   313                    || (P.reserved "tailrec" >> K Tailrec)
   314                    || ((P.$$$ "in" |-- P.xname) >> Target)
   315 
   316   fun config_parser default = (Scan.optional (P.$$$ "(" |-- P.!!! (P.list1 (P.group "option" option_parser)) --| P.$$$ ")") [])
   317                               >> (fn opts => fold apply_opt opts default)
   318 
   319   val otherwise = P.$$$ "(" |-- P.$$$ "otherwise" --| P.$$$ ")"
   320 
   321   fun pipe_error t = P.!!! (Scan.fail_with (K (cat_lines ["Equations must be separated by " ^ quote "|", quote t])))
   322 
   323   val statement_ow = SpecParse.opt_thm_name ":" -- (P.prop -- Scan.optional (otherwise >> K true) false)
   324                      --| Scan.ahead ((P.term :-- pipe_error) || Scan.succeed ("",""))
   325 
   326   val statements_ow = P.enum1 "|" statement_ow
   327 
   328   val flags_statements = statements_ow
   329                          >> (fn sow => (map (snd o snd) sow, map (apsnd fst) sow))
   330 in
   331   fun fundef_parser default_cfg = (config_parser default_cfg -- P.fixes --| P.$$$ "where" -- flags_statements)
   332 end
   333 
   334 
   335 end
   336 end
   337