GlobalISel: Implement fewerElementsVector for shifts

Introduce a new function which handles instructions with multiple type
indices, but have the same number of vector elements.

Also legalize v2s16 shifts when applicable.

llvm-svn: 353432
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index cd2d8c1..a6ca855 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -29,6 +29,36 @@
 using namespace llvm;
 using namespace LegalizeActions;
 
+/// Try to break down \p OrigTy into \p NarrowTy sized pieces.
+///
+/// Returns the number of \p NarrowTy elements needed to reconstruct \p OrigTy,
+/// with any leftover piece as type \p LeftoverTy
+///
+/// Returns -1 if the breakdown is not satisfiable.
+static int getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
+  assert(!LeftoverTy.isValid() && "this is an out argument");
+
+  unsigned Size = OrigTy.getSizeInBits();
+  unsigned NarrowSize = NarrowTy.getSizeInBits();
+  unsigned NumParts = Size / NarrowSize;
+  unsigned LeftoverSize = Size - NumParts * NarrowSize;
+  assert(Size > NarrowSize);
+
+  if (LeftoverSize == 0)
+    return NumParts;
+
+  if (NarrowTy.isVector()) {
+    unsigned EltSize = OrigTy.getScalarSizeInBits();
+    if (LeftoverSize % EltSize != 0)
+      return -1;
+    LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize);
+  } else {
+    LeftoverTy = LLT::scalar(LeftoverSize);
+  }
+
+  return NumParts;
+}
+
 LegalizerHelper::LegalizerHelper(MachineFunction &MF,
                                  GISelChangeObserver &Observer,
                                  MachineIRBuilder &Builder)
@@ -1728,6 +1758,102 @@
   return Legalized;
 }
 
+// Handle splitting vector operations which need to have the same number of
+// elements in each type index, but each type index may have a different element
+// type.
+//
+// e.g.  <4 x s64> = G_SHL <4 x s64>, <4 x s32> ->
+//       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
+//       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
+//
+// Also handles some irregular breakdown cases, e.g.
+// e.g.  <3 x s64> = G_SHL <3 x s64>, <3 x s32> ->
+//       <2 x s64> = G_SHL <2 x s64>, <2 x s32>
+//             s64 = G_SHL s64, s32
+LegalizerHelper::LegalizeResult
+LegalizerHelper::fewerElementsVectorMultiEltType(
+  MachineInstr &MI, unsigned TypeIdx, LLT NarrowTyArg) {
+  if (TypeIdx != 0)
+    return UnableToLegalize;
+
+  const LLT NarrowTy0 = NarrowTyArg;
+  const unsigned NewNumElts =
+      NarrowTy0.isVector() ? NarrowTy0.getNumElements() : 1;
+
+  const unsigned DstReg = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT LeftoverTy0;
+
+  // All of the operands need to have the same number of elements, so if we can
+  // determine a type breakdown for the result type, we can for all of the
+  // source types.
+  int NumParts = getNarrowTypeBreakDown(DstTy, NarrowTy0, LeftoverTy0);
+  if (NumParts < 0)
+    return UnableToLegalize;
+
+  SmallVector<MachineInstrBuilder, 4> NewInsts;
+
+  SmallVector<unsigned, 4> DstRegs, LeftoverDstRegs;
+  SmallVector<unsigned, 4> PartRegs, LeftoverRegs;
+
+  for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) {
+    LLT LeftoverTy;
+    unsigned SrcReg = MI.getOperand(I).getReg();
+    LLT SrcTyI = MRI.getType(SrcReg);
+    LLT NarrowTyI = LLT::scalarOrVector(NewNumElts, SrcTyI.getScalarType());
+    LLT LeftoverTyI;
+
+    // Split this operand into the requested typed registers, and any leftover
+    // required to reproduce the original type.
+    if (!extractParts(SrcReg, SrcTyI, NarrowTyI, LeftoverTyI, PartRegs,
+                      LeftoverRegs))
+      return UnableToLegalize;
+
+    if (I == 1) {
+      // For the first operand, create an instruction for each part and setup
+      // the result.
+      for (unsigned PartReg : PartRegs) {
+        unsigned PartDstReg = MRI.createGenericVirtualRegister(NarrowTy0);
+        NewInsts.push_back(MIRBuilder.buildInstrNoInsert(MI.getOpcode())
+                               .addDef(PartDstReg)
+                               .addUse(PartReg));
+        DstRegs.push_back(PartDstReg);
+      }
+
+      for (unsigned LeftoverReg : LeftoverRegs) {
+        unsigned PartDstReg = MRI.createGenericVirtualRegister(LeftoverTy0);
+        NewInsts.push_back(MIRBuilder.buildInstrNoInsert(MI.getOpcode())
+                               .addDef(PartDstReg)
+                               .addUse(LeftoverReg));
+        LeftoverDstRegs.push_back(PartDstReg);
+      }
+    } else {
+      assert(NewInsts.size() == PartRegs.size() + LeftoverRegs.size());
+
+      // Add the newly created operand splits to the existing instructions. The
+      // odd-sized pieces are ordered after the requested NarrowTyArg sized
+      // pieces.
+      unsigned InstCount = 0;
+      for (unsigned J = 0, JE = PartRegs.size(); J != JE; ++J)
+        NewInsts[InstCount++].addUse(PartRegs[J]);
+      for (unsigned J = 0, JE = LeftoverRegs.size(); J != JE; ++J)
+        NewInsts[InstCount++].addUse(LeftoverRegs[J]);
+    }
+
+    PartRegs.clear();
+    LeftoverRegs.clear();
+  }
+
+  // Insert the newly built operations and rebuild the result register.
+  for (auto &MIB : NewInsts)
+    MIRBuilder.insertInstr(MIB);
+
+  insertParts(DstReg, DstTy, NarrowTy0, DstRegs, LeftoverTy0, LeftoverDstRegs);
+
+  MI.eraseFromParent();
+  return Legalized;
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::fewerElementsVectorCasts(MachineInstr &MI, unsigned TypeIdx,
                                           LLT NarrowTy) {
@@ -1916,36 +2042,6 @@
   return Legalized;
 }
 
