[X86][AVX] lowerShuffleAsBroadcast - improve load folding by avoiding bitcasts

AVX1 broadcasts were failing as we were adding bitcasts that caused MayFoldLoad's hasOneUse to return false.

This patch stops introducing bitcasts so early and also replaces the broadcast index scaling through bitcasts (which can't succeed in some cases) to instead just keep track of the bitoffset which can be converted back to the broadcast index later on.

Differential Revision: https://reviews.llvm.org/D58888

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@356043 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index e8d47c6..fd7bfcb 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -11930,6 +11930,7 @@
   // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise
   // we can only broadcast from a register with AVX2.
   unsigned NumElts = Mask.size();
+  unsigned NumEltBits = VT.getScalarSizeInBits();
   unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2())
                         ? X86ISD::MOVDDUP
                         : X86ISD::VBROADCAST;
@@ -11953,29 +11954,19 @@
 
   // Go up the chain of (vector) values to find a scalar load that we can
   // combine with the broadcast.
+  int BitOffset = BroadcastIdx * NumEltBits;
   SDValue V = V1;
   for (;;) {
     switch (V.getOpcode()) {
     case ISD::BITCAST: {
-      // Peek through bitcasts as long as BroadcastIdx can be adjusted.
-      SDValue VSrc = V.getOperand(0);
-      unsigned NumEltBits = V.getScalarValueSizeInBits();
-      unsigned NumSrcBits = VSrc.getScalarValueSizeInBits();
-      if ((NumEltBits % NumSrcBits) == 0)
-        BroadcastIdx *= (NumEltBits / NumSrcBits);
-      else if ((NumSrcBits % NumEltBits) == 0 &&
-               (BroadcastIdx % (NumSrcBits / NumEltBits)) == 0)
-        BroadcastIdx /= (NumSrcBits / NumEltBits);
-      else
-        break;
-      V = VSrc;
+      V = V.getOperand(0);
       continue;
     }
     case ISD::CONCAT_VECTORS: {
-      int OperandSize =
-          V.getOperand(0).getSimpleValueType().getVectorNumElements();
-      V = V.getOperand(BroadcastIdx / OperandSize);
-      BroadcastIdx %= OperandSize;
+      int OpBitWidth = V.getOperand(0).getValueSizeInBits();
+      int OpIdx = BitOffset / OpBitWidth;
+      V = V.getOperand(OpIdx);
+      BitOffset %= OpBitWidth;
       continue;
     }
     case ISD::INSERT_SUBVECTOR: {
@@ -11984,11 +11975,13 @@
       if (!ConstantIdx)
         break;
 
-      int BeginIdx = (int)ConstantIdx->getZExtValue();
-      int EndIdx =
-          BeginIdx + (int)VInner.getSimpleValueType().getVectorNumElements();
-      if (BroadcastIdx >= BeginIdx && BroadcastIdx < EndIdx) {
-        BroadcastIdx -= BeginIdx;
+      int EltBitWidth = VOuter.getScalarValueSizeInBits();
+      int Idx = (int)ConstantIdx->getZExtValue();
+      int NumSubElts = (int)VInner.getSimpleValueType().getVectorNumElements();
+      int BeginOffset = Idx * EltBitWidth;
+      int EndOffset = BeginOffset + NumSubElts * EltBitWidth;
+      if (BeginOffset <= BitOffset && BitOffset < EndOffset) {
+        BitOffset -= BeginOffset;
         V = VInner;
       } else {
         V = VOuter;
@@ -11998,48 +11991,34 @@
     }
     break;
   }
+  assert((BitOffset % NumEltBits) == 0 && "Illegal bit-offset");
+  BroadcastIdx = BitOffset / NumEltBits;
 
-  // Ensure the source vector and BroadcastIdx are for a suitable type.
-  if (VT.getScalarSizeInBits() != V.getScalarValueSizeInBits()) {
-    unsigned NumEltBits = VT.getScalarSizeInBits();
-    unsigned NumSrcBits = V.getScalarValueSizeInBits();
-    if ((NumSrcBits % NumEltBits) == 0)
-      BroadcastIdx *= (NumSrcBits / NumEltBits);
-    else if ((NumEltBits % NumSrcBits) == 0 &&
-             (BroadcastIdx % (NumEltBits / NumSrcBits)) == 0)
-      BroadcastIdx /= (NumEltBits / NumSrcBits);
-    else
-      return SDValue();
-
-    unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
-    MVT SrcVT = MVT::getVectorVT(VT.getScalarType(), NumSrcElts);
-    V = DAG.getBitcast(SrcVT, V);
-  }
+  // Do we need to bitcast the source to retrieve the original broadcast index?
+  bool BitCastSrc = V.getScalarValueSizeInBits() != NumEltBits;
 
   // Check if this is a broadcast of a scalar. We special case lowering
   // for scalars so that we can more effectively fold with loads.
-  // First, look through bitcast: if the original value has a larger element
-  // type than the shuffle, the broadcast element is in essence truncated.
-  // Make that explicit to ease folding.
-  if (V.getOpcode() == ISD::BITCAST && VT.isInteger())
+  // If the original value has a larger element type than the shuffle, the
+  // broadcast element is in essence truncated. Make that explicit to ease
+  // folding.
+  if (BitCastSrc && VT.isInteger())
     if (SDValue TruncBroadcast = lowerShuffleAsTruncBroadcast(
-            DL, VT, V.getOperand(0), BroadcastIdx, Subtarget, DAG))
+            DL, VT, V, BroadcastIdx, Subtarget, DAG))
       return TruncBroadcast;
 
   MVT BroadcastVT = VT;
 
-  // Peek through any bitcast (only useful for loads).
-  SDValue BC = peekThroughBitcasts(V);
-
   // Also check the simpler case, where we can directly reuse the scalar.
-  if ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
-      (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0)) {
+  if (!BitCastSrc &&
+      ((V.getOpcode() == ISD::BUILD_VECTOR && V.hasOneUse()) ||
+       (V.getOpcode() == ISD::SCALAR_TO_VECTOR && BroadcastIdx == 0))) {
     V = V.getOperand(BroadcastIdx);
 
     // If we can't broadcast from a register, check that the input is a load.
     if (!BroadcastFromReg && !isShuffleFoldableLoad(V))
       return SDValue();
-  } else if (MayFoldLoad(BC) && !cast<LoadSDNode>(BC)->isVolatile()) {
+  } else if (MayFoldLoad(V) && !cast<LoadSDNode>(V)->isVolatile()) {
     // 32-bit targets need to load i64 as a f64 and then bitcast the result.
     if (!Subtarget.is64Bit() && VT.getScalarType() == MVT::i64) {
       BroadcastVT = MVT::getVectorVT(MVT::f64, VT.getVectorNumElements());
@@ -12050,10 +12029,11 @@
 
     // If we are broadcasting a load that is only used by the shuffle
     // then we can reduce the vector load to the broadcasted scalar load.
-    LoadSDNode *Ld = cast<LoadSDNode>(BC);
+    LoadSDNode *Ld = cast<LoadSDNode>(V);
     SDValue BaseAddr = Ld->getOperand(1);
     EVT SVT = BroadcastVT.getScalarType();
     unsigned Offset = BroadcastIdx * SVT.getStoreSize();
+    assert((Offset * 8) == BitOffset && "Unexpected bit-offset");
     SDValue NewAddr = DAG.getMemBasePlusOffset(BaseAddr, Offset, DL);
     V = DAG.getLoad(SVT, DL, Ld->getChain(), NewAddr,
                     DAG.getMachineFunction().getMachineMemOperand(
@@ -12062,7 +12042,7 @@
   } else if (!BroadcastFromReg) {
     // We can't broadcast from a vector register.
     return SDValue();
-  } else if (BroadcastIdx != 0) {
+  } else if (BitOffset != 0) {
     // We can only broadcast from the zero-element of a vector register,
     // but it can be advantageous to broadcast from the zero-element of a
     // subvector.
@@ -12074,18 +12054,15 @@
       return SDValue();
 
     // Only broadcast the zero-element of a 128-bit subvector.
-    unsigned EltSize = VT.getScalarSizeInBits();
-    if (((BroadcastIdx * EltSize) % 128) != 0)
+    if ((BitOffset % 128) != 0)
       return SDValue();
 
-    // The shuffle input might have been a bitcast we looked through; look at
-    // the original input vector.  Emit an EXTRACT_SUBVECTOR of that type; we'll
-    // later bitcast it to BroadcastVT.
-    assert(V.getScalarValueSizeInBits() == BroadcastVT.getScalarSizeInBits() &&
-           "Unexpected vector element size");
+    assert((BitOffset % V.getScalarValueSizeInBits()) == 0 &&
+           "Unexpected bit-offset");
     assert((V.getValueSizeInBits() == 256 || V.getValueSizeInBits() == 512) &&
            "Unexpected vector size");
-    V = extract128BitVector(V, BroadcastIdx, DAG, DL);
+    unsigned ExtractIdx = BitOffset / V.getScalarValueSizeInBits();
+    V = extract128BitVector(V, ExtractIdx, DAG, DL);
   }
 
   if (Opcode == X86ISD::MOVDDUP && !V.getValueType().isVector())
@@ -12093,21 +12070,21 @@
                     DAG.getBitcast(MVT::f64, V));
 
   // Bitcast back to the same scalar type as BroadcastVT.
-  MVT SrcVT = V.getSimpleValueType();
-  if (SrcVT.getScalarType() != BroadcastVT.getScalarType()) {
-    assert(SrcVT.getScalarSizeInBits() == BroadcastVT.getScalarSizeInBits() &&
+  if (V.getValueType().getScalarType() != BroadcastVT.getScalarType()) {
+    assert(NumEltBits == BroadcastVT.getScalarSizeInBits() &&
            "Unexpected vector element size");
-    if (SrcVT.isVector()) {
-      unsigned NumSrcElts = SrcVT.getVectorNumElements();
-      SrcVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts);
+    MVT ExtVT;
+    if (V.getValueType().isVector()) {
+      unsigned NumSrcElts = V.getValueSizeInBits() / NumEltBits;
+      ExtVT = MVT::getVectorVT(BroadcastVT.getScalarType(), NumSrcElts);
     } else {
-      SrcVT = BroadcastVT.getScalarType();
+      ExtVT = BroadcastVT.getScalarType();
     }
-    V = DAG.getBitcast(SrcVT, V);
+    V = DAG.getBitcast(ExtVT, V);
   }
 
   // 32-bit targets need to load i64 as a f64 and then bitcast the result.
-  if (!Subtarget.is64Bit() && SrcVT == MVT::i64) {
+  if (!Subtarget.is64Bit() && V.getValueType() == MVT::i64) {
     V = DAG.getBitcast(MVT::f64, V);
     unsigned NumBroadcastElts = BroadcastVT.getVectorNumElements();
     BroadcastVT = MVT::getVectorVT(MVT::f64, NumBroadcastElts);
@@ -12116,9 +12093,9 @@
   // We only support broadcasting from 128-bit vectors to minimize the
   // number of patterns we need to deal with in isel. So extract down to
   // 128-bits, removing as many bitcasts as possible.
-  if (SrcVT.getSizeInBits() > 128) {
-    MVT ExtVT = MVT::getVectorVT(SrcVT.getScalarType(),
-                                 128 / SrcVT.getScalarSizeInBits());
+  if (V.getValueSizeInBits() > 128) {
+    MVT ExtVT = V.getSimpleValueType().getScalarType();
+    ExtVT = MVT::getVectorVT(ExtVT, 128 / ExtVT.getScalarSizeInBits());
     V = extract128BitVector(peekThroughBitcasts(V), 0, DAG, DL);
     V = DAG.getBitcast(ExtVT, V);
   }
diff --git a/test/CodeGen/X86/widened-broadcast.ll b/test/CodeGen/X86/widened-broadcast.ll
index b43c8a4..2ffc413 100644
--- a/test/CodeGen/X86/widened-broadcast.ll
+++ b/test/CodeGen/X86/widened-broadcast.ll
@@ -110,21 +110,10 @@
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_8i32_4i32_01010101:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_8i32_4i32_01010101:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_8i32_4i32_01010101:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_8i32_4i32_01010101:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <4 x i32>, <4 x i32>* %ptr
   %ret = shufflevector <4 x i32> %ld, <4 x i32> undef, <8 x i32> <i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1>
@@ -207,21 +196,10 @@
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_16i16_8i16_0101010101010101:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,0,0,0]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_16i16_8i16_0101010101010101:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_16i16_8i16_0101010101010101:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_16i16_8i16_0101010101010101:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastss (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <8 x i16>, <8 x i16>* %ptr
   %ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1>
@@ -235,21 +213,10 @@
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_16i16_8i16_0123012301230123:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_16i16_8i16_0123012301230123:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_16i16_8i16_0123012301230123:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_16i16_8i16_0123012301230123:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <8 x i16>, <8 x i16>* %ptr
   %ret = shufflevector <8 x i16> %ld, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3,i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
@@ -407,21 +374,10 @@
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,0,0,0]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastss (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_32i8_16i8_01230123012301230123012301230123:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastss (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <16 x i8>, <16 x i8>* %ptr
   %ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
@@ -435,21 +391,10 @@
 ; SSE-NEXT:    movdqa %xmm0, %xmm1
 ; SSE-NEXT:    retq
 ;
-; AVX1-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
-; AVX1:       # %bb.0: # %entry
-; AVX1-NEXT:    vpermilps {{.*#+}} xmm0 = mem[0,1,0,1]
-; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
-; AVX1-NEXT:    retq
-;
-; AVX2-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
-; AVX2:       # %bb.0: # %entry
-; AVX2-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX2-NEXT:    retq
-;
-; AVX512-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
-; AVX512:       # %bb.0: # %entry
-; AVX512-NEXT:    vbroadcastsd (%rdi), %ymm0
-; AVX512-NEXT:    retq
+; AVX-LABEL: load_splat_32i8_16i8_01234567012345670123456701234567:
+; AVX:       # %bb.0: # %entry
+; AVX-NEXT:    vbroadcastsd (%rdi), %ymm0
+; AVX-NEXT:    retq
 entry:
   %ld = load <16 x i8>, <16 x i8>* %ptr
   %ret = shufflevector <16 x i8> %ld, <16 x i8> undef, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>