GlobalISel: Implement fewerElementsVector for select

llvm-svn: 352601
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 550cd05..5fb5aae 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1679,6 +1679,78 @@
 }
 
 LegalizerHelper::LegalizeResult
+LegalizerHelper::fewerElementsVectorSelect(MachineInstr &MI, unsigned TypeIdx,
+                                           LLT NarrowTy) {
+  unsigned DstReg = MI.getOperand(0).getReg();
+  unsigned CondReg = MI.getOperand(1).getReg();
+
+  unsigned NumParts = 0;
+  LLT NarrowTy0, NarrowTy1;
+
+  LLT DstTy = MRI.getType(DstReg);
+  LLT CondTy = MRI.getType(CondReg);
+  unsigned Size = DstTy.getSizeInBits();
+
+  assert(TypeIdx == 0 || CondTy.isVector());
+
+  if (TypeIdx == 0) {
+    NarrowTy0 = NarrowTy;
+    NarrowTy1 = CondTy;
+
+    unsigned NarrowSize = NarrowTy0.getSizeInBits();
+    // FIXME: Don't know how to handle the situation where the small vectors
+    // aren't all the same size yet.
+    if (Size % NarrowSize != 0)
+      return UnableToLegalize;
+
+    NumParts = Size / NarrowSize;
+
+    // Need to break down the condition type
+    if (CondTy.isVector()) {
+      if (CondTy.getNumElements() == NumParts)
+        NarrowTy1 = CondTy.getElementType();
+      else
+        NarrowTy1 = LLT::vector(CondTy.getNumElements() / NumParts,
+                                CondTy.getScalarSizeInBits());
+    }
+  } else {
+    NumParts = CondTy.getNumElements();
+    if (NarrowTy.isVector()) {
+      // TODO: Handle uneven breakdown.
+      if (NumParts * NarrowTy.getNumElements() != CondTy.getNumElements())
+        return UnableToLegalize;
+
+      return UnableToLegalize;
+    } else {
+      NarrowTy0 = DstTy.getElementType();
+      NarrowTy1 = NarrowTy;
+    }
+  }
+
+  SmallVector<unsigned, 2> DstRegs, Src0Regs, Src1Regs, Src2Regs;
+  if (CondTy.isVector())
+    extractParts(MI.getOperand(1).getReg(), NarrowTy1, NumParts, Src0Regs);
+
+  extractParts(MI.getOperand(2).getReg(), NarrowTy0, NumParts, Src1Regs);
+  extractParts(MI.getOperand(3).getReg(), NarrowTy0, NumParts, Src2Regs);
+
+  for (unsigned i = 0; i < NumParts; ++i) {
+    unsigned DstReg = MRI.createGenericVirtualRegister(NarrowTy0);
+    MIRBuilder.buildSelect(DstReg, CondTy.isVector() ? Src0Regs[i] : CondReg,
+                           Src1Regs[i], Src2Regs[i]);
+    DstRegs.push_back(DstReg);
+  }
+
+  if (NarrowTy0.isVector())
+    MIRBuilder.buildConcatVectors(DstReg, DstRegs);
+  else
+    MIRBuilder.buildBuildVector(DstReg, DstRegs);
+
+  MI.eraseFromParent();
+  return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
 LegalizerHelper::fewerElementsVectorLoadStore(MachineInstr &MI, unsigned TypeIdx,
                                               LLT NarrowTy) {
   // FIXME: Don't know how to handle secondary types yet.
@@ -1784,6 +1856,8 @@
   case G_ICMP:
   case G_FCMP:
     return fewerElementsVectorCmp(MI, TypeIdx, NarrowTy);
+  case G_SELECT:
+    return fewerElementsVectorSelect(MI, TypeIdx, NarrowTy);
   case G_LOAD:
   case G_STORE:
     return fewerElementsVectorLoadStore(MI, TypeIdx, NarrowTy);