[SLP][REVEC] Support more mask pattern usage in shufflevector. (#106212)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 93e7bfc..e6a0e9b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -293,8 +293,7 @@
/// A group has the following features
/// 1. All of value in a group are shufflevector.
/// 2. The mask of all shufflevector is isExtractSubvectorMask.
-/// 3. The mask of all shufflevector uses all of the elements of the source (and
-/// the elements are used in order).
+/// 3. The mask of all shufflevector uses all of the elements of the source.
/// e.g., it is 1 group (%0)
/// %1 = shufflevector <16 x i8> %0, <16 x i8> poison,
/// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
@@ -322,7 +321,8 @@
auto *SV = cast<ShuffleVectorInst>(VL.front());
unsigned SVNumElements =
cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
- unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
+ unsigned ShuffleMaskSize = SV->getShuffleMask().size();
+ unsigned GroupSize = SVNumElements / ShuffleMaskSize;
if (GroupSize == 0 || (VL.size() % GroupSize) != 0)
return 0;
unsigned NumGroup = 0;
@@ -330,7 +330,7 @@
auto *SV = cast<ShuffleVectorInst>(VL[I]);
Value *Src = SV->getOperand(0);
ArrayRef<Value *> Group = VL.slice(I, GroupSize);
- SmallVector<int> ExtractionIndex(SVNumElements);
+ SmallBitVector ExpectedIndex(GroupSize);
if (!all_of(Group, [&](Value *V) {
auto *SV = cast<ShuffleVectorInst>(V);
// From the same source.
@@ -339,12 +339,11 @@
int Index;
if (!SV->isExtractSubvectorMask(Index))
return false;
- for (int I : seq<int>(Index, Index + SV->getShuffleMask().size()))
- ExtractionIndex.push_back(I);
+ ExpectedIndex.set(Index / ShuffleMaskSize);
return true;
}))
return 0;
- if (!is_sorted(ExtractionIndex))
+ if (!ExpectedIndex.all())
return 0;
++NumGroup;
}
@@ -10289,12 +10288,40 @@
return VecCost;
};
if (SLPReVec && !E->isAltShuffle())
- return GetCostDiff(GetScalarCost, [](InstructionCost) {
- // shufflevector will be eliminated by instcombine because the
- // shufflevector masks are used in order (guaranteed by
- // getShufflevectorNumGroups). The vector cost is 0.
- return TTI::TCC_Free;
- });
+ return GetCostDiff(
+ GetScalarCost, [&](InstructionCost) -> InstructionCost {
+ // If a group uses mask in order, the shufflevector can be
+ // eliminated by instcombine. Then the cost is 0.
+ assert(isa<ShuffleVectorInst>(VL.front()) &&
+ "Not supported shufflevector usage.");
+ auto *SV = cast<ShuffleVectorInst>(VL.front());
+ unsigned SVNumElements =
+ cast<FixedVectorType>(SV->getOperand(0)->getType())
+ ->getNumElements();
+ unsigned GroupSize = SVNumElements / SV->getShuffleMask().size();
+ for (size_t I = 0, End = VL.size(); I != End; I += GroupSize) {
+ ArrayRef<Value *> Group = VL.slice(I, GroupSize);
+ int NextIndex = 0;
+ if (!all_of(Group, [&](Value *V) {
+ assert(isa<ShuffleVectorInst>(V) &&
+ "Not supported shufflevector usage.");
+ auto *SV = cast<ShuffleVectorInst>(V);
+ int Index;
+ bool isExtractSubvectorMask =
+ SV->isExtractSubvectorMask(Index);
+ assert(isExtractSubvectorMask &&
+ "Not supported shufflevector usage.");
+ if (NextIndex != Index)
+ return false;
+ NextIndex += SV->getShuffleMask().size();
+ return true;
+ }))
+ return ::getShuffleCost(
+ *TTI, TargetTransformInfo::SK_PermuteSingleSrc, VecTy,
+ calculateShufflevectorMask(E->Scalars));
+ }
+ return TTI::TCC_Free;
+ });
return GetCostDiff(GetScalarCost, GetVectorCost);
}
case Instruction::Freeze:
@@ -14072,9 +14099,16 @@
LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
return E->VectorizedValue;
}
- // The current shufflevector usage always duplicate the source.
- V = Builder.CreateShuffleVector(Src,
- calculateShufflevectorMask(E->Scalars));
+ assert(isa<ShuffleVectorInst>(Src) &&
+ "Not supported shufflevector usage.");
+ auto *SVSrc = cast<ShuffleVectorInst>(Src);
+ assert(isa<PoisonValue>(SVSrc->getOperand(1)) &&
+ "Not supported shufflevector usage.");
+ SmallVector<int> ThisMask(calculateShufflevectorMask(E->Scalars));
+ SmallVector<int> NewMask(ThisMask.size());
+ transform(ThisMask, NewMask.begin(),
+ [&SVSrc](int Mask) { return SVSrc->getShuffleMask()[Mask]; });
+ V = Builder.CreateShuffleVector(SVSrc->getOperand(0), NewMask);
propagateIRFlags(V, E->Scalars, VL0);
} else {
assert(E->isAltShuffle() &&
diff --git a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
index 6028a8b..1fc0b03 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec-shufflevector.ll
@@ -34,17 +34,9 @@
; CHECK-LABEL: @test2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load <8 x i32>, ptr [[IN:%.*]], align 1
-; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i32> [[TMP0]], <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
-; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i32> [[TMP1]] to <4 x i64>
-; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i32> [[TMP2]] to <4 x i64>
-; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i8, ptr [[OUT:%.*]], i64 16
-; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, ptr [[OUT]], i64 32
-; CHECK-NEXT: store <2 x i64> [[TMP5]], ptr [[OUT]], align 8
-; CHECK-NEXT: store <2 x i64> [[TMP6]], ptr [[TMP7]], align 8
-; CHECK-NEXT: store <4 x i64> [[TMP4]], ptr [[TMP8]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = zext <8 x i32> [[TMP0]] to <8 x i64>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i64> [[TMP1]], <8 x i64> poison, <8 x i32> <i32 2, i32 3, i32 0, i32 1, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT: store <8 x i64> [[TMP2]], ptr [[OUT:%.*]], align 8
; CHECK-NEXT: ret void
;
entry:
@@ -67,3 +59,26 @@
store <2 x i64> %8, ptr %12, align 8
ret void
}
+
+define void @test3(<16 x i32> %0, ptr %out) {
+; CHECK-LABEL: @test3(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i32> [[TMP0:%.*]], <16 x i32> poison, <16 x i32> <i32 12, i32 13, i32 14, i32 15, i32 8, i32 9, i32 10, i32 11, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT: store <16 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT: ret void
+;
+entry:
+ %1 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 12, i32 13, i32 14, i32 15>
+ %2 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 8, i32 9, i32 10, i32 11>
+ %3 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ %4 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %5 = getelementptr inbounds i32, ptr %out, i64 0
+ %6 = getelementptr inbounds i32, ptr %out, i64 4
+ %7 = getelementptr inbounds i32, ptr %out, i64 8
+ %8 = getelementptr inbounds i32, ptr %out, i64 12
+ store <4 x i32> %1, ptr %5, align 4
+ store <4 x i32> %2, ptr %6, align 4
+ store <4 x i32> %3, ptr %7, align 4
+ store <4 x i32> %4, ptr %8, align 4
+ ret void
+}