//=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//==-----------------------------------------------------------------------===//
//
/// \file
/// Implements helpers for target-independent rematerialization at the MIR
/// level.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/Rematerializer.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Support/Debug.h"
#include <optional>

#define DEBUG_TYPE "rematerializer"

using namespace llvm;
using RegisterIdx = Rematerializer::RegisterIdx;

// Pin the vtable to this file.
void Rematerializer::Listener::anchor() {}

/// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this
/// implies it is also live there). When \p LI has sub-ranges, checks that
/// all sub-ranges intersecting with \p Mask are also live at \p UseIdx.
static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask,
                             SlotIndex UseIdx, const LiveInterval &LI) {
  if (&OVNI != LI.getVNInfoAt(UseIdx))
    return false;

  if (LI.hasSubRanges()) {
    // Check that intersecting subranges are live at user.
    for (const LiveInterval::SubRange &SR : LI.subranges()) {
      if ((SR.LaneMask & Mask).none())
        continue;
      if (!SR.liveAt(UseIdx))
        return false;

      // Early exit if all used lanes are checked. No need to continue.
      Mask &= ~SR.LaneMask;
      if (Mask.none())
        break;
    }
  }
  return true;
}

/// If \p MO is a virtual read register, returns it. Otherwise returns the
/// sentinel register.
static Register getRegDependency(const MachineOperand &MO) {
  if (!MO.isReg() || !MO.readsReg())
    return Register();
  Register Reg = MO.getReg();
  if (Reg.isPhysical()) {
    // By the requirements on trivially rematerializable instructions, a
    // physical register use is either constant or ignorable.
    return Register();
  }
  return Reg;
}

RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx,
                                                  unsigned UseRegion,
                                                  DependencyReuseInfo &DRI) {
  MachineInstr *FirstMI =
      getReg(RootIdx).getRegionUseBounds(UseRegion, LIS).first;
  // If there are no users in the region, rematerialize the register at the very
  // end of the region.
  MachineBasicBlock::iterator InsertPos =
      FirstMI ? FirstMI : Regions[UseRegion].second;
  RegisterIdx NewRegIdx =
      rematerializeToPos(RootIdx, UseRegion, InsertPos, DRI);
  transferRegionUsers(RootIdx, NewRegIdx, UseRegion);
  return NewRegIdx;
}

RegisterIdx
Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion,
                                   MachineBasicBlock::iterator InsertPos,
                                   DependencyReuseInfo &DRI) {
  assert(!DRI.DependencyMap.contains(RootIdx));
  LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << '\n');

  SmallVector<Reg::Dependency, 2> NewDeps;
  // Copy all dependencies because recursive rematerialization of dependencies
  // may invalidate references to the backing vector of registers.
  SmallVector<Reg::Dependency, 2> OldDeps(getReg(RootIdx).Dependencies);
  for (const Reg::Dependency &Dep : OldDeps) {
    // Recursively rematerialize required dependencies at the same position as
    // the root. Registers form a DAG so the recursion is guaranteed to
    // terminate.
    auto RematIdx = DRI.DependencyMap.find(Dep.RegIdx);
    RegisterIdx NewDepRegIdx;
    if (RematIdx == DRI.DependencyMap.end())
      NewDepRegIdx = rematerializeToPos(Dep.RegIdx, UseRegion, InsertPos, DRI);
    else
      NewDepRegIdx = RematIdx->second;
    NewDeps.emplace_back(Dep.MOIdx, NewDepRegIdx);
  }
  RegisterIdx NewIdx =
      rematerializeReg(RootIdx, UseRegion, InsertPos, std::move(NewDeps));
  DRI.DependencyMap.insert({RootIdx, NewIdx});
  return NewIdx;
}

void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
                                  unsigned UserRegion, MachineInstr &UserMI) {
  transferUserImpl(FromRegIdx, ToRegIdx, UserMI);
  Regs[FromRegIdx].eraseUser(&UserMI, UserRegion);
  Regs[ToRegIdx].addUser(&UserMI, UserRegion);
  deleteRegIfUnused(FromRegIdx);
}

void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
                                         RegisterIdx ToRegIdx,
                                         unsigned UseRegion) {
  auto &FromRegUsers = Regs[FromRegIdx].Uses;
  auto UsesIt = FromRegUsers.find(UseRegion);
  if (UsesIt == FromRegUsers.end())
    return;

  const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond();
  for (MachineInstr *UserMI : RegionUsers)
    transferUserImpl(FromRegIdx, ToRegIdx, *UserMI);
  Regs[ToRegIdx].addUsers(RegionUsers, UseRegion);
  FromRegUsers.erase(UseRegion);
  deleteRegIfUnused(FromRegIdx);
}