-/// Try to break down \p OrigTy into \p NarrowTy sized pieces.
-///
-/// Returns the number of \p NarrowTy elements needed to reconstruct \p OrigTy,
-/// with any leftover piece as type \p LeftoverTy
-///
-/// Returns -1 if the breakdown is not satisfiable.
-static int getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
-  assert(!LeftoverTy.isValid() && "this is an out argument");
-
-  unsigned Size = OrigTy.getSizeInBits();
-  unsigned NarrowSize = NarrowTy.getSizeInBits();
-  unsigned NumParts = Size / NarrowSize;
-  unsigned LeftoverSize = Size - NumParts * NarrowSize;
-  assert(Size > NarrowSize);
-
-  if (LeftoverSize == 0)
-    return NumParts;
-
-  if (NarrowTy.isVector()) {
-    unsigned EltSize = OrigTy.getScalarSizeInBits();
-    if (LeftoverSize % EltSize != 0)
-      return -1;
-    LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize);
-  } else {
-    LeftoverTy = LLT::scalar(LeftoverSize);
-  }
-
-  return NumParts;
-}
-
 LegalizerHelper::LegalizeResult
 LegalizerHelper::reduceLoadStoreWidth(MachineInstr &MI, unsigned TypeIdx,
                                       LLT NarrowTy) {
@@ -2069,6 +2165,10 @@
   case G_FSQRT:
   case G_BSWAP:
     return fewerElementsVectorBasic(MI, TypeIdx, NarrowTy);
+  case G_SHL:
+  case G_LSHR:
+  case G_ASHR:
+    return fewerElementsVectorMultiEltType(MI, TypeIdx, NarrowTy);
   case G_ZEXT:
   case G_SEXT:
   case G_ANYEXT: