blob: 1c28c8afc52eb9adeb989172868740089be98826 [file] [log] [blame] [edit]
//=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//==-----------------------------------------------------------------------===//
//
/// \file
/// Implements helpers for target-independent rematerialization at the MIR
/// level.
//
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/Rematerializer.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "rematerializer"
using namespace llvm;
using RegisterIdx = Rematerializer::RegisterIdx;
/// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this
/// implies it is also live there). When \p LI has sub-ranges, checks that
/// all sub-ranges intersecting with \p Mask are also live at \p UseIdx.
static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask,
SlotIndex UseIdx, const LiveInterval &LI) {
if (&OVNI != LI.getVNInfoAt(UseIdx))
return false;
if (LI.hasSubRanges()) {
// Check that intersecting subranges are live at user.
for (const LiveInterval::SubRange &SR : LI.subranges()) {
if ((SR.LaneMask & Mask).none())
continue;
if (!SR.liveAt(UseIdx))
return false;
// Early exit if all used lanes are checked. No need to continue.
Mask &= ~SR.LaneMask;
if (Mask.none())
break;
}
}
return true;
}
/// If \p MO is a virtual read register, returns it. Otherwise returns the
/// sentinel register.
static Register getRegDependency(const MachineOperand &MO) {
if (!MO.isReg() || !MO.readsReg())
return Register();
Register Reg = MO.getReg();
if (Reg.isPhysical()) {
// By the requirements on trivially rematerializable instructions, a
// physical register use is either constant or ignorable.
return Register();
}
return Reg;
}
RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx,
unsigned UseRegion,
DependencyReuseInfo &DRI) {
MachineInstr *FirstMI =
getReg(RootIdx).getRegionUseBounds(UseRegion, LIS).first;
RegisterIdx NewRegIdx = rematerializeToPos(RootIdx, FirstMI, DRI);
transferRegionUsers(RootIdx, NewRegIdx, UseRegion);
return NewRegIdx;
}
RegisterIdx
Rematerializer::rematerializeToPos(RegisterIdx RootIdx,
MachineBasicBlock::iterator InsertPos,
DependencyReuseInfo &DRI) {
assert(!DRI.DependencyMap.contains(RootIdx));
LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << " to "
<< printUser(&*InsertPos) << '\n');
// Traverse the root's dependency DAG depth-first to find the set of
// registers we must rematerialize along with it and a legal order to
// rematerialize them in.
SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
SmallSetVector<RegisterIdx, 8> RematOrder;
RematOrder.insert(RootIdx);
do {
RegisterIdx RegIdx = DepDAG.pop_back_val();
for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) {
// The dependency may already have a rematerialization ready to use.
if (DRI.DependencyMap.contains(Dep.RegIdx))
continue;
// We may have already seen the dependency in the dependency DAG.
if (RematOrder.contains(Dep.RegIdx))
continue;
DepDAG.push_back(Dep.RegIdx);
RematOrder.insert(Dep.RegIdx);
}
} while (!DepDAG.empty());
// Rematerialize all necessary registers in the root's dependency DAG. At each
// rematerialization, dependencies should already be available.
RegisterIdx LastNewIdx;
for (RegisterIdx RegIdx : reverse(RematOrder)) {
assert(!DRI.DependencyMap.contains(RegIdx) && "useless remat");
SmallVector<Reg::Dependency, 2> Dependencies;
for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
Dependencies.emplace_back(Dep.MOIdx, DRI.DependencyMap.at(Dep.RegIdx));
LastNewIdx = rematerializeReg(RegIdx, InsertPos, std::move(Dependencies));
DRI.DependencyMap.insert({RegIdx, LastNewIdx});
}
return LastNewIdx;
}
void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) {
auto Remats = Rematerializations.find(RootIdx);
if (Remats == Rematerializations.end())
return;
LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx)
<< '\n');
reviveRegIfDead(RootIdx);
// All of the rematerialization's users must use the revived register.
for (RegisterIdx RematRegIdx : Remats->getSecond()) {
for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses)
transferRegionUsers(RematRegIdx, RootIdx, UseRegion);
}
Rematerializations.erase(RootIdx);
LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of "
<< printID(RootIdx) << '\n');
}
void Rematerializer::rollback(RegisterIdx RematIdx) {
assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) &&
"cannot rollback dead register");
const RegisterIdx OriginRegIdx = getOriginOf(RematIdx);
reviveRegIfDead(OriginRegIdx);
for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses)
transferRegionUsers(RematIdx, OriginRegIdx, UseRegion);
}
void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) {
if (getReg(RootIdx).isAlive())
return;
assert(Revivable.contains(RootIdx) && "not revivable");
// Traverse the root's dependency DAG depth-first to find the set of
// registers we must revive and a legal order to revive them in.
SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
SmallSetVector<RegisterIdx, 8> ReviveOrder;
ReviveOrder.insert(RootIdx);
do {
// All dependencies of a revived register need to be alive too.
const Reg &ReviveReg = getReg(DepDAG.pop_back_val());
for (const Reg::Dependency &Dep : ReviveReg.Dependencies) {
// We may have already seen the dependency in the dependency DAG.
if (ReviveOrder.contains(Dep.RegIdx))
continue;
// Dead dependencies need to be revived.
Reg &DepReg = Regs[Dep.RegIdx];
if (!DepReg.isAlive()) {
assert(Revivable.contains(Dep.RegIdx) && "not revivable");
ReviveOrder.insert(Dep.RegIdx);
DepDAG.push_back(Dep.RegIdx);
}
// All dependencies get a new user (the revived register).
DepReg.addUser(ReviveReg.DefMI, ReviveReg.DefRegion);
LISUpdates.insert(Dep.RegIdx);
}
} while (!DepDAG.empty());
for (RegisterIdx RegIdx : reverse(ReviveOrder)) {
// Pick any rematerialization to retrieve the original opcode from.
Reg &ReviveReg = Regs[RegIdx];
assert(Rematerializations.contains(RegIdx) && "no remats");
RegisterIdx RematIdx = *Rematerializations.at(RegIdx).begin();
ReviveReg.DefMI->setDesc(getReg(RematIdx).DefMI->getDesc());
for (const auto &[MOIdx, Reg] : Revivable.at(RegIdx))
ReviveReg.DefMI->getOperand(MOIdx).setReg(Reg);
Revivable.erase(RegIdx);
LISUpdates.insert(RegIdx);
LLVM_DEBUG({
dbgs() << "** Revived " << printID(RegIdx) << " @ ";
LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs());
dbgs() << '\n';
});
}
}
void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
MachineInstr &UserMI) {
transferUserImpl(FromRegIdx, ToRegIdx, UserMI);
unsigned UserRegion = MIRegion.at(&UserMI);
Regs[FromRegIdx].eraseUser(&UserMI, UserRegion);
Regs[ToRegIdx].addUser(&UserMI, UserRegion);
deleteRegIfUnused(FromRegIdx);
}
void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
RegisterIdx ToRegIdx,
unsigned UseRegion) {
auto &FromRegUsers = Regs[FromRegIdx].Uses;
auto UsesIt = FromRegUsers.find(UseRegion);
if (UsesIt == FromRegUsers.end())
return;
const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond();
for (MachineInstr *UserMI : RegionUsers)
transferUserImpl(FromRegIdx, ToRegIdx, *UserMI);
Regs[ToRegIdx].addUsers(RegionUsers, UseRegion);
FromRegUsers.erase(UseRegion);
deleteRegIfUnused(FromRegIdx);
}
void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
RegisterIdx ToRegIdx,
MachineInstr &UserMI) {
assert(MIRegion.contains(&UserMI) && "unknown user");
assert(getReg(FromRegIdx).Uses.at(MIRegion.at(&UserMI)).contains(&UserMI) &&
"not a user");
assert(FromRegIdx != ToRegIdx && "identical registers");
assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) &&
"unrelated registers");
LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to "
<< printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n');
UserMI.substituteRegister(getReg(FromRegIdx).getDefReg(),
getReg(ToRegIdx).getDefReg(), 0, TRI);
LISUpdates.insert(FromRegIdx);
LISUpdates.insert(ToRegIdx);
// If the user is rematerializable, we must change its dependency to the
// new register.
if (RegisterIdx UserRegIdx = getDefRegIdx(UserMI); UserRegIdx != NoReg) {
// Look for the user's dependency that matches the register.
for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) {
if (Dep.RegIdx == FromRegIdx) {
Dep.RegIdx = ToRegIdx;
return;
}
}
llvm_unreachable("broken dependency");
}
}
void Rematerializer::updateLiveIntervals() {
DenseSet<Register> SeenUnrematRegs;
for (RegisterIdx RegIdx : LISUpdates) {
const Reg &UpdateReg = getReg(RegIdx);
assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg");
Register DefReg = UpdateReg.getDefReg();
if (LIS.hasInterval(DefReg))
LIS.removeInterval(DefReg);
LIS.createAndComputeVirtRegInterval(DefReg);
LLVM_DEBUG({
dbgs() << "Re-computed interval for " << printID(RegIdx) << ": ";
LIS.getInterval(DefReg).print(dbgs());
dbgs() << '\n' << printRegUsers(RegIdx);
});
// Update intervals for unrematerializable operands.
for (unsigned MOIdx : getUnrematableOprds(RegIdx)) {
Register UnrematReg = UpdateReg.DefMI->getOperand(MOIdx).getReg();
if (!SeenUnrematRegs.insert(UnrematReg).second)
continue;
LIS.removeInterval(UnrematReg);
LIS.createAndComputeVirtRegInterval(UnrematReg);
LLVM_DEBUG(
dbgs() << " Re-computed interval for register "
<< printReg(UnrematReg, &TRI,
UpdateReg.DefMI->getOperand(MOIdx).getSubReg(),
&MRI)
<< '\n');
}
}
LISUpdates.clear();
}
void Rematerializer::commitRematerializations() {
for (auto &[RegIdx, _] : Revivable)
deleteReg(RegIdx);
Revivable.clear();
}
bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
ArrayRef<SlotIndex> Uses) const {
if (Uses.empty())
return true;
Register Reg = MO.getReg();
unsigned SubIdx = MO.getSubReg();
LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
: MRI.getMaxLaneMaskForVReg(Reg);
const LiveInterval &LI = LIS.getInterval(Reg);
const VNInfo *DefVN =
LI.getVNInfoAt(LIS.getInstructionIndex(*MO.getParent()).getRegSlot(true));
for (SlotIndex Use : Uses) {
if (!isIdenticalAtUse(*DefVN, Mask, Use, LI))
return false;
}
return true;
}
RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx,
unsigned Region,
SlotIndex Before) const {
auto It = Rematerializations.find(getOriginOrSelf(RegIdx));
if (It == Rematerializations.end())
return NoReg;
const RematsOf &Remats = It->getSecond();
SlotIndex BestSlot;
RegisterIdx BestRegIdx = NoReg;
for (RegisterIdx RematRegIdx : Remats) {
const Reg &RematReg = getReg(RematRegIdx);
if (RematReg.DefRegion != Region || RematReg.Uses.empty())
continue;
SlotIndex RematRegSlot =
LIS.getInstructionIndex(*RematReg.DefMI).getRegSlot();
if (RematRegSlot < Before &&
(BestRegIdx == NoReg || RematRegSlot > BestSlot)) {
BestSlot = RematRegSlot;
BestRegIdx = RematRegIdx;
}
}
return BestRegIdx;
}
void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
if (!getReg(RootIdx).Uses.empty())
return;
// Traverse the root's dependency DAG depth-first to find the set of registers
// we can delete and a legal order to delete them in.
SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
SmallSetVector<RegisterIdx, 8> DeleteOrder;
DeleteOrder.insert(RootIdx);
do {
// A deleted register's dependencies may be deletable too.
const Reg &DeleteReg = getReg(DepDAG.pop_back_val());
for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
// All dependencies loose a user (the delete register).
Reg &DepReg = Regs[Dep.RegIdx];
DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion);
if (DepReg.Uses.empty()) {
DeleteOrder.insert(Dep.RegIdx);
DepDAG.push_back(Dep.RegIdx);
}
}
} while (!DepDAG.empty());
for (RegisterIdx RegIdx : reverse(DeleteOrder)) {
Reg &DeleteReg = Regs[RegIdx];
LIS.removeInterval(DeleteReg.getDefReg());
LISUpdates.erase(RegIdx);
const bool IsRematerializedReg = isRematerializedRegister(RegIdx);
if (SupportRollback && !IsRematerializedReg) {
// Replace all read registers with the null one to prevent them from
// showing up in use-lists, which is disallowed for debug instructions in
// live interval calculations. Store mappings between operand indices and
// original registers for potential rollback.
DenseMap<unsigned, Register> &RegMap =
Revivable.try_emplace(RegIdx).first->getSecond();
for (auto [Idx, MO] : enumerate(DeleteReg.DefMI->operands())) {
if (MO.isReg() && MO.readsReg()) {
RegMap.insert({Idx, MO.getReg()});
MO.setReg(Register());
}
}
DeleteReg.DefMI->setDesc(TII.get(TargetOpcode::DBG_VALUE));
} else {
deleteReg(RegIdx);
}
if (IsRematerializedReg) {
// Delete rematerialized register from its origin's rematerializations.
RematsOf &OriginRemats = Rematerializations.at(getOriginOf(RegIdx));
assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
OriginRemats.erase(RegIdx);
if (OriginRemats.empty())
Rematerializations.erase(RegIdx);
}
LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n");
}
}
void Rematerializer::deleteReg(RegisterIdx RegIdx) {
Reg &DeleteReg = Regs[RegIdx];
assert(DeleteReg.DefMI && "register was already deleted");
// It is not possible for the deleted instruction to be the upper region
// boundary since we don't ever consider them rematerializable.
MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first;
if (RegionBegin == DeleteReg.DefMI)
RegionBegin = std::next(MachineBasicBlock::iterator(DeleteReg.DefMI));
LIS.RemoveMachineInstrFromMaps(*DeleteReg.DefMI);
DeleteReg.DefMI->eraseFromParent();
MIRegion.erase(DeleteReg.DefMI);
DeleteReg.DefMI = nullptr;
}
Rematerializer::Rematerializer(MachineFunction &MF,
SmallVectorImpl<RegionBoundaries> &Regions,
LiveIntervals &LIS)
: Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS),
TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) {
#ifdef EXPENSIVE_CHECKS
// Check that regions are valid.
DenseSet<MachineInstr *> SeenMIs;
for (const auto &[RegionBegin, RegionEnd] : Regions) {
assert(RegionBegin != RegionEnd && "empty region");
for (auto MI = RegionBegin; MI != RegionEnd; ++MI) {
bool IsNewMI = SeenMIs.insert(&*MI).second;
assert(IsNewMI && "overlapping regions");
assert(!MI->isTerminator() && "terminator in region");
}
if (RegionEnd != RegionBegin->getParent()->end()) {
bool IsNewMI = SeenMIs.insert(&*RegionEnd).second;
assert(IsNewMI && "overlapping regions (upper bound)");
}
}
#endif
}
bool Rematerializer::analyze(bool SupportRollback) {
Regs.clear();
UnrematableOprds.clear();
Origins.clear();
Rematerializations.clear();
MIRegion.clear();
RegToIdx.clear();
LISUpdates.clear();
Revivable.clear();
this->SupportRollback = SupportRollback;
if (Regions.empty())
return false;
// Initialize MI to containing region mapping.
for (unsigned I = 0, E = Regions.size(); I < E; ++I) {
RegionBoundaries Region = Regions[I];
assert(Region.first != Region.second && "empty cannot be region");
for (auto MI = Region.first; MI != Region.second; ++MI) {
assert(!MIRegion.contains(&*MI) && "regions should not intersect");
MIRegion.insert({&*MI, I});
}
// A terminator instruction is considered part of the region it terminates.
if (Region.second != Region.first->getParent()->end()) {
MachineInstr *RegionTerm = &*Region.second;
assert(!MIRegion.contains(RegionTerm) && "regions should not intersect");
MIRegion.insert({RegionTerm, I});
}
}
const unsigned NumVirtRegs = MRI.getNumVirtRegs();
BitVector SeenRegs(NumVirtRegs);
for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) {
if (!SeenRegs[I])
addRegIfRematerializable(I, SeenRegs);
}
assert(Regs.size() == UnrematableOprds.size());
LLVM_DEBUG({
for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I)
dbgs() << printDependencyDAG(I) << '\n';
});
return !Regs.empty();
}
void Rematerializer::addRegIfRematerializable(unsigned VirtRegIdx,
BitVector &SeenRegs) {
assert(!SeenRegs[VirtRegIdx] && "register already seen");
Register DefReg = Register::index2VirtReg(VirtRegIdx);
SeenRegs.set(VirtRegIdx);
MachineOperand *MO = MRI.getOneDef(DefReg);
if (!MO)
return;
MachineInstr &DefMI = *MO->getParent();
if (!isMIRematerializable(DefMI))
return;
auto DefRegion = MIRegion.find(&DefMI);
if (DefRegion == MIRegion.end())
return;
Reg RematReg;
RematReg.DefMI = &DefMI;
RematReg.DefRegion = DefRegion->second;
unsigned SubIdx = DefMI.getOperand(0).getSubReg();
RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
: MRI.getMaxLaneMaskForVReg(DefReg);
// Collect the candidate's direct users, both rematerializable and
// unrematerializable. MIs outside provided regions cannot be tracked so the
// registers they use are not safely rematerializable.
for (MachineInstr &UseMI : MRI.use_nodbg_instructions(DefReg)) {
if (auto UseRegion = MIRegion.find(&UseMI); UseRegion != MIRegion.end())
RematReg.addUser(&UseMI, UseRegion->second);
else
return;
}
if (RematReg.Uses.empty())
return;
// Collect the candidate's dependencies. If the same register is used
// multiple times we just need to consider it once.
SmallDenseSet<Register, 4> AllDepRegs;
SmallVector<unsigned, 2> UnrematDeps;
for (const auto &[MOIdx, MO] : enumerate(RematReg.DefMI->operands())) {
Register DepReg = getRegDependency(MO);
if (!DepReg || !AllDepRegs.insert(DepReg).second)
continue;
unsigned DepRegIdx = DepReg.virtRegIndex();
if (!SeenRegs[DepRegIdx])
addRegIfRematerializable(DepRegIdx, SeenRegs);
if (auto DepIt = RegToIdx.find(DepReg); DepIt != RegToIdx.end())
RematReg.Dependencies.push_back(Reg::Dependency(MOIdx, DepIt->second));
else
UnrematDeps.push_back(MOIdx);
}
// The register is rematerializable.
RegToIdx.insert({DefReg, Regs.size()});
Regs.push_back(RematReg);
UnrematableOprds.push_back(UnrematDeps);
}
bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const {
if (!TII.isReMaterializable(MI))
return false;
assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual");
assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def");
for (const MachineOperand &MO : MI.all_uses()) {
// We can't remat physreg uses, unless it is a constant or an ignorable
// use (e.g. implicit exec use on VALU instructions)
if (MO.getReg().isPhysical()) {
if (MRI.isConstantPhysReg(MO.getReg()) || TII.isIgnorableUse(MO))
continue;
return false;
}
}
return true;
}
RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const {
if (!MI.getNumOperands() || !MI.getOperand(0).isReg() ||
MI.getOperand(0).readsReg())
return NoReg;
Register Reg = MI.getOperand(0).getReg();
auto UserRegIt = RegToIdx.find(Reg);
if (UserRegIt == RegToIdx.end())
return NoReg;
return UserRegIt->second;
}
RegisterIdx Rematerializer::rematerializeReg(
RegisterIdx RegIdx, MachineBasicBlock::iterator InsertPos,
SmallVectorImpl<Reg::Dependency> &&Dependencies) {
unsigned UseRegion = MIRegion.at(&*InsertPos);
RegisterIdx NewRegIdx = Regs.size();
Reg &NewReg = Regs.emplace_back();
Reg &FromReg = Regs[RegIdx];
NewReg.Mask = FromReg.Mask;
NewReg.DefRegion = UseRegion;
NewReg.Dependencies = std::move(Dependencies);
// Track rematerialization link between registers. Origins are always
// registers that existed originally, and rematerializations are always
// attached to them.
RegisterIdx OriginIdx =
isRematerializedRegister(RegIdx) ? getOriginOf(RegIdx) : RegIdx;
Origins.push_back(OriginIdx);
Rematerializations[OriginIdx].insert(NewRegIdx);
// Use the TII to rematerialize the defining instruction with a new defined
// register.
Register NewDefReg = MRI.cloneVirtualRegister(FromReg.getDefReg());
TII.reMaterialize(*InsertPos->getParent(), InsertPos, NewDefReg, 0,
*FromReg.DefMI);
NewReg.DefMI = &*std::prev(InsertPos);
RegToIdx.insert({NewDefReg, NewRegIdx});
// Update the DAG.
RegionBoundaries &Bounds = Regions[UseRegion];
if (Bounds.first == std::next(MachineBasicBlock::iterator(NewReg.DefMI)))
Bounds.first = NewReg.DefMI;
LIS.InsertMachineInstrInMaps(*NewReg.DefMI);
MIRegion.emplace_or_assign(NewReg.DefMI, UseRegion);
LISUpdates.insert(NewRegIdx);
// Replace dependencies as needed in the rematerialized MI. All dependencies
// of the latter gain a new user.
auto ZipedDeps = zip_equal(FromReg.Dependencies, NewReg.Dependencies);
for (const auto &[OldDep, NewDep] : ZipedDeps) {
assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": "
<< printID(OldDep.RegIdx) << " -> "
<< printID(NewDep.RegIdx) << '\n');
Reg &NewDepReg = Regs[NewDep.RegIdx];
if (OldDep.RegIdx != NewDep.RegIdx) {
Register OldDefReg = FromReg.DefMI->getOperand(OldDep.MOIdx).getReg();
NewReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0,
TRI);
LISUpdates.insert(OldDep.RegIdx);
}
NewDepReg.addUser(NewReg.DefMI, UseRegion);
LISUpdates.insert(NewDep.RegIdx);
}
LLVM_DEBUG({
dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
<< printRematReg(NewRegIdx) << '\n';
});
return NewRegIdx;
}
std::pair<MachineInstr *, MachineInstr *>
Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion,
const LiveIntervals &LIS) const {
auto It = Uses.find(UseRegion);
if (It == Uses.end())
return {nullptr, nullptr};
const RegionUsers &RegionUsers = It->getSecond();
assert(!RegionUsers.empty() && "empty userset in region");
auto User = RegionUsers.begin(), UserEnd = RegionUsers.end();
MachineInstr *FirstMI = *User, *LastMI = FirstMI;
SlotIndex FirstIndex = LIS.getInstructionIndex(*FirstMI),
LastIndex = FirstIndex;
while (++User != UserEnd) {
SlotIndex UserIndex = LIS.getInstructionIndex(**User);
if (UserIndex < FirstIndex) {
FirstIndex = UserIndex;
FirstMI = *User;
} else if (UserIndex > LastIndex) {
LastIndex = UserIndex;
LastMI = *User;
}
}
return {FirstMI, LastMI};
}
void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) {
Uses[Region].insert(MI);
}
void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers,
unsigned Region) {
Uses[Region].insert_range(NewUsers);
}
void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) {
assert(Uses.contains(Region) && "no user in region");
assert(Uses.at(Region).contains(MI) && "user not in region");
RegionUsers &RUsers = Uses[Region];
if (RUsers.size() == 1)
Uses.erase(Region);
else
RUsers.erase(MI);
}
Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const {
return Printable([&, RootIdx](raw_ostream &OS) {
DenseMap<RegisterIdx, unsigned> RegDepths;
std::function<void(RegisterIdx, unsigned)> WalkTree =
[&](RegisterIdx RegIdx, unsigned Depth) -> void {
unsigned MaxDepth = std::max(RegDepths.lookup_or(RegIdx, Depth), Depth);
RegDepths.emplace_or_assign(RegIdx, MaxDepth);
for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
WalkTree(Dep.RegIdx, Depth + 1);
};
WalkTree(RootIdx, 0);
// Sort in decreasing depth order to print root at the bottom.
SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(),
RegDepths.end());
sort(Regs, [](const auto &LHS, const auto &RHS) {
return LHS.second > RHS.second;
});
OS << printID(RootIdx) << " has " << Regs.size() - 1 << " dependencies\n";
for (const auto &[RegIdx, Depth] : Regs) {
OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' '
<< printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n';
}
OS << printRegUsers(RootIdx);
});
}
Printable Rematerializer::printID(RegisterIdx RegIdx) const {
return Printable([&, RegIdx](raw_ostream &OS) {
const Reg &PrintReg = getReg(RegIdx);
OS << '(' << RegIdx << '/';
if (!PrintReg.DefMI) {
OS << "<dead>";
} else {
OS << printReg(PrintReg.getDefReg(), &TRI,
PrintReg.DefMI->getOperand(0).getSubReg(), &MRI);
}
OS << ")[" << PrintReg.DefRegion << "]";
});
}
Printable Rematerializer::printRematReg(RegisterIdx RegIdx,
bool SkipRegions) const {
return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) {
const Reg &PrintReg = getReg(RegIdx);
if (!SkipRegions) {
OS << printID(RegIdx) << " [" << PrintReg.DefRegion;
if (!PrintReg.Uses.empty()) {
assert(PrintReg.DefMI && "dead register cannot have uses");
const LiveInterval &LI = LIS.getInterval(PrintReg.getDefReg());
// First display all regions in which the register is live-through and
// not used.
bool First = true;
for (const auto [I, Bounds] : enumerate(Regions)) {
if (Bounds.first == Bounds.second)
continue;
if (!PrintReg.Uses.contains(I) &&
LI.liveAt(LIS.getInstructionIndex(*Bounds.first)) &&
LI.liveAt(LIS.getInstructionIndex(*std::prev(Bounds.second))
.getRegSlot())) {
OS << (First ? " - " : ",") << I;
First = false;
}
}
OS << (First ? " --> " : " -> ");
// Then display regions in which the register is used.
auto It = PrintReg.Uses.begin();
OS << It->first;
while (++It != PrintReg.Uses.end())
OS << "," << It->first;
}
OS << "] ";
}
OS << printID(RegIdx) << ' ';
PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
/*SkipDebugLoc=*/false, /*AddNewLine=*/false);
OS << " @ ";
LIS.getInstructionIndex(*PrintReg.DefMI).print(OS);
});
}
Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const {
return Printable([&, RegIdx](raw_ostream &OS) {
for (const auto &[_, Users] : getReg(RegIdx).Uses) {
for (MachineInstr *MI : Users)
dbgs() << " User " << printUser(MI) << '\n';
}
});
}
Printable Rematerializer::printUser(const MachineInstr *MI) const {
return Printable([&, MI](raw_ostream &OS) {
RegisterIdx RegIdx = getDefRegIdx(*MI);
if (RegIdx != NoReg)
OS << printID(RegIdx);
else
OS << "(-/-)[" << MIRegion.at(MI) << ']';
OS << ' ';
MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
/*SkipDebugLoc=*/false, /*AddNewLine=*/false);
OS << " @ ";
LIS.getInstructionIndex(*MI).print(dbgs());
});
}