src/Tools/Code/code_haskell.ML
author haftmann
Mon, 23 Jul 2012 09:28:03 +0200
changeset 49446 6efff142bb54
parent 49087 ace701efe203
child 49583 084cd758a8ab
permissions -rw-r--r--
restrict unqualified imports from Haskell Prelude to a small set of fundamental operations
     1 (*  Title:      Tools/Code/code_haskell.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Serializer for Haskell.
     5 *)
     6 
     7 signature CODE_HASKELL =
     8 sig
     9   val language_params: string
    10   val target: string
    11   val setup: theory -> theory
    12 end;
    13 
    14 structure Code_Haskell : CODE_HASKELL =
    15 struct
    16 
    17 val target = "Haskell";
    18 
    19 val language_extensions =
    20   ["EmptyDataDecls", "RankNTypes", "ScopedTypeVariables"];
    21 
    22 val language_pragma =
    23   "{-# LANGUAGE " ^ commas language_extensions ^ " #-}";
    24 
    25 val language_params =
    26   space_implode " " (map (prefix "-X") language_extensions);
    27 
    28 open Basic_Code_Thingol;
    29 open Code_Printer;
    30 
    31 infixr 5 @@;
    32 infixr 5 @|;
    33 
    34 
    35 (** Haskell serializer **)
    36 
    37 fun print_haskell_stmt class_syntax tyco_syntax const_syntax
    38     reserved deresolve deriving_show =
    39   let
    40     fun class_name class = case class_syntax class
    41      of NONE => deresolve class
    42       | SOME class => class;
    43     fun print_typcontext tyvars vs = case maps (fn (v, sort) => map (pair v) sort) vs
    44      of [] => []
    45       | constraints => enum "," "(" ")" (
    46           map (fn (v, class) =>
    47             str (class_name class ^ " " ^ lookup_var tyvars v)) constraints)
    48           @@ str " => ";
    49     fun print_typforall tyvars vs = case map fst vs
    50      of [] => []
    51       | vnames => str "forall " :: Pretty.breaks
    52           (map (str o lookup_var tyvars) vnames) @ str "." @@ Pretty.brk 1;
    53     fun print_tyco_expr tyvars fxy (tyco, tys) =
    54       brackify fxy (str tyco :: map (print_typ tyvars BR) tys)
    55     and print_typ tyvars fxy (tyco `%% tys) = (case tyco_syntax tyco
    56          of NONE => print_tyco_expr tyvars fxy (deresolve tyco, tys)
    57           | SOME (_, print) => print (print_typ tyvars) fxy tys)
    58       | print_typ tyvars fxy (ITyVar v) = (str o lookup_var tyvars) v;
    59     fun print_typdecl tyvars (tyco, vs) =
    60       print_tyco_expr tyvars NOBR (tyco, map ITyVar vs);
    61     fun print_typscheme tyvars (vs, ty) =
    62       Pretty.block (print_typforall tyvars vs @ print_typcontext tyvars vs @| print_typ tyvars NOBR ty);
    63     fun print_term tyvars some_thm vars fxy (IConst const) =
    64           print_app tyvars some_thm vars fxy (const, [])
    65       | print_term tyvars some_thm vars fxy (t as (t1 `$ t2)) =
    66           (case Code_Thingol.unfold_const_app t
    67            of SOME app => print_app tyvars some_thm vars fxy app
    68             | _ =>
    69                 brackify fxy [
    70                   print_term tyvars some_thm vars NOBR t1,
    71                   print_term tyvars some_thm vars BR t2
    72                 ])
    73       | print_term tyvars some_thm vars fxy (IVar NONE) =
    74           str "_"
    75       | print_term tyvars some_thm vars fxy (IVar (SOME v)) =
    76           (str o lookup_var vars) v
    77       | print_term tyvars some_thm vars fxy (t as _ `|=> _) =
    78           let
    79             val (binds, t') = Code_Thingol.unfold_pat_abs t;
    80             val (ps, vars') = fold_map (print_bind tyvars some_thm BR o fst) binds vars;
    81           in brackets (str "\\" :: ps @ str "->" @@ print_term tyvars some_thm vars' NOBR t') end
    82       | print_term tyvars some_thm vars fxy (ICase case_expr) =
    83           (case Code_Thingol.unfold_const_app (#primitive case_expr)
    84            of SOME (app as ({ name = c, ... }, _)) => if is_none (const_syntax c)
    85                 then print_case tyvars some_thm vars fxy case_expr
    86                 else print_app tyvars some_thm vars fxy app
    87             | NONE => print_case tyvars some_thm vars fxy case_expr)
    88     and print_app_expr tyvars some_thm vars ({ name = c, dom, range, annotate, ... }, ts) =
    89       let
    90         val ty = Library.foldr (fn (ty1, ty2) => Code_Thingol.fun_tyco `%% [ty1, ty2]) (dom, range)
    91         val printed_const =
    92           if annotate then
    93             brackets [(str o deresolve) c, str "::", print_typ tyvars NOBR ty]
    94           else
    95             (str o deresolve) c
    96       in 
    97         printed_const :: map (print_term tyvars some_thm vars BR) ts
    98       end
    99     and print_app tyvars = gen_print_app (print_app_expr tyvars) (print_term tyvars) const_syntax
   100     and print_bind tyvars some_thm fxy p = gen_print_bind (print_term tyvars) some_thm fxy p
   101     and print_case tyvars some_thm vars fxy { clauses = [], ... } =
   102           (brackify fxy o Pretty.breaks o map str) ["error", "\"empty case\""]
   103       | print_case tyvars some_thm vars fxy (case_expr as { clauses = [_], ... }) =
   104           let
   105             val (binds, body) = Code_Thingol.unfold_let (ICase case_expr);
   106             fun print_match ((pat, _), t) vars =
   107               vars
   108               |> print_bind tyvars some_thm BR pat
   109               |>> (fn p => semicolon [p, str "=", print_term tyvars some_thm vars NOBR t])
   110             val (ps, vars') = fold_map print_match binds vars;
   111           in brackify_block fxy (str "let {")
   112             ps
   113             (concat [str "}", str "in", print_term tyvars some_thm vars' NOBR body])
   114           end
   115       | print_case tyvars some_thm vars fxy { term = t, typ = ty, clauses = clauses as _ :: _, ... } =
   116           let
   117             fun print_select (pat, body) =
   118               let
   119                 val (p, vars') = print_bind tyvars some_thm NOBR pat vars;
   120               in semicolon [p, str "->", print_term tyvars some_thm vars' NOBR body] end;
   121           in Pretty.block_enclose
   122             (concat [str "(case", print_term tyvars some_thm vars NOBR t, str "of", str "{"], str "})")
   123             (map print_select clauses)
   124           end;
   125     fun print_stmt (name, Code_Thingol.Fun (_, (((vs, ty), raw_eqs), _))) =
   126           let
   127             val tyvars = intro_vars (map fst vs) reserved;
   128             fun print_err n =
   129               semicolon (
   130                 (str o deresolve) name
   131                 :: map str (replicate n "_")
   132                 @ str "="
   133                 :: str "error"
   134                 @@ (str o ML_Syntax.print_string
   135                     o Long_Name.base_name o Long_Name.qualifier) name
   136               );
   137             fun print_eqn ((ts, t), (some_thm, _)) =
   138               let
   139                 val consts = fold Code_Thingol.add_constnames (t :: ts) [];
   140                 val vars = reserved
   141                   |> intro_base_names
   142                       (is_none o const_syntax) deresolve consts
   143                   |> intro_vars ((fold o Code_Thingol.fold_varnames)
   144                       (insert (op =)) ts []);
   145               in
   146                 semicolon (
   147                   (str o deresolve) name
   148                   :: map (print_term tyvars some_thm vars BR) ts
   149                   @ str "="
   150                   @@ print_term tyvars some_thm vars NOBR t
   151                 )
   152               end;
   153           in
   154             Pretty.chunks (
   155               semicolon [
   156                 (str o suffix " ::" o deresolve) name,
   157                 print_typscheme tyvars (vs, ty)
   158               ]
   159               :: (case filter (snd o snd) raw_eqs
   160                of [] => [print_err ((length o fst o Code_Thingol.unfold_fun) ty)]
   161                 | eqs => map print_eqn eqs)
   162             )
   163           end
   164       | print_stmt (name, Code_Thingol.Datatype (_, (vs, []))) =
   165           let
   166             val tyvars = intro_vars vs reserved;
   167           in
   168             semicolon [
   169               str "data",
   170               print_typdecl tyvars (deresolve name, vs)
   171             ]
   172           end
   173       | print_stmt (name, Code_Thingol.Datatype (_, (vs, [((co, _), [ty])]))) =
   174           let
   175             val tyvars = intro_vars vs reserved;
   176           in
   177             semicolon (
   178               str "newtype"
   179               :: print_typdecl tyvars (deresolve name, vs)
   180               :: str "="
   181               :: (str o deresolve) co
   182               :: print_typ tyvars BR ty
   183               :: (if deriving_show name then [str "deriving (Read, Show)"] else [])
   184             )
   185           end
   186       | print_stmt (name, Code_Thingol.Datatype (_, (vs, co :: cos))) =
   187           let
   188             val tyvars = intro_vars vs reserved;
   189             fun print_co ((co, _), tys) =
   190               concat (
   191                 (str o deresolve) co
   192                 :: map (print_typ tyvars BR) tys
   193               )
   194           in
   195             semicolon (
   196               str "data"
   197               :: print_typdecl tyvars (deresolve name, vs)
   198               :: str "="
   199               :: print_co co
   200               :: map ((fn p => Pretty.block [str "| ", p]) o print_co) cos
   201               @ (if deriving_show name then [str "deriving (Read, Show)"] else [])
   202             )
   203           end
   204       | print_stmt (name, Code_Thingol.Class (_, (v, (super_classes, classparams)))) =
   205           let
   206             val tyvars = intro_vars [v] reserved;
   207             fun print_classparam (classparam, ty) =
   208               semicolon [
   209                 (str o deresolve) classparam,
   210                 str "::",
   211                 print_typ tyvars NOBR ty
   212               ]
   213           in
   214             Pretty.block_enclose (
   215               Pretty.block [
   216                 str "class ",
   217                 Pretty.block (print_typcontext tyvars [(v, map fst super_classes)]),
   218                 str (deresolve name ^ " " ^ lookup_var tyvars v),
   219                 str " where {"
   220               ],
   221               str "};"
   222             ) (map print_classparam classparams)
   223           end
   224       | print_stmt (_, Code_Thingol.Classinst { class, tyco, vs, inst_params, ... }) =
   225           let
   226             val tyvars = intro_vars (map fst vs) reserved;
   227             fun requires_args classparam = case const_syntax classparam
   228              of NONE => NONE
   229               | SOME (Code_Printer.Plain_const_syntax _) => SOME 0
   230               | SOME (Code_Printer.Complex_const_syntax (k,_ )) => SOME k;
   231             fun print_classparam_instance ((classparam, const), (thm, _)) =
   232               case requires_args classparam
   233                of NONE => semicolon [
   234                       (str o Long_Name.base_name o deresolve) classparam,
   235                       str "=",
   236                       print_app tyvars (SOME thm) reserved NOBR (const, [])
   237                     ]
   238                 | SOME k =>
   239                     let
   240                       val { name = c, dom, range, ... } = const;
   241                       val (vs, rhs) = (apfst o map) fst
   242                         (Code_Thingol.unfold_abs (Code_Thingol.eta_expand k (const, [])));
   243                       val s = if (is_some o const_syntax) c
   244                         then NONE else (SOME o Long_Name.base_name o deresolve) c;
   245                       val vars = reserved
   246                         |> intro_vars (map_filter I (s :: vs));
   247                       val lhs = IConst { name = classparam, typargs = [],
   248                         dicts = [], dom = dom, range = range, annotate = false } `$$ map IVar vs;
   249                         (*dictionaries are not relevant at this late stage,
   250                           and these consts never need type annotations for disambiguation *)
   251                     in
   252                       semicolon [
   253                         print_term tyvars (SOME thm) vars NOBR lhs,
   254                         str "=",
   255                         print_term tyvars (SOME thm) vars NOBR rhs
   256                       ]
   257                     end;
   258           in
   259             Pretty.block_enclose (
   260               Pretty.block [
   261                 str "instance ",
   262                 Pretty.block (print_typcontext tyvars vs),
   263                 str (class_name class ^ " "),
   264                 print_typ tyvars BR (tyco `%% map (ITyVar o fst) vs),
   265                 str " where {"
   266               ],
   267               str "};"
   268             ) (map print_classparam_instance inst_params)
   269           end;
   270   in print_stmt end;
   271 
   272 fun haskell_program_of_program labelled_name module_alias module_prefix reserved =
   273   let
   274     fun namify_fun upper base (nsp_fun, nsp_typ) =
   275       let
   276         val (base', nsp_fun') =
   277           Name.variant (if upper then first_upper base else base) nsp_fun;
   278       in (base', (nsp_fun', nsp_typ)) end;
   279     fun namify_typ base (nsp_fun, nsp_typ) =
   280       let
   281         val (base', nsp_typ') = Name.variant (first_upper base) nsp_typ;
   282       in (base', (nsp_fun, nsp_typ')) end;
   283     fun namify_stmt (Code_Thingol.Fun (_, (_, SOME _))) = pair
   284       | namify_stmt (Code_Thingol.Fun _) = namify_fun false
   285       | namify_stmt (Code_Thingol.Datatype _) = namify_typ
   286       | namify_stmt (Code_Thingol.Datatypecons _) = namify_fun true
   287       | namify_stmt (Code_Thingol.Class _) = namify_typ
   288       | namify_stmt (Code_Thingol.Classrel _) = pair
   289       | namify_stmt (Code_Thingol.Classparam _) = namify_fun false
   290       | namify_stmt (Code_Thingol.Classinst _) = pair;
   291     fun select_stmt (Code_Thingol.Fun (_, (_, SOME _))) = false
   292       | select_stmt (Code_Thingol.Fun _) = true
   293       | select_stmt (Code_Thingol.Datatype _) = true
   294       | select_stmt (Code_Thingol.Datatypecons _) = false
   295       | select_stmt (Code_Thingol.Class _) = true
   296       | select_stmt (Code_Thingol.Classrel _) = false
   297       | select_stmt (Code_Thingol.Classparam _) = false
   298       | select_stmt (Code_Thingol.Classinst _) = true;
   299   in
   300     Code_Namespace.flat_program labelled_name
   301       { module_alias = module_alias, module_prefix = module_prefix,
   302         reserved = reserved, empty_nsp = (reserved, reserved), namify_stmt = namify_stmt,
   303         modify_stmt = fn stmt => if select_stmt stmt then SOME stmt else NONE }
   304   end;
   305 
   306 val prelude_import_operators = [
   307   "==", "/=", "<", "<=", ">=", ">", "+", "-", "*", "/", "**", ">>=", ">>", "=<<", "&&", "||", "^", "^^", ".", "$", "$!", "++", "!!"
   308 ];
   309 
   310 val prelude_import_unqualified = [
   311   "Eq",
   312   "error",
   313   "id",
   314   "return",
   315   "not",
   316   "fst", "snd",
   317   "map", "filter", "concat", "concatMap", "reverse", "zip", "null", "takeWhile", "dropWhile", "all", "any",
   318   "Integer", "negate", "abs", "divMod",
   319   "String"
   320 ];
   321 
   322 val prelude_import_unqualified_constr = [
   323   ("Bool", ["True", "False"]),
   324   ("Maybe", ["Nothing", "Just"])
   325 ];
   326 
   327 fun serialize_haskell module_prefix string_classes { labelled_name, reserved_syms,
   328     includes, module_alias, class_syntax, tyco_syntax, const_syntax } program =
   329   let
   330 
   331     (* build program *)
   332     val reserved = fold (insert (op =) o fst) includes reserved_syms;
   333     val { deresolver, flat_program = haskell_program } = haskell_program_of_program
   334       labelled_name module_alias module_prefix (Name.make_context reserved) program;
   335 
   336     (* print statements *)
   337     fun deriving_show tyco =
   338       let
   339         fun deriv _ "fun" = false
   340           | deriv tycos tyco = not (tyco = Code_Thingol.fun_tyco)
   341               andalso (member (op =) tycos tyco
   342               orelse case try (Graph.get_node program) tyco
   343                 of SOME (Code_Thingol.Datatype (_, (_, cs))) => forall (deriv' (tyco :: tycos))
   344                     (maps snd cs)
   345                  | NONE => true)
   346         and deriv' tycos (tyco `%% tys) = deriv tycos tyco
   347               andalso forall (deriv' tycos) tys
   348           | deriv' _ (ITyVar _) = true
   349       in deriv [] tyco end;
   350     fun print_stmt deresolve = print_haskell_stmt
   351       class_syntax tyco_syntax const_syntax (make_vars reserved)
   352       deresolve (if string_classes then deriving_show else K false);
   353 
   354     (* print modules *)
   355     fun print_module_frame module_name ps =
   356       (module_name, Pretty.chunks2 (
   357         str ("module " ^ module_name ^ " where {")
   358         :: ps
   359         @| str "}"
   360       ));
   361     fun print_qualified_import module_name = semicolon [str "import qualified", str module_name];
   362     val import_common_ps =
   363       enclose "import Prelude (" ");" (commas (map str
   364         (map (Library.enclose "(" ")") prelude_import_operators @ prelude_import_unqualified)
   365           @ map (fn (tyco, constrs) => (enclose (tyco ^ "(") ")" o commas o map str) constrs) prelude_import_unqualified_constr))
   366       :: print_qualified_import "Prelude"
   367       :: map (print_qualified_import o fst) includes;
   368     fun print_module module_name (gr, imports) =
   369       let
   370         val deresolve = deresolver module_name;
   371         fun print_import module_name = (semicolon o map str) ["import qualified", module_name];
   372         val import_ps = import_common_ps @ map (print_qualified_import o fst) imports;
   373         fun print_stmt' name = case Graph.get_node gr name
   374          of (_, NONE) => NONE
   375           | (_, SOME stmt) => SOME (markup_stmt name (print_stmt deresolve (name, stmt)));
   376         val body_ps = map_filter print_stmt' ((flat o rev o Graph.strong_conn) gr);
   377       in
   378         print_module_frame module_name
   379           ((if null import_ps then [] else [Pretty.chunks import_ps]) @ body_ps)
   380       end;
   381 
   382     (*serialization*)
   383     fun write_module width (SOME destination) (module_name, content) =
   384           let
   385             val _ = File.check_dir destination;
   386             val filepath = (Path.append destination o Path.ext "hs" o Path.explode o implode
   387               o separate "/" o Long_Name.explode) module_name;
   388             val _ = Isabelle_System.mkdirs (Path.dir filepath);
   389           in
   390             (File.write filepath o format [] width o Pretty.chunks2)
   391               [str language_pragma, content]
   392           end
   393       | write_module width NONE (_, content) = writeln (format [] width content);
   394   in
   395     Code_Target.serialization
   396       (fn width => fn destination => K () o map (write_module width destination))
   397       (fn present => fn width => rpair (try (deresolver ""))
   398         o format present width o Pretty.chunks o map snd)
   399       (map (uncurry print_module_frame o apsnd single) includes
   400         @ map (fn module_name => print_module module_name (Graph.get_node haskell_program module_name))
   401           ((flat o rev o Graph.strong_conn) haskell_program))
   402   end;
   403 
   404 val serializer : Code_Target.serializer =
   405   Code_Target.parse_args (Scan.optional (Args.$$$ "root" -- Args.colon |-- Args.name) ""
   406     -- Scan.optional (Args.$$$ "string_classes" >> K true) false
   407     >> (fn (module_prefix, string_classes) =>
   408       serialize_haskell module_prefix string_classes));
   409 
   410 val literals = let
   411   fun char_haskell c =
   412     let
   413       val s = ML_Syntax.print_char c;
   414     in if s = "'" then "\\'" else s end;
   415   fun numeral_haskell k = if k >= 0 then string_of_int k
   416     else Library.enclose "(" ")" (signed_string_of_int k);
   417 in Literals {
   418   literal_char = Library.enclose "'" "'" o char_haskell,
   419   literal_string = quote o translate_string char_haskell,
   420   literal_numeral = numeral_haskell,
   421   literal_positive_numeral = numeral_haskell,
   422   literal_alternative_numeral = numeral_haskell,
   423   literal_naive_numeral = numeral_haskell,
   424   literal_list = enum "," "[" "]",
   425   infix_cons = (5, ":")
   426 } end;
   427 
   428 
   429 (** optional monad syntax **)
   430 
   431 fun pretty_haskell_monad c_bind =
   432   let
   433     fun dest_bind t1 t2 = case Code_Thingol.split_pat_abs t2
   434      of SOME ((pat, ty), t') =>
   435           SOME ((SOME ((pat, ty), true), t1), t')
   436       | NONE => NONE;
   437     fun dest_monad c_bind_name (IConst { name = c, ... } `$ t1 `$ t2) =
   438           if c = c_bind_name then dest_bind t1 t2
   439           else NONE
   440       | dest_monad _ t = case Code_Thingol.split_let t
   441          of SOME (((pat, ty), tbind), t') =>
   442               SOME ((SOME ((pat, ty), false), tbind), t')
   443           | NONE => NONE;
   444     fun implode_monad c_bind_name = Code_Thingol.unfoldr (dest_monad c_bind_name);
   445     fun print_monad print_bind print_term (NONE, t) vars =
   446           (semicolon [print_term vars NOBR t], vars)
   447       | print_monad print_bind print_term (SOME ((bind, _), true), t) vars = vars
   448           |> print_bind NOBR bind
   449           |>> (fn p => semicolon [p, str "<-", print_term vars NOBR t])
   450       | print_monad print_bind print_term (SOME ((bind, _), false), t) vars = vars
   451           |> print_bind NOBR bind
   452           |>> (fn p => semicolon [str "let", str "{", p, str "=", print_term vars NOBR t, str "}"]);
   453     fun pretty _ [c_bind'] print_term thm vars fxy [(t1, _), (t2, _)] = case dest_bind t1 t2
   454      of SOME (bind, t') => let
   455           val (binds, t'') = implode_monad c_bind' t'
   456           val (ps, vars') = fold_map (print_monad (gen_print_bind (K print_term) thm) print_term)
   457             (bind :: binds) vars;
   458         in
   459           (brackify fxy o single o enclose "do { " " }" o Pretty.breaks)
   460             (ps @| print_term vars' NOBR t'')
   461         end
   462       | NONE => brackify_infix (1, L) fxy
   463           (print_term vars (INFX (1, L)) t1, str ">>=", print_term vars (INFX (1, X)) t2)
   464   in (2, ([c_bind], pretty)) end;
   465 
   466 fun add_monad target' raw_c_bind thy =
   467   let
   468     val c_bind = Code.read_const thy raw_c_bind;
   469   in if target = target' then
   470     thy
   471     |> Code_Target.add_const_syntax target c_bind
   472         (SOME (Code_Printer.complex_const_syntax (pretty_haskell_monad c_bind)))
   473   else error "Only Haskell target allows for monad syntax" end;
   474 
   475 
   476 (** Isar setup **)
   477 
   478 val _ =
   479   Outer_Syntax.command @{command_spec "code_monad"} "define code syntax for monads"
   480     (Parse.term_group -- Parse.name >> (fn (raw_bind, target) =>
   481       Toplevel.theory  (add_monad target raw_bind)));
   482 
   483 val setup =
   484   Code_Target.add_target
   485     (target, { serializer = serializer, literals = literals,
   486       check = { env_var = "ISABELLE_GHC", make_destination = I,
   487         make_command = fn module_name =>
   488           "\"$ISABELLE_GHC\" " ^ language_params  ^ " -odir build -hidir build -stubdir build -e \"\" " ^
   489             module_name ^ ".hs" } })
   490   #> Code_Target.add_tyco_syntax target "fun" (SOME (2, fn print_typ => fn fxy => fn [ty1, ty2] =>
   491       brackify_infix (1, R) fxy (
   492         print_typ (INFX (1, X)) ty1,
   493         str "->",
   494         print_typ (INFX (1, R)) ty2
   495       )))
   496   #> fold (Code_Target.add_reserved target) [
   497       "hiding", "deriving", "where", "case", "of", "infix", "infixl", "infixr",
   498       "import", "default", "forall", "let", "in", "class", "qualified", "data",
   499       "newtype", "instance", "if", "then", "else", "type", "as", "do", "module"
   500     ]
   501   #> fold (Code_Target.add_reserved target) prelude_import_unqualified
   502   #> fold (Code_Target.add_reserved target o fst) prelude_import_unqualified_constr
   503   #> fold (fold (Code_Target.add_reserved target) o snd) prelude_import_unqualified_constr;
   504 
   505 end; (*struct*)