[llvm-objcopy] [COFF] Add support for removing symbols

Differential Revision: https://reviews.llvm.org/D55881

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@350893 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/test/tools/llvm-objcopy/COFF/Inputs/strip-symbols.yaml b/test/tools/llvm-objcopy/COFF/Inputs/strip-symbols.yaml
new file mode 100644
index 0000000..1d652691
--- /dev/null
+++ b/test/tools/llvm-objcopy/COFF/Inputs/strip-symbols.yaml
@@ -0,0 +1,53 @@
+--- !COFF
+header:          
+  Machine:         IMAGE_FILE_MACHINE_AMD64
+  Characteristics: [  ]
+sections:        
+  - Name:            .text
+    Characteristics: [  ]
+    Alignment:       4
+    SectionData:     488B0500000000488B0500000000488B0500000000
+    Relocations:     
+      - VirtualAddress:  3
+        SymbolTableIndex: 0
+        Type:            IMAGE_REL_AMD64_REL32
+      - VirtualAddress:  10
+        SymbolTableIndex: 1
+        Type:            IMAGE_REL_AMD64_REL32
+      - VirtualAddress:  17
+        SymbolName:      foo
+        Type:            IMAGE_REL_AMD64_REL32
+  - Name:            .rdata
+    Characteristics: [  ]
+    Alignment:       1
+    SectionData:     '00'
+  - Name:            .rdata
+    Characteristics: [  ]
+    Alignment:       1
+    SectionData:     '01'
+symbols:         
+  - Name:            .rdata
+    Value:           0
+    SectionNumber:   2
+    SimpleType:      IMAGE_SYM_TYPE_NULL
+    ComplexType:     IMAGE_SYM_DTYPE_NULL
+    StorageClass:    IMAGE_SYM_CLASS_STATIC
+  - Name:            .rdata
+    Value:           0
+    SectionNumber:   3
+    SimpleType:      IMAGE_SYM_TYPE_NULL
+    ComplexType:     IMAGE_SYM_DTYPE_NULL
+    StorageClass:    IMAGE_SYM_CLASS_STATIC
+  - Name:            mainfunc
+    Value:           0
+    SectionNumber:   1
+    SimpleType:      IMAGE_SYM_TYPE_NULL
+    ComplexType:     IMAGE_SYM_DTYPE_NULL
+    StorageClass:    IMAGE_SYM_CLASS_EXTERNAL
+  - Name:            foo
+    Value:           0
+    SectionNumber:   3
+    SimpleType:      IMAGE_SYM_TYPE_NULL
+    ComplexType:     IMAGE_SYM_DTYPE_NULL
+    StorageClass:    IMAGE_SYM_CLASS_EXTERNAL
+...
diff --git a/test/tools/llvm-objcopy/COFF/strip-reloc-symbol.test b/test/tools/llvm-objcopy/COFF/strip-reloc-symbol.test
new file mode 100644
index 0000000..50740d9
--- /dev/null
+++ b/test/tools/llvm-objcopy/COFF/strip-reloc-symbol.test
@@ -0,0 +1,5 @@
+# RUN: yaml2obj %p/Inputs/strip-symbols.yaml > %t.o
+# RUN: not llvm-objcopy -N foo %t.o 2>&1 | FileCheck %s --check-prefix=ERROR
+# RUN: not llvm-objcopy --strip-symbol foo %t.o 2>&1 | FileCheck %s --check-prefix=ERROR
+
+# ERROR: error: '{{.*}}/strip-reloc-symbol.test.tmp.o': not stripping symbol 'foo' because it is named in a relocation.
diff --git a/test/tools/llvm-objcopy/COFF/strip-symbol.test b/test/tools/llvm-objcopy/COFF/strip-symbol.test
new file mode 100644
index 0000000..5355b22
--- /dev/null
+++ b/test/tools/llvm-objcopy/COFF/strip-symbol.test
@@ -0,0 +1,32 @@
+# RUN: yaml2obj %p/Inputs/strip-symbols.yaml > %t.in.o
+
+# RUN: llvm-readobj -relocations %t.in.o | FileCheck %s --check-prefixes=RELOCS,RELOCS-PRE
+# RUN: llvm-objdump -t %t.in.o | FileCheck %s --check-prefixes=SYMBOLS,SYMBOLS-PRE
+
+# RUN: llvm-objcopy -N mainfunc %t.in.o %t.out.o
+# RUN: llvm-readobj -relocations %t.out.o | FileCheck %s --check-prefixes=RELOCS,RELOCS-POST
+# RUN: llvm-objdump -t %t.out.o | FileCheck %s --check-prefix=SYMBOLS
+
+# RUN: llvm-objcopy --strip-symbol mainfunc %t.in.o %t.out.o
+# RUN: llvm-readobj -relocations %t.out.o | FileCheck %s --check-prefixes=RELOCS,RELOCS-POST
+# RUN: llvm-objdump -t %t.out.o | FileCheck %s --check-prefix=SYMBOLS
+
+# Explicitly listing the relocations for the input as well, to show
+# that the symbol index of the symbol foo is updated in the relocations,
+# while keeping relocations to two distinct .rdata symbols separate.
+
+# RELOCS:      Relocations [
+# RELOCS-NEXT:   Section (1) .text {
+# RELOCS-NEXT:     0x3 IMAGE_REL_AMD64_REL32 .rdata (0)
+# RELOCS-NEXT:     0xA IMAGE_REL_AMD64_REL32 .rdata (1)
+# RELOCS-PRE-NEXT:  0x11 IMAGE_REL_AMD64_REL32 foo (3)
+# RELOCS-POST-NEXT: 0x11 IMAGE_REL_AMD64_REL32 foo (2)
+# RELOCS-NEXT:   }
+# RELOCS-NEXT: ]
+
+# SYMBOLS: SYMBOL TABLE:
+# SYMBOLS-NEXT: .rdata
+# SYMBOLS-NEXT: .rdata
+# SYMBOLS-PRE-NEXT: mainfunc
+# SYMBOLS-NEXT: foo
+# SYMBOLS-EMPTY:
diff --git a/tools/llvm-objcopy/CMakeLists.txt b/tools/llvm-objcopy/CMakeLists.txt
index 3b6c345..1beb737 100644
--- a/tools/llvm-objcopy/CMakeLists.txt
+++ b/tools/llvm-objcopy/CMakeLists.txt
@@ -18,6 +18,7 @@
   CopyConfig.cpp
   llvm-objcopy.cpp
   COFF/COFFObjcopy.cpp
