[InstCombine] canEvaluateTruncated - use KnownBits to check for inrange shift amounts

Currently canEvaluateTruncated can only attempt to truncate shifts if they are scalar/uniform constant amounts that are in range.

This patch replaces the constant extraction code with KnownBits handling, using the KnownBits::getMaxValue to check that the amounts are inrange.

This enables support for nonuniform constant cases, and also variable shift amounts that have been masked somehow. Annoyingly, this still won't work for vectors with (demanded) undefs as KnownBits returns nothing in those cases, but its a definite improvement on what we currently have.

Differential Revision: https://reviews.llvm.org/D83127
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 8d9ebe4..7b3c503 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -377,29 +377,31 @@
     break;
   }
   case Instruction::Shl: {
-    // If we are truncating the result of this SHL, and if it's a shift of a
-    // constant amount, we can always perform a SHL in a smaller type.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
-      uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (Amt->getLimitedValue(BitWidth) < BitWidth)
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
-    }
+    // If we are truncating the result of this SHL, and if it's a shift of an
+    // inrange amount, we can always perform a SHL in a smaller type.
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    KnownBits AmtKnownBits =
+        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    if (AmtKnownBits.getMaxValue().ult(BitWidth))
+      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
     break;
   }
   case Instruction::LShr: {
     // If this is a truncate of a logical shr, we can truncate it to a smaller
     // lshr iff we know that the bits we would otherwise be shifting in are
     // already zeros.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
-      uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
-      uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (Amt->getLimitedValue(BitWidth) < BitWidth &&
-          IC.MaskedValueIsZero(I->getOperand(0),
-            APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) {
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
-      }
+    // TODO: It is enough to check that the bits we would be shifting in are
+    //       zero - use AmtKnownBits.getMaxValue().
+    uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    KnownBits AmtKnownBits =
+        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+    if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
+        IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
+      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
     }
     break;
   }
@@ -409,15 +411,15 @@
     // original type and the sign bit of the truncate type are similar.
     // TODO: It is enough to check that the bits we would be shifting in are
     //       similar to sign bit of the truncate type.
-    const APInt *Amt;
-    if (match(I->getOperand(1), m_APInt(Amt))) {
-      uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
-      uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (Amt->getLimitedValue(BitWidth) < BitWidth &&
-          OrigBitWidth - BitWidth <
-              IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
-        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
-    }
+    uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+    uint32_t BitWidth = Ty->getScalarSizeInBits();
+    KnownBits AmtKnownBits =
+        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    unsigned ShiftedBits = OrigBitWidth - BitWidth;
+    if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
+        ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
+      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
     break;
   }
   case Instruction::Trunc:
