[GlobalISel][AMDGPU] Lower G_UMULO/G_SMULO

Reviewed By: foad

Differential Revision: https://reviews.llvm.org/D93963
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index f74ef8e..97a5c64 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1824,6 +1824,62 @@
 }
 
 LegalizerHelper::LegalizeResult
+LegalizerHelper::widenScalarMulo(MachineInstr &MI, unsigned TypeIdx,
+                                 LLT WideTy) {
+  if (TypeIdx == 1)
+    return UnableToLegalize;
+
+  bool IsSigned = MI.getOpcode() == TargetOpcode::G_SMULO;
+  Register Result = MI.getOperand(0).getReg();
+  Register OriginalOverflow = MI.getOperand(1).getReg();
+  Register LHS = MI.getOperand(2).getReg();
+  Register RHS = MI.getOperand(3).getReg();
+  LLT SrcTy = MRI.getType(LHS);
+  LLT OverflowTy = MRI.getType(OriginalOverflow);
+  unsigned SrcBitWidth = SrcTy.getScalarSizeInBits();
+
+  // To determine if the result overflowed in the larger type, we extend the
+  // input to the larger type, do the multiply (checking if it overflows),
+  // then also check the high bits of the result to see if overflow happened
+  // there.
+  unsigned ExtOp = IsSigned ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
+  auto LeftOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {LHS});
+  auto RightOperand = MIRBuilder.buildInstr(ExtOp, {WideTy}, {RHS});
+
+  auto Mulo = MIRBuilder.buildInstr(MI.getOpcode(), {WideTy, OverflowTy},
+                                    {LeftOperand, RightOperand});
+  auto Mul = Mulo->getOperand(0);
+  MIRBuilder.buildTrunc(Result, Mul);
+
+  MachineInstrBuilder ExtResult;
+  // Overflow occurred if it occurred in the larger type, or if the high part
+  // of the result does not zero/sign-extend the low part.  Check this second
+  // possibility first.
+  if (IsSigned) {
+    // For signed, overflow occurred when the high part does not sign-extend
+    // the low part.
+    ExtResult = MIRBuilder.buildSExtInReg(WideTy, Mul, SrcBitWidth);
+  } else {
+    // Unsigned overflow occurred when the high part does not zero-extend the
+    // low part.
+    ExtResult = MIRBuilder.buildZExtInReg(WideTy, Mul, SrcBitWidth);
+  }
+
+  // Multiplication cannot overflow if the WideTy is >= 2 * original width,
+  // so we don't need to check the overflow result of larger type Mulo.
+  if (WideTy.getScalarSizeInBits() < 2 * SrcBitWidth) {
+    auto Overflow =
+        MIRBuilder.buildICmp(CmpInst::ICMP_NE, OverflowTy, Mul, ExtResult);
+    // Finally check if the multiplication in the larger type itself overflowed.
+    MIRBuilder.buildOr(OriginalOverflow, Mulo->getOperand(1), Overflow);
+  } else {
+    MIRBuilder.buildICmp(CmpInst::ICMP_NE, OriginalOverflow, Mul, ExtResult);
+  }
+  MI.eraseFromParent();
+  return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
 LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
   switch (MI.getOpcode()) {
   default:
@@ -1845,6 +1901,9 @@
   case TargetOpcode::G_UADDE:
   case TargetOpcode::G_USUBE:
     return widenScalarAddSubOverflow(MI, TypeIdx, WideTy);
+  case TargetOpcode::G_UMULO:
+  case TargetOpcode::G_SMULO:
+    return widenScalarMulo(MI, TypeIdx, WideTy);
   case TargetOpcode::G_SADDSAT:
   case TargetOpcode::G_SSUBSAT:
   case TargetOpcode::G_SSHLSAT:
@@ -3619,6 +3678,55 @@
   return Legalized;
 }
 
+LegalizerHelper::LegalizeResult
+LegalizerHelper::fewerElementsVectorMulo(MachineInstr &MI, unsigned TypeIdx,
+                                         LLT NarrowTy) {
+  Register Result = MI.getOperand(0).getReg();
+  Register Overflow = MI.getOperand(1).getReg();
+  Register LHS = MI.getOperand(2).getReg();
+  Register RHS = MI.getOperand(3).getReg();
+
+  LLT SrcTy = MRI.getType(LHS);
+  if (!SrcTy.isVector())
+    return UnableToLegalize;
+
+  LLT ElementType = SrcTy.getElementType();
+  LLT OverflowElementTy = MRI.getType(Overflow).getElementType();
+  const int NumResult = SrcTy.getNumElements();
+  LLT GCDTy = getGCDType(SrcTy, NarrowTy);
+
+  // Unmerge the operands to smaller parts of GCD type.
+  auto UnmergeLHS = MIRBuilder.buildUnmerge(GCDTy, LHS);
+  auto UnmergeRHS = MIRBuilder.buildUnmerge(GCDTy, RHS);
+
+  const int NumOps = UnmergeLHS->getNumOperands() - 1;
+  const int PartsPerUnmerge = NumResult / NumOps;
+  LLT OverflowTy = LLT::scalarOrVector(PartsPerUnmerge, OverflowElementTy);
+  LLT ResultTy = LLT::scalarOrVector(PartsPerUnmerge, ElementType);
+
+  // Perform the operation over unmerged parts.
+  SmallVector<Register, 8> ResultParts;
+  SmallVector<Register, 8> OverflowParts;
+  for (int I = 0; I != NumOps; ++I) {
+    Register Operand1 = UnmergeLHS->getOperand(I).getReg();
+    Register Operand2 = UnmergeRHS->getOperand(I).getReg();
+    auto PartMul = MIRBuilder.buildInstr(MI.getOpcode(), {ResultTy, OverflowTy},
+                                         {Operand1, Operand2});
+    ResultParts.push_back(PartMul->getOperand(0).getReg());
+    OverflowParts.push_back(PartMul->getOperand(1).getReg());
+  }
+
+  LLT ResultLCMTy = buildLCMMergePieces(SrcTy, NarrowTy, GCDTy, ResultParts);
+  LLT OverflowLCMTy =
+      LLT::scalarOrVector(ResultLCMTy.getNumElements(), OverflowElementTy);
+
+  // Recombine the pieces to the original result and overflow registers.
+  buildWidenedRemergeToDst(Result, ResultLCMTy, ResultParts);
+  buildWidenedRemergeToDst(Overflow, OverflowLCMTy, OverflowParts);
+  MI.eraseFromParent();
+  return Legalized;
+}
+
 // Handle FewerElementsVector a G_BUILD_VECTOR or G_CONCAT_VECTORS that produces
 // a vector
 //
@@ -4026,6 +4134,9 @@
   case G_UADDSAT:
   case G_USUBSAT:
     return reduceOperationWidth(MI, TypeIdx, NarrowTy);
+  case G_UMULO:
+  case G_SMULO:
+    return fewerElementsVectorMulo(MI, TypeIdx, NarrowTy);
   case G_SHL:
   case G_LSHR:
   case G_ASHR: