blob: 58cd160948f7f5dcfd2980c271edc1201e4ddf78 [file] [log] [blame]
//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace quant {
#define GEN_PASS_DEF_LOWERQUANTOPS
#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
namespace {
// If 'inputType' is a tensor, return its element type. If it is a scalar,
// return it as is.
Type getScalarType(Type inputType) {
if (auto tensorType = dyn_cast<TensorType>(inputType))
return tensorType.getElementType();
return inputType;
}
// Return the shape of an input value as a list of attributes (static
// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
// list is returned. If 'input' is a tensor, its shape is returned.
SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
Location loc, Value input) {
if (isa<TensorType>(input.getType()))
return tensor::getMixedSizes(builder, loc, input);
return {};
}
// If 'referenceType' is a scalar, return 'elementType' as is. If
// 'referenceType' is a tensor, return another tensor with the same shape and
// elements of type 'elementType'.
Type getScalarOrTensorType(Type elementType, Type referenceType) {
if (auto tensorType = dyn_cast<TensorType>(referenceType))
return tensorType.clone(elementType);
return elementType;
}
// Return a constant with the given value. If 'referenceType' is a tensor, a
// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
// scalar, 'referenceShape' is ignored and a scalar constant is returned.
Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
Type referenceType,
ArrayRef<OpFoldResult> referenceShape) {
// If the result type is a scalar, return the unmodified scalar constant.
auto tensorType = dyn_cast<TensorType>(referenceType);
if (!tensorType) {
assert(referenceShape.empty());
return scalar;
}
// Create tensor splat
auto tensorConstant =
tensor::SplatOp::create(builder, loc, scalar, referenceShape);
return tensorConstant;
}
// Reshape an unranked tensor into a 1D ranked tensor.
//
// - input
// Unranked tensor.
//
// Return values:
//
// - flatInput
// 1D ranked, dynamically shaped tensor.
//
// - inputShape
// 1D extent tensor containing the shape of the original unranked input.
//
std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
Value input) {
// Get unranked input shape and total size
auto *context = builder.getContext();
auto shapeType = shape::getExtentTensorType(context);
auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
Value inputSize = shape::NumElementsOp::create(
builder, loc, builder.getIndexType(), inputShape);
// Turn input size into 1D tensor
auto flatShapeType = shape::getExtentTensorType(context, 1);
auto flatInputShape =
tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
// Reshape input tensor into 1D
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType =
RankedTensorType::get({ShapedType::kDynamic}, elementType);
auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
flatInputShape);
return std::make_pair(flatInput, inputShape);
}
// Reshape an unranked tensor into a 3D ranked tensor where the central
// dimension of the result tensor corresponds to dimension 'axis' of the input
// tensor.
//
// - input
// Unranked tensor.
//
// - axis
// Index of the input dimension around which other input dimiensions will be
// collapsed.
//
// - axisSize
// Size of input dimension 'axis'.
//
// Return values:
//
// - flatInput
// 3D ranked tensor of shape [?, axisSize, ?].
//
// - inputShape
// 1D extent tensor containing the shape of the original unranked input.
//
std::pair<Value, Value>
flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
int64_t axis, int64_t axisSize) {
// Get full tensor shape
auto *context = builder.getContext();
auto indexType = builder.getIndexType();
auto shapeType = shape::getExtentTensorType(context);
auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
// Get shape and sizes on left and right of axis
auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
auto shapeLeft =
builder
.create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
inputShape, axisValue)
.getResult(0);
auto sizeLeft =
shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
auto shapeRight =
builder
.create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
inputShape, axisNextValue)
.getResult(1);
auto sizeRight =
shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
// Compute flat input shape as a 3-element 1D tensor
auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize);
auto flatShapeType = shape::getExtentTensorType(context, 3);
auto flatInputShape = tensor::FromElementsOp::create(
builder, loc, flatShapeType,
ValueRange{sizeLeft, axisSizeValue, sizeRight});
// Reshape input to 3D tensor
auto inputType = cast<UnrankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto flatInputType = RankedTensorType::get(
{ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
flatInputShape);
return std::make_pair(flatInput, inputShape);
}
// Reshape an input tensor into its original unranked shape.
//
// - input
// Ranked tensor.
//
// - inputShape
// 1D extent tensor.
//
Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
Value inputShape) {
auto inputType = cast<RankedTensorType>(input.getType());
auto elementType = inputType.getElementType();
auto unrankedType = UnrankedTensorType::get(elementType);
return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
inputShape);
}
// Create a tensor constant containing all scales in a per-channel quantized
// type. Example:
//
// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
//
// produces
//
// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
//
Value materializePerChannelScales(OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
auto scales = quantizedType.getScales();
auto expressedType = quantizedType.getExpressedType();
auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
return builder.getFloatAttr(expressedType, scale);
});
auto tensorType =
RankedTensorType::get({(int64_t)scales.size()}, expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
}
// Create a tensor constant containing all zero points in a per-channel
// quantized type. Example:
//
// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
//
// produces
//
// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
//
Value materializePerChannelZeroPoints(
OpBuilder &builder, Location loc,
UniformQuantizedPerAxisType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
auto zeroPointAttrs =
llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
return builder.getIntegerAttr(storageType, zeroPoint);
});
auto tensorType =
RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
}
// Create a tensor constant containing all scales in a sub-channel quantized
// type. Example:
//
// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
//
// produces
//
// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
//
Value materializeSubChannelScales(
OpBuilder &builder, Location loc,
UniformQuantizedSubChannelType quantizedType) {
auto scales = quantizedType.getScales();
auto expressedType = quantizedType.getExpressedType();
auto scaleAttrs = llvm::map_to_vector(
scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
return builder.getFloatAttr(expressedType, scale);
});
auto tensorType =
RankedTensorType::get(scales.getType().getShape(), expressedType);
auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
}
// Create a tensor constant containing all zero points in a sub-channel
// quantized type. Example:
//
// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
//
// produces
//
// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
//
Value materializeSubChannelZeroPoints(
OpBuilder &builder, Location loc,
UniformQuantizedSubChannelType quantizedType) {
auto zeroPoints = quantizedType.getZeroPoints();
auto storageType = quantizedType.getStorageType();
auto zeroPointAttrs = llvm::map_to_vector(
zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
return builder.getIntegerAttr(storageType, zeroPoint);
});
auto tensorType =
RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
}
// Clamp the given scalar or tensor input using the storage bounds encoded in
// the given quantized type, if present.
//
// - input
// Scalar or ranked tensor input. The element type must match the storage type
// of 'quantizedType'.
//
// - inputShape
// If 'input' is a tensor, combination of attributes/values representing its
// static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - quantizedType
// Per-axis or per-channel quantized type.
Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape,
QuantizedType quantizedType) {
// If quantized type does not narrow down the storage type range, there is
// nothing to do.
if (!quantizedType.hasStorageTypeBounds())
return input;
// Materialize bounds
auto inputType = input.getType();
auto storageType = quantizedType.getStorageType();
auto storageMinScalar = arith::ConstantIntOp::create(
builder, loc, storageType, quantizedType.getStorageTypeMin());
auto storageMaxScalar = arith::ConstantIntOp::create(
builder, loc, storageType, quantizedType.getStorageTypeMax());
auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
inputType, inputShape);
auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
inputType, inputShape);
// Clamp
if (quantizedType.isSigned()) {
input = arith::MaxSIOp::create(builder, loc, input, storageMin);
input = arith::MinSIOp::create(builder, loc, input, storageMax);
} else {
input = arith::MaxUIOp::create(builder, loc, input, storageMin);
input = arith::MinUIOp::create(builder, loc, input, storageMax);
}
return input;
}
// Emit op 'arith.fptosi' or 'arith.fptoui'.
Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
return arith::FPToSIOp::create(builder, loc, resultType, input);
return arith::FPToUIOp::create(builder, loc, resultType, input);
}
// Emit op 'arith.sitofp' or 'arith.uitofp'.
Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
Type resultType, bool isSigned) {
if (isSigned)
return arith::SIToFPOp::create(builder, loc, resultType, input);
return arith::UIToFPOp::create(builder, loc, resultType, input);
}
// Quantize a scalar or ranked tensor value. The stored value is clamped using
// the storage bounds encoded in the given quantized type.
//
// See function 'convertRanked()' below for a description of the arguments.
Value quantizeValue(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Scale input
auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
// Skip unnecessary computations if no zero point is given
Value storedValueFloat = scaledValue;
if (!matchPattern(zeroPoint, m_Zero())) {
// Convert zero point to tensor if necessary
zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
inputShape);
// Convert zero point from storage to expressed type
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Add zero point to stored value
storedValueFloat =
arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
}
// Convert stored value to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), inputType);
auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
storageScalarOrTensorType,
quantizedType.isSigned());
// Clamp stored value it if the storage type is bound
auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
inputShape, quantizedType);
return storedValueClamped;
}
// Dequantize a scalar or ranked tensor input.
//
// See function 'convertRanked()' below for a description of the arguments.
Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
// Convert scale to tensor if necessary
auto inputType = input.getType();
scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
// Convert stored value to float
auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
quantizedType.isSigned());
// Skip unnecessary computations if no zero point is given
if (!matchPattern(zeroPoint, m_Zero())) {
// Convert zero point to tensor if necessary
zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
inputShape);
// Convert zero point from storage to expressed type
zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
quantizedType.isSigned());
// Subtract zero point to stored value
result = arith::SubFOp::create(builder, loc, result, zeroPoint);
}
// Multiply by scale
result = arith::MulFOp::create(builder, loc, result, scale);
return result;
}
// Convert a scalar or ranked tensor input with the given scale and zero point
// values.
//
// - input
// Scalar or ranked tensor value.
//
// - inputShape
// If 'input' is a tensor, combination or attributes/values representing its
// static/dynamic dimensions. If 'input' is a scalar, empty list.
//
// - scale
// Scale as a floating-point scalar value.
//
// - zeroPoint
// Zero point as an integer scalar value.
//
// - quantizedType
// Scalar quantized type of the result ('quant.qcast') or of the input
// ('quant.dcast').
//
Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
Value zeroPoint, QuantizedType quantizedType) {
if (isa<QuantizeCastOp>(op))
return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
quantizedType);
if (isa<DequantizeCastOp>(op))
return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
quantizedType);
llvm_unreachable("unexpected quant op");
}
// Convert an operation using per-layer quantization with a scalar or ranked
// tensor input.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar or ranked tensor.
//
// - quantizedType
// Per-layer quantized type.
//
Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
Value input, UniformQuantizedType quantizedType) {
// Create scale and zero point constants
auto expressedType = quantizedType.getExpressedType();
auto storageType = quantizedType.getStorageType();
auto scaleAttr =
builder.getFloatAttr(expressedType, quantizedType.getScale());
auto scale =
arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
auto zeroPointAttr =
builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
auto zeroPoint =
arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
auto inputShape = getScalarOrTensorShape(builder, loc, input);
return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
quantizedType);
}
// Convert an operation using per-layer quantization.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor, or unranked tensor.
//
// - quantizedType
// Per-layer quantized type.
//
Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
Value input, UniformQuantizedType quantizedType) {
// Flatten input if unranked
bool isUnranked = isa<UnrankedTensorType>(input.getType());
Value inputShape;
if (isUnranked)
std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
// Process ranked tensor
auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
// Restore original shape if unranked
if (isUnranked)
result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
return result;
}
// Convert an operation using per-channel quantization and a scalar or ranked
// tensor as an input.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar or ranked tensor.
//
// - quantizedType
// Per-channel quantized type.
//
Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedPerAxisType quantizedType,
int64_t channelAxis) {
auto *context = builder.getContext();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
auto scales = materializePerChannelScales(builder, loc, quantizedType);
auto zeroPoints =
materializePerChannelZeroPoints(builder, loc, quantizedType);
auto elementType = isa<FloatType>(inputType.getElementType())
? quantizedType.getStorageType()
: quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
auto channelAxisAffineMap = AffineMap::get(
inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
SmallVector<AffineMap> indexingMaps{
builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
auto result = builder
.create<linalg::GenericOp>(
loc,
init.getType(), // resultType
ValueRange{input, scales, zeroPoints}, // inputs
ValueRange{init}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
assert(args.size() == 4);
auto input = args[0];
auto scale = args[1];
auto zeroPoint = args[2];
auto result =
convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
linalg::YieldOp::create(builder, loc, result);
})
.getResult(0);
return result;
}
// Convert an operation using per-channel quantization.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor, or unranked tensor.
//
// - quantizedType
// Per-channel quantized type.
//
Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedPerAxisType quantizedType) {
// Flatten unranked tensor into a 3D ranked tensor if necessary
bool isUnranked = isa<UnrankedTensorType>(input.getType());
int64_t channelAxis = quantizedType.getQuantizedDimension();
int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
Value inputShape;
if (isUnranked) {
std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
builder, loc, input, channelAxis, channelAxisSize);
channelAxis = 1;
}
// Work on a ranked tensor
auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
channelAxis);
// Restore original tensor shape if unranked
if (isUnranked)
result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
return result;
}
// Convert an operation using sub-channel quantization.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor.
//
// - quantizedType
// Sub-channel quantized type.
//
Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
Value input,
UniformQuantizedSubChannelType quantizedType) {
auto *context = builder.getContext();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
auto scales = materializeSubChannelScales(builder, loc, quantizedType);
auto zeroPoints =
materializeSubChannelZeroPoints(builder, loc, quantizedType);
auto elementType = isa<FloatType>(inputType.getElementType())
? quantizedType.getStorageType()
: quantizedType.getExpressedType();
auto initShape = tensor::getMixedSizes(builder, loc, input);
Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
utils::IteratorType::parallel);
const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
quantizedType.getBlockSizeInfo();
SmallVector<AffineExpr> affineExprs(inputRank,
builder.getAffineConstantExpr(0));
for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
affineExprs[quantizedDimension] =
builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
}
auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
SmallVector<AffineMap> indexingMaps{
builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
builder.getMultiDimIdentityMap(inputRank)};
auto result = builder
.create<linalg::GenericOp>(
loc,
init.getType(), // resultType
ValueRange{input, scales, zeroPoints}, // inputs
ValueRange{init}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
assert(args.size() == 4);
auto input = args[0];
auto scale = args[1];
auto zeroPoint = args[2];
auto result =
convertRanked(builder, loc, op, input, {}, scale,
zeroPoint, quantizedType);
linalg::YieldOp::create(builder, loc, result);
})
.getResult(0);
return result;
}
// Convert a quantization operation.
//
// - op
// 'quant.dcast' or 'quant.qcast' op.
//
// - input
// Scalar, ranked tensor, or unranked tensor. The element type matches
// the storage type (quant.dcast) or expressed type (quant.qcast) of
// 'quantizedType'.
//
// - quantizedType
// Per-layer or per-channel quantized type.
//
Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
Value input, Type quantizedType) {
if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
if (auto uniformQuantizedPerAxisType =
dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
return convertPerChannel(builder, loc, op, input,
uniformQuantizedPerAxisType);
if (auto uniformQuantizedSubChannelType =
dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
return convertSubChannel(builder, loc, op, input,
uniformQuantizedSubChannelType);
llvm_unreachable("unexpected quantized type");
}
// Lowering pattern for 'quant.dcast'
struct DequantizeCastOpConversion
: public OpConversionPattern<quant::DequantizeCastOp> {
using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput();
auto quantizedType =
cast<QuantizedType>(getScalarType(op.getInput().getType()));
// Convert quantized input to storage type
auto storageScalarOrTensorType =
getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
input = quant::StorageCastOp::create(rewriter, loc,
storageScalarOrTensorType, input);
auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
rewriter.replaceOp(op, result);
return success();
}
};
// Lowering pattern for 'quant.qcast'
struct QuantizeCastOpConversion
: public OpConversionPattern<quant::QuantizeCastOp> {
using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto input = op.getInput();
auto quantizedType = getScalarType(op.getResult().getType());
// Flatten unranked tensor input
auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
// Cast stored value to result quantized value
rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
op, op.getResult().getType(), result);
return success();
}
};
struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateLowerQuantOpsPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalOp<quant::StorageCastOp>();
target.addIllegalDialect<quant::QuantDialect>();
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
shape::ShapeDialect, tensor::TensorDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns) {
patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
patterns.getContext());
}
} // namespace quant
} // namespace mlir