+  COFF/Object.cpp
   COFF/Reader.cpp
   COFF/Writer.cpp
   ELF/ELFObjcopy.cpp
diff --git a/tools/llvm-objcopy/COFF/COFFObjcopy.cpp b/tools/llvm-objcopy/COFF/COFFObjcopy.cpp
index 9ed965c..9087cf6 100644
--- a/tools/llvm-objcopy/COFF/COFFObjcopy.cpp
+++ b/tools/llvm-objcopy/COFF/COFFObjcopy.cpp
@@ -17,6 +17,7 @@
 
 #include "llvm/Object/Binary.h"
 #include "llvm/Object/COFF.h"
+#include "llvm/Support/Errc.h"
 #include <cassert>
 
 namespace llvm {
@@ -26,6 +27,30 @@
 using namespace object;
 using namespace COFF;
 
+static Error handleArgs(const CopyConfig &Config, Object &Obj) {
+  // If we need to do per-symbol removals, initialize the Referenced field.
+  if (!Config.SymbolsToRemove.empty())
+    if (Error E = Obj.markSymbols())
+      return E;
+
+  // Actually do removals of symbols.
+  Obj.removeSymbols([&](const Symbol &Sym) {
+    if (is_contained(Config.SymbolsToRemove, Sym.Name)) {
+      // Explicitly removing a referenced symbol is an error.
+      if (Sym.Referenced)
+        reportError(Config.OutputFilename,
+                    make_error<StringError>(
+                        "not stripping symbol '" + Sym.Name +
+                            "' because it is named in a relocation.",
+                        llvm::errc::invalid_argument));
+      return true;
+    }
+
+    return false;
+  });
+  return Error::success();
+}
+
 void executeObjcopyOnBinary(const CopyConfig &Config,
                             object::COFFObjectFile &In, Buffer &Out) {
   COFFReader Reader(In);
@@ -34,6 +59,8 @@
     reportError(Config.InputFilename, ObjOrErr.takeError());
   Object *Obj = ObjOrErr->get();
   assert(Obj && "Unable to deserialize COFF object");
+  if (Error E = handleArgs(Config, *Obj))
+    reportError(Config.InputFilename, std::move(E));
   COFFWriter Writer(*Obj, Out);
   if (Error E = Writer.write())
     reportError(Config.OutputFilename, std::move(E));
diff --git a/tools/llvm-objcopy/COFF/Object.cpp b/tools/llvm-objcopy/COFF/Object.cpp
new file mode 100644
index 0000000..315d3a7
--- /dev/null
+++ b/tools/llvm-objcopy/COFF/Object.cpp
@@ -0,0 +1,70 @@
+//===- Object.cpp ---------------------------------------------------------===//
+//
+//                      The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Object.h"
+#include <algorithm>
+
+namespace llvm {
+namespace objcopy {
+namespace coff {
+
+using namespace object;
+
+void Object::addSymbols(ArrayRef<Symbol> NewSymbols) {
+  for (Symbol S : NewSymbols) {
+    S.UniqueId = NextSymbolUniqueId++;
+    Symbols.emplace_back(S);
+  }
+  updateSymbols();
+}
+
+void Object::updateSymbols() {
+  SymbolMap = DenseMap<size_t, Symbol *>(Symbols.size());
+  size_t RawSymIndex = 0;
+  for (Symbol &Sym : Symbols) {
+    SymbolMap[Sym.UniqueId] = &Sym;
+    Sym.RawIndex = RawSymIndex;
+    RawSymIndex += 1 + Sym.Sym.NumberOfAuxSymbols;
+  }
+}
+
+const Symbol *Object::findSymbol(size_t UniqueId) const {
+  auto It = SymbolMap.find(UniqueId);
+  if (It == SymbolMap.end())
+    return nullptr;
+  return It->second;
+}
+
+void Object::removeSymbols(function_ref<bool(const Symbol &)> ToRemove) {
+  Symbols.erase(
+      std::remove_if(std::begin(Symbols), std::end(Symbols),
+                     [ToRemove](const Symbol &Sym) { return ToRemove(Sym); }),
+      std::end(Symbols));
+  updateSymbols();
+}
+
+Error Object::markSymbols() {
+  for (Symbol &Sym : Symbols)
+    Sym.Referenced = false;
+  for (const Section &Sec : Sections) {
+    for (const Relocation &R : Sec.Relocs) {
+      auto It = SymbolMap.find(R.Target);
+      if (It == SymbolMap.end())
+        return make_error<StringError>("Relocation target " + Twine(R.Target) +
+                                           " not found",
+                                       object_error::invalid_symbol_index);
+      It->second->Referenced = true;
+    }
+  }
+  return Error::success();
+}
+
+} // end namespace coff
+} // end namespace objcopy
+} // end namespace llvm
diff --git a/tools/llvm-objcopy/COFF/Object.h b/tools/llvm-objcopy/COFF/Object.h
index 89f9903..ca1ff7f 100644
--- a/tools/llvm-objcopy/COFF/Object.h
+++ b/tools/llvm-objcopy/COFF/Object.h
@@ -11,7 +11,9 @@
 #define LLVM_TOOLS_OBJCOPY_COFF_OBJECT_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/iterator_range.h"
 #include "llvm/BinaryFormat/COFF.h"
 #include "llvm/Object/COFF.h"
 #include <cstddef>
@@ -22,10 +24,16 @@
 namespace objcopy {
 namespace coff {
 
+struct Relocation {
+  object::coff_relocation Reloc;
+  size_t Target;
+  StringRef TargetName; // Used for diagnostics only
+};
+
 struct Section {
   object::coff_section Header;
   ArrayRef<uint8_t> Contents;
-  std::vector<object::coff_relocation> Relocs;
+  std::vector<Relocation> Relocs;
   StringRef Name;
 };
 
@@ -33,6 +41,9 @@
   object::coff_symbol32 Sym;
   StringRef Name;
   ArrayRef<uint8_t> AuxData;
+  size_t UniqueId;
+  size_t RawIndex;
+  bool Referenced;
 };
 
 struct Object {
@@ -49,7 +60,31 @@
 
   std::vector<object::data_directory> DataDirectories;
   std::vector<Section> Sections;
+
+  ArrayRef<Symbol> getSymbols() const { return Symbols; }
+  // This allows mutating individual Symbols, but not mutating the list
+  // of symbols itself.
+  iterator_range<std::vector<Symbol>::iterator> getMutableSymbols() {
+    return make_range(Symbols.begin(), Symbols.end());
+  }
+
+  const Symbol *findSymbol(size_t UniqueId) const;
+
+  void addSymbols(ArrayRef<Symbol> NewSymbols);
+  void removeSymbols(function_ref<bool(const Symbol &)> ToRemove);
+
+  // Set the Referenced field on all Symbols, based on relocations in
+  // all sections.
+  Error markSymbols();
+
+private:
   std::vector<Symbol> Symbols;
+  DenseMap<size_t, Symbol *> SymbolMap;
+
+  size_t NextSymbolUniqueId = 0;
+
+  // Update SymbolMap and RawIndex in each Symbol.
+  void updateSymbols();
 };
 
 // Copy between coff_symbol16 and coff_symbol32.
diff --git a/tools/llvm-objcopy/COFF/Reader.cpp b/tools/llvm-objcopy/COFF/Reader.cpp
index 5d80596..76b3f73 100644
--- a/tools/llvm-objcopy/COFF/Reader.cpp
+++ b/tools/llvm-objcopy/COFF/Reader.cpp
@@ -73,7 +73,7 @@
       return errorCodeToError(EC);
     ArrayRef<coff_relocation> Relocs = COFFObj.getRelocations(Sec);
     for (const coff_relocation &R : Relocs)
-      S.Relocs.push_back(R);
+      S.Relocs.push_back(Relocation{R});
     if (auto EC = COFFObj.getSectionName(Sec, S.Name))
       return errorCodeToError(EC);
     if (Sec->hasExtendedRelocations())
@@ -84,14 +84,16 @@
 }
 
 Error COFFReader::readSymbols(Object &Obj, bool IsBigObj) const {
+  std::vector<Symbol> Symbols;
+  Symbols.reserve(COFFObj.getRawNumberOfSymbols());
   for (uint32_t I = 0, E = COFFObj.getRawNumberOfSymbols(); I < E;) {
     Expected<COFFSymbolRef> SymOrErr = COFFObj.getSymbol(I);
     if (!SymOrErr)
       return SymOrErr.takeError();
     COFFSymbolRef SymRef = *SymOrErr;
 
-    Obj.Symbols.push_back(Symbol());
-    Symbol &Sym = Obj.Symbols.back();
+    Symbols.push_back(Symbol());
+    Symbol &Sym = Symbols.back();
     // Copy symbols from the original form into an intermediate coff_symbol32.
     if (IsBigObj)
       copySymbol(Sym.Sym,
@@ -106,6 +108,30 @@
             (IsBigObj ? sizeof(coff_symbol32) : sizeof(coff_symbol16))) == 0);
     I += 1 + SymRef.getNumberOfAuxSymbols();
   }
+  Obj.addSymbols(Symbols);
+  return Error::success();
+}
+
+Error COFFReader::setRelocTargets(Object &Obj) const {
+  std::vector<const Symbol *> RawSymbolTable;
+  for (const Symbol &Sym : Obj.getSymbols()) {
+    RawSymbolTable.push_back(&Sym);
+    for (size_t I = 0; I < Sym.Sym.NumberOfAuxSymbols; I++)
+      RawSymbolTable.push_back(nullptr);
+  }
+  for (Section &Sec : Obj.Sections) {
+    for (Relocation &R : Sec.Relocs) {
+      if (R.Reloc.SymbolTableIndex >= RawSymbolTable.size())
+        return make_error<StringError>("SymbolTableIndex out of range",
+                                       object_error::parse_failed);
+      const Symbol *Sym = RawSymbolTable[R.Reloc.SymbolTableIndex];
+      if (Sym == nullptr)
+        return make_error<StringError>("Invalid SymbolTableIndex",
+                                       object_error::parse_failed);
+      R.Target = Sym->UniqueId;
+      R.TargetName = Sym->Name;
+    }
+  }
   return Error::success();
 }
 
@@ -136,6 +162,8 @@
     return std::move(E);
   if (Error E = readSymbols(*Obj, IsBigObj))
     return std::move(E);
+  if (Error E = setRelocTargets(*Obj))
+    return std::move(E);
 
   return std::move(Obj);
 }
