[WebAssembly] Add new `export_name` clang attribute for controlling wasm export names

This is equivalent to the existing `import_name` and `import_module`
attributes which control the import names in the final wasm binary
produced by lld.

This maps the existing

This attribute currently requires a string rather than using the
symbol name for a couple of reasons:

1. Avoid confusion with static and dynamic linking which is
   based on symbol name.  Exporting a function from a wasm module using
   this directive is orthogonal to both static and dynamic linking.
2. Avoids name mangling.

Differential Revision: https://reviews.llvm.org/D70520

GitOrigin-RevId: 881d877846e2904c731d616731969421ce8cc825
diff --git a/include/llvm/BinaryFormat/Wasm.h b/include/llvm/BinaryFormat/Wasm.h
index f550d88..59f99cc 100644
--- a/include/llvm/BinaryFormat/Wasm.h
+++ b/include/llvm/BinaryFormat/Wasm.h
@@ -131,6 +131,7 @@
   uint32_t CodeSectionOffset;
   uint32_t Size;
   uint32_t CodeOffset;  // start of Locals and Body
+  StringRef ExportName; // from the "export" section
   StringRef SymbolName; // from the "linking" section
   StringRef DebugName;  // from the "name" section
   uint32_t Comdat;      // from the "comdat info" section
@@ -179,6 +180,7 @@
   uint32_t Flags;
   StringRef ImportModule; // For undefined symbols the module of the import
   StringRef ImportName;   // For undefined symbols the name of the import
+  StringRef ExportName;   // For symbols to be exported from the final module
   union {
     // For function or global symbols, the index in function or global index
     // space.
diff --git a/include/llvm/MC/MCSymbolWasm.h b/include/llvm/MC/MCSymbolWasm.h
index c60d406..ba2068a 100644
--- a/include/llvm/MC/MCSymbolWasm.h
+++ b/include/llvm/MC/MCSymbolWasm.h
@@ -21,6 +21,7 @@
   mutable bool IsUsedInGOT = false;
   Optional<std::string> ImportModule;
   Optional<std::string> ImportName;
+  Optional<std::string> ExportName;
   wasm::WasmSignature *Signature = nullptr;
   Optional<wasm::WasmGlobalType> GlobalType;
   Optional<wasm::WasmEventType> EventType;
@@ -87,6 +88,10 @@
   }
   void setImportName(StringRef Name) { ImportName = Name; }
 
+  bool hasExportName() const { return ExportName.hasValue(); }
+  const StringRef getExportName() const { return ExportName.getValue(); }
+  void setExportName(StringRef Name) { ExportName = Name; }
+
   void setUsedInGOT() const { IsUsedInGOT = true; }
   bool isUsedInGOT() const { return IsUsedInGOT; }
 
diff --git a/include/llvm/Object/Wasm.h b/include/llvm/Object/Wasm.h
index e130ea3..8af94c4 100644
--- a/include/llvm/Object/Wasm.h
+++ b/include/llvm/Object/Wasm.h
@@ -280,6 +280,7 @@
   uint32_t StartFunction = -1;
   bool HasLinkingSection = false;
   bool HasDylinkSection = false;
+  bool SeenCodeSection = false;
   wasm::WasmLinkingData LinkingData;
   uint32_t NumImportedGlobals = 0;
   uint32_t NumImportedFunctions = 0;
diff --git a/lib/MC/WasmObjectWriter.cpp b/lib/MC/WasmObjectWriter.cpp
index b22a393..321f93d 100644
--- a/lib/MC/WasmObjectWriter.cpp
+++ b/lib/MC/WasmObjectWriter.cpp
@@ -1324,6 +1324,14 @@
           Comdats[C->getName()].emplace_back(
               WasmComdatEntry{wasm::WASM_COMDAT_FUNCTION, Index});
         }
+
+        if (WS.hasExportName()) {
+          wasm::WasmExport Export;
+          Export.Name = WS.getExportName();
+          Export.Kind = wasm::WASM_EXTERNAL_FUNCTION;
+          Export.Index = Index;
+          Exports.push_back(Export);
+        }
       } else {
         // An import; the index was assigned above.
         Index = WasmIndices.find(&WS)->second;
@@ -1454,6 +1462,8 @@
     }
     if (WS.hasImportName())
       Flags |= wasm::WASM_SYMBOL_EXPLICIT_NAME;
+    if (WS.hasExportName())
+      Flags |= wasm::WASM_SYMBOL_EXPORTED;
 
     wasm::WasmSymbolInfo Info;
     Info.Name = WS.getName();
diff --git a/lib/Object/WasmObjectFile.cpp b/lib/Object/WasmObjectFile.cpp
index 014b403..ab8918c 100644
--- a/lib/Object/WasmObjectFile.cpp
+++ b/lib/Object/WasmObjectFile.cpp
@@ -343,7 +343,7 @@
 
 Error WasmObjectFile::parseNameSection(ReadContext &Ctx) {
   llvm::DenseSet<uint64_t> Seen;
-  if (Functions.size() != FunctionTypes.size()) {
+  if (FunctionTypes.size() && !SeenCodeSection) {
     return make_error<GenericBinaryError>("Names must come after code section",
                                           object_error::parse_failed);
   }
@@ -389,7 +389,7 @@
 
 Error WasmObjectFile::parseLinkingSection(ReadContext &Ctx) {
   HasLinkingSection = true;
-  if (Functions.size() != FunctionTypes.size()) {
+  if (FunctionTypes.size() && !SeenCodeSection) {
     return make_error<GenericBinaryError>(
         "Linking data must come after code section",
         object_error::parse_failed);
@@ -940,6 +940,7 @@
 Error WasmObjectFile::parseFunctionSection(ReadContext &Ctx) {
   uint32_t Count = readVaruint32(Ctx);
   FunctionTypes.reserve(Count);
+  Functions.resize(Count);
   uint32_t NumTypes = Signatures.size();
   while (Count--) {
     uint32_t Type = readVaruint32(Ctx);
@@ -1029,9 +1030,11 @@
     Ex.Index = readVaruint32(Ctx);
     switch (Ex.Kind) {
     case wasm::WASM_EXTERNAL_FUNCTION:
-      if (!isValidFunctionIndex(Ex.Index))
+
+      if (!isDefinedFunctionIndex(Ex.Index))
         return make_error<GenericBinaryError>("Invalid function export",
                                               object_error::parse_failed);
+      getDefinedFunction(Ex.Index).ExportName = Ex.Name;
       break;
     case wasm::WASM_EXTERNAL_GLOBAL:
       if (!isValidGlobalIndex(Ex.Index))
@@ -1132,6 +1135,7 @@
 }
 
 Error WasmObjectFile::parseCodeSection(ReadContext &Ctx) {
+  SeenCodeSection = true;
   CodeSection = Sections.size();
   uint32_t FunctionCount = readVaruint32(Ctx);
   if (FunctionCount != FunctionTypes.size()) {
@@ -1139,14 +1143,14 @@
                                           object_error::parse_failed);
   }
 
-  while (FunctionCount--) {
-    wasm::WasmFunction Function;
+  for (uint32_t i = 0; i < FunctionCount; i++) {
+    wasm::WasmFunction& Function = Functions[i];
     const uint8_t *FunctionStart = Ctx.Ptr;
     uint32_t Size = readVaruint32(Ctx);
     const uint8_t *FunctionEnd = Ctx.Ptr + Size;
 
     Function.CodeOffset = Ctx.Ptr - FunctionStart;
-    Function.Index = NumImportedFunctions + Functions.size();
+    Function.Index = NumImportedFunctions + i;
     Function.CodeSectionOffset = FunctionStart - Ctx.Start;
     Function.Size = FunctionEnd - FunctionStart;
 
@@ -1165,7 +1169,6 @@
     Function.Comdat = UINT32_MAX;
     Ctx.Ptr += BodySize;
     assert(Ctx.Ptr == FunctionEnd);
-    Functions.push_back(Function);
   }
   if (Ctx.Ptr != Ctx.End)
     return make_error<GenericBinaryError>("Code section ended prematurely",
diff --git a/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
index 138ce85..1f0bdde 100644
--- a/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
+++ b/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
@@ -712,6 +712,18 @@
       return expect(AsmToken::EndOfStatement, "EOL");
     }
 
+    if (DirectiveID.getString() == ".export_name") {
+      auto SymName = expectIdent();
+      if (SymName.empty())
+        return true;
+      if (expect(AsmToken::Comma, ","))
+        return true;
+      auto ExportName = expectIdent();
+      auto WasmSym = cast<MCSymbolWasm>(Ctx.getOrCreateSymbol(SymName));
+      WasmSym->setExportName(ExportName);
+      TOut.emitExportName(WasmSym, ExportName);
+    }
+
     if (DirectiveID.getString() == ".import_module") {
       auto SymName = expectIdent();
       if (SymName.empty())
diff --git a/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp b/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp
index 4092620..7c21ed5 100644
--- a/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp
+++ b/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.cpp
@@ -94,6 +94,12 @@
                            << ImportName << '\n';
 }
 
+void WebAssemblyTargetAsmStreamer::emitExportName(const MCSymbolWasm *Sym,
+                                                  StringRef ExportName) {
+  OS << "\t.export_name\t" << Sym->getName() << ", "
+                           << ExportName << '\n';
+}
+
 void WebAssemblyTargetAsmStreamer::emitIndIdx(const MCExpr *Value) {
   OS << "\t.indidx  \t" << *Value << '\n';
 }
diff --git a/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h b/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h
index 0164f8e..9aee1a0 100644
--- a/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h
+++ b/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyTargetStreamer.h
@@ -48,6 +48,9 @@
   /// .import_name
   virtual void emitImportName(const MCSymbolWasm *Sym,
                               StringRef ImportName) = 0;
+  /// .export_name
+  virtual void emitExportName(const MCSymbolWasm *Sym,
+                              StringRef ExportName) = 0;
 
 protected:
   void emitValueType(wasm::ValType Type);
@@ -68,6 +71,7 @@
   void emitEventType(const MCSymbolWasm *Sym) override;
   void emitImportModule(const MCSymbolWasm *Sym, StringRef ImportModule) override;
   void emitImportName(const MCSymbolWasm *Sym, StringRef ImportName) override;
+  void emitExportName(const MCSymbolWasm *Sym, StringRef ExportName) override;
 };
 
 /// This part is for Wasm object output
@@ -85,6 +89,8 @@
                         StringRef ImportModule) override {}
   void emitImportName(const MCSymbolWasm *Sym,
                       StringRef ImportName) override {}
+  void emitExportName(const MCSymbolWasm *Sym,
+                      StringRef ExportName) override {}
 };
 
 /// This part is for null output
@@ -101,6 +107,7 @@
   void emitEventType(const MCSymbolWasm *) override {}
   void emitImportModule(const MCSymbolWasm *, StringRef) override {}
   void emitImportName(const MCSymbolWasm *, StringRef) override {}
+  void emitExportName(const MCSymbolWasm *, StringRef) override {}
 };
 
 } // end namespace llvm
diff --git a/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
index 5d8b873..cb95d5d 100644
--- a/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
+++ b/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
@@ -96,8 +96,11 @@
   }
 
   for (const auto &F : M) {
+    if (F.isIntrinsic())
+      continue;
+
     // Emit function type info for all undefined functions
-    if (F.isDeclarationForLinker() && !F.isIntrinsic()) {
+    if (F.isDeclarationForLinker()) {
       SmallVector<MVT, 4> Results;
       SmallVector<MVT, 4> Params;
       computeSignatureVTs(F.getFunctionType(), F, TM, Params, Results);
@@ -130,6 +133,13 @@
         getTargetStreamer()->emitImportName(Sym, Name);
       }
     }
+
+    if (F.hasFnAttribute("wasm-export-name")) {
+      auto *Sym = cast<MCSymbolWasm>(getSymbol(&F));
+      StringRef Name = F.getFnAttribute("wasm-export-name").getValueAsString();
+      Sym->setExportName(Name);
+      getTargetStreamer()->emitExportName(Sym, Name);
+    }
   }
 
   for (const auto &G : M.globals()) {
diff --git a/test/CodeGen/WebAssembly/export-name.ll b/test/CodeGen/WebAssembly/export-name.ll
new file mode 100644
index 0000000..d1d4c21
--- /dev/null
+++ b/test/CodeGen/WebAssembly/export-name.ll
@@ -0,0 +1,17 @@
+; RUN: llc < %s -asm-verbose=false -wasm-keep-registers | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
+target triple = "wasm32-unknown-unknown"
+
+define void @test() #0 {
+  ret void
+}
+
+declare void @test2() #1
+
+
+attributes #0 = { "wasm-export-name"="foo" }
+attributes #1 = { "wasm-export-name"="bar" }
+
+; CHECK: .export_name test, foo
+; CHECK: .export_name test2, bar
diff --git a/test/MC/WebAssembly/export-name.s b/test/MC/WebAssembly/export-name.s
new file mode 100644
index 0000000..51e1bcf
--- /dev/null
+++ b/test/MC/WebAssembly/export-name.s
@@ -0,0 +1,26 @@
+# RUN: llvm-mc -triple=wasm32-unknown-unknown < %s | FileCheck %s
+# Check that it also comiled to object for format.
+# RUN: llvm-mc -triple=wasm32-unknown-unknown -filetype=obj -o - < %s | obj2yaml | FileCheck -check-prefix=CHECK-OBJ %s
+
+foo:
+    .globl foo
+    .functype foo () -> ()
+    .export_name foo, bar
+    end_function
+
+# CHECK: .export_name foo, bar
+
+# CHECK-OBJ:        - Type:            EXPORT
+# CHECK-OBJ-NEXT:     Exports:
+# CHECK-OBJ-NEXT:       - Name:            bar
+# CHECK-OBJ-NEXT:         Kind:            FUNCTION
+# CHECK-OBJ-NEXT:         Index:           0
+
+# CHECK-OBJ:          Name:            linking
+# CHECK-OBJ-NEXT:     Version:         2
+# CHECK-OBJ-NEXT:     SymbolTable:
+# CHECK-OBJ-NEXT:       - Index:           0
+# CHECK-OBJ-NEXT:         Kind:            FUNCTION
+# CHECK-OBJ-NEXT:         Name:            foo
+# CHECK-OBJ-NEXT:         Flags:           [ EXPORTED ]
+# CHECK-OBJ-NEXT:         Function:        0