| //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <utility> |
| |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| #include "mlir/Analysis/DataFlow/Utils.h" |
| #include "mlir/Analysis/DataFlowFramework.h" |
| #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| |
| #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
| #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Interfaces/LoopLikeInterface.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Transforms/FoldUtils.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir::arith { |
| #define GEN_PASS_DEF_ARITHINTRANGEOPTS |
| #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
| |
| #define GEN_PASS_DEF_ARITHINTRANGENARROWING |
| #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
| } // namespace mlir::arith |
| |
| using namespace mlir; |
| using namespace mlir::arith; |
| using namespace mlir::dataflow; |
| |
| static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver, |
| Value value) { |
| auto *maybeInferredRange = |
| solver.lookupState<IntegerValueRangeLattice>(value); |
| if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) |
| return std::nullopt; |
| const ConstantIntRanges &inferredRange = |
| maybeInferredRange->getValue().getValue(); |
| return inferredRange.getConstantValue(); |
| } |
| |
| static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, |
| Value newVal) { |
| auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal); |
| if (!oldState) |
| return; |
| (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join( |
| *oldState); |
| } |
| |
| namespace mlir::dataflow { |
| /// Patterned after SCCP |
| LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, |
| RewriterBase &rewriter, Value value) { |
| if (value.use_empty()) |
| return failure(); |
| std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value); |
| if (!maybeConstValue.has_value()) |
| return failure(); |
| |
| Type type = value.getType(); |
| // If the type or element type is non-integral, the attribute constructor |
| // will crash, so eagerly check for an integer type to avoid this. |
| if (!getElementTypeOrSelf(type).isIntOrIndex()) |
| return failure(); |
| Location loc = value.getLoc(); |
| Operation *maybeDefiningOp = value.getDefiningOp(); |
| Dialect *valueDialect = |
| maybeDefiningOp ? maybeDefiningOp->getDialect() |
| : value.getParentRegion()->getParentOp()->getDialect(); |
| |
| Attribute constAttr; |
| if (auto shaped = dyn_cast<ShapedType>(type)) { |
| constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); |
| } else { |
| constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); |
| } |
| Operation *constOp = |
| valueDialect->materializeConstant(rewriter, constAttr, type, loc); |
| // Fall back to arith.constant if the dialect materializer doesn't know what |
| // to do with an integer constant. |
| if (!constOp) |
| constOp = rewriter.getContext() |
| ->getLoadedDialect<ArithDialect>() |
| ->materializeConstant(rewriter, constAttr, type, loc); |
| if (!constOp) |
| return failure(); |
| |
| OpResult res = constOp->getResult(0); |
| if (solver.lookupState<dataflow::IntegerValueRangeLattice>(res)) |
| solver.eraseState(res); |
| copyIntegerRange(solver, value, res); |
| rewriter.replaceAllUsesWith(value, res); |
| return success(); |
| } |
| } // namespace mlir::dataflow |
| |
| namespace { |
| class DataFlowListener : public RewriterBase::Listener { |
| public: |
| DataFlowListener(DataFlowSolver &s) : s(s) {} |
| |
| protected: |
| void notifyOperationErased(Operation *op) override { |
| s.eraseState(s.getProgramPointAfter(op)); |
| for (Value res : op->getResults()) |
| s.eraseState(res); |
| } |
| |
| DataFlowSolver &s; |
| }; |
| |
| /// Rewrite any results of `op` that were inferred to be constant integers to |
| /// and replace their uses with that constant. Return success() if all results |
| /// where thus replaced and the operation is erased. Also replace any block |
| /// arguments with their constant values. |
| struct MaterializeKnownConstantValues : public RewritePattern { |
| MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) |
| : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(), |
| /*benefit=*/1, context), |
| solver(s) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| if (matchPattern(op, m_Constant())) |
| return failure(); |
| |
| // We need to check isIntOrIndex() here as well to avoid infinite loops in |
| // the greedy pattern rewriter. If we only check it in |
| // maybeReplaceWithConstant, this lambda might still return true for |
| // non-integral types, causing the pattern to match and claim success |
| // without making any changes, leading to non-convergence. |
| auto needsReplacing = [&](Value v) { |
| return getElementTypeOrSelf(v.getType()).isIntOrIndex() && |
| getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); |
| }; |
| bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); |
| if (op->getNumRegions() == 0) |
| if (!hasConstantResults) |
| return failure(); |
| bool hasConstantRegionArgs = false; |
| for (Region ®ion : op->getRegions()) { |
| for (Block &block : region.getBlocks()) { |
| hasConstantRegionArgs |= |
| llvm::any_of(block.getArguments(), needsReplacing); |
| } |
| } |
| if (!hasConstantResults && !hasConstantRegionArgs) |
| return failure(); |
| |
| bool replacedAll = (op->getNumResults() != 0); |
| for (Value v : op->getResults()) |
| replacedAll &= |
| (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || |
| v.use_empty()); |
| if (replacedAll && isOpTriviallyDead(op)) { |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| |
| PatternRewriter::InsertionGuard guard(rewriter); |
| for (Region ®ion : op->getRegions()) { |
| for (Block &block : region.getBlocks()) { |
| rewriter.setInsertionPointToStart(&block); |
| for (BlockArgument &arg : block.getArguments()) { |
| (void)maybeReplaceWithConstant(solver, rewriter, arg); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| private: |
| DataFlowSolver &solver; |
| }; |
| |
| template <typename RemOp> |
| struct DeleteTrivialRem : public OpRewritePattern<RemOp> { |
| DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) |
| : OpRewritePattern<RemOp>(context), solver(s) {} |
| |
| LogicalResult matchAndRewrite(RemOp op, |
| PatternRewriter &rewriter) const override { |
| Value lhs = op.getOperand(0); |
| Value rhs = op.getOperand(1); |
| auto maybeModulus = getConstantIntValue(rhs); |
| if (!maybeModulus.has_value()) |
| return failure(); |
| int64_t modulus = *maybeModulus; |
| if (modulus <= 0) |
| return failure(); |
| auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs); |
| if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) |
| return failure(); |
| const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); |
| const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin(); |
| const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax(); |
| // The minima and maxima here are given as closed ranges, we must be |
| // strictly less than the modulus. |
| if (min.isNegative() || min.uge(modulus)) |
| return failure(); |
| if (max.isNegative() || max.uge(modulus)) |
| return failure(); |
| if (!min.ule(max)) |
| return failure(); |
| |
| // With all those conditions out of the way, we know thas this invocation of |
| // a remainder is a noop because the input is strictly within the range |
| // [0, modulus), so get rid of it. |
| rewriter.replaceOp(op, ValueRange{lhs}); |
| return success(); |
| } |
| |
| private: |
| DataFlowSolver &solver; |
| }; |
| |
| /// Gather ranges for all the values in `values`. Appends to the existing |
| /// vector. |
| static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, |
| SmallVectorImpl<ConstantIntRanges> &ranges) { |
| for (Value val : values) { |
| auto *maybeInferredRange = |
| solver.lookupState<IntegerValueRangeLattice>(val); |
| if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) |
| return failure(); |
| |
| const ConstantIntRanges &inferredRange = |
| maybeInferredRange->getValue().getValue(); |
| ranges.push_back(inferredRange); |
| } |
| return success(); |
| } |
| |
| /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, |
| /// return shaped type as well. |
| static Type getTargetType(Type srcType, unsigned targetBitwidth) { |
| auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); |
| if (auto shaped = dyn_cast<ShapedType>(srcType)) |
| return shaped.clone(dstType); |
| |
| assert(srcType.isIntOrIndex() && "Invalid src type"); |
| return dstType; |
| } |
| |
| namespace { |
| // Enum for tracking which type of truncation should be performed |
| // to narrow an operation, if any. |
| enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; |
| } // namespace |
| |
| /// If the values within `range` can be represented using only `width` bits, |
| /// return the kind of truncation needed to preserve that property. |
| /// |
| /// This check relies on the fact that the signed and unsigned ranges are both |
| /// always correct, but that one might be an approximation of the other, |
| /// so we want to use the correct truncation operation. |
| static CastKind checkTruncatability(const ConstantIntRanges &range, |
| unsigned targetWidth) { |
| unsigned srcWidth = range.smin().getBitWidth(); |
| if (srcWidth <= targetWidth) |
| return CastKind::None; |
| unsigned removedWidth = srcWidth - targetWidth; |
| // The sign bits need to extend into the sign bit of the target width. For |
| // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign |
| // bits. |
| bool canTruncateSigned = |
| range.smin().getNumSignBits() >= (removedWidth + 1) && |
| range.smax().getNumSignBits() >= (removedWidth + 1); |
| bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && |
| range.umax().countLeadingZeros() >= removedWidth; |
| if (canTruncateSigned && canTruncateUnsigned) |
| return CastKind::Both; |
| if (canTruncateSigned) |
| return CastKind::Signed; |
| if (canTruncateUnsigned) |
| return CastKind::Unsigned; |
| return CastKind::None; |
| } |
| |
| static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { |
| if (lhs == CastKind::None || rhs == CastKind::None) |
| return CastKind::None; |
| if (lhs == CastKind::Both) |
| return rhs; |
| if (rhs == CastKind::Both) |
| return lhs; |
| if (lhs == rhs) |
| return lhs; |
| return CastKind::None; |
| } |
| |
| static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, |
| CastKind castKind) { |
| Type srcType = src.getType(); |
| assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) && |
| "Mixing vector and non-vector types"); |
| assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); |
| Type srcElemType = getElementTypeOrSelf(srcType); |
| Type dstElemType = getElementTypeOrSelf(dstType); |
| assert(srcElemType.isIntOrIndex() && "Invalid src type"); |
| assert(dstElemType.isIntOrIndex() && "Invalid dst type"); |
| if (srcType == dstType) |
| return src; |
| |
| if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) { |
| if (castKind == CastKind::Signed) |
| return arith::IndexCastOp::create(builder, loc, dstType, src); |
| return arith::IndexCastUIOp::create(builder, loc, dstType, src); |
| } |
| |
| auto srcInt = cast<IntegerType>(srcElemType); |
| auto dstInt = cast<IntegerType>(dstElemType); |
| if (dstInt.getWidth() < srcInt.getWidth()) |
| return arith::TruncIOp::create(builder, loc, dstType, src); |
| |
| if (castKind == CastKind::Signed) |
| return arith::ExtSIOp::create(builder, loc, dstType, src); |
| return arith::ExtUIOp::create(builder, loc, dstType, src); |
| } |
| |
| struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> { |
| NarrowElementwise(MLIRContext *context, DataFlowSolver &s, |
| ArrayRef<unsigned> target) |
| : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} |
| |
| using OpTraitRewritePattern::OpTraitRewritePattern; |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| if (op->getNumResults() == 0) |
| return rewriter.notifyMatchFailure(op, "can't narrow resultless op"); |
| |
| // Inline size chosen empirically based on compilation profiling. |
| // Profiled: 2.6M calls, avg=1.7+-1.3. N=4 covers >95% of cases inline. |
| SmallVector<ConstantIntRanges, 4> ranges; |
| if (failed(collectRanges(solver, op->getOperands(), ranges))) |
| return rewriter.notifyMatchFailure(op, "input without specified range"); |
| if (failed(collectRanges(solver, op->getResults(), ranges))) |
| return rewriter.notifyMatchFailure(op, "output without specified range"); |
| |
| Type srcType = op->getResult(0).getType(); |
| if (!llvm::all_equal(op->getResultTypes())) |
| return rewriter.notifyMatchFailure(op, "mismatched result types"); |
| if (op->getNumOperands() == 0 || |
| !llvm::all_of(op->getOperandTypes(), |
| [=](Type t) { return t == srcType; })) |
| return rewriter.notifyMatchFailure( |
| op, "no operands or operand types don't match result type"); |
| |
| for (unsigned targetBitwidth : targetBitwidths) { |
| CastKind castKind = CastKind::Both; |
| for (const ConstantIntRanges &range : ranges) { |
| castKind = mergeCastKinds(castKind, |
| checkTruncatability(range, targetBitwidth)); |
| if (castKind == CastKind::None) |
| break; |
| } |
| // For operations that explicitly treat the values as signed, we should |
| // only do signed casts, if those are deemed possible as such based on the |
| // value range. |
| auto castKindForOp = |
| llvm::TypeSwitch<Operation *, CastKind>(op) |
| .Case<arith::DivSIOp, arith::CeilDivSIOp, arith::FloorDivSIOp, |
| arith::RemSIOp, arith::MaxSIOp, arith::MinSIOp, |
| arith::ShRSIOp>([](auto) { return CastKind::Signed; }) |
| .Default(CastKind::Both); |
| castKind = mergeCastKinds(castKind, castKindForOp); |
| if (castKind == CastKind::None) |
| continue; |
| Type targetType = getTargetType(srcType, targetBitwidth); |
| if (targetType == srcType) |
| continue; |
| |
| Location loc = op->getLoc(); |
| IRMapping mapping; |
| for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) { |
| CastKind argCastKind = castKind; |
| // When dealing with `index` values, preserve non-negativity in the |
| // index_casts since we can't recover this in unsigned when equivalent. |
| if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) |
| argCastKind = CastKind::Both; |
| Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind); |
| mapping.map(arg, newArg); |
| } |
| |
| Operation *newOp = rewriter.clone(*op, mapping); |
| rewriter.modifyOpInPlace(newOp, [&]() { |
| for (OpResult res : newOp->getResults()) { |
| res.setType(targetType); |
| } |
| }); |
| SmallVector<Value> newResults; |
| for (auto [newRes, oldRes] : |
| llvm::zip_equal(newOp->getResults(), op->getResults())) { |
| Value castBack = doCast(rewriter, loc, newRes, srcType, castKind); |
| copyIntegerRange(solver, oldRes, castBack); |
| newResults.push_back(castBack); |
| } |
| |
| rewriter.replaceOp(op, newResults); |
| return success(); |
| } |
| return failure(); |
| } |
| |
| private: |
| DataFlowSolver &solver; |
| SmallVector<unsigned, 4> targetBitwidths; |
| }; |
| |
| struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> { |
| NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target) |
| : OpRewritePattern(context), solver(s), targetBitwidths(target) {} |
| |
| LogicalResult matchAndRewrite(arith::CmpIOp op, |
| PatternRewriter &rewriter) const override { |
| Value lhs = op.getLhs(); |
| Value rhs = op.getRhs(); |
| |
| SmallVector<ConstantIntRanges> ranges; |
| if (failed(collectRanges(solver, op.getOperands(), ranges))) |
| return failure(); |
| const ConstantIntRanges &lhsRange = ranges[0]; |
| const ConstantIntRanges &rhsRange = ranges[1]; |
| |
| auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) -> bool { |
| return pred == arith::CmpIPredicate::sge || |
| pred == arith::CmpIPredicate::sgt || |
| pred == arith::CmpIPredicate::sle || |
| pred == arith::CmpIPredicate::slt; |
| }; |
| // If we're to narrow the input values via a cast, we should preserve the |
| // sign. |
| CastKind predicateBasedCastRestriction = |
| isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed |
| : CastKind::Both; |
| |
| Type srcType = lhs.getType(); |
| for (unsigned targetBitwidth : targetBitwidths) { |
| CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); |
| CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); |
| CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); |
| castKind = mergeCastKinds(castKind, predicateBasedCastRestriction); |
| // Note: this includes target width > src width, as well as the unsigned |
| // truncatability & signed predicate scenario. |
| if (castKind == CastKind::None) |
| continue; |
| |
| Type targetType = getTargetType(srcType, targetBitwidth); |
| if (targetType == srcType) |
| continue; |
| |
| Location loc = op->getLoc(); |
| IRMapping mapping; |
| Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind); |
| Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind); |
| mapping.map(lhs, lhsCast); |
| mapping.map(rhs, rhsCast); |
| |
| Operation *newOp = rewriter.clone(*op, mapping); |
| copyIntegerRange(solver, op.getResult(), newOp->getResult(0)); |
| rewriter.replaceOp(op, newOp->getResults()); |
| return success(); |
| } |
| return failure(); |
| } |
| |
| private: |
| DataFlowSolver &solver; |
| SmallVector<unsigned, 4> targetBitwidths; |
| }; |
| |
| /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg |
| /// This pattern assumes all passed `targetBitwidths` are not wider than index |
| /// type. |
| template <typename CastOp> |
| struct FoldIndexCastChain final : OpRewritePattern<CastOp> { |
| FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target) |
| : OpRewritePattern<CastOp>(context), targetBitwidths(target) {} |
| |
| LogicalResult matchAndRewrite(CastOp op, |
| PatternRewriter &rewriter) const override { |
| auto srcOp = op.getIn().template getDefiningOp<CastOp>(); |
| if (!srcOp) |
| return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); |
| |
| Value src = srcOp.getIn(); |
| if (src.getType() != op.getType()) |
| return rewriter.notifyMatchFailure(op, "outer types don't match"); |
| |
| if (!srcOp.getType().isIndex()) |
| return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); |
| |
| auto intType = dyn_cast<IntegerType>(op.getType()); |
| if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) |
| return failure(); |
| |
| rewriter.replaceOp(op, src); |
| return success(); |
| } |
| |
| private: |
| SmallVector<unsigned, 4> targetBitwidths; |
| }; |
| |
| struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> { |
| NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s, |
| ArrayRef<unsigned> target) |
| : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s), |
| targetBitwidths(target), |
| boundsNarrowingFailedAttr( |
| StringAttr::get(context, "arith.bounds_narrowing_failed")) {} |
| |
| LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike, |
| PatternRewriter &rewriter) const override { |
| // Skip ops where bounds narrowing previously failed. |
| if (loopLike->hasAttr(boundsNarrowingFailedAttr)) |
| return rewriter.notifyMatchFailure(loopLike, |
| "bounds narrowing previously failed"); |
| |
| std::optional<SmallVector<Value>> inductionVars = |
| loopLike.getLoopInductionVars(); |
| if (!inductionVars.has_value() || inductionVars->empty()) |
| return rewriter.notifyMatchFailure(loopLike, "no induction variables"); |
| |
| std::optional<SmallVector<OpFoldResult>> lowerBounds = |
| loopLike.getLoopLowerBounds(); |
| std::optional<SmallVector<OpFoldResult>> upperBounds = |
| loopLike.getLoopUpperBounds(); |
| std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps(); |
| |
| if (!lowerBounds.has_value() || !upperBounds.has_value() || |
| !steps.has_value()) |
| return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps"); |
| |
| if (lowerBounds->size() != inductionVars->size() || |
| upperBounds->size() != inductionVars->size() || |
| steps->size() != inductionVars->size()) |
| return rewriter.notifyMatchFailure(loopLike, |
| "mismatched bounds/steps count"); |
| |
| Location loc = loopLike->getLoc(); |
| SmallVector<OpFoldResult> newLowerBounds(*lowerBounds); |
| SmallVector<OpFoldResult> newUpperBounds(*upperBounds); |
| SmallVector<OpFoldResult> newSteps(*steps); |
| SmallVector<std::tuple<size_t, Type, CastKind>> narrowings; |
| |
| // Check each (indVar, lb, ub, step) tuple. |
| for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] : |
| llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) { |
| |
| // Only process value operands, skip attributes. |
| auto maybeLb = dyn_cast<Value>(lbOFR); |
| auto maybeUb = dyn_cast<Value>(ubOFR); |
| auto maybeStep = dyn_cast<Value>(stepOFR); |
| |
| if (!maybeLb || !maybeUb || !maybeStep) |
| continue; |
| |
| // Collect ranges for (lb, ub, step, indVar). |
| SmallVector<ConstantIntRanges> ranges; |
| if (failed(collectRanges( |
| solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges))) |
| continue; |
| |
| const ConstantIntRanges &stepRange = ranges[2]; |
| const ConstantIntRanges &indVarRange = ranges[3]; |
| |
| Type srcType = maybeLb.getType(); |
| |
| // Try each target bitwidth. |
| for (unsigned targetBitwidth : targetBitwidths) { |
| Type targetType = getTargetType(srcType, targetBitwidth); |
| if (targetType == srcType) |
| continue; |
| |
| // Check if the target type is valid for this loop's induction |
| // variables. |
| if (!loopLike.isValidInductionVarType(targetType)) |
| continue; |
| |
| // Check if all values in this tuple can be truncated. |
| CastKind castKind = CastKind::Both; |
| for (const ConstantIntRanges &range : ranges) { |
| castKind = mergeCastKinds(castKind, |
| checkTruncatability(range, targetBitwidth)); |
| if (castKind == CastKind::None) |
| break; |
| } |
| |
| if (castKind == CastKind::None) |
| continue; |
| |
| // Check if indVar + step fits in the narrowed type. |
| // This is critical for loop correctness: the loop computes |
| // iv_next = iv_current + step in the narrowed type, then compares |
| // iv_next < ub. If iv_current + step overflows, the comparison may |
| // produce incorrect results and break loop termination. |
| // Both signed and unsigned interpretations must fit because loop |
| // semantics are unknown (integer types are signless). |
| ConstantIntRanges indVarPlusStepRange( |
| indVarRange.smin().sadd_sat(stepRange.smin()), |
| indVarRange.smax().sadd_sat(stepRange.smax()), |
| indVarRange.umin().uadd_sat(stepRange.umin()), |
| indVarRange.umax().uadd_sat(stepRange.umax())); |
| |
| if (checkTruncatability(indVarPlusStepRange, targetBitwidth) != |
| CastKind::Both) |
| continue; |
| |
| // Narrow the bounds and step values. |
| Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind); |
| Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind); |
| Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind); |
| |
| newLowerBounds[idx] = newLb; |
| newUpperBounds[idx] = newUb; |
| newSteps[idx] = newStep; |
| narrowings.push_back({idx, targetType, castKind}); |
| break; |
| } |
| } |
| |
| if (narrowings.empty()) |
| return rewriter.notifyMatchFailure(loopLike, "no narrowings found"); |
| |
| // Save original types before modifying. |
| SmallVector<Type> origTypes; |
| for (auto [idx, targetType, castKind] : narrowings) { |
| Value indVar = (*inductionVars)[idx]; |
| origTypes.push_back(indVar.getType()); |
| } |
| |
| // Attempt to update bounds and induction variable types. |
| // If this fails, mark the op so we don't try again. |
| bool updateFailed = false; |
| rewriter.modifyOpInPlace(loopLike, [&]() { |
| // Update the loop bounds and steps. |
| if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) || |
| failed(loopLike.setLoopUpperBounds(newUpperBounds)) || |
| failed(loopLike.setLoopSteps(newSteps))) { |
| // Mark op to prevent future attempts. IR was modified (attribute |
| // added), so we must return success() from the pattern. |
| loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr()); |
| updateFailed = true; |
| return; |
| } |
| |
| // Update induction variable types. |
| for (auto [idx, targetType, castKind] : narrowings) { |
| Value indVar = (*inductionVars)[idx]; |
| auto blockArg = cast<BlockArgument>(indVar); |
| |
| // Change the block argument type. |
| blockArg.setType(targetType); |
| } |
| }); |
| |
| if (updateFailed) |
| return success(); |
| |
| // Insert casts back to original type for uses. |
| for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) { |
| auto [idx, targetType, castKind] = narrowingInfo; |
| Value indVar = (*inductionVars)[idx]; |
| auto blockArg = cast<BlockArgument>(indVar); |
| Type origType = origTypes[narrowingIdx]; |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(blockArg.getOwner()); |
| Value casted = doCast(rewriter, loc, blockArg, origType, castKind); |
| copyIntegerRange(solver, blockArg, casted); |
| |
| // Replace all uses of the narrowed indVar with the casted value. |
| rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp()); |
| } |
| |
| return success(); |
| } |
| |
| private: |
| DataFlowSolver &solver; |
| SmallVector<unsigned, 4> targetBitwidths; |
| StringAttr boundsNarrowingFailedAttr; |
| }; |
| |
| struct IntRangeOptimizationsPass final |
| : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> { |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| MLIRContext *ctx = op->getContext(); |
| DataFlowSolver solver; |
| loadBaselineAnalyses(solver); |
| solver.load<IntegerRangeAnalysis>(); |
| if (failed(solver.initializeAndRun(op))) |
| return signalPassFailure(); |
| |
| DataFlowListener listener(solver); |
| |
| RewritePatternSet patterns(ctx); |
| populateIntRangeOptimizationsPatterns(patterns, solver); |
| |
| // Disable folding and region simplification to avoid breaking the solver |
| // state. Both can remove block arguments (folding via control-flow |
| // simplification, region simplification via dead-arg elimination), which |
| // frees their underlying storage. A subsequent allocation may reuse the |
| // same address for a different block argument, causing stale solver state |
| // to be associated with the new argument and producing incorrect constants. |
| if (failed( |
| applyPatternsGreedily(op, std::move(patterns), |
| GreedyRewriteConfig() |
| .enableFolding(false) |
| .setRegionSimplificationLevel( |
| GreedySimplifyRegionLevel::Disabled) |
| .setListener(&listener)))) |
| signalPassFailure(); |
| } |
| }; |
| |
| struct IntRangeNarrowingPass final |
| : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> { |
| using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| MLIRContext *ctx = op->getContext(); |
| DataFlowSolver solver; |
| loadBaselineAnalyses(solver); |
| solver.load<IntegerRangeAnalysis>(); |
| if (failed(solver.initializeAndRun(op))) |
| return signalPassFailure(); |
| |
| DataFlowListener listener(solver); |
| |
| RewritePatternSet patterns(ctx); |
| populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); |
| populateControlFlowValuesNarrowingPatterns(patterns, solver, |
| bitwidthsSupported); |
| |
| // We specifically need bottom-up traversal as cmpi pattern needs range |
| // data, attached to its original argument values. |
| if (failed(applyPatternsGreedily( |
| op, std::move(patterns), |
| GreedyRewriteConfig().setUseTopDownTraversal(false).setListener( |
| &listener)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::arith::populateIntRangeOptimizationsPatterns( |
| RewritePatternSet &patterns, DataFlowSolver &solver) { |
| patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>, |
| DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver); |
| } |
| |
| void mlir::arith::populateIntRangeNarrowingPatterns( |
| RewritePatternSet &patterns, DataFlowSolver &solver, |
| ArrayRef<unsigned> bitwidthsSupported) { |
| patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver, |
| bitwidthsSupported); |
| patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>, |
| FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(), |
| bitwidthsSupported); |
| } |
| |
| void mlir::arith::populateControlFlowValuesNarrowingPatterns( |
| RewritePatternSet &patterns, DataFlowSolver &solver, |
| ArrayRef<unsigned> bitwidthsSupported) { |
| patterns.add<NarrowLoopBounds>(patterns.getContext(), solver, |
| bitwidthsSupported); |
| } |
| |
| std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() { |
| return std::make_unique<IntRangeOptimizationsPass>(); |
| } |