blob: 72d679054308ca8fca2dbf1eca22045e55ccf38d [file] [log] [blame]
//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace quant;
static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
auto typeLoc = parser.getCurrentLocation();
IntegerType type;
// Parse storage type (alpha_ident, integer_literal).
StringRef identifier;
unsigned storageTypeWidth = 0;
OptionalParseResult result = parser.parseOptionalType(type);
if (result.hasValue()) {
if (!succeeded(*result)) {
parser.parseType(type);
return nullptr;
}
isSigned = !type.isUnsigned();
storageTypeWidth = type.getWidth();
} else if (succeeded(parser.parseKeyword(&identifier))) {
// Otherwise, this must be an unsigned integer (`u` integer-literal).
if (!identifier.consume_front("u")) {
parser.emitError(typeLoc, "illegal storage type prefix");
return nullptr;
}
if (identifier.getAsInteger(10, storageTypeWidth)) {
parser.emitError(typeLoc, "expected storage type width");
return nullptr;
}
isSigned = false;
type = parser.getBuilder().getIntegerType(storageTypeWidth);
} else {
return nullptr;
}
if (storageTypeWidth == 0 ||
storageTypeWidth > QuantizedType::MaxStorageBits) {
parser.emitError(typeLoc, "illegal storage type size: ")
<< storageTypeWidth;
return nullptr;
}
return type;
}
static ParseResult parseStorageRange(DialectAsmParser &parser,
IntegerType storageType, bool isSigned,
int64_t &storageTypeMin,
int64_t &storageTypeMax) {
int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
isSigned, storageType.getWidth());
int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
isSigned, storageType.getWidth());
if (failed(parser.parseOptionalLess())) {
storageTypeMin = defaultIntegerMin;
storageTypeMax = defaultIntegerMax;
return success();
}
// Explicit storage min and storage max.
llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
parser.getCurrentLocation(&maxLoc) ||
parser.parseInteger(storageTypeMax) || parser.parseGreater())
return failure();
if (storageTypeMin < defaultIntegerMin) {
return parser.emitError(minLoc, "illegal storage type minimum: ")
<< storageTypeMin;
}
if (storageTypeMax > defaultIntegerMax) {
return parser.emitError(maxLoc, "illegal storage type maximum: ")
<< storageTypeMax;
}
return success();
}
static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
double &min, double &max) {
auto typeLoc = parser.getCurrentLocation();
FloatType type;
if (failed(parser.parseType(type))) {
parser.emitError(typeLoc, "expecting float expressed type");
return nullptr;
}
// Calibrated min and max values.
if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
parser.parseFloat(max) || parser.parseGreater()) {
parser.emitError(typeLoc, "calibrated values must be present");
return nullptr;
}
return type;
}
/// Parses an AnyQuantizedType.
///
/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
int64_t storageTypeMax;
// Type specification.
if (parser.parseLess())
return nullptr;
// Storage type.
bool isSigned = false;
storageType = parseStorageType(parser, isSigned);
if (!storageType) {
return nullptr;
}
if (isSigned) {
typeFlags |= QuantizationFlags::Signed;
}
// Storage type range.
if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
storageTypeMax)) {
return nullptr;
}
// Optional expressed type.
if (succeeded(parser.parseOptionalColon())) {
if (parser.parseType(expressedType)) {
return nullptr;
}
}
if (parser.parseGreater()) {
return nullptr;
}
return parser.getChecked<AnyQuantizedType>(
typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
}
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
int64_t &zeroPoint) {
// scale[:zeroPoint]?
// scale.
if (parser.parseFloat(scale))
return failure();
// zero point.
zeroPoint = 0;
if (failed(parser.parseOptionalColon())) {
// Default zero point.
return success();
}
return parser.parseInteger(zeroPoint);
}
/// Parses a UniformQuantizedType.
///
/// uniform_type ::= uniform_per_layer
/// | uniform_per_axis
/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
/// `,` scale-zero `>`
/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
/// axis-spec `,` scale-zero-list `>`
/// storage-spec ::= storage-type (`<` storage-range `>`)?
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= float-literal `:` integer-literal
/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
static Type parseUniformType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
int64_t storageTypeMin;
int64_t storageTypeMax;
bool isPerAxis = false;
int32_t quantizedDimension;
SmallVector<double, 1> scales;
SmallVector<int64_t, 1> zeroPoints;
// Type specification.
if (parser.parseLess()) {
return nullptr;
}
// Storage type.
bool isSigned = false;
storageType = parseStorageType(parser, isSigned);
if (!storageType) {
return nullptr;
}
if (isSigned) {
typeFlags |= QuantizationFlags::Signed;
}
// Storage type range.
if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
storageTypeMax)) {
return nullptr;
}
// Expressed type.
if (parser.parseColon() || parser.parseType(expressedType)) {
return nullptr;
}
// Optionally parse quantized dimension for per-axis quantization.
if (succeeded(parser.parseOptionalColon())) {
if (parser.parseInteger(quantizedDimension))
return nullptr;
isPerAxis = true;
}
// Comma leading into range_spec.
if (parser.parseComma()) {
return nullptr;
}
// Parameter specification.
// For per-axis, ranges are in a {} delimitted list.
if (isPerAxis) {
if (parser.parseLBrace()) {
return nullptr;
}
}
// Parse scales/zeroPoints.
llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
do {
scales.resize(scales.size() + 1);
zeroPoints.resize(zeroPoints.size() + 1);
if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
return nullptr;
}
} while (isPerAxis && succeeded(parser.parseOptionalComma()));
if (isPerAxis) {
if (parser.parseRBrace()) {
return nullptr;
}
}
if (parser.parseGreater()) {
return nullptr;
}
if (!isPerAxis && scales.size() > 1) {
return (parser.emitError(scaleZPLoc,
"multiple scales/zeroPoints provided, but "
"quantizedDimension wasn't specified"),
nullptr);
}
if (isPerAxis) {
ArrayRef<double> scalesRef(scales.begin(), scales.end());
ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
return parser.getChecked<UniformQuantizedPerAxisType>(
typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
quantizedDimension, storageTypeMin, storageTypeMax);
}
return parser.getChecked<UniformQuantizedType>(
typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
storageTypeMin, storageTypeMax);
}
/// Parses an CalibratedQuantizedType.
///
/// calibrated ::= `calibrated<` expressed-spec `>`
/// expressed-spec ::= expressed-type `<` calibrated-range `>`
/// expressed-type ::= `f` integer-literal
/// calibrated-range ::= float-literal `:` float-literal
static Type parseCalibratedType(DialectAsmParser &parser) {
FloatType expressedType;
double min;
double max;
// Type specification.
if (parser.parseLess())
return nullptr;
// Expressed type.
expressedType = parseExpressedTypeAndRange(parser, min, max);
if (!expressedType) {
return nullptr;
}
if (parser.parseGreater()) {
return nullptr;
}
return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
}
/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
// All types start with an identifier that we switch on.
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling)))
return nullptr;
if (typeNameSpelling == "uniform")
return parseUniformType(parser);
if (typeNameSpelling == "any")
return parseAnyType(parser);
if (typeNameSpelling == "calibrated")
return parseCalibratedType(parser);
parser.emitError(parser.getNameLoc(),
"unknown quantized type " + typeNameSpelling);
return nullptr;
}
static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
// storage type
unsigned storageWidth = type.getStorageTypeIntegralWidth();
bool isSigned = type.isSigned();
if (isSigned) {
out << "i" << storageWidth;
} else {
out << "u" << storageWidth;
}
// storageTypeMin and storageTypeMax if not default.
int64_t defaultIntegerMin =
QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
int64_t defaultIntegerMax =
QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
if (defaultIntegerMin != type.getStorageTypeMin() ||
defaultIntegerMax != type.getStorageTypeMax()) {
out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
<< ">";
}
}
static void printQuantParams(double scale, int64_t zeroPoint,
DialectAsmPrinter &out) {
out << scale;
if (zeroPoint != 0) {
out << ":" << zeroPoint;
}
}
/// Helper that prints a AnyQuantizedType.
static void printAnyQuantizedType(AnyQuantizedType type,
DialectAsmPrinter &out) {
out << "any<";
printStorageType(type, out);
if (Type expressedType = type.getExpressedType()) {
out << ":" << expressedType;
}
out << ">";
}
/// Helper that prints a UniformQuantizedType.
static void printUniformQuantizedType(UniformQuantizedType type,
DialectAsmPrinter &out) {
out << "uniform<";
printStorageType(type, out);
out << ":" << type.getExpressedType() << ", ";
// scheme specific parameters
printQuantParams(type.getScale(), type.getZeroPoint(), out);
out << ">";
}
/// Helper that prints a UniformQuantizedPerAxisType.
static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
DialectAsmPrinter &out) {
out << "uniform<";
printStorageType(type, out);
out << ":" << type.getExpressedType() << ":";
out << type.getQuantizedDimension();
out << ", ";
// scheme specific parameters
ArrayRef<double> scales = type.getScales();
ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
out << "{";
llvm::interleave(
llvm::seq<size_t>(0, scales.size()), out,
[&](size_t index) {
printQuantParams(scales[index], zeroPoints[index], out);
},
",");
out << "}>";
}
/// Helper that prints a CalibratedQuantizedType.
static void printCalibratedQuantizedType(CalibratedQuantizedType type,
DialectAsmPrinter &out) {
out << "calibrated<" << type.getExpressedType();
out << "<" << type.getMin() << ":" << type.getMax() << ">";
out << ">";
}
/// Print a type registered to this dialect.
void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
if (auto anyType = type.dyn_cast<AnyQuantizedType>())
printAnyQuantizedType(anyType, os);
else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
printUniformQuantizedType(uniformType, os);
else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
printUniformQuantizedPerAxisType(perAxisType, os);
else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
printCalibratedQuantizedType(calibratedType, os);
else
llvm_unreachable("Unhandled quantized type");
}