[X86][AARCH64] Improve ISD::ABS support

This patch takes some of the code from D49837 to allow us to enable ISD::ABS support for all SSE vector types.

Differential Revision: https://reviews.llvm.org/D56544

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@350998 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/include/llvm/CodeGen/TargetLowering.h b/include/llvm/CodeGen/TargetLowering.h
index f18ba19..7cb6341 100644
--- a/include/llvm/CodeGen/TargetLowering.h
+++ b/include/llvm/CodeGen/TargetLowering.h
@@ -3787,6 +3787,14 @@
   /// \returns True, if the expansion was successful, false otherwise
   bool expandCTTZ(SDNode *N, SDValue &Result, SelectionDAG &DAG) const;
 
+  /// Expand ABS nodes. Expands vector/scalar ABS nodes,
+  /// vector nodes can only succeed if all operations are legal/custom.
+  /// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
+  /// \param N Node to expand
+  /// \param Result output after conversion
+  /// \returns True, if the expansion was successful, false otherwise
+  bool expandABS(SDNode *N, SDValue &Result, SelectionDAG &DAG) const;
+
   /// Turn load of vector type into a load of the individual elements.
   /// \param LD load to expand
   /// \returns MERGE_VALUEs of the scalar loads with their chains.
diff --git a/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 9d6a693..0e5d924 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -2645,6 +2645,10 @@
   SDValue Tmp1, Tmp2, Tmp3, Tmp4;
   bool NeedInvert;
   switch (Node->getOpcode()) {
+  case ISD::ABS:
+    if (TLI.expandABS(Node, Tmp1, DAG))
+      Results.push_back(Tmp1);
+    break;
   case ISD::CTPOP:
     if (TLI.expandCTPOP(Node, Tmp1, DAG))
       Results.push_back(Tmp1);
diff --git a/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 753158d..6e0bc97 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -117,6 +117,12 @@
   /// the remaining lanes, finally bitcasting to the proper type.
   SDValue ExpandZERO_EXTEND_VECTOR_INREG(SDValue Op);
 
+  /// Implement expand-based legalization of ABS vector operations.
+  /// If following expanding is legal/custom then do it:
+  /// (ABS x) --> (XOR (ADD x, (SRA x, sizeof(x)-1)), (SRA x, sizeof(x)-1))
+  /// else unroll the operation.
+  SDValue ExpandABS(SDValue Op);
+
   /// Expand bswap of vectors into a shuffle if legal.
   SDValue ExpandBSWAP(SDValue Op);
 
@@ -355,6 +361,7 @@
   case ISD::FSHR:
   case ISD::ROTL:
   case ISD::ROTR:
+  case ISD::ABS:
   case ISD::BSWAP:
   case ISD::BITREVERSE:
   case ISD::CTLZ:
@@ -749,6 +756,8 @@
     return ExpandFSUB(Op);
   case ISD::SETCC:
     return UnrollVSETCC(Op);
+  case ISD::ABS:
+    return ExpandABS(Op);
   case ISD::BITREVERSE:
     return ExpandBITREVERSE(Op);
   case ISD::CTPOP:
@@ -1064,6 +1073,16 @@
   return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Val);
 }
 
+SDValue VectorLegalizer::ExpandABS(SDValue Op) {
+  // Attempt to expand using TargetLowering.
+  SDValue Result;
+  if (TLI.expandABS(Op.getNode(), Result, DAG))
+    return Result;
+
+  // Otherwise go ahead and unroll.
+  return DAG.UnrollVectorOp(Op.getNode());
+}
+
 SDValue VectorLegalizer::ExpandFP_TO_UINT(SDValue Op) {
   // Attempt to expand using TargetLowering.
   SDValue Result;
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index ca0db5c..1d8615c 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -4715,6 +4715,26 @@
   return true;
 }
 
+bool TargetLowering::expandABS(SDNode *N, SDValue &Result,
+                               SelectionDAG &DAG) const {
+  SDLoc dl(N);
+  EVT VT = N->getValueType(0);
+  SDValue Op = N->getOperand(0);
+
+  // Only expand vector types if we have the appropriate vector operations.
+  if (VT.isVector() && (!isOperationLegalOrCustom(ISD::SRA, VT) ||
+                        !isOperationLegalOrCustom(ISD::ADD, VT) ||
+                        !isOperationLegalOrCustomOrPromote(ISD::XOR, VT)))
+    return false;
+
+  SDValue Shift =
+      DAG.getNode(ISD::SRA, dl, VT, Op,
+                  DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, VT));
+  SDValue Add = DAG.getNode(ISD::ADD, dl, VT, Op, Shift);
+  Result = DAG.getNode(ISD::XOR, dl, VT, Add, Shift);
+  return true;
+}
+
 SDValue TargetLowering::scalarizeVectorLoad(LoadSDNode *LD,
                                             SelectionDAG &DAG) const {
   SDLoc SL(LD);
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index d4a7a94..4989c61 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -867,6 +867,7 @@
     for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
       setOperationAction(ISD::SETCC,              VT, Custom);
       setOperationAction(ISD::CTPOP,              VT, Custom);
+      setOperationAction(ISD::ABS,                VT, Custom);
 
       // The condition codes aren't legal in SSE/AVX and under AVX512 we use
       // setcc all the way to isel and prefer SETGT in some isel patterns.
@@ -1207,6 +1208,7 @@
     setOperationAction(ISD::MULHU,     MVT::v32i8,  Custom);
     setOperationAction(ISD::MULHS,     MVT::v32i8,  Custom);
 
+    setOperationAction(ISD::ABS,       MVT::v4i64,  Custom);
     setOperationAction(ISD::SMAX,      MVT::v4i64,  Custom);
     setOperationAction(ISD::UMAX,      MVT::v4i64,  Custom);
     setOperationAction(ISD::SMIN,      MVT::v4i64,  Custom);
@@ -23585,7 +23587,8 @@
   return split256IntArith(Op, DAG);
 }
 
