[X86] Refactor X86IndirectThunks.cpp to Accommodate Mitigations other than Retpoline

Introduce a ThunkInserter CRTP base class from which new thunk types can inherit, e.g., thunks to mitigate https://software.intel.com/security-software-guidance/software-guidance/load-value-injection.

Differential Revision: https://reviews.llvm.org/D76811
diff --git a/llvm/lib/Target/X86/X86IndirectThunks.cpp b/llvm/lib/Target/X86/X86IndirectThunks.cpp
index 0bf3493..e6408e9 100644
--- a/llvm/lib/Target/X86/X86IndirectThunks.cpp
+++ b/llvm/lib/Target/X86/X86IndirectThunks.cpp
@@ -51,6 +51,35 @@
 static const char EDIRetpolineName[]    = "__llvm_retpoline_edi";
 
 namespace {
+template <typename Derived> class ThunkInserter {
+  Derived &getDerived() { return *static_cast<Derived *>(this); }
+
+protected:
+  bool InsertedThunks;
+  void doInitialization(Module &M) {}
+  void createThunkFunction(MachineModuleInfo &MMI, StringRef Name);
+
+public:
+  void init(Module &M) {
+    InsertedThunks = false;
+    getDerived().doInitialization(M);
+  }
+  // return `true` if `MMI` or `MF` was modified
+  bool run(MachineModuleInfo &MMI, MachineFunction &MF);
+};
+
+struct RetpolineThunkInserter : ThunkInserter<RetpolineThunkInserter> {
+  const char *getThunkPrefix() { return RetpolineNamePrefix; }
+  bool mayUseThunk(const MachineFunction &MF) {
+    const auto &STI = MF.getSubtarget<X86Subtarget>();
+    return (STI.useRetpolineIndirectCalls() ||
+            STI.useRetpolineIndirectBranches()) &&
+           !STI.useRetpolineExternalThunk();
+  }
+  void insertThunks(MachineModuleInfo &MMI);
+  void populateThunk(MachineFunction &MF);
+};
+
 class X86IndirectThunks : public MachineFunctionPass {
 public:
   static char ID;
@@ -60,7 +89,7 @@
   StringRef getPassName() const override { return "X86 Indirect Thunks"; }
 
   bool doInitialization(Module &M) override;
-  bool runOnMachineFunction(MachineFunction &F) override;
+  bool runOnMachineFunction(MachineFunction &MF) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     MachineFunctionPass::getAnalysisUsage(AU);
@@ -69,78 +98,39 @@
   }
 
 private:
-  MachineModuleInfo *MMI = nullptr;
-  const TargetMachine *TM = nullptr;
-  bool Is64Bit = false;
-  const X86Subtarget *STI = nullptr;
-  const X86InstrInfo *TII = nullptr;
+  std::tuple<RetpolineThunkInserter> TIs;
 
-  bool InsertedThunks = false;
-
-  void createThunkFunction(Module &M, StringRef Name);
-  void insertRegReturnAddrClobber(MachineBasicBlock &MBB, Register Reg);
-  void populateThunk(MachineFunction &MF, Register Reg);
+  // FIXME: When LLVM moves to C++17, these can become folds
+  template <typename... ThunkInserterT>
+  static void initTIs(Module &M,
+                      std::tuple<ThunkInserterT...> &ThunkInserters) {
+    (void)std::initializer_list<int>{
+        (std::get<ThunkInserterT>(ThunkInserters).init(M), 0)...};
+  }
+  template <typename... ThunkInserterT>
+  static bool runTIs(MachineModuleInfo &MMI, MachineFunction &MF,
+                     std::tuple<ThunkInserterT...> &ThunkInserters) {
+    bool Modified = false;
+    (void)std::initializer_list<int>{
+        Modified |= std::get<ThunkInserterT>(ThunkInserters).run(MMI, MF)...};
+    return Modified;
+  }
 };
 
 } // end anonymous namespace
 