void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx,
                                      RegisterIdx ToRegIdx) {
  Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx];
  for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) {
    for (MachineInstr *UserMI : RegionUsers)
      transferUserImpl(FromRegIdx, ToRegIdx, *UserMI);
    ToReg.addUsers(RegionUsers, UseRegion);
  }
  FromReg.Uses.clear();
  deleteRegIfUnused(FromRegIdx);
}

void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
                                      RegisterIdx ToRegIdx,
                                      MachineInstr &UserMI) {
  assert(FromRegIdx != ToRegIdx && "identical registers");
  assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) &&
         "unrelated registers");

  LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to "
                    << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n');

  UserMI.substituteRegister(getReg(FromRegIdx).getDefReg(),
                            getReg(ToRegIdx).getDefReg(), 0, TRI);
  LISUpdates.insert(FromRegIdx);
  LISUpdates.insert(ToRegIdx);

  // If the user is rematerializable, we must change its dependency to the
  // new register.
  if (RegisterIdx UserRegIdx = getDefRegIdx(UserMI); UserRegIdx != NoReg) {
    // Look for the user's dependency that matches the register.
    for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) {
      if (Dep.RegIdx == FromRegIdx) {
        Dep.RegIdx = ToRegIdx;
        return;
      }
    }
    llvm_unreachable("broken dependency");
  }
}

void Rematerializer::updateLiveIntervals() {
  DenseSet<Register> SeenUnrematRegs;
  for (RegisterIdx RegIdx : LISUpdates) {
    const Reg &UpdateReg = getReg(RegIdx);
    assert(UpdateReg.isAlive() && "dead register");

    Register DefReg = UpdateReg.getDefReg();
    if (LIS.hasInterval(DefReg))
      LIS.removeInterval(DefReg);
    // Rematerializable registers have a single definition by construction so
    // re-creating their interval cannot yield a live interval with multiple
    // connected components.
    LIS.createAndComputeVirtRegInterval(DefReg);

    LLVM_DEBUG({
      dbgs() << "Re-computed interval for " << printID(RegIdx) << ": ";
      LIS.getInterval(DefReg).print(dbgs());
      dbgs() << '\n' << printRegUsers(RegIdx);
    });

    // Update intervals for unrematerializable operands.
    for (unsigned MOIdx : getUnrematableOprds(RegIdx)) {
      Register UnrematReg = UpdateReg.DefMI->getOperand(MOIdx).getReg();
      if (!SeenUnrematRegs.insert(UnrematReg).second)
        continue;
      LIS.removeInterval(UnrematReg);
      bool NeedSplit = false;

      // Unrematerializable registers may end up with multiple connected
      // components in their live interval after it is re-created. It needs to
      // be split in such cases. We don't track unrematerializable registers by
      // their actual register index (just by operand index) so we do not need
      // to update any state in the rematerializer.
      LiveInterval &LI =
          LIS.createAndComputeVirtRegInterval(UnrematReg, NeedSplit);
      if (NeedSplit) {
        SmallVector<LiveInterval *> SplitLIs;
        LIS.splitSeparateComponents(LI, SplitLIs);
      }
      LLVM_DEBUG(
          dbgs() << "  Re-computed interval for register "
                 << printReg(UnrematReg, &TRI,
                             UpdateReg.DefMI->getOperand(MOIdx).getSubReg(),
                             &MRI)
                 << '\n');
    }
  }
  LISUpdates.clear();
}

bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
                                         ArrayRef<SlotIndex> Uses) const {
  if (Uses.empty())
    return true;
  Register Reg = MO.getReg();
  unsigned SubIdx = MO.getSubReg();
  LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
                            : MRI.getMaxLaneMaskForVReg(Reg);
  const LiveInterval &LI = LIS.getInterval(Reg);
  const VNInfo *DefVN =
      LI.getVNInfoAt(LIS.getInstructionIndex(*MO.getParent()).getRegSlot(true));
  for (SlotIndex Use : Uses) {
    if (!isIdenticalAtUse(*DefVN, Mask, Use, LI))
      return false;
  }
  return true;
}

RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx,
                                              unsigned Region,
                                              SlotIndex Before) const {
  auto It = Rematerializations.find(getOriginOrSelf(RegIdx));
  if (It == Rematerializations.end())
    return NoReg;
  const RematsOf &Remats = It->getSecond();

  SlotIndex BestSlot;
  RegisterIdx BestRegIdx = NoReg;
  for (RegisterIdx RematRegIdx : Remats) {
    const Reg &RematReg = getReg(RematRegIdx);
    if (RematReg.DefRegion != Region || RematReg.Uses.empty())
      continue;
    SlotIndex RematRegSlot =
        LIS.getInstructionIndex(*RematReg.DefMI).getRegSlot();
    if (RematRegSlot < Before &&
        (BestRegIdx == NoReg || RematRegSlot > BestSlot)) {
      BestSlot = RematRegSlot;
      BestRegIdx = RematRegIdx;
    }
  }
  return BestRegIdx;
}

void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
  if (!getReg(RootIdx).Uses.empty())
    return;

  // Traverse the root's dependency DAG depth-first to find the set of registers
  // we can delete and a legal order to delete them in.
  SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
  SmallSetVector<RegisterIdx, 8> DeleteOrder;
  DeleteOrder.insert(RootIdx);
  do {
    // A deleted register's dependencies may be deletable too.
    const Reg &DeleteReg = getReg(DepDAG.pop_back_val());
    for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
      // All dependencies loose a user (the deleted register).
      Reg &DepReg = Regs[Dep.RegIdx];
      DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion);
      if (DepReg.Uses.empty()) {
        DeleteOrder.insert(Dep.RegIdx);
        DepDAG.push_back(Dep.RegIdx);
      }
    }
  } while (!DepDAG.empty());

  for (RegisterIdx RegIdx : reverse(DeleteOrder)) {
    Reg &DeleteReg = Regs[RegIdx];

    // It is possible that the defined register we are deleting doesn't have an
    // interval yet if the LIS hasn't been updated since it was created.
    Register DefReg = DeleteReg.getDefReg();
    if (LIS.hasInterval(DefReg))
      LIS.removeInterval(DefReg);
    LISUpdates.erase(RegIdx);

    deleteReg(RegIdx);
    if (isRematerializedRegister(RegIdx)) {
      // Delete rematerialized register from its origin's rematerializations.
      RematsOf &OriginRemats = Rematerializations.at(getOriginOf(RegIdx));
      assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
      OriginRemats.erase(RegIdx);
      if (OriginRemats.empty())
        Rematerializations.erase(RegIdx);
    }
    LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n");
  }
}

void Rematerializer::deleteReg(RegisterIdx RegIdx) {
  noteRegDeleted(RegIdx);

  Reg &DeleteReg = Regs[RegIdx];
  assert(DeleteReg.DefMI && "register was already deleted");
  // It is not possible for the deleted instruction to be the upper region
  // boundary since we don't ever consider them rematerializable.
  MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first;
  if (RegionBegin == DeleteReg.DefMI)
    RegionBegin = std::next(MachineBasicBlock::iterator(DeleteReg.DefMI));
  LIS.RemoveMachineInstrFromMaps(*DeleteReg.DefMI);
  DeleteReg.DefMI->eraseFromParent();
  DeleteReg.DefMI = nullptr;
}

Rematerializer::Rematerializer(MachineFunction &MF,
                               SmallVectorImpl<RegionBoundaries> &Regions,
                               LiveIntervals &LIS)
    : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS),
      TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) {
#ifdef EXPENSIVE_CHECKS
  // Check that regions are valid.
  DenseSet<MachineInstr *> SeenMIs;
  for (const auto &[RegionBegin, RegionEnd] : Regions) {
    assert(RegionBegin != RegionEnd && "empty region");
    for (auto MI = RegionBegin; MI != RegionEnd; ++MI) {
      bool IsNewMI = SeenMIs.insert(&*MI).second;
      assert(IsNewMI && "overlapping regions");
      assert(!MI->isTerminator() && "terminator in region");
    }
    if (RegionEnd != RegionBegin->getParent()->end()) {
      bool IsNewMI = SeenMIs.insert(&*RegionEnd).second;
      assert(IsNewMI && "overlapping regions (upper bound)");
    }
  }
#endif
}

