GlobalISel: Implement bitcast action for G_INSERT_VECTOR_ELT

This mirrors the support for the equivalent extracts. This also
creates a huge mess that would be greatly improved if we had any bit
operation combines.
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 8202458..0b07dd0 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -2369,6 +2369,28 @@
   return UnableToLegalize;
 }
 
+/// Figure out the bit offset into a register when coercing a vector index for
+/// the wide element type. This is only for the case when promoting vector to
+/// one with larger elements.
+//
+///
+/// %offset_idx = G_AND %idx, ~(-1 << Log2(DstEltSize / SrcEltSize))
+/// %offset_bits = G_SHL %offset_idx, Log2(SrcEltSize)
+static Register getBitcastWiderVectorElementOffset(MachineIRBuilder &B,
+                                                   Register Idx,
+                                                   unsigned NewEltSize,
+                                                   unsigned OldEltSize) {
+  const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
+  LLT IdxTy = B.getMRI()->getType(Idx);
+
+  // Now figure out the amount we need to shift to get the target bits.
+  auto OffsetMask = B.buildConstant(
+    IdxTy, ~(APInt::getAllOnesValue(IdxTy.getSizeInBits()) << Log2EltRatio));
+  auto OffsetIdx = B.buildAnd(IdxTy, Idx, OffsetMask);
+  return B.buildShl(IdxTy, OffsetIdx,
+                    B.buildConstant(IdxTy, Log2_32(OldEltSize))).getReg(0);
+}
+
 /// Perform a G_EXTRACT_VECTOR_ELT in a different sized vector element. If this
 /// is casting to a vector with a smaller element size, perform multiple element
 /// extracts and merge the results. If this is coercing to a vector with larger
@@ -2467,13 +2489,9 @@
                                                      ScaledIdx).getReg(0);
     }
 
-    // Now figure out the amount we need to shift to get the target bits.
-    auto OffsetMask = MIRBuilder.buildConstant(
-      IdxTy, ~(APInt::getAllOnesValue(IdxTy.getSizeInBits()) << Log2EltRatio));
-    auto OffsetIdx = MIRBuilder.buildAnd(IdxTy, Idx, OffsetMask);
-    auto OffsetBits = MIRBuilder.buildShl(
-      IdxTy, OffsetIdx,
-      MIRBuilder.buildConstant(IdxTy, Log2_32(OldEltSize)));
+    // Compute the bit offset into the register of the target element.
+    Register OffsetBits = getBitcastWiderVectorElementOffset(
+      MIRBuilder, Idx, NewEltSize, OldEltSize);
 
     // Shift the wide element to get the target element.
     auto ExtractedBits = MIRBuilder.buildLShr(NewEltTy, WideElt, OffsetBits);
@@ -2485,6 +2503,104 @@
   return UnableToLegalize;
 }
 
