| //===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the dataflow analysis class for integer range inference |
| // which is used in transformations over the `arith` dialect such as |
| // branch elimination or signed->unsigned rewriting |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
| #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
| #include "mlir/Analysis/DataFlowFramework.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| #include "mlir/Interfaces/InferIntRangeInterface.h" |
| #include "mlir/Interfaces/LoopLikeInterface.h" |
| #include "mlir/Support/DebugStringHelper.h" |
| #include "mlir/Support/LLVM.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/DebugLog.h" |
| #include <cassert> |
| #include <optional> |
| #include <utility> |
| |
| #define DEBUG_TYPE "int-range-analysis" |
| |
| using namespace mlir; |
| using namespace mlir::dataflow; |
| |
| namespace mlir::dataflow { |
| LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { |
| auto *result = solver.lookupState<IntegerValueRangeLattice>(v); |
| if (!result || result->getValue().isUninitialized()) |
| return failure(); |
| const ConstantIntRanges &range = result->getValue().getValue(); |
| return success(range.smin().isNonNegative()); |
| } |
| |
| LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) { |
| auto nonNegativePred = [&solver](Value v) -> bool { |
| return succeeded(staticallyNonNegative(solver, v)); |
| }; |
| return success(llvm::all_of(op->getOperands(), nonNegativePred) && |
| llvm::all_of(op->getResults(), nonNegativePred)); |
| } |
| } // namespace mlir::dataflow |
| |
| void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { |
| Lattice::onUpdate(solver); |
| |
| // If the integer range can be narrowed to a constant, update the constant |
| // value of the SSA value. |
| std::optional<APInt> constant = getValue().getValue().getConstantValue(); |
| auto value = cast<Value>(anchor); |
| auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value); |
| if (!constant) |
| return solver->propagateIfChanged( |
| cv, cv->join(ConstantValue::getUnknownConstant())); |
| |
| Dialect *dialect; |
| if (auto *parent = value.getDefiningOp()) |
| dialect = parent->getDialect(); |
| else |
| dialect = value.getParentBlock()->getParentOp()->getDialect(); |
| |
| Attribute cstAttr; |
| if (isa<IntegerType, IndexType>(value.getType())) { |
| cstAttr = IntegerAttr::get(value.getType(), *constant); |
| } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) { |
| cstAttr = SplatElementsAttr::get(shapedTy, *constant); |
| } else { |
| llvm::report_fatal_error( |
| Twine("FIXME: Don't know how to create a constant for this type: ") + |
| mlir::debugString(value.getType())); |
| } |
| solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect))); |
| } |
| |
| LogicalResult IntegerRangeAnalysis::visitOperation( |
| Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands, |
| ArrayRef<IntegerValueRangeLattice *> results) { |
| auto inferrable = dyn_cast<InferIntRangeInterface>(op); |
| if (!inferrable) { |
| setAllToEntryStates(results); |
| return success(); |
| } |
| |
| LDBG() << "Inferring ranges for " |
| << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
| auto argRanges = llvm::map_to_vector( |
| operands, [](const IntegerValueRangeLattice *lattice) { |
| return lattice->getValue(); |
| }); |
| |
| auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { |
| auto result = dyn_cast<OpResult>(v); |
| if (!result) |
| return; |
| assert(llvm::is_contained(op->getResults(), result)); |
| |
| LDBG() << "Inferred range " << attrs; |
| IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; |
| IntegerValueRange oldRange = lattice->getValue(); |
| |
| ChangeResult changed = lattice->join(attrs); |
| |
| // Catch loop results with loop variant bounds and conservatively make |
| // them [-inf, inf] so we don't circle around infinitely often (because |
| // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
| // and often can't). |
| bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { |
| return op->hasTrait<OpTrait::IsTerminator>(); |
| }); |
| if (isYieldedResult && !oldRange.isUninitialized() && |
| !(lattice->getValue() == oldRange)) { |
| LDBG() << "Loop variant loop result detected"; |
| changed |= lattice->join(IntegerValueRange::getMaxRange(v)); |
| } |
| propagateIfChanged(lattice, changed); |
| }; |
| |
| inferrable.inferResultRangesFromOptional(argRanges, joinCallback); |
| return success(); |
| } |
| |
| void IntegerRangeAnalysis::visitNonControlFlowArguments( |
| Operation *op, const RegionSuccessor &successor, |
| ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { |
| if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { |
| LDBG() << "Inferring ranges for " |
| << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
| |
| auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { |
| return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); |
| }); |
| |
| auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { |
| auto arg = dyn_cast<BlockArgument>(v); |
| if (!arg) |
| return; |
| if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) |
| return; |
| |
| LDBG() << "Inferred range " << attrs; |
| IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; |
| IntegerValueRange oldRange = lattice->getValue(); |
| |
| ChangeResult changed = lattice->join(attrs); |
| |
| // Catch loop results with loop variant bounds and conservatively make |
| // them [-inf, inf] so we don't circle around infinitely often (because |
| // the dataflow analysis in MLIR doesn't attempt to work out trip counts |
| // and often can't). |
| bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { |
| return op->hasTrait<OpTrait::IsTerminator>(); |
| }); |
| if (isYieldedValue && !oldRange.isUninitialized() && |
| !(lattice->getValue() == oldRange)) { |
| LDBG() << "Loop variant loop result detected"; |
| changed |= lattice->join(IntegerValueRange::getMaxRange(v)); |
| } |
| propagateIfChanged(lattice, changed); |
| }; |
| |
| inferrable.inferResultRangesFromOptional(argRanges, joinCallback); |
| return; |
| } |
| |
| /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() |
| /// on a LoopLikeInterface return the lower/upper bound for that result if |
| /// possible. |
| auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound, |
| Type boundType, Block *block, bool getUpper) { |
| unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); |
| if (loopBound.has_value()) { |
| if (auto attr = dyn_cast<Attribute>(*loopBound)) { |
| if (auto bound = dyn_cast_or_null<IntegerAttr>(attr)) |
| return bound.getValue(); |
| } else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) { |
| const IntegerValueRangeLattice *lattice = |
| getLatticeElementFor(getProgramPointBefore(block), value); |
| if (lattice != nullptr && !lattice->getValue().isUninitialized()) |
| return getUpper ? lattice->getValue().getValue().smax() |
| : lattice->getValue().getValue().smin(); |
| } |
| } |
| // Given the results of getConstant{Lower,Upper}Bound() |
| // or getConstantStep() on a LoopLikeInterface return the lower/upper |
| // bound |
| return getUpper ? APInt::getSignedMaxValue(width) |
| : APInt::getSignedMinValue(width); |
| }; |
| |
| // Infer bounds for loop arguments that have static bounds |
| if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) { |
| std::optional<Value> iv = loop.getSingleInductionVar(); |
| if (!iv) { |
| return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( |
| op, successor, argLattices, firstIndex); |
| } |
| Block *block = iv->getParentBlock(); |
| std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound(); |
| std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound(); |
| std::optional<OpFoldResult> step = loop.getSingleStep(); |
| APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, |
| /*getUpper=*/false); |
| APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, |
| /*getUpper=*/true); |
| // Assume positivity for uniscoverable steps by way of getUpper = true. |
| APInt stepVal = |
| getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); |
| |
| if (stepVal.isNegative()) { |
| std::swap(min, max); |
| } else { |
| // Correct the upper bound by subtracting 1 so that it becomes a <= |
| // bound, because loops do not generally include their upper bound. |
| max -= 1; |
| } |
| |
| // If we infer the lower bound to be larger than the upper bound, the |
| // resulting range is meaningless and should not be used in further |
| // inferences. |
| if (max.sge(min)) { |
| IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); |
| auto ivRange = ConstantIntRanges::fromSigned(min, max); |
| propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); |
| } |
| return; |
| } |
| |
| return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( |
| op, successor, argLattices, firstIndex); |
| } |