bool Rematerializer::analyze() {
  Regs.clear();
  UnrematableOprds.clear();
  Origins.clear();
  Rematerializations.clear();
  RegionMBB.clear();
  RegToIdx.clear();
  LISUpdates.clear();
  if (Regions.empty())
    return false;

  /// Maps all MIs to their parent region. Region terminators are considered
  /// part of the region they terminate.
  DenseMap<MachineInstr *, unsigned> MIRegion;

  // Initialize MI to containing region mapping.
  RegionMBB.reserve(Regions.size());
  for (unsigned I = 0, E = Regions.size(); I < E; ++I) {
    RegionBoundaries Region = Regions[I];
    assert(Region.first != Region.second && "empty cannot be region");
    for (auto MI = Region.first; MI != Region.second; ++MI) {
      assert(!MIRegion.contains(&*MI) && "regions should not intersect");
      MIRegion.insert({&*MI, I});
    }
    MachineBasicBlock &MBB = *Region.first->getParent();
    RegionMBB.push_back(&MBB);

    // A terminator instruction is considered part of the region it terminates.
    if (Region.second != MBB.end()) {
      MachineInstr *RegionTerm = &*Region.second;
      assert(!MIRegion.contains(RegionTerm) && "regions should not intersect");
      MIRegion.insert({RegionTerm, I});
    }
  }

  const unsigned NumVirtRegs = MRI.getNumVirtRegs();
  BitVector SeenRegs(NumVirtRegs);
  for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) {
    if (!SeenRegs[I])
      addRegIfRematerializable(I, MIRegion, SeenRegs);
  }
  assert(Regs.size() == UnrematableOprds.size());

  LLVM_DEBUG({
    for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I)
      dbgs() << printDependencyDAG(I) << '\n';
  });
  return !Regs.empty();
}

void Rematerializer::addRegIfRematerializable(
    unsigned VirtRegIdx, const DenseMap<MachineInstr *, unsigned> &MIRegion,
    BitVector &SeenRegs) {
  assert(!SeenRegs[VirtRegIdx] && "register already seen");
  Register DefReg = Register::index2VirtReg(VirtRegIdx);
  SeenRegs.set(VirtRegIdx);

  MachineOperand *MO = MRI.getOneDef(DefReg);
  if (!MO)
    return;
  MachineInstr &DefMI = *MO->getParent();
  if (!isMIRematerializable(DefMI))
    return;
  auto DefRegion = MIRegion.find(&DefMI);
  if (DefRegion == MIRegion.end())
    return;

  Reg RematReg;
  RematReg.DefMI = &DefMI;
  RematReg.DefRegion = DefRegion->second;
  unsigned SubIdx = DefMI.getOperand(0).getSubReg();
  RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
                         : MRI.getMaxLaneMaskForVReg(DefReg);

  // Collect the candidate's direct users, both rematerializable and
  // unrematerializable. MIs outside provided regions cannot be tracked so the
  // registers they use are not safely rematerializable.
  for (MachineInstr &UseMI : MRI.use_nodbg_instructions(DefReg)) {
    if (auto UseRegion = MIRegion.find(&UseMI); UseRegion != MIRegion.end())
      RematReg.addUser(&UseMI, UseRegion->second);
    else
      return;
  }
  if (RematReg.Uses.empty())
    return;

  // Collect the candidate's dependencies. If the same register is used
  // multiple times we just need to consider it once.
  SmallDenseSet<Register, 4> AllDepRegs;
  SmallVector<unsigned, 2> UnrematDeps;
  for (const auto &[MOIdx, MO] : enumerate(RematReg.DefMI->operands())) {
    Register DepReg = getRegDependency(MO);
    if (!DepReg || !AllDepRegs.insert(DepReg).second)
      continue;
    unsigned DepRegIdx = DepReg.virtRegIndex();
    if (!SeenRegs[DepRegIdx])
      addRegIfRematerializable(DepRegIdx, MIRegion, SeenRegs);
    if (auto DepIt = RegToIdx.find(DepReg); DepIt != RegToIdx.end())
      RematReg.Dependencies.push_back(Reg::Dependency(MOIdx, DepIt->second));
    else
      UnrematDeps.push_back(MOIdx);
  }

  // The register is rematerializable.
  RegToIdx.insert({DefReg, Regs.size()});
  Regs.push_back(RematReg);
  UnrematableOprds.push_back(UnrematDeps);
}

bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const {
  if (!TII.isReMaterializable(MI))
    return false;

  assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual");
  assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def");

  for (const MachineOperand &MO : MI.all_uses()) {
    // We can't remat physreg uses, unless it is a constant or an ignorable
    // use (e.g. implicit exec use on VALU instructions)
    if (MO.getReg().isPhysical()) {
      if (MRI.isConstantPhysReg(MO.getReg()) || TII.isIgnorableUse(MO))
        continue;
      return false;
    }
  }

  return true;
}

RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const {
  if (!MI.getNumOperands() || !MI.getOperand(0).isReg() ||
      MI.getOperand(0).readsReg())
    return NoReg;
  Register Reg = MI.getOperand(0).getReg();
  auto UserRegIt = RegToIdx.find(Reg);
  if (UserRegIt == RegToIdx.end())
    return NoReg;
  return UserRegIt->second;
}

RegisterIdx Rematerializer::rematerializeReg(
    RegisterIdx RegIdx, unsigned UseRegion,
    MachineBasicBlock::iterator InsertPos,
    SmallVectorImpl<Reg::Dependency> &&Dependencies) {
  RegisterIdx NewRegIdx = Regs.size();

  Reg &NewReg = Regs.emplace_back();
  Reg &FromReg = Regs[RegIdx];
  NewReg.Mask = FromReg.Mask;
  NewReg.DefRegion = UseRegion;
  NewReg.Dependencies = std::move(Dependencies);

  // Track rematerialization link between registers. Origins are always
  // registers that existed originally, and rematerializations are always
  // attached to them.
  const RegisterIdx OriginIdx = getOriginOrSelf(RegIdx);
  Origins.push_back(OriginIdx);
  Rematerializations[OriginIdx].insert(NewRegIdx);

  // Use the TII to rematerialize the defining instruction with a new defined
  // register.
  Register NewDefReg = MRI.cloneVirtualRegister(FromReg.getDefReg());
  TII.reMaterialize(*RegionMBB[UseRegion], InsertPos, NewDefReg, 0,
                    *FromReg.DefMI);
  NewReg.DefMI = &*std::prev(InsertPos);
  RegToIdx.insert({NewDefReg, NewRegIdx});
  postRematerialization(RegIdx, NewRegIdx, InsertPos);

  noteRegCreated(NewRegIdx);
  LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
                    << printRematReg(NewRegIdx) << '\n');
  return NewRegIdx;
}

void Rematerializer::recreateReg(
    RegisterIdx RegIdx, unsigned DefRegion,
    MachineBasicBlock::iterator InsertPos, Register DefReg,
    SmallVectorImpl<Reg::Dependency> &&Dependencies) {
  assert(RegToIdx.contains(DefReg) && "unknown defined register");
  assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register");
  assert(!getReg(RegIdx).DefMI && "register is still alive");

  Reg &OriginReg = Regs[RegIdx];
  OriginReg.DefRegion = DefRegion;
  OriginReg.Dependencies = std::move(Dependencies);

  // Re-establish the link between origin and rematerialization if necessary.
  const bool RecreateOriginalReg = isOriginalRegister(RegIdx);
  if (!RecreateOriginalReg)
    Rematerializations[getOriginOf(RegIdx)].insert(RegIdx);

  // Rematerialize from one of the existing rematerializations or from the
  // origin. We expect at least one to exist, otherwise it would mean the value
  // held by the original register is no longer available anywhere in the MF.
  RegisterIdx ModelRegIdx;
  if (RecreateOriginalReg) {
    assert(Rematerializations.contains(RegIdx) && "expected remats");
    ModelRegIdx = *Rematerializations.at(RegIdx).begin();
  } else {
    assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin");
    ModelRegIdx = getOriginOf(RegIdx);
  }
  const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI;

  TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI);
  OriginReg.DefMI = &*std::prev(InsertPos);
  postRematerialization(ModelRegIdx, RegIdx, InsertPos);
  LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as "
                    << printRematReg(RegIdx) << '\n');
}

