| //=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //==-----------------------------------------------------------------------===// |
| // |
| /// \file |
| /// Implements helpers for target-independent rematerialization at the MIR |
| /// level. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/CodeGen/Rematerializer.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/CodeGen/LiveIntervals.h" |
| #include "llvm/CodeGen/MachineBasicBlock.h" |
| #include "llvm/CodeGen/MachineOperand.h" |
| #include "llvm/CodeGen/MachineRegisterInfo.h" |
| #include "llvm/CodeGen/Register.h" |
| #include "llvm/CodeGen/TargetRegisterInfo.h" |
| #include "llvm/Support/Debug.h" |
| #include <optional> |
| |
| #define DEBUG_TYPE "rematerializer" |
| |
| using namespace llvm; |
| using RegisterIdx = Rematerializer::RegisterIdx; |
| |
| // Pin the vtable to this file. |
| void Rematerializer::Listener::anchor() {} |
| |
| /// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this |
| /// implies it is also live there). When \p LI has sub-ranges, checks that |
| /// all sub-ranges intersecting with \p Mask are also live at \p UseIdx. |
| static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask, |
| SlotIndex UseIdx, const LiveInterval &LI) { |
| if (&OVNI != LI.getVNInfoAt(UseIdx)) |
| return false; |
| |
| if (LI.hasSubRanges()) { |
| // Check that intersecting subranges are live at user. |
| for (const LiveInterval::SubRange &SR : LI.subranges()) { |
| if ((SR.LaneMask & Mask).none()) |
| continue; |
| if (!SR.liveAt(UseIdx)) |
| return false; |
| |
| // Early exit if all used lanes are checked. No need to continue. |
| Mask &= ~SR.LaneMask; |
| if (Mask.none()) |
| break; |
| } |
| } |
| return true; |
| } |
| |
| /// If \p MO is a virtual read register, returns it. Otherwise returns the |
| /// sentinel register. |
| static Register getRegDependency(const MachineOperand &MO) { |
| if (!MO.isReg() || !MO.readsReg()) |
| return Register(); |
| Register Reg = MO.getReg(); |
| if (Reg.isPhysical()) { |
| // By the requirements on trivially rematerializable instructions, a |
| // physical register use is either constant or ignorable. |
| return Register(); |
| } |
| return Reg; |
| } |
| |
| RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx, |
| unsigned UseRegion, |
| DependencyReuseInfo &DRI) { |
| MachineInstr *FirstMI = |
| getReg(RootIdx).getRegionUseBounds(UseRegion, LIS).first; |
| // If there are no users in the region, rematerialize the register at the very |
| // end of the region. |
| MachineBasicBlock::iterator InsertPos = |
| FirstMI ? FirstMI : Regions[UseRegion].second; |
| RegisterIdx NewRegIdx = |
| rematerializeToPos(RootIdx, UseRegion, InsertPos, DRI); |
| transferRegionUsers(RootIdx, NewRegIdx, UseRegion); |
| return NewRegIdx; |
| } |
| |
| RegisterIdx |
| Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion, |
| MachineBasicBlock::iterator InsertPos, |
| DependencyReuseInfo &DRI) { |
| assert(!DRI.DependencyMap.contains(RootIdx)); |
| LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << '\n'); |
| |
| SmallVector<Reg::Dependency, 2> NewDeps; |
| // Copy all dependencies because recursive rematerialization of dependencies |
| // may invalidate references to the backing vector of registers. |
| SmallVector<Reg::Dependency, 2> OldDeps(getReg(RootIdx).Dependencies); |
| for (const Reg::Dependency &Dep : OldDeps) { |
| // Recursively rematerialize required dependencies at the same position as |
| // the root. Registers form a DAG so the recursion is guaranteed to |
| // terminate. |
| auto RematIdx = DRI.DependencyMap.find(Dep.RegIdx); |
| RegisterIdx NewDepRegIdx; |
| if (RematIdx == DRI.DependencyMap.end()) |
| NewDepRegIdx = rematerializeToPos(Dep.RegIdx, UseRegion, InsertPos, DRI); |
| else |
| NewDepRegIdx = RematIdx->second; |
| NewDeps.emplace_back(Dep.MOIdx, NewDepRegIdx); |
| } |
| RegisterIdx NewIdx = |
| rematerializeReg(RootIdx, UseRegion, InsertPos, std::move(NewDeps)); |
| DRI.DependencyMap.insert({RootIdx, NewIdx}); |
| return NewIdx; |
| } |
| |
| void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, |
| unsigned UserRegion, MachineInstr &UserMI) { |
| transferUserImpl(FromRegIdx, ToRegIdx, UserMI); |
| Regs[FromRegIdx].eraseUser(&UserMI, UserRegion); |
| Regs[ToRegIdx].addUser(&UserMI, UserRegion); |
| deleteRegIfUnused(FromRegIdx); |
| } |
| |
| void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx, |
| RegisterIdx ToRegIdx, |
| unsigned UseRegion) { |
| auto &FromRegUsers = Regs[FromRegIdx].Uses; |
| auto UsesIt = FromRegUsers.find(UseRegion); |
| if (UsesIt == FromRegUsers.end()) |
| return; |
| |
| const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond(); |
| for (MachineInstr *UserMI : RegionUsers) |
| transferUserImpl(FromRegIdx, ToRegIdx, *UserMI); |
| Regs[ToRegIdx].addUsers(RegionUsers, UseRegion); |
| FromRegUsers.erase(UseRegion); |
| deleteRegIfUnused(FromRegIdx); |
| } |
| |
| void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx, |
| RegisterIdx ToRegIdx) { |
| Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx]; |
| for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) { |
| for (MachineInstr *UserMI : RegionUsers) |
| transferUserImpl(FromRegIdx, ToRegIdx, *UserMI); |
| ToReg.addUsers(RegionUsers, UseRegion); |
| } |
| FromReg.Uses.clear(); |
| deleteRegIfUnused(FromRegIdx); |
| } |
| |
| void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx, |
| RegisterIdx ToRegIdx, |
| MachineInstr &UserMI) { |
| assert(FromRegIdx != ToRegIdx && "identical registers"); |
| assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) && |
| "unrelated registers"); |
| |
| LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to " |
| << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n'); |
| |
| UserMI.substituteRegister(getReg(FromRegIdx).getDefReg(), |
| getReg(ToRegIdx).getDefReg(), 0, TRI); |
| LISUpdates.insert(FromRegIdx); |
| LISUpdates.insert(ToRegIdx); |
| |
| // If the user is rematerializable, we must change its dependency to the |
| // new register. |
| if (RegisterIdx UserRegIdx = getDefRegIdx(UserMI); UserRegIdx != NoReg) { |
| // Look for the user's dependency that matches the register. |
| for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) { |
| if (Dep.RegIdx == FromRegIdx) { |
| Dep.RegIdx = ToRegIdx; |
| return; |
| } |
| } |
| llvm_unreachable("broken dependency"); |
| } |
| } |
| |
| void Rematerializer::updateLiveIntervals() { |
| DenseSet<Register> SeenUnrematRegs; |
| for (RegisterIdx RegIdx : LISUpdates) { |
| const Reg &UpdateReg = getReg(RegIdx); |
| assert(UpdateReg.isAlive() && "dead register"); |
| |
| Register DefReg = UpdateReg.getDefReg(); |
| if (LIS.hasInterval(DefReg)) |
| LIS.removeInterval(DefReg); |
| // Rematerializable registers have a single definition by construction so |
| // re-creating their interval cannot yield a live interval with multiple |
| // connected components. |
| LIS.createAndComputeVirtRegInterval(DefReg); |
| |
| LLVM_DEBUG({ |
| dbgs() << "Re-computed interval for " << printID(RegIdx) << ": "; |
| LIS.getInterval(DefReg).print(dbgs()); |
| dbgs() << '\n' << printRegUsers(RegIdx); |
| }); |
| |
| // Update intervals for unrematerializable operands. |
| for (unsigned MOIdx : getUnrematableOprds(RegIdx)) { |
| Register UnrematReg = UpdateReg.DefMI->getOperand(MOIdx).getReg(); |
| if (!SeenUnrematRegs.insert(UnrematReg).second) |
| continue; |
| LIS.removeInterval(UnrematReg); |
| bool NeedSplit = false; |
| |
| // Unrematerializable registers may end up with multiple connected |
| // components in their live interval after it is re-created. It needs to |
| // be split in such cases. We don't track unrematerializable registers by |
| // their actual register index (just by operand index) so we do not need |
| // to update any state in the rematerializer. |
| LiveInterval &LI = |
| LIS.createAndComputeVirtRegInterval(UnrematReg, NeedSplit); |
| if (NeedSplit) { |
| SmallVector<LiveInterval *> SplitLIs; |
| LIS.splitSeparateComponents(LI, SplitLIs); |
| } |
| LLVM_DEBUG( |
| dbgs() << " Re-computed interval for register " |
| << printReg(UnrematReg, &TRI, |
| UpdateReg.DefMI->getOperand(MOIdx).getSubReg(), |
| &MRI) |
| << '\n'); |
| } |
| } |
| LISUpdates.clear(); |
| } |
| |
| bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO, |
| ArrayRef<SlotIndex> Uses) const { |
| if (Uses.empty()) |
| return true; |
| Register Reg = MO.getReg(); |
| unsigned SubIdx = MO.getSubReg(); |
| LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx) |
| : MRI.getMaxLaneMaskForVReg(Reg); |
| const LiveInterval &LI = LIS.getInterval(Reg); |
| const VNInfo *DefVN = |
| LI.getVNInfoAt(LIS.getInstructionIndex(*MO.getParent()).getRegSlot(true)); |
| for (SlotIndex Use : Uses) { |
| if (!isIdenticalAtUse(*DefVN, Mask, Use, LI)) |
| return false; |
| } |
| return true; |
| } |
| |
| RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx, |
| unsigned Region, |
| SlotIndex Before) const { |
| auto It = Rematerializations.find(getOriginOrSelf(RegIdx)); |
| if (It == Rematerializations.end()) |
| return NoReg; |
| const RematsOf &Remats = It->getSecond(); |
| |
| SlotIndex BestSlot; |
| RegisterIdx BestRegIdx = NoReg; |
| for (RegisterIdx RematRegIdx : Remats) { |
| const Reg &RematReg = getReg(RematRegIdx); |
| if (RematReg.DefRegion != Region || RematReg.Uses.empty()) |
| continue; |
| SlotIndex RematRegSlot = |
| LIS.getInstructionIndex(*RematReg.DefMI).getRegSlot(); |
| if (RematRegSlot < Before && |
| (BestRegIdx == NoReg || RematRegSlot > BestSlot)) { |
| BestSlot = RematRegSlot; |
| BestRegIdx = RematRegIdx; |
| } |
| } |
| return BestRegIdx; |
| } |
| |
| void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) { |
| if (!getReg(RootIdx).Uses.empty()) |
| return; |
| |
| // Traverse the root's dependency DAG depth-first to find the set of registers |
| // we can delete and a legal order to delete them in. |
| SmallVector<RegisterIdx, 4> DepDAG{RootIdx}; |
| SmallSetVector<RegisterIdx, 8> DeleteOrder; |
| DeleteOrder.insert(RootIdx); |
| do { |
| // A deleted register's dependencies may be deletable too. |
| const Reg &DeleteReg = getReg(DepDAG.pop_back_val()); |
| for (const Reg::Dependency &Dep : DeleteReg.Dependencies) { |
| // All dependencies loose a user (the deleted register). |
| Reg &DepReg = Regs[Dep.RegIdx]; |
| DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion); |
| if (DepReg.Uses.empty()) { |
| DeleteOrder.insert(Dep.RegIdx); |
| DepDAG.push_back(Dep.RegIdx); |
| } |
| } |
| } while (!DepDAG.empty()); |
| |
| for (RegisterIdx RegIdx : reverse(DeleteOrder)) { |
| Reg &DeleteReg = Regs[RegIdx]; |
| |
| // It is possible that the defined register we are deleting doesn't have an |
| // interval yet if the LIS hasn't been updated since it was created. |
| Register DefReg = DeleteReg.getDefReg(); |
| if (LIS.hasInterval(DefReg)) |
| LIS.removeInterval(DefReg); |
| LISUpdates.erase(RegIdx); |
| |
| deleteReg(RegIdx); |
| if (isRematerializedRegister(RegIdx)) { |
| // Delete rematerialized register from its origin's rematerializations. |
| RematsOf &OriginRemats = Rematerializations.at(getOriginOf(RegIdx)); |
| assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link"); |
| OriginRemats.erase(RegIdx); |
| if (OriginRemats.empty()) |
| Rematerializations.erase(RegIdx); |
| } |
| LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n"); |
| } |
| } |
| |
| void Rematerializer::deleteReg(RegisterIdx RegIdx) { |
| noteRegDeleted(RegIdx); |
| |
| Reg &DeleteReg = Regs[RegIdx]; |
| assert(DeleteReg.DefMI && "register was already deleted"); |
| // It is not possible for the deleted instruction to be the upper region |
| // boundary since we don't ever consider them rematerializable. |
| MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first; |
| if (RegionBegin == DeleteReg.DefMI) |
| RegionBegin = std::next(MachineBasicBlock::iterator(DeleteReg.DefMI)); |
| LIS.RemoveMachineInstrFromMaps(*DeleteReg.DefMI); |
| DeleteReg.DefMI->eraseFromParent(); |
| DeleteReg.DefMI = nullptr; |
| } |
| |
| Rematerializer::Rematerializer(MachineFunction &MF, |
| SmallVectorImpl<RegionBoundaries> &Regions, |
| LiveIntervals &LIS) |
| : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS), |
| TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) { |
| #ifdef EXPENSIVE_CHECKS |
| // Check that regions are valid. |
| DenseSet<MachineInstr *> SeenMIs; |
| for (const auto &[RegionBegin, RegionEnd] : Regions) { |
| assert(RegionBegin != RegionEnd && "empty region"); |
| for (auto MI = RegionBegin; MI != RegionEnd; ++MI) { |
| bool IsNewMI = SeenMIs.insert(&*MI).second; |
| assert(IsNewMI && "overlapping regions"); |
| assert(!MI->isTerminator() && "terminator in region"); |
| } |
| if (RegionEnd != RegionBegin->getParent()->end()) { |
| bool IsNewMI = SeenMIs.insert(&*RegionEnd).second; |
| assert(IsNewMI && "overlapping regions (upper bound)"); |
| } |
| } |
| #endif |
| } |
| |
| bool Rematerializer::analyze() { |
| Regs.clear(); |
| UnrematableOprds.clear(); |
| Origins.clear(); |
| Rematerializations.clear(); |
| RegionMBB.clear(); |
| RegToIdx.clear(); |
| LISUpdates.clear(); |
| if (Regions.empty()) |
| return false; |
| |
| /// Maps all MIs to their parent region. Region terminators are considered |
| /// part of the region they terminate. |
| DenseMap<MachineInstr *, unsigned> MIRegion; |
| |
| // Initialize MI to containing region mapping. |
| RegionMBB.reserve(Regions.size()); |
| for (unsigned I = 0, E = Regions.size(); I < E; ++I) { |
| RegionBoundaries Region = Regions[I]; |
| assert(Region.first != Region.second && "empty cannot be region"); |
| for (auto MI = Region.first; MI != Region.second; ++MI) { |
| assert(!MIRegion.contains(&*MI) && "regions should not intersect"); |
| MIRegion.insert({&*MI, I}); |
| } |
| MachineBasicBlock &MBB = *Region.first->getParent(); |
| RegionMBB.push_back(&MBB); |
| |
| // A terminator instruction is considered part of the region it terminates. |
| if (Region.second != MBB.end()) { |
| MachineInstr *RegionTerm = &*Region.second; |
| assert(!MIRegion.contains(RegionTerm) && "regions should not intersect"); |
| MIRegion.insert({RegionTerm, I}); |
| } |
| } |
| |
| const unsigned NumVirtRegs = MRI.getNumVirtRegs(); |
| BitVector SeenRegs(NumVirtRegs); |
| for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) { |
| if (!SeenRegs[I]) |
| addRegIfRematerializable(I, MIRegion, SeenRegs); |
| } |
| assert(Regs.size() == UnrematableOprds.size()); |
| |
| LLVM_DEBUG({ |
| for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I) |
| dbgs() << printDependencyDAG(I) << '\n'; |
| }); |
| return !Regs.empty(); |
| } |
| |
| void Rematerializer::addRegIfRematerializable( |
| unsigned VirtRegIdx, const DenseMap<MachineInstr *, unsigned> &MIRegion, |
| BitVector &SeenRegs) { |
| assert(!SeenRegs[VirtRegIdx] && "register already seen"); |
| Register DefReg = Register::index2VirtReg(VirtRegIdx); |
| SeenRegs.set(VirtRegIdx); |
| |
| MachineOperand *MO = MRI.getOneDef(DefReg); |
| if (!MO) |
| return; |
| MachineInstr &DefMI = *MO->getParent(); |
| if (!isMIRematerializable(DefMI)) |
| return; |
| auto DefRegion = MIRegion.find(&DefMI); |
| if (DefRegion == MIRegion.end()) |
| return; |
| |
| Reg RematReg; |
| RematReg.DefMI = &DefMI; |
| RematReg.DefRegion = DefRegion->second; |
| unsigned SubIdx = DefMI.getOperand(0).getSubReg(); |
| RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx) |
| : MRI.getMaxLaneMaskForVReg(DefReg); |
| |
| // Collect the candidate's direct users, both rematerializable and |
| // unrematerializable. MIs outside provided regions cannot be tracked so the |
| // registers they use are not safely rematerializable. |
| for (MachineInstr &UseMI : MRI.use_nodbg_instructions(DefReg)) { |
| if (auto UseRegion = MIRegion.find(&UseMI); UseRegion != MIRegion.end()) |
| RematReg.addUser(&UseMI, UseRegion->second); |
| else |
| return; |
| } |
| if (RematReg.Uses.empty()) |
| return; |
| |
| // Collect the candidate's dependencies. If the same register is used |
| // multiple times we just need to consider it once. |
| SmallDenseSet<Register, 4> AllDepRegs; |
| SmallVector<unsigned, 2> UnrematDeps; |
| for (const auto &[MOIdx, MO] : enumerate(RematReg.DefMI->operands())) { |
| Register DepReg = getRegDependency(MO); |
| if (!DepReg || !AllDepRegs.insert(DepReg).second) |
| continue; |
| unsigned DepRegIdx = DepReg.virtRegIndex(); |
| if (!SeenRegs[DepRegIdx]) |
| addRegIfRematerializable(DepRegIdx, MIRegion, SeenRegs); |
| if (auto DepIt = RegToIdx.find(DepReg); DepIt != RegToIdx.end()) |
| RematReg.Dependencies.push_back(Reg::Dependency(MOIdx, DepIt->second)); |
| else |
| UnrematDeps.push_back(MOIdx); |
| } |
| |
| // The register is rematerializable. |
| RegToIdx.insert({DefReg, Regs.size()}); |
| Regs.push_back(RematReg); |
| UnrematableOprds.push_back(UnrematDeps); |
| } |
| |
| bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const { |
| if (!TII.isReMaterializable(MI)) |
| return false; |
| |
| assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual"); |
| assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def"); |
| |
| for (const MachineOperand &MO : MI.all_uses()) { |
| // We can't remat physreg uses, unless it is a constant or an ignorable |
| // use (e.g. implicit exec use on VALU instructions) |
| if (MO.getReg().isPhysical()) { |
| if (MRI.isConstantPhysReg(MO.getReg()) || TII.isIgnorableUse(MO)) |
| continue; |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const { |
| if (!MI.getNumOperands() || !MI.getOperand(0).isReg() || |
| MI.getOperand(0).readsReg()) |
| return NoReg; |
| Register Reg = MI.getOperand(0).getReg(); |
| auto UserRegIt = RegToIdx.find(Reg); |
| if (UserRegIt == RegToIdx.end()) |
| return NoReg; |
| return UserRegIt->second; |
| } |
| |
| RegisterIdx Rematerializer::rematerializeReg( |
| RegisterIdx RegIdx, unsigned UseRegion, |
| MachineBasicBlock::iterator InsertPos, |
| SmallVectorImpl<Reg::Dependency> &&Dependencies) { |
| RegisterIdx NewRegIdx = Regs.size(); |
| |
| Reg &NewReg = Regs.emplace_back(); |
| Reg &FromReg = Regs[RegIdx]; |
| NewReg.Mask = FromReg.Mask; |
| NewReg.DefRegion = UseRegion; |
| NewReg.Dependencies = std::move(Dependencies); |
| |
| // Track rematerialization link between registers. Origins are always |
| // registers that existed originally, and rematerializations are always |
| // attached to them. |
| const RegisterIdx OriginIdx = getOriginOrSelf(RegIdx); |
| Origins.push_back(OriginIdx); |
| Rematerializations[OriginIdx].insert(NewRegIdx); |
| |
| // Use the TII to rematerialize the defining instruction with a new defined |
| // register. |
| Register NewDefReg = MRI.cloneVirtualRegister(FromReg.getDefReg()); |
| TII.reMaterialize(*RegionMBB[UseRegion], InsertPos, NewDefReg, 0, |
| *FromReg.DefMI); |
| NewReg.DefMI = &*std::prev(InsertPos); |
| RegToIdx.insert({NewDefReg, NewRegIdx}); |
| postRematerialization(RegIdx, NewRegIdx, InsertPos); |
| |
| noteRegCreated(NewRegIdx); |
| LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as " |
| << printRematReg(NewRegIdx) << '\n'); |
| return NewRegIdx; |
| } |
| |
| void Rematerializer::recreateReg( |
| RegisterIdx RegIdx, unsigned DefRegion, |
| MachineBasicBlock::iterator InsertPos, Register DefReg, |
| SmallVectorImpl<Reg::Dependency> &&Dependencies) { |
| assert(RegToIdx.contains(DefReg) && "unknown defined register"); |
| assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register"); |
| assert(!getReg(RegIdx).DefMI && "register is still alive"); |
| |
| Reg &OriginReg = Regs[RegIdx]; |
| OriginReg.DefRegion = DefRegion; |
| OriginReg.Dependencies = std::move(Dependencies); |
| |
| // Re-establish the link between origin and rematerialization if necessary. |
| const bool RecreateOriginalReg = isOriginalRegister(RegIdx); |
| if (!RecreateOriginalReg) |
| Rematerializations[getOriginOf(RegIdx)].insert(RegIdx); |
| |
| // Rematerialize from one of the existing rematerializations or from the |
| // origin. We expect at least one to exist, otherwise it would mean the value |
| // held by the original register is no longer available anywhere in the MF. |
| RegisterIdx ModelRegIdx; |
| if (RecreateOriginalReg) { |
| assert(Rematerializations.contains(RegIdx) && "expected remats"); |
| ModelRegIdx = *Rematerializations.at(RegIdx).begin(); |
| } else { |
| assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin"); |
| ModelRegIdx = getOriginOf(RegIdx); |
| } |
| const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI; |
| |
| TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI); |
| OriginReg.DefMI = &*std::prev(InsertPos); |
| postRematerialization(ModelRegIdx, RegIdx, InsertPos); |
| LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as " |
| << printRematReg(RegIdx) << '\n'); |
| } |
| |
| void Rematerializer::postRematerialization( |
| RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx, |
| MachineBasicBlock::iterator InsertPos) { |
| |
| // The start of the new register's region may have changed. |
| Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx]; |
| LIS.InsertMachineInstrInMaps(*RematReg.DefMI); |
| MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first; |
| if (RegionBegin == std::next(MachineBasicBlock::iterator(RematReg.DefMI))) |
| RegionBegin = RematReg.DefMI; |
| |
| // Replace dependencies as needed in the rematerialized MI. All dependencies |
| // of the latter gain a new user. |
| auto ZipedDeps = zip_equal(ModelReg.Dependencies, RematReg.Dependencies); |
| for (const auto &[OldDep, NewDep] : ZipedDeps) { |
| assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch"); |
| LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": " |
| << printID(OldDep.RegIdx) << " -> " |
| << printID(NewDep.RegIdx) << '\n'); |
| |
| Reg &NewDepReg = Regs[NewDep.RegIdx]; |
| if (OldDep.RegIdx != NewDep.RegIdx) { |
| Register OldDefReg = ModelReg.DefMI->getOperand(OldDep.MOIdx).getReg(); |
| RematReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0, |
| TRI); |
| LISUpdates.insert(OldDep.RegIdx); |
| } |
| NewDepReg.addUser(RematReg.DefMI, RematReg.DefRegion); |
| LISUpdates.insert(NewDep.RegIdx); |
| } |
| } |
| |
| std::pair<MachineInstr *, MachineInstr *> |
| Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion, |
| const LiveIntervals &LIS) const { |
| auto It = Uses.find(UseRegion); |
| if (It == Uses.end()) |
| return {nullptr, nullptr}; |
| const RegionUsers &RegionUsers = It->getSecond(); |
| assert(!RegionUsers.empty() && "empty userset in region"); |
| |
| auto User = RegionUsers.begin(), UserEnd = RegionUsers.end(); |
| MachineInstr *FirstMI = *User, *LastMI = FirstMI; |
| SlotIndex FirstIndex = LIS.getInstructionIndex(*FirstMI), |
| LastIndex = FirstIndex; |
| |
| while (++User != UserEnd) { |
| SlotIndex UserIndex = LIS.getInstructionIndex(**User); |
| if (UserIndex < FirstIndex) { |
| FirstIndex = UserIndex; |
| FirstMI = *User; |
| } else if (UserIndex > LastIndex) { |
| LastIndex = UserIndex; |
| LastMI = *User; |
| } |
| } |
| |
| return {FirstMI, LastMI}; |
| } |
| |
| void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) { |
| Uses[Region].insert(MI); |
| } |
| |
| void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers, |
| unsigned Region) { |
| Uses[Region].insert_range(NewUsers); |
| } |
| |
| void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) { |
| RegionUsers &RUsers = Uses.at(Region); |
| assert(RUsers.contains(MI) && "user not in region"); |
| if (RUsers.size() == 1) |
| Uses.erase(Region); |
| else |
| RUsers.erase(MI); |
| } |
| |
| Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const { |
| return Printable([&, RootIdx](raw_ostream &OS) { |
| DenseMap<RegisterIdx, unsigned> RegDepths; |
| std::function<void(RegisterIdx, unsigned)> WalkTree = |
| [&](RegisterIdx RegIdx, unsigned Depth) -> void { |
| unsigned MaxDepth = std::max(RegDepths.lookup_or(RegIdx, Depth), Depth); |
| RegDepths.emplace_or_assign(RegIdx, MaxDepth); |
| for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) |
| WalkTree(Dep.RegIdx, Depth + 1); |
| }; |
| WalkTree(RootIdx, 0); |
| |
| // Sort in decreasing depth order to print root at the bottom. |
| SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(), |
| RegDepths.end()); |
| sort(Regs, [](const auto &LHS, const auto &RHS) { |
| return LHS.second > RHS.second; |
| }); |
| |
| OS << printID(RootIdx) << " has " << Regs.size() - 1 << " dependencies\n"; |
| for (const auto &[RegIdx, Depth] : Regs) { |
| OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' ' |
| << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n'; |
| } |
| OS << printRegUsers(RootIdx); |
| }); |
| } |
| |
| Printable Rematerializer::printID(RegisterIdx RegIdx) const { |
| return Printable([&, RegIdx](raw_ostream &OS) { |
| const Reg &PrintReg = getReg(RegIdx); |
| OS << '(' << RegIdx << '/'; |
| if (!PrintReg.DefMI) { |
| OS << "<dead>"; |
| } else { |
| OS << printReg(PrintReg.getDefReg(), &TRI, |
| PrintReg.DefMI->getOperand(0).getSubReg(), &MRI); |
| } |
| OS << ")[" << PrintReg.DefRegion << "]"; |
| }); |
| } |
| |
| Printable Rematerializer::printRematReg(RegisterIdx RegIdx, |
| bool SkipRegions) const { |
| return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) { |
| const Reg &PrintReg = getReg(RegIdx); |
| if (!SkipRegions) { |
| OS << printID(RegIdx) << " [" << PrintReg.DefRegion; |
| if (!PrintReg.Uses.empty()) { |
| assert(PrintReg.DefMI && "dead register cannot have uses"); |
| const LiveInterval &LI = LIS.getInterval(PrintReg.getDefReg()); |
| // First display all regions in which the register is live-through and |
| // not used. |
| bool First = true; |
| for (const auto [I, Bounds] : enumerate(Regions)) { |
| if (Bounds.first == Bounds.second) |
| continue; |
| if (!PrintReg.Uses.contains(I) && |
| LI.liveAt(LIS.getInstructionIndex(*Bounds.first)) && |
| LI.liveAt(LIS.getInstructionIndex(*std::prev(Bounds.second)) |
| .getRegSlot())) { |
| OS << (First ? " - " : ",") << I; |
| First = false; |
| } |
| } |
| OS << (First ? " --> " : " -> "); |
| |
| // Then display regions in which the register is used. |
| auto It = PrintReg.Uses.begin(); |
| OS << It->first; |
| while (++It != PrintReg.Uses.end()) |
| OS << "," << It->first; |
| } |
| OS << "] "; |
| } |
| OS << printID(RegIdx) << ' '; |
| PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false, |
| /*SkipDebugLoc=*/false, /*AddNewLine=*/false); |
| OS << " @ "; |
| LIS.getInstructionIndex(*PrintReg.DefMI).print(OS); |
| }); |
| } |
| |
| Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const { |
| return Printable([&, RegIdx](raw_ostream &OS) { |
| for (const auto &[UseRegion, Users] : getReg(RegIdx).Uses) { |
| for (MachineInstr *MI : Users) |
| OS << " User " << printUser(MI, UseRegion) << '\n'; |
| } |
| }); |
| } |
| |
| Printable Rematerializer::printUser(const MachineInstr *MI, |
| std::optional<unsigned> UseRegion) const { |
| return Printable([&, MI, UseRegion](raw_ostream &OS) { |
| RegisterIdx RegIdx = getDefRegIdx(*MI); |
| if (RegIdx != NoReg) { |
| OS << printID(RegIdx); |
| } else { |
| OS << "(-/-)["; |
| if (UseRegion) |
| OS << *UseRegion; |
| else |
| OS << '?'; |
| OS << ']'; |
| } |
| OS << ' '; |
| MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false, |
| /*SkipDebugLoc=*/false, /*AddNewLine=*/false); |
| OS << " @ "; |
| LIS.getInstructionIndex(*MI).print(OS); |
| }); |
| } |
| |
| Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater, |
| RegisterIdx RegIdx) { |
| const Rematerializer::Reg &Reg = Remater.getReg(RegIdx); |
| DefReg = Reg.getDefReg(); |
| DefRegion = Reg.DefRegion; |
| Dependencies = Reg.Dependencies; |
| |
| InsertPos = std::next(Reg.DefMI->getIterator()); |
| if (InsertPos != Reg.DefMI->getParent()->end()) |
| NextRegIdx = Remater.getDefRegIdx(*InsertPos); |
| else |
| NextRegIdx = Rematerializer::NoReg; |
| } |
| |
| void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater, |
| RegisterIdx RegIdx) { |
| if (RollingBack) |
| return; |
| Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx); |
| } |
| |
| void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater, |
| RegisterIdx RegIdx) { |
| if (RollingBack || Remater.isRematerializedRegister(RegIdx)) |
| return; |
| DeadRegs.try_emplace(RegIdx, Remater, RegIdx); |
| } |
| |
| void Rollbacker::rollback(Rematerializer &Remater) { |
| RollingBack = true; |
| |
| // Re-create deleted registers. |
| for (auto &[RegIdx, Info] : DeadRegs) { |
| assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead"); |
| |
| // The MI that was originally just after the MI defining the register we |
| // are trying to re-create may have been deleted. In such cases, we can |
| // re-create at that MI's own insert position (and apply the same logic |
| // recursively). |
| MachineBasicBlock::iterator InsertPos = Info.InsertPos; |
| RegisterIdx NextRegIdx = Info.NextRegIdx; |
| while (NextRegIdx != Rematerializer::NoReg) { |
| const auto *NextRegRollback = DeadRegs.find(NextRegIdx); |
| if (NextRegRollback == DeadRegs.end()) |
| break; |
| InsertPos = NextRegRollback->second.InsertPos; |
| NextRegIdx = NextRegRollback->second.NextRegIdx; |
| } |
| Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg, |
| std::move(Info.Dependencies)); |
| } |
| |
| // Rollback rematerializations. |
| for (const auto &[RegIdx, RematsOf] : Rematerializations) { |
| for (RegisterIdx RematRegIdx : RematsOf) { |
| // It is possible that rematerializations were deleted. Their users would |
| // have been transfered to some other rematerialization so we can safely |
| // ignore them. Original registers that were deleted were just re-created |
| // so we do not need to check for that. |
| if (Remater.getReg(RematRegIdx).isAlive()) |
| Remater.transferAllUsers(RematRegIdx, RegIdx); |
| } |
| } |
| |
| Remater.updateLiveIntervals(); |
| DeadRegs.clear(); |
| Rematerializations.clear(); |
| RollingBack = false; |
| } |