[LoopIdiom] Improve code; use SCEVPatternMatch (NFC) (#139540)

diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 0eaed2e..0d5e015 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -32,7 +32,6 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -49,6 +48,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
@@ -91,6 +91,7 @@
 #include <vector>
 
 using namespace llvm;
+using namespace SCEVPatternMatch;
 
 #define DEBUG_TYPE "loop-idiom"
 
@@ -340,9 +341,8 @@
 
   // If this loop executes exactly one time, then it should be peeled, not
   // optimized by this pass.
-  if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
-    if (BECst->getAPInt() == 0)
-      return false;
+  if (BECount->isZero())
+    return false;
 
   SmallVector<BasicBlock *, 8> ExitBlocks;
   CurLoop->getUniqueExitBlocks(ExitBlocks);
@@ -453,13 +453,10 @@
   // See if the pointer expression is an AddRec like {base,+,1} on the current
   // loop, which indicates a strided store.  If we have something else, it's a
   // random store we can't handle.
-  const SCEVAddRecExpr *StoreEv =
-      dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
-  if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
-    return LegalStoreKind::None;
-
-  // Check to see if we have a constant stride.
-  if (!isa<SCEVConstant>(StoreEv->getOperand(1)))
+  const SCEV *StoreEv = SE->getSCEV(StorePtr);
+  const SCEVConstant *Stride;
+  if (!match(StoreEv, m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride))) ||
+      cast<SCEVAddRecExpr>(StoreEv)->getLoop() != CurLoop)
     return LegalStoreKind::None;
 
   // See if the store can be turned into a memset.
@@ -494,9 +491,9 @@
   if (HasMemcpy && !DisableLIRP::Memcpy) {
     // Check to see if the stride matches the size of the store.  If so, then we
     // know that every byte is touched in the loop.
-    APInt Stride = getStoreStride(StoreEv);
     unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
-    if (StoreSize != Stride && StoreSize != -Stride)
+    APInt StrideAP = Stride->getAPInt();
+    if (StoreSize != StrideAP && StoreSize != -StrideAP)
       return LegalStoreKind::None;
 
     // The store must be feeding a non-volatile load.
@@ -512,13 +509,12 @@
     // See if the pointer expression is an AddRec like {base,+,1} on the current
     // loop, which indicates a strided load.  If we have something else, it's a
     // random load we can't handle.
-    const SCEVAddRecExpr *LoadEv =
-        dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand()));
-    if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
-      return LegalStoreKind::None;
+    const SCEV *LoadEv = SE->getSCEV(LI->getPointerOperand());
 
     // The store and load must share the same stride.
-    if (StoreEv->getOperand(1) != LoadEv->getOperand(1))
+    if (!match(LoadEv,
+               m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride))) ||
+        cast<SCEVAddRecExpr>(LoadEv)->getLoop() != CurLoop)
       return LegalStoreKind::None;
 
     // Success.  This store can be converted into a memcpy.
@@ -805,20 +801,17 @@
 
   // Check if the stride matches the size of the memcpy. If so, then we know
   // that every byte is touched in the loop.
-  const SCEVConstant *ConstStoreStride =
-      dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
-  const SCEVConstant *ConstLoadStride =
-      dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
-  if (!ConstStoreStride || !ConstLoadStride)
+  const APInt *StoreStrideValue, *LoadStrideValue;
+  if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
+      !match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
     return false;
 
-  APInt StoreStrideValue = ConstStoreStride->getAPInt();
-  APInt LoadStrideValue = ConstLoadStride->getAPInt();
   // Huge stride value - give up
-  if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
+  if (StoreStrideValue->getBitWidth() > 64 ||
+      LoadStrideValue->getBitWidth() > 64)
     return false;
 
-  if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
+  if (SizeInBytes != *StoreStrideValue && SizeInBytes != -*StoreStrideValue) {
     ORE.emit([&]() {
       return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
              << ore::NV("Inst", "memcpy") << " in "
@@ -829,8 +822,8 @@
     return false;
   }
 
-  int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
-  int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
+  int64_t StoreStrideInt = StoreStrideValue->getSExtValue();
+  int64_t LoadStrideInt = LoadStrideValue->getSExtValue();
   // Check if the load stride matches the store stride.
   if (StoreStrideInt != LoadStrideInt)
     return false;
@@ -857,15 +850,15 @@
   // See if the pointer expression is an AddRec like {base,+,1} on the current
   // loop, which indicates a strided store.  If we have something else, it's a
   // random store we can't handle.
-  const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Pointer));
-  if (!Ev || Ev->getLoop() != CurLoop)
-    return false;
-  if (!Ev->isAffine()) {
+  const SCEV *Ev = SE->getSCEV(Pointer);
+  const SCEV *PointerStrideSCEV;
+  if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV)))) {
     LLVM_DEBUG(dbgs() << "  Pointer is not affine, abort\n");
     return false;
   }
