| //===- InferIntRangeInterface.cpp - Integer range inference interface ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/InferIntRangeInterface.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/Interfaces/InferIntRangeInterface.cpp.inc" |
| |
| using namespace mlir; |
| |
| bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const { |
| return umin().getBitWidth() == other.umin().getBitWidth() && |
| umin() == other.umin() && umax() == other.umax() && |
| smin() == other.smin() && smax() == other.smax(); |
| } |
| |
| const APInt &ConstantIntRanges::umin() const { return uminVal; } |
| |
| const APInt &ConstantIntRanges::umax() const { return umaxVal; } |
| |
| const APInt &ConstantIntRanges::smin() const { return sminVal; } |
| |
| const APInt &ConstantIntRanges::smax() const { return smaxVal; } |
| |
| unsigned ConstantIntRanges::getStorageBitwidth(Type type) { |
| if (type.isIndex()) |
| return IndexType::kInternalStorageBitWidth; |
| if (auto integerType = type.dyn_cast<IntegerType>()) |
| return integerType.getWidth(); |
| // Non-integer types have their bounds stored in width 0 `APInt`s. |
| return 0; |
| } |
| |
| ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) { |
| return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth)); |
| } |
| |
| ConstantIntRanges ConstantIntRanges::constant(const APInt &value) { |
| return {value, value, value, value}; |
| } |
| |
| ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max, |
| bool isSigned) { |
| if (isSigned) |
| return fromSigned(min, max); |
| return fromUnsigned(min, max); |
| } |
| |
| ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin, |
| const APInt &smax) { |
| unsigned int width = smin.getBitWidth(); |
| APInt umin, umax; |
| if (smin.isNonNegative() == smax.isNonNegative()) { |
| umin = smin.ult(smax) ? smin : smax; |
| umax = smin.ugt(smax) ? smin : smax; |
| } else { |
| umin = APInt::getMinValue(width); |
| umax = APInt::getMaxValue(width); |
| } |
| return {umin, umax, smin, smax}; |
| } |
| |
| ConstantIntRanges ConstantIntRanges::fromUnsigned(const APInt &umin, |
| const APInt &umax) { |
| unsigned int width = umin.getBitWidth(); |
| APInt smin, smax; |
| if (umin.isNonNegative() == umax.isNonNegative()) { |
| smin = umin.slt(umax) ? umin : umax; |
| smax = umin.sgt(umax) ? umin : umax; |
| } else { |
| smin = APInt::getSignedMinValue(width); |
| smax = APInt::getSignedMaxValue(width); |
| } |
| return {umin, umax, smin, smax}; |
| } |
| |
| ConstantIntRanges |
| ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const { |
| // "Not an integer" poisons everything and also cannot be fed to comparison |
| // operators. |
| if (umin().getBitWidth() == 0) |
| return *this; |
| if (other.umin().getBitWidth() == 0) |
| return other; |
| |
| const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin(); |
| const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax(); |
| const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin(); |
| const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax(); |
| |
| return {uminUnion, umaxUnion, sminUnion, smaxUnion}; |
| } |
| |
| ConstantIntRanges |
| ConstantIntRanges::intersection(const ConstantIntRanges &other) const { |
| // "Not an integer" poisons everything and also cannot be fed to comparison |
| // operators. |
| if (umin().getBitWidth() == 0) |
| return *this; |
| if (other.umin().getBitWidth() == 0) |
| return other; |
| |
| const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin(); |
| const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax(); |
| const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin(); |
| const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax(); |
| |
| return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect}; |
| } |
| |
| Optional<APInt> ConstantIntRanges::getConstantValue() const { |
| // Note: we need to exclude the trivially-equal width 0 values here. |
| if (umin() == umax() && umin().getBitWidth() != 0) |
| return umin(); |
| if (smin() == smax() && smin().getBitWidth() != 0) |
| return smin(); |
| return None; |
| } |
| |
| raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) { |
| return os << "unsigned : [" << range.umin() << ", " << range.umax() |
| << "] signed : [" << range.smin() << ", " << range.smax() << "]"; |
| } |