-FunctionPass *llvm::createX86IndirectThunksPass() {
-  return new X86IndirectThunks();
+void RetpolineThunkInserter::insertThunks(MachineModuleInfo &MMI) {
+  if (MMI.getTarget().getTargetTriple().getArch() == Triple::x86_64)
+    createThunkFunction(MMI, R11RetpolineName);
+  else
+    for (StringRef Name : {EAXRetpolineName, ECXRetpolineName, EDXRetpolineName,
+                           EDIRetpolineName})
+      createThunkFunction(MMI, Name);
 }
 
-char X86IndirectThunks::ID = 0;
-
-bool X86IndirectThunks::doInitialization(Module &M) {
-  InsertedThunks = false;
-  return false;
-}
-
-bool X86IndirectThunks::runOnMachineFunction(MachineFunction &MF) {
-  LLVM_DEBUG(dbgs() << getPassName() << '\n');
-
-  TM = &MF.getTarget();;
-  STI = &MF.getSubtarget<X86Subtarget>();
-  TII = STI->getInstrInfo();
-  Is64Bit = TM->getTargetTriple().getArch() == Triple::x86_64;
-
-  MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
-  Module &M = const_cast<Module &>(*MMI->getModule());
-
-  // If this function is not a thunk, check to see if we need to insert
-  // a thunk.
-  if (!MF.getName().startswith(RetpolineNamePrefix)) {
-    // If we've already inserted a thunk, nothing else to do.
-    if (InsertedThunks)
-      return false;
-
-    // Only add a thunk if one of the functions has the retpoline feature
-    // enabled in its subtarget, and doesn't enable external thunks.
-    // FIXME: Conditionalize on indirect calls so we don't emit a thunk when
-    // nothing will end up calling it.
-    // FIXME: It's a little silly to look at every function just to enumerate
-    // the subtargets, but eventually we'll want to look at them for indirect
-    // calls, so maybe this is OK.
-    if ((!STI->useRetpolineIndirectCalls() &&
-         !STI->useRetpolineIndirectBranches()) ||
-        STI->useRetpolineExternalThunk())
-      return false;
-
-    // Otherwise, we need to insert the thunk.
-    // WARNING: This is not really a well behaving thing to do in a function
-    // pass. We extract the module and insert a new function (and machine
-    // function) directly into the module.
-    if (Is64Bit)
-      createThunkFunction(M, R11RetpolineName);
-    else
-      for (StringRef Name :
-           {EAXRetpolineName, ECXRetpolineName, EDXRetpolineName,
-            EDIRetpolineName})
-        createThunkFunction(M, Name);
-    InsertedThunks = true;
-    return true;
-  }
-
-  // If this *is* a thunk function, we need to populate it with the correct MI.
+void RetpolineThunkInserter::populateThunk(MachineFunction &MF) {
+  bool Is64Bit = MF.getTarget().getTargetTriple().getArch() == Triple::x86_64;
+  Register ThunkReg;
   if (Is64Bit) {
     assert(MF.getName() == "__llvm_retpoline_r11" &&
            "Should only have an r11 thunk on 64-bit targets");
@@ -155,7 +145,7 @@
     // .Lr11_call_target:
     //   movq %r11, (%rsp)
     //   retq
-    populateThunk(MF, X86::R11);
+    ThunkReg = X86::R11;
   } else {
     // For 32-bit targets we need to emit a collection of thunks for various
     // possible scratch registers as well as a fallback that uses EDI, which is
@@ -185,24 +175,80 @@
     //         movl %edi, (%esp)
     //         retl
     if (MF.getName() == EAXRetpolineName)
-      populateThunk(MF, X86::EAX);
+      ThunkReg = X86::EAX;
     else if (MF.getName() == ECXRetpolineName)
-      populateThunk(MF, X86::ECX);
+      ThunkReg = X86::ECX;
     else if (MF.getName() == EDXRetpolineName)
-      populateThunk(MF, X86::EDX);
+      ThunkReg = X86::EDX;
     else if (MF.getName() == EDIRetpolineName)
-      populateThunk(MF, X86::EDI);
+      ThunkReg = X86::EDI;
     else
       llvm_unreachable("Invalid thunk name on x86-32!");
   }
 
-  return true;
+  const TargetInstrInfo *TII = MF.getSubtarget<X86Subtarget>().getInstrInfo();
+  // Grab the entry MBB and erase any other blocks. O0 codegen appears to
+  // generate two bbs for the entry block.
+  MachineBasicBlock *Entry = &MF.front();
+  Entry->clear();
+  while (MF.size() > 1)
+    MF.erase(std::next(MF.begin()));
+
+  MachineBasicBlock *CaptureSpec =
+      MF.CreateMachineBasicBlock(Entry->getBasicBlock());
+  MachineBasicBlock *CallTarget =
+      MF.CreateMachineBasicBlock(Entry->getBasicBlock());
+  MCSymbol *TargetSym = MF.getContext().createTempSymbol();
+  MF.push_back(CaptureSpec);
+  MF.push_back(CallTarget);
+
+  const unsigned CallOpc = Is64Bit ? X86::CALL64pcrel32 : X86::CALLpcrel32;
+  const unsigned RetOpc = Is64Bit ? X86::RETQ : X86::RETL;
+
+  Entry->addLiveIn(ThunkReg);
+  BuildMI(Entry, DebugLoc(), TII->get(CallOpc)).addSym(TargetSym);
+
+  // The MIR verifier thinks that the CALL in the entry block will fall through
+  // to CaptureSpec, so mark it as the successor. Technically, CaptureTarget is
+  // the successor, but the MIR verifier doesn't know how to cope with that.
+  Entry->addSuccessor(CaptureSpec);
+
+  // In the capture loop for speculation, we want to stop the processor from
+  // speculating as fast as possible. On Intel processors, the PAUSE instruction
+  // will block speculation without consuming any execution resources. On AMD
+  // processors, the PAUSE instruction is (essentially) a nop, so we also use an
+  // LFENCE instruction which they have advised will stop speculation as well
+  // with minimal resource utilization. We still end the capture with a jump to
+  // form an infinite loop to fully guarantee that no matter what implementation
+  // of the x86 ISA, speculating this code path never escapes.
+  BuildMI(CaptureSpec, DebugLoc(), TII->get(X86::PAUSE));
+  BuildMI(CaptureSpec, DebugLoc(), TII->get(X86::LFENCE));
+  BuildMI(CaptureSpec, DebugLoc(), TII->get(X86::JMP_1)).addMBB(CaptureSpec);
+  CaptureSpec->setHasAddressTaken();
+  CaptureSpec->addSuccessor(CaptureSpec);
+
+  CallTarget->addLiveIn(ThunkReg);
+  CallTarget->setHasAddressTaken();
+  CallTarget->setAlignment(Align(16));
+
+  // Insert return address clobber
+  const unsigned MovOpc = Is64Bit ? X86::MOV64mr : X86::MOV32mr;
+  const Register SPReg = Is64Bit ? X86::RSP : X86::ESP;
+  addRegOffset(BuildMI(CallTarget, DebugLoc(), TII->get(MovOpc)), SPReg, false,
+               0)
+      .addReg(ThunkReg);
+
+  CallTarget->back().setPreInstrSymbol(MF, TargetSym);
+  BuildMI(CallTarget, DebugLoc(), TII->get(RetOpc));
 }
 
-void X86IndirectThunks::createThunkFunction(Module &M, StringRef Name) {
-  assert(Name.startswith(RetpolineNamePrefix) &&
+template <typename Derived>
+void ThunkInserter<Derived>::createThunkFunction(MachineModuleInfo &MMI,
+                                                 StringRef Name) {
+  assert(Name.startswith(getDerived().getThunkPrefix()) &&
          "Created a thunk with an unexpected prefix!");
 
+  Module &M = const_cast<Module &>(*MMI.getModule());
   LLVMContext &Ctx = M.getContext();
   auto Type = FunctionType::get(Type::getVoidTy(Ctx), false);
   Function *F =
@@ -226,70 +272,56 @@
   // MachineFunctions/MachineBasicBlocks aren't created automatically for the
   // IR-level constructs we already made. Create them and insert them into the
   // module.
-  MachineFunction &MF = MMI->getOrCreateMachineFunction(*F);
+  MachineFunction &MF = MMI.getOrCreateMachineFunction(*F);
   MachineBasicBlock *EntryMBB = MF.CreateMachineBasicBlock(Entry);
 
   // Insert EntryMBB into MF. It's not in the module until we do this.
   MF.insert(MF.end(), EntryMBB);
-}
-
-void X86IndirectThunks::insertRegReturnAddrClobber(MachineBasicBlock &MBB,
-                                                   Register Reg) {
-  const unsigned MovOpc = Is64Bit ? X86::MOV64mr : X86::MOV32mr;
-  const Register SPReg = Is64Bit ? X86::RSP : X86::ESP;
-  addRegOffset(BuildMI(&MBB, DebugLoc(), TII->get(MovOpc)), SPReg, false, 0)
-      .addReg(Reg);
-}
-
-void X86IndirectThunks::populateThunk(MachineFunction &MF,
-                                      Register Reg) {
   // Set MF properties. We never use vregs...
   MF.getProperties().set(MachineFunctionProperties::Property::NoVRegs);
+}
 
-  // Grab the entry MBB and erase any other blocks. O0 codegen appears to
-  // generate two bbs for the entry block.
-  MachineBasicBlock *Entry = &MF.front();
-  Entry->clear();
-  while (MF.size() > 1)
-    MF.erase(std::next(MF.begin()));
+template <typename Derived>
+bool ThunkInserter<Derived>::run(MachineModuleInfo &MMI, MachineFunction &MF) {
+  // If MF is not a thunk, check to see if we need to insert a thunk.
+  if (!MF.getName().startswith(getDerived().getThunkPrefix())) {
+    // If we've already inserted a thunk, nothing else to do.
+    if (InsertedThunks)
+      return false;
 
-  MachineBasicBlock *CaptureSpec =
-      MF.CreateMachineBasicBlock(Entry->getBasicBlock());
-  MachineBasicBlock *CallTarget =
-      MF.CreateMachineBasicBlock(Entry->getBasicBlock());
-  MCSymbol *TargetSym = MF.getContext().createTempSymbol();
-  MF.push_back(CaptureSpec);
-  MF.push_back(CallTarget);
+    // Only add a thunk if one of the functions has the corresponding feature
+    // enabled in its subtarget, and doesn't enable external thunks.
+    // FIXME: Conditionalize on indirect calls so we don't emit a thunk when
+    // nothing will end up calling it.
+    // FIXME: It's a little silly to look at every function just to enumerate
+    // the subtargets, but eventually we'll want to look at them for indirect
+    // calls, so maybe this is OK.
+    if (!getDerived().mayUseThunk(MF))
+      return false;
 
-  const unsigned CallOpc = Is64Bit ? X86::CALL64pcrel32 : X86::CALLpcrel32;
-  const unsigned RetOpc = Is64Bit ? X86::RETQ : X86::RETL;
+    getDerived().insertThunks(MMI);
+    InsertedThunks = true;
+    return true;
+  }
 
-  Entry->addLiveIn(Reg);
-  BuildMI(Entry, DebugLoc(), TII->get(CallOpc)).addSym(TargetSym);
+  // If this *is* a thunk function, we need to populate it with the correct MI.
+  getDerived().populateThunk(MF);
+  return true;
+}
 
-  // The MIR verifier thinks that the CALL in the entry block will fall through
-  // to CaptureSpec, so mark it as the successor. Technically, CaptureTarget is
-  // the successor, but the MIR verifier doesn't know how to cope with that.
-  Entry->addSuccessor(CaptureSpec);
+FunctionPass *llvm::createX86IndirectThunksPass() {
+  return new X86IndirectThunks();
+}
 
-  // In the capture loop for speculation, we want to stop the processor from
-  // speculating as fast as possible. On Intel processors, the PAUSE instruction
-  // will block speculation without consuming any execution resources. On AMD
-  // processors, the PAUSE instruction is (essentially) a nop, so we also use an
-  // LFENCE instruction which they have advised will stop speculation as well
-  // with minimal resource utilization. We still end the capture with a jump to
-  // form an infinite loop to fully guarantee that no matter what implementation
-  // of the x86 ISA, speculating this code path never escapes.
-  BuildMI(CaptureSpec, DebugLoc(), TII->get(X86::PAUSE));
-  BuildMI(CaptureSpec, DebugLoc(), TII->get(X86::LFENCE));
-  BuildMI(CaptureSpec, DebugLoc(), TII->get(X86::JMP_1)).addMBB(CaptureSpec);
-  CaptureSpec->setHasAddressTaken();
-  CaptureSpec->addSuccessor(CaptureSpec);
+char X86IndirectThunks::ID = 0;
 
-  CallTarget->addLiveIn(Reg);
-  CallTarget->setHasAddressTaken();
-  CallTarget->setAlignment(Align(16));
-  insertRegReturnAddrClobber(*CallTarget, Reg);
-  CallTarget->back().setPreInstrSymbol(MF, TargetSym);
-  BuildMI(CallTarget, DebugLoc(), TII->get(RetOpc));
+bool X86IndirectThunks::doInitialization(Module &M) {
+  initTIs(M, TIs);
+  return false;
+}
+
+bool X86IndirectThunks::runOnMachineFunction(MachineFunction &MF) {
+  LLVM_DEBUG(dbgs() << getPassName() << '\n');
+  auto &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
+  return runTIs(MMI, MF, TIs);
 }