Reapply "[DemandedBits][BDCE] Support vectors of integers"
DemandedBits and BDCE currently only support scalar integers. This
patch extends them to also handle vector integer operations. In this
case bits are not tracked for individual vector elements, instead a
bit is demanded if it is demanded for any of the elements. This matches
the behavior of computeKnownBits in ValueTracking and
SimplifyDemandedBits in InstCombine.
Unlike the previous iteration of this patch, getDemandedBits() can now
again be called on arbirary (sized) instructions, even if they don't
have integer or vector of integer type. (For vector types the size of the
returned mask will now be the scalar size in bits though.)
The added LoopVectorize test case shows a case which triggered an
assertion failure with the previous attempt, because getDemandedBits()
was called on a pointer-typed instruction.
Differential Revision: https://reviews.llvm.org/D55297
llvm-svn: 348602
diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
index 6bef771..0382787 100644
--- a/llvm/lib/Analysis/DemandedBits.cpp
+++ b/llvm/lib/Analysis/DemandedBits.cpp
@@ -39,6 +39,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Use.h"
#include "llvm/Pass.h"
@@ -50,6 +51,7 @@
#include <cstdint>
using namespace llvm;
+using namespace llvm::PatternMatch;
#define DEBUG_TYPE "demanded-bits"
@@ -143,17 +145,17 @@
}
break;
case Intrinsic::fshl:
- case Intrinsic::fshr:
+ case Intrinsic::fshr: {
+ const APInt *SA;
if (OperandNo == 2) {
// Shift amount is modulo the bitwidth. For powers of two we have
// SA % BW == SA & (BW - 1).
if (isPowerOf2_32(BitWidth))
AB = BitWidth - 1;
- } else if (auto *SA = dyn_cast<ConstantInt>(II->getOperand(2))) {
- // TODO: Support vectors.
+ } else if (match(II->getOperand(2), m_APInt(SA))) {
// Normalize to funnel shift left. APInt shifts of BitWidth are well-
// defined, so no need to special-case zero shifts here.
- uint64_t ShiftAmt = SA->getValue().urem(BitWidth);
+ uint64_t ShiftAmt = SA->urem(BitWidth);
if (II->getIntrinsicID() == Intrinsic::fshr)
ShiftAmt = BitWidth - ShiftAmt;
@@ -164,6 +166,7 @@
}
break;
}
+ }
break;
case Instruction::Add:
case Instruction::Sub:
@@ -174,8 +177,9 @@
AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
break;
case Instruction::Shl:
- if (OperandNo == 0)
- if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) {
+ if (OperandNo == 0) {
+ const APInt *ShiftAmtC;
+ if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
AB = AOut.lshr(ShiftAmt);
@@ -187,10 +191,12 @@
else if (S->hasNoUnsignedWrap())
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
}
+ }
break;
case Instruction::LShr:
- if (OperandNo == 0)
- if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) {
+ if (OperandNo == 0) {
+ const APInt *ShiftAmtC;
+ if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
AB = AOut.shl(ShiftAmt);
@@ -199,10 +205,12 @@
if (cast<LShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
}
+ }
break;
case Instruction::AShr:
- if (OperandNo == 0)
- if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) {
+ if (OperandNo == 0) {
+ const APInt *ShiftAmtC;
+ if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
AB = AOut.shl(ShiftAmt);
// Because the high input bit is replicated into the
@@ -217,6 +225,7 @@
if (cast<AShrOperator>(UserI)->isExact())
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
}
+ }
break;
case Instruction::And:
AB = AOut;
@@ -274,6 +283,15 @@
if (OperandNo != 0)
AB = AOut;
break;
+ case Instruction::ExtractElement:
+ if (OperandNo == 0)
+ AB = AOut;
+ break;
+ case Instruction::InsertElement:
+ case Instruction::ShuffleVector:
+ if (OperandNo == 0 || OperandNo == 1)
+ AB = AOut;
+ break;
}
}
@@ -309,8 +327,9 @@
// bits and add the instruction to the work list. For other instructions
// add their operands to the work list (for integer values operands, mark
// all bits as live).
- if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
- if (AliveBits.try_emplace(&I, IT->getBitWidth(), 0).second)
+ Type *T = I.getType();
+ if (T->isIntOrIntVectorTy()) {
+ if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
Worklist.push_back(&I);
continue;
@@ -319,8 +338,9 @@
// Non-integer-typed instructions...
for (Use &OI : I.operands()) {
if (Instruction *J = dyn_cast<Instruction>(OI)) {
- if (IntegerType *IT = dyn_cast<IntegerType>(J->getType()))
- AliveBits[J] = APInt::getAllOnesValue(IT->getBitWidth());
+ Type *T = J->getType();
+ if (T->isIntOrIntVectorTy())
+ AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits());
Worklist.push_back(J);
}
}
@@ -336,13 +356,13 @@
LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
APInt AOut;
- if (UserI->getType()->isIntegerTy()) {
+ if (UserI->getType()->isIntOrIntVectorTy()) {
AOut = AliveBits[UserI];
LLVM_DEBUG(dbgs() << " Alive Out: " << AOut);
}
LLVM_DEBUG(dbgs() << "\n");
- if (!UserI->getType()->isIntegerTy())
+ if (!UserI->getType()->isIntOrIntVectorTy())
Visited.insert(UserI);
KnownBits Known, Known2;
@@ -351,10 +371,11 @@
// operand is added to the work-list.
for (Use &OI : UserI->operands()) {
if (Instruction *I = dyn_cast<Instruction>(OI)) {
- if (IntegerType *IT = dyn_cast<IntegerType>(I->getType())) {
- unsigned BitWidth = IT->getBitWidth();
+ Type *T = I->getType();
+ if (T->isIntOrIntVectorTy()) {
+ unsigned BitWidth = T->getScalarSizeInBits();
APInt AB = APInt::getAllOnesValue(BitWidth);
- if (UserI->getType()->isIntegerTy() && !AOut &&
+ if (UserI->getType()->isIntOrIntVectorTy() && !AOut &&
!isAlwaysLive(UserI)) {
AB = APInt(BitWidth, 0);
} else {
@@ -389,11 +410,13 @@
APInt DemandedBits::getDemandedBits(Instruction *I) {
performAnalysis();
- const DataLayout &DL = I->getModule()->getDataLayout();
auto Found = AliveBits.find(I);
if (Found != AliveBits.end())
return Found->second;
- return APInt::getAllOnesValue(DL.getTypeSizeInBits(I->getType()));
+
+ const DataLayout &DL = I->getModule()->getDataLayout();
+ return APInt::getAllOnesValue(
+ DL.getTypeSizeInBits(I->getType()->getScalarType()));
}
bool DemandedBits::isInstructionDead(Instruction *I) {