[DAGCombiner] Transform (zext (select c, load1, load2)) -> (select c, zextload1, zextload2)

If extload is legal, following transform
    (zext (select c, load1, load2)) -> (select c, zextload1, zextload2)
can save one ext instruction.

Differential Revision: https://reviews.llvm.org/D95086

GitOrigin-RevId: 66f2d09ebf8d81d019a5524cdc5e7f88acbb7504
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 737997a..7f3aeeb 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10029,6 +10029,77 @@
   return SDValue();
 }
 
+/// Check if N satisfies:
+///   N is used once.
+///   N is a Load.
+///   The load is compatible with ExtOpcode. It means
+///     If load has explicit zero/sign extension, ExpOpcode must have the same
+///     extension.
+///     Otherwise returns true.
+static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
+  if (!N.hasOneUse())
+    return false;
+
+  if (!isa<LoadSDNode>(N))
+    return false;
+
+  LoadSDNode *Load = cast<LoadSDNode>(N);
+  ISD::LoadExtType LoadExt = Load->getExtensionType();
+  if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
+    return true;
+
+  // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
+  // extension.
+  if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
+      (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
+    return false;
+
+  return true;
+}
+
+/// Fold
+///   (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
+///   (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
+///   (aext (select c, load x, load y)) -> (select c, extload x, extload y)
+/// This function is called by the DAGCombiner when visiting sext/zext/aext
+/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
+static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
+                                         SelectionDAG &DAG) {
+  unsigned Opcode = N->getOpcode();
+  SDValue N0 = N->getOperand(0);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
+          Opcode == ISD::ANY_EXTEND) &&
+         "Expected EXTEND dag node in input!");
+
+  if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
+      !N0.hasOneUse())
+    return SDValue();
+
+  SDValue Op1 = N0->getOperand(1);
+  SDValue Op2 = N0->getOperand(2);
+  if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
+    return SDValue();
+
+  auto ExtLoadOpcode = ISD::EXTLOAD;
+  if (Opcode == ISD::SIGN_EXTEND)
+    ExtLoadOpcode = ISD::SEXTLOAD;
+  else if (Opcode == ISD::ZERO_EXTEND)
+    ExtLoadOpcode = ISD::ZEXTLOAD;
+
+  LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
+  LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
+  if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
+      !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
+    return SDValue();
+
+  SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
+  SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
+  return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
+}
+
 /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
 /// a build_vector of constants.
 /// This function is called by the DAGCombiner when visiting sext/zext/aext
@@ -10813,6 +10884,9 @@
     return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
   }
 
+  if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
+    return Res;
+
   return SDValue();
 }
 
@@ -11125,6 +11199,9 @@
   if (SDValue NewCtPop = widenCtPop(N, DAG))
     return NewCtPop;
 
+  if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
+    return Res;
+
   return SDValue();
 }
 
@@ -11277,6 +11354,9 @@
   if (SDValue NewCtPop = widenCtPop(N, DAG))
     return NewCtPop;
 
+  if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
+    return Res;
+
   return SDValue();
 }
 
diff --git a/test/CodeGen/X86/select-ext.ll b/test/CodeGen/X86/select-ext.ll
index acbd757..82e79b1 100644
--- a/test/CodeGen/X86/select-ext.ll
+++ b/test/CodeGen/X86/select-ext.ll
@@ -1,15 +1,14 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 | FileCheck %s
 
