blob: e23a0d6aba825336480c07c3d73f2b9ca5ceca6d [file] [log] [blame] [edit]
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "QuantDialectBytecode.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
namespace mlir {
namespace quant {
namespace {
// Verify the integrity of per-axis quantization information, if present.
//
// - uniformQuantizedPerAxisType
// A quantized type with per-axis quantization.
//
// - containerType
// Original input or result type of the operation using the provided quantized
// type. Used to ensure that the quantized type appears within a tensor and
// that the tensor is compatible with per-axis quantization information.
//
LogicalResult verifyPerAxisQuantization(
Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
Type containerType) {
auto tensorType = dyn_cast<TensorType>(containerType);
if (!tensorType)
return op->emitError("scalar types may not use per-axis quantization");
if (!tensorType.hasRank())
return success();
int32_t quantizedDimension =
uniformQuantizedPerAxisType.getQuantizedDimension();
if ((int64_t)quantizedDimension >= tensorType.getRank())
return op->emitError("quantized dimension must be less than tensor rank");
int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
if (quantizedDimensionSize != ShapedType::kDynamic &&
quantizedDimensionSize !=
(int64_t)uniformQuantizedPerAxisType.getScales().size())
return op->emitError(
"quantized dimension size does not match number of scales");
return success();
}
// Verifies that the sub-channel quantization parameters are consistent with
// the given container type. The function checks the following:
//
// - The container type must be a ranked tensor type.
// - Each quantized dimension must be less than the rank of the tensor.
// - The size of each dimension at the quantized dimension must be divisible
// by the corresponding block size.
// - The scale dimension size at each axis index should match the tensor
// dimension at the index divided by the corresponding block size.
//
// The `uniformQuantizedSubChannelType` argument provides the sub-channel
// quantization parameters, and the `containerType` argument specifies the
// type of the container holding the quantized data.
//
LogicalResult verifySubChannelQuantization(
Operation *op,
UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
Type containerType) {
auto tensorType = dyn_cast<TensorType>(containerType);
if (!tensorType)
return op->emitError("scalar types may not use sub-channel quantization");
if (!tensorType.hasRank())
return op->emitError(
"tensor containing the sub-channel quantized type must be ranked");
const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
uniformQuantizedSubChannelType.getBlockSizeInfo();
auto shape = tensorType.getShape();
// The dimension size of scale for an axis which is not specified as quantized
// dimension should be 1.
SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
if (quantizedDimension >= tensorType.getRank())
return op->emitError()
<< "quantized dimension " << quantizedDimension
<< " must be less than tensor rank " << tensorType.getRank();
if (!tensorType.isDynamicDim(quantizedDimension) &&
tensorType.getDimSize(quantizedDimension) % blockSize != 0)
return op->emitError()
<< "tensor dimension size "
<< tensorType.getDimSize(quantizedDimension) << " at axis "
<< quantizedDimension
<< " must be divisible by the corresponding block size "
<< blockSize;
if (tensorType.isDynamicDim(quantizedDimension))
expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
else
expectedScaleShape[quantizedDimension] =
tensorType.getDimSize(quantizedDimension) / blockSize;
}
// Block sizes must be greater than 0 and divide the corresponding dimension
// size. While a block size b must be less than or equal to the corresponding
// dimension size d, this constraint is implicitly enforced by requiring that
// d % b == 0 when d != 0.
//
// However, a problem arises when d = 0. The divisibility constraint allows b
// to be any value, potentially violating the requirement that b <= d.
// Furthermore, if b is unspecified (implicitly equal to d), it violates the
// constraint that b > 0.
//
// Therefore, we explicitly disallow the case where d = 0 to maintain
// consistency and avoid these issues.
if (llvm::is_contained(tensorType.getShape(), 0)) {
return op->emitError() << "tensor dimension size of zero is not allowed "
"with sub-channel quantization";
}
auto scaleShape =
uniformQuantizedSubChannelType.getScales().getType().getShape();
if (scaleShape.size() != shape.size()) {
return op->emitError() << "Rank of scales " << scaleShape.size()
<< " must match "
<< "the rank of the tensor " << shape.size();
}
for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
if (expectedScaleShape[index] != ShapedType::kDynamic &&
expectedScaleShape[index] != scaleShape[index])
return op->emitError() << "dimension size " << scaleDim
<< " of scales tensor at axis " << index
<< " should match (tensor dimension at axis / "
"block sizes at axis) = "
<< expectedScaleShape[index];
}
return success();
}
// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
//
// - quantizedType
// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
// whether as a primitive type or in a tensor.
//
// - floatType
// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
// whether as a primitive type or in a tensor.
//
// - containerType
// Type of original input or result.
//
LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
FloatType floatType, Type containerType) {
if (quantizedType.getExpressedType() != floatType)
return op->emitError(
"expressed type in quantized type expected to match float type");
// Verify integrity of per-axis quantization information, if present.
if (auto quantizedPerAxisType =
dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
}
if (auto quantizedSubChannelType =
dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
return verifySubChannelQuantization(op, quantizedSubChannelType,
containerType);
}
// At this point the type is UniformQuantizedType
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//
void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
>();
detail::addBytecodeInterface(this);
}
//===----------------------------------------------------------------------===//
// DequantizeCastOp
//===----------------------------------------------------------------------===//
LogicalResult DequantizeCastOp::verify() {
return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
getInput().getType());
}
OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
// Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
// with the value of x. Values x and y are guaranteed to be of the same type
// in this pattern.
auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
if (!srcQcastOp)
return {};
assert(srcQcastOp.getInput().getType() == getType());
return srcQcastOp.getInput();
}
FloatType DequantizeCastOp::getFloatType() {
return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
}
QuantizedType DequantizeCastOp::getQuantizedType() {
return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
}
//===----------------------------------------------------------------------===//
// QuantizeCastOp
//===----------------------------------------------------------------------===//
LogicalResult QuantizeCastOp::verify() {
return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
getInput().getType());
}
OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
// Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
// with the value of x if the casts invert each other. Contrary to the folding
// pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
// x and y are not guaranteed to be of the same type here, as they may use
// different quantization parameters.
auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
return {};
return srcDcastOp.getInput();
}
FloatType QuantizeCastOp::getFloatType() {
return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
}
QuantizedType QuantizeCastOp::getQuantizedType() {
return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
}
//===----------------------------------------------------------------------===//
// StorageCastOp
//===----------------------------------------------------------------------===//
LogicalResult StorageCastOp::verify() {
auto quantizedType = getQuantizedType();
auto integerType = getIntegerType();
if (quantizedType.getStorageType() != integerType)
return emitError(
"storage type in quantized type expected to match integer type");
// Verify integrity of per-axis quantization information, if available. While
// the quantization type may appear in the input or the result, their tensor
// shapes are guaranteed to be identical at this point.
if (auto quantizedPerAxisType =
dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
return verifyPerAxisQuantization(*this, quantizedPerAxisType,
getInput().getType());
}
if (auto quantizedSunChannelType =
dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
return verifySubChannelQuantization(*this, quantizedSunChannelType,
getInput().getType());
}
// At this point the type is UniformQuantizedType
return success();
}
OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
// Matches x -> quant.scast -> quant.scast -> y, replacing the second
// quant.scast with the value of x if the casts invert each other.
auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
if (!srcScastOp || srcScastOp.getInput().getType() != getType())
return {};
return srcScastOp.getInput();
}
IntegerType StorageCastOp::getIntegerType() {
auto inputScalarType = getElementTypeOrSelf(getInput().getType());
if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
return integerType;
auto resultScalarType = getElementTypeOrSelf(getResult().getType());
return cast<IntegerType>(resultScalarType);
}
QuantizedType StorageCastOp::getQuantizedType() {
auto inputScalarType = getElementTypeOrSelf(getInput().getType());
if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
return quantizedType;
auto resultScalarType = getElementTypeOrSelf(getResult().getType());
return cast<QuantizedType>(resultScalarType);
}
} // namespace quant
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"