blob: 45e791b8d4c468d9f8724193ecb8659879686449 [file] [log] [blame]
//===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_QUANT_QUANT_TYPES_H_
#define MLIR_DIALECT_QUANT_QUANT_TYPES_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "llvm/Support/MathExtras.h"
namespace mlir {
namespace quant {
class QuantizedIntegerType;
namespace detail {
struct QuantizedTypeStorage;
struct AnyQuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
struct CalibratedQuantizedTypeStorage;
} // namespace detail
/// Enumeration of bit-mapped flags related to quantized types.
namespace QuantizationFlags {
enum FlagValue {
// Indicates that the storage type should be interpreted as a signed
// integer. The default is to interpret it as an unsigned value.
Signed = 1,
};
} // namespace QuantizationFlags
/// Base class for all quantized types known to this dialect.
/// All quantized types have:
/// - storageType: The (narrower) numeric type that is being used to
/// approximate some expressed type.
/// - expressedType: The type that is being approximated.
///
/// The base class provides generic support for manipulating the types based
/// on these fields.
class QuantizedType : public Type {
public:
using ImplType = detail::QuantizedTypeStorage;
using Type::Type;
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
/// Gets the minimum possible stored by a storageType. storageTypeMin must
/// be greater than or equal to this value.
static int64_t getDefaultMinimumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::minIntN(integralWidth);
}
return 0;
}
/// Gets the maximum possible stored by a storageType. storageTypeMax must
/// be less than or equal to this value.
static int64_t getDefaultMaximumForInteger(bool isSigned,
unsigned integralWidth) {
if (isSigned) {
return llvm::maxIntN(integralWidth);
}
return llvm::maxUIntN(integralWidth);
}
/// Gets the original expressed type that this quantized type approximates.
/// Note that this presumes that the quantized type was always derived from
/// a floating point type, which in the broadest definition, is not true (i.e.
/// it could be some form of integral, fixed type or affine type in its own
/// right); however, at the high level, no examples of such usage are
/// presently known and the restriction serves some useful purposes (such as
/// always being able to reverse a transformation or measure error). In most
/// cases, this will be f32.
Type getExpressedType() const;
/// Gets the flags associated with this type. Typically a more specific
/// accessor is appropriate.
unsigned getFlags() const;
// Convenience helpers.
/// Whether the storage type should be interpreted as a signed quantity
/// (true) or an unsigned value (false).
bool isSigned() const {
return (getFlags() & QuantizationFlags::Signed) ==
QuantizationFlags::Signed;
}
/// Gets the underlying type used for to store values. Note that this may
/// be signed or unsigned. Use the isSigned() accessor to differentiate.
Type getStorageType() const;
/// The minimum value that storageType can take.
int64_t getStorageTypeMin() const;
/// The maximum value that storageType can take.
int64_t getStorageTypeMax() const;
/// Gets the integral bit width that the underlying storage type can exactly
/// represent. For integral storage types, this will just be their width.
unsigned getStorageTypeIntegralWidth() const;
/// Returns whether the candidateExpressedType is a match for this
/// QuantizedType. This will be true if the candidate type is either a
/// primitive type or a container type whose element type equals this
/// QuantizedType's expressed type.
/// Examples of compatible candidateExpressedType:
/// !quant.uniform<i8:f32, 1.0> =~ f32
/// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
bool isCompatibleExpressedType(Type candidateExpressedType);
/// Returns the element type as a QuantizedType or nullptr if it is not
/// a quantized type. If the type is primitive, returns that. If it is a
/// container (vector/tensor), return the element type.
/// Examples:
/// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
/// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
/// Casts from a type based on the storageType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// i8 -> !quant.uniform<i8:f32, 1.0>
/// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
/// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
Type castFromStorageType(Type candidateType);
/// Casts from a type based on a QuantizedType to a corresponding type based
/// on the storageType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromStorageType().
static Type castToStorageType(Type quantizedType);
/// Casts from a type based on the expressedType to a corresponding type based
/// on this type (returns nullptr if the cast is not valid).
/// Examples:
/// f32 -> !quant.uniform<i8:f32, 1.0>
/// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
/// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
Type castFromExpressedType(Type candidateType);
/// Casts from a type based on QuantizedType to a corresponding type based
/// on the expressedType (returns nullptr if the cast is not valid).
/// This is the inverse of castFromExpressedType.
static Type castToExpressedType(Type quantizedType);
/// Casts from a type based on the expressedType to the equivalent type
/// based on storageType by way of this QuantizedType. Equivalent to:
/// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
/// (but with validity checks).
/// Example (for this = !quant.uniform<i8:f32, 1.0>):
/// tensor<4xf32> -> tensor<4xi8>
Type castExpressedToStorageType(Type candidateType);
private:
/// Hide the following methods inherited from `Type`. It is almost certainly
/// a bug to call them from a `QuantizedType` object. Users should call
/// `getStorageType` or `getExpressedType` to get the underlying types
/// they want to inspect.
using Type::isBF16;
using Type::isF16;
using Type::isF32;
using Type::isF64;
using Type::isIndex;
using Type::isInteger;
};
/// A quantized type that maps storage to/from expressed types in an
/// unspecified way.
///
/// Typical syntax:
/// quant.any<i8:f32>
/// quant.any<i8>
/// quant.any<i8<-16,15>>
///
/// Note that for the any type, the expressed type is optional.
class AnyQuantizedType
: public Type::TypeBase<AnyQuantizedType, QuantizedType,
detail::AnyQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static AnyQuantizedType get(unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static AnyQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};
/// Represents a family of uniform, quantized types.
///
/// Each instance of this type expresses a mapping between real values (most
/// often expressed in floating point f32) and quantized values (either fixed
/// point or affine).
///
/// The relationship is:
/// real_value = scale * (quantized_value - zero_point)
///
/// It is used as part of high level graph transformations that have the goal
/// of re-expressing parts of a computation in terms of this common form for
/// more efficient execution at runtime. In addition, it is designed to be
/// expressive enough to facilitate lowering to precise types and operations
/// in target hardware.
///
/// As a high-level type, focused on intermediate passes, this type holds
/// opinions consistent with high-level usage. If lowering math kernels below
/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
/// instruction sets), it is expected that the information expressed here
/// will be used to drive low level codegen and target specific type selection,
/// but this type will likely be erased in the process.
///
/// Syntax synopsis:
/// Per-layer, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
/// Per-layer, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedType
: public Type::TypeBase<UniformQuantizedType, QuantizedType,
detail::UniformQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
double getScale() const;
/// Gets the storage value corresponding to the real value 0 in the affine
/// equation.
int64_t getZeroPoint() const;
// Fixed point values are real numbers divided by a scale.
// Currently, only signed storage types are treated as fixed point.
// A fixed point value can be obtained from an affine value by subtracting
// the zeroPoint.
// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
};
/// Represents per-axis (also known as per-channel quantization).
///
/// Syntax synopsis:
/// Per-axis, all parameters expressed:
/// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
/// Per-axis, optional parameters omitted:
/// !quant<uniform[StorageType]{Scale}>
///
/// StorageType: 'i'|'u' NumBits
/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
/// QuantizedDim: An integer value
/// QuantParams: (Scale ':' ZeroPoint)+
/// Scale: A legal double value
/// ZeroPoint: An integer value
class UniformQuantizedPerAxisType
: public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
detail::UniformQuantizedPerAxisTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
get(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedPerAxisType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
/// by 1. The ith scale corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<double> getScales() const;
/// Gets the storage values corresponding to the real value 0 in the affine
/// equation. The ith zero point corresponds to the ith slice in the
/// quantized_dimension.
ArrayRef<int64_t> getZeroPoints() const;
/// Specifies the dimension of the Tensor's shape that the scales and
/// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
/// with quantization params:
/// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
/// will be quantized across the second dimension of t.
/// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
/// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
/// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
int32_t getQuantizedDimension() const;
/// Fixed point values are real numbers divided by a scale.
/// Currently, only signed storage types are treated as fixed point.
/// A fixed point value can be obtained from an affine value by subtracting
/// the zeroPoint.
/// In the future, this may be explicit versus implied by type and zeroPoint.
bool isFixedPoint() const {
if (!isSigned())
return false;
return llvm::all_of(getZeroPoints(),
[](int64_t zeroPoint) { return zeroPoint != 0; });
}
};
/// A quantized type that infers its range from given min/max values.
///
/// Typical syntax:
/// quant.calibrated<f32<-0.922,0.981>>
class CalibratedQuantizedType
: public Type::TypeBase<CalibratedQuantizedType, QuantizedType,
detail::CalibratedQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
static CalibratedQuantizedType get(Type expressedType, double min,
double max);
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static CalibratedQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type expressedType,
double min, double max);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};
} // namespace quant
} // namespace mlir
#endif // MLIR_DIALECT_QUANT_QUANT_TYPES_H_