[DAGCombiner] Transform (zext (select c, load1, load2)) -> (select c, zextload1, zextload2)
If extload is legal, following transform
(zext (select c, load1, load2)) -> (select c, zextload1, zextload2)
can save one ext instruction.
Differential Revision: https://reviews.llvm.org/D95086
GitOrigin-RevId: 66f2d09ebf8d81d019a5524cdc5e7f88acbb7504
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 737997a..7f3aeeb 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10029,6 +10029,77 @@
return SDValue();
}
+/// Check if N satisfies:
+/// N is used once.
+/// N is a Load.
+/// The load is compatible with ExtOpcode. It means
+/// If load has explicit zero/sign extension, ExpOpcode must have the same
+/// extension.
+/// Otherwise returns true.
+static bool isCompatibleLoad(SDValue N, unsigned ExtOpcode) {
+ if (!N.hasOneUse())
+ return false;
+
+ if (!isa<LoadSDNode>(N))
+ return false;
+
+ LoadSDNode *Load = cast<LoadSDNode>(N);
+ ISD::LoadExtType LoadExt = Load->getExtensionType();
+ if (LoadExt == ISD::NON_EXTLOAD || LoadExt == ISD::EXTLOAD)
+ return true;
+
+ // Now LoadExt is either SEXTLOAD or ZEXTLOAD, ExtOpcode must have the same
+ // extension.
+ if ((LoadExt == ISD::SEXTLOAD && ExtOpcode != ISD::SIGN_EXTEND) ||
+ (LoadExt == ISD::ZEXTLOAD && ExtOpcode != ISD::ZERO_EXTEND))
+ return false;
+
+ return true;
+}
+
+/// Fold
+/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y)
+/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y)
+/// (aext (select c, load x, load y)) -> (select c, extload x, extload y)
+/// This function is called by the DAGCombiner when visiting sext/zext/aext
+/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND).
+static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI,
+ SelectionDAG &DAG) {
+ unsigned Opcode = N->getOpcode();
+ SDValue N0 = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND ||
+ Opcode == ISD::ANY_EXTEND) &&
+ "Expected EXTEND dag node in input!");
+
+ if (!(N0->getOpcode() == ISD::SELECT || N0->getOpcode() == ISD::VSELECT) ||
+ !N0.hasOneUse())
+ return SDValue();
+
+ SDValue Op1 = N0->getOperand(1);
+ SDValue Op2 = N0->getOperand(2);
+ if (!isCompatibleLoad(Op1, Opcode) || !isCompatibleLoad(Op2, Opcode))
+ return SDValue();
+
+ auto ExtLoadOpcode = ISD::EXTLOAD;
+ if (Opcode == ISD::SIGN_EXTEND)
+ ExtLoadOpcode = ISD::SEXTLOAD;
+ else if (Opcode == ISD::ZERO_EXTEND)
+ ExtLoadOpcode = ISD::ZEXTLOAD;
+
+ LoadSDNode *Load1 = cast<LoadSDNode>(Op1);
+ LoadSDNode *Load2 = cast<LoadSDNode>(Op2);
+ if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT()) ||
+ !TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load2->getMemoryVT()))
+ return SDValue();
+
+ SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1);
+ SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2);
+ return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2);
+}
+
/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or
/// a build_vector of constants.
/// This function is called by the DAGCombiner when visiting sext/zext/aext
@@ -10813,6 +10884,9 @@
return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT));
}
+ if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
+ return Res;
+
return SDValue();
}
@@ -11125,6 +11199,9 @@
if (SDValue NewCtPop = widenCtPop(N, DAG))
return NewCtPop;
+ if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
+ return Res;
+
return SDValue();
}
@@ -11277,6 +11354,9 @@
if (SDValue NewCtPop = widenCtPop(N, DAG))
return NewCtPop;
+ if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG))
+ return Res;
+
return SDValue();
}
diff --git a/test/CodeGen/X86/select-ext.ll b/test/CodeGen/X86/select-ext.ll
index acbd757..82e79b1 100644
--- a/test/CodeGen/X86/select-ext.ll
+++ b/test/CodeGen/X86/select-ext.ll
@@ -1,15 +1,14 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 | FileCheck %s
-; TODO: (zext(select c, load1, load2)) -> (select c, zextload1, zextload2)
+; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2)
define i64 @zext_scalar(i8* %p, i1 zeroext %c) {
; CHECK-LABEL: zext_scalar:
; CHECK: # %bb.0:
-; CHECK-NEXT: movzbl (%rdi), %eax
-; CHECK-NEXT: movzbl 1(%rdi), %ecx
+; CHECK-NEXT: movzbl (%rdi), %ecx
+; CHECK-NEXT: movzbl 1(%rdi), %eax
; CHECK-NEXT: testl %esi, %esi
-; CHECK-NEXT: cmovel %eax, %ecx
-; CHECK-NEXT: movzbl %cl, %eax
+; CHECK-NEXT: cmoveq %rcx, %rax
; CHECK-NEXT: retq
%ld1 = load volatile i8, i8* %p
%arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1
@@ -22,13 +21,10 @@
define i64 @zext_scalar2(i8* %p, i16* %q, i1 zeroext %c) {
; CHECK-LABEL: zext_scalar2:
; CHECK: # %bb.0:
-; CHECK-NEXT: movzbl (%rdi), %eax
-; CHECK-NEXT: testl %edx, %edx
-; CHECK-NEXT: je .LBB1_2
-; CHECK-NEXT: # %bb.1:
+; CHECK-NEXT: movzbl (%rdi), %ecx
; CHECK-NEXT: movzwl (%rsi), %eax
-; CHECK-NEXT: .LBB1_2:
-; CHECK-NEXT: movzwl %ax, %eax
+; CHECK-NEXT: testl %edx, %edx
+; CHECK-NEXT: cmoveq %rcx, %rax
; CHECK-NEXT: retq
%ld1 = load volatile i8, i8* %p
%ext_ld1 = zext i8 %ld1 to i16
@@ -58,15 +54,14 @@
ret i64 %cond
}
-; TODO: (sext(select c, load1, load2)) -> (select c, sextload1, sextload2)
+; (sext(select c, load1, load2)) -> (select c, sextload1, sextload2)
define i64 @sext_scalar(i8* %p, i1 zeroext %c) {
; CHECK-LABEL: sext_scalar:
; CHECK: # %bb.0:
-; CHECK-NEXT: movzbl (%rdi), %eax
-; CHECK-NEXT: movzbl 1(%rdi), %ecx
+; CHECK-NEXT: movsbq (%rdi), %rcx
+; CHECK-NEXT: movsbq 1(%rdi), %rax
; CHECK-NEXT: testl %esi, %esi
-; CHECK-NEXT: cmovel %eax, %ecx
-; CHECK-NEXT: movsbq %cl, %rax
+; CHECK-NEXT: cmoveq %rcx, %rax
; CHECK-NEXT: retq
%ld1 = load volatile i8, i8* %p
%arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1
@@ -80,14 +75,13 @@
define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
; CHECK-LABEL: zext_vector_i1:
; CHECK: # %bb.0:
-; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
+; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
+; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero
; CHECK-NEXT: testl %esi, %esi
; CHECK-NEXT: jne .LBB4_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: movdqa %xmm1, %xmm0
; CHECK-NEXT: .LBB4_2:
-; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@@ -100,12 +94,11 @@
define <2 x i64> @zext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
; CHECK-LABEL: zext_vector_v2i1:
; CHECK: # %bb.0:
-; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; CHECK-NEXT: pslld $31, %xmm0
-; CHECK-NEXT: movsd {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT: movsd {{.*#+}} xmm2 = mem[0],zero
-; CHECK-NEXT: blendvps %xmm0, %xmm2, %xmm1
-; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm1[0],zero,xmm1[1],zero
+; CHECK-NEXT: psllq $63, %xmm0
+; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero
+; CHECK-NEXT: pmovzxdq {{.*#+}} xmm2 = mem[0],zero,mem[1],zero
+; CHECK-NEXT: blendvpd %xmm0, %xmm2, %xmm1
+; CHECK-NEXT: movapd %xmm1, %xmm0
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@@ -119,14 +112,13 @@
define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) {
; CHECK-LABEL: sext_vector_i1:
; CHECK: # %bb.0:
-; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
+; CHECK-NEXT: pmovsxdq (%rdi), %xmm1
+; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm0
; CHECK-NEXT: testl %esi, %esi
; CHECK-NEXT: jne .LBB6_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: movdqa %xmm1, %xmm0
; CHECK-NEXT: .LBB6_2:
-; CHECK-NEXT: pmovsxdq %xmm0, %xmm0
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1
@@ -139,12 +131,11 @@
define <2 x i64> @sext_vector_v2i1(<2 x i32>* %p, <2 x i1> %c) {
; CHECK-LABEL: sext_vector_v2i1:
; CHECK: # %bb.0:
-; CHECK-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
-; CHECK-NEXT: pslld $31, %xmm0
-; CHECK-NEXT: movsd {{.*#+}} xmm1 = mem[0],zero
-; CHECK-NEXT: movsd {{.*#+}} xmm2 = mem[0],zero
-; CHECK-NEXT: blendvps %xmm0, %xmm2, %xmm1
-; CHECK-NEXT: pmovsxdq %xmm1, %xmm0
+; CHECK-NEXT: psllq $63, %xmm0
+; CHECK-NEXT: pmovsxdq (%rdi), %xmm1
+; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm2
+; CHECK-NEXT: blendvpd %xmm0, %xmm2, %xmm1
+; CHECK-NEXT: movapd %xmm1, %xmm0
; CHECK-NEXT: retq
%ld1 = load volatile <2 x i32>, <2 x i32>* %p
%arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1