[MLIR] Add a getStaticTripCount method to LoopLikeOpInterface (#158679)

This patch adds a `getStaticTripCount` to the LoopLikeOpInterface,
allowing loops to optionally return a static trip count when possible.
This is implemented on SCF ForOp, revamping the implementation of
`constantTripCount`, removing redundant duplicate implementations from
SCF.cpp.

GitOrigin-RevId: 75469bb376b7fa931d03ea95f6b7142d27ef0e51
diff --git a/include/mlir/Dialect/SCF/IR/SCFOps.td b/include/mlir/Dialect/SCF/IR/SCFOps.td
index d3c01c3..fadd3fc 100644
--- a/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -152,7 +152,7 @@
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
         "getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
-        "getLoopUpperBounds", "getYieldedValuesMutable",
+        "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
         "promoteIfSingleIteration", "replaceWithAdditionalYields",
         "yieldTiledValuesAndReplace"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
diff --git a/include/mlir/Dialect/Utils/StaticValueUtils.h b/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376f..2e7f85c 100644
--- a/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -106,6 +106,10 @@
                                                  ArrayRef<int64_t> values);
 
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
+/// The second return value indicates whether the value is an index type
+/// and thus the bitwidth is not defined (the APInt will be set with 64bits).
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
 std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 /// If all ofrs are constant integers or IntegerAttrs, return the integers.
 std::optional<SmallVector<int64_t>>
@@ -201,9 +205,26 @@
 LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
-/// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
-                                         OpFoldResult step);
+/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
+/// comparison between lb and ub is signed or unsigned. A negative step or a
+/// lower bound greater than the upper bound are considered invalid and will
+/// yield a zero trip count.
+/// The `computeUbMinusLb` callback is invoked to compute the difference between
+/// the upper and lower bound when not constant. It can be used by the client
+/// to compute a static difference when the bounds are not constant.
+///
+/// For example, the following code:
+///
+///   %ub = arith.addi nsw %lb, %c16_i32 : i32
+///   %1 = scf.for %arg0 = %lb to %ub ...
+///
+/// where %ub is computed as a static offset from %lb.
+/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
+/// to avoid overflow, otherwise an overflow would imply a zero trip count.
+std::optional<APInt> constantTripCount(
+    OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+    llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+        computeUbMinusLb);
 
 /// Idiomatic saturated operations on values like offsets, sizes, and strides.
 struct SaturatedInteger {
diff --git a/include/mlir/Interfaces/LoopLikeInterface.td b/include/mlir/Interfaces/LoopLikeInterface.td
index 6c95b48..cfd15a7 100644
--- a/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/include/mlir/Interfaces/LoopLikeInterface.td
@@ -232,6 +232,17 @@
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
+    >,
+    InterfaceMethod<[{
+        Compute the static trip count if possible.
+      }],
+      /*retTy=*/"::std::optional<APInt>",
+      /*methodName=*/"getStaticTripCount",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::std::nullopt;
+      }]
     >
   ];
 
diff --git a/lib/Dialect/SCF/IR/SCF.cpp b/lib/Dialect/SCF/IR/SCF.cpp
index c35989e..ae55ead 100644
--- a/lib/Dialect/SCF/IR/SCF.cpp
+++ b/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,6 +19,8 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
@@ -26,6 +28,9 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::scf;
@@ -105,6 +110,24 @@
   return nullptr;
 }
 
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
+                                                    bool isSigned) {
+  llvm::APSInt diff;
+  auto addOp = ub.getDefiningOp<arith::AddIOp>();
+  if (!addOp)
+    return std::nullopt;
+  if ((isSigned && !addOp.hasNoSignedWrap()) ||
+      (!isSigned && !addOp.hasNoUnsignedWrap()))
+    return std::nullopt;
+
+  if (addOp.getLhs() != lb ||
+      !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
+    return std::nullopt;
+  return diff;
+}
+
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
@@ -408,11 +431,19 @@
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// it can be determined that the loop has a single iteration.
 LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
-  std::optional<int64_t> tripCount =
-      constantTripCount(getLowerBound(), getUpperBound(), getStep());
-  if (!tripCount.has_value() || tripCount != 1)
+  std::optional<APInt> tripCount = getStaticTripCount();
+  LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
+         << " for loop "
+         << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
+  if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
     return failure();
 
+  if (*tripCount == 0) {
+    rewriter.replaceAllUsesWith(getResults(), getInitArgs());
+    rewriter.eraseOp(*this);
+    return success();
+  }
+
   // Replace all results with the yielded values.
   auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
   rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
@@ -646,7 +677,8 @@
 LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   for (auto [lb, ub, step] :
        llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
-    auto tripCount = constantTripCount(lb, ub, step);
+    auto tripCount =
+        constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
     if (!tripCount.has_value() || *tripCount != 1)
       return failure();
   }
@@ -1003,27 +1035,6 @@
   }
 };
 
-/// Util function that tries to compute a constant diff between u and l.
-/// Returns std::nullopt when the difference between two AffineValueMap is
-/// dynamic.
-static std::optional<APInt> computeConstDiff(Value l, Value u) {
-  IntegerAttr clb, cub;
-  if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
-    llvm::APInt lbValue = clb.getValue();
-    llvm::APInt ubValue = cub.getValue();
-    return ubValue - lbValue;
-  }
-
-  // Else a simple pattern match for x + c or c + x
-  llvm::APInt diff;
-  if (matchPattern(
-          u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
-      matchPattern(
-          u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
-    return diff;
-  return std::nullopt;
-}
-
 /// Rewriting pattern that erases loops that are known not to iterate, replaces
 /// single-iteration loops with their bodies, and removes empty loops that
 /// iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1043,21 @@
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
-    // If the upper bound is the same as the lower bound, the loop does not
-    // iterate, just remove it.
-    if (op.getLowerBound() == op.getUpperBound()) {
+    std::optional<APInt> tripCount = op.getStaticTripCount();
+    if (!tripCount.has_value())
+      return rewriter.notifyMatchFailure(op,
+                                         "can't compute constant trip count");
+
+    if (tripCount->isZero()) {
+      LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
       rewriter.replaceOp(op, op.getInitArgs());
       return success();
     }
 
-    std::optional<APInt> diff =
-        computeConstDiff(op.getLowerBound(), op.getUpperBound());
-    if (!diff)
-      return failure();
-
-    // If the loop is known to have 0 iterations, remove it.
-    bool zeroOrLessIterations =
-        diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
-    if (zeroOrLessIterations) {
-      rewriter.replaceOp(op, op.getInitArgs());
-      return success();
-    }
-
-    std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
-    if (!maybeStepValue)
-      return failure();
-
-    // If the loop is known to have 1 iteration, inline its body and remove the
-    // loop.
-    llvm::APInt stepValue = *maybeStepValue;
-    if (stepValue.sge(*diff)) {
+    if (tripCount->getSExtValue() == 1) {
+      LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
+             << OpWithFlags(op, OpPrintingFlags().skipRegions());
       SmallVector<Value, 4> blockArgs;
       blockArgs.reserve(op.getInitArgs().size() + 1);
       blockArgs.push_back(op.getLowerBound());
@@ -1072,11 +1070,14 @@
     Block &block = op.getRegion().front();
     if (!llvm::hasSingleElement(block))
       return failure();
-    // If the loop is empty, iterates at least once, and only returns values
+    // The loop is empty and iterates at least once, if it only returns values
     // defined outside of the loop, remove it and replace it with yield values.
     if (llvm::any_of(op.getYieldedValues(),
                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
       return failure();
+    LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
+              "yield operands for loop "
+           << OpWithFlags(op, OpPrintingFlags().skipRegions());
     rewriter.replaceOp(op, op.getYieldedValues());
     return success();
   }
@@ -1172,6 +1173,11 @@
   return Speculation::NotSpeculatable;
 }
 
+std::optional<APInt> ForOp::getStaticTripCount() {
+  return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
+                           /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
+}
+
 //===----------------------------------------------------------------------===//
 // ForallOp
 //===----------------------------------------------------------------------===//
@@ -1768,7 +1774,8 @@
     for (auto [lb, ub, step, iv] :
          llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
                    op.getMixedStep(), op.getInductionVars())) {
-      auto numIterations = constantTripCount(lb, ub, step);
+      auto numIterations =
+          constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
       if (numIterations.has_value()) {
         // Remove the loop if it performs zero iterations.
         if (*numIterations == 0) {
@@ -1839,7 +1846,8 @@
                    op.getMixedStep(), op.getInductionVars())) {
       if (iv.hasNUses(0))
         continue;
-      auto numIterations = constantTripCount(lb, ub, step);
+      auto numIterations =
+          constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
       if (!numIterations.has_value() || numIterations.value() != 1) {
         continue;
       }
@@ -3084,7 +3092,8 @@
     for (auto [lb, ub, step, iv] :
          llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
                    op.getInductionVars())) {
-      auto numIterations = constantTripCount(lb, ub, step);
+      auto numIterations =
+          constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
       if (numIterations.has_value()) {
         // Remove the loop if it performs zero iterations.
         if (*numIterations == 0) {
diff --git a/lib/Dialect/SCF/Utils/Utils.cpp b/lib/Dialect/SCF/Utils/Utils.cpp
index fc93cf3..18f139c 100644
--- a/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/lib/Dialect/SCF/Utils/Utils.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/DebugLog.h"
@@ -290,14 +291,6 @@
   return arith::DivUIOp::create(builder, loc, sum, divisor);
 }
 
-/// Returns the trip count of `forOp` if its' low bound, high bound and step are
-/// constants, or optional otherwise. Trip count is computed as
-/// ceilDiv(highBound - lowBound, step).
-static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
-  return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
-                           forOp.getStep());
-}
-
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -376,7 +369,7 @@
   Value stepUnrolled;
   bool generateEpilogueLoop = true;
 
-  std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+  std::optional<APInt> constTripCount = forOp.getStaticTripCount();
   if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
@@ -390,7 +383,8 @@
     }
 
     int64_t tripCountEvenMultiple =
-        *constTripCount - (*constTripCount % unrollFactor);
+        constTripCount->getSExtValue() -
+        (constTripCount->getSExtValue() % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -486,15 +480,15 @@
 /// Unrolls this loop completely.
 LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
   IRRewriter rewriter(forOp.getContext());
-  std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
   if (!mayBeConstantTripCount.has_value())
     return failure();
-  uint64_t tripCount = *mayBeConstantTripCount;
-  if (tripCount == 0)
+  APInt &tripCount = *mayBeConstantTripCount;
+  if (tripCount.isZero())
     return success();
-  if (tripCount == 1)
+  if (tripCount.getSExtValue() == 1)
     return forOp.promoteIfSingleIteration(rewriter);
-  return loopUnrollByFactor(forOp, tripCount);
+  return loopUnrollByFactor(forOp, tripCount.getSExtValue());
 }
 
 /// Check if bounds of all inner loops are defined outside of `forOp`
@@ -534,18 +528,18 @@
 
   // Currently, only constant trip count that divided by the unroll factor is
   // supported.
-  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  std::optional<APInt> tripCount = forOp.getStaticTripCount();
   if (!tripCount.has_value()) {
     // If the trip count is dynamic, do not unroll & jam.
     LDBG() << "failed to unroll and jam: trip count could not be determined";
     return failure();
   }
-  if (unrollJamFactor > *tripCount) {
+  if (unrollJamFactor > tripCount->getZExtValue()) {
     LDBG() << "unroll and jam factor is greater than trip count, set factor to "
               "trip "
               "count";
-    unrollJamFactor = *tripCount;
-  } else if (*tripCount % unrollJamFactor != 0) {
+    unrollJamFactor = tripCount->getZExtValue();
+  } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
     LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
               "multiple of unroll jam factor";
     return failure();
diff --git a/lib/Dialect/Utils/StaticValueUtils.cpp b/lib/Dialect/Utils/StaticValueUtils.cpp
index 34385d7..5048b19 100644
--- a/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/MathExtras.h"
 
 namespace mlir {
@@ -112,21 +113,30 @@
 }
 
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
-std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+/// The boolean indicates whether the value is an index type.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.
   if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
-    APSInt intVal;
+    APInt intVal;
     if (matchPattern(val, m_ConstantInt(&intVal)))
-      return intVal.getSExtValue();
+      return std::make_pair(intVal, val.getType().isIndex());
     return std::nullopt;
   }
   // Case 2: Check for IntegerAttr.
   Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
   if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
-    return intAttr.getValue().getSExtValue();
+    return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
   return std::nullopt;
 }
 
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+  std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
+  if (!apInt)
+    return std::nullopt;
+  return apInt->first.getSExtValue();
+}
+
 std::optional<SmallVector<int64_t>>
 getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
   bool failed = false;
@@ -264,22 +274,108 @@
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
-                                         OpFoldResult step) {
+std::optional<APInt> constantTripCount(
+    OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+    llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+        computeUbMinusLb) {
+  // This is the bitwidth used to return 0 when loop does not execute.
+  // We infer it from the type of the bound if it isn't an index type.
+  bool isIndex = true;
+  auto getBitwidth = [&](OpFoldResult ofr) -> int {
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
+          isIndex = intType.isIndex();
+          return intType.getWidth();
+        }
+      }
+    } else {
+      auto val = cast<Value>(ofr);
+      if (auto intType = dyn_cast<IntegerType>(val.getType())) {
+        isIndex = intType.isIndex();
+        return intType.getWidth();
+      }
+    }
+    return IndexType::kInternalStorageBitWidth;
+  };
+  int bitwidth = getBitwidth(lb);
+  assert(bitwidth == getBitwidth(ub) &&
+         "lb and ub must have the same bitwidth");
   if (lb == ub)
