[VectorCombine] try to form a better extractelement

Extracting to the same index that we are going to insert back into
allows forming select ("blend") shuffles and enables further transforms.

Admittedly, this is a quick-fix for a more general problem that I'm
hoping to solve by adding transforms for patterns that start with an
insertelement.

But this might resolve some regressions known to be caused by the
extract-extract transform (although I have not gotten more details on
those yet).

In the motivating case from PR34724:
https://bugs.llvm.org/show_bug.cgi?id=34724

The combination of subsequent instcombine and codegen transforms gets us this improvement:

  vmovshdup	%xmm0, %xmm2    ## xmm2 = xmm0[1,1,3,3]
  vhaddps	%xmm1, %xmm1, %xmm4
  vmovshdup	%xmm1, %xmm3    ## xmm3 = xmm1[1,1,3,3]
  vaddps	%xmm0, %xmm2, %xmm0
  vaddps	%xmm1, %xmm3, %xmm1
  vshufps	$200, %xmm4, %xmm0, %xmm0 ## xmm0 = xmm0[0,2],xmm4[0,3]
  vinsertps	$177, %xmm1, %xmm0, %xmm0 ## xmm0 = zero,xmm0[1,2],xmm1[2]

  -->

  vmovshdup	%xmm0, %xmm2    ## xmm2 = xmm0[1,1,3,3]
  vhaddps	%xmm1, %xmm1, %xmm1
  vaddps	%xmm0, %xmm2, %xmm0
  vshufps	$200, %xmm1, %xmm0, %xmm0 ## xmm0 = xmm0[0,2],xmm1[0,3]

Differential Revision: https://reviews.llvm.org/D76623
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 66095af..444290b 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -52,7 +52,8 @@
 static bool isExtractExtractCheap(Instruction *Ext0, Instruction *Ext1,
                                   unsigned Opcode,
                                   const TargetTransformInfo &TTI,
-                                  Instruction *&ConvertToShuffle) {
+                                  Instruction *&ConvertToShuffle,
+                                  unsigned PreferredExtractIndex) {
   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
          isa<ConstantInt>(Ext1->getOperand(1)) &&
          "Expected constant extract indexes");
@@ -131,12 +132,17 @@
     NewCost +=
         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
 
-    // The more expensive extract will be replaced by a shuffle. If the extracts
-    // have the same cost, replace the extract with the higher index.
+    // The more expensive extract will be replaced by a shuffle. If the costs
+    // are equal and there is a preferred extract index, shuffle the opposite
+    // operand. Otherwise, replace the extract with the higher index.
     if (Extract0Cost > Extract1Cost)
       ConvertToShuffle = Ext0;
     else if (Extract1Cost > Extract0Cost)
       ConvertToShuffle = Ext1;
+    else if (PreferredExtractIndex == Ext0Index)
+      ConvertToShuffle = Ext1;
+    else if (PreferredExtractIndex == Ext1Index)
+      ConvertToShuffle = Ext0;
     else
       ConvertToShuffle = Ext0Index > Ext1Index ? Ext0 : Ext1;
   }
@@ -209,8 +215,19 @@
       V0->getType() != V1->getType())
     return false;
 
+  // If the scalar value 'I' is going to be re-inserted into a vector, then try
+  // to create an extract to that same element. The extract/insert can be
+  // reduced to a "select shuffle".
+  // TODO: If we add a larger pattern match that starts from an insert, this
+  //       probably becomes unnecessary.
+  uint64_t InsertIndex = std::numeric_limits<uint64_t>::max();
+  if (I.hasOneUse())
+    match(I.user_back(), m_InsertElement(m_Value(), m_Value(),
+                                         m_ConstantInt(InsertIndex)));
+
   Instruction *ConvertToShuffle;
