[GlobalISel][AMDGPU] Lower G_UMULO/G_SMULO
Reviewed By: foad
Differential Revision: https://reviews.llvm.org/D93963
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index f74ef8e..97a5c64 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1824,6 +1824,62 @@
}
LegalizerHelper::LegalizeResult
+LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
+ LLT WideTy) {
+ if (TypeIdx == 1)
+ return UnableToLegalize;
+
+ bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULO;
+ Register Result = MI.getOperand(0).getReg();
+ Register OriginalOverflow = MI.getOperand(1).getReg();
+ Register LHS = MI.getOperand(2).getReg();
+ Register RHS = MI.getOperand(3).getReg();
+ LLT SrcTy = MRI.getType(LHS);
+ LLT OverflowTy = MRI.getType(OriginalOverflow);
+ unsigned SrcBitWidth = SrcTy.getScalarSizeInBits();
+
+ // To determine if the result overflowed in the larger type, we extend the
+ // input to the larger type, do the multiply (checking if it overflows),
+ // then also check the high bits of the result to see if overflow happened
+ // there.
+ unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
+ auto LeftOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {LHS});
+ auto RightOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {RHS});
+
+ auto Mulo = MIRBuilder.buildInstr(MI.getOpcode(), {WideTy, OverflowTy},
+ {LeftOperand, RightOperand});
+ auto Mul = Mulo->getOperand(0);
+ MIRBuilder.buildTrunc(Result, Mul);
+
+ MachineInstrBuilder ExtResult;
+ // Overflow occurred if it occurred in the larger type, or if the high part
+ // of the result does not zero/sign-extend the low part. Check this second
+ // possibility first.
+ if (IsSigned) {
+ // For signed, overflow occurred when the high part does not sign-extend
+ // the low part.
+ ExtResult = MIRBuilder.buildSExtInReg(WideTy, Mul, SrcBitWidth);
+ } else {
+ // Unsigned overflow occurred when the high part does not zero-extend the
+ // low part.
+ ExtResult = MIRBuilder.buildZExtInReg(WideTy, Mul, SrcBitWidth);
+ }
+
+ // Multiplication cannot overflow if the WideTy is >= 2 * original width,
+ // so we don't need to check the overflow result of larger type Mulo.
+ if (WideTy.getScalarSizeInBits() < 2 * SrcBitWidth) {
+ auto Overflow =
+ MIRBuilder.buildICmp(CmpInst::ICMP_NE, OverflowTy, Mul, ExtResult);
+ // Finally check if the multiplication in the larger type itself overflowed.
+ MIRBuilder.buildOr(OriginalOverflow, Mulo->getOperand(1), Overflow);
+ } else {
+ MIRBuilder.buildICmp(CmpInst::ICMP_NE, OriginalOverflow, Mul, ExtResult);
+ }
+ MI.eraseFromParent();
+ return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
switch (MI.getOpcode()) {
default:
@@ -1845,6 +1901,9 @@
case TargetOpcode::G_UADDE:
case TargetOpcode::G_USUBE:
return widenScalarAddSubOverflow(MI, TypeIdx, WideTy);
+ case TargetOpcode::G_UMULO:
+ case TargetOpcode::G_SMULO:
+ return widenScalarMulo(MI, TypeIdx, WideTy);
case TargetOpcode::G_SADDSAT:
case TargetOpcode::G_SSUBSAT:
case TargetOpcode::G_SSHLSAT:
@@ -3619,6 +3678,55 @@
return Legalized;
}
+LegalizerHelper::LegalizeResult
+LegalizerHelper::fewerElementsVectorMulo(MachineInstr &MI, unsigned TypeIdx,
+ LLT NarrowTy) {
+ Register Result = MI.getOperand(0).getReg();
+ Register Overflow = MI.getOperand(1).getReg();
+ Register LHS = MI.getOperand(2).getReg();
+ Register RHS = MI.getOperand(3).getReg();
+
+ LLT SrcTy = MRI.getType(LHS);
+ if (!SrcTy.isVector())
+ return UnableToLegalize;
+
+ LLT ElementType = SrcTy.getElementType();
+ LLT OverflowElementTy = MRI.getType(Overflow).getElementType();
+ const int NumResult = SrcTy.getNumElements();
+ LLT GCDTy = getGCDType(SrcTy, NarrowTy);
+
+ // Unmerge the operands to smaller parts of GCD type.
+ auto UnmergeLHS = MIRBuilder.buildUnmerge(GCDTy, LHS);
+ auto UnmergeRHS = MIRBuilder.buildUnmerge(GCDTy, RHS);
+
+ const int NumOps = UnmergeLHS->getNumOperands() - 1;
+ const int PartsPerUnmerge = NumResult / NumOps;
+ LLT OverflowTy = LLT::scalarOrVector(PartsPerUnmerge, OverflowElementTy);
+ LLT ResultTy = LLT::scalarOrVector(PartsPerUnmerge, ElementType);
+
+ // Perform the operation over unmerged parts.
+ SmallVector<Register, 8> ResultParts;
+ SmallVector<Register, 8> OverflowParts;
+ for (int I = 0; I != NumOps; ++I) {
+ Register Operand1 = UnmergeLHS->getOperand(I).getReg();
+ Register Operand2 = UnmergeRHS->getOperand(I).getReg();
+ auto PartMul = MIRBuilder.buildInstr(MI.getOpcode(), {ResultTy, OverflowTy},
+ {Operand1, Operand2});
+ ResultParts.push_back(PartMul->getOperand(0).getReg());
+ OverflowParts.push_back(PartMul->getOperand(1).getReg());
+ }
+
+ LLT ResultLCMTy = buildLCMMergePieces(SrcTy, NarrowTy, GCDTy, ResultParts);
+ LLT OverflowLCMTy =
+ LLT::scalarOrVector(ResultLCMTy.getNumElements(), OverflowElementTy);
+
+ // Recombine the pieces to the original result and overflow registers.
+ buildWidenedRemergeToDst(Result, ResultLCMTy, ResultParts);
+ buildWidenedRemergeToDst(Overflow, OverflowLCMTy, OverflowParts);
+ MI.eraseFromParent();
+ return Legalized;
+}
+
// Handle FewerElementsVector a G_BUILD_VECTOR or G_CONCAT_VECTORS that produces
// a vector
//
@@ -4026,6 +4134,9 @@
case G_UADDSAT:
case G_USUBSAT:
return reduceOperationWidth(MI, TypeIdx, NarrowTy);
+ case G_UMULO:
+ case G_SMULO:
+ return fewerElementsVectorMulo(MI, TypeIdx, NarrowTy);
case G_SHL:
case G_LSHR:
case G_ASHR: