blob: 06c2036923dfe2d39a99ebf809730b0c25da5129 [file] [log] [blame]
//===- TosaValidation.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Validate if TOSA dialect input matchs with the specification for given
// requirements.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAVALIDATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir
using namespace mlir;
using namespace mlir::tosa;
namespace {
static LogicalResult
checkConstantOperands(Operation *op, ArrayRef<unsigned int> operandIndices) {
for (const auto index : operandIndices) {
Attribute attr;
if (!matchPattern(op->getOperand(index), m_Constant(&attr))) {
return op->emitOpError("expected compile time resolvable constant, but "
"got variable value for operand #")
<< index;
}
}
return success();
}
static LogicalResult checkConstantOperandMul(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::MulOp>(op)) {
// Check 'shift'
return checkConstantOperands(op, {2});
}
return success();
}
static LogicalResult checkConstantOperandTable(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::TableOp>(op)) {
// Check 'table'
return checkConstantOperands(op, {1});
}
return success();
}
static LogicalResult checkConstantOperandPad(Operation *op,
const TargetEnv &env) {
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
// Assume this op is zero-padding if padConst is not presented
if (!env.allows(Extension::dynamic) && padOp.getPadConst())
// Check 'pad_const'
// Note: 'padding' (operand 1) is not checked as it is a tosa.shape type
return checkConstantOperands(op, {2});
}
return success();
}
static LogicalResult checkConstantOperandRescale(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::RescaleOp>(op)) {
// Check 'multiplier', 'shift', 'input_zp' and 'output_zp'
return checkConstantOperands(op, {1, 2, 3, 4});
}
return success();
}
template <typename T>
static LogicalResult checkConstantOperandConvOps(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<T>(op)) {
// Check 'input_zp' and 'weight_zp'
return checkConstantOperands(op, {3, 4});
}
return success();
}
static LogicalResult checkConstantOperandMatMul(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::MatMulOp>(op)) {
// Check 'A_zp' and 'B_zp'
return checkConstantOperands(op, {2, 3});
}
return success();
}
static LogicalResult checkConstantOperandAvgPool2d(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::AvgPool2dOp>(op)) {
// Check 'input_zp' and 'output_zp'
return checkConstantOperands(op, {1, 2});
}
return success();
}
static LogicalResult checkConstantOperandNegate(Operation *op,
const TargetEnv &env) {
if (!env.allows(Extension::dynamic) && isa<tosa::NegateOp>(op)) {
// Check 'input1_zp' and 'output_zp'
return checkConstantOperands(op, {1, 2});
}
return success();
}
struct TosaLevel {
int32_t MAX_RANK = 0;
int32_t MAX_KERNEL = 0;
int32_t MAX_STRIDE = 0;
int32_t MAX_SCALE = 0;
int32_t MAX_LOG2_SIZE = 0;
int32_t MAX_NESTING = 0;
int32_t MAX_TENSOR_LIST_SIZE = 0;
bool operator==(const TosaLevel &rhs) {
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE &&
MAX_LOG2_SIZE == rhs.MAX_LOG2_SIZE &&
MAX_NESTING == rhs.MAX_NESTING &&
MAX_TENSOR_LIST_SIZE == rhs.MAX_TENSOR_LIST_SIZE;
}
};
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256, 31, 6, 64};
static constexpr TosaLevel TOSA_LEVEL_NONE = {32, 2147483647, 2147483647, 2048,
63, 256, 256};
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
explicit TosaValidation() { populateConstantOperandChecks(); }
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
this->profile = options.profile;
this->extension = options.extension;
this->strictOpSpecAlignment = options.strictOpSpecAlignment;
this->allowInvalidOpDatatypeCombinations =
options.allowInvalidOpDatatypeCombinations;
this->level = options.level;
}
void runOnOperation() final;
LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : constCheckers) {
if (failed(checker(op, targetEnv)))
return failure();
}
return success();
}
LogicalResult applyLevelCheck(Operation *op);
LogicalResult applyAttributeCheck(Operation *op);
// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);
// check error if conditions
LogicalResult applyErrorIfCheck(Operation *op);
private:
void populateConstantOperandChecks() {
constCheckers.emplace_back(checkConstantOperandMul);
constCheckers.emplace_back(checkConstantOperandTable);
constCheckers.emplace_back(checkConstantOperandPad);
constCheckers.emplace_back(checkConstantOperandRescale);
constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv2DOp>);
constCheckers.emplace_back(checkConstantOperandConvOps<tosa::Conv3DOp>);
constCheckers.emplace_back(
checkConstantOperandConvOps<tosa::DepthwiseConv2DOp>);
constCheckers.emplace_back(
checkConstantOperandConvOps<tosa::TransposeConv2DOp>);
constCheckers.emplace_back(checkConstantOperandMatMul);
constCheckers.emplace_back(checkConstantOperandAvgPool2d);
constCheckers.emplace_back(checkConstantOperandNegate);
}
bool levelCheckKernel(Operation *op, int32_t v, const StringRef checkDesc) {
if (v > tosaLevel.MAX_KERNEL) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
return true;
}
bool levelCheckStride(Operation *op, int32_t v, const StringRef checkDesc) {
if (v > tosaLevel.MAX_STRIDE) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
return true;
}
bool levelCheckScale(Operation *op, int32_t v, const StringRef checkDesc) {
if (v > tosaLevel.MAX_SCALE) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
return true;
}
bool levelCheckListSize(Operation *op, int32_t v, const StringRef checkDesc) {
if (v > tosaLevel.MAX_TENSOR_LIST_SIZE) {
op->emitOpError() << "failed level check for MAX_TENSOR_LIST_SIZE: "
<< checkDesc;
return false;
}
return true;
}
template <typename T>
bool levelCheckRank(Operation *op, const T &v,
const StringRef operandOrResult, int32_t highest_rank) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
}
if (type.getRank() > highest_rank) {
op->emitOpError() << "failed level check: " << operandOrResult
<< " rank(shape) <= MAX_RANK";
return false;
}
}
return true;
}
// Perform the Level tensor size check on the input tensor.
bool levelCheckSize(Operation *op, const Value &v,
const StringRef operandOrResult);
// Level check sizes of all operands and results of the operation.
template <typename T>
bool levelCheckSizes(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
if (!levelCheckSize(op, v, "operand"))
return false;
}
for (auto v : op->getResults()) {
if (!levelCheckSize(op, v, "result"))
return false;
}
return true;
}
// Level check ranks of all operands, attribute and results of the operation.
template <typename T>
bool levelCheckRanks(T tosaOp) {
auto op = tosaOp.getOperation();
for (auto v : op->getOperands()) {
if (!levelCheckRank(op, v, "operand", tosaLevel.MAX_RANK))
return false;
}
if (!op->getAttrs().empty()) {
for (NamedAttribute attr : op->getAttrs()) {
if (auto elemAttr = dyn_cast<ElementsAttr>(attr.getValue())) {
if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK))
return false;
}
}
}
for (auto v : op->getResults()) {
if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK))
return false;
}
return true;
}
// Level check ranks and sizes.
bool levelCheckRanksAndSizes(Operation *op);
// Pool Op: level check kernel/stride/pad values
template <typename T>
bool levelCheckPool(Operation *op) {
if (auto poolOp = dyn_cast<T>(op)) {
for (auto k : poolOp.getKernel()) {
if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
return false;
}
}
for (auto s : poolOp.getStride()) {
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
return false;
}
}
for (auto p : poolOp.getPad()) {
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
return false;
}
}
}
return true;
}
// Conv Op: level check dilation/stride/pad values
template <typename T>
bool levelCheckConv(Operation *op) {
if (auto convOp = dyn_cast<T>(op)) {
for (auto k : convOp.getDilation()) {
if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
return false;
}
}
for (auto p : convOp.getPad()) {
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
return false;
}
}
for (auto s : convOp.getStride()) {
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
return false;
}
}
auto dilation = convOp.getDilation();
if (ShapedType weightType =
dyn_cast<ShapedType>(op->getOperand(1).getType())) {
auto shape = weightType.getShape();
if (isa<tosa::Conv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
if (!levelCheckKernel(op, dilation[0] * shape[1],
"dilation_y * KH <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[1] * shape[2],
"dilation_x * KW <= MAX_KERNEL)"))
return false;
} else if (isa<tosa::Conv3DOp>(op)) {
assert(shape.size() == 5);
assert(dilation.size() == 3);
if (!levelCheckKernel(op, dilation[0] * shape[1],
"dilation_d * KD <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[1] * shape[2],
"dilation_y * KH <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[2] * shape[3],
"dilation_x * KW <= MAX_KERNEL)"))
return false;
} else if (isa<tosa::DepthwiseConv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
if (!levelCheckKernel(op, dilation[0] * shape[0],
"dilation_y * KH <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[1] * shape[1],
"dilation_x * KW <= MAX_KERNEL)"))
return false;
}
}
}
return true;
}
// FFT op: level check H, W in input shape [N,H,W]
template <typename T>
bool levelCheckFFT(Operation *op) {
if (isa<T>(op)) {
for (auto v : op->getOperands()) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
auto shape = type.getShape();
assert(shape.size() == 3);
if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
!levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
return false;
}
}
}
}
return true;
}
// TransposeConv2d op: level check kH/kW, outpad, and stride
bool levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
dyn_cast<ShapedType>(transpose.getWeight().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
!levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
return false;
}
}
for (auto p : transpose.getOutPad()) {
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
return false;
}
}
for (auto s : transpose.getStride()) {
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
return false;
}
}
}
return true;
}
// Resize op: level check max scales
bool levelCheckResize(Operation *op) {
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(),
scale)) {
return false;
}
const int64_t scaleYN = scale[0];
const int64_t scaleYD = scale[1];
const int64_t scaleXN = scale[2];
const int64_t scaleXD = scale[3];
if (!levelCheckScale(op, scaleYN / scaleYD,
"scale_y_n/scale_y_d <= MAX_SCALE") ||
!levelCheckScale(op, scaleXN / scaleXD,
"scale_x_n/scale_x_d <= MAX_SCALE")) {
return false;
}
}
return true;
}
bool levelCheckListSize(Operation *op) {
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
return levelCheckListSize(op, concat.getInput1().size(), "input1");
}
if (auto custom = dyn_cast<tosa::CustomOp>(op)) {
if (!levelCheckListSize(op, custom.getInputList().size(), "input_list") ||
!levelCheckListSize(op, custom.getOutputList().size(),
"output_list")) {
return false;
}
}
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
!levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
return false;
}
}
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
!levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
return false;
}
}
return true;
}
bool attributeCheckRescale(Operation *op) {
if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
!targetEnv.allows(Extension::doubleround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = DOUBLE_ROUND "
<< "requires extension [doubleround]";
return false;
} else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
!targetEnv.allows(Extension::inexactround)) {
op->emitOpError()
<< "failed attribute check: rounding_mode = INEXACT_ROUND "
<< "requires extension [inexactround]";
return false;
}
}
return true;
}
// configure profile and level values from pass options profileName and
// levelName
void configLevelAndProfile() {
tosaLevel = TOSA_LEVEL_NONE;
if (level == TosaLevelEnum::EightK) {
tosaLevel = TOSA_LEVEL_EIGHTK;
}
if (!profile.empty()) {
for (std::string &prof : profile) {
auto profSymbol = symbolizeProfile(prof);
if (profSymbol) {
targetEnv.addProfile(profSymbol.value());
} else {
llvm::errs() << "unknown TOSA profile name passed in: " << prof
<< ", supported profiles are `pro_int` and `pro_fp`\n";
return signalPassFailure();
}
}
}
if (!extension.empty()) {
for (std::string &ext : extension) {
auto extSymbol = symbolizeExtension(ext);
if (extSymbol) {
targetEnv.addExtension(extSymbol.value());
} else {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
<< "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
<< "doubleround, inexactround and dynamic\n";
return signalPassFailure();
}
}
}
}
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
bool isValidElementType(Type type);
SmallVector<
std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
TosaProfileCompliance profileComp;
tosa::TargetEnv targetEnv;
};
template <>
bool TosaValidation::levelCheckRanks(tosa::ArgMaxOp tosaOp) {
auto op = tosaOp.getOperation();
if (!levelCheckRank(op, tosaOp.getInput(), "operand", tosaLevel.MAX_RANK))
return false;
// rank(output) = rank(input) - 1
if (!levelCheckRank(op, tosaOp.getOutput(), "result", tosaLevel.MAX_RANK - 1))
return false;
return true;
}
template <>
bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();
// Only the condition input has rank limitation.
if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
return false;
return true;
}
bool TosaValidation::levelCheckRanksAndSizes(Operation *op) {
#define CHECK_RANKS_AND_SIZES(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
if (!levelCheckRanks(cast<tosa::tosaOp##Op>(op))) \
return false; \
if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
return false; \
}
#define CHECK_SIZES(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
if (!levelCheckSizes(cast<tosa::tosaOp##Op>(op))) \
return false; \
}
// Tensor Operators
CHECK_RANKS_AND_SIZES(ArgMax);
// Activation Functions
CHECK_RANKS_AND_SIZES(Clamp);
CHECK_RANKS_AND_SIZES(Erf);
CHECK_RANKS_AND_SIZES(Sigmoid);
CHECK_RANKS_AND_SIZES(Tanh);
// Elementwise Binary Operators
CHECK_RANKS_AND_SIZES(Add);
CHECK_RANKS_AND_SIZES(ArithmeticRightShift);
CHECK_RANKS_AND_SIZES(BitwiseAnd);
CHECK_RANKS_AND_SIZES(BitwiseOr);
CHECK_RANKS_AND_SIZES(BitwiseXor);
CHECK_RANKS_AND_SIZES(IntDiv);
CHECK_RANKS_AND_SIZES(LogicalAnd);
CHECK_RANKS_AND_SIZES(LogicalLeftShift);
CHECK_RANKS_AND_SIZES(LogicalRightShift);
CHECK_RANKS_AND_SIZES(LogicalOr);
CHECK_RANKS_AND_SIZES(LogicalXor);
CHECK_RANKS_AND_SIZES(Maximum);
CHECK_RANKS_AND_SIZES(Minimum);
CHECK_RANKS_AND_SIZES(Mul);
CHECK_RANKS_AND_SIZES(Pow);
CHECK_RANKS_AND_SIZES(Sub);
CHECK_RANKS_AND_SIZES(Table);
// Elementwise Unary Operators
CHECK_RANKS_AND_SIZES(Abs);
CHECK_RANKS_AND_SIZES(BitwiseNot);
CHECK_RANKS_AND_SIZES(Ceil);
CHECK_RANKS_AND_SIZES(Clz);
CHECK_RANKS_AND_SIZES(Cos);
CHECK_RANKS_AND_SIZES(Exp);
CHECK_RANKS_AND_SIZES(Floor);
CHECK_RANKS_AND_SIZES(Log);
CHECK_RANKS_AND_SIZES(LogicalNot);
CHECK_RANKS_AND_SIZES(Negate);
CHECK_RANKS_AND_SIZES(Reciprocal);
CHECK_RANKS_AND_SIZES(Rsqrt);
CHECK_RANKS_AND_SIZES(Sin);
// Elementwise Ternary Operators
CHECK_RANKS_AND_SIZES(Select);
// Comparison Operators
CHECK_RANKS_AND_SIZES(Equal);
CHECK_RANKS_AND_SIZES(Greater);
CHECK_RANKS_AND_SIZES(GreaterEqual);
// Reduction Operators
CHECK_RANKS_AND_SIZES(ReduceAll);
CHECK_RANKS_AND_SIZES(ReduceAny);
CHECK_RANKS_AND_SIZES(ReduceMax);
CHECK_RANKS_AND_SIZES(ReduceMin);
CHECK_RANKS_AND_SIZES(ReduceProduct);
CHECK_RANKS_AND_SIZES(ReduceSum);
// Data Layout Operators
CHECK_RANKS_AND_SIZES(Concat);
CHECK_RANKS_AND_SIZES(Pad);
CHECK_RANKS_AND_SIZES(Reshape);
CHECK_RANKS_AND_SIZES(Reverse);
CHECK_RANKS_AND_SIZES(Slice);
CHECK_RANKS_AND_SIZES(Tile);
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
// Variable Operators
CHECK_RANKS_AND_SIZES(Variable);
CHECK_RANKS_AND_SIZES(VariableWrite);
CHECK_RANKS_AND_SIZES(VariableRead);
// Data Nodes
CHECK_RANKS_AND_SIZES(Const);
CHECK_RANKS_AND_SIZES(Identity);
// For the following operators, check whether the size of each tensor
// operand is valid in a given Level.
// Tensor Operators
CHECK_SIZES(AvgPool2d);
CHECK_SIZES(Conv2D);
CHECK_SIZES(Conv3D);
CHECK_SIZES(DepthwiseConv2D);
CHECK_SIZES(TransposeConv2D);
CHECK_SIZES(FFT2d);
CHECK_SIZES(MatMul);
CHECK_SIZES(MaxPool2d);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
CHECK_SIZES(Gather);
CHECK_SIZES(Scatter);
// Image Operators
CHECK_SIZES(Resize);
// Custom Operators
CHECK_SIZES(Custom);
// Control Flow Operators
CHECK_SIZES(While);
// Shape Operators
CHECK_SIZES(ConstShape);
#undef CHECK_RANKS_AND_SIZES
#undef CHECK_SIZES
return true;
}
// Perform the Level tensor size check
bool TosaValidation::levelCheckSize(Operation *op, const Value &v,
const StringRef operandOrResult) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
}
auto shape = type.getShape();
for (auto dim : shape) {
if (mlir::ShapedType::isDynamic(dim)) {
op->emitOpError() << "failed level check: " << operandOrResult
<< " shape dimension cannot be dynamic";
return false;
}
}
int64_t element_bits = type.getElementTypeBitWidth();
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
int64_t size = element_bytes * type.getNumElements();
// According to 1.11. Tensor Definitions of Tosa spec, the value of
// tensor_size_t is 1 << MAX_LOG2_SIZE) - 1 where MAX_LOG2_SIZE is
// defined in 1.7. Levels.
// For each tensor, the number of tensor elements multiplied by the
// element size in bytes must be representable as a tensor_size_t.
const int64_t max_size = (INT64_C(1) << tosaLevel.MAX_LOG2_SIZE) - 1;
if (size > max_size) {
op->emitOpError()
<< "failed level check: " << operandOrResult
<< " tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)";
return false;
}
}
return true;
}
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
if (tosaLevel == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
// additional level checks from spec 0.70
if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
!levelCheckConv<tosa::Conv2DOp>(op) ||
!levelCheckConv<tosa::Conv3DOp>(op) ||
!levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
!levelCheckFFT<tosa::FFT2dOp>(op) ||
!levelCheckPool<tosa::MaxPool2dOp>(op) ||
!levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
!levelCheckResize(op)) {
return failure();
}
if (!levelCheckRanksAndSizes(op)) {
return failure();
}
// level check MAX_TENSOR_LIST_SIZE
if (!levelCheckListSize(op)) {
return failure();
}
return success();
}
LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
if (!attributeCheckRescale(op))
return failure();
return success();
}
inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declaredType) {
// for now, simply use type equality comparison
return type == declaredType;
}
bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
if (variablesMap.count(nameAttr)) {
op->emitOpError() << "name has already been declared";
return false;
}
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
mlir::Type type = typeAttr.getValue();
variablesMap[nameAttr] = type;
}
return true;
}
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
if (isa<mlir::tosa::VariableReadOp>(op) ||
isa<mlir::tosa::VariableWriteOp>(op)) {
mlir::StringAttr nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
if (!variablesMap.count(nameAttr)) {
op->emitOpError() << "name has not been declared";
return false;
}
auto varType = variablesMap[nameAttr];
for (auto v : op->getOperands()) {
auto type = v.getType();
if (!CompatibleTypes(type, varType)) {
op->emitOpError() << "operand type does not equal variable type";
return false;
}
}
for (auto v : op->getResults()) {
auto type = v.getType();
if (!CompatibleTypes(type, varType)) {
op->emitOpError() << "result type does not equal variable type";
return false;
}
}
}
return true;
}
LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
return failure();
}
return success();
}
bool checkErrorIfResize(Operation *op) {
auto resize = dyn_cast<tosa::ResizeOp>(op);
if (!resize)
return true;
const Value input = resize.getInput();
const Value output = resize.getOutput();
const RankedTensorType inputType =
llvm::dyn_cast<RankedTensorType>(input.getType());
const RankedTensorType outputType =
llvm::dyn_cast<RankedTensorType>(output.getType());
if (!inputType || !outputType) {
op->emitOpError("expect ranked input/output tensor");
return false;
}
// Ensure the image size is supported by GPU APIs and that for integer
// implementations, position * stride does not overflow int32_t.
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
const SmallVector<int64_t, 4> sizes = {
outputType.getDimSize(1), outputType.getDimSize(2),
inputType.getDimSize(1), inputType.getDimSize(2)};
const int64_t *maxDim = llvm::max_element(sizes);
if (maxDim != sizes.end() && *maxDim >= 16384) {
op->emitOpError("expect input/output height/width dims to be < 16384, ")
<< "got [OH, OW, IH, IW] = " << sizes;
return false;
}
}
SmallVector<int64_t> scale;
if (!tosa::getConstShapeValues(resize.getScale().getDefiningOp(), scale)) {
return false;
}
const int64_t scaleYN = scale[0];
const int64_t scaleYD = scale[1];
const int64_t scaleXN = scale[2];
const int64_t scaleXD = scale[3];
// Ensure scale values don't overflow int32 accumulator
if (scaleYN > (1 << 11) || scaleXN > (1 << 11)) {
op->emitOpError("expect all scale numerator values to be <= (1 << 11), "
"got scale_y_n=")
<< scaleYN << ", scale_x_n=" << scaleXN;
return false;
}
if (scaleYD >= 16 * scaleYN || scaleXD >= 16 * scaleXN) {
op->emitOpError("expect a downscale ratio larger than 1/16, got y=")
<< scaleYN << "/" << scaleYD << ", x=" << scaleXN << "/" << scaleXD;
return false;
}
SmallVector<int64_t> offset;
SmallVector<int64_t> border;
if (!tosa::getConstShapeValues(resize.getOffset().getDefiningOp(), offset) ||
!tosa::getConstShapeValues(resize.getBorder().getDefiningOp(), border)) {
return false;
}
const int64_t offsetY = offset[0];
const int64_t offsetX = offset[1];
// Set a consistent lower limit of 1/16 downscale to simplify
// implementations
if (offsetY < -scaleYN || offsetY >= 16 * scaleYN) {
op->emitOpError(
"expect offsetY / scaleYNumerator to be in range [-1, 16), got ")
<< offsetY << "/" << scaleYN;
return false;
}
if (offsetX < -scaleXN || offsetX >= 16 * scaleXN) {
op->emitOpError(
"expect offsetX / scaleXNumerator to be in range [-1, 16), got ")
<< offsetX << "/" << scaleXN;
return false;
}
const int64_t borderY = border[0];
const int64_t borderX = border[1];
if (borderY < -16 * scaleYN || borderY >= scaleYN) {
op->emitOpError(
"expect borderY / scaleYNumerator to be in range [-16, 1), got ")
<< borderY << "/" << scaleYN;
return false;
}
if (borderX < -16 * scaleXN || borderX >= scaleXN) {
op->emitOpError(
"expect borderX / scaleXNumerator to be in range [-16, 1), got ")
<< borderX << "/" << scaleXN;
return false;
}
// The following section of code is mostly duplicated with ResizeOp::verify().
//
// In TOSA specification, we do not support broadcast behavior.
// However, there is a rewrite pattern to materialize broadcast ResizeOp.
// It makes invalid TOSA ResizeOp into valid one. To avoid breaking
// existing code, we keep the rewrite pattern untouched. So, we need
// loose the checking in ResizeOp::verify() to support broadcast ResizeOp.
//
// Here is a strict checking to conform TOSA specification.
// FIXME: Remove the duplicated checkings when broadcast ResizeOp is removed.
auto idivCheck = [](const int64_t lhs,
const int64_t rhs) -> std::optional<int64_t> {
if (lhs % rhs != 0)
return std::nullopt;
return lhs / rhs;
};
const int64_t oh = outputType.getDimSize(1);
const int64_t ow = outputType.getDimSize(2);
const int64_t ih = inputType.getDimSize(1);
const int64_t iw = inputType.getDimSize(2);
if (ih != ShapedType::kDynamic) {
const std::optional<int64_t> calculatedOutHeightMinusOne =
idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
if (!calculatedOutHeightMinusOne.has_value()) {
op->emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
"border_y ")
<< "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * "
<< scaleYN << " - " << offsetY << " + " << borderY << ") / "
<< scaleYD;
return false;
}
const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) {
op->emitOpError("calculated output height did not match expected: ")
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
return false;
}
}
if (iw != ShapedType::kDynamic) {
const std::optional<int64_t> calculatedOutWidthMinusOne =
idivCheck((iw - 1) * scaleXN - offsetX + borderX, scaleXD);
if (!calculatedOutWidthMinusOne.has_value()) {
op->emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
"border_x ")
<< "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * "
<< scaleXN << " - " << offsetX << " + " << borderX << ") / "
<< scaleXD;
return false;
}
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) {
op->emitOpError("calculated output width did not match expected: ")
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
return false;
}
}
return true;
}
bool checkErrorIfMul(Operation *op) {
auto mul = dyn_cast<tosa::MulOp>(op);
if (!mul)
return true;
// REQUIRE(0 <= shift && shift <= 63);
// REQUIRE(is_same<in_t,int32_t>() || shift == 0);
ElementsAttr shift_elem;
if (!matchPattern(mul.getShift(), m_Constant(&shift_elem))) {
return true;
}
int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
auto inputElemType = getElementTypeOrSelf(mul.getInput1());
if (inputElemType.isInteger(32)) {
// 0 <= shift <= 63 for int32_t type
if (shift < 0 || shift > 63) {
op->emitOpError() << "requires 0 <= shift && shift <= 63, but got: "
<< shift;
return false;
}
} else {
// shift must be 0 for all other types
if (shift != 0) {
op->emitOpError() << "requires shift = 0 for all input data types that "
"are not int32_t, but got: "
<< shift;
return false;
}
}
return true;
}
bool checkErrorIfTable(Operation *op) {
auto table = dyn_cast<tosa::TableOp>(op);
if (!table)
return true;
// REQUIRE(length(table) == TABLE_SIZE) where TABLE_SIZE is 256 or 513
const auto inputElemType = getElementTypeOrSelf(table.getInput1().getType());
const int tableSize = inputElemType.isInteger(8) ? 256 : 513;
const ShapeAdaptor tableShape(table.getTable().getType());
if (tableShape.hasStaticShape()) {
const auto numElements = tableShape.getNumElements();
if (numElements != tableSize) {
op->emitOpError() << "requires table size of " << tableSize << ", got "
<< numElements;
return false;
}
}
return true;
}
bool checkErrorIfRescale(Operation *op) {
auto rescale = dyn_cast<tosa::RescaleOp>(op);
if (!rescale)
return true;
auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
!outputType.getElementType().isInteger())
return true;
auto inElemType = inputType.getElementType();
auto outElemType = outputType.getElementType();
auto inWidth = inElemType.getIntOrFloatBitWidth();
auto outWidth = outElemType.getIntOrFloatBitWidth();
bool inputUnsigned = rescale.getInputUnsigned();
bool outputUnsigned = rescale.getOutputUnsigned();
bool scale32 = rescale.getScale32();
auto roundingMode = rescale.getRoundingMode();
// ERROR_IF(scale32 && is_same<in_t,i48_t>())
if (scale32 && inWidth == 48) {
op->emitOpError() << "scale32 is not allowed with 48-bit input.";
return false;
}
// ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
if (!scale32 && roundingMode == "DOUBLE_ROUND") {
op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
return false;
}
// ERROR_IF(input_unsigned && output_unsigned)
if (inputUnsigned && outputUnsigned) {
op->emitOpError() << "input and output cannot be both unsigned.";
return false;
}
// ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
if (outWidth == 32 && inputUnsigned) {
op->emitOpError() << "i32 output type is not allowed with unsigned input.";
return false;
}
// ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
if (inWidth == 32 && outputUnsigned) {
op->emitOpError() << "i32 input type is not allowed with unsigned output.";
return false;
}
// ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
if (inWidth == 48 && outputUnsigned) {
op->emitOpError() << "i48 input type is not allowed with unsigned output.";
return false;
}
// ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
if (inWidth == 48 && inputUnsigned) {
op->emitOpError() << "i48 input type cannot be unsigned.";
return false;
}
// ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
if (inWidth == 32 && inputUnsigned) {
op->emitOpError() << "i32 input type cannot be unsigned.";
return false;
}
// ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
if (outWidth == 32 && outputUnsigned) {
op->emitOpError() << "i32 output type cannot be unsigned.";
return false;
}
return true;
}
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op))
return failure();
return success();
}
bool TosaValidation::isValidElementType(Type type) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
Float8E5M2Type>(type);
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
case 4:
case 8:
case 16:
case 32:
case 48:
return true;
}
}
} else if (mlir::isa<tosa::shapeType>(type)) {
return true;
}
return false;
}
void TosaValidation::runOnOperation() {
configLevelAndProfile();
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
getOperation().walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
// perform valid element type check at the beginning to
// protect rest of code against quantized element types
for (Value operand : op->getOperands()) {
auto elementTy = getElementTypeOrSelf(operand);
if (!isValidElementType(elementTy)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
}
}
for (Type resultTy : op->getResultTypes()) {
auto elementTy = getElementTypeOrSelf(resultTy);
if (!isValidElementType(elementTy)) {
op->emitOpError() << "is not profile-aligned: element type "
<< elementTy << " is not legal";
return signalPassFailure();
}
}
if (strictOpSpecAlignment &&
failed(profileComp.checkProfile(op, targetEnv)))
return signalPassFailure();
if (strictOpSpecAlignment &&
failed(profileComp.checkExtension(op, targetEnv)))
return signalPassFailure();
if (!allowInvalidOpDatatypeCombinations &&
failed(profileComp.checkInvalid(op))) {
op->emitOpError("illegal: operand/result data types not supported");
return signalPassFailure();
}
// Some uses of TOSA rely on the constant operands of particular
// operations.
if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();
// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();
// check additional attribute restrictions
if (failed(applyAttributeCheck(op)))
signalPassFailure();
// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
// do error if checks
if (strictOpSpecAlignment && failed(applyErrorIfCheck(op)))
signalPassFailure();
});
}
} // namespace