diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h
index 74881a5..4917a35 100644
--- a/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -858,7 +858,8 @@
                                                int DmaskIdx = -1);
 
   Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
-                                    APInt &UndefElts, unsigned Depth = 0);
+                                    APInt &UndefElts, unsigned Depth = 0,
+                                    bool AllowMultipleUsers = false);
 
   /// Canonicalize the position of binops relative to shufflevector.
   Instruction *foldVectorBinop(BinaryOperator &Inst);
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 4680cb3..d30ab80 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -1074,16 +1074,22 @@
 }
 
 /// The specified value produces a vector with any number of elements.
+/// This method analyzes which elements of the operand are undef and returns
+/// that information in UndefElts.
+///
 /// DemandedElts contains the set of elements that are actually used by the
-/// caller. This method analyzes which elements of the operand are undef and
-/// returns that information in UndefElts.
+/// caller, and by default (AllowMultipleUsers equals false) the value is
+/// simplified only if it has a single caller. If AllowMultipleUsers is set
+/// to true, DemandedElts refers to the union of sets of elements that are
+/// used by all callers.
 ///
 /// If the information about demanded elements can be used to simplify the
 /// operation, the operation is simplified, then the resultant value is
 /// returned.  This returns null if no change was made.
 Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
                                                 APInt &UndefElts,
-                                                unsigned Depth) {
+                                                unsigned Depth,
+                                                bool AllowMultipleUsers) {
   unsigned VWidth = V->getType()->getVectorNumElements();
   APInt EltMask(APInt::getAllOnesValue(VWidth));
   assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
@@ -1137,19 +1143,21 @@
   if (Depth == 10)
     return nullptr;
 
-  // If multiple users are using the root value, proceed with
-  // simplification conservatively assuming that all elements
-  // are needed.
-  if (!V->hasOneUse()) {
-    // Quit if we find multiple users of a non-root value though.
-    // They'll be handled when it's their turn to be visited by
-    // the main instcombine process.
-    if (Depth != 0)
-      // TODO: Just compute the UndefElts information recursively.
-      return nullptr;
+  if (!AllowMultipleUsers) {
+    // If multiple users are using the root value, proceed with
+    // simplification conservatively assuming that all elements
+    // are needed.
+    if (!V->hasOneUse()) {
+      // Quit if we find multiple users of a non-root value though.
+      // They'll be handled when it's their turn to be visited by
+      // the main instcombine process.
+      if (Depth != 0)
+        // TODO: Just compute the UndefElts information recursively.
+        return nullptr;
 
-    // Conservatively assume that all elements are needed.
-    DemandedElts = EltMask;
+      // Conservatively assume that all elements are needed.
+      DemandedElts = EltMask;
+    }
   }
 
   Instruction *I = dyn_cast<Instruction>(V);
diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 13c38ca..9c89074 100644
--- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -253,6 +253,69 @@
   return nullptr;
 }
 
+/// Find elements of V demanded by UserInstr.
+static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
+  unsigned VWidth = V->getType()->getVectorNumElements();
+
+  // Conservatively assume that all elements are needed.
+  APInt UsedElts(APInt::getAllOnesValue(VWidth));
+
+  switch (UserInstr->getOpcode()) {
+  case Instruction::ExtractElement: {
+    ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr);
+    assert(EEI->getVectorOperand() == V);
+    ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand());
+    if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) {
+      UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue());
+    }
+    break;
+  }
+  case Instruction::ShuffleVector: {
+    ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
+    unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements();
+
+    UsedElts = APInt(VWidth, 0);
+    for (unsigned i = 0; i < MaskNumElts; i++) {
+      unsigned MaskVal = Shuffle->getMaskValue(i);
+      if (MaskVal == -1u || MaskVal >= 2 * VWidth)
+        continue;
+      if (Shuffle->getOperand(0) == V && (MaskVal < VWidth))
+        UsedElts.setBit(MaskVal);
+      if (Shuffle->getOperand(1) == V &&
+          ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth)))
+        UsedElts.setBit(MaskVal - VWidth);
+    }
+    break;
+  }
+  default:
+    break;
+  }
+  return UsedElts;
+}
+
+/// Find union of elements of V demanded by all its users.
+/// If it is known by querying findDemandedEltsBySingleUser that
+/// no user demands an element of V, then the corresponding bit
+/// remains unset in the returned value.
+static APInt findDemandedEltsByAllUsers(Value *V) {
+  unsigned VWidth = V->getType()->getVectorNumElements();
+
+  APInt UnionUsedElts(VWidth, 0);
+  for (const Use &U : V->uses()) {
+    if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
+      UnionUsedElts |= findDemandedEltsBySingleUser(V, I);
+    } else {
+      UnionUsedElts = APInt::getAllOnesValue(VWidth);
+      break;
+    }
+
+    if (UnionUsedElts.isAllOnesValue())
+      break;
+  }
+
+  return UnionUsedElts;
+}
+
 Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
   Value *SrcVec = EI.getVectorOperand();
   Value *Index = EI.getIndexOperand();
@@ -271,19 +334,35 @@
       return nullptr;
 
     // This instruction only demands the single element from the input vector.