diff --git a/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll b/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
index 999b5d5..89e4a3c 100644
--- a/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
+++ b/llvm/test/Transforms/InstCombine/2008-01-21-MulTrunc.ll
@@ -35,12 +35,10 @@
 
 define <2 x i16> @test1_vec_nonuniform(<2 x i16> %a) {
 ; CHECK-LABEL: @test1_vec_nonuniform(
-; CHECK-NEXT:    [[B:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[C:%.*]] = lshr <2 x i32> [[B]], <i32 8, i32 9>
-; CHECK-NEXT:    [[D:%.*]] = mul nuw nsw <2 x i32> [[B]], <i32 5, i32 6>
-; CHECK-NEXT:    [[E:%.*]] = or <2 x i32> [[C]], [[D]]
-; CHECK-NEXT:    [[F:%.*]] = trunc <2 x i32> [[E]] to <2 x i16>
-; CHECK-NEXT:    ret <2 x i16> [[F]]
+; CHECK-NEXT:    [[C:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 8, i16 9>
+; CHECK-NEXT:    [[D:%.*]] = mul <2 x i16> [[A]], <i16 5, i16 6>
+; CHECK-NEXT:    [[E:%.*]] = or <2 x i16> [[C]], [[D]]
+; CHECK-NEXT:    ret <2 x i16> [[E]]
 ;
   %b = zext <2 x i16> %a to <2 x i32>
   %c = lshr <2 x i32> %b, <i32 8, i32 9>
diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index 10d59bf..18b4111 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -502,12 +502,10 @@
 
 define <2 x i16> @test40vec_nonuniform(<2 x i16> %a) {
 ; ALL-LABEL: @test40vec_nonuniform(
-; ALL-NEXT:    [[T:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
-; ALL-NEXT:    [[T21:%.*]] = lshr <2 x i32> [[T]], <i32 9, i32 10>
-; ALL-NEXT:    [[T5:%.*]] = shl <2 x i32> [[T]], <i32 8, i32 9>
-; ALL-NEXT:    [[T32:%.*]] = or <2 x i32> [[T21]], [[T5]]
-; ALL-NEXT:    [[R:%.*]] = trunc <2 x i32> [[T32]] to <2 x i16>
-; ALL-NEXT:    ret <2 x i16> [[R]]
+; ALL-NEXT:    [[T21:%.*]] = lshr <2 x i16> [[A:%.*]], <i16 9, i16 10>
+; ALL-NEXT:    [[T5:%.*]] = shl <2 x i16> [[A]], <i16 8, i16 9>
+; ALL-NEXT:    [[T32:%.*]] = or <2 x i16> [[T21]], [[T5]]
+; ALL-NEXT:    ret <2 x i16> [[T32]]
 ;
   %t = zext <2 x i16> %a to <2 x i32>
   %t21 = lshr <2 x i32> %t, <i32 9, i32 10>
diff --git a/llvm/test/Transforms/InstCombine/trunc.ll b/llvm/test/Transforms/InstCombine/trunc.ll
index 4e9f440..d8a615c 100644
--- a/llvm/test/Transforms/InstCombine/trunc.ll
+++ b/llvm/test/Transforms/InstCombine/trunc.ll
@@ -286,12 +286,11 @@
 
 define <2 x i64> @test8_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test8_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
-; CHECK-NEXT:    [[D:%.*]] = zext <2 x i32> [[B:%.*]] to <2 x i128>
-; CHECK-NEXT:    [[E:%.*]] = shl <2 x i128> [[D]], <i128 32, i128 48>
-; CHECK-NEXT:    [[F:%.*]] = or <2 x i128> [[E]], [[C]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
+; CHECK-NEXT:    [[D:%.*]] = zext <2 x i32> [[B:%.*]] to <2 x i64>
+; CHECK-NEXT:    [[E:%.*]] = shl <2 x i64> [[D]], <i64 32, i64 48>
+; CHECK-NEXT:    [[F:%.*]] = or <2 x i64> [[E]], [[C]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -343,12 +342,11 @@
 
 define i64 @test11(i32 %A, i32 %B) {
 ; CHECK-LABEL: @test11(
-; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i128
+; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[B:%.*]], 31
-; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i128
-; CHECK-NEXT:    [[F:%.*]] = shl i128 [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc i128 [[F]] to i64
-; CHECK-NEXT:    ret i64 [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[F:%.*]] = shl i64 [[C]], [[E]]
+; CHECK-NEXT:    ret i64 [[F]]
 ;
   %C = zext i32 %A to i128
   %D = zext i32 %B to i128
@@ -360,12 +358,11 @@
 
 define <2 x i64> @test11_vec(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test11_vec(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = shl <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = shl <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -377,12 +374,11 @@
 
 define <2 x i64> @test11_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test11_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = shl <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = shl <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -411,12 +407,11 @@
 
 define i64 @test12(i32 %A, i32 %B) {
 ; CHECK-LABEL: @test12(
-; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i128
+; CHECK-NEXT:    [[C:%.*]] = zext i32 [[A:%.*]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[B:%.*]], 31
-; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i128
-; CHECK-NEXT:    [[F:%.*]] = lshr i128 [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc i128 [[F]] to i64
-; CHECK-NEXT:    ret i64 [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[F:%.*]] = lshr i64 [[C]], [[E]]
+; CHECK-NEXT:    ret i64 [[F]]
 ;
   %C = zext i32 %A to i128
   %D = zext i32 %B to i128
@@ -428,12 +423,11 @@
 
 define <2 x i64> @test12_vec(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test12_vec(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -445,12 +439,11 @@
 
 define <2 x i64> @test12_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test12_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = zext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = lshr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = zext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -479,12 +472,11 @@
 
 define i64 @test13(i32 %A, i32 %B) {
 ; CHECK-LABEL: @test13(
-; CHECK-NEXT:    [[C:%.*]] = sext i32 [[A:%.*]] to i128
+; CHECK-NEXT:    [[C:%.*]] = sext i32 [[A:%.*]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[B:%.*]], 31
-; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i128
-; CHECK-NEXT:    [[F:%.*]] = ashr i128 [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc i128 [[F]] to i64
-; CHECK-NEXT:    ret i64 [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[F:%.*]] = ashr i64 [[C]], [[E]]
+; CHECK-NEXT:    ret i64 [[F]]
 ;
   %C = sext i32 %A to i128
   %D = zext i32 %B to i128
@@ -496,12 +488,11 @@
 
 define <2 x i64> @test13_vec(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test13_vec(
-; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 31>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = sext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>
@@ -513,12 +504,11 @@
 
 define <2 x i64> @test13_vec_nonuniform(<2 x i32> %A, <2 x i32> %B) {
 ; CHECK-LABEL: @test13_vec_nonuniform(
-; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i128>
+; CHECK-NEXT:    [[C:%.*]] = sext <2 x i32> [[A:%.*]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[B:%.*]], <i32 31, i32 15>
-; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i128>
-; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i128> [[C]], [[E]]
-; CHECK-NEXT:    [[G:%.*]] = trunc <2 x i128> [[F]] to <2 x i64>
-; CHECK-NEXT:    ret <2 x i64> [[G]]
+; CHECK-NEXT:    [[E:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    [[F:%.*]] = ashr <2 x i64> [[C]], [[E]]
+; CHECK-NEXT:    ret <2 x i64> [[F]]
 ;
   %C = sext <2 x i32> %A to <2 x i128>
   %D = zext <2 x i32> %B to <2 x i128>