+/// Emit code to insert \p InsertReg into \p TargetRet at \p OffsetBits in \p
+/// TargetReg, while preserving other bits in \p TargetReg.
+///
+/// (InsertReg << Offset) | (TargetReg & ~(-1 >> InsertReg.size()) << Offset)
+static Register buildBitFieldInsert(MachineIRBuilder &B,
+                                    Register TargetReg, Register InsertReg,
+                                    Register OffsetBits) {
+  LLT TargetTy = B.getMRI()->getType(TargetReg);
+  LLT InsertTy = B.getMRI()->getType(InsertReg);
+  auto ZextVal = B.buildZExt(TargetTy, InsertReg);
+  auto ShiftedInsertVal = B.buildShl(TargetTy, ZextVal, OffsetBits);
+
+  // Produce a bitmask of the value to insert
+  auto EltMask = B.buildConstant(
+    TargetTy, APInt::getLowBitsSet(TargetTy.getSizeInBits(),
+                                   InsertTy.getSizeInBits()));
+  // Shift it into position
+  auto ShiftedMask = B.buildShl(TargetTy, EltMask, OffsetBits);
+  auto InvShiftedMask = B.buildNot(TargetTy, ShiftedMask);
+
+  // Clear out the bits in the wide element
+  auto MaskedOldElt = B.buildAnd(TargetTy, TargetReg, InvShiftedMask);
+
+  // The value to insert has all zeros already, so stick it into the masked
+  // wide element.
+  return B.buildOr(TargetTy, MaskedOldElt, ShiftedInsertVal).getReg(0);
+}
+
+/// Perform a G_INSERT_VECTOR_ELT in a different sized vector element. If this
+/// is increasing the element size, perform the indexing in the target element
+/// type, and use bit operations to insert at the element position. This is
+/// intended for architectures that can dynamically index the register file and
+/// want to force indexing in the native register size.
+LegalizerHelper::LegalizeResult
+LegalizerHelper::bitcastInsertVectorElt(MachineInstr &MI, unsigned TypeIdx,
+                                        LLT CastTy) {
+  if (TypeIdx != 0)
+    return UnableToLegalize;
+
+  Register Dst = MI.getOperand(0).getReg();
+  Register SrcVec = MI.getOperand(1).getReg();
+  Register Val = MI.getOperand(2).getReg();
+  Register Idx = MI.getOperand(3).getReg();
+
+  LLT VecTy = MRI.getType(Dst);
+  LLT ValTy = MRI.getType(Val);
+  LLT IdxTy = MRI.getType(Idx);
+
+  LLT VecEltTy = VecTy.getElementType();
+  LLT NewEltTy = CastTy.isVector() ? CastTy.getElementType() : CastTy;
+  const unsigned NewEltSize = NewEltTy.getSizeInBits();
+  const unsigned OldEltSize = VecEltTy.getSizeInBits();
+
+  unsigned NewNumElts = CastTy.isVector() ? CastTy.getNumElements() : 1;
+  unsigned OldNumElts = VecTy.getNumElements();
+
+  Register CastVec = MIRBuilder.buildBitcast(CastTy, SrcVec).getReg(0);
+  if (NewNumElts < OldNumElts) {
+    if (NewEltSize % OldEltSize != 0)
+      return UnableToLegalize;
+
+    // This only depends on powers of 2 because we use bit tricks to figure out
+    // the bit offset we need to shift to get the target element. A general
+    // expansion could emit division/multiply.
+    if (!isPowerOf2_32(NewEltSize / OldEltSize))
+      return UnableToLegalize;
+
+    const unsigned Log2EltRatio = Log2_32(NewEltSize / OldEltSize);
+    auto Log2Ratio = MIRBuilder.buildConstant(IdxTy, Log2EltRatio);
+
+    // Divide to get the index in the wider element type.
+    auto ScaledIdx = MIRBuilder.buildLShr(IdxTy, Idx, Log2Ratio);
+
+    Register ExtractedElt = CastVec;
+    if (CastTy.isVector()) {
+      ExtractedElt = MIRBuilder.buildExtractVectorElement(NewEltTy, CastVec,
+                                                          ScaledIdx).getReg(0);
+    }
+
+    // Compute the bit offset into the register of the target element.
+    Register OffsetBits = getBitcastWiderVectorElementOffset(
+      MIRBuilder, Idx, NewEltSize, OldEltSize);
+
+    Register InsertedElt = buildBitFieldInsert(MIRBuilder, ExtractedElt,
+                                               Val, OffsetBits);
+    if (CastTy.isVector()) {
+      InsertedElt = MIRBuilder.buildInsertVectorElement(
+        CastTy, CastVec, InsertedElt, ScaledIdx).getReg(0);
+    }
+
+    MIRBuilder.buildBitcast(Dst, InsertedElt);
+    MI.eraseFromParent();
+    return Legalized;
+  }
+
+  return UnableToLegalize;
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::lowerLoad(MachineInstr &MI) {
   // Lower to a memory-width G_LOAD and a G_SEXT/G_ZEXT/G_ANYEXT
@@ -2674,6 +2790,8 @@
   }
   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
     return bitcastExtractVectorElt(MI, TypeIdx, CastTy);
+  case TargetOpcode::G_INSERT_VECTOR_ELT:
+    return bitcastInsertVectorElt(MI, TypeIdx, CastTy);
   default:
     return UnableToLegalize;
   }