[AArch64][GlobalISel] Add support for lowering trunc stores of vector bools.
This is essentially a port of TargetLowering::scalarizeVectorStore(), which
is used for the case where we have something like a store of <8 x s8> truncating
to <8 x s1> in memory. The naive lowering is a sequence of extracts to compute
a scalar value to store.
AArch64's DAG implementation has some more smarts to improve this further which
we can do later.
Reviewers: topperc, davemgreen
Pull Request: https://github.com/llvm/llvm-project/pull/121169
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 7dece93..0bfa897 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -4143,9 +4143,8 @@
}
if (MemTy.isVector()) {
- // TODO: Handle vector trunc stores
if (MemTy != SrcTy)
- return UnableToLegalize;
+ return scalarizeVectorBooleanStore(StoreMI);
// TODO: We can do better than scalarizing the vector and at least split it
// in half.
@@ -4201,6 +4200,50 @@
}
LegalizerHelper::LegalizeResult
+LegalizerHelper::scalarizeVectorBooleanStore(GStore &StoreMI) {
+ Register SrcReg = StoreMI.getValueReg();
+ Register PtrReg = StoreMI.getPointerReg();
+ LLT SrcTy = MRI.getType(SrcReg);
+ MachineMemOperand &MMO = **StoreMI.memoperands_begin();
+ LLT MemTy = MMO.getMemoryType();
+ LLT MemScalarTy = MemTy.getElementType();
+ MachineFunction &MF = MIRBuilder.getMF();
+
+ assert(SrcTy.isVector() && "Expect a vector store type");
+
+ if (!MemScalarTy.isByteSized()) {
+ // We need to build an integer scalar of the vector bit pattern.
+ // It's not legal for us to add padding when storing a vector.
+ unsigned NumBits = MemTy.getSizeInBits();
+ LLT IntTy = LLT::scalar(NumBits);
+ auto CurrVal = MIRBuilder.buildConstant(IntTy, 0);
+ LLT IdxTy = getLLTForMVT(TLI.getVectorIdxTy(MF.getDataLayout()));
+
+ for (unsigned I = 0, E = MemTy.getNumElements(); I < E; ++I) {
+ auto Elt = MIRBuilder.buildExtractVectorElement(
+ SrcTy.getElementType(), SrcReg, MIRBuilder.buildConstant(IdxTy, I));
+ auto Trunc = MIRBuilder.buildTrunc(MemScalarTy, Elt);
+ auto ZExt = MIRBuilder.buildZExt(IntTy, Trunc);
+ unsigned ShiftIntoIdx = MF.getDataLayout().isBigEndian()
+ ? (MemTy.getNumElements() - 1) - I
+ : I;
+ auto ShiftAmt = MIRBuilder.buildConstant(
+ IntTy, ShiftIntoIdx * MemScalarTy.getSizeInBits());
+ auto Shifted = MIRBuilder.buildShl(IntTy, ZExt, ShiftAmt);
+ CurrVal = MIRBuilder.buildOr(IntTy, CurrVal, Shifted);
+ }
+ auto PtrInfo = MMO.getPointerInfo();
+ auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, IntTy);
+ MIRBuilder.buildStore(CurrVal, PtrReg, *NewMMO);
+ StoreMI.eraseFromParent();
+ return Legalized;
+ }
+
+ // TODO: implement simple scalarization.
+ return UnableToLegalize;
+}
+
+LegalizerHelper::LegalizeResult
LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
switch (MI.getOpcode()) {
case TargetOpcode::G_LOAD: {