GlobalISel: Implement narrowScalar for shift main type

This is pretty much directly ported from SelectionDAG. Doesn't include
the shift by non-constant but known bits version, since there isn't a
globalisel version of computeKnownBits yet.

This shows a disadvantage of targets not specifically which type
should be used for the shift amount. If type 0 is legalized before
type 1, the operations on the shift amount type use the wider type
(which are also less likely to legalize). This can be avoided by
targets specifying legalization actions on type 1 earlier than for
type 0.

llvm-svn: 353455
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index a6ca855..3dd5c32 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -811,14 +811,8 @@
   }
   case TargetOpcode::G_SHL:
   case TargetOpcode::G_LSHR:
-  case TargetOpcode::G_ASHR: {
-    if (TypeIdx != 1)
-      return UnableToLegalize; // TODO
-    Observer.changingInstr(MI);
-    narrowScalarSrc(MI, NarrowTy, 2);
-    Observer.changedInstr(MI);
-    return Legalized;
-  }
+  case TargetOpcode::G_ASHR:
+    return narrowScalarShift(MI, TypeIdx, NarrowTy);
   case TargetOpcode::G_CTLZ:
   case TargetOpcode::G_CTLZ_ZERO_UNDEF:
   case TargetOpcode::G_CTTZ:
@@ -2195,6 +2189,221 @@
 }
 
 LegalizerHelper::LegalizeResult
+LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
+                                             const LLT HalfTy, const LLT AmtTy) {
+
+  unsigned InL = MRI.createGenericVirtualRegister(HalfTy);
+  unsigned InH = MRI.createGenericVirtualRegister(HalfTy);
+  MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1).getReg());
+
+  if (Amt.isNullValue()) {
+    MIRBuilder.buildMerge(MI.getOperand(0).getReg(), {InL, InH});
+    MI.eraseFromParent();
+    return Legalized;
+  }
+
+  LLT NVT = HalfTy;
+  unsigned NVTBits = HalfTy.getSizeInBits();
+  unsigned VTBits = 2 * NVTBits;
+
+  SrcOp Lo(0), Hi(0);
+  if (MI.getOpcode() == TargetOpcode::G_SHL) {
+    if (Amt.ugt(VTBits)) {
+      Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
+    } else if (Amt.ugt(NVTBits)) {
+      Lo = MIRBuilder.buildConstant(NVT, 0);
+      Hi = MIRBuilder.buildShl(NVT, InL,
+                               MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
+    } else if (Amt == NVTBits) {
+      Lo = MIRBuilder.buildConstant(NVT, 0);
+      Hi = InL;
+    } else {
+      Lo = MIRBuilder.buildShl(NVT, InL, MIRBuilder.buildConstant(AmtTy, Amt));
+      Hi = MIRBuilder.buildOr(
+          NVT,
+          MIRBuilder.buildShl(NVT, InH, MIRBuilder.buildConstant(AmtTy, Amt)),
+          MIRBuilder.buildLShr(
+              NVT, InL, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits)));
+    }
+  } else if (MI.getOpcode() == TargetOpcode::G_LSHR) {
+    if (Amt.ugt(VTBits)) {
+      Lo = Hi = MIRBuilder.buildConstant(NVT, 0);
+    } else if (Amt.ugt(NVTBits)) {
+      Lo = MIRBuilder.buildLShr(NVT, InH,
+                                MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
+      Hi = MIRBuilder.buildConstant(NVT, 0);
+    } else if (Amt == NVTBits) {
+      Lo = InH;
+      Hi = MIRBuilder.buildConstant(NVT, 0);
+    } else {
+      auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
+
+      auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
+      auto OrRHS = MIRBuilder.buildShl(
+          NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
+
+      Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
+      Hi = MIRBuilder.buildLShr(NVT, InH, ShiftAmtConst);
+    }
+  } else {
+    if (Amt.ugt(VTBits)) {
+      Hi = Lo = MIRBuilder.buildAShr(
+          NVT, InH, MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
+    } else if (Amt.ugt(NVTBits)) {
+      Lo = MIRBuilder.buildAShr(NVT, InH,
+                                MIRBuilder.buildConstant(AmtTy, Amt - NVTBits));
+      Hi = MIRBuilder.buildAShr(NVT, InH,
+                                MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
+    } else if (Amt == NVTBits) {
+      Lo = InH;
+      Hi = MIRBuilder.buildAShr(NVT, InH,
+                                MIRBuilder.buildConstant(AmtTy, NVTBits - 1));
+    } else {
+      auto ShiftAmtConst = MIRBuilder.buildConstant(AmtTy, Amt);
+
+      auto OrLHS = MIRBuilder.buildLShr(NVT, InL, ShiftAmtConst);
+      auto OrRHS = MIRBuilder.buildShl(
+          NVT, InH, MIRBuilder.buildConstant(AmtTy, -Amt + NVTBits));
+
+      Lo = MIRBuilder.buildOr(NVT, OrLHS, OrRHS);
+      Hi = MIRBuilder.buildAShr(NVT, InH, ShiftAmtConst);
+    }
+  }
+
+  MIRBuilder.buildMerge(MI.getOperand(0).getReg(), {Lo.getReg(), Hi.getReg()});
+  MI.eraseFromParent();
+
+  return Legalized;
+}
+
+// TODO: Optimize if constant shift amount.
+LegalizerHelper::LegalizeResult
+LegalizerHelper::narrowScalarShift(MachineInstr &MI, unsigned TypeIdx,
+                                   LLT RequestedTy) {
+  if (TypeIdx == 1) {
+    Observer.changingInstr(MI);
+    narrowScalarSrc(MI, RequestedTy, 2);
+    Observer.changedInstr(MI);
+    return Legalized;
+  }
+
+  unsigned DstReg = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  if (DstTy.isVector())
+    return UnableToLegalize;
+
+  unsigned Amt = MI.getOperand(2).getReg();
+  LLT ShiftAmtTy = MRI.getType(Amt);
+  const unsigned DstEltSize = DstTy.getScalarSizeInBits();
+  if (DstEltSize % 2 != 0)
+    return UnableToLegalize;
+
+  // Ignore the input type. We can only go to exactly half the size of the
+  // input. If that isn't small enough, the resulting pieces will be further
+  // legalized.
+  const unsigned NewBitSize = DstEltSize / 2;
+  const LLT HalfTy = LLT::scalar(NewBitSize);
+  const LLT CondTy = LLT::scalar(1);
+
+  if (const MachineInstr *KShiftAmt =
+          getOpcodeDef(TargetOpcode::G_CONSTANT, Amt, MRI)) {
+    return narrowScalarShiftByConstant(
+        MI, KShiftAmt->getOperand(1).getCImm()->getValue(), HalfTy, ShiftAmtTy);
+  }
+
+  // TODO: Expand with known bits.
+
+  // Handle the fully general expansion by an unknown amount.
+  auto NewBits = MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize);
+
+  unsigned InL = MRI.createGenericVirtualRegister(HalfTy);
+  unsigned InH = MRI.createGenericVirtualRegister(HalfTy);
+  MIRBuilder.buildUnmerge({InL, InH}, MI.getOperand(1).getReg());
+
+  auto AmtExcess = MIRBuilder.buildSub(ShiftAmtTy, Amt, NewBits);
+  auto AmtLack = MIRBuilder.buildSub(ShiftAmtTy, NewBits, Amt);
+
+  auto Zero = MIRBuilder.buildConstant(ShiftAmtTy, 0);
+  auto IsShort = MIRBuilder.buildICmp(ICmpInst::ICMP_ULT, CondTy, Amt, NewBits);
+  auto IsZero = MIRBuilder.buildICmp(ICmpInst::ICMP_EQ, CondTy, Amt, Zero);
+
+  unsigned ResultRegs[2];
+  switch (MI.getOpcode()) {
+  case TargetOpcode::G_SHL: {
+    // Short: ShAmt < NewBitSize
+    auto LoS = MIRBuilder.buildShl(HalfTy, InH, Amt);
+
+    auto OrLHS = MIRBuilder.buildShl(HalfTy, InH, Amt);
+    auto OrRHS = MIRBuilder.buildLShr(HalfTy, InL, AmtLack);
+    auto HiS = MIRBuilder.buildOr(HalfTy, OrLHS, OrRHS);
+
+    // Long: ShAmt >= NewBitSize
+    auto LoL = MIRBuilder.buildConstant(HalfTy, 0);         // Lo part is zero.
+    auto HiL = MIRBuilder.buildShl(HalfTy, InL, AmtExcess); // Hi from Lo part.
+
+    auto Lo = MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL);
+    auto Hi = MIRBuilder.buildSelect(
+        HalfTy, IsZero, InH, MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL));
+
+    ResultRegs[0] = Lo.getReg(0);
+    ResultRegs[1] = Hi.getReg(0);
+    break;
+  }
+  case TargetOpcode::G_LSHR: {
+    // Short: ShAmt < NewBitSize
+    auto HiS = MIRBuilder.buildLShr(HalfTy, InH, Amt);
+
+    auto OrLHS = MIRBuilder.buildLShr(HalfTy, InL, Amt);
+    auto OrRHS = MIRBuilder.buildShl(HalfTy, InH, AmtLack);
+    auto LoS = MIRBuilder.buildOr(HalfTy, OrLHS, OrRHS);
+
+    // Long: ShAmt >= NewBitSize
+    auto HiL = MIRBuilder.buildConstant(HalfTy, 0);          // Hi part is zero.
+    auto LoL = MIRBuilder.buildLShr(HalfTy, InH, AmtExcess); // Lo from Hi part.
+
+    auto Lo = MIRBuilder.buildSelect(
+        HalfTy, IsZero, InL, MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL));
+    auto Hi = MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL);
+
+    ResultRegs[0] = Lo.getReg(0);
+    ResultRegs[1] = Hi.getReg(0);
+    break;
+  }
+  case TargetOpcode::G_ASHR: {
+    // Short: ShAmt < NewBitSize
+    auto HiS = MIRBuilder.buildAShr(HalfTy, InH, Amt);
+
+    auto OrLHS = MIRBuilder.buildLShr(HalfTy, InL, Amt);
+    auto OrRHS = MIRBuilder.buildLShr(HalfTy, InH, AmtLack);
+    auto LoS = MIRBuilder.buildOr(HalfTy, OrLHS, OrRHS);
+
+    // Long: ShAmt >= NewBitSize
+
+    // Sign of Hi part.
+    auto HiL = MIRBuilder.buildAShr(
+        HalfTy, InH, MIRBuilder.buildConstant(ShiftAmtTy, NewBitSize - 1));
+
+    auto LoL = MIRBuilder.buildAShr(HalfTy, InH, AmtExcess); // Lo from Hi part.
+
+    auto Lo = MIRBuilder.buildSelect(
+        HalfTy, IsZero, InL, MIRBuilder.buildSelect(HalfTy, IsShort, LoS, LoL));
+
+    auto Hi = MIRBuilder.buildSelect(HalfTy, IsShort, HiS, HiL);
+
+    ResultRegs[0] = Lo.getReg(0);
+    ResultRegs[1] = Hi.getReg(0);
+    break;
+  }
+  default:
+    llvm_unreachable("not a shift");
+  }
+
+  MIRBuilder.buildMerge(DstReg, ResultRegs);
+  MI.eraseFromParent();
+  return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
 LegalizerHelper::narrowScalarMul(MachineInstr &MI, unsigned TypeIdx, LLT NewTy) {
   unsigned DstReg = MI.getOperand(0).getReg();
   unsigned Src0 = MI.getOperand(1).getReg();