void Rematerializer::postRematerialization(
    RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx,
    MachineBasicBlock::iterator InsertPos) {

  // The start of the new register's region may have changed.
  Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx];
  LIS.InsertMachineInstrInMaps(*RematReg.DefMI);
  MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first;
  if (RegionBegin == std::next(MachineBasicBlock::iterator(RematReg.DefMI)))
    RegionBegin = RematReg.DefMI;

  // Replace dependencies as needed in the rematerialized MI. All dependencies
  // of the latter gain a new user.
  auto ZipedDeps = zip_equal(ModelReg.Dependencies, RematReg.Dependencies);
  for (const auto &[OldDep, NewDep] : ZipedDeps) {
    assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
    LLVM_DEBUG(dbgs() << "  Operand #" << OldDep.MOIdx << ": "
                      << printID(OldDep.RegIdx) << " -> "
                      << printID(NewDep.RegIdx) << '\n');

    Reg &NewDepReg = Regs[NewDep.RegIdx];
    if (OldDep.RegIdx != NewDep.RegIdx) {
      Register OldDefReg = ModelReg.DefMI->getOperand(OldDep.MOIdx).getReg();
      RematReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0,
                                         TRI);
      LISUpdates.insert(OldDep.RegIdx);
    }
    NewDepReg.addUser(RematReg.DefMI, RematReg.DefRegion);
    LISUpdates.insert(NewDep.RegIdx);
  }
}

std::pair<MachineInstr *, MachineInstr *>
Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion,
                                        const LiveIntervals &LIS) const {
  auto It = Uses.find(UseRegion);
  if (It == Uses.end())
    return {nullptr, nullptr};
  const RegionUsers &RegionUsers = It->getSecond();
  assert(!RegionUsers.empty() && "empty userset in region");

  auto User = RegionUsers.begin(), UserEnd = RegionUsers.end();
  MachineInstr *FirstMI = *User, *LastMI = FirstMI;
  SlotIndex FirstIndex = LIS.getInstructionIndex(*FirstMI),
            LastIndex = FirstIndex;

  while (++User != UserEnd) {
    SlotIndex UserIndex = LIS.getInstructionIndex(**User);
    if (UserIndex < FirstIndex) {
      FirstIndex = UserIndex;
      FirstMI = *User;
    } else if (UserIndex > LastIndex) {
      LastIndex = UserIndex;
      LastMI = *User;
    }
  }

  return {FirstMI, LastMI};
}

void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) {
  Uses[Region].insert(MI);
}

void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers,
                                   unsigned Region) {
  Uses[Region].insert_range(NewUsers);
}

void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) {
  RegionUsers &RUsers = Uses.at(Region);
  assert(RUsers.contains(MI) && "user not in region");
  if (RUsers.size() == 1)
    Uses.erase(Region);
  else
    RUsers.erase(MI);
}

Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const {
  return Printable([&, RootIdx](raw_ostream &OS) {
    DenseMap<RegisterIdx, unsigned> RegDepths;
    std::function<void(RegisterIdx, unsigned)> WalkTree =
        [&](RegisterIdx RegIdx, unsigned Depth) -> void {
      unsigned MaxDepth = std::max(RegDepths.lookup_or(RegIdx, Depth), Depth);
      RegDepths.emplace_or_assign(RegIdx, MaxDepth);
      for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
        WalkTree(Dep.RegIdx, Depth + 1);
    };
    WalkTree(RootIdx, 0);

    // Sort in decreasing depth order to print root at the bottom.
    SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(),
                                                       RegDepths.end());
    sort(Regs, [](const auto &LHS, const auto &RHS) {
      return LHS.second > RHS.second;
    });

    OS << printID(RootIdx) << " has " << Regs.size() - 1 << " dependencies\n";
    for (const auto &[RegIdx, Depth] : Regs) {
      OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' '
         << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n';
    }
    OS << printRegUsers(RootIdx);
  });
}

Printable Rematerializer::printID(RegisterIdx RegIdx) const {
  return Printable([&, RegIdx](raw_ostream &OS) {
    const Reg &PrintReg = getReg(RegIdx);
    OS << '(' << RegIdx << '/';
    if (!PrintReg.DefMI) {
      OS << "<dead>";
    } else {
      OS << printReg(PrintReg.getDefReg(), &TRI,
                     PrintReg.DefMI->getOperand(0).getSubReg(), &MRI);
    }
    OS << ")[" << PrintReg.DefRegion << "]";
  });
}

