[GlobalISel] Implement splitting of G_SHUFFLE_VECTOR.
Thhis is a port from the DAG legalization. We're still missing some of the
canonicalizations of shuffles but it's a start.
Differential Revision: https://reviews.llvm.org/D102828
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index e07d4d8..fd20b50 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -23,7 +23,9 @@
#include "llvm/CodeGen/TargetFrameLowering.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -4244,11 +4246,154 @@
return fewerElementsVectorSextInReg(MI, TypeIdx, NarrowTy);
GISEL_VECREDUCE_CASES_NONSEQ
return fewerElementsVectorReductions(MI, TypeIdx, NarrowTy);
+ case G_SHUFFLE_VECTOR:
+ return fewerElementsVectorShuffle(MI, TypeIdx, NarrowTy);
default:
return UnableToLegalize;
}
}
+LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorShuffle(
+ MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
+ assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
+ if (TypeIdx != 0)
+ return UnableToLegalize;
+
+ Register DstReg = MI.getOperand(0).getReg();
+ Register Src1Reg = MI.getOperand(1).getReg();
+ Register Src2Reg = MI.getOperand(2).getReg();
+ ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT Src1Ty = MRI.getType(Src1Reg);
+ LLT Src2Ty = MRI.getType(Src2Reg);
+ // The shuffle should be canonicalized by now.
+ if (DstTy != Src1Ty)
+ return UnableToLegalize;
+ if (DstTy != Src2Ty)
+ return UnableToLegalize;
+
+ if (!isPowerOf2_32(DstTy.getNumElements()))
+ return UnableToLegalize;
+
+ // We only support splitting a shuffle into 2, so adjust NarrowTy accordingly.
+ // Further legalization attempts will be needed to do split further.
+ NarrowTy = DstTy.changeNumElements(DstTy.getNumElements() / 2);
+ unsigned NewElts = NarrowTy.getNumElements();
+
+ SmallVector<Register> SplitSrc1Regs, SplitSrc2Regs;
+ extractParts(Src1Reg, NarrowTy, 2, SplitSrc1Regs);
+ extractParts(Src2Reg, NarrowTy, 2, SplitSrc2Regs);
+ Register Inputs[4] = {SplitSrc1Regs[0], SplitSrc1Regs[1], SplitSrc2Regs[0],
+ SplitSrc2Regs[1]};
+
+ Register Hi, Lo;
+
+ // If Lo or Hi uses elements from at most two of the four input vectors, then
+ // express it as a vector shuffle of those two inputs. Otherwise extract the
+ // input elements by hand and construct the Lo/Hi output using a BUILD_VECTOR.
+ SmallVector<int, 16> Ops;
+ for (unsigned High = 0; High < 2; ++High) {
+ Register &Output = High ? Hi : Lo;
+
+ // Build a shuffle mask for the output, discovering on the fly which
+ // input vectors to use as shuffle operands (recorded in InputUsed).
+ // If building a suitable shuffle vector proves too hard, then bail
+ // out with useBuildVector set.
+ unsigned InputUsed[2] = {-1U, -1U}; // Not yet discovered.
+ unsigned FirstMaskIdx = High * NewElts;
+ bool UseBuildVector = false;
+ for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
+ // The mask element. This indexes into the input.
+ int Idx = Mask[FirstMaskIdx + MaskOffset];
+
+ // The input vector this mask element indexes into.
+ unsigned Input = (unsigned)Idx / NewElts;
+
+ if (Input >= array_lengthof(Inputs)) {
+ // The mask element does not index into any input vector.
+ Ops.push_back(-1);
+ continue;
+ }
+
+ // Turn the index into an offset from the start of the input vector.
+ Idx -= Input * NewElts;
+
+ // Find or create a shuffle vector operand to hold this input.
+ unsigned OpNo;
+ for (OpNo = 0; OpNo < array_lengthof(InputUsed); ++OpNo) {
+ if (InputUsed[OpNo] == Input) {
+ // This input vector is already an operand.
+ break;
+ } else if (InputUsed[OpNo] == -1U) {
+ // Create a new operand for this input vector.
+ InputUsed[OpNo] = Input;
+ break;
+ }
+ }
+
+ if (OpNo >= array_lengthof(InputUsed)) {
+ // More than two input vectors used! Give up on trying to create a
+ // shuffle vector. Insert all elements into a BUILD_VECTOR instead.
+ UseBuildVector = true;
+ break;
+ }
+
+ // Add the mask index for the new shuffle vector.
+ Ops.push_back(Idx + OpNo * NewElts);
+ }
+
+ if (UseBuildVector) {
+ LLT EltTy = NarrowTy.getElementType();
+ SmallVector<Register, 16> SVOps;
+
+ // Extract the input elements by hand.
+ for (unsigned MaskOffset = 0; MaskOffset < NewElts; ++MaskOffset) {
+ // The mask element. This indexes into the input.
+ int Idx = Mask[FirstMaskIdx + MaskOffset];
+
+ // The input vector this mask element indexes into.
+ unsigned Input = (unsigned)Idx / NewElts;
+
+ if (Input >= array_lengthof(Inputs)) {
+ // The mask element is "undef" or indexes off the end of the input.
+ SVOps.push_back(MIRBuilder.buildUndef(EltTy).getReg(0));
+ continue;
+ }
+
+ // Turn the index into an offset from the start of the input vector.
+ Idx -= Input * NewElts;
+
+ // Extract the vector element by hand.
+ SVOps.push_back(MIRBuilder
+ .buildExtractVectorElement(
+ EltTy, Inputs[Input],
+ MIRBuilder.buildConstant(LLT::scalar(32), Idx))
+ .getReg(0));
+ }
+
+ // Construct the Lo/Hi output using a G_BUILD_VECTOR.
+ Output = MIRBuilder.buildBuildVector(NarrowTy, SVOps).getReg(0);
+ } else if (InputUsed[0] == -1U) {
+ // No input vectors were used! The result is undefined.
+ Output = MIRBuilder.buildUndef(NarrowTy).getReg(0);
+ } else {
+ Register Op0 = Inputs[InputUsed[0]];
+ // If only one input was used, use an undefined vector for the other.
+ Register Op1 = InputUsed[1] == -1U
+ ? MIRBuilder.buildUndef(NarrowTy).getReg(0)
+ : Inputs[InputUsed[1]];
+ // At least one input vector was used. Create a new shuffle vector.
+ Output = MIRBuilder.buildShuffleVector(NarrowTy, Op0, Op1, Ops).getReg(0);
+ }
+
+ Ops.clear();
+ }
+
+ MIRBuilder.buildConcatVectors(DstReg, {Lo, Hi});
+ MI.eraseFromParent();
+ return Legalized;
+}
+
LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
unsigned Opc = MI.getOpcode();