GlobalISel: Implement lower for G_INSERT_VECTOR_ELT
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index b56d1a0..949c0b4 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -2908,7 +2908,8 @@
     return Legalized;
   }
   case G_EXTRACT_VECTOR_ELT:
-    return lowerExtractVectorElt(MI);
+  case G_INSERT_VECTOR_ELT:
+    return lowerExtractInsertVectorElt(MI);
   case G_SHUFFLE_VECTOR:
     return lowerShuffleVector(MI);
   case G_DYN_STACKALLOC:
@@ -3522,7 +3523,7 @@
   //
   // TODO: We could emit a chain of compare/select to figure out which piece to
   // index.
-  return lowerExtractVectorElt(MI);
+  return lowerExtractInsertVectorElt(MI);
 }
 
 LegalizerHelper::LegalizeResult
@@ -5423,8 +5424,8 @@
   return Legalized;
 }
 
-/// Lower a vector extract by writing the vector to a stack temporary and
-/// reloading the element.
+/// Lower a vector extract or insert by writing the vector to a stack temporary
+/// and reloading the element or vector.
 ///
 /// %dst = G_EXTRACT_VECTOR_ELT %vec, %idx
 ///  =>
@@ -5434,10 +5435,15 @@
 ///  %element_ptr = G_PTR_ADD %stack_temp, %idx
 ///  %dst = G_LOAD %element_ptr
 LegalizerHelper::LegalizeResult
-LegalizerHelper::lowerExtractVectorElt(MachineInstr &MI) {
+LegalizerHelper::lowerExtractInsertVectorElt(MachineInstr &MI) {
   Register DstReg = MI.getOperand(0).getReg();
   Register SrcVec = MI.getOperand(1).getReg();
-  Register Idx = MI.getOperand(2).getReg();
+  Register InsertVal;
+  if (MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT)
+    InsertVal = MI.getOperand(2).getReg();
+
+  Register Idx = MI.getOperand(MI.getNumOperands() - 1).getReg();
+
   LLT VecTy = MRI.getType(SrcVec);
   LLT EltTy = VecTy.getElementType();
   if (!EltTy.isByteSized()) { // Not implemented.
@@ -5446,30 +5452,39 @@
   }
 
   unsigned EltBytes = EltTy.getSizeInBytes();
-  Align StoreAlign = getStackTemporaryAlignment(VecTy);
-  Align LoadAlign;
+  Align VecAlign = getStackTemporaryAlignment(VecTy);
+  Align EltAlign;
 
   MachinePointerInfo PtrInfo;
   auto StackTemp = createStackTemporary(TypeSize::Fixed(VecTy.getSizeInBytes()),
-                                        StoreAlign, PtrInfo);
-  MIRBuilder.buildStore(SrcVec, StackTemp, PtrInfo, StoreAlign);
+                                        VecAlign, PtrInfo);
+  MIRBuilder.buildStore(SrcVec, StackTemp, PtrInfo, VecAlign);
 
   // Get the pointer to the element, and be sure not to hit undefined behavior
   // if the index is out of bounds.
-  Register LoadPtr = getVectorElementPointer(StackTemp.getReg(0), VecTy, Idx);
+  Register EltPtr = getVectorElementPointer(StackTemp.getReg(0), VecTy, Idx);
 
   int64_t IdxVal;
   if (mi_match(Idx, MRI, m_ICst(IdxVal))) {
     int64_t Offset = IdxVal * EltBytes;
     PtrInfo = PtrInfo.getWithOffset(Offset);
-    LoadAlign = commonAlignment(StoreAlign, Offset);
+    EltAlign = commonAlignment(VecAlign, Offset);
   } else {
     // We lose information with a variable offset.
-    LoadAlign = getStackTemporaryAlignment(EltTy);
-    PtrInfo = MachinePointerInfo(MRI.getType(LoadPtr).getAddressSpace());
+    EltAlign = getStackTemporaryAlignment(EltTy);
+    PtrInfo = MachinePointerInfo(MRI.getType(EltPtr).getAddressSpace());
   }
 
-  MIRBuilder.buildLoad(DstReg, LoadPtr, PtrInfo, LoadAlign);
+  if (InsertVal) {
+    // Write the inserted element
+    MIRBuilder.buildStore(InsertVal, EltPtr, PtrInfo, EltAlign);
+
+    // Reload the whole vector.
+    MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
+  } else {
+    MIRBuilder.buildLoad(DstReg, EltPtr, PtrInfo, EltAlign);
+  }
+
   MI.eraseFromParent();
   return Legalized;
 }