[AArch64][CostModel] Alter sdiv/srem cost where the divisor is constant (#123552)

This patch revises the cost model for sdiv/srem and draws its inspiration from the udiv/urem patch #122236

The typical codegen for the different scenarios has been mentioned as notes/comments in the code itself( this is done owing to lot of scenarios such that it would be difficult to mention them here in the patch description).
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 53c6a02..7cec8a1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -18,6 +18,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/CostTable.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -3531,23 +3532,111 @@
   default:
     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
                                          Op2Info);
+  case ISD::SREM:
   case ISD::SDIV:
-    if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
-      // On AArch64, scalar signed division by constants power-of-two are
-      // normally expanded to the sequence ADD + CMP + SELECT + SRA.
-      // The OperandValue properties many not be same as that of previous
-      // operation; conservatively assume OP_None.
-      InstructionCost Cost = getArithmeticInstrCost(
-          Instruction::Add, Ty, CostKind,
-          Op1Info.getNoProps(), Op2Info.getNoProps());
-      Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
-                                     Op1Info.getNoProps(), Op2Info.getNoProps());
-      Cost += getArithmeticInstrCost(
-          Instruction::Select, Ty, CostKind,
-          Op1Info.getNoProps(), Op2Info.getNoProps());
-      Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
-                                     Op1Info.getNoProps(), Op2Info.getNoProps());
-      return Cost;
+    /*
+    Notes for sdiv/srem specific costs:
+    1. This only considers the cases where the divisor is constant, uniform and
+    (pow-of-2/non-pow-of-2). Other cases are not important since they either
+    result in some form of (ldr + adrp), corresponding to constant vectors, or
+    scalarization of the division operation.
+    2. Constant divisors, either negative in whole or partially, don't result in
+    significantly different codegen as compared to positive constant divisors.
+    So, we don't consider negative divisors seperately.
+    3. If the codegen is significantly different with SVE, it has been indicated
+    using comments at appropriate places.
+
+    sdiv specific cases:
+    -----------------------------------------------------------------------
+    codegen                       | pow-of-2               | Type
+    -----------------------------------------------------------------------
+    add + cmp + csel + asr        | Y                      | i64
+    add + cmp + csel + asr        | Y                      | i32
+    -----------------------------------------------------------------------
+
+    srem specific cases:
+    -----------------------------------------------------------------------
+    codegen                       | pow-of-2               | Type
+    -----------------------------------------------------------------------
+    negs + and + and + csneg      | Y                      | i64
+    negs + and + and + csneg      | Y                      | i32
+    -----------------------------------------------------------------------
+
+    other sdiv/srem cases:
+    -------------------------------------------------------------------------
+    commom codegen            | + srem     | + sdiv     | pow-of-2  | Type
+    -------------------------------------------------------------------------
+    smulh + asr + add + add   | -          | -          | N         | i64
+    smull + lsr + add + add   | -          | -          | N         | i32
+    usra                      | and + sub  | sshr       | Y         | <2 x i64>
+    2 * (scalar code)         | -          | -          | N         | <2 x i64>
+    usra                      | bic + sub  | sshr + neg | Y         | <4 x i32>
+    smull2 + smull + uzp2     | mls        | -          | N         | <4 x i32>
+           + sshr  + usra     |            |            |           |
+    -------------------------------------------------------------------------
+    */
+    if (Op2Info.isConstant() && Op2Info.isUniform()) {
+      InstructionCost AddCost =
+          getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
+                                 Op1Info.getNoProps(), Op2Info.getNoProps());
+      InstructionCost AsrCost =
+          getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
+                                 Op1Info.getNoProps(), Op2Info.getNoProps());
+      InstructionCost MulCost =
+          getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
+                                 Op1Info.getNoProps(), Op2Info.getNoProps());
+      // add/cmp/csel/csneg should have similar cost while asr/negs/and should
+      // have similar cost.
+      auto VT = TLI->getValueType(DL, Ty);
+      if (LT.second.isScalarInteger() && VT.getSizeInBits() <= 64) {
+        if (Op2Info.isPowerOf2()) {
+          return ISD == ISD::SDIV ? (3 * AddCost + AsrCost)
+                                  : (3 * AsrCost + AddCost);
+        } else {
+          return MulCost + AsrCost + 2 * AddCost;
+        }
+      } else if (VT.isVector()) {
+        InstructionCost UsraCost = 2 * AsrCost;
+        if (Op2Info.isPowerOf2()) {
+          // Division with scalable types corresponds to native 'asrd'
+          // instruction when SVE is available.
+          // e.g. %1 = sdiv <vscale x 4 x i32> %a, splat (i32 8)
+          if (Ty->isScalableTy() && ST->hasSVE())
+            return 2 * AsrCost;
+          return UsraCost +
+                 (ISD == ISD::SDIV
+                      ? (LT.second.getScalarType() == MVT::i64 ? 1 : 2) *
+                            AsrCost
+                      : 2 * AddCost);
+        } else if (LT.second == MVT::v2i64) {
+          return VT.getVectorNumElements() *
+                 getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind,
+                                        Op1Info.getNoProps(),
+                                        Op2Info.getNoProps());
+        } else {
+          // When SVE is available, we get:
+          // smulh + lsr + add/sub + asr + add/sub.
+          if (Ty->isScalableTy() && ST->hasSVE())
+            return MulCost /*smulh cost*/ + 2 * AddCost + 2 * AsrCost;
+          return 2 * MulCost + AddCost /*uzp2 cost*/ + AsrCost + UsraCost;
+        }
+      }
+    }
+    if (Op2Info.isConstant() && !Op2Info.isUniform() &&
+        LT.second.isFixedLengthVector()) {
+      // FIXME: When the constant vector is non-uniform, this may result in
+      // loading the vector from constant pool or in some cases, may also result
+      // in scalarization. For now, we are approximating this with the
+      // scalarization cost.
+      auto ExtractCost = 2 * getVectorInstrCost(Instruction::ExtractElement, Ty,
+                                                CostKind, -1, nullptr, nullptr);
+      auto InsertCost = getVectorInstrCost(Instruction::InsertElement, Ty,
+                                           CostKind, -1, nullptr, nullptr);
+      unsigned NElts = cast<FixedVectorType>(Ty)->getNumElements();
+      return ExtractCost + InsertCost +
+             NElts * getArithmeticInstrCost(Opcode, Ty->getScalarType(),
+                                            CostKind, Op1Info.getNoProps(),
+                                            Op2Info.getNoProps());
     }
     [[fallthrough]];
   case ISD::UDIV:
@@ -3587,23 +3676,6 @@
                                   AddCost * 2 + ShrCost;
         return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
       }
-
-      // TODO: Fix SDIV and SREM costs, similar to the above.
-      if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT) &&
-          Op2Info.isUniform() && !VT.isScalableVector()) {
-        // Vector signed division by constant are expanded to the
-        // sequence MULHS + ADD/SUB + SRA + SRL + ADD.
-        InstructionCost MulCost =
-            getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
-                                   Op1Info.getNoProps(), Op2Info.getNoProps());
-        InstructionCost AddCost =
-            getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
-                                   Op1Info.getNoProps(), Op2Info.getNoProps());
-        InstructionCost ShrCost =
-            getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
-                                   Op1Info.getNoProps(), Op2Info.getNoProps());
-        return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
-      }
     }
 
     // div i128's are lowered as libcalls.  Pass nullptr as (u)divti3 calls are