[SDAG] Use shift amount type in MULO promotion; NFC

Directly use the correct shift amount type if it is possible, and
future-proof the code against vectors. The added test makes sure that
bitwidths that do not fit into the shift amount type do not assert.

Split out from D57997.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@354359 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index aee4194..e664e06 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -952,9 +952,11 @@
   SDValue Overflow;
   if (N->getOpcode() == ISD::UMULO) {
     // Unsigned overflow occurred if the high part is non-zero.
+    unsigned Shift = SmallVT.getScalarSizeInBits();
+    EVT ShiftTy = getShiftAmountTyForConstant(Shift, Mul.getValueType(),
+                                              TLI, DAG);
     SDValue Hi = DAG.getNode(ISD::SRL, DL, Mul.getValueType(), Mul,
-                             DAG.getIntPtrConstant(SmallVT.getSizeInBits(),
-                                                   DL));
+                             DAG.getConstant(Shift, DL, ShiftTy));
     Overflow = DAG.getSetCC(DL, N->getValueType(1), Hi,
                             DAG.getConstant(0, DL, Hi.getValueType()),
                             ISD::SETNE);
diff --git a/test/CodeGen/X86/umul-with-overflow.ll b/test/CodeGen/X86/umul-with-overflow.ll
index 64a8933..c2a0dc0 100644
--- a/test/CodeGen/X86/umul-with-overflow.ll
+++ b/test/CodeGen/X86/umul-with-overflow.ll
@@ -68,3 +68,12 @@
 	%tmp2 = extractvalue { i32, i1 } %tmp1, 0
 	ret i32 %tmp2
 }
+
+; Check that shifts larger than the shift amount type are handled.
+; Intentionally not testing codegen here, only that this doesn't assert.
+declare {i300, i1} @llvm.umul.with.overflow.i300(i300 %a, i300 %b)
+define i300 @test4(i300 %a, i300 %b) nounwind {
+  %x = call {i300, i1} @llvm.umul.with.overflow.i300(i300 %a, i300 %b)
+  %y = extractvalue {i300, i1} %x, 0
+  ret i300 %y
+}