[GlobalISel] Refactor extractParts() (#75223)
Moved extractParts() and extractVectorParts() from LegalizerHelper
to Utils to be able to use it in different passes.
extractParts() will also try to use unmerge when doing irregular
splits where possible, falling back to extract elements when not.
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 21947a5..91d2497 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -158,100 +158,6 @@
}
}
-void LegalizerHelper::extractParts(Register Reg, LLT Ty, int NumParts,
- SmallVectorImpl<Register> &VRegs) {
- for (int i = 0; i < NumParts; ++i)
- VRegs.push_back(MRI.createGenericVirtualRegister(Ty));
- MIRBuilder.buildUnmerge(VRegs, Reg);
-}
-
-bool LegalizerHelper::extractParts(Register Reg, LLT RegTy,
- LLT MainTy, LLT &LeftoverTy,
- SmallVectorImpl<Register> &VRegs,
- SmallVectorImpl<Register> &LeftoverRegs) {
- assert(!LeftoverTy.isValid() && "this is an out argument");
-
- unsigned RegSize = RegTy.getSizeInBits();
- unsigned MainSize = MainTy.getSizeInBits();
- unsigned NumParts = RegSize / MainSize;
- unsigned LeftoverSize = RegSize - NumParts * MainSize;
-
- // Use an unmerge when possible.
- if (LeftoverSize == 0) {
- for (unsigned I = 0; I < NumParts; ++I)
- VRegs.push_back(MRI.createGenericVirtualRegister(MainTy));
- MIRBuilder.buildUnmerge(VRegs, Reg);
- return true;
- }
-
- // Perform irregular split. Leftover is last element of RegPieces.
- if (MainTy.isVector()) {
- SmallVector<Register, 8> RegPieces;
- extractVectorParts(Reg, MainTy.getNumElements(), RegPieces);
- for (unsigned i = 0; i < RegPieces.size() - 1; ++i)
- VRegs.push_back(RegPieces[i]);
- LeftoverRegs.push_back(RegPieces[RegPieces.size() - 1]);
- LeftoverTy = MRI.getType(LeftoverRegs[0]);
- return true;
- }
-
- LeftoverTy = LLT::scalar(LeftoverSize);
- // For irregular sizes, extract the individual parts.
- for (unsigned I = 0; I != NumParts; ++I) {
- Register NewReg = MRI.createGenericVirtualRegister(MainTy);
- VRegs.push_back(NewReg);
- MIRBuilder.buildExtract(NewReg, Reg, MainSize * I);
- }
-
- for (unsigned Offset = MainSize * NumParts; Offset < RegSize;
- Offset += LeftoverSize) {
- Register NewReg = MRI.createGenericVirtualRegister(LeftoverTy);
- LeftoverRegs.push_back(NewReg);
- MIRBuilder.buildExtract(NewReg, Reg, Offset);
- }
-
- return true;
-}
-
-void LegalizerHelper::extractVectorParts(Register Reg, unsigned NumElts,
- SmallVectorImpl<Register> &VRegs) {
- LLT RegTy = MRI.getType(Reg);
- assert(RegTy.isVector() && "Expected a vector type");
-
- LLT EltTy = RegTy.getElementType();
- LLT NarrowTy = (NumElts == 1) ? EltTy : LLT::fixed_vector(NumElts, EltTy);
- unsigned RegNumElts = RegTy.getNumElements();
- unsigned LeftoverNumElts = RegNumElts % NumElts;
- unsigned NumNarrowTyPieces = RegNumElts / NumElts;
-
- // Perfect split without leftover
- if (LeftoverNumElts == 0)
- return extractParts(Reg, NarrowTy, NumNarrowTyPieces, VRegs);
-
- // Irregular split. Provide direct access to all elements for artifact
- // combiner using unmerge to elements. Then build vectors with NumElts
- // elements. Remaining element(s) will be (used to build vector) Leftover.
- SmallVector<Register, 8> Elts;
- extractParts(Reg, EltTy, RegNumElts, Elts);
-
- unsigned Offset = 0;
- // Requested sub-vectors of NarrowTy.
- for (unsigned i = 0; i < NumNarrowTyPieces; ++i, Offset += NumElts) {
- ArrayRef<Register> Pieces(&Elts[Offset], NumElts);
- VRegs.push_back(MIRBuilder.buildMergeLikeInstr(NarrowTy, Pieces).getReg(0));
- }
-
- // Leftover element(s).
- if (LeftoverNumElts == 1) {
- VRegs.push_back(Elts[Offset]);
- } else {
- LLT LeftoverTy = LLT::fixed_vector(LeftoverNumElts, EltTy);
- ArrayRef<Register> Pieces(&Elts[Offset], LeftoverNumElts);
- VRegs.push_back(
- MIRBuilder.buildMergeLikeInstr(LeftoverTy, Pieces).getReg(0));
- }
-}
-
void LegalizerHelper::insertParts(Register DstReg,
LLT ResultTy, LLT PartTy,
ArrayRef<Register> PartRegs,
@@ -293,7 +199,8 @@
Register Reg) {
LLT Ty = MRI.getType(Reg);
SmallVector<Register, 8> RegElts;
- extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts);
+ extractParts(Reg, Ty.getScalarType(), Ty.getNumElements(), RegElts,
+ MIRBuilder, MRI);
Elts.append(RegElts);
}
@@ -1542,7 +1449,7 @@
MachineBasicBlock &OpMBB = *MI.getOperand(i + 1).getMBB();
MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
extractParts(MI.getOperand(i).getReg(), NarrowTy, NumParts,
- SrcRegs[i / 2]);
+ SrcRegs[i / 2], MIRBuilder, MRI);
}
MachineBasicBlock &MBB = *MI.getParent();
MIRBuilder.setInsertPt(MBB, MI);
@@ -1584,13 +1491,13 @@
LLT LeftoverTy; // Example: s88 -> s64 (NarrowTy) + s24 (leftover)
SmallVector<Register, 4> LHSPartRegs, LHSLeftoverRegs;
if (!extractParts(LHS, SrcTy, NarrowTy, LeftoverTy, LHSPartRegs,
- LHSLeftoverRegs))
+ LHSLeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;
LLT Unused; // Matches LeftoverTy; G_ICMP LHS and RHS are the same type.
SmallVector<Register, 4> RHSPartRegs, RHSLeftoverRegs;
if (!extractParts(MI.getOperand(3).getReg(), SrcTy, NarrowTy, Unused,
- RHSPartRegs, RHSLeftoverRegs))
+ RHSPartRegs, RHSLeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;
// We now have the LHS and RHS of the compare split into narrow-type
@@ -1744,7 +1651,8 @@
Observer.changingInstr(MI);
SmallVector<Register, 2> SrcRegs, DstRegs;
unsigned NumParts = SizeOp0 / NarrowSize;
- extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs);
+ extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
+ MIRBuilder, MRI);
for (unsigned i = 0; i < NumParts; ++i) {
auto DstPart = MIRBuilder.buildInstr(MI.getOpcode(), {NarrowTy},
@@ -4194,7 +4102,8 @@
MI.getOperand(UseIdx));
} else {
SmallVector<Register, 8> SplitPieces;
- extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces);
+ extractVectorParts(MI.getReg(UseIdx), NumElts, SplitPieces, MIRBuilder,
+ MRI);
for (auto Reg : SplitPieces)
InputOpsPieces[UseNo].push_back(Reg);
}
@@ -4250,7 +4159,8 @@
UseIdx += 2, ++UseNo) {
MachineBasicBlock &OpMBB = *MI.getOperand(UseIdx + 1).getMBB();
MIRBuilder.setInsertPt(OpMBB, OpMBB.getFirstTerminatorForward());
- extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo]);
+ extractVectorParts(MI.getReg(UseIdx), NumElts, InputOpsPieces[UseNo],
+ MIRBuilder, MRI);
}
// Build PHIs with fewer elements.
@@ -4519,7 +4429,7 @@
std::tie(NumParts, NumLeftover) = getNarrowTypeBreakDown(ValTy, NarrowTy, LeftoverTy);
} else {
if (extractParts(ValReg, ValTy, NarrowTy, LeftoverTy, NarrowRegs,
- NarrowLeftoverRegs)) {
+ NarrowLeftoverRegs, MIRBuilder, MRI)) {
NumParts = NarrowRegs.size();
NumLeftover = NarrowLeftoverRegs.size();
}
@@ -4765,8 +4675,8 @@
unsigned NewElts = NarrowTy.getNumElements();
SmallVector<Register> SplitSrc1Regs, SplitSrc2Regs;
- extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs);
- extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs);
+ extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs, MIRBuilder, MRI);
+ extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs, MIRBuilder, MRI);
Register Inputs[4] = {SplitSrc1Regs[0], SplitSrc1Regs[1], SplitSrc2Regs[0],
SplitSrc2Regs[1]};
@@ -4900,7 +4810,7 @@
NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
: SrcTy.getNumElements();
- extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
+ extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
if (NarrowTy.isScalar()) {
if (DstTy != NarrowTy)
return UnableToLegalize; // FIXME: handle implicit extensions.
@@ -4983,7 +4893,7 @@
SmallVector<Register> SplitSrcs;
unsigned NumParts = SrcTy.getNumElements();
- extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
+ extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs, MIRBuilder, MRI);
Register Acc = ScalarReg;
for (unsigned i = 0; i < NumParts; i++)
Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[i]})
@@ -5001,7 +4911,8 @@
SmallVector<Register> SplitSrcs;
// Split the sources into NarrowTy size pieces.
extractParts(SrcReg, NarrowTy,
- SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs);
+ SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs,
+ MIRBuilder, MRI);
// We're going to do a tree reduction using vector operations until we have
// one NarrowTy size value left.
while (SplitSrcs.size() > 1) {
@@ -5640,8 +5551,10 @@
LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
LLT LeftoverTy, DummyTy;
SmallVector<Register, 2> Src1Regs, Src2Regs, Src1Left, Src2Left, DstRegs;
- extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left);
- extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left);
+ extractParts(Src1, RegTy, NarrowTy, LeftoverTy, Src1Regs, Src1Left,
+ MIRBuilder, MRI);
+ extractParts(Src2, RegTy, NarrowTy, DummyTy, Src2Regs, Src2Left, MIRBuilder,
+ MRI);
int NarrowParts = Src1Regs.size();
for (int I = 0, E = Src1Left.size(); I != E; ++I) {
@@ -5699,8 +5612,8 @@
SmallVector<Register, 2> Src1Parts, Src2Parts;
SmallVector<Register, 2> DstTmpRegs(DstTmpParts);
- extractParts(Src1, NarrowTy, NumParts, Src1Parts);
- extractParts(Src2, NarrowTy, NumParts, Src2Parts);
+ extractParts(Src1, NarrowTy, NumParts, Src1Parts, MIRBuilder, MRI);
+ extractParts(Src2, NarrowTy, NumParts, Src2Parts, MIRBuilder, MRI);
multiplyRegisters(DstTmpRegs, Src1Parts, Src2Parts, NarrowTy);
// Take only high half of registers if this is high mul.
@@ -5752,7 +5665,8 @@
SmallVector<Register, 2> SrcRegs, DstRegs;
SmallVector<uint64_t, 2> Indexes;
- extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs);
+ extractParts(MI.getOperand(1).getReg(), NarrowTy, NumParts, SrcRegs,
+ MIRBuilder, MRI);
Register OpReg = MI.getOperand(0).getReg();
uint64_t OpStart = MI.getOperand(2).getImm();
@@ -5814,7 +5728,7 @@
LLT RegTy = MRI.getType(MI.getOperand(0).getReg());
LLT LeftoverTy;
extractParts(MI.getOperand(1).getReg(), RegTy, NarrowTy, LeftoverTy, SrcRegs,
- LeftoverRegs);
+ LeftoverRegs, MIRBuilder, MRI);
for (Register Reg : LeftoverRegs)
SrcRegs.push_back(Reg);
@@ -5899,12 +5813,12 @@
SmallVector<Register, 4> Src1Regs, Src1LeftoverRegs;
LLT LeftoverTy;
if (!extractParts(MI.getOperand(1).getReg(), DstTy, NarrowTy, LeftoverTy,
- Src0Regs, Src0LeftoverRegs))
+ Src0Regs, Src0LeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;
LLT Unused;
if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, Unused,
- Src1Regs, Src1LeftoverRegs))
+ Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
llvm_unreachable("inconsistent extractParts result");
for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
@@ -5967,12 +5881,12 @@
SmallVector<Register, 4> Src2Regs, Src2LeftoverRegs;
LLT LeftoverTy;
if (!extractParts(MI.getOperand(2).getReg(), DstTy, NarrowTy, LeftoverTy,
- Src1Regs, Src1LeftoverRegs))
+ Src1Regs, Src1LeftoverRegs, MIRBuilder, MRI))
return UnableToLegalize;
LLT Unused;
if (!extractParts(MI.getOperand(3).getReg(), DstTy, NarrowTy, Unused,
- Src2Regs, Src2LeftoverRegs))
+ Src2Regs, Src2LeftoverRegs, MIRBuilder, MRI))
llvm_unreachable("inconsistent extractParts result");
for (unsigned I = 0, E = Src1Regs.size(); I != E; ++I) {
@@ -6468,7 +6382,7 @@
// First, split the source into two smaller vectors.
SmallVector<Register, 2> SplitSrcs;
- extractParts(SrcReg, SplitSrcTy, 2, SplitSrcs);
+ extractParts(SrcReg, SplitSrcTy, 2, SplitSrcs, MIRBuilder, MRI);
// Truncate the splits into intermediate narrower elements.
LLT InterTy;
@@ -7208,7 +7122,7 @@
int64_t IdxVal;
if (mi_match(Idx, MRI, m_ICst(IdxVal)) && IdxVal <= NumElts) {
SmallVector<Register, 8> SrcRegs;
- extractParts(SrcVec, EltTy, NumElts, SrcRegs);
+ extractParts(SrcVec, EltTy, NumElts, SrcRegs, MIRBuilder, MRI);
if (InsertVal) {
SrcRegs[IdxVal] = MI.getOperand(2).getReg();