GlobalISel: Handle odd splits in fewerElementsVector for load/store

llvm-svn: 352720
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 08ea4d9..9e61498 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -86,6 +86,91 @@
   MIRBuilder.buildUnmerge(VRegs, Reg);
 }
 
+bool LegalizerHelper::extractParts(unsigned Reg, LLT RegTy,
+                                   LLT MainTy, LLT &LeftoverTy,
+                                   SmallVectorImpl<unsigned> &VRegs,
+                                   SmallVectorImpl<unsigned> &LeftoverRegs) {
+  assert(!LeftoverTy.isValid() && "this is an out argument");
+
+  unsigned RegSize = RegTy.getSizeInBits();
+  unsigned MainSize = MainTy.getSizeInBits();
+  unsigned NumParts = RegSize / MainSize;
+  unsigned LeftoverSize = RegSize - NumParts * MainSize;
+
+  // Use an unmerge when possible.
+  if (LeftoverSize == 0) {
+    for (unsigned I = 0; I < NumParts; ++I)
+      VRegs.push_back(MRI.createGenericVirtualRegister(MainTy));
+    MIRBuilder.buildUnmerge(VRegs, Reg);
+    return true;
+  }
+
+  if (MainTy.isVector()) {
+    unsigned EltSize = MainTy.getScalarSizeInBits();
+    if (LeftoverSize % EltSize != 0)
+      return false;
+    LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize);
+  } else {
+    LeftoverTy = LLT::scalar(LeftoverSize);
+  }
+
+  // For irregular sizes, extract the individual parts.
+  for (unsigned I = 0; I != NumParts; ++I) {
+    unsigned NewReg = MRI.createGenericVirtualRegister(MainTy);
+    VRegs.push_back(NewReg);
+    MIRBuilder.buildExtract(NewReg, Reg, MainSize * I);
+  }
+
+  for (unsigned Offset = MainSize * NumParts; Offset < RegSize;
+       Offset += LeftoverSize) {
+    unsigned NewReg = MRI.createGenericVirtualRegister(LeftoverTy);
+    LeftoverRegs.push_back(NewReg);
+    MIRBuilder.buildExtract(NewReg, Reg, Offset);
+  }
+
+  return true;
+}
+
+void LegalizerHelper::insertParts(unsigned DstReg,
+                                  LLT ResultTy, LLT PartTy,
+                                  ArrayRef<unsigned> PartRegs,
+                                  LLT LeftoverTy,
+                                  ArrayRef<unsigned> LeftoverRegs) {
+  if (!LeftoverTy.isValid()) {
+    assert(LeftoverRegs.empty());
+
+    if (PartTy.isVector())
+      MIRBuilder.buildConcatVectors(DstReg, PartRegs);
+    else
+      MIRBuilder.buildBuildVector(DstReg, PartRegs);
+    return;
+  }
+
+  unsigned PartSize = PartTy.getSizeInBits();
+  unsigned LeftoverPartSize = LeftoverTy.getSizeInBits();
+
+  unsigned CurResultReg = MRI.createGenericVirtualRegister(ResultTy);
+  MIRBuilder.buildUndef(CurResultReg);
+
+  unsigned Offset = 0;
+  for (unsigned PartReg : PartRegs) {
+    unsigned NewResultReg = MRI.createGenericVirtualRegister(ResultTy);
+    MIRBuilder.buildInsert(NewResultReg, CurResultReg, PartReg, Offset);
+    CurResultReg = NewResultReg;
+    Offset += PartSize;
+  }
+
+  for (unsigned I = 0, E = LeftoverRegs.size(); I != E; ++I) {
+    // Use the original output register for the final insert to avoid a copy.
+    unsigned NewResultReg = (I + 1 == E) ?
+      DstReg : MRI.createGenericVirtualRegister(ResultTy);
+
+    MIRBuilder.buildInsert(NewResultReg, CurResultReg, LeftoverRegs[I], Offset);
+    CurResultReg = NewResultReg;
+    Offset += LeftoverPartSize;
+  }
+}
+
 static RTLIB::Libcall getRTLibDesc(unsigned Opcode, unsigned Size) {
   switch (Opcode) {
   case TargetOpcode::G_SDIV:
@@ -1810,6 +1895,36 @@
   return Legalized;
 }
 
+/// Try to break down \p OrigTy into \p NarrowTy sized pieces.
+///
+/// Returns the number of \p NarrowTy elements needed to reconstruct \p OrigTy,
+/// with any leftover piece as type \p LeftoverTy
+///
+/// Returns -1 if the breakdown is not satisfiable.
+static int getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
+  assert(!LeftoverTy.isValid() && "this is an out argument");
+
+  unsigned Size = OrigTy.getSizeInBits();
+  unsigned NarrowSize = NarrowTy.getSizeInBits();
+  unsigned NumParts = Size / NarrowSize;
+  unsigned LeftoverSize = Size - NumParts * NarrowSize;
+  assert(Size > NarrowSize);
+
+  if (LeftoverSize == 0)
+    return NumParts;
+
+  if (NarrowTy.isVector()) {
+    unsigned EltSize = OrigTy.getScalarSizeInBits();
+    if (LeftoverSize % EltSize != 0)
+      return -1;
+    LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize);
+  } else {
+    LeftoverTy = LLT::scalar(LeftoverSize);
+  }
+
+  return NumParts;
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::fewerElementsVectorLoadStore(MachineInstr &MI, unsigned TypeIdx,
                                               LLT NarrowTy) {
@@ -1828,40 +1943,68 @@
   bool IsLoad = MI.getOpcode() == TargetOpcode::G_LOAD;
   unsigned ValReg = MI.getOperand(0).getReg();
   unsigned AddrReg = MI.getOperand(1).getReg();
-  unsigned NarrowSize = NarrowTy.getSizeInBits();
-  unsigned Size = MRI.getType(ValReg).getSizeInBits();
-  unsigned NumParts = Size / NarrowSize;
+  LLT ValTy = MRI.getType(ValReg);
 
-  SmallVector<unsigned, 8> NarrowRegs;
-  if (!IsLoad)
-    extractParts(ValReg, NarrowTy, NumParts, NarrowRegs);
-
-  const LLT OffsetTy =
-    LLT::scalar(MRI.getType(AddrReg).getScalarSizeInBits());
-  MachineFunction &MF = *MI.getMF();
-
-  for (unsigned Idx = 0; Idx < NumParts; ++Idx) {
-    unsigned Adjustment = Idx * NarrowTy.getSizeInBits() / 8;
-    unsigned Alignment = MinAlign(MMO->getAlignment(), Adjustment);
-    unsigned NewAddrReg = 0;
-    MIRBuilder.materializeGEP(NewAddrReg, AddrReg, OffsetTy, Adjustment);
-    MachineMemOperand &NewMMO = *MF.getMachineMemOperand(
-      MMO->getPointerInfo().getWithOffset(Adjustment), MMO->getFlags(),
-      NarrowTy.getSizeInBits() / 8, Alignment);
-    if (IsLoad) {
-      unsigned Dst = MRI.createGenericVirtualRegister(NarrowTy);
-      NarrowRegs.push_back(Dst);
-      MIRBuilder.buildLoad(Dst, NewAddrReg, NewMMO);
-    } else {
-      MIRBuilder.buildStore(NarrowRegs[Idx], NewAddrReg, NewMMO);
-    }
-  }
+  int NumParts = -1;
+  LLT LeftoverTy;
+  SmallVector<unsigned, 8> NarrowRegs, NarrowLeftoverRegs;
   if (IsLoad) {
-    if (NarrowTy.isVector())
-      MIRBuilder.buildConcatVectors(ValReg, NarrowRegs);
-    else
-      MIRBuilder.buildBuildVector(ValReg, NarrowRegs);
+    NumParts = getNarrowTypeBreakDown(ValTy, NarrowTy, LeftoverTy);
+  } else {
+    if (extractParts(ValReg, ValTy, NarrowTy, LeftoverTy, NarrowRegs,
+                     NarrowLeftoverRegs))
+      NumParts = NarrowRegs.size();
   }
+
+  if (NumParts == -1)
+    return UnableToLegalize;
+
+  const LLT OffsetTy = LLT::scalar(MRI.getType(AddrReg).getScalarSizeInBits());
+
+  unsigned TotalSize = ValTy.getSizeInBits();
+
+  // Split the load/store into PartTy sized pieces starting at Offset. If this
+  // is a load, return the new registers in ValRegs. For a store, each elements
+  // of ValRegs should be PartTy. Returns the next offset that needs to be
+  // handled.
+  auto splitTypePieces = [=](LLT PartTy, SmallVectorImpl<unsigned> &ValRegs,
+                             unsigned Offset) -> unsigned {
+    MachineFunction &MF = MIRBuilder.getMF();
+    unsigned PartSize = PartTy.getSizeInBits();
+    for (unsigned Idx = 0, E = NumParts; Idx != E && Offset < TotalSize;
+         Offset += PartSize, ++Idx) {
+      unsigned ByteSize = PartSize / 8;
+      unsigned ByteOffset = Offset / 8;
+      unsigned NewAddrReg = 0;
+
+      MIRBuilder.materializeGEP(NewAddrReg, AddrReg, OffsetTy, ByteOffset);
+
+      MachineMemOperand *NewMMO =
+        MF.getMachineMemOperand(MMO, ByteOffset, ByteSize);
+
+      if (IsLoad) {
+        unsigned Dst = MRI.createGenericVirtualRegister(PartTy);
+        ValRegs.push_back(Dst);
+        MIRBuilder.buildLoad(Dst, NewAddrReg, *NewMMO);
+      } else {
+        MIRBuilder.buildStore(ValRegs[Idx], NewAddrReg, *NewMMO);
+      }
+    }
+
+    return Offset;
+  };
+
+  unsigned HandledOffset = splitTypePieces(NarrowTy, NarrowRegs, 0);
+
+  // Handle the rest of the register if this isn't an even type breakdown.
+  if (LeftoverTy.isValid())
+    splitTypePieces(LeftoverTy, NarrowLeftoverRegs, HandledOffset);
+
+  if (IsLoad) {
+    insertParts(ValReg, ValTy, NarrowTy, NarrowRegs,
+                LeftoverTy, NarrowLeftoverRegs);
+  }
+
   MI.eraseFromParent();
   return Legalized;
 }