-static SDValue LowerABS(SDValue Op, SelectionDAG &DAG) {
+static SDValue LowerABS(SDValue Op, const X86Subtarget &Subtarget,
+                        SelectionDAG &DAG) {
   MVT VT = Op.getSimpleValueType();
   if (VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) {
     // Since X86 does not have CMOV for 8-bit integer, we don't convert
@@ -23599,10 +23602,14 @@
     return DAG.getNode(X86ISD::CMOV, DL, VT, Ops);
   }
 
-  assert(Op.getSimpleValueType().is256BitVector() &&
-         Op.getSimpleValueType().isInteger() &&
-         "Only handle AVX 256-bit vector integer operation");
-  return Lower256IntUnary(Op, DAG);
+  if (VT.is256BitVector() && !Subtarget.hasInt256()) {
+    assert(VT.isInteger() &&
+           "Only handle AVX 256-bit vector integer operation");
+    return Lower256IntUnary(Op, DAG);
+  }
+
+  // Default to expand.
+  return SDValue();
 }
 
 static SDValue LowerMINMAX(SDValue Op, SelectionDAG &DAG) {
@@ -26287,7 +26294,7 @@
   case ISD::SMIN:
   case ISD::UMAX:
   case ISD::UMIN:               return LowerMINMAX(Op, DAG);
-  case ISD::ABS:                return LowerABS(Op, DAG);
+  case ISD::ABS:                return LowerABS(Op, Subtarget, DAG);
   case ISD::FSINCOS:            return LowerFSINCOS(Op, Subtarget, DAG);
   case ISD::MLOAD:              return LowerMLOAD(Op, Subtarget, DAG);
   case ISD::MSTORE:             return LowerMSTORE(Op, Subtarget, DAG);
diff --git a/test/CodeGen/AArch64/arm64-vabs.ll b/test/CodeGen/AArch64/arm64-vabs.ll
index 53669a1..8a0a2dd 100644
--- a/test/CodeGen/AArch64/arm64-vabs.ll
+++ b/test/CodeGen/AArch64/arm64-vabs.ll
@@ -542,7 +542,8 @@
 
 define i64 @abs_1d_honestly(i64 %A) nounwind {
 ; CHECK-LABEL: abs_1d_honestly:
-; CHECK: abs d0, d0
+; CHECK:       cmp x0, #0
+; CHECK-NEXT:  cneg x0, x0, mi
   %abs = call i64 @llvm.aarch64.neon.abs.i64(i64 %A)
   ret i64 %abs
 }
diff --git a/test/CodeGen/X86/combine-abs.ll b/test/CodeGen/X86/combine-abs.ll
index 30bba8d..3a2c477 100644
--- a/test/CodeGen/X86/combine-abs.ll
+++ b/test/CodeGen/X86/combine-abs.ll
@@ -67,9 +67,6 @@
 ; AVX2-LABEL: combine_v4i64_abs_abs:
 ; AVX2:       # %bb.0:
 ; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT:    vpcmpgtq %ymm0, %ymm1, %ymm2
-; AVX2-NEXT:    vpaddq %ymm2, %ymm0, %ymm0
-; AVX2-NEXT:    vpxor %ymm2, %ymm0, %ymm0
 ; AVX2-NEXT:    vpcmpgtq %ymm0, %ymm1, %ymm1
 ; AVX2-NEXT:    vpaddq %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    vpxor %ymm1, %ymm0, %ymm0
diff --git a/test/CodeGen/X86/viabs.ll b/test/CodeGen/X86/viabs.ll
index 7241cd1..06dd168 100644
--- a/test/CodeGen/X86/viabs.ll
+++ b/test/CodeGen/X86/viabs.ll
@@ -583,12 +583,12 @@
 ; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm1
 ; AVX1-NEXT:    vpxor %xmm2, %xmm2, %xmm2
 ; AVX1-NEXT:    vpcmpgtq %xmm1, %xmm2, %xmm3
-; AVX1-NEXT:    vpcmpgtq %xmm0, %xmm2, %xmm2
-; AVX1-NEXT:    vinsertf128 $1, %xmm3, %ymm2, %ymm4
 ; AVX1-NEXT:    vpaddq %xmm3, %xmm1, %xmm1
+; AVX1-NEXT:    vpxor %xmm3, %xmm1, %xmm1
+; AVX1-NEXT:    vpcmpgtq %xmm0, %xmm2, %xmm2
 ; AVX1-NEXT:    vpaddq %xmm2, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm2, %xmm0, %xmm0
 ; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm0
-; AVX1-NEXT:    vxorps %ymm4, %ymm0, %ymm0
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: test_abs_gt_v4i64:
@@ -639,20 +639,20 @@
 ; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm2
 ; AVX1-NEXT:    vpxor %xmm3, %xmm3, %xmm3
 ; AVX1-NEXT:    vpcmpgtq %xmm2, %xmm3, %xmm4
-; AVX1-NEXT:    vpcmpgtq %xmm0, %xmm3, %xmm5
-; AVX1-NEXT:    vinsertf128 $1, %xmm4, %ymm5, %ymm6
 ; AVX1-NEXT:    vpaddq %xmm4, %xmm2, %xmm2
-; AVX1-NEXT:    vpaddq %xmm5, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm4, %xmm2, %xmm2
+; AVX1-NEXT:    vpcmpgtq %xmm0, %xmm3, %xmm4
+; AVX1-NEXT:    vpaddq %xmm4, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm4, %xmm0, %xmm0
 ; AVX1-NEXT:    vinsertf128 $1, %xmm2, %ymm0, %ymm0
-; AVX1-NEXT:    vxorps %ymm6, %ymm0, %ymm0
 ; AVX1-NEXT:    vextractf128 $1, %ymm1, %xmm2
 ; AVX1-NEXT:    vpcmpgtq %xmm2, %xmm3, %xmm4
-; AVX1-NEXT:    vpcmpgtq %xmm1, %xmm3, %xmm3
-; AVX1-NEXT:    vinsertf128 $1, %xmm4, %ymm3, %ymm5
 ; AVX1-NEXT:    vpaddq %xmm4, %xmm2, %xmm2
+; AVX1-NEXT:    vpxor %xmm4, %xmm2, %xmm2
+; AVX1-NEXT:    vpcmpgtq %xmm1, %xmm3, %xmm3
 ; AVX1-NEXT:    vpaddq %xmm3, %xmm1, %xmm1
+; AVX1-NEXT:    vpxor %xmm3, %xmm1, %xmm1
 ; AVX1-NEXT:    vinsertf128 $1, %xmm2, %ymm1, %ymm1
-; AVX1-NEXT:    vxorps %ymm5, %ymm1, %ymm1
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: test_abs_le_v8i64:
@@ -713,19 +713,19 @@
 ; AVX1-NEXT:    vmovdqu 48(%rdi), %xmm3
 ; AVX1-NEXT:    vpxor %xmm4, %xmm4, %xmm4
 ; AVX1-NEXT:    vpcmpgtq %xmm1, %xmm4, %xmm5
-; AVX1-NEXT:    vpcmpgtq %xmm0, %xmm4, %xmm6
-; AVX1-NEXT:    vinsertf128 $1, %xmm5, %ymm6, %ymm7
 ; AVX1-NEXT:    vpaddq %xmm5, %xmm1, %xmm1
-; AVX1-NEXT:    vpaddq %xmm6, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm5, %xmm1, %xmm1
+; AVX1-NEXT:    vpcmpgtq %xmm0, %xmm4, %xmm5
+; AVX1-NEXT:    vpaddq %xmm5, %xmm0, %xmm0
+; AVX1-NEXT:    vpxor %xmm5, %xmm0, %xmm0
 ; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm0
-; AVX1-NEXT:    vxorps %ymm7, %ymm0, %ymm0
 ; AVX1-NEXT:    vpcmpgtq %xmm3, %xmm4, %xmm1
-; AVX1-NEXT:    vpcmpgtq %xmm2, %xmm4, %xmm4
-; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm4, %ymm5
-; AVX1-NEXT:    vpaddq %xmm1, %xmm3, %xmm1
-; AVX1-NEXT:    vpaddq %xmm4, %xmm2, %xmm2
+; AVX1-NEXT:    vpaddq %xmm1, %xmm3, %xmm3
+; AVX1-NEXT:    vpxor %xmm1, %xmm3, %xmm1
+; AVX1-NEXT:    vpcmpgtq %xmm2, %xmm4, %xmm3
+; AVX1-NEXT:    vpaddq %xmm3, %xmm2, %xmm2
+; AVX1-NEXT:    vpxor %xmm3, %xmm2, %xmm2
 ; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
-; AVX1-NEXT:    vxorps %ymm5, %ymm1, %ymm1
 ; AVX1-NEXT:    retq
 ;
 ; AVX2-LABEL: test_abs_le_v8i64_fold: