[AArch64][InstCombine] Simplify repeated complex patterns in dupqlane
Repeated floating-point complex patterns in dupqlane such as (f32 a, f32 b, f32
a, f32 b) can be simplified to shufflevector(f64(a, b), undef, 0)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ce073ce..471b05b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1436,6 +1436,86 @@
return std::nullopt;
}
+bool SimplifyValuePattern(SmallVector<Value *> &Vec) {
+ size_t VecSize = Vec.size();
+ if (VecSize == 1)
+ return true;
+ if (!isPowerOf2_64(VecSize))
+ return false;
+ size_t HalfVecSize = VecSize / 2;
+
+ for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
+ RHS != Vec.end(); LHS++, RHS++) {
+ if (*LHS != nullptr && *RHS != nullptr && *LHS == *RHS)
+ continue;
+ return false;
+ }
+
+ Vec.resize(HalfVecSize);
+ SimplifyValuePattern(Vec);
+ return true;
+}
+
+// Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
+// to dupqlane(f64(C)) where C is A concatenated with B
+static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
+ IntrinsicInst &II) {
+ Value *CurrentInsertElt = nullptr, *Default = nullptr;
+ if (!match(II.getOperand(0),
+ m_Intrinsic<Intrinsic::vector_insert>(
+ m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
+ !isa<FixedVectorType>(CurrentInsertElt->getType()))
+ return std::nullopt;
+ auto IIScalableTy = cast<ScalableVectorType>(II.getType());
+
+ // Insert the scalars into a container ordered by InsertElement index
+ SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
+ while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
+ auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
+ Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
+ CurrentInsertElt = InsertElt->getOperand(0);
+ }
+
+ if (!SimplifyValuePattern(Elts))
+ return std::nullopt;
+
+ // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
+ IRBuilder<> Builder(II.getContext());
+ Builder.SetInsertPoint(&II);
+ Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
+ for (size_t I = 0; I < Elts.size(); I++) {
+ InsertEltChain = Builder.CreateInsertElement(InsertEltChain, Elts[I],
+ Builder.getInt64(I));
+ }
+
+ // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
+ // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
+ // be bitcast to a type wide enough to fit the sequence, be splatted, and then
+ // be narrowed back to the original type.
+ unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
+ unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
+ IIScalableTy->getMinNumElements() /
+ PatternWidth;
+
+ IntegerType *WideTy = Builder.getIntNTy(PatternWidth);
+ auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
+ auto *WideShuffleMaskTy =
+ ScalableVectorType::get(Builder.getInt32Ty(), PatternElementCount);
+
+ auto ZeroIdx = ConstantInt::get(Builder.getInt64Ty(), APInt(64, 0));
+ auto InsertSubvector = Builder.CreateInsertVector(
+ II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx);
+ auto WideBitcast =
+ Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
+ auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
+ auto WideShuffle = Builder.CreateShuffleVector(
+ WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
+ auto NarrowBitcast =
+ Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
+
+ return IC.replaceInstUsesWith(II, NarrowBitcast);
+}
+
static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
IntrinsicInst &II) {
Value *A = II.getArgOperand(0);
@@ -1553,6 +1633,8 @@
return instCombineSVESel(IC, II);
case Intrinsic::aarch64_sve_srshl:
return instCombineSVESrshl(IC, II);
+ case Intrinsic::aarch64_sve_dupq_lane:
+ return instCombineSVEDupqLane(IC, II);
}
return std::nullopt;