GlobalISel: Reimplement widenScalar for G_UNMERGE_VALUES results

Only use shifts if the requested type exactly matches the source type,
and create sub-unmerges otherwise.
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index e3882bf..a6fe2c3 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1396,7 +1396,7 @@
   if (TypeIdx != 0)
     return UnableToLegalize;
 
-  unsigned NumDst = MI.getNumOperands() - 1;
+  int NumDst = MI.getNumOperands() - 1;
   Register SrcReg = MI.getOperand(NumDst).getReg();
   LLT SrcTy = MRI.getType(SrcReg);
   if (!SrcTy.isScalar())
@@ -1407,26 +1407,66 @@
   if (!DstTy.isScalar())
     return UnableToLegalize;
 
-  unsigned NewSrcSize = NumDst * WideTy.getSizeInBits();
-  LLT NewSrcTy = LLT::scalar(NewSrcSize);
-  unsigned SizeDiff = WideTy.getSizeInBits() - DstTy.getSizeInBits();
+  if (WideTy == SrcTy) {
+    // Theres no unmerge type to target. Directly extract the bits from the
+    // source type
+    unsigned DstSize = DstTy.getSizeInBits();
 
-  auto WideSrc = MIRBuilder.buildZExt(NewSrcTy, SrcReg);
+    MIRBuilder.buildTrunc(Dst0Reg, SrcReg);
+    for (int I = 1; I != NumDst; ++I) {
+      auto ShiftAmt = MIRBuilder.buildConstant(SrcTy, DstSize * I);
+      auto Shr = MIRBuilder.buildLShr(SrcTy, SrcReg, ShiftAmt);
+      MIRBuilder.buildTrunc(MI.getOperand(I), Shr);
+    }
 
-  for (unsigned I = 1; I != NumDst; ++I) {
-    auto ShiftAmt = MIRBuilder.buildConstant(NewSrcTy, SizeDiff * I);
-    auto Shl = MIRBuilder.buildShl(NewSrcTy, WideSrc, ShiftAmt);
-    WideSrc = MIRBuilder.buildOr(NewSrcTy, WideSrc, Shl);
+    MI.eraseFromParent();
+    return Legalized;
   }
 
-  Observer.changingInstr(MI);
+  // TODO
+  if (WideTy.getSizeInBits() > SrcTy.getSizeInBits())
+    return UnableToLegalize;
 
-  MI.getOperand(NumDst).setReg(WideSrc.getReg(0));
-  for (unsigned I = 0; I != NumDst; ++I)
-    widenScalarDst(MI, WideTy, I);
+  // Extend the source to a wider type.
+  LLT LCMTy = getLCMType(SrcTy, WideTy);
 
-  Observer.changedInstr(MI);
+  Register WideSrc = SrcReg;
+  if (LCMTy != SrcTy)
+    WideSrc = MIRBuilder.buildAnyExt(LCMTy, WideSrc).getReg(0);
+  auto Unmerge = MIRBuilder.buildUnmerge(WideTy, WideSrc);
 
+  // Create a sequence of unmerges to the original results. since we may have
+  // widened the source, we will need to pad the results with dead defs to cover
+  // the source register.
+  // e.g. widen s16 to s32:
+  // %1:_(s16), %2:_(s16), %3:_(s16) = G_UNMERGE_VALUES %0:_(s48)
+  //
+  // =>
+  //  %4:_(s64) = G_ANYEXT %0:_(s48)
+  //  %5:_(s32), %6:_(s32) = G_UNMERGE_VALUES %4 ; Requested unmerge
+  //  %1:_(s16), %2:_(s16) = G_UNMERGE_VALUES %5 ; unpack to original regs
+  //  %3:_(s16), dead %7 = G_UNMERGE_VALUES %6 ; original reg + extra dead def
+
+  const int NumUnmerge = Unmerge->getNumOperands() - 1;
+  const int PartsPerUnmerge = WideTy.getSizeInBits() / DstTy.getSizeInBits();
+
+  for (int I = 0; I != NumUnmerge; ++I) {
+    auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_UNMERGE_VALUES);
+
+    for (int J = 0; J != PartsPerUnmerge; ++J) {
+      int Idx = I * PartsPerUnmerge + J;
+      if (Idx < NumDst)
+        MIB.addDef(MI.getOperand(Idx).getReg());
+      else {
+        // Create dead def for excess components.
+        MIB.addDef(MRI.createGenericVirtualRegister(DstTy));
+      }
+    }
+
+    MIB.addUse(Unmerge.getReg(I));
+  }
+
+  MI.eraseFromParent();
   return Legalized;
 }