+  if (cast<SCEVAddRecExpr>(Ev)->getLoop() != CurLoop)
+    return false;
 
-  const SCEV *PointerStrideSCEV = Ev->getOperand(1);
   const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength());
   if (!PointerStrideSCEV || !MemsetSizeSCEV)
     return false;
@@ -879,15 +872,14 @@
     // we know that every byte is touched in the loop.
     LLVM_DEBUG(dbgs() << "  memset size is constant\n");
     uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
-    const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
-    if (!ConstStride)
+    const APInt *Stride;
+    if (!match(PointerStrideSCEV, m_scev_APInt(Stride)))
       return false;
 
-    APInt Stride = ConstStride->getAPInt();
-    if (SizeInBytes != Stride && SizeInBytes != -Stride)
+    if (SizeInBytes != *Stride && SizeInBytes != -*Stride)
       return false;
 
-    IsNegStride = SizeInBytes == -Stride;
+    IsNegStride = SizeInBytes == -*Stride;
   } else {
     // Memset size is non-constant.
     // Check if the pointer stride matches the memset size.
@@ -944,8 +936,9 @@
   SmallPtrSet<Instruction *, 1> MSIs;
   MSIs.insert(MSI);
   return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
-                                 MSI->getDestAlign(), SplatValue, MSI, MSIs, Ev,
-                                 BECount, IsNegStride, /*IsLoopMemset=*/true);
+                                 MSI->getDestAlign(), SplatValue, MSI, MSIs,
+                                 cast<SCEVAddRecExpr>(Ev), BECount, IsNegStride,
+                                 /*IsLoopMemset=*/true);
 }
 
 /// mayLoopAccessLocation - Return true if the specified loop might access the
@@ -963,11 +956,11 @@
 
   // If the loop iterates a fixed number of times, we can refine the access size
   // to be exactly the size of the memset, which is (BECount+1)*StoreSize
-  const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
-  const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
-  if (BECst && ConstSize) {
-    std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
-    std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
+  const APInt *BECst, *ConstSize;
+  if (match(BECount, m_scev_APInt(BECst)) &&
+      match(StoreSizeSCEV, m_scev_APInt(ConstSize))) {
+    std::optional<uint64_t> BEInt = BECst->tryZExtValue();
+    std::optional<uint64_t> SizeInt = ConstSize->tryZExtValue();
     // FIXME: Should this check for overflow?
     if (BEInt && SizeInt)
       AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
@@ -1579,24 +1572,15 @@
     // See if the pointer expression is an AddRec with constant step a of form
     // ({n,+,a}) where a is the width of the char type.
     Value *IncPtr = LoopLoad->getPointerOperand();
-    const SCEVAddRecExpr *LoadEv =
-        dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IncPtr));
-    if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
+    const SCEV *LoadEv = SE->getSCEV(IncPtr);
+    const APInt *Step;
+    if (!match(LoadEv,
+               m_scev_AffineAddRec(m_SCEV(LoadBaseEv), m_scev_APInt(Step))))
       return false;
-    LoadBaseEv = LoadEv->getStart();
 
     LLVM_DEBUG(dbgs() << "pointer load scev: " << *LoadEv << "\n");
 
-    const SCEVConstant *Step =
-        dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
-    if (!Step)
-      return false;
-
-    unsigned StepSize = 0;
-    StepSizeCI = dyn_cast<ConstantInt>(Step->getValue());
-    if (!StepSizeCI)
-      return false;
-    StepSize = StepSizeCI->getZExtValue();
+    unsigned StepSize = Step->getZExtValue();
 
     // Verify that StepSize is consistent with platform char width.
     OpWidth = OperandType->getIntegerBitWidth();
@@ -3277,9 +3261,7 @@
   // Ok, transform appears worthwhile.
   MadeChange = true;
 
-  bool OffsetIsZero = false;
-  if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
-    OffsetIsZero = ExtraOffsetExprC->isZero();
+  bool OffsetIsZero = ExtraOffsetExpr->isZero();
 
   // Step 1: Compute the loop's final IV value / trip count.