-    // If the input vector has a single use, simplify it based on this use
-    // property.
-    if (SrcVec->hasOneUse() && NumElts != 1) {
-      APInt UndefElts(NumElts, 0);
-      APInt DemandedElts(NumElts, 0);
-      DemandedElts.setBit(IndexC->getZExtValue());
-      if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts,
-                                                UndefElts)) {
-        EI.setOperand(0, V);
-        return &EI;
+    if (NumElts != 1) {
+      // If the input vector has a single use, simplify it based on this use
+      // property.
+      if (SrcVec->hasOneUse()) {
+        APInt UndefElts(NumElts, 0);
+        APInt DemandedElts(NumElts, 0);
+        DemandedElts.setBit(IndexC->getZExtValue());
+        if (Value *V =
+                SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) {
+          EI.setOperand(0, V);
+          return &EI;
+        }
+      } else {
+        // If the input vector has multiple uses, simplify it based on a union
+        // of all elements used.
+        APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec);
+        if (!DemandedElts.isAllOnesValue()) {
+          APInt UndefElts(NumElts, 0);
+          if (Value *V = SimplifyDemandedVectorElts(
+                  SrcVec, DemandedElts, UndefElts, 0 /* Depth */,
+                  true /* AllowMultipleUsers */)) {
+            if (V != SrcVec) {
+              SrcVec->replaceAllUsesWith(V);
+              return &EI;
+            }
+          }
+        }
       }
     }
-
     if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian()))
       return I;
 
diff --git a/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll b/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll
index db7533a..389ca60 100644
--- a/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll
+++ b/test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll
@@ -152,11 +152,10 @@
   ret <3 x float> %shuf
 }
 
-; FIXME: Not handled even though only 2 elts used
 ; CHECK-LABEL: @extract_elt0_elt1_buffer_load_v4f32_2(
-; CHECK-NEXT: %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
-; CHECK-NEXT: %elt0 = extractelement <4 x float> %data, i32 0
-; CHECK-NEXT: %elt1 = extractelement <4 x float> %data, i32 1
+; CHECK-NEXT: %data = call <2 x float> @llvm.amdgcn.buffer.load.v2f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+; CHECK-NEXT: %elt0 = extractelement <2 x float> %data, i32 0
+; CHECK-NEXT: %elt1 = extractelement <2 x float> %data, i32 1
 ; CHECK-NEXT: %ins0 = insertvalue { float, float } undef, float %elt0, 0
 ; CHECK-NEXT: %ins1 = insertvalue { float, float } %ins0, float %elt1, 1
 ; CHECK-NEXT: ret { float, float } %ins1
@@ -169,6 +168,74 @@
   ret { float, float } %ins1
 }
 
+; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_2(
+; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+; CHECK-NEXT: %elt0 = extractelement <3 x float> %data, i32 0
+; CHECK-NEXT: %elt1 = extractelement <3 x float> %data, i32 1
+; CHECK-NEXT: %elt2 = extractelement <3 x float> %data, i32 2
+; CHECK-NEXT: %ins0 = insertvalue { float, float, float } undef, float %elt0, 0
+; CHECK-NEXT: %ins1 = insertvalue { float, float, float } %ins0, float %elt1, 1
+; CHECK-NEXT: %ins2 = insertvalue { float, float, float } %ins1, float %elt2, 2
+; CHECK-NEXT: ret { float, float, float } %ins2
+define amdgpu_ps { float, float, float } @extract_elt0_elt1_elt2_buffer_load_v4f32_2(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
+  %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+  %elt0 = extractelement <4 x float> %data, i32 0
+  %elt1 = extractelement <4 x float> %data, i32 1
+  %elt2 = extractelement <4 x float> %data, i32 2
+  %ins0 = insertvalue { float, float, float } undef, float %elt0, 0
+  %ins1 = insertvalue { float, float, float } %ins0, float %elt1, 1
+  %ins2 = insertvalue { float, float, float } %ins1, float %elt2, 2
+  ret { float, float, float } %ins2
+}
+
+; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_3(
+; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 0, i32 2>
+; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 undef, i32 1>
+; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf
+define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_3(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
+  %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+  %elt0 = extractelement <4 x float> %data, i32 0
+  %elt2 = extractelement <4 x float> %data, i32 2
+  %ins0 = insertelement <2 x float> undef, float %elt0, i32 0
+  %ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1
+  %shuf = shufflevector <4 x float> %data, <4 x float> undef, <2 x i32> <i32 4, i32 1>
+  %ret = fadd <2 x float> %ins1, %shuf
+  ret <2 x float> %ret
+}
+
+; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_4(
+; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 0, i32 2>
+; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 1, i32 undef>
+; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf
+; CHECK-NEXT: ret <2 x float> %ret
+define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_4(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
+  %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+  %elt0 = extractelement <4 x float> %data, i32 0
+  %elt2 = extractelement <4 x float> %data, i32 2
+  %ins0 = insertelement <2 x float> undef, float %elt0, i32 0
+  %ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1
+  %shuf = shufflevector <4 x float> undef, <4 x float> %data, <2 x i32> <i32 5, i32 1>
+  %ret = fadd <2 x float> %ins1, %shuf
+  ret <2 x float> %ret
+}
+
+; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_5(
+; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 2, i32 2>
+; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf
+define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_5(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
+  %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
+  %elt2 = extractelement <4 x float> %data, i32 2
+  %ins0 = insertelement <2 x float> undef, float %elt2, i32 0
+  %ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1
+  %shuf = shufflevector <4 x float> %data, <4 x float> %data, <2 x i32> <i32 0, i32 5>
+  %ret = fadd <2 x float> %ins1, %shuf
+  ret <2 x float> %ret
+}
+
 ; CHECK-LABEL: @extract_elt0_buffer_load_v3f32(
 ; CHECK-NEXT: %data = call float @llvm.amdgcn.buffer.load.f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
 ; CHECK-NEXT: ret float %data
