src/HOL/Tools/datatype_codegen.ML
author haftmann
Thu, 14 May 2009 15:09:48 +0200
changeset 31156 90fed3d4430f
parent 31151 1c64b0827ee8
child 31240 251a34663242
permissions -rw-r--r--
merged module code_unit.ML into code.ML
     1 (*  Title:      HOL/Tools/datatype_codegen.ML
     2     Author:     Stefan Berghofer and Florian Haftmann, TU Muenchen
     3 
     4 Code generator facilities for inductive datatypes.
     5 *)
     6 
     7 signature DATATYPE_CODEGEN =
     8 sig
     9   val mk_eq_eqns: theory -> string -> (thm * bool) list
    10   val mk_case_cert: theory -> string -> thm
    11   val setup: theory -> theory
    12 end;
    13 
    14 structure DatatypeCodegen : DATATYPE_CODEGEN =
    15 struct
    16 
    17 (** SML code generator **)
    18 
    19 open Codegen;
    20 
    21 (**** datatype definition ****)
    22 
    23 (* find shortest path to constructor with no recursive arguments *)
    24 
    25 fun find_nonempty (descr: DatatypeAux.descr) is i =
    26   let
    27     val (_, _, constrs) = valOf (AList.lookup (op =) descr i);
    28     fun arg_nonempty (_, DatatypeAux.DtRec i) = if i mem is then NONE
    29           else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i)
    30       | arg_nonempty _ = SOME 0;
    31     fun max xs = Library.foldl
    32       (fn (NONE, _) => NONE
    33         | (SOME i, SOME j) => SOME (Int.max (i, j))
    34         | (_, NONE) => NONE) (SOME 0, xs);
    35     val xs = sort (int_ord o pairself snd)
    36       (List.mapPartial (fn (s, dts) => Option.map (pair s)
    37         (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
    38   in case xs of [] => NONE | x :: _ => SOME x end;
    39 
    40 fun add_dt_defs thy defs dep module (descr: DatatypeAux.descr) sorts gr =
    41   let
    42     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
    43     val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
    44       exists (exists DatatypeAux.is_rec_type o snd) cs) descr');
    45 
    46     val (_, (tname, _, _)) :: _ = descr';
    47     val node_id = tname ^ " (type)";
    48     val module' = if_library (thyname_of_type thy tname) module;
    49 
    50     fun mk_dtdef prfx [] gr = ([], gr)
    51       | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr =
    52           let
    53             val tvs = map DatatypeAux.dest_DtTFree dts;
    54             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    55             val ((_, type_id), gr') = mk_type_id module' tname gr;
    56             val (ps, gr'') = gr' |>
    57               fold_map (fn (cname, cargs) =>
    58                 fold_map (invoke_tycodegen thy defs node_id module' false)
    59                   cargs ##>>
    60                 mk_const_id module' cname) cs';
    61             val (rest, gr''') = mk_dtdef "and " xs gr''
    62           in
    63             (Pretty.block (str prfx ::
    64                (if null tvs then [] else
    65                   [mk_tuple (map str tvs), str " "]) @
    66                [str (type_id ^ " ="), Pretty.brk 1] @
    67                List.concat (separate [Pretty.brk 1, str "| "]
    68                  (map (fn (ps', (_, cname)) => [Pretty.block
    69                    (str cname ::
    70                     (if null ps' then [] else
    71                      List.concat ([str " of", Pretty.brk 1] ::
    72                        separate [str " *", Pretty.brk 1]
    73                          (map single ps'))))]) ps))) :: rest, gr''')
    74           end;
    75 
    76     fun mk_constr_term cname Ts T ps =
    77       List.concat (separate [str " $", Pretty.brk 1]
    78         ([str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
    79           mk_type false (Ts ---> T), str ")"] :: ps));
    80 
    81     fun mk_term_of_def gr prfx [] = []
    82       | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
    83           let
    84             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    85             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
    86             val T = Type (tname, dts');
    87             val rest = mk_term_of_def gr "and " xs;
    88             val (_, eqs) = Library.foldl_map (fn (prfx, (cname, Ts)) =>
    89               let val args = map (fn i =>
    90                 str ("x" ^ string_of_int i)) (1 upto length Ts)
    91               in ("  | ", Pretty.blk (4,
    92                 [str prfx, mk_term_of gr module' false T, Pretty.brk 1,
    93                  if null Ts then str (snd (get_const_id gr cname))
    94                  else parens (Pretty.block
    95                    [str (snd (get_const_id gr cname)),
    96                     Pretty.brk 1, mk_tuple args]),
    97                  str " =", Pretty.brk 1] @
    98                  mk_constr_term cname Ts T
    99                    (map (fn (x, U) => [Pretty.block [mk_term_of gr module' false U,
   100                       Pretty.brk 1, x]]) (args ~~ Ts))))
   101               end) (prfx, cs')
   102           in eqs @ rest end;
   103 
   104     fun mk_gen_of_def gr prfx [] = []
   105       | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
   106           let
   107             val tvs = map DatatypeAux.dest_DtTFree dts;
   108             val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
   109             val T = Type (tname, Us);
   110             val (cs1, cs2) =
   111               List.partition (exists DatatypeAux.is_rec_type o snd) cs;
   112             val SOME (cname, _) = find_nonempty descr [i] i;
   113 
   114             fun mk_delay p = Pretty.block
   115               [str "fn () =>", Pretty.brk 1, p];
   116 
   117             fun mk_force p = Pretty.block [p, Pretty.brk 1, str "()"];
   118 
   119             fun mk_constr s b (cname, dts) =
   120               let
   121                 val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
   122                     (DatatypeAux.typ_of_dtyp descr sorts dt))
   123                   [str (if b andalso DatatypeAux.is_rec_type dt then "0"
   124                      else "j")]) dts;
   125                 val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
   126                 val xs = map str
   127                   (DatatypeProp.indexify_names (replicate (length dts) "x"));
   128                 val ts = map str
   129                   (DatatypeProp.indexify_names (replicate (length dts) "t"));
   130                 val (_, id) = get_const_id gr cname
   131               in
   132                 mk_let
   133                   (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
   134                   (mk_tuple
   135                     [case xs of
   136                        _ :: _ :: _ => Pretty.block
   137                          [str id, Pretty.brk 1, mk_tuple xs]
   138                      | _ => mk_app false (str id) xs,
   139                      mk_delay (Pretty.block (mk_constr_term cname Ts T
   140                        (map (single o mk_force) ts)))])
   141               end;
   142 
   143             fun mk_choice [c] = mk_constr "(i-1)" false c
   144               | mk_choice cs = Pretty.block [str "one_of",
   145                   Pretty.brk 1, Pretty.blk (1, str "[" ::
   146                   List.concat (separate [str ",", Pretty.fbrk]
   147                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
   148                   [str "]"]), Pretty.brk 1, str "()"];
   149 
   150             val gs = maps (fn s =>
   151               let val s' = strip_tname s
   152               in [str (s' ^ "G"), str (s' ^ "T")] end) tvs;
   153             val gen_name = "gen_" ^ snd (get_type_id gr tname)
   154 
   155           in
   156             Pretty.blk (4, separate (Pretty.brk 1) 
   157                 (str (prfx ^ gen_name ^
   158                    (if null cs1 then "" else "'")) :: gs @
   159                  (if null cs1 then [] else [str "i"]) @
   160                  [str "j"]) @
   161               [str " =", Pretty.brk 1] @
   162               (if not (null cs1) andalso not (null cs2)
   163                then [str "frequency", Pretty.brk 1,
   164                  Pretty.blk (1, [str "[",
   165                    mk_tuple [str "i", mk_delay (mk_choice cs1)],
   166                    str ",", Pretty.fbrk,
   167                    mk_tuple [str "1", mk_delay (mk_choice cs2)],
   168                    str "]"]), Pretty.brk 1, str "()"]
   169                else if null cs2 then
   170                  [Pretty.block [str "(case", Pretty.brk 1,
   171                    str "i", Pretty.brk 1, str "of",
   172                    Pretty.brk 1, str "0 =>", Pretty.brk 1,
   173                    mk_constr "0" true (cname, valOf (AList.lookup (op =) cs cname)),
   174                    Pretty.brk 1, str "| _ =>", Pretty.brk 1,
   175                    mk_choice cs1, str ")"]]
   176                else [mk_choice cs2])) ::
   177             (if null cs1 then []
   178              else [Pretty.blk (4, separate (Pretty.brk 1) 
   179                  (str ("and " ^ gen_name) :: gs @ [str "i"]) @
   180                [str " =", Pretty.brk 1] @
   181                separate (Pretty.brk 1) (str (gen_name ^ "'") :: gs @
   182                  [str "i", str "i"]))]) @
   183             mk_gen_of_def gr "and " xs
   184           end
   185 
   186   in
   187     (module', (add_edge_acyclic (node_id, dep) gr
   188         handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
   189          let
   190            val gr1 = add_edge (node_id, dep)
   191              (new_node (node_id, (NONE, "", "")) gr);
   192            val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ;
   193          in
   194            map_node node_id (K (NONE, module',
   195              string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
   196                [str ";"])) ^ "\n\n" ^
   197              (if "term_of" mem !mode then
   198                 string_of (Pretty.blk (0, separate Pretty.fbrk
   199                   (mk_term_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
   200               else "") ^
   201              (if "test" mem !mode then
   202                 string_of (Pretty.blk (0, separate Pretty.fbrk
   203                   (mk_gen_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
   204               else ""))) gr2
   205          end)
   206   end;
   207 
   208 
   209 (**** case expressions ****)
   210 
   211 fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr =
   212   let val i = length constrs
   213   in if length ts <= i then
   214        invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr
   215     else
   216       let
   217         val ts1 = Library.take (i, ts);
   218         val t :: ts2 = Library.drop (i, ts);
   219         val names = List.foldr OldTerm.add_term_names
   220           (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1;
   221         val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
   222 
   223         fun pcase [] [] [] gr = ([], gr)
   224           | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
   225               let
   226                 val j = length cargs;
   227                 val xs = Name.variant_list names (replicate j "x");
   228                 val Us' = Library.take (j, fst (strip_type U));
   229                 val frees = map Free (xs ~~ Us');
   230                 val (cp, gr0) = invoke_codegen thy defs dep module false
   231                   (list_comb (Const (cname, Us' ---> dT), frees)) gr;
   232                 val t' = Envir.beta_norm (list_comb (t, frees));
   233                 val (p, gr1) = invoke_codegen thy defs dep module false t' gr0;
   234                 val (ps, gr2) = pcase cs ts Us gr1;
   235               in
   236                 ([Pretty.block [cp, str " =>", Pretty.brk 1, p]] :: ps, gr2)
   237               end;
   238 
   239         val (ps1, gr1) = pcase constrs ts1 Ts gr ;
   240         val ps = List.concat (separate [Pretty.brk 1, str "| "] ps1);
   241         val (p, gr2) = invoke_codegen thy defs dep module false t gr1;
   242         val (ps2, gr3) = fold_map (invoke_codegen thy defs dep module true) ts2 gr2;
   243       in ((if not (null ts2) andalso brack then parens else I)
   244         (Pretty.block (separate (Pretty.brk 1)
   245           (Pretty.block ([str "(case ", p, str " of",
   246              Pretty.brk 1] @ ps @ [str ")"]) :: ps2))), gr3)
   247       end
   248   end;
   249 
   250 
   251 (**** constructors ****)
   252 
   253 fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr =
   254   let val i = length args
   255   in if i > 1 andalso length ts < i then
   256       invoke_codegen thy defs dep module brack (eta_expand c ts i) gr
   257      else
   258        let
   259          val id = mk_qual_id module (get_const_id gr s);
   260          val (ps, gr') = fold_map
   261            (invoke_codegen thy defs dep module (i = 1)) ts gr;
   262        in (case args of
   263           _ :: _ :: _ => (if brack then parens else I)
   264             (Pretty.block [str id, Pretty.brk 1, mk_tuple ps])
   265         | _ => (mk_app brack (str id) ps), gr')
   266        end
   267   end;
   268 
   269 
   270 (**** code generators for terms and types ****)
   271 
   272 fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of
   273    (c as Const (s, T), ts) =>
   274      (case DatatypePackage.datatype_of_case thy s of
   275         SOME {index, descr, ...} =>
   276           if is_some (get_assoc_code thy (s, T)) then NONE else
   277           SOME (pretty_case thy defs dep module brack
   278             (#3 (the (AList.lookup op = descr index))) c ts gr )
   279       | NONE => case (DatatypePackage.datatype_of_constr thy s, strip_type T) of
   280         (SOME {index, descr, ...}, (_, U as Type (tyname, _))) =>
   281           if is_some (get_assoc_code thy (s, T)) then NONE else
   282           let
   283             val SOME (tyname', _, constrs) = AList.lookup op = descr index;
   284             val SOME args = AList.lookup op = constrs s
   285           in
   286             if tyname <> tyname' then NONE
   287             else SOME (pretty_constr thy defs
   288               dep module brack args c ts (snd (invoke_tycodegen thy defs dep module false U gr)))
   289           end
   290       | _ => NONE)
   291  | _ => NONE);
   292 
   293 fun datatype_tycodegen thy defs dep module brack (Type (s, Ts)) gr =
   294       (case DatatypePackage.get_datatype thy s of
   295          NONE => NONE
   296        | SOME {descr, sorts, ...} =>
   297            if is_some (get_assoc_type thy s) then NONE else
   298            let
   299              val (ps, gr') = fold_map
   300                (invoke_tycodegen thy defs dep module false) Ts gr;
   301              val (module', gr'') = add_dt_defs thy defs dep module descr sorts gr' ;
   302              val (tyid, gr''') = mk_type_id module' s gr''
   303            in SOME (Pretty.block ((if null Ts then [] else
   304                [mk_tuple ps, str " "]) @
   305                [str (mk_qual_id module tyid)]), gr''')
   306            end)
   307   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
   308 
   309 
   310 (** generic code generator **)
   311 
   312 (* case certificates *)
   313 
   314 fun mk_case_cert thy tyco =
   315   let
   316     val raw_thms =
   317       (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
   318     val thms as hd_thm :: _ = raw_thms
   319       |> Conjunction.intr_balanced
   320       |> Thm.unvarify
   321       |> Conjunction.elim_balanced (length raw_thms)
   322       |> map Simpdata.mk_meta_eq
   323       |> map Drule.zero_var_indexes
   324     val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
   325       | _ => I) (Thm.prop_of hd_thm) [];
   326     val rhs = hd_thm
   327       |> Thm.prop_of
   328       |> Logic.dest_equals
   329       |> fst
   330       |> Term.strip_comb
   331       |> apsnd (fst o split_last)
   332       |> list_comb;
   333     val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
   334     val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
   335   in
   336     thms
   337     |> Conjunction.intr_balanced
   338     |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
   339     |> Thm.implies_intr asm
   340     |> Thm.generalize ([], params) 0
   341     |> AxClass.unoverload thy
   342     |> Thm.varifyT
   343   end;
   344 
   345 
   346 (* equality *)
   347 
   348 fun mk_eq_eqns thy dtco =
   349   let
   350     val (vs, cos) = DatatypePackage.the_datatype_spec thy dtco;
   351     val { descr, index, inject = inject_thms, ... } = DatatypePackage.the_datatype thy dtco;
   352     val ty = Type (dtco, map TFree vs);
   353     fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
   354       $ t1 $ t2;
   355     fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
   356     fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
   357     val triv_injects = map_filter
   358      (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
   359        | _ => NONE) cos;
   360     fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
   361       trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
   362     val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index);
   363     fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
   364       [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
   365     val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index));
   366     val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
   367     val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss
   368       addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms))
   369       addsimprocs [DatatypePackage.distinct_simproc]);
   370     fun prove prop = Goal.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
   371       |> Simpdata.mk_eq;
   372   in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end;
   373 
   374 fun add_equality vs dtcos thy =
   375   let
   376     fun add_def dtco lthy =
   377       let
   378         val ty = Type (dtco, map TFree vs);
   379         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
   380           $ Free ("x", ty) $ Free ("y", ty);
   381         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
   382           (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
   383         val def' = Syntax.check_term lthy def;
   384         val ((_, (_, thm)), lthy') = Specification.definition
   385           (NONE, (Attrib.empty_binding, def')) lthy;
   386         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
   387         val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
   388       in (thm', lthy') end;
   389     fun tac thms = Class.intro_classes_tac []
   390       THEN ALLGOALS (ProofContext.fact_tac thms);
   391     fun add_eq_thms dtco thy =
   392       let
   393         val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
   394         val thy_ref = Theory.check_thy thy;
   395         fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco));
   396       in
   397         Code.add_eqnl (const, Lazy.lazy mk_thms) thy
   398       end;
   399   in
   400     thy
   401     |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq])
   402     |> fold_map add_def dtcos
   403     |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
   404          (fn _ => fn def_thms => tac def_thms) def_thms)
   405     |-> (fn def_thms => fold Code.del_eqn def_thms)
   406     |> fold add_eq_thms dtcos
   407   end;
   408 
   409 
   410 (* liberal addition of code data for datatypes *)
   411 
   412 fun mk_constr_consts thy vs dtco cos =
   413   let
   414     val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
   415     val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
   416   in if is_some (try (Code.constrset_of_consts thy) cs')
   417     then SOME cs
   418     else NONE
   419   end;
   420 
   421 fun add_all_code dtcos thy =
   422   let
   423     val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
   424     val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
   425     val css = if exists is_none any_css then []
   426       else map_filter I any_css;
   427     val case_rewrites = maps (#case_rewrites o DatatypePackage.the_datatype thy) dtcos;
   428     val certs = map (mk_case_cert thy) dtcos;
   429   in
   430     if null css then thy
   431     else thy
   432       |> fold Code.add_datatype css
   433       |> fold_rev Code.add_default_eqn case_rewrites
   434       |> fold Code.add_case certs
   435       |> add_equality vs dtcos
   436    end;
   437 
   438 
   439 
   440 (** theory setup **)
   441 
   442 val setup = 
   443   add_codegen "datatype" datatype_codegen
   444   #> add_tycodegen "datatype" datatype_tycodegen
   445   #> DatatypePackage.interpretation add_all_code
   446 
   447 end;