-    return 0;
+    return APInt(bitwidth, 0);
 
-  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
-  if (!lbConstant)
-    return std::nullopt;
-  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
-  if (!ubConstant)
-    return std::nullopt;
-  std::optional<int64_t> stepConstant = getConstantIntValue(step);
-  if (!stepConstant || *stepConstant == 0)
-    return std::nullopt;
+  std::optional<std::pair<APInt, bool>> maybeStepCst =
+      getConstantAPIntValue(step);
 
-  return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
+  if (maybeStepCst) {
+    auto &stepCst = maybeStepCst->first;
+    assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
+           "step must have the same bitwidth as lb and ub");
+    if (stepCst.isZero())
+      return stepCst;
+    if (stepCst.isNegative())
+      return APInt(bitwidth, 0);
+  }
+
+  if (isIndex) {
+    LDBG()
+        << "Computing loop trip count for index type may break with overflow";
+    // TODO: we can't compute the trip count for index type. We should fix this
+    // but too many tests are failing right now.
+    //   return {};
+  }
+
+  /// Compute the difference between the upper and lower bound: either from the
+  /// constant value or using the computeUbMinusLb callback.
+  llvm::APSInt diff;
+  std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
+  std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
+  if (maybeLbCst) {
+    // If one of the bounds is not a constant, we can't compute the trip count.
+    if (!maybeUbCst)
+      return std::nullopt;
+    APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
+    APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
+    if (!maybeUbCst)
+      return std::nullopt;
+    if (ubCst <= lbCst) {
+      LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
+             << lbCst.getBitWidth() << ") <= " << ubCst << "("
+             << ubCst.getBitWidth() << "), "
+             << (isSigned ? "isSigned" : "isUnsigned") << ")";
+      return APInt(bitwidth, 0);
+    }
+    diff = ubCst - lbCst;
+  } else {
+    if (maybeUbCst)
+      return std::nullopt;
+
+    /// Non-constant bound, let's try to compute the difference between the
+    /// upper and lower bound
+    std::optional<llvm::APSInt> maybeDiff =
+        computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
+    if (!maybeDiff)
+      return std::nullopt;
+    diff = *maybeDiff;
+  }
+  LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
+         << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
+  if (diff.isNegative()) {
+    LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
+    return APInt(bitwidth, 0);
+  }
+  if (!maybeStepCst) {
+    LDBG()
+        << "constantTripCount can't be computed because step is not a constant";
+    return std::nullopt;
+  }
+  auto &stepCst = maybeStepCst->first;
+  llvm::APInt tripCount = diff.sdiv(stepCst);
+  llvm::APInt r = diff.srem(stepCst);
+  if (!r.isZero())
+    tripCount = tripCount + 1;
+  LDBG() << "constantTripCount found: " << tripCount;
+  return tripCount;
 }
 
 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
