GlobalISel: Don't ignore requested ext narrowing type
This was assuming the narrow target was the source type. Respect the
requested type when these don't match by using intermediate
merges. This avoids producing very wide, illegal shift expansions.
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 1513f98..601ae26 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -63,6 +63,35 @@
return std::make_pair(NumParts, NumLeftover);
}
+static LLT getGCDType(LLT OrigTy, LLT TargetTy) {
+ if (OrigTy.isVector() && TargetTy.isVector()) {
+ assert(OrigTy.getElementType() == TargetTy.getElementType());
+ int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
+ TargetTy.getNumElements());
+ return LLT::scalarOrVector(GCD, OrigTy.getElementType());
+ }
+
+ if (OrigTy.isVector() && !TargetTy.isVector()) {
+ assert(OrigTy.getElementType() == TargetTy);
+ return TargetTy;
+ }
+
+ assert(!OrigTy.isVector() && !TargetTy.isVector() &&
+ "GCD type of vector and scalar not implemented");
+
+ int GCD = greatestCommonDivisor(OrigTy.getSizeInBits(),
+ TargetTy.getSizeInBits());
+ return LLT::scalar(GCD);
+}
+
+static LLT getLCMType(LLT Ty0, LLT Ty1) {
+ assert(Ty0.isScalar() && Ty1.isScalar() && "not yet handled");
+ unsigned Mul = Ty0.getSizeInBits() * Ty1.getSizeInBits();
+ int GCDSize = greatestCommonDivisor(Ty0.getSizeInBits(),
+ Ty1.getSizeInBits());
+ return LLT::scalar(Mul / GCDSize);
+}
+
LegalizerHelper::LegalizerHelper(MachineFunction &MF,
GISelChangeObserver &Observer,
MachineIRBuilder &Builder)
@@ -172,26 +201,6 @@
return true;
}
-static LLT getGCDType(LLT OrigTy, LLT TargetTy) {
- if (OrigTy.isVector() && TargetTy.isVector()) {
- assert(OrigTy.getElementType() == TargetTy.getElementType());
- int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
- TargetTy.getNumElements());
- return LLT::scalarOrVector(GCD, OrigTy.getElementType());
- }
-
- if (OrigTy.isVector() && !TargetTy.isVector()) {
- assert(OrigTy.getElementType() == TargetTy);
- return TargetTy;
- }
-
- assert(!OrigTy.isVector() && !TargetTy.isVector());
-
- int GCD = greatestCommonDivisor(OrigTy.getSizeInBits(),
- TargetTy.getSizeInBits());
- return LLT::scalar(GCD);
-}
-
void LegalizerHelper::insertParts(Register DstReg,
LLT ResultTy, LLT PartTy,
ArrayRef<Register> PartRegs,
@@ -237,6 +246,133 @@
}
}
+/// Return the result registers of G_UNMERGE_VALUES \p MI in \p Regs
+static void getUnmergeResults(SmallVectorImpl<Register> &Regs,
+ const MachineInstr &MI) {
+ assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES);
+
+ const int NumResults = MI.getNumOperands() - 1;
+ Regs.resize(NumResults);
+ for (int I = 0; I != NumResults; ++I)
+ Regs[I] = MI.getOperand(I).getReg();
+}
+
+LLT LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts, LLT DstTy,
+ LLT NarrowTy, Register SrcReg) {
+ LLT SrcTy = MRI.getType(SrcReg);
+
+ LLT GCDTy = getGCDType(getGCDType(SrcTy, NarrowTy), DstTy);
+ if (SrcTy == GCDTy) {
+ // If the source already evenly divides the result type, we don't need to do
+ // anything.
+ Parts.push_back(SrcReg);
+ } else {
+ // Need to split into common type sized pieces.
+ auto Unmerge = MIRBuilder.buildUnmerge(GCDTy, SrcReg);
+ getUnmergeResults(Parts, *Unmerge);
+ }
+
+ return GCDTy;
+}
+
+void LegalizerHelper::buildLCMMerge(Register DstReg, LLT NarrowTy, LLT GCDTy,
+ SmallVectorImpl<Register> &VRegs,
+ unsigned PadStrategy) {
+ LLT DstTy = MRI.getType(DstReg);
+ LLT LCMTy = getLCMType(DstTy, NarrowTy);
+
+ int NumParts = LCMTy.getSizeInBits() / NarrowTy.getSizeInBits();
+ int NumSubParts = NarrowTy.getSizeInBits() / GCDTy.getSizeInBits();
+ int NumOrigSrc = VRegs.size();
+
+ Register PadReg;
+
+ // Get a value we can use to pad the source value if the sources won't evenly
+ // cover the result type.
+ if (NumOrigSrc < NumParts * NumSubParts) {
+ if (PadStrategy == TargetOpcode::G_ZEXT)
+ PadReg = MIRBuilder.buildConstant(GCDTy, 0).getReg(0);
+ else if (PadStrategy == TargetOpcode::G_ANYEXT)
+ PadReg = MIRBuilder.buildUndef(GCDTy).getReg(0);
+ else {
+ assert(PadStrategy == TargetOpcode::G_SEXT);
+
+ // Shift the sign bit of the low register through the high register.
+ auto ShiftAmt =
+ MIRBuilder.buildConstant(LLT::scalar(64), GCDTy.getSizeInBits() - 1);
+ PadReg = MIRBuilder.buildAShr(GCDTy, VRegs.back(), ShiftAmt).getReg(0);
+ }
+ }
+
+ // Registers for the final merge to be produced.
+ SmallVector<Register, 4> Remerge;
+ Remerge.resize(NumParts);
+
+ // Registers needed for intermediate merges, which will be merged into a
+ // source for Remerge.
+ SmallVector<Register, 4> SubMerge;
+ SubMerge.resize(NumSubParts);
+
+ // Once we've fully read off the end of the original source bits, we can reuse
+ // the same high bits for remaining padding elements.
+ Register AllPadReg;
+
+ // Build merges to the LCM type to cover the original result type.
+ for (int I = 0; I != NumParts; ++I) {
+ bool AllMergePartsArePadding = true;
+
+ // Build the requested merges to the requested type.
+ for (int J = 0; J != NumSubParts; ++J) {
+ int Idx = I * NumSubParts + J;
+ if (Idx >= NumOrigSrc) {
+ SubMerge[J] = PadReg;
+ continue;
+ }
+
+ SubMerge[J] = VRegs[Idx];
+
+ // There are meaningful bits here we can't reuse later.
+ AllMergePartsArePadding = false;
+ }
+
+ // If we've filled up a complete piece with padding bits, we can directly
+ // emit the natural sized constant if applicable, rather than a merge of
+ // smaller constants.
+ if (AllMergePartsArePadding && !AllPadReg) {
+ if (PadStrategy == TargetOpcode::G_ANYEXT)
+ AllPadReg = MIRBuilder.buildUndef(NarrowTy).getReg(0);
+ else if (PadStrategy == TargetOpcode::G_ZEXT)
+ AllPadReg = MIRBuilder.buildConstant(NarrowTy, 0).getReg(0);
+
+ // If this is a sign extension, we can't materialize a trivial constant
+ // with the right type and have to produce a merge.
+ }
+
+ if (AllPadReg) {
+ // Avoid creating additional instructions if we're just adding additional
+ // copies of padding bits.
+ Remerge[I] = AllPadReg;
+ continue;
+ }
+
+ if (NumSubParts == 1)
+ Remerge[I] = SubMerge[0];
+ else
+ Remerge[I] = MIRBuilder.buildMerge(NarrowTy, SubMerge).getReg(0);
+
+ // In the sign extend padding case, re-use the first all-signbit merge.
+ if (AllMergePartsArePadding && !AllPadReg)
+ AllPadReg = Remerge[I];
+ }
+
+ // Create the merge to the widened source, and extract the relevant bits into
+ // the result.
+ if (DstTy == LCMTy)
+ MIRBuilder.buildMerge(DstReg, Remerge);
+ else
+ MIRBuilder.buildTrunc(DstReg, MIRBuilder.buildMerge(LCMTy, Remerge));
+}
+
static RTLIB::Libcall getRTLibDesc(unsigned Opcode, unsigned Size) {
switch (Opcode) {
case TargetOpcode::G_SDIV:
@@ -3658,33 +3794,14 @@
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
- LLT DstTy = MRI.getType(DstReg);
- LLT SrcTy = MRI.getType(SrcReg);
- unsigned DstSize = DstTy.getSizeInBits();
- unsigned SrcSize = SrcTy.getSizeInBits();
- if (DstSize % SrcSize != 0)
+ LLT DstTy = MRI.getType(DstReg);
+ if (DstTy.isVector())
return UnableToLegalize;
- Register PadReg;
- if (MI.getOpcode() == TargetOpcode::G_ZEXT)
- PadReg = MIRBuilder.buildConstant(SrcTy, 0).getReg(0);
- else if (MI.getOpcode() == TargetOpcode::G_ANYEXT)
- PadReg = MIRBuilder.buildUndef(SrcTy).getReg(0);
- else {
- // Shift the sign bit of the low register through the high register.
- auto ShiftAmt =
- MIRBuilder.buildConstant(LLT::scalar(64), SrcSize - 1);
- PadReg = MIRBuilder.buildAShr(SrcTy, SrcReg, ShiftAmt).getReg(0);
- }
-
- // Generate a merge where the bottom bits are taken from the source, and
- // zero/impdef/sign bit everything else.
- unsigned NumParts = DstSize / SrcSize;
- SmallVector<Register, 4> Srcs = {SrcReg};
- for (unsigned Part = 1; Part < NumParts; ++Part)
- Srcs.push_back(PadReg);
- MIRBuilder.buildMerge(DstReg, Srcs);
+ SmallVector<Register, 8> Parts;
+ LLT GCDTy = extractGCDType(Parts, DstTy, NarrowTy, SrcReg);
+ buildLCMMerge(DstReg, NarrowTy, GCDTy, Parts, MI.getOpcode());
MI.eraseFromParent();
return Legalized;
}