[Types] Define a getWithNewBitWidth for Types and make use of it
This is designed to change the bitwidth of a type without altering the number
of vector lanes. Also useful in D68651. Otherwise an NFC.
Differential Revision: https://reviews.llvm.org/D69139
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@375417 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/llvm/CodeGen/BasicTTIImpl.h b/include/llvm/CodeGen/BasicTTIImpl.h
index f73976d..2e57b4c 100644
--- a/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1149,9 +1149,7 @@
OpPropsBW);
// For non-rotates (X != Y) we must add shift-by-zero handling costs.
if (X != Y) {
- Type *CondTy = Type::getInt1Ty(RetTy->getContext());
- if (RetVF > 1)
- CondTy = VectorType::get(CondTy, RetVF);
+ Type *CondTy = RetTy->getWithNewBitWidth(1);
Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, RetTy,
CondTy, nullptr);
Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::Select, RetTy,
@@ -1169,7 +1167,6 @@
unsigned getIntrinsicInstrCost(
Intrinsic::ID IID, Type *RetTy, ArrayRef<Type *> Tys, FastMathFlags FMF,
unsigned ScalarizationCostPassed = std::numeric_limits<unsigned>::max()) {
- unsigned RetVF = (RetTy->isVectorTy() ? RetTy->getVectorNumElements() : 1);
auto *ConcreteTTI = static_cast<T *>(this);
SmallVector<unsigned, 2> ISDs;
@@ -1326,9 +1323,7 @@
/*IsUnsigned=*/false);
case Intrinsic::sadd_sat:
case Intrinsic::ssub_sat: {
- Type *CondTy = Type::getInt1Ty(RetTy->getContext());
- if (RetVF > 1)
- CondTy = VectorType::get(CondTy, RetVF);
+ Type *CondTy = RetTy->getWithNewBitWidth(1);
Type *OpTy = StructType::create({RetTy, CondTy});
Intrinsic::ID OverflowOp = IID == Intrinsic::sadd_sat
@@ -1348,9 +1343,7 @@
}
case Intrinsic::uadd_sat:
case Intrinsic::usub_sat: {
- Type *CondTy = Type::getInt1Ty(RetTy->getContext());
- if (RetVF > 1)
- CondTy = VectorType::get(CondTy, RetVF);
+ Type *CondTy = RetTy->getWithNewBitWidth(1);
Type *OpTy = StructType::create({RetTy, CondTy});
Intrinsic::ID OverflowOp = IID == Intrinsic::uadd_sat
@@ -1367,9 +1360,7 @@
case Intrinsic::smul_fix:
case Intrinsic::umul_fix: {
unsigned ExtSize = RetTy->getScalarSizeInBits() * 2;
- Type *ExtTy = Type::getIntNTy(RetTy->getContext(), ExtSize);
- if (RetVF > 1)
- ExtTy = VectorType::get(ExtTy, RetVF);
+ Type *ExtTy = RetTy->getWithNewBitWidth(ExtSize);
unsigned ExtOp =
IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt;
@@ -1433,9 +1424,7 @@
Type *MulTy = RetTy->getContainedType(0);
Type *OverflowTy = RetTy->getContainedType(1);
unsigned ExtSize = MulTy->getScalarSizeInBits() * 2;
- Type *ExtTy = Type::getIntNTy(RetTy->getContext(), ExtSize);
- if (MulTy->isVectorTy())
- ExtTy = VectorType::get(ExtTy, MulTy->getVectorNumElements() );
+ Type *ExtTy = MulTy->getWithNewBitWidth(ExtSize);
unsigned ExtOp =
IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt;
diff --git a/include/llvm/IR/DerivedTypes.h b/include/llvm/IR/DerivedTypes.h
index ade6376..20097ef 100644
--- a/include/llvm/IR/DerivedTypes.h
+++ b/include/llvm/IR/DerivedTypes.h
@@ -571,6 +571,10 @@
return cast<VectorType>(this)->isScalable();
}
+ElementCount Type::getVectorElementCount() const {
+ return cast<VectorType>(this)->getElementCount();
+}
+
/// Class to represent pointers.
class PointerType : public Type {
explicit PointerType(Type *ElType, unsigned AddrSpace);
@@ -618,6 +622,16 @@
return cast<IntegerType>(this)->getExtendedType();
}
+Type *Type::getWithNewBitWidth(unsigned NewBitWidth) const {
+ assert(
+ isIntOrIntVectorTy() &&
+ "Original type expected to be a vector of integers or a scalar integer.");
+ Type *NewType = getIntNTy(getContext(), NewBitWidth);
+ if (isVectorTy())
+ NewType = VectorType::get(NewType, getVectorElementCount());
+ return NewType;
+}
+
unsigned Type::getPointerAddressSpace() const {
return cast<PointerType>(getScalarType())->getAddressSpace();
}
diff --git a/include/llvm/IR/Type.h b/include/llvm/IR/Type.h
index 63bc884..d0961da 100644
--- a/include/llvm/IR/Type.h
+++ b/include/llvm/IR/Type.h
@@ -372,6 +372,7 @@
inline bool getVectorIsScalable() const;
inline unsigned getVectorNumElements() const;
+ inline ElementCount getVectorElementCount() const;
Type *getVectorElementType() const {
assert(getTypeID() == VectorTyID);
return ContainedTys[0];
@@ -382,6 +383,10 @@
return ContainedTys[0];
}
+ /// Given an integer or vector type, change the lane bitwidth to NewBitwidth,
+ /// whilst keeping the old number of lanes.
+ inline Type *getWithNewBitWidth(unsigned NewBitWidth) const;
+
/// Given scalar/vector integer type, returns a type with elements twice as
/// wide as in the original type. For vectors, preserves element count.
inline Type *getExtendedType() const;