| //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Quant/QuantTypes.h" |
| #include "TypeDetail.h" |
| #include "mlir/Dialect/Quant/QuantOps.h" |
| |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/MathExtras.h" |
| |
| using namespace mlir; |
| using namespace mlir::quant; |
| using namespace mlir::quant::detail; |
| |
| unsigned QuantizedType::getFlags() const { |
| return static_cast<ImplType *>(impl)->flags; |
| } |
| |
| bool QuantizedType::classof(Type type) { |
| return llvm::isa<QuantizationDialect>(type.getDialect()); |
| } |
| |
| LogicalResult |
| QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError, |
| unsigned flags, Type storageType, Type expressedType, |
| int64_t storageTypeMin, int64_t storageTypeMax) { |
| // Verify that the storage type is integral. |
| // This restriction may be lifted at some point in favor of using bf16 |
| // or f16 as exact representations on hardware where that is advantageous. |
| auto intStorageType = storageType.dyn_cast<IntegerType>(); |
| if (!intStorageType) |
| return emitError() << "storage type must be integral"; |
| unsigned integralWidth = intStorageType.getWidth(); |
| |
| // Verify storage width. |
| if (integralWidth == 0 || integralWidth > MaxStorageBits) |
| return emitError() << "illegal storage type size: " << integralWidth; |
| |
| // Verify storageTypeMin and storageTypeMax. |
| bool isSigned = |
| (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed; |
| int64_t defaultIntegerMin = |
| getDefaultMinimumForInteger(isSigned, integralWidth); |
| int64_t defaultIntegerMax = |
| getDefaultMaximumForInteger(isSigned, integralWidth); |
| if (storageTypeMax - storageTypeMin <= 0 || |
| storageTypeMin < defaultIntegerMin || |
| storageTypeMax > defaultIntegerMax) { |
| return emitError() << "illegal storage min and storage max: (" |
| << storageTypeMin << ":" << storageTypeMax << ")"; |
| } |
| return success(); |
| } |
| |
| Type QuantizedType::getStorageType() const { |
| return static_cast<ImplType *>(impl)->storageType; |
| } |
| |
| int64_t QuantizedType::getStorageTypeMin() const { |
| return static_cast<ImplType *>(impl)->storageTypeMin; |
| } |
| |
| int64_t QuantizedType::getStorageTypeMax() const { |
| return static_cast<ImplType *>(impl)->storageTypeMax; |
| } |
| |
| unsigned QuantizedType::getStorageTypeIntegralWidth() const { |
| // NOTE: If ever supporting non-integral storage types, some other scheme |
| // for determining the width will be needed. |
| return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth(); |
| } |
| |
| Type QuantizedType::getExpressedType() const { |
| return static_cast<ImplType *>(impl)->expressedType; |
| } |
| |
| bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) { |
| if (candidateExpressedType.isa<ShapedType>()) { |
| return candidateExpressedType.cast<ShapedType>().getElementType() == |
| getExpressedType(); |
| } |
| return candidateExpressedType == getExpressedType(); |
| } |
| |
| QuantizedType |
| QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) { |
| if (primitiveOrContainerType.isa<ShapedType>()) { |
| Type elementType = |
| primitiveOrContainerType.cast<ShapedType>().getElementType(); |
| return elementType.dyn_cast<QuantizedType>(); |
| } |
| return primitiveOrContainerType.dyn_cast<QuantizedType>(); |
| } |
| |
| Type QuantizedType::castFromStorageType(Type candidateType) { |
| if (candidateType == getStorageType()) { |
| // i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> |
| return *this; |
| } else if (candidateType.isa<RankedTensorType>()) { |
| // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
| return RankedTensorType::get( |
| candidateType.cast<RankedTensorType>().getShape(), getStorageType()); |
| } else if (candidateType.isa<UnrankedTensorType>()) { |
| // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">> |
| return UnrankedTensorType::get(getStorageType()); |
| } else if (candidateType.isa<VectorType>()) { |
| // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
| return VectorType::get(candidateType.cast<VectorType>().getShape(), |
| getStorageType()); |
| } |
| |
| return nullptr; |
| } |
| |
| Type QuantizedType::castToStorageType(Type quantizedType) { |
| if (quantizedType.isa<QuantizedType>()) { |
| // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 |
| return quantizedType.cast<QuantizedType>().getStorageType(); |
| } else if (quantizedType.isa<ShapedType>()) { |
| // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
| ShapedType sType = quantizedType.cast<ShapedType>(); |
| if (!sType.getElementType().isa<QuantizedType>()) { |
| return nullptr; |
| } |
| Type storageType = |
| sType.getElementType().cast<QuantizedType>().getStorageType(); |
| if (quantizedType.isa<RankedTensorType>()) { |
| return RankedTensorType::get(sType.getShape(), storageType); |
| } else if (quantizedType.isa<UnrankedTensorType>()) { |
| return UnrankedTensorType::get(storageType); |
| } else if (quantizedType.isa<VectorType>()) { |
| return VectorType::get(sType.getShape(), storageType); |
| } |
| } |
| |
| return nullptr; |
| } |
| |
| Type QuantizedType::castFromExpressedType(Type candidateType) { |
| if (candidateType == getExpressedType()) { |
| // i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> |
| return *this; |
| } else if (candidateType.isa<ShapedType>()) { |
| ShapedType candidateShapedType = candidateType.cast<ShapedType>(); |
| if (candidateShapedType.getElementType() != getExpressedType()) { |
| return nullptr; |
| } |
| |
| if (candidateType.isa<RankedTensorType>()) { |
| // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
| return RankedTensorType::get(candidateShapedType.getShape(), *this); |
| } else if (candidateType.isa<UnrankedTensorType>()) { |
| // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">> |
| return UnrankedTensorType::get(*this); |
| } else if (candidateType.isa<VectorType>()) { |
| // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
| return VectorType::get(candidateShapedType.getShape(), *this); |
| } |
| } |
| |
| return nullptr; |
| } |
| |
| Type QuantizedType::castToExpressedType(Type quantizedType) { |
| if (quantizedType.isa<QuantizedType>()) { |
| // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 |
| return quantizedType.cast<QuantizedType>().getExpressedType(); |
| } else if (quantizedType.isa<ShapedType>()) { |
| // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> |
| ShapedType sType = quantizedType.cast<ShapedType>(); |
| if (!sType.getElementType().isa<QuantizedType>()) { |
| return nullptr; |
| } |
| Type expressedType = |
| sType.getElementType().cast<QuantizedType>().getExpressedType(); |
| if (quantizedType.isa<RankedTensorType>()) { |
| return RankedTensorType::get(sType.getShape(), expressedType); |
| } else if (quantizedType.isa<UnrankedTensorType>()) { |
| return UnrankedTensorType::get(expressedType); |
| } else if (quantizedType.isa<VectorType>()) { |
| return VectorType::get(sType.getShape(), expressedType); |
| } |
| } |
| |
| return nullptr; |
| } |
| |
| Type QuantizedType::castExpressedToStorageType(Type candidateType) { |
| Type expressedQuantizedType = castFromExpressedType(candidateType); |
| if (!expressedQuantizedType) { |
| return nullptr; |
| } |
| return QuantizedType::castToStorageType(expressedQuantizedType); |
| } |
| |
| AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType, |
| Type expressedType, |
| int64_t storageTypeMin, |
| int64_t storageTypeMax) { |
| return Base::get(storageType.getContext(), flags, storageType, expressedType, |
| storageTypeMin, storageTypeMax); |
| } |
| |
| AnyQuantizedType |
| AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
| unsigned flags, Type storageType, |
| Type expressedType, int64_t storageTypeMin, |
| int64_t storageTypeMax) { |
| return Base::getChecked(emitError, storageType.getContext(), flags, |
| storageType, expressedType, storageTypeMin, |
| storageTypeMax); |
| } |
| |
| LogicalResult |
| AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError, |
| unsigned flags, Type storageType, Type expressedType, |
| int64_t storageTypeMin, int64_t storageTypeMax) { |
| if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, |
| storageTypeMin, storageTypeMax))) { |
| return failure(); |
| } |
| |
| // Verify that the expressed type is floating point. |
| // If this restriction is ever eliminated, the parser/printer must be |
| // extended. |
| if (expressedType && !expressedType.isa<FloatType>()) |
| return emitError() << "expressed type must be floating point"; |
| |
| return success(); |
| } |
| |
| UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType, |
| Type expressedType, double scale, |
| int64_t zeroPoint, |
| int64_t storageTypeMin, |
| int64_t storageTypeMax) { |
| return Base::get(storageType.getContext(), flags, storageType, expressedType, |
| scale, zeroPoint, storageTypeMin, storageTypeMax); |
| } |
| |
| UniformQuantizedType UniformQuantizedType::getChecked( |
| function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
| Type storageType, Type expressedType, double scale, int64_t zeroPoint, |
| int64_t storageTypeMin, int64_t storageTypeMax) { |
| return Base::getChecked(emitError, storageType.getContext(), flags, |
| storageType, expressedType, scale, zeroPoint, |
| storageTypeMin, storageTypeMax); |
| } |
| |
| LogicalResult UniformQuantizedType::verify( |
| function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
| Type storageType, Type expressedType, double scale, int64_t zeroPoint, |
| int64_t storageTypeMin, int64_t storageTypeMax) { |
| if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, |
| storageTypeMin, storageTypeMax))) { |
| return failure(); |
| } |
| |
| // Uniform quantization requires fully expressed parameters, including |
| // expressed type. |
| if (!expressedType) |
| return emitError() << "uniform quantization requires expressed type"; |
| |
| // Verify that the expressed type is floating point. |
| // If this restriction is ever eliminated, the parser/printer must be |
| // extended. |
| if (!expressedType.isa<FloatType>()) |
| return emitError() << "expressed type must be floating point"; |
| |
| // Verify scale. |
| if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) |
| return emitError() << "illegal scale: " << scale; |
| |
| return success(); |
| } |
| |
| double UniformQuantizedType::getScale() const { return getImpl()->scale; } |
| |
| int64_t UniformQuantizedType::getZeroPoint() const { |
| return getImpl()->zeroPoint; |
| } |
| |
| UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get( |
| unsigned flags, Type storageType, Type expressedType, |
| ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, |
| int32_t quantizedDimension, int64_t storageTypeMin, |
| int64_t storageTypeMax) { |
| return Base::get(storageType.getContext(), flags, storageType, expressedType, |
| scales, zeroPoints, quantizedDimension, storageTypeMin, |
| storageTypeMax); |
| } |
| |
| UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( |
| function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
| Type storageType, Type expressedType, ArrayRef<double> scales, |
| ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, |
| int64_t storageTypeMin, int64_t storageTypeMax) { |
| return Base::getChecked(emitError, storageType.getContext(), flags, |
| storageType, expressedType, scales, zeroPoints, |
| quantizedDimension, storageTypeMin, storageTypeMax); |
| } |
| |
| LogicalResult UniformQuantizedPerAxisType::verify( |
| function_ref<InFlightDiagnostic()> emitError, unsigned flags, |
| Type storageType, Type expressedType, ArrayRef<double> scales, |
| ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, |
| int64_t storageTypeMin, int64_t storageTypeMax) { |
| if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, |
| storageTypeMin, storageTypeMax))) { |
| return failure(); |
| } |
| |
| // Uniform quantization requires fully expressed parameters, including |
| // expressed type. |
| if (!expressedType) |
| return emitError() << "uniform quantization requires expressed type"; |
| |
| // Verify that the expressed type is floating point. |
| // If this restriction is ever eliminated, the parser/printer must be |
| // extended. |
| if (!expressedType.isa<FloatType>()) |
| return emitError() << "expressed type must be floating point"; |
| |
| // Ensure that the number of scales and zeroPoints match. |
| if (scales.size() != zeroPoints.size()) |
| return emitError() << "illegal number of scales and zeroPoints: " |
| << scales.size() << ", " << zeroPoints.size(); |
| |
| // Verify scale. |
| for (double scale : scales) { |
| if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) |
| return emitError() << "illegal scale: " << scale; |
| } |
| |
| return success(); |
| } |
| |
| ArrayRef<double> UniformQuantizedPerAxisType::getScales() const { |
| return getImpl()->getScales(); |
| } |
| |
| ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const { |
| return getImpl()->getZeroPoints(); |
| } |
| |
| int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { |
| return getImpl()->quantizedDimension; |
| } |
| |
| CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, |
| double min, double max) { |
| return Base::get(expressedType.getContext(), expressedType, min, max); |
| } |
| |
| CalibratedQuantizedType CalibratedQuantizedType::getChecked( |
| function_ref<InFlightDiagnostic()> emitError, Type expressedType, |
| double min, double max) { |
| return Base::getChecked(emitError, expressedType.getContext(), expressedType, |
| min, max); |
| } |
| |
| LogicalResult |
| CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError, |
| Type expressedType, double min, double max) { |
| // Verify that the expressed type is floating point. |
| // If this restriction is ever eliminated, the parser/printer must be |
| // extended. |
| if (!expressedType.isa<FloatType>()) |
| return emitError() << "expressed type must be floating point"; |
| if (max <= min) |
| return emitError() << "illegal min and max: (" << min << ":" << max << ")"; |
| |
| return success(); |
| } |
| |
| double CalibratedQuantizedType::getMin() const { return getImpl()->min; } |
| |
| double CalibratedQuantizedType::getMax() const { return getImpl()->max; } |