[AArch64][GlobalISel] Add G_VECREDUCE fewerElements support for full scalarization.
For some reductions like G_VECREDUCE_OR on AArch64, we need to scalarize
completely if the source is <= 64b. This change adds support for that in
the legalizer. If the source has a pow-2 num elements, then we can do
a tree reduction using the scalar operation in the individual elements.
Otherwise, we just create a sequential chain of operations.
For AArch64, we only need to scalarize if the input is <64b. If it's great than
64b then we can first do a fewElements step to 64b, taking advantage of vector
instructions until we reach the point of scalarization.
I also had to relax the verifier checks for reductions because the intrinsics
support <1 x EltTy> types, which we lower to scalars for GlobalISel.
Differential Revision: https://reviews.llvm.org/D108276
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 611cf10..463437a 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3489,6 +3489,8 @@
return lowerRotate(MI);
case G_ISNAN:
return lowerIsNaN(MI);
+ GISEL_VECREDUCE_CASES_NONSEQ
+ return lowerVectorReduction(MI);
}
}
@@ -4637,35 +4639,7 @@
return Legalized;
}
-LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
- MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
- unsigned Opc = MI.getOpcode();
- assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
- Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
- "Sequential reductions not expected");
-
- if (TypeIdx != 1)
- return UnableToLegalize;
-
- // The semantics of the normal non-sequential reductions allow us to freely
- // re-associate the operation.
- Register SrcReg = MI.getOperand(1).getReg();
- LLT SrcTy = MRI.getType(SrcReg);
- Register DstReg = MI.getOperand(0).getReg();
- LLT DstTy = MRI.getType(DstReg);
-
- if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)
- return UnableToLegalize;
-
- SmallVector<Register> SplitSrcs;
- const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements();
- extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
- SmallVector<Register> PartialReductions;
- for (unsigned Part = 0; Part < NumParts; ++Part) {
- PartialReductions.push_back(
- MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
- }
-
+static unsigned getScalarOpcForReduction(unsigned Opc) {
unsigned ScalarOpc;
switch (Opc) {
case TargetOpcode::G_VECREDUCE_FADD:
@@ -4708,9 +4682,80 @@
ScalarOpc = TargetOpcode::G_UMIN;
break;
default:
- LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n");
- return UnableToLegalize;
+ llvm_unreachable("Unhandled reduction");
}
+ return ScalarOpc;
+}
+
+LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
+ MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
+ unsigned Opc = MI.getOpcode();
+ assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
+ Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
+ "Sequential reductions not expected");
+
+ if (TypeIdx != 1)
+ return UnableToLegalize;
+
+ // The semantics of the normal non-sequential reductions allow us to freely
+ // re-associate the operation.
+ Register SrcReg = MI.getOperand(1).getReg();
+ LLT SrcTy = MRI.getType(SrcReg);
+ Register DstReg = MI.getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+
+ if (NarrowTy.isVector() &&
+ (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0))
+ return UnableToLegalize;
+
+ unsigned ScalarOpc = getScalarOpcForReduction(Opc);
+ SmallVector<Register> SplitSrcs;
+ // If NarrowTy is a scalar then we're being asked to scalarize.
+ const unsigned NumParts =
+ NarrowTy.isVector() ? SrcTy.getNumElements() / NarrowTy.getNumElements()
+ : SrcTy.getNumElements();
+
+ extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
+ if (NarrowTy.isScalar()) {
+ if (DstTy != NarrowTy)
+ return UnableToLegalize; // FIXME: handle implicit extensions.
+
+ if (isPowerOf2_32(NumParts)) {
+ // Generate a tree of scalar operations to reduce the critical path.
+ SmallVector<Register> PartialResults;
+ unsigned NumPartsLeft = NumParts;
+ while (NumPartsLeft > 1) {
+ for (unsigned Idx = 0; Idx < NumPartsLeft - 1; Idx += 2) {
+ PartialResults.emplace_back(
+ MIRBuilder
+ .buildInstr(ScalarOpc, {NarrowTy},
+ {SplitSrcs[Idx], SplitSrcs[Idx + 1]})
+ .getReg(0));
+ }
+ SplitSrcs = PartialResults;
+ PartialResults.clear();
+ NumPartsLeft = SplitSrcs.size();
+ }
+ assert(SplitSrcs.size() == 1);
+ MIRBuilder.buildCopy(DstReg, SplitSrcs[0]);
+ MI.eraseFromParent();
+ return Legalized;
+ }
+ // If we can't generate a tree, then just do sequential operations.
+ Register Acc = SplitSrcs[0];
+ for (unsigned Idx = 1; Idx < NumParts; ++Idx)
+ Acc = MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {Acc, SplitSrcs[Idx]})
+ .getReg(0);
+ MIRBuilder.buildCopy(DstReg, Acc);
+ MI.eraseFromParent();
+ return Legalized;
+ }
+ SmallVector<Register> PartialReductions;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ PartialReductions.push_back(
+ MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
+ }
+
// If the types involved are powers of 2, we can generate intermediate vector
// ops, before generating a final reduction operation.
@@ -7389,3 +7434,22 @@
MI.eraseFromParent();
return Legalized;
}
+
+LegalizerHelper::LegalizeResult
+LegalizerHelper::lowerVectorReduction(MachineInstr &MI) {
+ Register SrcReg = MI.getOperand(1).getReg();
+ LLT SrcTy = MRI.getType(SrcReg);
+ LLT DstTy = MRI.getType(SrcReg);
+
+ // The source could be a scalar if the IR type was <1 x sN>.
+ if (SrcTy.isScalar()) {
+ if (DstTy.getSizeInBits() > SrcTy.getSizeInBits())
+ return UnableToLegalize; // FIXME: handle extension.
+ // This can be just a plain copy.
+ Observer.changingInstr(MI);
+ MI.setDesc(MIRBuilder.getTII().get(TargetOpcode::COPY));
+ Observer.changedInstr(MI);
+ return Legalized;
+ }
+ return UnableToLegalize;;
+}