[MLIR] Add a getStaticTripCount method to LoopLikeOpInterface (#158679)
This patch adds a `getStaticTripCount` to the LoopLikeOpInterface,
allowing loops to optionally return a static trip count when possible.
This is implemented on SCF ForOp, revamping the implementation of
`constantTripCount`, removing redundant duplicate implementations from
SCF.cpp.
GitOrigin-RevId: 75469bb376b7fa931d03ea95f6b7142d27ef0e51
diff --git a/include/mlir/Dialect/SCF/IR/SCFOps.td b/include/mlir/Dialect/SCF/IR/SCFOps.td
index d3c01c3..fadd3fc 100644
--- a/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -152,7 +152,7 @@
[AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
- "getLoopUpperBounds", "getYieldedValuesMutable",
+ "getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
diff --git a/include/mlir/Dialect/Utils/StaticValueUtils.h b/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376f..2e7f85c 100644
--- a/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -106,6 +106,10 @@
ArrayRef<int64_t> values);
/// If ofr is a constant integer or an IntegerAttr, return the integer.
+/// The second return value indicates whether the value is an index type
+/// and thus the bitwidth is not defined (the APInt will be set with 64bits).
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr);
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr);
/// If all ofrs are constant integers or IntegerAttrs, return the integers.
std::optional<SmallVector<int64_t>>
@@ -201,9 +205,26 @@
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
/// Return the number of iterations for a loop with a lower bound `lb`, upper
-/// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
- OpFoldResult step);
+/// bound `ub` and step `step`. The `isSigned` flag indicates whether the loop
+/// comparison between lb and ub is signed or unsigned. A negative step or a
+/// lower bound greater than the upper bound are considered invalid and will
+/// yield a zero trip count.
+/// The `computeUbMinusLb` callback is invoked to compute the difference between
+/// the upper and lower bound when not constant. It can be used by the client
+/// to compute a static difference when the bounds are not constant.
+///
+/// For example, the following code:
+///
+/// %ub = arith.addi nsw %lb, %c16_i32 : i32
+/// %1 = scf.for %arg0 = %lb to %ub ...
+///
+/// where %ub is computed as a static offset from %lb.
+/// Note: the matched addition should be nsw/nuw (matching the loop comparison)
+/// to avoid overflow, otherwise an overflow would imply a zero trip count.
+std::optional<APInt> constantTripCount(
+ OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+ llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+ computeUbMinusLb);
/// Idiomatic saturated operations on values like offsets, sizes, and strides.
struct SaturatedInteger {
diff --git a/include/mlir/Interfaces/LoopLikeInterface.td b/include/mlir/Interfaces/LoopLikeInterface.td
index 6c95b48..cfd15a7 100644
--- a/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/include/mlir/Interfaces/LoopLikeInterface.td
@@ -232,6 +232,17 @@
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
+ >,
+ InterfaceMethod<[{
+ Compute the static trip count if possible.
+ }],
+ /*retTy=*/"::std::optional<APInt>",
+ /*methodName=*/"getStaticTripCount",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::std::nullopt;
+ }]
>
];
diff --git a/lib/Dialect/SCF/IR/SCF.cpp b/lib/Dialect/SCF/IR/SCF.cpp
index c35989e..ae55ead 100644
--- a/lib/Dialect/SCF/IR/SCF.cpp
+++ b/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,6 +19,8 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
@@ -26,6 +28,9 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
+#include <optional>
using namespace mlir;
using namespace mlir::scf;
@@ -105,6 +110,24 @@
return nullptr;
}
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
+ bool isSigned) {
+ llvm::APSInt diff;
+ auto addOp = ub.getDefiningOp<arith::AddIOp>();
+ if (!addOp)
+ return std::nullopt;
+ if ((isSigned && !addOp.hasNoSignedWrap()) ||
+ (!isSigned && !addOp.hasNoUnsignedWrap()))
+ return std::nullopt;
+
+ if (addOp.getLhs() != lb ||
+ !matchPattern(addOp.getRhs(), m_ConstantInt(&diff)))
+ return std::nullopt;
+ return diff;
+}
+
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@@ -408,11 +431,19 @@
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
- std::optional<int64_t> tripCount =
- constantTripCount(getLowerBound(), getUpperBound(), getStep());
- if (!tripCount.has_value() || tripCount != 1)
+ std::optional<APInt> tripCount = getStaticTripCount();
+ LDBG() << "promoteIfSingleIteration tripCount is " << tripCount
+ << " for loop "
+ << OpWithFlags(getOperation(), OpPrintingFlags().skipRegions());
+ if (!tripCount.has_value() || tripCount->getSExtValue() > 1)
return failure();
+ if (*tripCount == 0) {
+ rewriter.replaceAllUsesWith(getResults(), getInitArgs());
+ rewriter.eraseOp(*this);
+ return success();
+ }
+
// Replace all results with the yielded values.
auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
@@ -646,7 +677,8 @@
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
for (auto [lb, ub, step] :
llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) {
- auto tripCount = constantTripCount(lb, ub, step);
+ auto tripCount =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!tripCount.has_value() || *tripCount != 1)
return failure();
}
@@ -1003,27 +1035,6 @@
}
};
-/// Util function that tries to compute a constant diff between u and l.
-/// Returns std::nullopt when the difference between two AffineValueMap is
-/// dynamic.
-static std::optional<APInt> computeConstDiff(Value l, Value u) {
- IntegerAttr clb, cub;
- if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
- llvm::APInt lbValue = clb.getValue();
- llvm::APInt ubValue = cub.getValue();
- return ubValue - lbValue;
- }
-
- // Else a simple pattern match for x + c or c + x
- llvm::APInt diff;
- if (matchPattern(
- u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
- matchPattern(
- u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff;
- return std::nullopt;
-}
-
/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1043,21 @@
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
- // If the upper bound is the same as the lower bound, the loop does not
- // iterate, just remove it.
- if (op.getLowerBound() == op.getUpperBound()) {
+ std::optional<APInt> tripCount = op.getStaticTripCount();
+ if (!tripCount.has_value())
+ return rewriter.notifyMatchFailure(op,
+ "can't compute constant trip count");
+
+ if (tripCount->isZero()) {
+ LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
- std::optional<APInt> diff =
- computeConstDiff(op.getLowerBound(), op.getUpperBound());
- if (!diff)
- return failure();
-
- // If the loop is known to have 0 iterations, remove it.
- bool zeroOrLessIterations =
- diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
- if (zeroOrLessIterations) {
- rewriter.replaceOp(op, op.getInitArgs());
- return success();
- }
-
- std::optional<llvm::APInt> maybeStepValue = op.getConstantStep();
- if (!maybeStepValue)
- return failure();
-
- // If the loop is known to have 1 iteration, inline its body and remove the
- // loop.
- llvm::APInt stepValue = *maybeStepValue;
- if (stepValue.sge(*diff)) {
+ if (tripCount->getSExtValue() == 1) {
+ LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getInitArgs().size() + 1);
blockArgs.push_back(op.getLowerBound());
@@ -1072,11 +1070,14 @@
Block &block = op.getRegion().front();
if (!llvm::hasSingleElement(block))
return failure();
- // If the loop is empty, iterates at least once, and only returns values
+ // The loop is empty and iterates at least once, if it only returns values
// defined outside of the loop, remove it and replace it with yield values.
if (llvm::any_of(op.getYieldedValues(),
[&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
return failure();
+ LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
+ "yield operands for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
rewriter.replaceOp(op, op.getYieldedValues());
return success();
}
@@ -1172,6 +1173,11 @@
return Speculation::NotSpeculatable;
}
+std::optional<APInt> ForOp::getStaticTripCount() {
+ return constantTripCount(getLowerBound(), getUpperBound(), getStep(),
+ /*isSigned=*/!getUnsignedCmp(), computeUbMinusLb);
+}
+
//===----------------------------------------------------------------------===//
// ForallOp
//===----------------------------------------------------------------------===//
@@ -1768,7 +1774,8 @@
for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) {
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
@@ -1839,7 +1846,8 @@
op.getMixedStep(), op.getInductionVars())) {
if (iv.hasNUses(0))
continue;
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (!numIterations.has_value() || numIterations.value() != 1) {
continue;
}
@@ -3084,7 +3092,8 @@
for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
- auto numIterations = constantTripCount(lb, ub, step);
+ auto numIterations =
+ constantTripCount(lb, ub, step, /*isSigned=*/true, computeUbMinusLb);
if (numIterations.has_value()) {
// Remove the loop if it performs zero iterations.
if (*numIterations == 0) {
diff --git a/lib/Dialect/SCF/Utils/Utils.cpp b/lib/Dialect/SCF/Utils/Utils.cpp
index fc93cf3..18f139c 100644
--- a/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/lib/Dialect/SCF/Utils/Utils.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
@@ -290,14 +291,6 @@
return arith::DivUIOp::create(builder, loc, sum, divisor);
}
-/// Returns the trip count of `forOp` if its' low bound, high bound and step are
-/// constants, or optional otherwise. Trip count is computed as
-/// ceilDiv(highBound - lowBound, step).
-static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
- return constantTripCount(forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep());
-}
-
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -376,7 +369,7 @@
Value stepUnrolled;
bool generateEpilogueLoop = true;
- std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+ std::optional<APInt> constTripCount = forOp.getStaticTripCount();
if (constTripCount) {
// Constant loop bounds computation.
int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
@@ -390,7 +383,8 @@
}
int64_t tripCountEvenMultiple =
- *constTripCount - (*constTripCount % unrollFactor);
+ constTripCount->getSExtValue() -
+ (constTripCount->getSExtValue() % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -486,15 +480,15 @@
/// Unrolls this loop completely.
LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
IRRewriter rewriter(forOp.getContext());
- std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+ std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
if (!mayBeConstantTripCount.has_value())
return failure();
- uint64_t tripCount = *mayBeConstantTripCount;
- if (tripCount == 0)
+ APInt &tripCount = *mayBeConstantTripCount;
+ if (tripCount.isZero())
return success();
- if (tripCount == 1)
+ if (tripCount.getSExtValue() == 1)
return forOp.promoteIfSingleIteration(rewriter);
- return loopUnrollByFactor(forOp, tripCount);
+ return loopUnrollByFactor(forOp, tripCount.getSExtValue());
}
/// Check if bounds of all inner loops are defined outside of `forOp`
@@ -534,18 +528,18 @@
// Currently, only constant trip count that divided by the unroll factor is
// supported.
- std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ std::optional<APInt> tripCount = forOp.getStaticTripCount();
if (!tripCount.has_value()) {
// If the trip count is dynamic, do not unroll & jam.
LDBG() << "failed to unroll and jam: trip count could not be determined";
return failure();
}
- if (unrollJamFactor > *tripCount) {
+ if (unrollJamFactor > tripCount->getZExtValue()) {
LDBG() << "unroll and jam factor is greater than trip count, set factor to "
"trip "
"count";
- unrollJamFactor = *tripCount;
- } else if (*tripCount % unrollJamFactor != 0) {
+ unrollJamFactor = tripCount->getZExtValue();
+ } else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
LDBG() << "failed to unroll and jam: unsupported trip count that is not a "
"multiple of unroll jam factor";
return failure();
diff --git a/lib/Dialect/Utils/StaticValueUtils.cpp b/lib/Dialect/Utils/StaticValueUtils.cpp
index 34385d7..5048b19 100644
--- a/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
@@ -112,21 +113,30 @@
}
/// If ofr is a constant integer or an IntegerAttr, return the integer.
-std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+/// The boolean indicates whether the value is an index type.
+std::optional<std::pair<APInt, bool>> getConstantAPIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
- APSInt intVal;
+ APInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
- return intVal.getSExtValue();
+ return std::make_pair(intVal, val.getType().isIndex());
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
- return intAttr.getValue().getSExtValue();
+ return std::make_pair(intAttr.getValue(), intAttr.getType().isIndex());
return std::nullopt;
}
+/// If ofr is a constant integer or an IntegerAttr, return the integer.
+std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
+ std::optional<std::pair<APInt, bool>> apInt = getConstantAPIntValue(ofr);
+ if (!apInt)
+ return std::nullopt;
+ return apInt->first.getSExtValue();
+}
+
std::optional<SmallVector<int64_t>>
getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
bool failed = false;
@@ -264,22 +274,108 @@
/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
-std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
- OpFoldResult step) {
+std::optional<APInt> constantTripCount(
+ OpFoldResult lb, OpFoldResult ub, OpFoldResult step, bool isSigned,
+ llvm::function_ref<std::optional<llvm::APSInt>(Value, Value, bool)>
+ computeUbMinusLb) {
+ // This is the bitwidth used to return 0 when loop does not execute.
+ // We infer it from the type of the bound if it isn't an index type.
+ bool isIndex = true;
+ auto getBitwidth = [&](OpFoldResult ofr) -> int {
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
+ isIndex = intType.isIndex();
+ return intType.getWidth();
+ }
+ }
+ } else {
+ auto val = cast<Value>(ofr);
+ if (auto intType = dyn_cast<IntegerType>(val.getType())) {
+ isIndex = intType.isIndex();
+ return intType.getWidth();
+ }
+ }
+ return IndexType::kInternalStorageBitWidth;
+ };
+ int bitwidth = getBitwidth(lb);
+ assert(bitwidth == getBitwidth(ub) &&
+ "lb and ub must have the same bitwidth");
if (lb == ub)
- return 0;
+ return APInt(bitwidth, 0);
- std::optional<int64_t> lbConstant = getConstantIntValue(lb);
- if (!lbConstant)
- return std::nullopt;
- std::optional<int64_t> ubConstant = getConstantIntValue(ub);
- if (!ubConstant)
- return std::nullopt;
- std::optional<int64_t> stepConstant = getConstantIntValue(step);
- if (!stepConstant || *stepConstant == 0)
- return std::nullopt;
+ std::optional<std::pair<APInt, bool>> maybeStepCst =
+ getConstantAPIntValue(step);
- return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
+ if (maybeStepCst) {
+ auto &stepCst = maybeStepCst->first;
+ assert(static_cast<int>(stepCst.getBitWidth()) == bitwidth &&
+ "step must have the same bitwidth as lb and ub");
+ if (stepCst.isZero())
+ return stepCst;
+ if (stepCst.isNegative())
+ return APInt(bitwidth, 0);
+ }
+
+ if (isIndex) {
+ LDBG()
+ << "Computing loop trip count for index type may break with overflow";
+ // TODO: we can't compute the trip count for index type. We should fix this
+ // but too many tests are failing right now.
+ // return {};
+ }
+
+ /// Compute the difference between the upper and lower bound: either from the
+ /// constant value or using the computeUbMinusLb callback.
+ llvm::APSInt diff;
+ std::optional<std::pair<APInt, bool>> maybeLbCst = getConstantAPIntValue(lb);
+ std::optional<std::pair<APInt, bool>> maybeUbCst = getConstantAPIntValue(ub);
+ if (maybeLbCst) {
+ // If one of the bounds is not a constant, we can't compute the trip count.
+ if (!maybeUbCst)
+ return std::nullopt;
+ APSInt lbCst(maybeLbCst->first, /*isUnsigned=*/!isSigned);
+ APSInt ubCst(maybeUbCst->first, /*isUnsigned=*/!isSigned);
+ if (!maybeUbCst)
+ return std::nullopt;
+ if (ubCst <= lbCst) {
+ LDBG() << "constantTripCount is 0 because ub <= lb (" << lbCst << "("
+ << lbCst.getBitWidth() << ") <= " << ubCst << "("
+ << ubCst.getBitWidth() << "), "
+ << (isSigned ? "isSigned" : "isUnsigned") << ")";
+ return APInt(bitwidth, 0);
+ }
+ diff = ubCst - lbCst;
+ } else {
+ if (maybeUbCst)
+ return std::nullopt;
+
+ /// Non-constant bound, let's try to compute the difference between the
+ /// upper and lower bound
+ std::optional<llvm::APSInt> maybeDiff =
+ computeUbMinusLb(cast<Value>(lb), cast<Value>(ub), isSigned);
+ if (!maybeDiff)
+ return std::nullopt;
+ diff = *maybeDiff;
+ }
+ LDBG() << "constantTripCount: " << (isSigned ? "isSigned" : "isUnsigned")
+ << ", ub-lb: " << diff << "(" << diff.getBitWidth() << "b)";
+ if (diff.isNegative()) {
+ LDBG() << "constantTripCount is 0 because ub-lb diff is negative";
+ return APInt(bitwidth, 0);
+ }
+ if (!maybeStepCst) {
+ LDBG()
+ << "constantTripCount can't be computed because step is not a constant";
+ return std::nullopt;
+ }
+ auto &stepCst = maybeStepCst->first;
+ llvm::APInt tripCount = diff.sdiv(stepCst);
+ llvm::APInt r = diff.srem(stepCst);
+ if (!r.isZero())
+ tripCount = tripCount + 1;
+ LDBG() << "constantTripCount found: " << tripCount;
+ return tripCount;
}
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
diff --git a/test/Dialect/SCF/canonicalize.mlir b/test/Dialect/SCF/canonicalize.mlir
index 4ad2da8..5e89f74 100644
--- a/test/Dialect/SCF/canonicalize.mlir
+++ b/test/Dialect/SCF/canonicalize.mlir
@@ -749,7 +749,7 @@
// CHECK-NEXT: %[[CST:.*]] = arith.constant 2
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
- %5 = arith.addi %arg0, %c1 : index
+ %5 = arith.addi %arg0, %c1 overflow<nsw> : index
// CHECK-NOT: scf.for
scf.for %arg2 = %arg0 to %5 step %c1 {
// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[A0]], %[[CST]]
@@ -1933,8 +1933,9 @@
// -----
+// Step 0 is invalid, the loop is eliminated.
// CHECK-LABEL: func @scf_for_all_step_size_0()
-// CHECK: scf.forall (%{{.*}}) = (0) to (1) step (0)
+// CHECK-NOT: scf.forall
func.func @scf_for_all_step_size_0() {
%x = arith.constant 0 : index
scf.forall (%i, %j) = (0, 4) to (1, 5) step (%x, 8) {
diff --git a/test/Dialect/SCF/for-loop-peeling.mlir b/test/Dialect/SCF/for-loop-peeling.mlir
index 03c446c..be58548 100644
--- a/test/Dialect/SCF/for-loop-peeling.mlir
+++ b/test/Dialect/SCF/for-loop-peeling.mlir
@@ -328,10 +328,9 @@
// -----
// Regression test: Make sure that we do not crash.
-
+// The step is 0, the loop will be eliminated.
// CHECK-LABEL: func @zero_step(
-// CHECK: scf.for
-// CHECK: scf.for
+// CHECK-NOT: scf.for
func.func @zero_step(%arg0: memref<i64>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
diff --git a/test/Dialect/SCF/trip_count.mlir b/test/Dialect/SCF/trip_count.mlir
new file mode 100644
index 0000000..54883d7
--- /dev/null
+++ b/test/Dialect/SCF/trip_count.mlir
@@ -0,0 +1,702 @@
+// RUN: mlir-opt %s -test-scf-for-utils --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @trip_count_index_zero_to_zero(
+func.func @trip_count_index_zero_to_zero(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK: "test.trip-count" = 0
+ %r = scf.for %i = %c0 to %c0 step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_zero_to_zero_step_dyn(
+func.func @trip_count_index_zero_to_zero_step_dyn(%a : i32, %b : i32, %step : index) -> i32 {
+ %c0 = arith.constant 0 : index
+
+ // CHECK: "test.trip-count" = 0
+ %r = scf.for %i = %c0 to %c0 step %step iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_zero_to_zero(
+func.func @trip_count_i32_zero_to_zero(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+
+ // CHECK: "test.trip-count" = 0
+ %r = scf.for %i = %c0 to %c0 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+
+// CHECK-LABEL: func.func @trip_count_i32_zero_to_zero_step_dyn(
+func.func @trip_count_i32_zero_to_zero_step_dyn(%a : i32, %b : i32, %step : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+
+ // CHECK: "test.trip-count" = 0
+ %r = scf.for %i = %c0 to %c0 step %step iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_one_to_zero(
+func.func @trip_count_index_one_to_zero(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 0
+ %r2 = scf.for %i = %c1 to %c0 step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_one_to_zero(
+func.func @trip_count_i32_one_to_zero(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+
+ // CHECK: "test.trip-count" = 0
+ %r2 = scf.for %i = %c1 to %c0 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_one_to_zero_dyn_step(
+func.func @trip_count_i32_one_to_zero_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+
+ // CHECK: "test.trip-count" = 0
+ %r2 = scf.for %i = %c1 to %c0 step %step iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_step(
+func.func @trip_count_index_negative_step(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c-1 = arith.constant -1 : index
+
+ // Negative step is invalid, loop won't execute.
+ // CHECK: "test.trip-count" = 0
+ %r3 = scf.for %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_step(
+func.func @trip_count_i32_negative_step(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c-1 = arith.constant -1 : i32
+
+ // Negative step is invalid, loop won't execute.
+ // CHECK: "test.trip-count" = 0
+ %r3 = scf.for %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_step_unsigned_loop(
+func.func @trip_count_index_negative_step_unsigned_loop(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c-1 = arith.constant -1 : index
+
+ // Negative step is invalid, loop won't execute.
+ // CHECK: "test.trip-count" = 0
+ %r3 = scf.for unsigned %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_step_unsigned_loop(
+func.func @trip_count_i32_negative_step_unsigned_loop(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c-1 = arith.constant -1 : i32
+
+ // Negative step is invalid, loop won't execute.
+ // CHECK: "test.trip-count" = 0
+ %r3 = scf.for unsigned %i = %c1 to %c0 step %c-1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_normal_loop(
+func.func @trip_count_index_normal_loop(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c10 = arith.constant 10 : index
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 5
+ %r4 = scf.for %i = %c0 to %c10 step %c2 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r4 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_normal_loop(
+func.func @trip_count_i32_normal_loop(%a : i32, %b : i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c2 = arith.constant 2 : i32
+ %c10 = arith.constant 10 : i32
+
+ // Normal loop
+ // CHECK: "test.trip-count" = 5
+ %r4 = scf.for %i = %c0 to %c10 step %c2 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r4 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_signed_crossing_zero(
+func.func @trip_count_index_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
+ %c-1 = arith.constant -1 : index
+ %c1 = arith.constant 1 : index
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 2
+ %r5 = scf.for %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r5 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_signed_crossing_zero(
+func.func @trip_count_i32_signed_crossing_zero(%a : i32, %b : i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %c1 = arith.constant 1 : i32
+
+ // This loop execute with signed comparison, but not unsigned, because it is crossing 0.
+ // CHECK: "test.trip-count" = 2
+ %r5 = scf.for %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r5 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_unsigned_crossing_zero(
+func.func @trip_count_index_unsigned_crossing_zero(%a : i32, %b : i32) -> i32 {
+ %c-1 = arith.constant -1 : index
+ %c1 = arith.constant 1 : index
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 0
+ %r6 = scf.for unsigned %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r6 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_unsigned_crossing_zero(
+func.func @trip_count_i32_unsigned_crossing_zero(%a : i32, %b : i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %c1 = arith.constant 1 : i32
+
+ // This loop execute with signed comparison, but not unsigned, because it is crossing 0.
+ // CHECK: "test.trip-count" = 0
+ %r6 = scf.for unsigned %i = %c-1 to %c1 step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r6 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_unsigned_crossing_zero_dyn_step(
+func.func @trip_count_i32_unsigned_crossing_zero_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %c1 = arith.constant 1 : i32
+
+ // This loop execute with signed comparison, but not unsigned, because it is crossing 0.
+ // CHECK: "test.trip-count" = 0
+ %r6 = scf.for unsigned %i = %c-1 to %c1 step %step iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r6 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_bounds_signed(
+func.func @trip_count_index_negative_bounds_signed(%a : i32, %b : i32) -> i32 {
+ %c-10 = arith.constant -10 : index
+ %c-1 = arith.constant -1 : index
+ %c2 = arith.constant 2 : index
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 5
+ %r7 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_bounds_signed(
+func.func @trip_count_i32_negative_bounds_signed(%a : i32, %b : i32) -> i32 {
+ %c-10 = arith.constant -10 : i32
+ %c-1 = arith.constant -1 : i32
+ %c2 = arith.constant 2 : i32
+
+ // This loop execute with signed comparison, because both bounds are
+ // negative and there is no crossing of 0 here.
+ // CHECK: "test.trip-count" = 5
+ %r7 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r7 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_negative_bounds_unsigned(
+func.func @trip_count_index_negative_bounds_unsigned(%a : i32, %b : i32) -> i32 {
+ %c-10 = arith.constant -10 : index
+ %c-1 = arith.constant -1 : index
+ %c2 = arith.constant 2 : index
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 5
+ %r8 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r8 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_negative_bounds_unsigned(
+func.func @trip_count_i32_negative_bounds_unsigned(%a : i32, %b : i32) -> i32 {
+ %c-10 = arith.constant -10 : i32
+ %c-1 = arith.constant -1 : i32
+ %c2 = arith.constant 2 : i32
+
+ // CHECK: "test.trip-count" = 5
+ %r8 = scf.for %i = %c-10 to %c-1 step %c2 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r8 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_signed(
+func.func @trip_count_index_overflow_signed(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : index
+ %c_max = arith.constant 2147483647 : index // 2^31 - 1
+ %c_min = arith.constant 2147483648 : index // -2^31
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 1
+ %r9 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r9 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_signed(
+func.func @trip_count_i32_overflow_signed(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : i32
+ %c_max = arith.constant 2147483647 : i32 // 2^31 - 1
+ %c_min = arith.constant 2147483648 : i32 // -2^31
+
+ // This loop crosses the 2^31 threshold, which would overflow a signed 32-bit integer.
+ // CHECK: "test.trip-count" = 0
+ %r9 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r9 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_signed_dyn_step(
+func.func @trip_count_i32_overflow_signed_dyn_step(%a : i32, %b : i32, %step : i32) -> i32 {
+ %c_max = arith.constant 2147483647 : i32 // 2^31 - 1
+ %c_min = arith.constant 2147483648 : i32 // -2^31
+
+ // This loop crosses the 2^31 threshold, which would overflow a signed 32-bit integer.
+ // CHECK: "test.trip-count" = 0
+ %r9 = scf.for %i = %c_max to %c_min step %step iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r9 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_unsigned(
+func.func @trip_count_index_overflow_unsigned(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : index
+ %c_max = arith.constant 2147483647 : index // 2^31 - 1
+ %c_min = arith.constant 2147483648 : index // -2^31
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 1
+ %r10 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r10 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_unsigned(
+func.func @trip_count_i32_overflow_unsigned(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : i32
+ %c_max = arith.constant 2147483647 : i32 // 2^31 - 1
+ %c_min = arith.constant 2147483648 : i32 // -2^31
+
+ // The same loop with unsigned comparison executes normally
+ // CHECK: "test.trip-count" = 1
+ %r10 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i32 {
+ scf.yield %b : i32
+ }
+ return %r10 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_64bit_signed(
+func.func @trip_count_index_overflow_64bit_signed(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : index
+ %c_max = arith.constant 9223372036854775807 : index // 2^63 - 1
+ %c_min = arith.constant -9223372036854775808 : index // -2^63
+
+ // This loop crosses the 2^63 threshold, which would overflow a signed 64-bit integer.
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount.
+ // CHECK: "test.trip-count" = 0
+ %r11 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r11 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i64_overflow_64bit_signed(
+func.func @trip_count_i64_overflow_64bit_signed(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : i64
+ %c_max = arith.constant 9223372036854775807 : i64 // 2^63 - 1
+ %c_min = arith.constant -9223372036854775808 : i64 // -2^63
+
+ // This loop crosses the 2^63 threshold, which would overflow a signed 64-bit integer.
+ // CHECK: "test.trip-count" = 0
+ %r11 = scf.for %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i64 {
+ scf.yield %b : i32
+ }
+ return %r11 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_index_overflow_64bit_unsigned(
+func.func @trip_count_index_overflow_64bit_unsigned(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : index
+ %c_max = arith.constant 9223372036854775807 : index // 2^63 - 1
+ %c_min = arith.constant -9223372036854775808 : index // -2^63
+
+ // Index type has a unknown bitwidth, we can't compute a loop tripcount
+ // in theory because of overflow concerns.
+ // CHECK: "test.trip-count" = 1
+ %r12 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 {
+ scf.yield %b : i32
+ }
+ return %r12 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @trip_count_i32_overflow_64bit_unsigned(
+func.func @trip_count_i32_overflow_64bit_unsigned(%a : i32, %b : i32) -> i32 {
+ %c1 = arith.constant 1 : i64
+ %c_max = arith.constant 9223372036854775807 : i64 // 2^63 - 1
+ %c_min = arith.constant -9223372036854775808 : i64 // -2^63
+
+ // The same loop with unsigned comparison executes normally
+ // CHECK: "test.trip-count" = 1
+ %r12 = scf.for unsigned %i = %c_max to %c_min step %c1 iter_args(%0 = %a) -> i32 : i64 {
+ scf.yield %b : i32
+ }
+ return %r12 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_step_greater_than_iteration(
+func.func @trip_count_step_greater_than_iteration() -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c17_i32 = arith.constant 17 : i32
+ %c16_i32 = arith.constant 16 : i32
+ // CHECK: "test.trip-count" = 1
+ %1 = scf.for %arg0 = %c16_i32 to %c17_i32 step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add(
+func.func @trip_count_arith_add(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c17_i32 = arith.constant 17 : i32
+ %c16_i32 = arith.constant 16 : i32
+ // Can't compute a trip-count in the absence of overflow flag.
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %lb, %c16_i32 : i32
+ %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative(
+func.func @trip_count_arith_add_negative(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ // Can't compute a trip-count in the absence of overflow flag.
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %lb, %c-16_i32 : i32
+ %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nsw_loop_signed(
+func.func @trip_count_arith_add_nsw_loop_signed(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c16_i32 = arith.constant 16 : i32
+ %ub = arith.addi %lb, %c16_i32 overflow<nsw> : i32
+ // CHECK: "test.trip-count" = 4
+ %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_signed(
+func.func @trip_count_arith_add_negative_nsw_loop_signed(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ %ub = arith.addi %lb, %c-16_i32 overflow<nsw> : i32
+ // CHECK: "test.trip-count" = 0
+ %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_signed_step_dyn(
+func.func @trip_count_arith_add_negative_nsw_loop_signed_step_dyn(%lb : i32, %step : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ %ub = arith.addi %lb, %c-16_i32 overflow<nsw> : i32
+ // CHECK: "test.trip-count" = 0
+ %1 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nsw_loop_unsigned(
+func.func @trip_count_arith_add_nsw_loop_unsigned(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c16_i32 = arith.constant 16 : i32
+ // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %lb, %c16_i32 overflow<nsw> : i32
+ %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nsw_loop_unsigned(
+func.func @trip_count_arith_add_negative_nsw_loop_unsigned(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %lb, %c-16_i32 overflow<nsw> : i32
+ %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_signed(
+func.func @trip_count_arith_add_nuw_loop_signed(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c16_i32 = arith.constant 16 : i32
+ // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %lb, %c16_i32 overflow<nuw> : i32
+ %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_signed(
+func.func @trip_count_arith_add_negative_nuw_loop_signed(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ // Can't compute a trip-count when the overflow flag mismatches the loop comparison signess
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %lb, %c-16_i32 overflow<nuw> : i32
+ %1 = scf.for %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_unsigned(
+func.func @trip_count_arith_add_nuw_loop_unsigned(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c16_i32 = arith.constant 16 : i32
+ // CHECK: "test.trip-count" = 4
+ %ub = arith.addi %lb, %c16_i32 overflow<nuw> : i32
+ %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_unsigned(
+func.func @trip_count_arith_add_negative_nuw_loop_unsigned(%lb : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ // CHECK: "test.trip-count" = 0
+ %ub = arith.addi %lb, %c-16_i32 overflow<nuw> : i32
+ %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_negative_nuw_loop_unsigned_step_dyn(
+func.func @trip_count_arith_add_negative_nuw_loop_unsigned_step_dyn(%lb : i32, %step : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c-16_i32 = arith.constant -16 : i32
+ // CHECK: "test.trip-count" = 0
+ %ub = arith.addi %lb, %c-16_i32 overflow<nuw> : i32
+ %1 = scf.for unsigned %arg0 = %lb to %ub step %step iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// CHECK-LABEL:func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(
+func.func @trip_count_arith_add_nuw_loop_unsigned_invalid(%lb : i32, %other : i32) -> i32 {
+ %c0_i32 = arith.constant 0 : i32
+ %c4_i32 = arith.constant 4 : i32
+ %c16_i32 = arith.constant 16 : i32
+ // The addition here is not adding from %lb
+ // CHECK: "test.trip-count" = "none"
+ %ub = arith.addi %other, %c16_i32 overflow<nuw> : i32
+ %1 = scf.for unsigned %arg0 = %lb to %ub step %c4_i32 iter_args(%arg1 = %c0_i32) -> (i32) : i32 {
+ scf.yield %arg0 : i32
+ }
+ return %1 : i32
+}
\ No newline at end of file
diff --git a/test/lib/Dialect/SCF/TestSCFUtils.cpp b/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 9a394d2..6199cb1 100644
--- a/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -42,6 +42,20 @@
void runOnOperation() override {
func::FuncOp func = getOperation();
+ // Annotate every loop-like operation with the static trip count.
+ func.walk([&](LoopLikeOpInterface loopOp) {
+ std::optional<APInt> tripCount = loopOp.getStaticTripCount();
+ if (tripCount.has_value())
+ loopOp->setDiscardableAttr(
+ "test.trip-count",
+ IntegerAttr::get(IntegerType::get(&getContext(),
+ tripCount.value().getBitWidth()),
+ tripCount.value().getSExtValue()));
+ else
+ loopOp->setDiscardableAttr("test.trip-count",
+ StringAttr::get(&getContext(), "none"));
+ });
+
if (testReplaceWithNewYields) {
func.walk([&](scf::ForOp forOp) {
if (forOp.getNumResults() == 0)