blob: 66ea00b23b9d428112cc53a64253ca0abdc68f15 [file]
//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
using namespace mlir::tosa;
TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
// End of auto-generated metadata
}
template <>
OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() {
return profileComplianceMap;
}
template <>
OperationExtensionComplianceMap
TosaProfileCompliance::getProfileComplianceMap() {
return extensionComplianceMap;
}
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
Value output) {
for (auto operand : operands)
addValue(operand);
addValue(output);
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
addValue(op.getInput1().front());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addType(op.getAccType());
addValue(op.getOutput());
return success();
}
template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
addValue(op.getInput());
addValue(op.getWeight());
addValue(op.getBias());
addValue(op.getInputZp());
addValue(op.getWeightZp());
addType(op.getAccType());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
return populateProfileInfoConv(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
addValue(op.getInput1());
addValue(op.getPadConst());
addValue(op.getOutput());
return success();
}
template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
addValue(op.getInput1());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
return populateProfileInfoDataLayout(op);
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
addValue(op.getValues());
addValue(op.getIndices());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
addValue(op.getValuesIn());
addValue(op.getIndices());
addValue(op.getInput());
addValue(op.getValuesOut());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
addValue(op.getInput1());
addValue(op.getInput2());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
addValue(op.getInput());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getInputImag());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
addValue(op.getInputReal());
addValue(op.getOutputReal());
addValue(op.getOutputImag());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getInput2());
addValue(op.getInput3());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
addValue(op.getA());
addValue(op.getB());
addValue(op.getAZp());
addValue(op.getBZp());
addValue(op.getOutput());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
::mlir::Attribute attr = op.getInitialValueAttr();
if (attr == nullptr)
return failure();
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
addType(getElementTypeOrSelf(typedAttr));
return success();
}
return failure();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) {
addValue(op.getInput1());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
addValue(op.getCondition());
return success();
}
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
Block *block = &op.getCondGraph().front();
Operation *terminator = block->getTerminator();
addValue(terminator->getOperands().front());
return success();
}
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \
}
#define POPULATE_PROFILE_INFO_SKIP(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) \
return success();
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
return populateProfileInfo(op->getOperands(), op->getResult(0)); \
}
// Skip irrelevant operands when they are independent and not tied to any
// specific profile/extension.
POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d)
POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Conv2D)
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
POPULATE_PROFILE_INFO_CUSTOM(Slice)
POPULATE_PROFILE_INFO_CUSTOM(Tile)
POPULATE_PROFILE_INFO_CUSTOM(Transpose)
POPULATE_PROFILE_INFO_CUSTOM(Gather)
POPULATE_PROFILE_INFO_CUSTOM(Scatter)
POPULATE_PROFILE_INFO_CUSTOM(Resize)
POPULATE_PROFILE_INFO_CUSTOM(Select)
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
POPULATE_PROFILE_INFO_CUSTOM(If)
POPULATE_PROFILE_INFO_CUSTOM(While)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
POPULATE_PROFILE_INFO_COMMON(Maximum)
POPULATE_PROFILE_INFO_COMMON(Minimum)
POPULATE_PROFILE_INFO_COMMON(MaxPool2d)
POPULATE_PROFILE_INFO_COMMON(Clamp)
POPULATE_PROFILE_INFO_COMMON(Erf)
POPULATE_PROFILE_INFO_COMMON(Sigmoid)
POPULATE_PROFILE_INFO_COMMON(Tanh)
POPULATE_PROFILE_INFO_COMMON(Add)
POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
POPULATE_PROFILE_INFO_COMMON(BitwiseOr)
POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
POPULATE_PROFILE_INFO_COMMON(LogicalNot)
POPULATE_PROFILE_INFO_COMMON(LogicalOr)
POPULATE_PROFILE_INFO_COMMON(LogicalXor)
POPULATE_PROFILE_INFO_COMMON(IntDiv)
POPULATE_PROFILE_INFO_COMMON(Pow)
POPULATE_PROFILE_INFO_COMMON(Table)
POPULATE_PROFILE_INFO_COMMON(Abs)
POPULATE_PROFILE_INFO_COMMON(Ceil)
POPULATE_PROFILE_INFO_COMMON(Clz)
POPULATE_PROFILE_INFO_COMMON(Sin)
POPULATE_PROFILE_INFO_COMMON(Cos)
POPULATE_PROFILE_INFO_COMMON(Exp)
POPULATE_PROFILE_INFO_COMMON(Floor)
POPULATE_PROFILE_INFO_COMMON(Log)
POPULATE_PROFILE_INFO_COMMON(Negate)
POPULATE_PROFILE_INFO_COMMON(Reciprocal)
POPULATE_PROFILE_INFO_COMMON(Rsqrt)
POPULATE_PROFILE_INFO_COMMON(ReduceAll)
POPULATE_PROFILE_INFO_COMMON(ReduceAny)
POPULATE_PROFILE_INFO_COMMON(ReduceMax)
POPULATE_PROFILE_INFO_COMMON(ReduceMin)
POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
POPULATE_PROFILE_INFO_COMMON(ReduceSum)
POPULATE_PROFILE_INFO_COMMON(Equal)
POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
POPULATE_PROFILE_INFO_COMMON(Greater)
POPULATE_PROFILE_INFO_COMMON(Reverse)
POPULATE_PROFILE_INFO_COMMON(Identity)
POPULATE_PROFILE_INFO_COMMON(VariableRead)
// Type Invariant Extension, a capability extension that is independent
// of the data type, meaning any compatible type can be used. No type
// constraint for those operations.
POPULATE_PROFILE_INFO_SKIP(ConstShape)
POPULATE_PROFILE_INFO_SKIP(Yield)
return failure();
}
//===----------------------------------------------------------------------===//
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//
template <typename T>
FailureOr<SmallVector<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op,
CheckCondition &condition) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
return findMatchedProfile<T>(op, it->second, condition);
}
template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
Operation *op, const tosa::TargetEnv &targetEnv,
const SmallVector<ArrayRef<T>> &specRequiredModeSet) {
// None of profile requirement is set in the specification.
if (specRequiredModeSet.size() == 0)
return success();
CheckCondition condition = CheckCondition::invalid;
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
if (failed(maybeOpRequiredMode)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
int mode_count = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
mode_count += cands.size();
}
op->emitOpError() << "illegal: requires"
<< (mode_count > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
return failure();
}
// Find the required profiles or extensions according to the operand type
// combination.
const auto opRequiredMode = maybeOpRequiredMode.value();
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
}
if (condition == CheckCondition::allOf &&
!targetEnv.allowsAllOf(opRequiredMode)) {
op->emitOpError() << "illegal: requires"
<< (opRequiredMode.size() > 1 ? " all of " : " ") << "["
<< llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
<< "] but not enabled in target\n";
return failure();
}
if (condition == CheckCondition::anyOf &&
!targetEnv.allowsAnyOf(opRequiredMode)) {
op->emitOpError() << "illegal: requires"
<< (opRequiredMode.size() > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
<< "] but not enabled in target\n";
return failure();
}
// Each extension can contain a list of profiles that it works with, usually
// have the same data type.
if constexpr (std::is_same_v<T, Extension>) {
for (const auto &mode : opRequiredMode) {
SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
if (!targetEnv.allowsAnyOf(coProfs)) {
op->emitOpError() << "illegal: requires ["
<< llvm::join(stringifyProfile<Profile>(coProfs),
", ")
<< "] to work with but not enabled in target\n";
return failure();
}
}
}
// Ensure the profile inference match the profile knowledge of the
// specification.
for (const auto &cands : specRequiredModeSet) {
for (const auto &mode : opRequiredMode) {
if (!llvm::is_contained(cands, mode)) {
op->emitOpError() << "illegal: requires ["
<< llvm::join(stringifyProfile<T>(opRequiredMode),
", ")
<< "] but not included in the profile compliance ["
<< llvm::join(
stringifyProfile<T>(specRequiredModeSet), ", ")
<< "]\n";
return failure();
}
}
}
return success();
}
LogicalResult
TosaProfileCompliance::checkProfile(Operation *op,
const tosa::TargetEnv &targetEnv) {
if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
return checkProfileOrExtension<Profile>(op, targetEnv,
interface.getProfiles());
return success();
}
LogicalResult
TosaProfileCompliance::checkExtension(Operation *op,
const tosa::TargetEnv &targetEnv) {
if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
return checkProfileOrExtension<Extension>(op, targetEnv,
interface.getExtensions());
return success();
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
!maybeProfDef.value().size() && !maybeExtDef.value().size())
return failure();
return success();
}
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
// Populate the type of profile/extension relevant operands.
ProfileInfoDepot depot(op);
SmallVector<TypeInfo> present = depot.getInfo();
if (present.size() == 0)
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
for (SmallVector<TypeInfo> expected : sets) {
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
bool is_found = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
is_found = false;
break;
}
}
if (is_found == true) {
condition = compInfo[i].condition;
return compInfo[i].mode;
}
}
}
return {};
}
// Debug utilites.
template <typename T>
SmallVector<StringRef>
TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) {
SmallVector<StringRef> debugStrings;
for (const auto &profile : profiles) {
if constexpr (std::is_same_v<T, Profile>)
debugStrings.push_back(tosa::stringifyProfile(profile));
else
debugStrings.push_back(tosa::stringifyExtension(profile));
}
return debugStrings;
}
template <typename T>
SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
const SmallVector<ArrayRef<T>> &profileSet) {
SmallVector<StringRef> debugStrings;
for (const auto &profiles : profileSet) {
auto tempStrings = stringifyProfile<T>(profiles);
llvm::append_range(debugStrings, tempStrings);
}
return debugStrings;
}