[AArh64-SVE]: Improve cost model for div/udiv/mul 128-bit vector operations

Differential Revision: https://reviews.llvm.org/D132477
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index d1dcbdf..ebceac3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2084,12 +2084,40 @@
     InstructionCost Cost = BaseT::getArithmeticInstrCost(
         Opcode, Ty, CostKind, Op1Info, Op2Info);
     if (Ty->isVectorTy()) {
-      // On AArch64, vector divisions are not supported natively and are
-      // expanded into scalar divisions of each pair of elements.
-      Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, CostKind,
-                                     Op1Info, Op2Info);
-      Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
-                                     Op1Info, Op2Info);
+      if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
+        // SDIV/UDIV operations are lowered, then we can have less costs.
+        if (isa<FixedVectorType>(Ty) &&
+            cast<FixedVectorType>(Ty)->getPrimitiveSizeInBits().getFixedSize() <
+                128) {
+          EVT VT = TLI->getValueType(DL, Ty);
+          static const CostTblEntry DivTbl[]{
+              {ISD::SDIV, MVT::v2i8, 5},  {ISD::SDIV, MVT::v4i8, 8},
+              {ISD::SDIV, MVT::v8i8, 8},  {ISD::SDIV, MVT::v2i16, 5},
+              {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
+              {ISD::UDIV, MVT::v2i8, 5},  {ISD::UDIV, MVT::v4i8, 8},
+              {ISD::UDIV, MVT::v8i8, 8},  {ISD::UDIV, MVT::v2i16, 5},
+              {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
+
+          const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
+          if (nullptr != Entry)
+            return Entry->Cost;
+        }
+        // For 8/16-bit elements, the cost is higher because the type
+        // requires promotion and possibly splitting:
+        if (LT.second.getScalarType() == MVT::i8)
+          Cost *= 8;
+        else if (LT.second.getScalarType() == MVT::i16)
+          Cost *= 4;
+        return Cost;
+      } else {
+        // On AArch64, without SVE, vector divisions are expanded
+        // into scalar divisions of each pair of elements.
+        Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty,
+                                       CostKind, Op1Info, Op2Info);
+        Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
+                                       Op1Info, Op2Info);
+      }
+
       // TODO: if one of the arguments is scalar, then it's not necessary to
       // double the cost of handling the vector elements.
       Cost += Cost;
@@ -2097,16 +2125,23 @@
     return Cost;
   }
   case ISD::MUL:
-    // Since we do not have a MUL.2d instruction, a mul <2 x i64> is expensive
-    // as elements are extracted from the vectors and the muls scalarized.
-    // As getScalarizationOverhead is a bit too pessimistic, we estimate the
-    // cost for a i64 vector directly here, which is:
+    // When SVE is available, then we can lower the v2i64 operation using
+    // the SVE mul instruction, which has a lower cost.
+    if (LT.second == MVT::v2i64 && ST->hasSVE())
+      return LT.first;
+
+    // When SVE is not available, there is no MUL.2d instruction,
+    // which means mul <2 x i64> is expensive as elements are extracted
+    // from the vectors and the muls scalarized.
+    // As getScalarizationOverhead is a bit too pessimistic, we
+    // estimate the cost for a i64 vector directly here, which is:
     // - four 2-cost i64 extracts,
     // - two 2-cost i64 inserts, and
     // - two 1-cost muls.
     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
     // LT.first = 2 the cost is 28. If both operands are extensions it will not
     // need to scalarize so the cost can be cheaper (smull or umull).
+    // so the cost can be cheaper (smull or umull).
     if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
       return LT.first;
     return LT.first * 14;