| //===- InferIntRangeCommon.cpp - Inference for common ops ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file contains implementations of range inference for operations that are |
| // common to both the `arith` and `index` dialects to facilitate reuse. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
| |
| #include "mlir/Interfaces/InferIntRangeInterface.h" |
| #include "mlir/Interfaces/ShapedOpInterfaces.h" |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| |
| #include "llvm/Support/Debug.h" |
| |
| #include <iterator> |
| #include <optional> |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "int-range-analysis" |
| |
| //===----------------------------------------------------------------------===// |
| // General utilities |
| //===----------------------------------------------------------------------===// |
| |
| /// Function that evaluates the result of doing something on arithmetic |
| /// constants and returns std::nullopt on overflow. |
| using ConstArithFn = |
| function_ref<std::optional<APInt>(const APInt &, const APInt &)>; |
| using ConstArithStdFn = |
| std::function<std::optional<APInt>(const APInt &, const APInt &)>; |
| |
| /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, |
| /// If either computation overflows, make the result unbounded. |
| static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, |
| const APInt &minRight, |
| const APInt &maxLeft, |
| const APInt &maxRight, bool isSigned) { |
| std::optional<APInt> maybeMin = op(minLeft, minRight); |
| std::optional<APInt> maybeMax = op(maxLeft, maxRight); |
| if (maybeMin && maybeMax) |
| return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); |
| return ConstantIntRanges::maxRange(minLeft.getBitWidth()); |
| } |
| |
| /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, |
| /// ignoring unbounded values. Returns the maximal range if `op` overflows. |
| static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs, |
| ArrayRef<APInt> rhs, bool isSigned) { |
| unsigned width = lhs[0].getBitWidth(); |
| APInt min = |
| isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); |
| APInt max = |
| isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); |
| for (const APInt &left : lhs) { |
| for (const APInt &right : rhs) { |
| std::optional<APInt> maybeThisResult = op(left, right); |
| if (!maybeThisResult) |
| return ConstantIntRanges::maxRange(width); |
| APInt result = std::move(*maybeThisResult); |
| min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; |
| max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; |
| } |
| } |
| return ConstantIntRanges::range(min, max, isSigned); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Ext, trunc, index op handling |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferIndexOp(const InferRangeFn &inferFn, |
| ArrayRef<ConstantIntRanges> argRanges, |
| intrange::CmpMode mode) { |
| ConstantIntRanges sixtyFour = inferFn(argRanges); |
| SmallVector<ConstantIntRanges, 2> truncated; |
| llvm::transform(argRanges, std::back_inserter(truncated), |
| [](const ConstantIntRanges &range) { |
| return truncRange(range, /*destWidth=*/indexMinWidth); |
| }); |
| ConstantIntRanges thirtyTwo = inferFn(truncated); |
| ConstantIntRanges thirtyTwoAsSixtyFour = |
| extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); |
| ConstantIntRanges sixtyFourAsThirtyTwo = |
| truncRange(sixtyFour, /*destWidth=*/indexMinWidth); |
| |
| LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour |
| << " 32-bit = " << thirtyTwo << "\n"); |
| bool truncEqual = false; |
| switch (mode) { |
| case intrange::CmpMode::Both: |
| truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); |
| break; |
| case intrange::CmpMode::Signed: |
| truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && |
| thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); |
| break; |
| case intrange::CmpMode::Unsigned: |
| truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && |
| thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); |
| break; |
| } |
| if (truncEqual) |
| // Returing the 64-bit result preserves more information. |
| return sixtyFour; |
| ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); |
| return merged; |
| } |
| |
| ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, |
| unsigned int destWidth) { |
| APInt umin = range.umin().zext(destWidth); |
| APInt umax = range.umax().zext(destWidth); |
| APInt smin = range.smin().sext(destWidth); |
| APInt smax = range.smax().sext(destWidth); |
| return {umin, umax, smin, smax}; |
| } |
| |
| ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, |
| unsigned destWidth) { |
| APInt umin = range.umin().zext(destWidth); |
| APInt umax = range.umax().zext(destWidth); |
| return ConstantIntRanges::fromUnsigned(umin, umax); |
| } |
| |
| ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, |
| unsigned destWidth) { |
| APInt smin = range.smin().sext(destWidth); |
| APInt smax = range.smax().sext(destWidth); |
| return ConstantIntRanges::fromSigned(smin, smax); |
| } |
| |
| ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, |
| unsigned int destWidth) { |
| // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], |
| // the range of the resulting value is not contiguous ind includes 0. |
| // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], |
| // but you can't truncate [255, 257] similarly. |
| bool hasUnsignedRollover = |
| range.umin().lshr(destWidth) != range.umax().lshr(destWidth); |
| APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) |
| : range.umin().trunc(destWidth); |
| APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) |
| : range.umax().trunc(destWidth); |
| |
| // Signed post-truncation rollover will not occur when either: |
| // - The high parts of the min and max, plus the sign bit, are the same |
| // - The high halves + sign bit of the min and max are either all 1s or all 0s |
| // and you won't create a [positive, negative] range by truncating. |
| // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 |
| // but not [255, 257]_i16 to a range of i8s. You can also truncate |
| // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. |
| // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) |
| // will truncate to 0x7e, which is greater than 0 |
| APInt sminHighPart = range.smin().ashr(destWidth - 1); |
| APInt smaxHighPart = range.smax().ashr(destWidth - 1); |
| bool hasSignedOverflow = |
| (sminHighPart != smaxHighPart) && |
| !(sminHighPart.isAllOnes() && |
| (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && |
| !(sminHighPart.isZero() && smaxHighPart.isZero()); |
| APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) |
| : range.smin().trunc(destWidth); |
| APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) |
| : range.smax().trunc(destWidth); |
| return {umin, umax, smin, smax}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Addition |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges, |
| OverflowFlags ovfFlags) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| ConstArithStdFn uadd = [=](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nuw) |
| ? a.uadd_sat(b) |
| : a.uadd_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| ConstArithStdFn sadd = [=](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nsw) |
| ? a.sadd_sat(b) |
| : a.sadd_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| |
| ConstantIntRanges urange = computeBoundsBy( |
| uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); |
| ConstantIntRanges srange = computeBoundsBy( |
| sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); |
| return urange.intersection(srange); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Subtraction |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges, |
| OverflowFlags ovfFlags) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| ConstArithStdFn usub = [=](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nuw) |
| ? a.usub_sat(b) |
| : a.usub_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| ConstArithStdFn ssub = [=](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nsw) |
| ? a.ssub_sat(b) |
| : a.ssub_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| ConstantIntRanges urange = computeBoundsBy( |
| usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); |
| ConstantIntRanges srange = computeBoundsBy( |
| ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); |
| return urange.intersection(srange); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Multiplication |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges, |
| OverflowFlags ovfFlags) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| ConstArithStdFn umul = [=](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nuw) |
| ? a.umul_sat(b) |
| : a.umul_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| ConstArithStdFn smul = [=](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nsw) |
| ? a.smul_sat(b) |
| : a.smul_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| |
| ConstantIntRanges urange = |
| minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, |
| /*isSigned=*/false); |
| ConstantIntRanges srange = |
| minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, |
| /*isSigned=*/true); |
| return urange.intersection(srange); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DivU, CeilDivU (Unsigned division) |
| //===----------------------------------------------------------------------===// |
| |
| /// Fix up division results (ex. for ceiling and floor), returning an APInt |
| /// if there has been no overflow |
| using DivisionFixupFn = function_ref<std::optional<APInt>( |
| const APInt &lhs, const APInt &rhs, const APInt &result)>; |
| |
| static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, |
| const ConstantIntRanges &rhs, |
| DivisionFixupFn fixup) { |
| const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), |
| &rhsMax = rhs.umax(); |
| |
| if (!rhsMin.isZero()) { |
| auto udiv = [&fixup](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| return fixup(a, b, a.udiv(b)); |
| }; |
| return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, |
| /*isSigned=*/false); |
| } |
| |
| APInt umin = APInt::getZero(rhsMin.getBitWidth()); |
| if (lhsMin.uge(rhsMax) && !rhsMax.isZero()) |
| umin = lhsMin.udiv(rhsMax); |
| |
| // X u/ Y u<= X. |
| APInt umax = lhsMax; |
| return ConstantIntRanges::fromUnsigned(umin, umax); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferDivU(ArrayRef<ConstantIntRanges> argRanges) { |
| return inferDivURange(argRanges[0], argRanges[1], |
| [](const APInt &lhs, const APInt &rhs, |
| const APInt &result) { return result; }); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferCeilDivU(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| auto ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, |
| const APInt &result) -> std::optional<APInt> { |
| if (!lhs.urem(rhs).isZero()) { |
| bool overflowed = false; |
| APInt corrected = |
| result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); |
| return overflowed ? std::optional<APInt>() : corrected; |
| } |
| return result; |
| }; |
| return inferDivURange(lhs, rhs, ceilDivUIFix); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DivS, CeilDivS, FloorDivS (Signed division) |
| //===----------------------------------------------------------------------===// |
| |
| static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, |
| const ConstantIntRanges &rhs, |
| DivisionFixupFn fixup) { |
| const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), |
| &rhsMax = rhs.smax(); |
| bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); |
| |
| if (canDivide) { |
| auto sdiv = [&fixup](const APInt &a, |
| const APInt &b) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = a.sdiv_ov(b, overflowed); |
| return overflowed ? std::optional<APInt>() : fixup(a, b, result); |
| }; |
| return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, |
| /*isSigned=*/true); |
| } |
| return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferDivS(ArrayRef<ConstantIntRanges> argRanges) { |
| return inferDivSRange(argRanges[0], argRanges[1], |
| [](const APInt &lhs, const APInt &rhs, |
| const APInt &result) { return result; }); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| auto ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, |
| const APInt &result) -> std::optional<APInt> { |
| if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { |
| bool overflowed = false; |
| APInt corrected = |
| result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); |
| return overflowed ? std::optional<APInt>() : corrected; |
| } |
| // Special case where the usual implementation of ceilDiv causes |
| // INT_MIN / [positive number] to be positive. This doesn't match the |
| // definition of signed ceiling division mathematically, but it prevents |
| // inconsistent constant-folding results. This arises because (-int_min) is |
| // still negative, so -(-int_min / b) is -(int_min / b), which is |
| // positive See #115293. |
| if (lhs.isMinSignedValue() && rhs.sgt(1)) { |
| return -result; |
| } |
| return result; |
| }; |
| ConstantIntRanges result = inferDivSRange(lhs, rhs, ceilDivSIFix); |
| if (lhs.smin().isMinSignedValue() && lhs.smax().sgt(lhs.smin())) { |
| // If lhs range includes INT_MIN and lhs is not a single value, we can |
| // suddenly wrap to positive val, skipping entire negative range, add |
| // [INT_MIN + 1, smax()] range to the result to handle this. |
| auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax()); |
| result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix)); |
| } |
| return result; |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferFloorDivS(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| auto floorDivSIFix = [](const APInt &lhs, const APInt &rhs, |
| const APInt &result) -> std::optional<APInt> { |
| if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { |
| bool overflowed = false; |
| APInt corrected = |
| result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); |
| return overflowed ? std::optional<APInt>() : corrected; |
| } |
| return result; |
| }; |
| return inferDivSRange(lhs, rhs, floorDivSIFix); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Signed remainder (RemS) |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), |
| &rhsMax = rhs.smax(); |
| |
| unsigned width = rhsMax.getBitWidth(); |
| APInt smin = APInt::getSignedMinValue(width); |
| APInt smax = APInt::getSignedMaxValue(width); |
| // No bounds if zero could be a divisor. |
| bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); |
| if (canBound) { |
| APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); |
| bool canNegativeDividend = lhsMin.isNegative(); |
| bool canPositiveDividend = lhsMax.isStrictlyPositive(); |
| APInt zero = APInt::getZero(maxDivisor.getBitWidth()); |
| APInt maxPositiveResult = maxDivisor - 1; |
| APInt minNegativeResult = -maxPositiveResult; |
| smin = canNegativeDividend ? minNegativeResult : zero; |
| smax = canPositiveDividend ? maxPositiveResult : zero; |
| // Special case: sweeping out a contiguous range in N/[modulus]. |
| if (rhsMin == rhsMax) { |
| if ((lhsMax - lhsMin).ult(maxDivisor)) { |
| APInt minRem = lhsMin.srem(maxDivisor); |
| APInt maxRem = lhsMax.srem(maxDivisor); |
| if (minRem.sle(maxRem)) { |
| smin = minRem; |
| smax = maxRem; |
| } |
| } |
| } |
| } |
| return ConstantIntRanges::fromSigned(smin, smax); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Unsigned remainder (RemU) |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); |
| |
| unsigned width = rhsMin.getBitWidth(); |
| APInt umin = APInt::getZero(width); |
| // Remainder can't be larger than either of its arguments. |
| APInt umax = llvm::APIntOps::umin((rhsMax - 1), lhs.umax()); |
| |
| if (!rhsMin.isZero()) { |
| // Special case: sweeping out a contiguous range in N/[modulus] |
| if (rhsMin == rhsMax) { |
| const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); |
| if ((lhsMax - lhsMin).ult(rhsMax)) { |
| APInt minRem = lhsMin.urem(rhsMax); |
| APInt maxRem = lhsMax.urem(rhsMax); |
| if (minRem.ule(maxRem)) { |
| umin = minRem; |
| umax = maxRem; |
| } |
| } |
| } |
| } |
| return ConstantIntRanges::fromUnsigned(umin, umax); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Max and min (MaxS, MaxU, MinS, MinU) |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); |
| const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); |
| return ConstantIntRanges::fromSigned(smin, smax); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferMaxU(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); |
| const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); |
| return ConstantIntRanges::fromUnsigned(umin, umax); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferMinS(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); |
| const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); |
| return ConstantIntRanges::fromSigned(smin, smax); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferMinU(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); |
| const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); |
| return ConstantIntRanges::fromUnsigned(umin, umax); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Bitwise operators (And, Or, Xor) |
| //===----------------------------------------------------------------------===// |
| |
| /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, |
| /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits |
| /// that both bonuds have in common. This gives us a consertive approximation |
| /// for what values can be passed to bitwise operations. |
| static std::tuple<APInt, APInt> |
| widenBitwiseBounds(const ConstantIntRanges &bound) { |
| APInt leftVal = bound.umin(), rightVal = bound.umax(); |
| unsigned bitwidth = leftVal.getBitWidth(); |
| unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); |
| leftVal.clearLowBits(differingBits); |
| rightVal.setLowBits(differingBits); |
| return std::make_tuple(std::move(leftVal), std::move(rightVal)); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferAnd(ArrayRef<ConstantIntRanges> argRanges) { |
| auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); |
| auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); |
| auto andi = [](const APInt &a, const APInt &b) -> std::optional<APInt> { |
| return a & b; |
| }; |
| return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, |
| /*isSigned=*/false); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) { |
| auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); |
| auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); |
| auto ori = [](const APInt &a, const APInt &b) -> std::optional<APInt> { |
| return a | b; |
| }; |
| return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, |
| /*isSigned=*/false); |
| } |
| |
| /// Get bitmask of all bits which can change while iterating in |
| /// [bound.umin(), bound.umax()]. |
| static APInt getVaryingBitsMask(const ConstantIntRanges &bound) { |
| APInt leftVal = bound.umin(), rightVal = bound.umax(); |
| unsigned bitwidth = leftVal.getBitWidth(); |
| unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero(); |
| return APInt::getLowBitsSet(bitwidth, differingBits); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) { |
| // Construct mask of varying bits for both ranges, xor values and then replace |
| // masked bits with 0s and 1s to get min and max values respectively. |
| ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1]; |
| APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs); |
| APInt res = lhs.umin() ^ rhs.umin(); |
| APInt min = res & ~mask; |
| APInt max = res | mask; |
| return ConstantIntRanges::fromUnsigned(min, max); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Shifts (Shl, ShrS, ShrU) |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferShl(ArrayRef<ConstantIntRanges> argRanges, |
| OverflowFlags ovfFlags) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| const APInt &rhsUMin = rhs.umin(), &rhsUMax = rhs.umax(); |
| |
| // The signed/unsigned overflow behavior of shl by `rhs` matches a mul with |
| // 2^rhs. |
| ConstArithStdFn ushl = [=](const APInt &l, |
| const APInt &r) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nuw) |
| ? l.ushl_sat(r) |
| : l.ushl_ov(r, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| ConstArithStdFn sshl = [=](const APInt &l, |
| const APInt &r) -> std::optional<APInt> { |
| bool overflowed = false; |
| APInt result = any(ovfFlags & OverflowFlags::Nsw) |
| ? l.sshl_sat(r) |
| : l.sshl_ov(r, overflowed); |
| return overflowed ? std::optional<APInt>() : result; |
| }; |
| |
| ConstantIntRanges urange = |
| minMaxBy(ushl, {lhs.umin(), lhs.umax()}, {rhsUMin, rhsUMax}, |
| /*isSigned=*/false); |
| ConstantIntRanges srange = |
| minMaxBy(sshl, {lhs.smin(), lhs.smax()}, {rhsUMin, rhsUMax}, |
| /*isSigned=*/true); |
| return urange.intersection(srange); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferShrS(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| auto ashr = [](const APInt &l, const APInt &r) -> std::optional<APInt> { |
| return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.ashr(r); |
| }; |
| |
| return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, |
| /*isSigned=*/true); |
| } |
| |
| ConstantIntRanges |
| mlir::intrange::inferShrU(ArrayRef<ConstantIntRanges> argRanges) { |
| const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; |
| |
| auto lshr = [](const APInt &l, const APInt &r) -> std::optional<APInt> { |
| return r.uge(r.getBitWidth()) ? std::optional<APInt>() : l.lshr(r); |
| }; |
| return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, |
| /*isSigned=*/false); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Comparisons (Cmp) |
| //===----------------------------------------------------------------------===// |
| |
| static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { |
| switch (pred) { |
| case intrange::CmpPredicate::eq: |
| return intrange::CmpPredicate::ne; |
| case intrange::CmpPredicate::ne: |
| return intrange::CmpPredicate::eq; |
| case intrange::CmpPredicate::slt: |
| return intrange::CmpPredicate::sge; |
| case intrange::CmpPredicate::sle: |
| return intrange::CmpPredicate::sgt; |
| case intrange::CmpPredicate::sgt: |
| return intrange::CmpPredicate::sle; |
| case intrange::CmpPredicate::sge: |
| return intrange::CmpPredicate::slt; |
| case intrange::CmpPredicate::ult: |
| return intrange::CmpPredicate::uge; |
| case intrange::CmpPredicate::ule: |
| return intrange::CmpPredicate::ugt; |
| case intrange::CmpPredicate::ugt: |
| return intrange::CmpPredicate::ule; |
| case intrange::CmpPredicate::uge: |
| return intrange::CmpPredicate::ult; |
| } |
| llvm_unreachable("unknown cmp predicate value"); |
| } |
| |
| static bool isStaticallyTrue(intrange::CmpPredicate pred, |
| const ConstantIntRanges &lhs, |
| const ConstantIntRanges &rhs) { |
| switch (pred) { |
| case intrange::CmpPredicate::sle: |
| return lhs.smax().sle(rhs.smin()); |
| case intrange::CmpPredicate::slt: |
| return lhs.smax().slt(rhs.smin()); |
| case intrange::CmpPredicate::ule: |
| return lhs.umax().ule(rhs.umin()); |
| case intrange::CmpPredicate::ult: |
| return lhs.umax().ult(rhs.umin()); |
| case intrange::CmpPredicate::sge: |
| return lhs.smin().sge(rhs.smax()); |
| case intrange::CmpPredicate::sgt: |
| return lhs.smin().sgt(rhs.smax()); |
| case intrange::CmpPredicate::uge: |
| return lhs.umin().uge(rhs.umax()); |
| case intrange::CmpPredicate::ugt: |
| return lhs.umin().ugt(rhs.umax()); |
| case intrange::CmpPredicate::eq: { |
| std::optional<APInt> lhsConst = lhs.getConstantValue(); |
| std::optional<APInt> rhsConst = rhs.getConstantValue(); |
| return lhsConst && rhsConst && lhsConst == rhsConst; |
| } |
| case intrange::CmpPredicate::ne: { |
| // While equality requires that there is an interpration of the preceeding |
| // computations that produces equal constants, whether that be signed or |
| // unsigned, statically determining inequality requires that neither |
| // interpretation produce potentially overlapping ranges. |
| bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || |
| isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); |
| bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || |
| isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); |
| return sne && une; |
| } |
| } |
| return false; |
| } |
| |
| std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred, |
| const ConstantIntRanges &lhs, |
| const ConstantIntRanges &rhs) { |
| if (isStaticallyTrue(pred, lhs, rhs)) |
| return true; |
| if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) |
| return false; |
| return std::nullopt; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Shaped type dimension accessors / ShapedDimOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| ConstantIntRanges |
| mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op, |
| const IntegerValueRange &maybeDim) { |
| unsigned width = |
| ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType()); |
| APInt zero = APInt::getZero(width); |
| APInt typeMax = APInt::getSignedMaxValue(width); |
| |
| auto shapedTy = cast<ShapedType>(op.getShapedValue().getType()); |
| if (!shapedTy.hasRank()) |
| return ConstantIntRanges::fromSigned(zero, typeMax); |
| |
| int64_t rank = shapedTy.getRank(); |
| int64_t minDim = 0; |
| int64_t maxDim = rank - 1; |
| if (!maybeDim.isUninitialized()) { |
| const ConstantIntRanges &dim = maybeDim.getValue(); |
| minDim = std::max(minDim, dim.smin().getSExtValue()); |
| maxDim = std::min(maxDim, dim.smax().getSExtValue()); |
| } |
| |
| std::optional<ConstantIntRanges> result; |
| auto joinResult = [&](const ConstantIntRanges &thisResult) { |
| if (!result.has_value()) |
| result = thisResult; |
| else |
| result = result->rangeUnion(thisResult); |
| }; |
| for (int64_t i = minDim; i <= maxDim; ++i) { |
| int64_t length = shapedTy.getDimSize(i); |
| |
| if (ShapedType::isDynamic(length)) |
| joinResult(ConstantIntRanges::fromSigned(zero, typeMax)); |
| else |
| joinResult(ConstantIntRanges::constant(APInt(width, length))); |
| } |
| return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax)); |
| } |