| //===- QuantizeUtils.cpp - Support utilities for quantization -------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Quant/QuantizeUtils.h" |
| #include "mlir/Dialect/Quant/UniformSupport.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| |
| using namespace mlir; |
| using namespace mlir::quant; |
| |
| /// Converts a possible primitive, real expressed value attribute to a |
| /// corresponding storage attribute (typically FloatAttr -> IntegerAttr). |
| /// quantizedElementType is the QuantizedType that describes the expressed |
| /// origValue. |
| /// Returns a converter Attribute or nullptr if conversion is not possible. |
| static Attribute convertPrimitiveValueAttr( |
| Attribute origRealValue, QuantizedType quantizedElementType, |
| const UniformQuantizedValueConverter &converter, Type &outConvertedType) { |
| if (origRealValue.isa<FloatAttr>()) { |
| FloatAttr floatAttr = origRealValue.cast<FloatAttr>(); |
| outConvertedType = quantizedElementType.getStorageType(); |
| return IntegerAttr::get(quantizedElementType.getStorageType(), |
| converter.quantizeFloatToInt(floatAttr.getValue())); |
| } |
| |
| return nullptr; |
| } |
| |
| /// Converts a real expressed DenseFPElementsAttr to a corresponding |
| /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized |
| /// storage values assuming the given quantizedElementType and converter. |
| static DenseElementsAttr |
| convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, |
| QuantizedType quantizedElementType, |
| const UniformQuantizedValueConverter &converter) { |
| // Convert to corresponding quantized value attributes. |
| SmallVector<APInt, 8> quantValues; |
| if (realFPElementsAttr.isSplat()) { |
| quantValues.push_back( |
| converter.quantizeFloatToInt(*realFPElementsAttr.begin())); |
| } else { |
| quantValues.reserve(realFPElementsAttr.getNumElements()); |
| for (APFloat realVal : realFPElementsAttr) { |
| quantValues.push_back(converter.quantizeFloatToInt(realVal)); |
| } |
| } |
| |
| // Cast from an expressed-type-based type to storage-type-based type, |
| // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). |
| ShapedType newDenseType = |
| quantizedElementType |
| .castExpressedToStorageType(realFPElementsAttr.getType()) |
| .dyn_cast_or_null<ShapedType>(); |
| if (!newDenseType) { |
| return nullptr; |
| } |
| return DenseIntElementsAttr::get(newDenseType, quantValues); |
| } |
| |
| /// Converts a real expressed SplatElementsAttr to a corresponding |
| /// SplatElementsAttr containing quantized storage values assuming the given |
| /// quantizedElementType and converter. |
| static SparseElementsAttr |
| convertSparseElementsAttr(SparseElementsAttr realSparseAttr, |
| QuantizedType quantizedElementType, |
| const UniformQuantizedValueConverter &converter) { |
| DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); |
| if (!realDenseAttr.isa<DenseFPElementsAttr>()) { |
| return nullptr; |
| } |
| DenseElementsAttr quantDenseAttr = |
| convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(), |
| quantizedElementType, converter); |
| if (!quantDenseAttr) { |
| return nullptr; |
| } |
| |
| // Cast from an expressed-type-based type to storage-type-based type, |
| // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). |
| ShapedType newSparseType = |
| quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) |
| .dyn_cast_or_null<ShapedType>(); |
| if (!newSparseType) { |
| return nullptr; |
| } |
| return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), |
| quantDenseAttr); |
| } |
| |
| /// Converts a real expressed Attribute to a corresponding Attribute containing |
| /// quantized storage values assuming the given uniform quantizedElementType and |
| /// converter. |
| Attribute mlir::quant::quantizeAttrUniform( |
| Attribute realValue, UniformQuantizedType quantizedElementType, |
| const UniformQuantizedValueConverter &converter, Type &outConvertedType) { |
| // Fork to handle different variants of constants supported. |
| if (realValue.isa<DenseFPElementsAttr>()) { |
| // Dense tensor or vector constant. |
| auto converted = convertDenseFPElementsAttr( |
| realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter); |
| outConvertedType = converted.getType(); |
| return converted; |
| } else if (realValue.isa<SparseElementsAttr>()) { |
| // Sparse tensor or vector constant. |
| auto converted = convertSparseElementsAttr( |
| realValue.cast<SparseElementsAttr>(), quantizedElementType, converter); |
| outConvertedType = converted.getType(); |
| return converted; |
| } else { |
| // Nothing else matched: try to convert a primitive. |
| return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, |
| outConvertedType); |
| } |
| } |
| |
| /// Convert an attribute from a type based on |
| /// quantizedElementType.getExpressedType() to one based on |
| /// quantizedElementType.getStorageType(). |
| /// Returns nullptr if the conversion is not supported. |
| /// On success, stores the converted type in outConvertedType. |
| Attribute mlir::quant::quantizeAttr(Attribute realValue, |
| QuantizedType quantizedElementType, |
| Type &outConvertedType) { |
| if (auto uniformQuantized = |
| quantizedElementType.dyn_cast<UniformQuantizedType>()) { |
| UniformQuantizedValueConverter converter(uniformQuantized); |
| return quantizeAttrUniform(realValue, uniformQuantized, converter, |
| outConvertedType); |
| |
| } else if (auto uniformQuantizedPerAxis = |
| quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) { |
| UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); |
| auto converted = converter.convert(realValue); |
| // TODO: why we need this outConvertedType? remove it? |
| if (converted) { |
| outConvertedType = converted.getType(); |
| } |
| return converted; |
| } else { |
| return nullptr; |
| } |
| } |