| //===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h" |
| #include "llvm/ADT/StringExtras.h" |
| |
| using namespace mlir; |
| using namespace mlir::tosa; |
| |
| TosaProfileCompliance::TosaProfileCompliance() { |
| const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1}; |
| const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4}; |
| const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8}; |
| const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16}; |
| const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32}; |
| const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48}; |
| const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16}; |
| const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16}; |
| const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32}; |
| const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; |
| const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; |
| |
| // The profile-based compliance content below is auto-generated by a script |
| // in https://git.mlplatform.org/tosa/specification.git |
| #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" |
| // End of auto-generated metadata |
| } |
| |
| template <> |
| OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() { |
| return profileComplianceMap; |
| } |
| |
| template <> |
| OperationExtensionComplianceMap |
| TosaProfileCompliance::getProfileComplianceMap() { |
| return extensionComplianceMap; |
| } |
| |
| // Base populating function |
| LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, |
| Value output) { |
| for (auto operand : operands) |
| addValue(operand); |
| addValue(output); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) { |
| addValue(op.getInput1().front()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) { |
| addValue(op.getInput()); |
| addValue(op.getInputZp()); |
| addValue(op.getOutputZp()); |
| addType(op.getAccType()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <typename T> |
| LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) { |
| addValue(op.getInput()); |
| addValue(op.getWeight()); |
| addValue(op.getBias()); |
| addValue(op.getInputZp()); |
| addValue(op.getWeightZp()); |
| addType(op.getAccType()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) { |
| return populateProfileInfoConv(op); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) { |
| return populateProfileInfoConv(op); |
| } |
| |
| template <> |
| LogicalResult |
| ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) { |
| return populateProfileInfoConv(op); |
| } |
| |
| template <> |
| LogicalResult |
| ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) { |
| return populateProfileInfoConv(op); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) { |
| addValue(op.getInput1()); |
| addValue(op.getPadConst()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <typename T> |
| LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) { |
| addValue(op.getInput1()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) { |
| return populateProfileInfoDataLayout(op); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) { |
| return populateProfileInfoDataLayout(op); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) { |
| return populateProfileInfoDataLayout(op); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) { |
| return populateProfileInfoDataLayout(op); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) { |
| addValue(op.getValues()); |
| addValue(op.getIndices()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) { |
| addValue(op.getValuesIn()); |
| addValue(op.getIndices()); |
| addValue(op.getInput()); |
| addValue(op.getValuesOut()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) { |
| addValue(op.getInput1()); |
| addValue(op.getInput2()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { |
| addValue(op.getInput()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { |
| addValue(op.getInputReal()); |
| addValue(op.getInputImag()); |
| addValue(op.getOutputReal()); |
| addValue(op.getOutputImag()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { |
| addValue(op.getInputReal()); |
| addValue(op.getOutputReal()); |
| addValue(op.getOutputImag()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { |
| addValue(op.getInput2()); |
| addValue(op.getInput3()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) { |
| addValue(op.getInput()); |
| addValue(op.getInputZp()); |
| addValue(op.getOutputZp()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) { |
| addValue(op.getA()); |
| addValue(op.getB()); |
| addValue(op.getAZp()); |
| addValue(op.getBZp()); |
| addValue(op.getOutput()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) { |
| ::mlir::Attribute attr = op.getInitialValueAttr(); |
| if (attr == nullptr) |
| return failure(); |
| |
| if (auto typedAttr = dyn_cast<TypedAttr>(attr)) { |
| addType(getElementTypeOrSelf(typedAttr)); |
| return success(); |
| } |
| return failure(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableWriteOp op) { |
| addValue(op.getInput1()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) { |
| addValue(op.getCondition()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) { |
| Block *block = &op.getCondGraph().front(); |
| Operation *terminator = block->getTerminator(); |
| addValue(terminator->getOperands().front()); |
| return success(); |
| } |
| |
| LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { |
| // This helper function only populates the info for the customised operands. |
| #define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \ |
| if (isa<tosa::tosaOp##Op>(op)) { \ |
| return populateProfileInfo(cast<tosa::tosaOp##Op>(op)); \ |
| } |
| |
| #define POPULATE_PROFILE_INFO_SKIP(tosaOp) \ |
| if (isa<tosa::tosaOp##Op>(op)) \ |
| return success(); |
| |
| // This helper function populates the info for all operands. |
| #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ |
| if (isa<tosa::tosaOp##Op>(op)) { \ |
| return populateProfileInfo(op->getOperands(), op->getResult(0)); \ |
| } |
| |
| // Skip irrelevant operands when they are independent and not tied to any |
| // specific profile/extension. |
| POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d) |
| POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D) |
| POPULATE_PROFILE_INFO_CUSTOM(Conv2D) |
| POPULATE_PROFILE_INFO_CUSTOM(Conv3D) |
| POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) |
| POPULATE_PROFILE_INFO_CUSTOM(Mul) |
| POPULATE_PROFILE_INFO_CUSTOM(FFT2d) |
| POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) |
| POPULATE_PROFILE_INFO_CUSTOM(Concat) |
| POPULATE_PROFILE_INFO_CUSTOM(Pad) |
| POPULATE_PROFILE_INFO_CUSTOM(Reshape) |
| POPULATE_PROFILE_INFO_CUSTOM(Slice) |
| POPULATE_PROFILE_INFO_CUSTOM(Tile) |
| POPULATE_PROFILE_INFO_CUSTOM(Transpose) |
| POPULATE_PROFILE_INFO_CUSTOM(Gather) |
| POPULATE_PROFILE_INFO_CUSTOM(Scatter) |
| POPULATE_PROFILE_INFO_CUSTOM(Resize) |
| POPULATE_PROFILE_INFO_CUSTOM(Select) |
| POPULATE_PROFILE_INFO_CUSTOM(Rescale) |
| POPULATE_PROFILE_INFO_CUSTOM(MatMul) |
| POPULATE_PROFILE_INFO_CUSTOM(Variable) |
| POPULATE_PROFILE_INFO_CUSTOM(VariableWrite) |
| POPULATE_PROFILE_INFO_CUSTOM(If) |
| POPULATE_PROFILE_INFO_CUSTOM(While) |
| |
| // For the most of tosa operators, all operands are profile/extension related |
| // and hence are all considered in this profile-based compilance check. |
| POPULATE_PROFILE_INFO_COMMON(Cast) |
| POPULATE_PROFILE_INFO_COMMON(Const) |
| POPULATE_PROFILE_INFO_COMMON(ArgMax) |
| POPULATE_PROFILE_INFO_COMMON(Sub) |
| POPULATE_PROFILE_INFO_COMMON(Maximum) |
| POPULATE_PROFILE_INFO_COMMON(Minimum) |
| POPULATE_PROFILE_INFO_COMMON(MaxPool2d) |
| POPULATE_PROFILE_INFO_COMMON(Clamp) |
| POPULATE_PROFILE_INFO_COMMON(Erf) |
| POPULATE_PROFILE_INFO_COMMON(Sigmoid) |
| POPULATE_PROFILE_INFO_COMMON(Tanh) |
| POPULATE_PROFILE_INFO_COMMON(Add) |
| POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift) |
| POPULATE_PROFILE_INFO_COMMON(BitwiseAnd) |
| POPULATE_PROFILE_INFO_COMMON(BitwiseNot) |
| POPULATE_PROFILE_INFO_COMMON(BitwiseOr) |
| POPULATE_PROFILE_INFO_COMMON(BitwiseXor) |
| POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift) |
| POPULATE_PROFILE_INFO_COMMON(LogicalRightShift) |
| POPULATE_PROFILE_INFO_COMMON(LogicalAnd) |
| POPULATE_PROFILE_INFO_COMMON(LogicalNot) |
| POPULATE_PROFILE_INFO_COMMON(LogicalOr) |
| POPULATE_PROFILE_INFO_COMMON(LogicalXor) |
| POPULATE_PROFILE_INFO_COMMON(IntDiv) |
| POPULATE_PROFILE_INFO_COMMON(Pow) |
| POPULATE_PROFILE_INFO_COMMON(Table) |
| POPULATE_PROFILE_INFO_COMMON(Abs) |
| POPULATE_PROFILE_INFO_COMMON(Ceil) |
| POPULATE_PROFILE_INFO_COMMON(Clz) |
| POPULATE_PROFILE_INFO_COMMON(Sin) |
| POPULATE_PROFILE_INFO_COMMON(Cos) |
| POPULATE_PROFILE_INFO_COMMON(Exp) |
| POPULATE_PROFILE_INFO_COMMON(Floor) |
| POPULATE_PROFILE_INFO_COMMON(Log) |
| POPULATE_PROFILE_INFO_COMMON(Negate) |
| POPULATE_PROFILE_INFO_COMMON(Reciprocal) |
| POPULATE_PROFILE_INFO_COMMON(Rsqrt) |
| POPULATE_PROFILE_INFO_COMMON(ReduceAll) |
| POPULATE_PROFILE_INFO_COMMON(ReduceAny) |
| POPULATE_PROFILE_INFO_COMMON(ReduceMax) |
| POPULATE_PROFILE_INFO_COMMON(ReduceMin) |
| POPULATE_PROFILE_INFO_COMMON(ReduceProduct) |
| POPULATE_PROFILE_INFO_COMMON(ReduceSum) |
| POPULATE_PROFILE_INFO_COMMON(Equal) |
| POPULATE_PROFILE_INFO_COMMON(GreaterEqual) |
| POPULATE_PROFILE_INFO_COMMON(Greater) |
| POPULATE_PROFILE_INFO_COMMON(Reverse) |
| POPULATE_PROFILE_INFO_COMMON(Identity) |
| POPULATE_PROFILE_INFO_COMMON(VariableRead) |
| |
| // Type Invariant Extension, a capability extension that is independent |
| // of the data type, meaning any compatible type can be used. No type |
| // constraint for those operations. |
| POPULATE_PROFILE_INFO_SKIP(ConstShape) |
| POPULATE_PROFILE_INFO_SKIP(Yield) |
| |
| return failure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Tosa Profile And Extension Compliance Checker |
| //===----------------------------------------------------------------------===// |
| |
| template <typename T> |
| FailureOr<SmallVector<T>> |
| TosaProfileCompliance::getOperatorDefinition(Operation *op, |
| CheckCondition &condition) { |
| const std::string opName = op->getName().getStringRef().str(); |
| const auto complianceMap = getProfileComplianceMap<T>(); |
| const auto it = complianceMap.find(opName); |
| if (it == complianceMap.end()) |
| return {}; |
| |
| return findMatchedProfile<T>(op, it->second, condition); |
| } |
| |
| template <typename T> |
| LogicalResult TosaProfileCompliance::checkProfileOrExtension( |
| Operation *op, const tosa::TargetEnv &targetEnv, |
| const SmallVector<ArrayRef<T>> &specRequiredModeSet) { |
| |
| // None of profile requirement is set in the specification. |
| if (specRequiredModeSet.size() == 0) |
| return success(); |
| |
| CheckCondition condition = CheckCondition::invalid; |
| const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition); |
| if (failed(maybeOpRequiredMode)) { |
| // Operators such as control-flow and shape ops do not have an operand type |
| // restriction. When the profile compliance information of operation is not |
| // found, confirm if the target have enabled the profile required from the |
| // specification. |
| int mode_count = 0; |
| for (const auto &cands : specRequiredModeSet) { |
| if (targetEnv.allowsAnyOf(cands)) |
| return success(); |
| mode_count += cands.size(); |
| } |
| |
| op->emitOpError() << "illegal: requires" |
| << (mode_count > 1 ? " any of " : " ") << "[" |
| << llvm::join(stringifyProfile<T>(specRequiredModeSet), |
| ", ") |
| << "] but not enabled in target\n"; |
| |
| return failure(); |
| } |
| |
| // Find the required profiles or extensions according to the operand type |
| // combination. |
| const auto opRequiredMode = maybeOpRequiredMode.value(); |
| if (opRequiredMode.size() == 0) { |
| // No matched restriction found. |
| return success(); |
| } |
| |
| if (condition == CheckCondition::allOf && |
| !targetEnv.allowsAllOf(opRequiredMode)) { |
| op->emitOpError() << "illegal: requires" |
| << (opRequiredMode.size() > 1 ? " all of " : " ") << "[" |
| << llvm::join(stringifyProfile<T>(opRequiredMode), ", ") |
| << "] but not enabled in target\n"; |
| return failure(); |
| } |
| |
| if (condition == CheckCondition::anyOf && |
| !targetEnv.allowsAnyOf(opRequiredMode)) { |
| op->emitOpError() << "illegal: requires" |
| << (opRequiredMode.size() > 1 ? " any of " : " ") << "[" |
| << llvm::join(stringifyProfile<T>(opRequiredMode), ", ") |
| << "] but not enabled in target\n"; |
| return failure(); |
| } |
| |
| // Each extension can contain a list of profiles that it works with, usually |
| // have the same data type. |
| if constexpr (std::is_same_v<T, Extension>) { |
| for (const auto &mode : opRequiredMode) { |
| SmallVector<Profile> coProfs = getCooperativeProfiles(mode); |
| if (!targetEnv.allowsAnyOf(coProfs)) { |
| op->emitOpError() << "illegal: requires [" |
| << llvm::join(stringifyProfile<Profile>(coProfs), |
| ", ") |
| << "] to work with but not enabled in target\n"; |
| return failure(); |
| } |
| } |
| } |
| |
| // Ensure the profile inference match the profile knowledge of the |
| // specification. |
| for (const auto &cands : specRequiredModeSet) { |
| for (const auto &mode : opRequiredMode) { |
| if (!llvm::is_contained(cands, mode)) { |
| op->emitOpError() << "illegal: requires [" |
| << llvm::join(stringifyProfile<T>(opRequiredMode), |
| ", ") |
| << "] but not included in the profile compliance [" |
| << llvm::join( |
| stringifyProfile<T>(specRequiredModeSet), ", ") |
| << "]\n"; |
| return failure(); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult |
| TosaProfileCompliance::checkProfile(Operation *op, |
| const tosa::TargetEnv &targetEnv) { |
| if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op)) |
| return checkProfileOrExtension<Profile>(op, targetEnv, |
| interface.getProfiles()); |
| |
| return success(); |
| } |
| |
| LogicalResult |
| TosaProfileCompliance::checkExtension(Operation *op, |
| const tosa::TargetEnv &targetEnv) { |
| if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op)) |
| return checkProfileOrExtension<Extension>(op, targetEnv, |
| interface.getExtensions()); |
| |
| return success(); |
| } |
| |
| LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { |
| CheckCondition condition = CheckCondition::invalid; |
| const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); |
| const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); |
| if (!failed(maybeProfDef) && !failed(maybeExtDef) && |
| !maybeProfDef.value().size() && !maybeExtDef.value().size()) |
| return failure(); |
| |
| return success(); |
| } |
| |
| // Find the profiles or extensions requirement according to the signature of |
| // type of the operand list. |
| template <typename T> |
| SmallVector<T> TosaProfileCompliance::findMatchedProfile( |
| Operation *op, SmallVector<OpComplianceInfo<T>> compInfo, |
| CheckCondition &condition) { |
| assert(compInfo.size() != 0 && |
| "profile-based compliance information is empty"); |
| |
| // Populate the type of profile/extension relevant operands. |
| ProfileInfoDepot depot(op); |
| SmallVector<TypeInfo> present = depot.getInfo(); |
| if (present.size() == 0) |
| return {}; |
| |
| for (size_t i = 0; i < compInfo.size(); i++) { |
| SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet; |
| for (SmallVector<TypeInfo> expected : sets) { |
| assert(present.size() == expected.size() && |
| "the entries for profile-based compliance do not match between " |
| "the generated metadata and the type definition retrieved from " |
| " the operation"); |
| |
| bool is_found = true; |
| // Compare the type signature between the given operation and the |
| // compliance metadata. |
| for (size_t j = 0; j < expected.size(); j++) { |
| if (!isSameTypeInfo(present[j], expected[j])) { |
| // Verify the next mode set from the list. |
| is_found = false; |
| break; |
| } |
| } |
| |
| if (is_found == true) { |
| condition = compInfo[i].condition; |
| return compInfo[i].mode; |
| } |
| } |
| } |
| |
| return {}; |
| } |
| |
| // Debug utilites. |
| template <typename T> |
| SmallVector<StringRef> |
| TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) { |
| SmallVector<StringRef> debugStrings; |
| for (const auto &profile : profiles) { |
| if constexpr (std::is_same_v<T, Profile>) |
| debugStrings.push_back(tosa::stringifyProfile(profile)); |
| else |
| debugStrings.push_back(tosa::stringifyExtension(profile)); |
| } |
| return debugStrings; |
| } |
| |
| template <typename T> |
| SmallVector<StringRef> TosaProfileCompliance::stringifyProfile( |
| const SmallVector<ArrayRef<T>> &profileSet) { |
| SmallVector<StringRef> debugStrings; |
| |
| for (const auto &profiles : profileSet) { |
| auto tempStrings = stringifyProfile<T>(profiles); |
| llvm::append_range(debugStrings, tempStrings); |
| } |
| |
| return debugStrings; |
| } |