[GlobalISel] Implement fewerElements legalization for vector reductions.
This patch adds 3 methods, one for power-of-2 vectors which use tree
reductions using vector ops, before a final reduction op. For non-pow-2
types it generates multiple narrow reductions and combines the values with
scalar ops.
Differential Revision: https://reviews.llvm.org/D97163
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 7680f61..9eb4c80 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -17,6 +17,7 @@
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetFrameLowering.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
@@ -4207,11 +4208,139 @@
return reduceLoadStoreWidth(MI, TypeIdx, NarrowTy);
case G_SEXT_INREG:
return fewerElementsVectorSextInReg(MI, TypeIdx, NarrowTy);
+ GISEL_VECREDUCE_CASES_NONSEQ
+ return fewerElementsVectorReductions(MI, TypeIdx, NarrowTy);
default:
return UnableToLegalize;
}
}
+LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
+ MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
+ unsigned Opc = MI.getOpcode();
+ assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
+ Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
+ "Sequential reductions not expected");
+
+ if (TypeIdx != 1)
+ return UnableToLegalize;
+
+ // The semantics of the normal non-sequential reductions allow us to freely
+ // re-associate the operation.
+ Register SrcReg = MI.getOperand(1).getReg();
+ LLT SrcTy = MRI.getType(SrcReg);
+ Register DstReg = MI.getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+
+ if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)
+ return UnableToLegalize;
+
+ SmallVector<Register> SplitSrcs;
+ const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements();
+ extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
+ SmallVector<Register> PartialReductions;
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ PartialReductions.push_back(
+ MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
+ }
+
+ unsigned ScalarOpc;
+ switch (Opc) {
+ case TargetOpcode::G_VECREDUCE_FADD:
+ ScalarOpc = TargetOpcode::G_FADD;
+ break;
+ case TargetOpcode::G_VECREDUCE_FMUL:
+ ScalarOpc = TargetOpcode::G_FMUL;
+ break;
+ case TargetOpcode::G_VECREDUCE_FMAX:
+ ScalarOpc = TargetOpcode::G_FMAXNUM;
+ break;
+ case TargetOpcode::G_VECREDUCE_FMIN:
+ ScalarOpc = TargetOpcode::G_FMINNUM;
+ break;
+ case TargetOpcode::G_VECREDUCE_ADD:
+ ScalarOpc = TargetOpcode::G_ADD;
+ break;
+ case TargetOpcode::G_VECREDUCE_MUL:
+ ScalarOpc = TargetOpcode::G_MUL;
+ break;
+ case TargetOpcode::G_VECREDUCE_AND:
+ ScalarOpc = TargetOpcode::G_AND;
+ break;
+ case TargetOpcode::G_VECREDUCE_OR:
+ ScalarOpc = TargetOpcode::G_OR;
+ break;
+ case TargetOpcode::G_VECREDUCE_XOR:
+ ScalarOpc = TargetOpcode::G_XOR;
+ break;
+ case TargetOpcode::G_VECREDUCE_SMAX:
+ ScalarOpc = TargetOpcode::G_SMAX;
+ break;
+ case TargetOpcode::G_VECREDUCE_SMIN:
+ ScalarOpc = TargetOpcode::G_SMIN;
+ break;
+ case TargetOpcode::G_VECREDUCE_UMAX:
+ ScalarOpc = TargetOpcode::G_UMAX;
+ break;
+ case TargetOpcode::G_VECREDUCE_UMIN:
+ ScalarOpc = TargetOpcode::G_UMIN;
+ break;
+ default:
+ LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n");
+ return UnableToLegalize;
+ }
+
+ // If the types involved are powers of 2, we can generate intermediate vector
+ // ops, before generating a final reduction operation.
+ if (isPowerOf2_32(SrcTy.getNumElements()) &&
+ isPowerOf2_32(NarrowTy.getNumElements())) {
+ return tryNarrowPow2Reduction(MI, SrcReg, SrcTy, NarrowTy, ScalarOpc);
+ }
+
+ Register Acc = PartialReductions[0];
+ for (unsigned Part = 1; Part < NumParts; ++Part) {
+ if (Part == NumParts - 1) {
+ MIRBuilder.buildInstr(ScalarOpc, {DstReg},
+ {Acc, PartialReductions[Part]});
+ } else {
+ Acc = MIRBuilder
+ .buildInstr(ScalarOpc, {DstTy}, {Acc, PartialReductions[Part]})
+ .getReg(0);
+ }
+ }
+ MI.eraseFromParent();
+ return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
+LegalizerHelper::tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
+ LLT SrcTy, LLT NarrowTy,
+ unsigned ScalarOpc) {
+ SmallVector<Register> SplitSrcs;
+ // Split the sources into NarrowTy size pieces.
+ extractParts(SrcReg, NarrowTy,
+ SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs);
+ // We're going to do a tree reduction using vector operations until we have
+ // one NarrowTy size value left.
+ while (SplitSrcs.size() > 1) {
+ SmallVector<Register> PartialRdxs;
+ for (unsigned Idx = 0; Idx < SplitSrcs.size()-1; Idx += 2) {
+ Register LHS = SplitSrcs[Idx];
+ Register RHS = SplitSrcs[Idx + 1];
+ // Create the intermediate vector op.
+ Register Res =
+ MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {LHS, RHS}).getReg(0);
+ PartialRdxs.push_back(Res);
+ }
+ SplitSrcs = std::move(PartialRdxs);
+ }
+ // Finally generate the requested NarrowTy based reduction.
+ Observer.changingInstr(MI);
+ MI.getOperand(1).setReg(SplitSrcs[0]);
+ Observer.changedInstr(MI);
+ return Legalized;
+}
+
LegalizerHelper::LegalizeResult
LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
const LLT HalfTy, const LLT AmtTy) {