-; TODO: (zext(select c, load1, load2)) -> (select c, zextload1, zextload2)
+; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2)
 define i64 @zext_scalar(i8* %p, i1 zeroext %c) {
 ; CHECK-LABEL: zext_scalar:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    movzbl (%rdi), %eax
-; CHECK-NEXT:    movzbl 1(%rdi), %ecx
+; CHECK-NEXT:    movzbl (%rdi), %ecx
+; CHECK-NEXT:    movzbl 1(%rdi), %eax
 ; CHECK-NEXT:    testl %esi, %esi
-; CHECK-NEXT:    cmovel %eax, %ecx
-; CHECK-NEXT:    movzbl %cl, %eax
+; CHECK-NEXT:    cmoveq %rcx, %rax
 ; CHECK-NEXT:    retq
   %ld1 = load volatile i8, i8* %p
   %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1
@@ -22,13 +21,10 @@
 define i64 @zext_scalar2(i8* %p, i16* %q, i1 zeroext %c) {
 ; CHECK-LABEL: zext_scalar2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    movzbl (%rdi), %eax
-; CHECK-NEXT:    testl %edx, %edx
-; CHECK-NEXT:    je .LBB1_2
-; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    movzbl (%rdi), %ecx
 ; CHECK-NEXT:    movzwl (%rsi), %eax
-; CHECK-NEXT:  .LBB1_2:
-; CHECK-NEXT:    movzwl %ax, %eax
+; CHECK-NEXT:    testl %edx, %edx
+; CHECK-NEXT:    cmoveq %rcx, %rax
 ; CHECK-NEXT:    retq
   %ld1 = load volatile i8, i8* %p
   %ext_ld1 = zext i8 %ld1 to i16
@@ -58,15 +54,14 @@
   ret i64 %cond
 }
 
-; TODO: (sext(select c, load1, load2)) -> (select c, sextload1, sextload2)
+; (sext(select c, load1, load2)) -> (select c, sextload1, sextload2)
 define i64 @sext_scalar(i8* %p, i1 zeroext %c) {
 ; CHECK-LABEL: sext_scalar:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    movzbl (%rdi), %eax
-; CHECK-NEXT:    movzbl 1(%rdi), %ecx
+; CHECK-NEXT:    movsbq (%rdi), %rcx
+; CHECK-NEXT:    movsbq 1(%rdi), %rax
 ; CHECK-NEXT:    testl %esi, %esi
-; CHECK-NEXT:    cmovel %eax, %ecx
-; CHECK-NEXT:    movsbq %cl, %rax
+; CHECK-NEXT:    cmoveq %rcx, %rax
 ; CHECK-NEXT:    retq
   %ld1 = load volatile i8, i8* %p
   %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1
@@ -80,14 +75,13 @@
 define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
 ; CHECK-LABEL: zext_vector_i1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    movq {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
+; CHECK-NEXT:    pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
+; CHECK-NEXT:    pmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero
 ; CHECK-NEXT:    testl %esi, %esi
 ; CHECK-NEXT:    jne .LBB4_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    movdqa %xmm1, %xmm0
 ; CHECK-NEXT:  .LBB4_2:
-; CHECK-NEXT:    pmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
 ; CHECK-NEXT:    retq
   %ld1 = load volatile <2 x i32>, <2 x i32>* %p
   %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@@ -100,12 +94,11 @@
 define <2 x i64> @zext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
 ; CHECK-LABEL: zext_vector_v2i1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; CHECK-NEXT:    pslld $31, %xmm0
-; CHECK-NEXT:    movsd {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT:    movsd {{.*#+}} xmm2 = mem[0],zero
-; CHECK-NEXT:    blendvps %xmm0, %xmm2, %xmm1
-; CHECK-NEXT:    pmovzxdq {{.*#+}} xmm0 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT:    psllq $63, %xmm0
+; CHECK-NEXT:    pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
+; CHECK-NEXT:    pmovzxdq {{.*#+}} xmm2 = mem[0],zero,mem[1],zero
+; CHECK-NEXT:    blendvpd %xmm0, %xmm2, %xmm1
+; CHECK-NEXT:    movapd %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %ld1 = load volatile <2 x i32>, <2 x i32>* %p
   %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@@ -119,14 +112,13 @@
 define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
 ; CHECK-LABEL: sext_vector_i1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    movq {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
+; CHECK-NEXT:    pmovsxdq (%rdi), %xmm1
+; CHECK-NEXT:    pmovsxdq 8(%rdi), %xmm0
 ; CHECK-NEXT:    testl %esi, %esi
 ; CHECK-NEXT:    jne .LBB6_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    movdqa %xmm1, %xmm0
 ; CHECK-NEXT:  .LBB6_2:
-; CHECK-NEXT:    pmovsxdq %xmm0, %xmm0
 ; CHECK-NEXT:    retq
   %ld1 = load volatile <2 x i32>, <2 x i32>* %p
   %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@@ -139,12 +131,11 @@
 define <2 x i64> @sext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
 ; CHECK-LABEL: sext_vector_v2i1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; CHECK-NEXT:    pslld $31, %xmm0
-; CHECK-NEXT:    movsd {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT:    movsd {{.*#+}} xmm2 = mem[0],zero
-; CHECK-NEXT:    blendvps %xmm0, %xmm2, %xmm1
-; CHECK-NEXT:    pmovsxdq %xmm1, %xmm0
+; CHECK-NEXT:    psllq $63, %xmm0
+; CHECK-NEXT:    pmovsxdq (%rdi), %xmm1
+; CHECK-NEXT:    pmovsxdq 8(%rdi), %xmm2
+; CHECK-NEXT:    blendvpd %xmm0, %xmm2, %xmm1
+; CHECK-NEXT:    movapd %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %ld1 = load volatile <2 x i32>, <2 x i32>* %p
   %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1