-  if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle))
+  if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), TTI, ConvertToShuffle,
+                            InsertIndex))
     return false;
 
   if (ConvertToShuffle) {
diff --git a/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll b/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll
index cc3d35f..bffd6ab 100644
--- a/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/extract-binop.ll
@@ -418,9 +418,9 @@
 
 define <4 x float> @ins_bo_ext_ext(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: @ins_bo_ext_ext(
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 3, i32 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = fadd <4 x float> [[A]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x float> [[TMP2]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 undef, i32 2>
+; CHECK-NEXT:    [[TMP2:%.*]] = fadd <4 x float> [[TMP1]], [[A]]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x float> [[TMP2]], i64 3
 ; CHECK-NEXT:    [[V3:%.*]] = insertelement <4 x float> [[B:%.*]], float [[TMP3]], i32 3
 ; CHECK-NEXT:    ret <4 x float> [[V3]]
 ;
@@ -431,6 +431,9 @@
   ret <4 x float> %v3
 }
 
+; TODO: This is conservatively left to extract from the lower index value,
+;       but it is likely that extracting from index 3 is the better option.
+
 define <4 x float> @ins_bo_ext_ext_uses(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: @ins_bo_ext_ext_uses(
 ; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 3, i32 undef>
@@ -452,13 +455,13 @@
 ; CHECK-LABEL: @PR34724(
 ; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 3, i32 undef>
 ; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x float> [[B:%.*]], <4 x float> undef, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x float> [[B]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 3, i32 undef>
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x float> [[B]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 undef, i32 2>
 ; CHECK-NEXT:    [[TMP4:%.*]] = fadd <4 x float> [[A]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x float> [[TMP4]], i32 2
 ; CHECK-NEXT:    [[TMP6:%.*]] = fadd <4 x float> [[B]], [[TMP2]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x float> [[TMP6]], i32 0
-; CHECK-NEXT:    [[TMP8:%.*]] = fadd <4 x float> [[B]], [[TMP3]]
-; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x float> [[TMP8]], i32 2
+; CHECK-NEXT:    [[TMP8:%.*]] = fadd <4 x float> [[TMP3]], [[B]]
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x float> [[TMP8]], i64 3
 ; CHECK-NEXT:    [[V1:%.*]] = insertelement <4 x float> undef, float [[TMP5]], i32 1
 ; CHECK-NEXT:    [[V2:%.*]] = insertelement <4 x float> [[V1]], float [[TMP7]], i32 2
 ; CHECK-NEXT:    [[V3:%.*]] = insertelement <4 x float> [[V2]], float [[TMP9]], i32 3
diff --git a/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll b/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll
index 807bb80..6f6f6d0 100644
--- a/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/extract-cmp.ll
@@ -161,9 +161,9 @@
 ; SSE-NEXT:    ret <4 x i1> [[R]]
 ;
 ; AVX-LABEL: @ins_fcmp_ext_ext(
-; AVX-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <4 x i32> <i32 undef, i32 2, i32 undef, i32 undef>
-; AVX-NEXT:    [[TMP2:%.*]] = fcmp ugt <4 x float> [[TMP1]], [[A]]
-; AVX-NEXT:    [[TMP3:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1
+; AVX-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <4 x i32> <i32 undef, i32 undef, i32 1, i32 undef>
+; AVX-NEXT:    [[TMP2:%.*]] = fcmp ugt <4 x float> [[A]], [[TMP1]]
+; AVX-NEXT:    [[TMP3:%.*]] = extractelement <4 x i1> [[TMP2]], i32 2
 ; AVX-NEXT:    [[R:%.*]] = insertelement <4 x i1> [[B:%.*]], i1 [[TMP3]], i32 2
 ; AVX-NEXT:    ret <4 x i1> [[R]]
 ;
@@ -176,9 +176,9 @@
 
 define <4 x i1> @ins_icmp_ext_ext(<4 x i32> %a, <4 x i1> %b) {
 ; CHECK-LABEL: @ins_icmp_ext_ext(
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> <i32 undef, i32 undef, i32 3, i32 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ule <4 x i32> [[A]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i1> [[TMP2]], i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <4 x i32> <i32 undef, i32 undef, i32 undef, i32 2>
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ule <4 x i32> [[TMP1]], [[A]]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i1> [[TMP2]], i64 3
 ; CHECK-NEXT:    [[R:%.*]] = insertelement <4 x i1> [[B:%.*]], i1 [[TMP3]], i32 3
 ; CHECK-NEXT:    ret <4 x i1> [[R]]
 ;