Printable Rematerializer::printRematReg(RegisterIdx RegIdx,
                                        bool SkipRegions) const {
  return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) {
    const Reg &PrintReg = getReg(RegIdx);
    if (!SkipRegions) {
      OS << printID(RegIdx) << " [" << PrintReg.DefRegion;
      if (!PrintReg.Uses.empty()) {
        assert(PrintReg.DefMI && "dead register cannot have uses");
        const LiveInterval &LI = LIS.getInterval(PrintReg.getDefReg());
        // First display all regions in which the register is live-through and
        // not used.
        bool First = true;
        for (const auto [I, Bounds] : enumerate(Regions)) {
          if (Bounds.first == Bounds.second)
            continue;
          if (!PrintReg.Uses.contains(I) &&
              LI.liveAt(LIS.getInstructionIndex(*Bounds.first)) &&
              LI.liveAt(LIS.getInstructionIndex(*std::prev(Bounds.second))
                            .getRegSlot())) {
            OS << (First ? " - " : ",") << I;
            First = false;
          }
        }
        OS << (First ? " --> " : " -> ");

        // Then display regions in which the register is used.
        auto It = PrintReg.Uses.begin();
        OS << It->first;
        while (++It != PrintReg.Uses.end())
          OS << "," << It->first;
      }
      OS << "] ";
    }
    OS << printID(RegIdx) << ' ';
    PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
                          /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
    OS << " @ ";
    LIS.getInstructionIndex(*PrintReg.DefMI).print(OS);
  });
}

Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const {
  return Printable([&, RegIdx](raw_ostream &OS) {
    for (const auto &[UseRegion, Users] : getReg(RegIdx).Uses) {
      for (MachineInstr *MI : Users)
        OS << "  User " << printUser(MI, UseRegion) << '\n';
    }
  });
}

Printable Rematerializer::printUser(const MachineInstr *MI,
                                    std::optional<unsigned> UseRegion) const {
  return Printable([&, MI, UseRegion](raw_ostream &OS) {
    RegisterIdx RegIdx = getDefRegIdx(*MI);
    if (RegIdx != NoReg) {
      OS << printID(RegIdx);
    } else {
      OS << "(-/-)[";
      if (UseRegion)
        OS << *UseRegion;
      else
        OS << '?';
      OS << ']';
    }
    OS << ' ';
    MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
              /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
    OS << " @ ";
    LIS.getInstructionIndex(*MI).print(OS);
  });
}

Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater,
                                       RegisterIdx RegIdx) {
  const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
  DefReg = Reg.getDefReg();
  DefRegion = Reg.DefRegion;
  Dependencies = Reg.Dependencies;

  InsertPos = std::next(Reg.DefMI->getIterator());
  if (InsertPos != Reg.DefMI->getParent()->end())
    NextRegIdx = Remater.getDefRegIdx(*InsertPos);
  else
    NextRegIdx = Rematerializer::NoReg;
}

void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
                                              RegisterIdx RegIdx) {
  if (RollingBack)
    return;
  Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx);
}

void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
                                              RegisterIdx RegIdx) {
  if (RollingBack || Remater.isRematerializedRegister(RegIdx))
    return;
  DeadRegs.try_emplace(RegIdx, Remater, RegIdx);
}

void Rollbacker::rollback(Rematerializer &Remater) {
  RollingBack = true;

  // Re-create deleted registers.
  for (auto &[RegIdx, Info] : DeadRegs) {
    assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead");

    // The MI that was originally just after the MI defining the register we
    // are trying to re-create may have been deleted. In such cases, we can
    // re-create at that MI's own insert position (and apply the same logic
    // recursively).
    MachineBasicBlock::iterator InsertPos = Info.InsertPos;
    RegisterIdx NextRegIdx = Info.NextRegIdx;
    while (NextRegIdx != Rematerializer::NoReg) {
      const auto *NextRegRollback = DeadRegs.find(NextRegIdx);
      if (NextRegRollback == DeadRegs.end())
        break;
      InsertPos = NextRegRollback->second.InsertPos;
      NextRegIdx = NextRegRollback->second.NextRegIdx;
    }
    Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
                        std::move(Info.Dependencies));
  }

  // Rollback rematerializations.
  for (const auto &[RegIdx, RematsOf] : Rematerializations) {
    for (RegisterIdx RematRegIdx : RematsOf) {
      // It is possible that rematerializations were deleted. Their users would
      // have been transfered to some other rematerialization so we can safely
      // ignore them. Original registers that were deleted were just re-created
      // so we do not need to check for that.
      if (Remater.getReg(RematRegIdx).isAlive())
        Remater.transferAllUsers(RematRegIdx, RegIdx);
    }
  }

  Remater.updateLiveIntervals();
  DeadRegs.clear();
  Rematerializations.clear();
  RollingBack = false;
}