diff --git a/tools/llvm-objcopy/COFF/Reader.h b/tools/llvm-objcopy/COFF/Reader.h
index f2b7f78..c972a14 100644
--- a/tools/llvm-objcopy/COFF/Reader.h
+++ b/tools/llvm-objcopy/COFF/Reader.h
@@ -35,6 +35,7 @@
   Error readExecutableHeaders(Object &Obj) const;
   Error readSections(Object &Obj) const;
   Error readSymbols(Object &Obj, bool IsBigObj) const;
+  Error setRelocTargets(Object &Obj) const;
 
 public:
   explicit COFFReader(const COFFObjectFile &O) : COFFObj(O) {}
diff --git a/tools/llvm-objcopy/COFF/Writer.cpp b/tools/llvm-objcopy/COFF/Writer.cpp
index 388c62c..d7a5224 100644
--- a/tools/llvm-objcopy/COFF/Writer.cpp
+++ b/tools/llvm-objcopy/COFF/Writer.cpp
@@ -27,6 +27,21 @@
 
 Writer::~Writer() {}
 
+Error COFFWriter::finalizeRelocTargets() {
+  for (Section &Sec : Obj.Sections) {
+    for (Relocation &R : Sec.Relocs) {
+      const Symbol *Sym = Obj.findSymbol(R.Target);
+      if (Sym == nullptr)
+        return make_error<StringError>("Relocation target " + R.TargetName +
+                                           " (" + Twine(R.Target) +
+                                           ") not found",
+                                       object_error::invalid_symbol_index);
+      R.Reloc.SymbolTableIndex = Sym->RawIndex;
+    }
+  }
+  return Error::success();
+}
+
 void COFFWriter::layoutSections() {
   for (auto &S : Obj.Sections) {
     if (S.Header.SizeOfRawData > 0)
@@ -48,7 +63,7 @@
     if (S.Name.size() > COFF::NameSize)
       StrTabBuilder.add(S.Name);
 
-  for (const auto &S : Obj.Symbols)
+  for (const auto &S : Obj.getSymbols())
     if (S.Name.size() > COFF::NameSize)
       StrTabBuilder.add(S.Name);
 
@@ -62,7 +77,7 @@
       strncpy(S.Header.Name, S.Name.data(), COFF::NameSize);
     }
   }
