[InstCombine] Simplify demanded bits of blendv mask operands (#173723)
fixes #173368
- Integer masks: Demands only the sign bit of the operand.
- Float/Double masks: Peeks through bitcasts to demand the sign bit
from the integer source.
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index ffb8f2e..cbfaf0f 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -2890,7 +2890,26 @@
getNegativeIsTrueBoolVec(ConstantMask, IC.getDataLayout());
return SelectInst::Create(NewSelector, Op1, Op0, "blendv");
}
+ unsigned BitWidth = Mask->getType()->getScalarSizeInBits();
+ if (Mask->getType()->isIntOrIntVectorTy()) {
+ KnownBits Known(BitWidth);
+ if (IC.SimplifyDemandedBits(&II, 2, APInt::getSignMask(BitWidth), Known))
+ return ⅈ
+ } else if (auto *BC = dyn_cast<BitCastInst>(Mask)) {
+ if (BC->hasOneUse()) {
+ Value *Src = BC->getOperand(0);
+ if (Src->getType()->isIntOrIntVectorTy()) {
+ unsigned SrcBitWidth = Src->getType()->getScalarSizeInBits();
+ if (SrcBitWidth == BitWidth) {
+ KnownBits KnownSrc(SrcBitWidth);
+ if (IC.SimplifyDemandedBits(BC, 0, APInt::getSignMask(SrcBitWidth),
+ KnownSrc))
+ return &II;
+ }
+ }
+ }
+ }
Mask = InstCombiner::peekThroughBitcast(Mask);
// Bitshift upto the signbit can always be converted to an efficient
diff --git a/llvm/test/Transforms/InstCombine/X86/blend_x86.ll b/llvm/test/Transforms/InstCombine/X86/blend_x86.ll
index 0916cf7..90fa512 100644
--- a/llvm/test/Transforms/InstCombine/X86/blend_x86.ll
+++ b/llvm/test/Transforms/InstCombine/X86/blend_x86.ll
@@ -357,6 +357,75 @@
ret <4 x double> %r
}
+define <16 x i8> @pblendvb_demanded_msb(<16 x i8> %a, <16 x i8> %b, <16 x i8> %m) {
+; CHECK-LABEL: @pblendvb_demanded_msb(
+; CHECK-NEXT: [[R:%.*]] = call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[A:%.*]], <16 x i8> [[B:%.*]], <16 x i8> [[OR:%.*]])
+; CHECK-NEXT: ret <16 x i8> [[R]]
+;
+ %or = or <16 x i8> %m, splat (i8 1)
+ %r = call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> %a, <16 x i8> %b, <16 x i8> %or)
+ ret <16 x i8> %r
+}
+
+define <8 x float> @blendvps_demanded_msb(<8 x float> %a, <8 x float> %b, <8 x i32> %m) {
+; CHECK-LABEL: @blendvps_demanded_msb(
+; CHECK-NEXT: [[MASK:%.*]] = bitcast <8 x i32> [[OR:%.*]] to <8 x float>
+; CHECK-NEXT: [[R:%.*]] = call <8 x float> @llvm.x86.avx.blendv.ps.256(<8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x float> [[MASK]])
+; CHECK-NEXT: ret <8 x float> [[R]]
+;
+ %or = or <8 x i32> %m, splat (i32 1)
+ %mask = bitcast <8 x i32> %or to <8 x float>
+ %r = call <8 x float> @llvm.x86.avx.blendv.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %mask)
+ ret <8 x float> %r
+}
+
+define <16 x i8> @pblendvb_or_affects_msb(<16 x i8> %a, <16 x i8> %b, <16 x i8> %m) {
+; CHECK-LABEL: @pblendvb_or_affects_msb(
+; CHECK-NEXT: ret <16 x i8> [[R:%.*]]
+;
+ %or = or <16 x i8> %m, splat (i8 128)
+ %r = call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> %a, <16 x i8> %b, <16 x i8> %or)
+ ret <16 x i8> %r
+}
+
+define <32 x i8> @pblendvb_demanded_msb_avx2(<32 x i8> %a, <32 x i8> %b, <32 x i8> %m) {
+; CHECK-LABEL: @pblendvb_demanded_msb_avx2(
+; CHECK-NEXT: [[R:%.*]] = call <32 x i8> @llvm.x86.avx2.pblendvb(<32 x i8> [[A:%.*]], <32 x i8> [[B:%.*]], <32 x i8> [[OR:%.*]])
+; CHECK-NEXT: ret <32 x i8> [[R]]
+;
+ %or = or <32 x i8> %m, splat (i8 1)
+ %r = call <32 x i8> @llvm.x86.avx2.pblendvb(<32 x i8> %a, <32 x i8> %b, <32 x i8> %or)
+ ret <32 x i8> %r
+}
+
+define <2 x double> @blendvpd_demanded_msb(<2 x double> %a, <2 x double> %b, <2 x i64> %m) {
+; CHECK-LABEL: @blendvpd_demanded_msb(
+; CHECK-NEXT: [[MASK:%.*]] = bitcast <2 x i64> [[M:%.*]] to <2 x double>
+; CHECK-NEXT: [[R:%.*]] = call <2 x double> @llvm.x86.sse41.blendvpd(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[MASK]])
+; CHECK-NEXT: ret <2 x double> [[R]]
+;
+ %or = or <2 x i64> %m, splat (i64 1)
+ %mask = bitcast <2 x i64> %or to <2 x double>
+ %r = call <2 x double> @llvm.x86.sse41.blendvpd(<2 x double> %a, <2 x double> %b, <2 x double> %mask)
+ ret <2 x double> %r
+}
+
+declare void @use_mask(<8 x float>)
+define <8 x float> @blendvps_demanded_msb_multiuse(<8 x float> %a, <8 x float> %b, <8 x i32> %m) {
+; CHECK-LABEL: @blendvps_demanded_msb_multiuse(
+; CHECK-NEXT: [[OR:%.*]] = or <8 x i32> [[M:%.*]], splat (i32 1)
+; CHECK-NEXT: [[MASK:%.*]] = bitcast <8 x i32> [[OR]] to <8 x float>
+; CHECK-NEXT: call void @use_mask(<8 x float> [[MASK]])
+; CHECK-NEXT: [[R:%.*]] = call <8 x float> @llvm.x86.avx.blendv.ps.256(<8 x float> [[A:%.*]], <8 x float> [[B:%.*]], <8 x float> [[MASK]])
+; CHECK-NEXT: ret <8 x float> [[R]]
+;
+ %or = or <8 x i32> %m, splat (i32 1)
+ %mask = bitcast <8 x i32> %or to <8 x float>
+ call void @use_mask(<8 x float> %mask)
+ %r = call <8 x float> @llvm.x86.avx.blendv.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %mask)
+ ret <8 x float> %r
+}
+
declare <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8>, <16 x i8>, <16 x i8>)
declare <4 x float> @llvm.x86.sse41.blendvps(<4 x float>, <4 x float>, <4 x float>)
declare <2 x double> @llvm.x86.sse41.blendvpd(<2 x double>, <2 x double>, <2 x double>)