| //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements an ODS (and C++) generator from a YAML form |
| // derived from the mathematical expression of linalg named ops. Typically a |
| // math oriented DSL will be used to export the essential representation to |
| // this form, and maintaining the SOT at the math level (versus recreating it |
| // in MLIR) is deemed to have systemic value. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/AsmParser/AsmParser.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/Support/FileUtilities.h" |
| #include "mlir/Support/LLVM.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/ToolOutputFile.h" |
| #include "llvm/Support/YAMLTraits.h" |
| #include <optional> |
| |
| using namespace mlir; |
| |
| using llvm::yaml::Input; |
| using llvm::yaml::MappingTraits; |
| using llvm::yaml::ScalarEnumerationTraits; |
| using llvm::yaml::ScalarTraits; |
| |
| #define DEBUG_TYPE "linalg-ods-gen" |
| |
| //===----------------------------------------------------------------------===// |
| // Mapping structs (correspond to data types in the YAML description). |
| // TODO: Since this is a schema/part of the contract, it should be moved to |
| // a real header. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| struct LinalgYAMLContext { |
| MLIRContext *mlirContext; |
| }; |
| |
| struct LinalgOpMetadata { |
| std::string name; |
| std::string cppClassName; |
| std::optional<std::string> doc; |
| SmallVector<std::string> implements; |
| SmallVector<std::string> defines; |
| }; |
| |
| struct SerializedAffineMap { |
| AffineMapAttr affineMapAttr; |
| |
| AffineMap affineMap() { return affineMapAttr.getValue(); } |
| }; |
| |
| enum class LinalgOperandDefKind { |
| InputTensor, |
| Scalar, |
| OutputTensor, |
| IndexAttr, |
| UnaryFnAttr, |
| BinaryFnAttr, |
| TernaryFnAttr, |
| TypeFnAttr |
| }; |
| |
| struct LinalgOperandDef { |
| std::string name; |
| LinalgOperandDefKind kind; |
| std::optional<std::string> typeVar; |
| std::optional<SerializedAffineMap> shapeMap; |
| std::optional<SerializedAffineMap> indexAttrMap; |
| std::optional<SmallVector<int64_t>> defaultIndices; |
| std::optional<std::string> defaultFn; |
| }; |
| |
| enum class LinalgIteratorTypeDef { |
| parallel, |
| reduction, |
| }; |
| |
| struct LinalgIndexingMapsConfig { |
| std::optional<SmallVector<SerializedAffineMap>> staticIndexingMaps; |
| }; |
| |
| struct ScalarExpression; |
| |
| enum class ScalarFnKind { Unary, Binary, Ternary, Type }; |
| |
| struct ScalarFn { |
| ScalarFnKind kind; |
| std::optional<std::string> fnName; |
| std::optional<std::string> attrName; |
| std::optional<std::string> typeVar; |
| // NOTE: This must be of arity 1, but to break the self-referential cycle, |
| // we use a heap allocated vector. |
| std::vector<ScalarExpression> operands; |
| }; |
| |
| struct ScalarExpression { |
| std::optional<std::string> arg; |
| std::optional<std::string> constant; |
| std::optional<int64_t> index; |
| std::optional<ScalarFn> scalarFn; |
| }; |
| |
| struct ScalarAssign { |
| std::string arg; |
| ScalarExpression value; |
| }; |
| |
| struct LinalgStructuredOpConfig { |
| SmallVector<LinalgOperandDef> args; |
| LinalgIndexingMapsConfig indexingMaps; |
| SmallVector<LinalgIteratorTypeDef> iteratorTypes; |
| std::vector<ScalarAssign> assignments; |
| }; |
| |
| struct LinalgOpConfig { |
| std::optional<LinalgOpMetadata> metadata; |
| std::optional<LinalgStructuredOpConfig> structuredOp; |
| }; |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Mapping traits. |
| //===----------------------------------------------------------------------===// |
| |
| LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression) |
| LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig) |
| |
| namespace llvm { |
| namespace yaml { |
| |
| /// Top-level type containing op metadata and one of a concrete op type. |
| /// Currently, the only defined op type is `structured_op` (maps to |
| /// `LinalgStructuredOpConfig`). |
| template <> |
| struct MappingTraits<LinalgOpConfig> { |
| static void mapping(IO &io, LinalgOpConfig &info) { |
| io.mapOptional("metadata", info.metadata); |
| io.mapOptional("structured_op", info.structuredOp); |
| } |
| }; |
| |
| /// A structured op models (at most) a single contraction by modeling |
| /// - A list of named arguments (`LinalgOperandDef`), which can be inputs, |
| /// outputs, or index attributes. |
| /// - List of indexing maps (see `LinalgIndexingMaps`). |
| /// - Iterator types (see `LinalgIteratorTypeDef`). |
| /// - List of scalar level assignment (see `ScalarAssign`). |
| template <> |
| struct MappingTraits<LinalgStructuredOpConfig> { |
| static void mapping(IO &io, LinalgStructuredOpConfig &info) { |
| io.mapRequired("args", info.args); |
| io.mapRequired("indexing_maps", info.indexingMaps); |
| io.mapRequired("iterator_types", info.iteratorTypes); |
| io.mapRequired("assignments", info.assignments); |
| } |
| }; |
| |
| /// Maps a named tensor, scalar or attribute argument to an operation, |
| /// consisting of: |
| /// - `name`: Must be unique within the operation. |
| /// - `usage`: How the argument is used (input, output, attribute, etc). |
| /// - `type_var`: The symbolic type variable that binds to the element or self |
| /// type of the tensor or scalar argument, respectively. |
| /// - `shape_map`: An optional AffineMap from all op symbols to the shape of |
| /// the argument. Only tensor arguments have a `shape_map`. Each shape must |
| /// be normalized over the same list of symbols and have no dimension |
| /// inputs. |
| /// - `index_attr_map`: An optional AffineMap from all op symbols to the |
| /// index attribute symbols. During op creation these symbols are replaced |
| /// by the corresponding `name` index attribue values. Only index attribute |
| /// arguments have an `index_attr_map`. |
| /// - `default_indices`: An optional default initialization for index |
| /// attribute arguments. |
| /// - `default_fn`: An optional default initialization for function attribute |
| /// arguments. |
| template <> |
| struct MappingTraits<LinalgOperandDef> { |
| static void mapping(IO &io, LinalgOperandDef &info) { |
| io.mapRequired("name", info.name); |
| io.mapRequired("kind", info.kind); |
| io.mapOptional("type_var", info.typeVar); |
| io.mapOptional("shape_map", info.shapeMap); |
| io.mapOptional("index_attr_map", info.indexAttrMap); |
| io.mapOptional("default_indices", info.defaultIndices); |
| io.mapOptional("default_fn", info.defaultFn); |
| } |
| }; |
| |
| /// Usage enum for a named argument. |
| template <> |
| struct ScalarEnumerationTraits<LinalgOperandDefKind> { |
| static void enumeration(IO &io, LinalgOperandDefKind &value) { |
| io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor); |
| io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar); |
| io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor); |
| io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr); |
| io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr); |
| io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr); |
| io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr); |
| io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr); |
| } |
| }; |
| |
| /// Iterator type enum. |
| template <> |
| struct ScalarEnumerationTraits<LinalgIteratorTypeDef> { |
| static void enumeration(IO &io, LinalgIteratorTypeDef &value) { |
| io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); |
| io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); |
| } |
| }; |
| |
| /// Metadata about the op (name, C++ name, and documentation). |
| template <> |
| struct MappingTraits<LinalgOpMetadata> { |
| static void mapping(IO &io, LinalgOpMetadata &info) { |
| io.mapRequired("name", info.name); |
| io.mapRequired("cpp_class_name", info.cppClassName); |
| io.mapOptional("doc", info.doc); |
| io.mapOptional("implements", info.implements); |
| io.mapOptional("defines", info.defines); |
| } |
| }; |
| |
| /// How the ops indexing maps are produced. Must be one of: |
| /// - static_indexing_maps: A static list of AffineMaps, possibly with |
| /// some symbols that bind to attributes of the op. Each indexing map must |
| /// be normalized over the same list of dimensions, and its symbols must |
| /// match the symbols for argument shapes. |
| template <> |
| struct MappingTraits<LinalgIndexingMapsConfig> { |
| static void mapping(IO &io, LinalgIndexingMapsConfig &info) { |
| io.mapOptional("static_indexing_maps", info.staticIndexingMaps); |
| } |
| }; |
| |
| /// Models an assignment to a named output. |
| /// - The `arg` name must match a named output. |
| /// - The `value` is a scalar expression for computing the value to |
| /// assign (see `ScalarExpression`). |
| template <> |
| struct MappingTraits<ScalarAssign> { |
| static void mapping(IO &io, ScalarAssign &info) { |
| io.mapRequired("arg", info.arg); |
| io.mapRequired("value", info.value); |
| } |
| }; |
| |
| /// A scalar expression (RHS of an assignment). Must be one of: |
| /// - `scalar_arg`: An operation argument. |
| /// - `scalar_const`: A constant definition. |
| /// - `scalar_index`: An iteration index. |
| /// - `scalar_fn`: A named function (see `ScalarFn`). |
| template <> |
| struct MappingTraits<ScalarExpression> { |
| static void mapping(IO &io, ScalarExpression &info) { |
| io.mapOptional("scalar_arg", info.arg); |
| io.mapOptional("scalar_const", info.constant); |
| io.mapOptional("scalar_index", info.index); |
| io.mapOptional("scalar_fn", info.scalarFn); |
| } |
| }; |
| |
| /// Scalar function kind enum. |
| template <> |
| struct ScalarEnumerationTraits<ScalarFnKind> { |
| static void enumeration(IO &io, ScalarFnKind &value) { |
| io.enumCase(value, "unary", ScalarFnKind::Unary); |
| io.enumCase(value, "binary", ScalarFnKind::Binary); |
| io.enumCase(value, "ternary", ScalarFnKind::Ternary); |
| io.enumCase(value, "type", ScalarFnKind::Type); |
| } |
| }; |
| |
| /// A scalar expression that evaluates a named function. |
| /// Functions are generally "math" level and type polymorphic. Builtin |
| /// functions include: |
| /// - `add(lhs, rhs)` |
| /// - `mul(lhs, rhs)` |
| template <> |
| struct MappingTraits<ScalarFn> { |
| static void mapping(IO &io, ScalarFn &info) { |
| io.mapRequired("kind", info.kind); |
| io.mapOptional("fn_name", info.fnName); |
| io.mapOptional("attr_name", info.attrName); |
| io.mapOptional("type_var", info.typeVar); |
| io.mapRequired("operands", info.operands); |
| } |
| }; |
| |
| /// Helper mapping which accesses an AffineMapAttr as a serialized string of |
| /// the same. |
| template <> |
| struct ScalarTraits<SerializedAffineMap> { |
| static void output(const SerializedAffineMap &value, void *rawYamlContext, |
| raw_ostream &out) { |
| assert(value.affineMapAttr); |
| value.affineMapAttr.print(out); |
| } |
| static StringRef input(StringRef scalar, void *rawYamlContext, |
| SerializedAffineMap &value) { |
| assert(rawYamlContext); |
| auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext); |
| if (auto attr = dyn_cast_or_null<AffineMapAttr>( |
| mlir::parseAttribute(scalar, yamlContext->mlirContext))) |
| value.affineMapAttr = attr; |
| else if (!value.affineMapAttr || !isa<AffineMapAttr>(value.affineMapAttr)) |
| return "could not parse as an affine map attribute"; |
| return StringRef(); |
| } |
| static QuotingType mustQuote(StringRef) { return QuotingType::None; } |
| }; |
| |
| } // namespace yaml |
| } // namespace llvm |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Generation utilities |
| //===----------------------------------------------------------------------===// |
| |
| class GenerationContext { |
| public: |
| GenerationContext(MLIRContext *context, raw_ostream *odsOut, |
| raw_ostream *defnOut) |
| : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut), |
| defnOut(defnOut) {} |
| |
| MLIRContext *getContext() { return context; } |
| |
| void setLoc(Location loc) { this->loc = loc; } |
| Location getLoc() { return loc; } |
| |
| bool shouldGenerateOds() { return odsOut; } |
| bool shouldGenerateDefns() { return defnOut; } |
| |
| raw_ostream &odss() { |
| assert(odsOut && "ODS stream not defined"); |
| return *odsOut; |
| } |
| |
| raw_ostream &defns() { |
| assert(defnOut && "Definition stream not defined"); |
| return *defnOut; |
| } |
| |
| private: |
| MLIRContext *context; |
| Location loc; |
| raw_ostream *odsOut; |
| raw_ostream *defnOut; |
| }; |
| |
| } // namespace |
| |
| static std::string generateCppExpression(SerializedAffineMap self, |
| StringRef contextName) { |
| std::string printedStr; |
| llvm::raw_string_ostream printedSs(printedStr); |
| self.affineMapAttr.print(printedSs); |
| printedSs.flush(); |
| |
| static const char exprFormat[] = |
| R"FMT(llvm::cast<AffineMapAttr>(mlir::parseAttribute("{0}", {1})).getValue())FMT"; |
| return llvm::formatv(exprFormat, printedStr, contextName); |
| } |
| |
| template <typename Container> |
| static std::string interleaveToString(Container &container, |
| StringRef separator) { |
| std::string result; |
| llvm::raw_string_ostream ss(result); |
| llvm::interleave(container, ss, separator); |
| ss.flush(); |
| return result; |
| } |
| |
| static std::optional<int> |
| findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) { |
| for (const auto &it : llvm::enumerate(args)) { |
| if (it.value().name == name) |
| return it.index(); |
| } |
| return std::nullopt; |
| } |
| |
| // Try to map the TypeVar to a predefined or an argument type. |
| static std::optional<std::string> |
| findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) { |
| // Handle all predefined types. |
| if (typeVar == "I32") |
| return std::string("helper.getIntegerType(32)"); |
| if (typeVar == "I64") |
| return std::string("helper.getIntegerType(64)"); |
| if (typeVar == "F32") |
| return std::string("helper.getFloat32Type()"); |
| if (typeVar == "F64") |
| return std::string("helper.getFloat64Type()"); |
| |
| // Search all argument types. |
| for (const auto &it : llvm::enumerate(args)) { |
| if (it.value().kind != LinalgOperandDefKind::InputTensor && |
| it.value().kind != LinalgOperandDefKind::Scalar && |
| it.value().kind != LinalgOperandDefKind::OutputTensor) |
| continue; |
| if (*it.value().typeVar == typeVar) |
| return llvm::formatv("block.getArgument({0}).getType()", it.index()) |
| .str(); |
| } |
| |
| return std::nullopt; |
| } |
| |
| static ScalarAssign *findAssignment(StringRef name, |
| std::vector<ScalarAssign> &assignments) { |
| for (auto &assign : assignments) { |
| if (assign.arg == name) |
| return &assign; |
| } |
| return nullptr; |
| } |
| |
| // Return true if the operand is a function attribute. |
| static bool isFunctionAttribute(LinalgOperandDefKind kind) { |
| return kind == LinalgOperandDefKind::UnaryFnAttr || |
| kind == LinalgOperandDefKind::BinaryFnAttr || |
| kind == LinalgOperandDefKind::TernaryFnAttr || |
| kind == LinalgOperandDefKind::TypeFnAttr; |
| } |
| |
| // Return true if the operand is an attribute. |
| static bool isAttribute(LinalgOperandDefKind kind) { |
| return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind); |
| } |
| |
| // Get the enum name for the given operand kind. |
| std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) { |
| switch (kind) { |
| case LinalgOperandDefKind::UnaryFnAttr: |
| return std::string("UnaryFn"); |
| case LinalgOperandDefKind::BinaryFnAttr: |
| return std::string("BinaryFn"); |
| case LinalgOperandDefKind::TernaryFnAttr: |
| return std::string("TernaryFn"); |
| case LinalgOperandDefKind::TypeFnAttr: |
| return std::string("TypeFn"); |
| default: |
| break; |
| } |
| llvm_unreachable("unsupported function attribute kind"); |
| } |
| |
| // Get the enum name for the given function kind. |
| std::string convertFunctionKindToEnumName(ScalarFnKind kind) { |
| switch (kind) { |
| case ScalarFnKind::Unary: |
| return std::string("UnaryFn"); |
| case ScalarFnKind::Binary: |
| return std::string("BinaryFn"); |
| case ScalarFnKind::Ternary: |
| return std::string("TernaryFn"); |
| case ScalarFnKind::Type: |
| return std::string("TypeFn"); |
| } |
| llvm_unreachable("unsupported function kind"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Templates |
| //===----------------------------------------------------------------------===// |
| |
| // A single line banner format. Parameters: |
| // {0}: Single line comment |
| static const char bannerFormat[] = R"FMT( |
| //===----------------------------------------------------------------------===// |
| // {0} |
| //===----------------------------------------------------------------------===// |
| )FMT"; |
| |
| //===----------------------------------------------------------------------===// |
| // Named generic op generation. |
| // These ops map at most a single contraction that complies with the limitations |
| // of a linalg.generic. |
| //===----------------------------------------------------------------------===// |
| |
| // Template for Linalg named ops' ODS definitions. Parameters: |
| // {0}: ODS/C++ op name |
| // {1}: assembly op mnemonic |
| // {2}: op interface list |
| // {3}: documentation (summary + description) |
| // {4}: op attribute list |
| // {5}: builder methods taking standalone attribute parameters |
| // {6}: additional method defintions |
| // {7}: additional methods for attributes used by indexing maps |
| static const char structuredOpOdsHeaderFormat[] = R"FMT( |
| //===----------------------------------------------------------------------===// |
| // Op definition for {0} |
| //===----------------------------------------------------------------------===// |
| |
| def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], |
| /*extraInterfaces=*/[{2}])> { |
| {3} |
| let arguments = (ins |
| Variadic<AnyType>:$inputs, |
| Variadic<AnyShaped>:$outputs{4} |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$result_tensors); |
| let regions = (region AnyRegion:$region); |
| |
| let skipDefaultBuilders = 1; |
| let builders = [ |
| OpBuilder< |
| (ins "ValueRange":$inputs, "ValueRange":$outputs, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, |
| attributes, {0}::getRegionBuilder()); |
| }]>, |
| OpBuilder< |
| (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, |
| "ValueRange":$outputs, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| buildStructuredOp($_builder, $_state, resultTensorTypes, |
| inputs, outputs, attributes, {0}::getRegionBuilder()); |
| }]>, |
| OpBuilder< |
| (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| $_state.addOperands(operands); |
| $_state.addAttributes(attributes); |
| $_state.addTypes(resultTensorTypes); |
| (void)$_state.addRegion(); |
| }]> |
| {5} |
| ]; |
| let hasCustomAssemblyFormat = 1; |
| let hasFolder = 1; |
| {6} |
| |
| let extraClassDeclaration = structuredOpsBaseDecls # [{{ |
| // Auto-generated. |
| SmallVector<utils::IteratorType> getIteratorTypesArray(); |
| ArrayAttr getIndexingMaps(); |
| static void regionBuilder(ImplicitLocOpBuilder &b, |
| Block &block, ArrayRef<NamedAttribute> attrs); |
| static std::function<void(ImplicitLocOpBuilder &, |
| Block &, ArrayRef<NamedAttribute>)> |
| getRegionBuilder() {{ |
| return regionBuilder; |
| } |
| |
| ::mlir::MutableOperandRange getDpsInitsMutable() {{ |
| return getOutputsMutable(); |
| } |
| |
| // Generic methods. |
| static unsigned getNumRegionArgs(); |
| std::string getLibraryCallName(); |
| {7} |
| }]; |
| } |
| )FMT"; |
| |
| // Builder method taking attribute parameters. Parameters: |
| // {0}: Class name |
| // {1}: Comma interleaved attribute parameters |
| // {2}: Attribute initialization |
| static const char structuredOpBuilderFormat[] = R"FMT( |
| , OpBuilder< |
| (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, |
| "ValueRange":$outputs, {1}, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| {2} |
| buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, |
| attributes, {0}::getRegionBuilder()); |
| }]> |
| )FMT"; |
| |
| // The getIteratorTypesArray() method for structured ops. Parameters: |
| // {0}: Class name |
| // {1}: Comma interleaved iterator type names. |
| static const char structuredOpIteratorTypesFormat[] = |
| R"FMT( |
| SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{ |
| return SmallVector<utils::IteratorType>{{ {1} }; |
| } |
| )FMT"; |
| |
| // The getIteratorTypesArray() method for rank polymorphic structured ops. |
| // Parameters: |
| // {0}: Class name |
| static const char rankPolyStructuredOpIteratorTypesFormat[] = |
| R"FMT( |
| SmallVector<utils::IteratorType> {0}::getIteratorTypesArray() {{ |
| int64_t rank = getRank(getDpsInitOperand(0)); |
| return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); |
| } |
| )FMT"; |
| |
| // The indexing_maps() method for structured ops. Parameters: |
| // {0}: Class name |
| // {1}: Comma-separated list of dimension variable names. |
| // {2}: Statements |
| static const char structuredOpIndexingMapsFormat[] = R"FMT( |
| ArrayAttr {0}::getIndexingMaps() {{ |
| static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; |
| ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr); |
| if (cached) |
| return cached; |
| |
| MLIRContext *context = getContext(); |
| auto symbolBindings = getSymbolBindings(*this); |
| SmallVector<AffineMap> maps; |
| {2} |
| cached = Builder(context).getAffineMapArrayAttr(maps); |
| getOperation()->setAttr(memoizeAttr, cached); |
| return cached; |
| } |
| )FMT"; |
| |
| // The indexing_maps() method for rank polymorphic structured ops. Parameters: |
| // {0}: Class name |
| static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT( |
| ArrayAttr {0}::getIndexingMaps() {{ |
| MLIRContext *context = getContext(); |
| AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); |
| AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( |
| getNumParallelLoops(), context); |
| SmallVector<AffineMap> indexingMaps; |
| for (OpOperand &opOperand : getOperation()->getOpOperands()) |
| indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap); |
| return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); |
| } |
| )FMT"; |
| |
| // Implementations of fold and getEffects. |
| // Parameters: |
| // {0}: Class name |
| const char structuredOpFoldersFormat[] = R"FMT( |
| LogicalResult {0}::fold(FoldAdaptor, |
| SmallVectorImpl<OpFoldResult> &) {{ |
| return memref::foldMemRefCast(*this); |
| } |
| void {0}::getEffects(SmallVectorImpl< |
| SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{ |
| if (hasPureTensorSemantics()) return; |
| getGenericEffectsImpl(effects, |
| getOperation()->getResults(), getDpsInputs(), getDpsInits()); |
| } |
| )FMT"; |
| |
| // Implementation of parse/print. |
| // Parameters: |
| // {0}: Class name |
| static const char structuredOpParserFormat[] = R"FMT( |
| ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ |
| return ::parseNamedStructuredOp(parser, result, |
| {0}::getNumRegionArgs(), {0}::getRegionBuilder()); |
| } |
| void {0}::print(OpAsmPrinter &p) {{ |
| ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); |
| } |
| )FMT"; |
| |
| static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, |
| GenerationContext &genContext) { |
| if (!genContext.shouldGenerateOds()) |
| return success(); |
| |
| raw_ostream &os = genContext.odss(); |
| |
| std::string interfaceNameList; |
| std::string attrList; |
| std::string attrMethods; |
| std::string attrBuilder; |
| |
| std::string doc; |
| if (opConfig.metadata->doc) { |
| static const char structuredOpDocFmt[] = R"FMT( |
| let summary = [{{{0}}]; |
| let description = [{{{1}}]; |
| )FMT"; |
| StringRef summary, description; |
| std::tie(summary, description) = |
| StringRef(*opConfig.metadata->doc).trim().split("\n\n"); |
| |
| doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim()); |
| } |
| |
| interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); |
| |
| std::string definitionList; |
| for (const std::string &definition : opConfig.metadata->defines) { |
| static const char definitionFmt[] = "let {0} = 1;\n"; |
| definitionList.append(llvm::formatv(definitionFmt, definition)); |
| } |
| |
| if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return isAttribute(arg.kind); |
| })) { |
| SmallVector<std::string> attrDefs; |
| SmallVector<std::string> attrParams; |
| SmallVector<std::string> attrStmts; |
| for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
| static const char paramFmt[] = "\"Attribute\":${0}"; |
| static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; |
| // Add the type conversion attributes to the op definition and builders. |
| if (isFunctionAttribute(arg.kind)) { |
| assert(arg.defaultFn); |
| std::string enumName = convertOperandKindToEnumName(arg.kind); |
| static const char typeFmt[] = "{0}::{1}"; |
| static const char defFmt[] = |
| "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}"; |
| attrDefs.push_back(llvm::formatv( |
| defFmt, llvm::formatv("{0}Attr", enumName), |
| llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name)); |
| attrParams.push_back(llvm::formatv(paramFmt, arg.name)); |
| attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); |
| } |
| // Add the index attributes to the op definition and builders. |
| if (arg.kind == LinalgOperandDefKind::IndexAttr) { |
| assert(arg.indexAttrMap.has_value()); |
| assert(arg.defaultIndices.has_value()); |
| size_t size = arg.indexAttrMap->affineMap().getNumResults(); |
| assert(arg.defaultIndices->size() == size); |
| static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; |
| static const char defFmt[] = |
| "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}"; |
| std::string defaultVals; |
| llvm::raw_string_ostream ss(defaultVals); |
| llvm::interleave( |
| *arg.defaultIndices, ss, |
| [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; }, |
| ", "); |
| attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), |
| ss.str(), arg.name)); |
| attrParams.push_back(llvm::formatv(paramFmt, arg.name)); |
| attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); |
| } |
| } |
| if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return arg.kind == LinalgOperandDefKind::IndexAttr; |
| })) { |
| attrMethods = R"( |
| bool hasDynamicIndexingMaps(); |
| LogicalResult verifyIndexingMapRequiredAttributes(); |
| )"; |
| } |
| attrList = ",\n" + llvm::join(attrDefs, ",\n"); |
| attrBuilder = llvm::formatv( |
| structuredOpBuilderFormat, opConfig.metadata->cppClassName, |
| llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); |
| } |
| |
| os << llvm::formatv(structuredOpOdsHeaderFormat, |
| opConfig.metadata->cppClassName, opConfig.metadata->name, |
| interfaceNameList, doc, attrList, attrBuilder, |
| definitionList, attrMethods); |
| |
| return success(); |
| } |
| |
| static LogicalResult |
| generateNamedGenericOpDefns(LinalgOpConfig &opConfig, |
| GenerationContext &genContext) { |
| if (!genContext.shouldGenerateDefns()) |
| return success(); |
| |
| raw_ostream &os = genContext.defns(); |
| StringRef className = opConfig.metadata->cppClassName; |
| |
| // Implementation banner. |
| std::string bannerComment = llvm::formatv("Implementation of {0}", className); |
| os << llvm::formatv(bannerFormat, bannerComment); |
| |
| // Compute the number of scalar and tensor arguments. |
| int64_t numOfArgs = |
| llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return arg.kind == LinalgOperandDefKind::InputTensor || |
| arg.kind == LinalgOperandDefKind::Scalar || |
| arg.kind == LinalgOperandDefKind::OutputTensor; |
| }); |
| |
| // An operation that accesses only scalars and scalar/rank zero tensors is |
| // rank polymorhpic. We implement rank polymorphism by generating different |
| // indexing maps and iterators that match the rank of the first output tensor. |
| // An operation is rank polymorphic if the iteration domain has rank zero. |
| bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty(); |
| |
| // Generate the iterator_types() method. |
| if (!isRankPolymorphic) { |
| std::string iteratorsStr; |
| llvm::raw_string_ostream ss(iteratorsStr); |
| llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, |
| [&](LinalgIteratorTypeDef it) { |
| switch (it) { |
| case LinalgIteratorTypeDef::parallel: |
| ss << "utils::IteratorType::parallel"; |
| break; |
| case LinalgIteratorTypeDef::reduction: |
| ss << "utils::IteratorType::reduction"; |
| break; |
| } |
| }); |
| ss.flush(); |
| os << llvm::formatv(structuredOpIteratorTypesFormat, className, |
| iteratorsStr); |
| } else { |
| os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className); |
| } |
| |
| // Generating the getIndexingMaps() method. |
| if (auto &staticMaps = |
| opConfig.structuredOp->indexingMaps.staticIndexingMaps) { |
| if (staticMaps->empty()) |
| return emitError(genContext.getLoc()) << "op has no indexing maps"; |
| if (!isRankPolymorphic) { |
| AffineMap firstMap = staticMaps->front().affineMap(); |
| |
| // Symbol bindings. |
| { |
| // For each symbol, generate a declaration for it, either with an |
| // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from |
| // an attribute). |
| // TODO: Possibly lift into a top-level method. |
| static const char structuredOpSymbolBindingsFormat[] = R"FMT( |
| static SmallVector<AffineExpr> getSymbolBindings({0} self) { |
| MLIRContext *context = self.getContext(); |
| SmallVector<AffineExpr> exprs; |
| {1} |
| return exprs; |
| } |
| )FMT"; |
| |
| unsigned symbolCount = firstMap.getNumSymbols(); |
| SmallVector<std::string> symbolBindings; |
| for (unsigned i = 0; i < symbolCount; ++i) { |
| symbolBindings.push_back(llvm::formatv( |
| " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); |
| } |
| |
| // Access an index attribute. Parameters: |
| // {0}: Attribute name |
| // {1}: Symbol position |
| // {2}: Attribute index |
| static const char structuredOpAccessAttrFormat[] = R"FMT( |
| int64_t cst{1} = self.get{0}().getValues<int64_t>()[{2}]; |
| exprs.push_back(getAffineConstantExpr(cst{1}, context)); |
| )FMT"; |
| // Update all symbol bindings mapped to an attribute. |
| for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
| if (arg.kind != LinalgOperandDefKind::IndexAttr) |
| continue; |
| assert(arg.indexAttrMap); |
| for (auto [idx, result] : |
| llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) { |
| if (auto symbol = dyn_cast<AffineSymbolExpr>(result)) { |
| std::string argName = arg.name; |
| argName[0] = toupper(argName[0]); |
| symbolBindings[symbol.getPosition()] = |
| llvm::formatv(structuredOpAccessAttrFormat, argName, |
| symbol.getPosition(), idx); |
| } |
| } |
| } |
| |
| std::string symbolBindingsStr; |
| llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); |
| llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); |
| symbolBindingsSs.flush(); |
| |
| os << llvm::formatv(structuredOpSymbolBindingsFormat, className, |
| symbolBindingsStr); |
| } |
| |
| // Indexing maps. |
| { |
| unsigned dimCount = firstMap.getNumDims(); |
| |
| // Generate a comma-separated list of dim identifiers to be passed to |
| // bindDims, ensuring tht AffineExpr identifiers are bound in the right |
| // order to the proper AffineDimExpr. |
| // This results in vars in scope like: d0, d1, d2... |
| SmallVector<unsigned> dimIndices; |
| for (unsigned i = 0; i < dimCount; ++i) |
| dimIndices.push_back(i); |
| std::string dimIdentsStr; |
| llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); |
| llvm::interleaveComma(dimIndices, dimIdentsSs, |
| [&](unsigned i) { dimIdentsSs << "d" << i; }); |
| dimIdentsSs.flush(); |
| |
| // Statements to add and simplify each affine map. |
| SmallVector<std::string> stmts; |
| for (auto &indexingMap : *staticMaps) { |
| // TODO: Assert that dim and symbol count match the first. |
| stmts.push_back( |
| llvm::formatv("maps.push_back({0});", |
| generateCppExpression(indexingMap, "context"))); |
| stmts.push_back(llvm::formatv( |
| "maps.back() = " |
| "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " |
| "symbolBindings, {0}, 0));", |
| dimCount)); |
| } |
| |
| // TODO: This needs to be memoized and/or converted to non-parser based |
| // C++ codegen prior to real use. |
| os << llvm::formatv(structuredOpIndexingMapsFormat, className, |
| dimIdentsStr, interleaveToString(stmts, "\n ")); |
| } |
| } else { |
| os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className); |
| } |
| } else { |
| return emitError(genContext.getLoc()) |
| << "generating code for non static indexing maps not currently " |
| "supported"; |
| } |
| |
| // getNumRegionArgs() |
| { |
| // Generates a getNumRegionArgs() method. Parameters: |
| // {0}: Class name |
| // {1}: Number of region args |
| static const char structuredOpGetNumRegionArgsFormat[] = R"FMT( |
| unsigned {0}::getNumRegionArgs() {{ return {1}; } |
| )FMT"; |
| os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className, |
| numOfArgs); |
| } |
| |
| // getLibraryCallName() |
| { |
| // Generates a getLibraryCallName method. Parameters: |
| // {0}: Class name |
| static const char structuredOpGetLibraryCallFormat[] = R"FMT( |
| std::string {0}::getLibraryCallName() {{ |
| return generateLibraryCallName(getOperation()); |
| } |
| )FMT"; |
| os << llvm::formatv(structuredOpGetLibraryCallFormat, className); |
| } |
| |
| // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() |
| if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return arg.kind == LinalgOperandDefKind::IndexAttr; |
| })) { |
| std::vector<std::string> attrVerifications; |
| for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
| if (arg.kind != LinalgOperandDefKind::IndexAttr) |
| continue; |
| assert(arg.indexAttrMap); |
| // Verify index attribute. Paramters: |
| // {0}: Attribute name |
| // {1}: Attribute size |
| static const char attrFmt[] = R"FMT( |
| if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{ |
| if (!attr.getType().getElementType().isInteger(64)) |
| return op->emitError("incorrect element type for index attribute '{0}'"); |
| if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} }) |
| return op->emitError("incorrect shape for index attribute '{0}'"); |
| } |
| )FMT"; |
| attrVerifications.push_back(llvm::formatv( |
| attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults())); |
| } |
| |
| // Generates the verifyIndexingMapRequiredAttributes method. Parameters: |
| // {0}: Class name |
| // {1}: Attribute verification |
| static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT( |
| bool {0}::hasDynamicIndexingMaps() {{ return true; } |
| LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ |
| Operation *op = getOperation(); |
| {1} |
| return success(); |
| } |
| )FMT"; |
| os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes, |
| className, llvm::join(attrVerifications, "\n")); |
| } |
| |
| // regionBuilder() |
| { |
| // Generates a regionBuilder method. Parameters. |
| // {0}: Class name |
| // {1}: Number of args |
| // {2}: Attributes |
| // {3}: Statements |
| static const char structuredOpRegionBuilderFormat[] = R"FMT( |
| void {0}::regionBuilder(ImplicitLocOpBuilder &b, |
| Block &block, ArrayRef<NamedAttribute> attrs) {{ |
| assert({1} > 0 && block.getNumArguments() == {1} && |
| "{0} regionBuilder expects {1} (>=0) args"); |
| RegionBuilderHelper helper(b, block); |
| SmallVector<Value> yields; |
| {2} |
| {3} |
| helper.yieldOutputs(yields); |
| } |
| )FMT"; |
| auto &args = opConfig.structuredOp->args; |
| auto &assignments = opConfig.structuredOp->assignments; |
| size_t generatedAssignmentCount = 0; |
| int localCounter = 0; |
| SmallVector<std::string> attrs; |
| SmallVector<std::string> stmts; |
| for (LinalgOperandDef &arg : args) { |
| if (!isFunctionAttribute(arg.kind)) |
| continue; |
| // Obtain the type function attribute values. Parameters. |
| // {0}: enum name |
| // {1}: attribute name |
| // {2}: default type function name |
| static const char attrDef[] = R"FMT( |
| {0} {1}Val = {0}::{2}; |
| auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ |
| return attr.getName() == "{1}"; }); |
| if ({1}Iter != attrs.end()) {{ |
| if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue())) |
| {1}Val = attr.getValue(); |
| } |
| )FMT"; |
| std::string enumName = convertOperandKindToEnumName(arg.kind); |
| attrs.push_back( |
| llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn)); |
| } |
| for (LinalgOperandDef &arg : args) { |
| if (arg.kind != LinalgOperandDefKind::OutputTensor) |
| continue; |
| |
| // Find the assignment that correlates with the argument. |
| ScalarAssign *assignment = findAssignment(arg.name, assignments); |
| if (!assignment) |
| return emitError(genContext.getLoc()) |
| << "no assignment found for output argument " << arg.name; |
| ++generatedAssignmentCount; |
| |
| // Recursively generate the expression. |
| std::function<std::optional<std::string>(ScalarExpression &)> |
| generateExpression = |
| [&](ScalarExpression &expression) -> std::optional<std::string> { |
| if (expression.arg) { |
| // Argument reference. |
| std::optional<int> argIndex = |
| findTensorDefArgIndex(*expression.arg, args); |
| if (!argIndex) { |
| emitError(genContext.getLoc()) |
| << "scalar argument not defined on the op: " << *expression.arg; |
| return std::nullopt; |
| } |
| return std::string( |
| llvm::formatv("block.getArgument({0})", *argIndex)); |
| } |
| if (expression.constant) { |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back( |
| llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT", |
| cppIdent, expression.constant)); |
| return cppIdent; |
| } |
| if (expression.index) { |
| // Access an iteration index. |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back(llvm::formatv("Value {0} = helper.index({1});", |
| cppIdent, *expression.index)); |
| return cppIdent; |
| } |
| if (expression.scalarFn) { |
| std::string enumName = |
| convertFunctionKindToEnumName(expression.scalarFn->kind); |
| |
| // Get the function or attribute name. |
| assert(expression.scalarFn->fnName || expression.scalarFn->attrName); |
| std::string funcType; |
| if (expression.scalarFn->fnName) { |
| funcType = llvm::formatv("{0}::{1}", enumName, |
| *expression.scalarFn->fnName); |
| } |
| if (expression.scalarFn->attrName) { |
| if (llvm::none_of(args, [&](LinalgOperandDef &arg) { |
| return isFunctionAttribute(arg.kind) && |
| arg.name == *expression.scalarFn->attrName; |
| })) { |
| emitError(genContext.getLoc()) << "missing function attribute " |
| << *expression.scalarFn->attrName; |
| } |
| funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName); |
| } |
| assert(!funcType.empty()); |
| |
| // Add the optional type parameter to the operands. |
| SmallVector<std::string> operandCppValues; |
| if (expression.scalarFn->kind == ScalarFnKind::Type) { |
| assert(expression.scalarFn->typeVar.has_value()); |
| std::optional<std::string> typeCppValue = |
| findTypeValue(*expression.scalarFn->typeVar, args); |
| if (!typeCppValue) { |
| emitError(genContext.getLoc()) |
| << "type variable " << *expression.scalarFn->typeVar |
| << ", used in a type conversion, must map to a predefined or " |
| << "an argument type but it does not"; |
| return std::nullopt; |
| } |
| operandCppValues.push_back(*typeCppValue); |
| } |
| |
| // Collect the scalar operands. |
| for (ScalarExpression &operand : expression.scalarFn->operands) { |
| auto operandCppValue = generateExpression(operand); |
| if (!operandCppValue) |
| return std::nullopt; |
| operandCppValues.push_back(*operandCppValue); |
| } |
| |
| // Call the function builder. |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back(llvm::formatv( |
| "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName, |
| funcType, interleaveToString(operandCppValues, ", "))); |
| return cppIdent; |
| } |
| emitError(genContext.getLoc()) << "unknown ScalarExpression type"; |
| return std::nullopt; |
| }; |
| std::optional<std::string> cppValue = |
| generateExpression(assignment->value); |
| if (!cppValue) |
| return failure(); |
| stmts.push_back(llvm::formatv("yields.push_back({0});", *cppValue)); |
| } |
| |
| if (generatedAssignmentCount != assignments.size()) |
| return emitError(genContext.getLoc()) |
| << "mismatched number of assignments vs output arguments"; |
| |
| os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, |
| interleaveToString(attrs, "\n "), |
| interleaveToString(stmts, "\n ")); |
| } |
| |
| // Parser and printer. |
| os << llvm::formatv(structuredOpParserFormat, className); |
| |
| // Canonicalizers and folders. |
| os << llvm::formatv(structuredOpFoldersFormat, className); |
| |
| return success(); |
| } |
| |
| static LogicalResult generateOp(LinalgOpConfig &opConfig, |
| GenerationContext &genContext) { |
| // Switch on op type being generated. |
| if (opConfig.structuredOp) { |
| return success( |
| succeeded(generateNamedGenericOpOds(opConfig, genContext)) && |
| succeeded(generateNamedGenericOpDefns(opConfig, genContext))); |
| } |
| return emitError(genContext.getLoc()) << "unsupported operation type"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Command line options and main |
| //===----------------------------------------------------------------------===// |
| |
| static llvm::cl::opt<std::string> |
| inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), |
| llvm::cl::init("-"), llvm::cl::value_desc("YAML filename")); |
| |
| static llvm::cl::opt<std::string> |
| outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"), |
| llvm::cl::value_desc("filename"), llvm::cl::init("")); |
| |
| static llvm::cl::opt<std::string> |
| outputCppImplFilename("o-impl", |
| llvm::cl::desc("C++ implementation file name"), |
| llvm::cl::value_desc("filename"), llvm::cl::init("")); |
| |
| int main(int argc, char **argv) { |
| llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML"); |
| |
| // Set up the input file. |
| std::string errorMessage; |
| std::unique_ptr<llvm::MemoryBuffer> file = |
| mlir::openInputFile(inputFilename, &errorMessage); |
| if (!file) { |
| llvm::errs() << errorMessage << "\n"; |
| return 1; |
| } |
| |
| MLIRContext mlirContext; |
| LinalgYAMLContext yamlContext{&mlirContext}; |
| |
| std::vector<LinalgOpConfig> opConfigs; |
| |
| // Parse input. |
| Input yin(file->getBuffer(), &yamlContext); |
| yin >> opConfigs; |
| |
| if (yin.error()) |
| return 1; |
| |
| // Open output files. |
| std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl; |
| if (!outputOdsDeclFilename.empty()) { |
| outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage); |
| if (!outputOdsDecl) { |
| llvm::errs() << errorMessage << "\n"; |
| return 1; |
| } |
| } |
| |
| std::unique_ptr<llvm::ToolOutputFile> outputCppImpl; |
| if (!outputCppImplFilename.empty()) { |
| outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage); |
| if (!outputCppImpl) { |
| llvm::errs() << errorMessage << "\n"; |
| return 1; |
| } |
| } |
| |
| if (!outputOdsDecl && !outputCppImpl) { |
| llvm::errs() << "error: No output files specified\n"; |
| return 1; |
| } |
| |
| // Generate. |
| GenerationContext genContext(&mlirContext, |
| outputOdsDecl ? &outputOdsDecl->os() : nullptr, |
| outputCppImpl ? &outputCppImpl->os() : nullptr); |
| |
| for (auto &opConfig : opConfigs) { |
| if (!opConfig.metadata) { |
| emitError(genContext.getLoc()) |
| << "missing operation metadata on subsequent op"; |
| return 1; |
| } |
| |
| genContext.setLoc(NameLoc::get( |
| StringAttr::get(&mlirContext, opConfig.metadata->cppClassName))); |
| if (failed(generateOp(opConfig, genContext))) { |
| return 1; |
| } |
| } |
| |
| if (outputOdsDecl) |
| outputOdsDecl->keep(); |
| if (outputCppImpl) |
| outputCppImpl->keep(); |
| |
| return 0; |
| } |