diff --git a/test/Dialect/SCF/canonicalize.mlir b/test/Dialect/SCF/canonicalize.mlir
index 4ad2da8..5e89f74 100644
--- a/test/Dialect/SCF/canonicalize.mlir
+++ b/test/Dialect/SCF/canonicalize.mlir
@@ -749,7 +749,7 @@
   // CHECK-NEXT: %[[CST:.*]] = arith.constant 2
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
-  %5 = arith.addi %arg0, %c1 : index
+  %5 = arith.addi %arg0, %c1 overflow<nsw> : index
   // CHECK-NOT: scf.for
   scf.for %arg2 = %arg0 to %5 step %c1 {
     // CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]]
@@ -1933,8 +1933,9 @@
 
 // -----
 
+// Step 0 is invalid, the loop is eliminated.
 // CHECK-LABEL: func @scf_for_all_step_size_0()
-//       CHECK:   scf.forall (%{{.*}}) = (0) to (1) step (0)
+//       CHECK-NOT:   scf.forall
 func.func @scf_for_all_step_size_0()  {
   %x = arith.constant 0 : index
   scf.forall (%i, %j) = (0, 4) to (1, 5) step (%x, 8) {
diff --git a/test/Dialect/SCF/for-loop-peeling.mlir b/test/Dialect/SCF/for-loop-peeling.mlir
index 03c446c..be58548 100644
--- a/test/Dialect/SCF/for-loop-peeling.mlir
+++ b/test/Dialect/SCF/for-loop-peeling.mlir
@@ -328,10 +328,9 @@
 // -----
 
 // Regression test: Make sure that we do not crash.
-
+// The step is 0, the loop will be eliminated.
 // CHECK-LABEL: func @zero_step(
-//       CHECK:   scf.for
-//       CHECK:   scf.for
+//       CHECK-NOT:   scf.for
 func.func @zero_step(%arg0: memref<i64>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
diff --git a/test/Dialect/SCF/trip_count.mlir b/test/Dialect/SCF/trip_count.mlir
new file mode 100644
index 0000000..54883d7
--- /dev/null
+++ b/test/Dialect/SCF/trip_count.mlir
@@ -0,0 +1,702 @@
+// RUN: mlir-opt %s  -test-scf-for-utils --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @trip_count_index_zero_to_zero(
+func.func @trip_count_index_zero_to_zero(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK: "test.trip-count" = 0
+  %r = scf.for %i = %c0 to %c0 step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_zero_to_zero_step_dyn(
+func.func @trip_count_index_zero_to_zero_step_dyn(%a : i32, %b : i32, %step : index) -> i32 {
+  %c0 = arith.constant 0 : index
+
+  // CHECK: "test.trip-count" = 0
+  %r = scf.for %i = %c0 to %c0 step %step iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_zero_to_zero(
+func.func @trip_count_i32_zero_to_zero(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+
+  // CHECK: "test.trip-count" = 0
+  %r = scf.for %i = %c0 to %c0 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r : i32
+}
+
+// -----
+
+
+// CHECK-LABEL: func.func @trip_count_i32_zero_to_zero_step_dyn(
+func.func @trip_count_i32_zero_to_zero_step_dyn(%a : i32, %b : i32, %step : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+
+  // CHECK: "test.trip-count" = 0
+  %r = scf.for %i = %c0 to %c0 step %step iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_one_to_zero(
+func.func @trip_count_index_one_to_zero(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 0
+  %r2 = scf.for %i = %c1 to %c0 step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_one_to_zero(
+func.func @trip_count_i32_one_to_zero(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+
+  // CHECK: "test.trip-count" = 0
+  %r2 = scf.for %i = %c1 to %c0 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_one_to_zero_dyn_step(
+func.func @trip_count_i32_one_to_zero_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+
+  // CHECK: "test.trip-count" = 0
+  %r2 = scf.for %i = %c1 to %c0 step %step iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_step(
+func.func @trip_count_index_negative_step(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c-1 = arith.constant -1 : index
+
+  // Negative step is invalid, loop won't execute.
+  // CHECK: "test.trip-count" = 0
+  %r3 = scf.for %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_step(
+func.func @trip_count_i32_negative_step(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c-1 = arith.constant -1 : i32
+
+  // Negative step is invalid, loop won't execute.
+  // CHECK: "test.trip-count" = 0
+  %r3 = scf.for %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_step_unsigned_loop(
+func.func @trip_count_index_negative_step_unsigned_loop(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c-1 = arith.constant -1 : index
+
+  // Negative step is invalid, loop won't execute.
+  // CHECK: "test.trip-count" = 0
+  %r3 = scf.for unsigned %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_step_unsigned_loop(
+func.func @trip_count_i32_negative_step_unsigned_loop(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c-1 = arith.constant -1 : i32
+
+  // Negative step is invalid, loop won't execute.
+  // CHECK: "test.trip-count" = 0
+  %r3 = scf.for unsigned %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_normal_loop(
+func.func @trip_count_index_normal_loop(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c10 = arith.constant 10 : index
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 5
+  %r4 = scf.for %i = %c0 to %c10 step %c2 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r4 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_normal_loop(
+func.func @trip_count_i32_normal_loop(%a : i32, %b : i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c2 = arith.constant 2 : i32
+  %c10 = arith.constant 10 : i32
+
+  // Normal loop
+  // CHECK: "test.trip-count" = 5
+  %r4 = scf.for %i = %c0 to %c10 step %c2 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r4 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_signed_crossing_zero(
+func.func @trip_count_index_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
+  %c-1 = arith.constant -1 : index
+  %c1 = arith.constant 1 : index
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 2
+  %r5 = scf.for %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r5 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_signed_crossing_zero(
+func.func @trip_count_i32_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %c1 = arith.constant 1 : i32
+
+  // This loop execute with signed comparison, but not unsigned, because it is crossing 0.
+  // CHECK: "test.trip-count" = 2
+  %r5 = scf.for %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r5 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_unsigned_crossing_zero(
+func.func @trip_count_index_unsigned_crossing_zero(%a : i32, %b : i32) -> i32 {
+  %c-1 = arith.constant -1 : index
+  %c1 = arith.constant 1 : index
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 0
+  %r6 = scf.for unsigned %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r6 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_unsigned_crossing_zero(
+func.func @trip_count_i32_unsigned_crossing_zero(%a : i32, %b : i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %c1 = arith.constant 1 : i32
+
+  // This loop execute with signed comparison, but not unsigned, because it is crossing 0.
+  // CHECK: "test.trip-count" = 0
+  %r6 = scf.for unsigned %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r6 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_unsigned_crossing_zero_dyn_step(
+func.func @trip_count_i32_unsigned_crossing_zero_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %c1 = arith.constant 1 : i32
+
+  // This loop execute with signed comparison, but not unsigned, because it is crossing 0.
+  // CHECK: "test.trip-count" = 0
+  %r6 = scf.for unsigned %i = %c-1 to %c1 step %step iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r6 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_bounds_signed(
+func.func @trip_count_index_negative_bounds_signed(%a : i32, %b : i32) -> i32 {
+  %c-10 = arith.constant -10 : index
+  %c-1 = arith.constant -1 : index
+  %c2 = arith.constant 2 : index
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 5
+  %r7 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_bounds_signed(
+func.func @trip_count_i32_negative_bounds_signed(%a : i32, %b : i32) -> i32 {
+  %c-10 = arith.constant -10 : i32
+  %c-1 = arith.constant -1 : i32
+  %c2 = arith.constant 2 : i32
+
+  // This loop execute with signed comparison, because both bounds are
+  // negative and there is no crossing of 0 here.
+  // CHECK: "test.trip-count" = 5
+  %r7 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_bounds_unsigned(
+func.func @trip_count_index_negative_bounds_unsigned(%a : i32, %b : i32) -> i32 {
+  %c-10 = arith.constant -10 : index
+  %c-1 = arith.constant -1 : index
+  %c2 = arith.constant 2 : index
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 5
+  %r8 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r8 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_bounds_unsigned(
+func.func @trip_count_i32_negative_bounds_unsigned(%a : i32, %b : i32) -> i32 {
+  %c-10 = arith.constant -10 : i32
+  %c-1 = arith.constant -1 : i32
+  %c2 = arith.constant 2 : i32
+
+  // CHECK: "test.trip-count" = 5
+  %r8 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r8 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_signed(
+func.func @trip_count_index_overflow_signed(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c_max = arith.constant 2147483647 : index   // 2^31 - 1
+  %c_min = arith.constant 2147483648 : index  // -2^31
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 1
+  %r9 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r9 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_signed(
+func.func @trip_count_i32_overflow_signed(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : i32
+  %c_max = arith.constant 2147483647 : i32   // 2^31 - 1
+  %c_min = arith.constant 2147483648 : i32  // -2^31
+
+  // This loop crosses the 2^31 threshold, which would overflow a signed 32-bit integer.
+  // CHECK: "test.trip-count" = 0
+  %r9 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r9 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_signed_dyn_step(
+func.func @trip_count_i32_overflow_signed_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 {
+  %c_max = arith.constant 2147483647 : i32   // 2^31 - 1
+  %c_min = arith.constant 2147483648 : i32  // -2^31
+
+  // This loop crosses the 2^31 threshold, which would overflow a signed 32-bit integer.
+  // CHECK: "test.trip-count" = 0
+  %r9 = scf.for %i = %c_max to %c_min step %step iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r9 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_unsigned(
+func.func @trip_count_index_overflow_unsigned(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c_max = arith.constant 2147483647 : index   // 2^31 - 1
+  %c_min = arith.constant 2147483648 : index  // -2^31
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 1
+  %r10 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r10 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_unsigned(
+func.func @trip_count_i32_overflow_unsigned(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : i32
+  %c_max = arith.constant 2147483647 : i32   // 2^31 - 1
+  %c_min = arith.constant 2147483648 : i32  // -2^31
+
+  // The same loop with unsigned comparison executes normally
+  // CHECK: "test.trip-count" = 1
+  %r10 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i32 {
+    scf.yield %b : i32
+  }
+  return %r10 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_64bit_signed(
+func.func @trip_count_index_overflow_64bit_signed(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c_max = arith.constant 9223372036854775807 : index   // 2^63 - 1
+  %c_min = arith.constant -9223372036854775808 : index  // -2^63
+
+  // This loop crosses the 2^63 threshold, which would overflow a signed 64-bit integer.
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount.
+  // CHECK: "test.trip-count" = 0
+  %r11 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r11 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i64_overflow_64bit_signed(
+func.func @trip_count_i64_overflow_64bit_signed(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : i64
+  %c_max = arith.constant 9223372036854775807 : i64   // 2^63 - 1
+  %c_min = arith.constant -9223372036854775808 : i64  // -2^63
+
+  // This loop crosses the 2^63 threshold, which would overflow a signed 64-bit integer.
+  // CHECK: "test.trip-count" = 0
+  %r11 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i64 {
+    scf.yield %b : i32
+  }
+  return %r11 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_64bit_unsigned(
+func.func @trip_count_index_overflow_64bit_unsigned(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : index
+  %c_max = arith.constant 9223372036854775807 : index   // 2^63 - 1
+  %c_min = arith.constant -9223372036854775808 : index  // -2^63
+
+  // Index type has a unknown bitwidth, we can't compute a loop tripcount
+  // in theory because of overflow concerns.
+  // CHECK: "test.trip-count" = 1
+  %r12 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+    scf.yield %b : i32
+  }
+  return %r12 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_64bit_unsigned(
+func.func @trip_count_i32_overflow_64bit_unsigned(%a : i32, %b : i32) -> i32 {
+  %c1 = arith.constant 1 : i64
+  %c_max = arith.constant 9223372036854775807 : i64   // 2^63 - 1
+  %c_min = arith.constant -9223372036854775808 : i64  // -2^63
+
+  // The same loop with unsigned comparison executes normally
+  // CHECK: "test.trip-count" = 1
+  %r12 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i64 {
+    scf.yield %b : i32
+  }
+  return %r12 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_step_greater_than_iteration(
+func.func @trip_count_step_greater_than_iteration() -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c17_i32 = arith.constant 17 : i32
+  %c16_i32 = arith.constant 16 : i32
+  // CHECK: "test.trip-count" = 1
+  %1 = scf.for %arg0 = %c16_i32 to %c17_i32 step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add(
+func.func @trip_count_arith_add(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c17_i32 = arith.constant 17 : i32
+  %c16_i32 = arith.constant 16 : i32
+  // Can't compute a trip-count in the absence of overflow flag.
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %lb, %c16_i32 : i32
+  %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative(
+func.func @trip_count_arith_add_negative(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  // Can't compute a trip-count in the absence of overflow flag.
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %lb, %c-16_i32 : i32
+  %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nsw_loop_signed(
+func.func @trip_count_arith_add_nsw_loop_signed(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c16_i32 = arith.constant 16 : i32
+  %ub = arith.addi %lb, %c16_i32 overflow<nsw> : i32
+  // CHECK: "test.trip-count" = 4
+  %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_signed(
+func.func @trip_count_arith_add_negative_nsw_loop_signed(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  %ub = arith.addi %lb, %c-16_i32 overflow<nsw> : i32
+  // CHECK: "test.trip-count" = 0
+  %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_signed_step_dyn(
+func.func @trip_count_arith_add_negative_nsw_loop_signed_step_dyn(%lb : i32, %step : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  %ub = arith.addi %lb, %c-16_i32 overflow<nsw> : i32
+  // CHECK: "test.trip-count" = 0
+  %1 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nsw_loop_unsigned(
+func.func @trip_count_arith_add_nsw_loop_unsigned(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c16_i32 = arith.constant 16 : i32
+  // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %lb, %c16_i32 overflow<nsw> : i32
+  %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_unsigned(
+func.func @trip_count_arith_add_negative_nsw_loop_unsigned(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %lb, %c-16_i32 overflow<nsw> : i32
+  %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_signed(
+func.func @trip_count_arith_add_nuw_loop_signed(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c16_i32 = arith.constant 16 : i32
+  // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %lb, %c16_i32 overflow<nuw> : i32
+  %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_signed(
+func.func @trip_count_arith_add_negative_nuw_loop_signed(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %lb, %c-16_i32 overflow<nuw> : i32
+  %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_unsigned(
+func.func @trip_count_arith_add_nuw_loop_unsigned(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c16_i32 = arith.constant 16 : i32
+  // CHECK: "test.trip-count" = 4
+  %ub = arith.addi %lb, %c16_i32 overflow<nuw> : i32
+  %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_unsigned(
+func.func @trip_count_arith_add_negative_nuw_loop_unsigned(%lb : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  // CHECK: "test.trip-count" = 0
+  %ub = arith.addi %lb, %c-16_i32 overflow<nuw> : i32
+  %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_unsigned_step_dyn(
+func.func @trip_count_arith_add_negative_nuw_loop_unsigned_step_dyn(%lb : i32, %step : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c-16_i32 = arith.constant -16 : i32
+  // CHECK: "test.trip-count" = 0
+  %ub = arith.addi %lb, %c-16_i32 overflow<nuw> : i32
+  %1 = scf.for unsigned %arg0 = %lb to %ub step %step iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(
+func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(%lb : i32, %other : i32) -> i32 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i32 = arith.constant 4 : i32
+  %c16_i32 = arith.constant 16 : i32
+  // The addition here is not adding from %lb
+  // CHECK: "test.trip-count" = "none"
+  %ub = arith.addi %other, %c16_i32 overflow<nuw> : i32
+  %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32)  : i32 {
+    scf.yield %arg0 : i32
+  }
+  return %1 : i32
+}
\ No newline at end of file
diff --git a/test/lib/Dialect/SCF/TestSCFUtils.cpp b/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 9a394d2..6199cb1 100644
--- a/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -42,6 +42,20 @@
   void runOnOperation() override {
     func::FuncOp func = getOperation();
 
+    // Annotate every loop-like operation with the static trip count.
+    func.walk([&](LoopLikeOpInterface loopOp) {
+      std::optional<APInt> tripCount = loopOp.getStaticTripCount();
+      if (tripCount.has_value())
+        loopOp->setDiscardableAttr(
+            "test.trip-count",
+            IntegerAttr::get(IntegerType::get(&getContext(),
+                                              tripCount.value().getBitWidth()),
+                             tripCount.value().getSExtValue()));
+      else
+        loopOp->setDiscardableAttr("test.trip-count",
+                                   StringAttr::get(&getContext(), "none"));
+    });
+
     if (testReplaceWithNewYields) {
       func.walk([&](scf::ForOp forOp) {
         if (forOp.getNumResults() == 0)