[AArch64][GlobalISel] Add support for lowering trunc stores of vector bools.

This is essentially a port of TargetLowering::scalarizeVectorStore(), which
is used for the case where we have something like a store of <8 x s8> truncating
to <8 x s1> in memory. The naive lowering is a sequence of extracts to compute
a scalar value to store.

AArch64's DAG implementation has some more smarts to improve this further which
we can do later.

Reviewers: topperc, davemgreen

Pull Request: https://github.com/llvm/llvm-project/pull/121169
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 7dece93..0bfa897 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -4143,9 +4143,8 @@
   }
 
   if (MemTy.isVector()) {
-    // TODO: Handle vector trunc stores
     if (MemTy != SrcTy)
-      return UnableToLegalize;
+      return scalarizeVectorBooleanStore(StoreMI);
 
     // TODO: We can do better than scalarizing the vector and at least split it
     // in half.
@@ -4201,6 +4200,50 @@
 }
 
 LegalizerHelper::LegalizeResult
+LegalizerHelper::scalarizeVectorBooleanStore(GStore &StoreMI) {
+  Register SrcReg = StoreMI.getValueReg();
+  Register PtrReg = StoreMI.getPointerReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+  MachineMemOperand &MMO = **StoreMI.memoperands_begin();
+  LLT MemTy = MMO.getMemoryType();
+  LLT MemScalarTy = MemTy.getElementType();
+  MachineFunction &MF = MIRBuilder.getMF();
+
+  assert(SrcTy.isVector() && "Expect a vector store type");
+
+  if (!MemScalarTy.isByteSized()) {
+    // We need to build an integer scalar of the vector bit pattern.
+    // It's not legal for us to add padding when storing a vector.
+    unsigned NumBits = MemTy.getSizeInBits();
+    LLT IntTy = LLT::scalar(NumBits);
+    auto CurrVal = MIRBuilder.buildConstant(IntTy, 0);
+    LLT IdxTy = getLLTForMVT(TLI.getVectorIdxTy(MF.getDataLayout()));
+
+    for (unsigned I = 0, E = MemTy.getNumElements(); I < E; ++I) {
+      auto Elt = MIRBuilder.buildExtractVectorElement(
+          SrcTy.getElementType(), SrcReg, MIRBuilder.buildConstant(IdxTy, I));
+      auto Trunc = MIRBuilder.buildTrunc(MemScalarTy, Elt);
+      auto ZExt = MIRBuilder.buildZExt(IntTy, Trunc);
+      unsigned ShiftIntoIdx = MF.getDataLayout().isBigEndian()
+                                  ? (MemTy.getNumElements() - 1) - I
+                                  : I;
+      auto ShiftAmt = MIRBuilder.buildConstant(
+          IntTy, ShiftIntoIdx * MemScalarTy.getSizeInBits());
+      auto Shifted = MIRBuilder.buildShl(IntTy, ZExt, ShiftAmt);
+      CurrVal = MIRBuilder.buildOr(IntTy, CurrVal, Shifted);
+    }
+    auto PtrInfo = MMO.getPointerInfo();
+    auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, IntTy);
+    MIRBuilder.buildStore(CurrVal, PtrReg, *NewMMO);
+    StoreMI.eraseFromParent();
+    return Legalized;
+  }
+
+  // TODO: implement simple scalarization.
+  return UnableToLegalize;
+}
+
+LegalizerHelper::LegalizeResult
 LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) {
   switch (MI.getOpcode()) {
   case TargetOpcode::G_LOAD: {