Haskell uses generic flat_program combinator
authorhaftmann
Tue, 07 Sep 2010 16:05:18 +0200
changeset 394343d30f501b7c2
parent 39433 b2f9a6f4b84b
child 39435 13c6e91efcb6
Haskell uses generic flat_program combinator
src/Tools/Code/code_haskell.ML
     1.1 --- a/src/Tools/Code/code_haskell.ML	Tue Sep 07 11:08:58 2010 +0200
     1.2 +++ b/src/Tools/Code/code_haskell.ML	Tue Sep 07 16:05:18 2010 +0200
     1.3 @@ -261,7 +261,7 @@
     1.4            end;
     1.5    in print_stmt end;
     1.6  
     1.7 -type flat_program = ((string * Code_Thingol.stmt) Graph.T * ((string * (string list * string list)) list)) Graph.T;
     1.8 +type flat_program = ((string * Code_Thingol.stmt option) Graph.T * string list) Graph.T;
     1.9  
    1.10  fun flat_program labelled_name { module_alias, module_prefix, reserved,
    1.11        empty_nsp, namify_stmt, modify_stmt } program =
    1.12 @@ -277,11 +277,9 @@
    1.13      fun add_stmt name stmt =
    1.14        let
    1.15          val (module_name, base) = dest_name name;
    1.16 -      in case modify_stmt stmt
    1.17 -       of SOME stmt' => 
    1.18 -            Graph.default_node (module_name, (Graph.empty, []))
    1.19 -            #> (Graph.map_node module_name o apfst) (Graph.new_node (name, (base, stmt')))
    1.20 -        | NONE => I
    1.21 +      in
    1.22 +        Graph.default_node (module_name, (Graph.empty, []))
    1.23 +        #> (Graph.map_node module_name o apfst) (Graph.new_node (name, (base, stmt)))
    1.24        end;
    1.25      fun add_dependency name name' =
    1.26        let
    1.27 @@ -289,14 +287,13 @@
    1.28          val (module_name', base') = dest_name name';
    1.29        in if module_name = module_name'
    1.30          then (Graph.map_node module_name o apfst) (Graph.add_edge (name, name'))
    1.31 -        else (Graph.map_node module_name o apsnd)
    1.32 -          (AList.map_default (op =) (module_name', []) (insert (op =) name'))
    1.33 +        else (Graph.map_node module_name o apsnd) (AList.map_default (op =) (module_name', []) (insert (op =) name'))
    1.34        end;
    1.35      val proto_program = Graph.empty
    1.36        |> Graph.fold (fn (name, (stmt, _)) => add_stmt name stmt) program
    1.37        |> Graph.fold (fn (name, (_, (_, names))) => fold (add_dependency name) names) program;
    1.38  
    1.39 -    (* name declarations *)
    1.40 +    (* name declarations and statement modifications *)
    1.41      fun declare name (base, stmt) (gr, nsp) = 
    1.42        let
    1.43          val (base', nsp') = namify_stmt stmt base nsp;
    1.44 @@ -304,45 +301,36 @@
    1.45        in (gr', nsp') end;
    1.46      fun declarations gr = (gr, empty_nsp)
    1.47        |> fold (fn name => declare name (Graph.get_node gr name)) (Graph.keys gr) 
    1.48 -      |> fst;
    1.49 -    val intermediate_program = proto_program
    1.50 -      |> Graph.map ((K o apfst) declarations);
    1.51 +      |> fst
    1.52 +      |> (Graph.map o K o apsnd) modify_stmt;
    1.53 +    val flat_program = proto_program
    1.54 +      |> (Graph.map o K o apfst) declarations;
    1.55  
    1.56      (* qualified and unqualified imports, deresolving *)
    1.57      fun base_deresolver name = fst (Graph.get_node
    1.58 -      (fst (Graph.get_node intermediate_program (fst (dest_name name)))) name);
    1.59 -    fun classify_imports gr imports =
    1.60 +      (fst (Graph.get_node flat_program (fst (dest_name name)))) name);
    1.61 +    fun classify_names gr imports =
    1.62        let
    1.63          val import_tab = maps
    1.64            (fn (module_name, names) => map (rpair module_name) names) imports;
    1.65          val imported_names = map fst import_tab;
    1.66          val here_names = Graph.keys gr;
    1.67 -        val qualified_names = []
    1.68 -          |> fold (fn name => AList.map_default (op =) (base_deresolver name, [])
    1.69 -               (insert (op =) name)) (here_names @ imported_names)
    1.70 -          |> filter (fn (_, names) => length names > 1)
    1.71 -          |> maps snd;
    1.72 -        val name_tab = Symtab.empty
    1.73 -          |> fold (fn name => Symtab.update (name, base_deresolver name)) here_names
    1.74 -          |> fold (fn name => Symtab.update (name,
    1.75 -               if member (op =) qualified_names name
    1.76 -               then Long_Name.append (the (AList.lookup (op =) import_tab name))
    1.77 -                 (base_deresolver name)
    1.78 -               else base_deresolver name)) imported_names;
    1.79 -        val imports' = (map o apsnd) (List.partition (member (op =) qualified_names))
    1.80 -          imports;
    1.81 -      in (name_tab, imports') end;
    1.82 -    val classified = AList.make (uncurry classify_imports o Graph.get_node intermediate_program)
    1.83 -      (Graph.keys intermediate_program);
    1.84 -    val flat_program = Graph.map (apsnd o K o snd o the o AList.lookup (op =) classified)
    1.85 -      intermediate_program;
    1.86 +      in
    1.87 +        Symtab.empty
    1.88 +        |> fold (fn name => Symtab.update (name, base_deresolver name)) here_names
    1.89 +        |> fold (fn name => Symtab.update (name,
    1.90 +            Long_Name.append (the (AList.lookup (op =) import_tab name))
    1.91 +              (base_deresolver name))) imported_names
    1.92 +      end;
    1.93 +    val name_tabs = AList.make (uncurry classify_names o Graph.get_node flat_program)
    1.94 +      (Graph.keys flat_program);
    1.95      val deresolver_tab = Symtab.empty
    1.96 -      |> fold (fn (module_name, (name_tab, _)) => Symtab.update (module_name, name_tab)) classified;
    1.97 +      |> fold (fn (module_name, name_tab) => Symtab.update (module_name, name_tab)) name_tabs;
    1.98      fun deresolver module_name name =
    1.99        the (Symtab.lookup (the (Symtab.lookup deresolver_tab module_name)) name)
   1.100        handle Option => error ("Unknown statement name: " ^ labelled_name name);
   1.101  
   1.102 -  in (deresolver, flat_program) end;
   1.103 +  in { deresolver = deresolver, flat_program = flat_program } end;
   1.104  
   1.105  fun haskell_program_of_program labelled_name module_alias module_prefix reserved =
   1.106    let
   1.107 @@ -379,70 +367,16 @@
   1.108          modify_stmt = fn stmt => if select_stmt stmt then SOME stmt else NONE }
   1.109    end;
   1.110  
   1.111 -fun mk_name_module reserved module_prefix module_alias program =
   1.112 -  let
   1.113 -    val fragments_tab = Code_Namespace.build_module_namespace { module_alias = module_alias,
   1.114 -      module_prefix = module_prefix, reserved = reserved } program;
   1.115 -  in Long_Name.implode o the o Symtab.lookup fragments_tab end;
   1.116 -
   1.117 -fun haskell_program_of_program labelled_name module_prefix reserved module_alias program =
   1.118 -  let
   1.119 -    val reserved = Name.make_context reserved;
   1.120 -    val mk_name_module = mk_name_module reserved module_prefix module_alias program;
   1.121 -    fun add_stmt (name, (stmt, deps)) =
   1.122 -      let
   1.123 -        val (module_name, base) = Code_Namespace.dest_name name;
   1.124 -        val module_name' = mk_name_module module_name;
   1.125 -        val mk_name_stmt = yield_singleton Name.variants;
   1.126 -        fun add_fun upper (nsp_fun, nsp_typ) =
   1.127 -          let
   1.128 -            val (base', nsp_fun') =
   1.129 -              mk_name_stmt (if upper then first_upper base else base) nsp_fun
   1.130 -          in (base', (nsp_fun', nsp_typ)) end;
   1.131 -        fun add_typ (nsp_fun, nsp_typ) =
   1.132 -          let
   1.133 -            val (base', nsp_typ') = mk_name_stmt (first_upper base) nsp_typ
   1.134 -          in (base', (nsp_fun, nsp_typ')) end;
   1.135 -        val add_name = case stmt
   1.136 -         of Code_Thingol.Fun (_, (_, SOME _)) => pair base
   1.137 -          | Code_Thingol.Fun _ => add_fun false
   1.138 -          | Code_Thingol.Datatype _ => add_typ
   1.139 -          | Code_Thingol.Datatypecons _ => add_fun true
   1.140 -          | Code_Thingol.Class _ => add_typ
   1.141 -          | Code_Thingol.Classrel _ => pair base
   1.142 -          | Code_Thingol.Classparam _ => add_fun false
   1.143 -          | Code_Thingol.Classinst _ => pair base;
   1.144 -        fun add_stmt' base' = case stmt
   1.145 -         of Code_Thingol.Fun (_, (_, SOME _)) =>
   1.146 -              I
   1.147 -          | Code_Thingol.Datatypecons _ =>
   1.148 -              cons (name, (Long_Name.append module_name' base', NONE))
   1.149 -          | Code_Thingol.Classrel _ => I
   1.150 -          | Code_Thingol.Classparam _ =>
   1.151 -              cons (name, (Long_Name.append module_name' base', NONE))
   1.152 -          | _ => cons (name, (Long_Name.append module_name' base', SOME stmt));
   1.153 -      in
   1.154 -        Symtab.map_default (module_name', ([], ([], (reserved, reserved))))
   1.155 -              (apfst (fold (insert (op = : string * string -> bool)) deps))
   1.156 -        #> `(fn program => add_name ((snd o snd o the o Symtab.lookup program) module_name'))
   1.157 -        #-> (fn (base', names) =>
   1.158 -              (Symtab.map_entry module_name' o apsnd) (fn (stmts, _) =>
   1.159 -              (add_stmt' base' stmts, names)))
   1.160 -      end;
   1.161 -    val hs_program = fold add_stmt (AList.make (fn name =>
   1.162 -      (Graph.get_node program name, Graph.imm_succs program name))
   1.163 -      (Graph.strong_conn program |> flat)) Symtab.empty;
   1.164 -    fun deresolver name = (fst o the o AList.lookup (op =) ((fst o snd o the
   1.165 -      o Symtab.lookup hs_program) ((mk_name_module o fst o Code_Namespace.dest_name) name))) name
   1.166 -      handle Option => error ("Unknown statement name: " ^ labelled_name name);
   1.167 -  in (deresolver, hs_program) end;
   1.168 -
   1.169  fun serialize_haskell module_prefix string_classes { labelled_name, reserved_syms,
   1.170      includes, module_alias, class_syntax, tyco_syntax, const_syntax, program } =
   1.171    let
   1.172 +
   1.173 +    (* build program *)
   1.174      val reserved = fold (insert (op =) o fst) includes reserved_syms;
   1.175 -    val (deresolver, hs_program) = haskell_program_of_program labelled_name
   1.176 -      module_prefix reserved module_alias program;
   1.177 +    val { deresolver, flat_program = haskell_program } = haskell_program_of_program
   1.178 +      labelled_name module_alias module_prefix (Name.make_context reserved) program;
   1.179 +
   1.180 +    (* print statements *)
   1.181      val contr_classparam_typs = Code_Thingol.contr_classparam_typs program;
   1.182      fun deriving_show tyco =
   1.183        let
   1.184 @@ -457,58 +391,52 @@
   1.185                andalso forall (deriv' tycos) tys
   1.186            | deriv' _ (ITyVar _) = true
   1.187        in deriv [] tyco end;
   1.188 -    val reserved = make_vars reserved;
   1.189 -    fun print_stmt qualified = print_haskell_stmt labelled_name
   1.190 -      class_syntax tyco_syntax const_syntax reserved
   1.191 -      (if qualified then deresolver else Long_Name.base_name o deresolver)
   1.192 -      contr_classparam_typs
   1.193 +    fun print_stmt deresolve = print_haskell_stmt labelled_name
   1.194 +      class_syntax tyco_syntax const_syntax (make_vars reserved)
   1.195 +      deresolve contr_classparam_typs
   1.196        (if string_classes then deriving_show else K false);
   1.197 -    fun print_module name content =
   1.198 -      (name, Pretty.chunks2 [
   1.199 -        str ("module " ^ name ^ " where {"),
   1.200 -        content,
   1.201 -        str "}"
   1.202 -      ]);
   1.203 -    fun serialize_module (module_name', (deps, (stmts, _))) =
   1.204 +
   1.205 +    (* print modules *)
   1.206 +    val import_includes_ps =
   1.207 +      map (fn (name, _) => str ("import qualified " ^ name ^ ";")) includes;
   1.208 +    fun print_module_frame module_name ps =
   1.209 +      (module_name, Pretty.chunks2 (
   1.210 +        str "{-# OPTIONS_GHC -fglasgow-exts #-}"
   1.211 +        :: str ("module " ^ module_name ^ " where {")
   1.212 +        :: ps
   1.213 +        @| str "}"
   1.214 +      ));
   1.215 +    fun print_module module_name (gr, imports) =
   1.216        let
   1.217 -        val stmt_names = map fst stmts;
   1.218 -        val qualified = true;
   1.219 -        val imports = subtract (op =) stmt_names deps
   1.220 -          |> distinct (op =)
   1.221 -          |> map_filter (try deresolver)
   1.222 -          |> map Long_Name.qualifier
   1.223 -          |> distinct (op =);
   1.224 -        fun print_import_include (name, _) = str ("import qualified " ^ name ^ ";");
   1.225 -        fun print_import_module name = str ((if qualified
   1.226 -          then "import qualified "
   1.227 -          else "import ") ^ name ^ ";");
   1.228 -        val import_ps = map print_import_include includes @ map print_import_module imports
   1.229 -        val content = Pretty.chunks2 ((if null import_ps then [] else [Pretty.chunks import_ps])
   1.230 -            @ map_filter
   1.231 -              (fn (name, (_, SOME stmt)) => SOME (markup_stmt name (print_stmt qualified (name, stmt)))
   1.232 -                | (_, (_, NONE)) => NONE) stmts
   1.233 -          );
   1.234 -      in print_module module_name' content end;
   1.235 -    fun write_module width (SOME destination) (modlname, content) =
   1.236 +        val deresolve = deresolver module_name
   1.237 +        fun print_import module_name = (semicolon o map str) ["import qualified", module_name];
   1.238 +        val import_ps = import_includes_ps @ map (print_import o fst) imports;
   1.239 +        fun print_stmt' gr name = case Graph.get_node gr name
   1.240 +         of (_, NONE) => NONE
   1.241 +          | (_, SOME stmt) => SOME (markup_stmt name (print_stmt deresolve (name, stmt)));
   1.242 +        val body_ps = map_filter (print_stmt' gr) ((flat o rev o Graph.strong_conn) gr);
   1.243 +      in
   1.244 +        print_module_frame module_name
   1.245 +          ((if null import_ps then [] else [Pretty.chunks import_ps]) @ body_ps)
   1.246 +      end;
   1.247 +
   1.248 +    (*serialization*)
   1.249 +    fun write_module width (SOME destination) (module_name, content) =
   1.250            let
   1.251              val _ = File.check destination;
   1.252 -            val filename = case modlname
   1.253 -             of "" => Path.explode "Main.hs"
   1.254 -              | _ => (Path.ext "hs" o Path.explode o implode o separate "/"
   1.255 -                    o Long_Name.explode) modlname;
   1.256 -            val pathname = Path.append destination filename;
   1.257 -            val _ = File.mkdir_leaf (Path.dir pathname);
   1.258 -          in File.write pathname
   1.259 -            ("{-# OPTIONS_GHC -fglasgow-exts #-}\n\n"
   1.260 -              ^ format [] width content)
   1.261 -          end
   1.262 +            val filepath = (Path.append destination o Path.ext "hs" o Path.explode o implode
   1.263 +              o separate "/" o Long_Name.explode) module_name;
   1.264 +            val _ = File.mkdir_leaf (Path.dir filepath);
   1.265 +          in File.write filepath (format [] width content) end
   1.266        | write_module width NONE (_, content) = writeln (format [] width content);
   1.267    in
   1.268      Code_Target.serialization
   1.269        (fn width => fn destination => K () o map (write_module width destination))
   1.270 -      (fn present => fn width => rpair (fn _ => error "no deresolving") o format present width o Pretty.chunks o map snd)
   1.271 -      (map (uncurry print_module) includes
   1.272 -        @ map serialize_module (Symtab.dest hs_program))
   1.273 +      (fn present => fn width => rpair (fn _ => error "no deresolving")
   1.274 +        o format present width o Pretty.chunks o map snd)
   1.275 +      (map (uncurry print_module_frame o apsnd single) includes
   1.276 +        @ map (fn module_name => print_module module_name (Graph.get_node haskell_program module_name))
   1.277 +          ((flat o rev o Graph.strong_conn) haskell_program))
   1.278    end;
   1.279  
   1.280  val serializer : Code_Target.serializer =