-  for (auto &S : Obj.Symbols) {
+  for (auto &S : Obj.getMutableSymbols()) {
     if (S.Name.size() > COFF::NameSize) {
       S.Sym.Name.Offset.Zeroes = 0;
       S.Sym.Name.Offset.Offset = StrTabBuilder.getOffset(S.Name);
@@ -75,13 +90,16 @@
 
 template <class SymbolTy>
 std::pair<size_t, size_t> COFFWriter::finalizeSymbolTable() {
-  size_t SymTabSize = Obj.Symbols.size() * sizeof(SymbolTy);
-  for (const auto &S : Obj.Symbols)
+  size_t SymTabSize = Obj.getSymbols().size() * sizeof(SymbolTy);
+  for (const auto &S : Obj.getSymbols())
     SymTabSize += S.AuxData.size();
   return std::make_pair(SymTabSize, sizeof(SymbolTy));
 }
 
-void COFFWriter::finalize(bool IsBigObj) {
+Error COFFWriter::finalize(bool IsBigObj) {
+  if (Error E = finalizeRelocTargets())
+    return E;
+
   size_t SizeOfHeaders = 0;
   FileAlignment = 1;
   size_t PeHeaderSize = 0;
@@ -149,6 +167,8 @@
   Obj.CoffFileHeader.NumberOfSymbols = NumRawSymbols;
   FileSize += SymTabSize + StrTabSize;
   FileSize = alignTo(FileSize, FileAlignment);
+
+  return Error::success();
 }
 
 void COFFWriter::writeHeaders(bool IsBigObj) {
@@ -225,14 +245,16 @@
              S.Header.SizeOfRawData - S.Contents.size());
 
     Ptr += S.Header.SizeOfRawData;
-    std::copy(S.Relocs.begin(), S.Relocs.end(),
-              reinterpret_cast<coff_relocation *>(Ptr));
+    for (const auto &R : S.Relocs) {
+      memcpy(Ptr, &R.Reloc, sizeof(R.Reloc));
+      Ptr += sizeof(R.Reloc);
+    }
   }
 }
 
 template <class SymbolTy> void COFFWriter::writeSymbolStringTables() {
   uint8_t *Ptr = Buf.getBufferStart() + Obj.CoffFileHeader.PointerToSymbolTable;
-  for (const auto &S : Obj.Symbols) {
+  for (const auto &S : Obj.getSymbols()) {
     // Convert symbols back to the right size, from coff_symbol32.
     copySymbol<SymbolTy, coff_symbol32>(*reinterpret_cast<SymbolTy *>(Ptr),
                                         S.Sym);
@@ -248,7 +270,8 @@
 }
 
 Error COFFWriter::write(bool IsBigObj) {
-  finalize(IsBigObj);
+  if (Error E = finalize(IsBigObj))
+    return E;
 
   Buf.allocate(FileSize);
 
diff --git a/tools/llvm-objcopy/COFF/Writer.h b/tools/llvm-objcopy/COFF/Writer.h
index f2e814a..a2612ca 100644
--- a/tools/llvm-objcopy/COFF/Writer.h
+++ b/tools/llvm-objcopy/COFF/Writer.h
@@ -40,11 +40,12 @@
   size_t SizeOfInitializedData;
   StringTableBuilder StrTabBuilder;
 
+  Error finalizeRelocTargets();
   void layoutSections();
   size_t finalizeStringTable();
   template <class SymbolTy> std::pair<size_t, size_t> finalizeSymbolTable();
 
-  void finalize(bool IsBigObj);
+  Error finalize(bool IsBigObj);
 
   void writeHeaders(bool IsBigObj);
   void writeSections();