| //===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements an ODS (and C++) generator from a YAML form |
| // derived from the mathematical expression of linalg named ops. Typically a |
| // math oriented DSL will be used to export the essential representation to |
| // this form, and maintaining the SOT at the math level (versus recreating it |
| // in MLIR) is deemed to have systemic value. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/Parser.h" |
| #include "mlir/Support/FileUtilities.h" |
| #include "mlir/Support/LLVM.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/ToolOutputFile.h" |
| #include "llvm/Support/YAMLTraits.h" |
| |
| using namespace mlir; |
| |
| using llvm::yaml::Input; |
| using llvm::yaml::IO; |
| using llvm::yaml::MappingTraits; |
| using llvm::yaml::ScalarEnumerationTraits; |
| using llvm::yaml::ScalarTraits; |
| |
| #define DEBUG_TYPE "linalg-ods-gen" |
| |
| //===----------------------------------------------------------------------===// |
| // Mapping structs (correspond to data types in the YAML description). |
| // TODO: Since this is a schema/part of the contract, it should be moved to |
| // a real header. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| struct LinalgYAMLContext { |
| MLIRContext *mlirContext; |
| }; |
| |
| struct LinalgOpMetadata { |
| std::string name; |
| std::string cppClassName; |
| Optional<std::string> doc; |
| SmallVector<std::string> implements; |
| }; |
| |
| struct SerializedAffineMap { |
| AffineMapAttr affineMapAttr; |
| |
| AffineMap affineMap() { return affineMapAttr.getValue(); } |
| }; |
| |
| enum class LinalgOperandDefUsage { input, output, attribute }; |
| |
| struct LinalgOperandDef { |
| std::string name; |
| LinalgOperandDefUsage usage; |
| std::string typeVar; |
| Optional<SerializedAffineMap> shapeMap; |
| Optional<SerializedAffineMap> attributeMap; |
| }; |
| |
| enum class LinalgIteratorTypeDef { |
| parallel, |
| reduction, |
| }; |
| |
| struct LinalgIndexingMapsConfig { |
| Optional<SmallVector<SerializedAffineMap>> staticIndexingMaps; |
| }; |
| |
| struct ScalarExpression; |
| |
| struct ScalarApply { |
| std::string fnName; |
| // NOTE: Must be pure heap allocated container (not SmallVector) |
| // due to recursive data type. |
| std::vector<ScalarExpression> operands; |
| }; |
| |
| struct ScalarSymbolicCast { |
| std::string typeVar; |
| // NOTE: This must be of arity 1, but to break the self-referential cycle, |
| // we use a heap allocated vector. |
| std::vector<ScalarExpression> operands; |
| bool isUnsignedCast; |
| }; |
| |
| struct ScalarExpression { |
| Optional<std::string> arg; |
| Optional<std::string> constant; |
| Optional<int64_t> index; |
| Optional<ScalarApply> apply; |
| Optional<ScalarSymbolicCast> symbolicCast; |
| }; |
| |
| struct ScalarAssign { |
| std::string arg; |
| ScalarExpression value; |
| }; |
| |
| struct LinalgStructuredOpConfig { |
| SmallVector<LinalgOperandDef> args; |
| LinalgIndexingMapsConfig indexingMaps; |
| SmallVector<LinalgIteratorTypeDef> iteratorTypes; |
| std::vector<ScalarAssign> assignments; |
| }; |
| |
| struct LinalgOpConfig { |
| Optional<LinalgOpMetadata> metadata; |
| Optional<LinalgStructuredOpConfig> structuredOp; |
| }; |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Mapping traits. |
| //===----------------------------------------------------------------------===// |
| |
| LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) |
| LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression) |
| LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig) |
| |
| namespace llvm { |
| namespace yaml { |
| |
| /// Top-level type containing op metadata and one of a concrete op type. |
| /// Currently, the only defined op type is `structured_op` (maps to |
| /// `LinalgStructuredOpConfig`). |
| template <> struct MappingTraits<LinalgOpConfig> { |
| static void mapping(IO &io, LinalgOpConfig &info) { |
| io.mapOptional("metadata", info.metadata); |
| io.mapOptional("structured_op", info.structuredOp); |
| } |
| }; |
| |
| /// A structured op models (at most) a single contraction by modeling |
| /// - A list of named arguments (`LinalgOperandDef`), which can be inputs, |
| /// outputs, or index attributes. |
| /// - List of indexing maps (see `LinalgIndexingMaps`). |
| /// - Iterator types (see `LinalgIteratorTypeDef`). |
| /// - List of scalar level assignment (see `ScalarAssign`). |
| template <> struct MappingTraits<LinalgStructuredOpConfig> { |
| static void mapping(IO &io, LinalgStructuredOpConfig &info) { |
| io.mapRequired("args", info.args); |
| io.mapRequired("indexing_maps", info.indexingMaps); |
| io.mapRequired("iterator_types", info.iteratorTypes); |
| io.mapRequired("assignments", info.assignments); |
| } |
| }; |
| |
| /// Maps a named tensor, scalar or attribute argument to an operation, |
| /// consisting of: |
| /// - `name`: Must be unique within the operation. |
| /// - `usage`: How the argument is used (input, output, attribute, etc). |
| /// - `type_var`: The symbolic type variable that binds to the element or self |
| /// type of the tensor or scalar argument, respectively. |
| /// - `shape_map`: An optional AffineMap from all op symbols to the shape of |
| /// the argument. Only tensor arguments have a `shape_map`. Each shape must |
| /// be normalized over the same list of symbols and have no dimension |
| /// inputs. |
| /// - `attribute_map`: An optional AffineMap from all op symbols to the |
| /// attribute symbols. During op creation these symbols are replaced by the |
| /// corresponding `name` attribute values. Only attribute arguments have |
| /// an `attribute_map`. |
| template <> struct MappingTraits<LinalgOperandDef> { |
| static void mapping(IO &io, LinalgOperandDef &info) { |
| io.mapRequired("name", info.name); |
| io.mapRequired("usage", info.usage); |
| io.mapRequired("type_var", info.typeVar); |
| io.mapOptional("shape_map", info.shapeMap); |
| io.mapOptional("attribute_map", info.attributeMap); |
| } |
| }; |
| |
| /// Usage enum for a named argument. |
| template <> struct ScalarEnumerationTraits<LinalgOperandDefUsage> { |
| static void enumeration(IO &io, LinalgOperandDefUsage &value) { |
| io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input); |
| io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output); |
| io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute); |
| } |
| }; |
| |
| /// Iterator type enum. |
| template <> struct ScalarEnumerationTraits<LinalgIteratorTypeDef> { |
| static void enumeration(IO &io, LinalgIteratorTypeDef &value) { |
| io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); |
| io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); |
| } |
| }; |
| |
| /// Metadata about the op (name, C++ name, and documentation). |
| template <> struct MappingTraits<LinalgOpMetadata> { |
| static void mapping(IO &io, LinalgOpMetadata &info) { |
| io.mapRequired("name", info.name); |
| io.mapRequired("cpp_class_name", info.cppClassName); |
| io.mapOptional("doc", info.doc); |
| io.mapOptional("implements", info.implements); |
| } |
| }; |
| |
| /// How the ops indexing maps are produced. Must be one of: |
| /// - static_indexing_maps: A static list of AffineMaps, possibly with |
| /// some symbols that bind to attributes of the op. Each indexing map must |
| /// be normalized over the same list of dimensions, and its symbols must |
| /// match the symbols for argument shapes. |
| template <> struct MappingTraits<LinalgIndexingMapsConfig> { |
| static void mapping(IO &io, LinalgIndexingMapsConfig &info) { |
| io.mapOptional("static_indexing_maps", info.staticIndexingMaps); |
| } |
| }; |
| |
| /// Models an assignment to a named output. |
| /// - The `arg` name must match a named output. |
| /// - The `value` is a scalar expression for computing the value to |
| /// assign (see `ScalarExpression`). |
| template <> struct MappingTraits<ScalarAssign> { |
| static void mapping(IO &io, ScalarAssign &info) { |
| io.mapRequired("arg", info.arg); |
| io.mapRequired("value", info.value); |
| } |
| }; |
| |
| /// A scalar expression (RHS of an assignment). Must be one of: |
| /// - `scalar_arg`: Name of an argument to the op. |
| /// - `scalar_apply`: Result of evaluating a named function (see |
| /// `ScalarApply`). |
| /// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere. |
| template <> struct MappingTraits<ScalarExpression> { |
| static void mapping(IO &io, ScalarExpression &info) { |
| io.mapOptional("scalar_arg", info.arg); |
| io.mapOptional("scalar_const", info.constant); |
| io.mapOptional("scalar_index", info.index); |
| io.mapOptional("scalar_apply", info.apply); |
| io.mapOptional("symbolic_cast", info.symbolicCast); |
| } |
| }; |
| |
| /// A scalar expression that evaluates a named function. |
| /// Functions are generally "math" level and type polymorphic. Builtin |
| /// functions include: |
| /// - `add(lhs, rhs)` |
| /// - `mul(lhs, rhs)` |
| template <> struct MappingTraits<ScalarApply> { |
| static void mapping(IO &io, ScalarApply &info) { |
| io.mapRequired("fn_name", info.fnName); |
| io.mapRequired("operands", info.operands); |
| } |
| }; |
| |
| template <> struct MappingTraits<ScalarSymbolicCast> { |
| static void mapping(IO &io, ScalarSymbolicCast &info) { |
| io.mapRequired("type_var", info.typeVar); |
| io.mapRequired("operands", info.operands); |
| io.mapRequired("is_unsigned_cast", info.isUnsignedCast); |
| } |
| }; |
| |
| /// Helper mapping which accesses an AffineMapAttr as a serialized string of |
| /// the same. |
| template <> struct ScalarTraits<SerializedAffineMap> { |
| static void output(const SerializedAffineMap &value, void *rawYamlContext, |
| raw_ostream &out) { |
| assert(value.affineMapAttr); |
| value.affineMapAttr.print(out); |
| } |
| static StringRef input(StringRef scalar, void *rawYamlContext, |
| SerializedAffineMap &value) { |
| assert(rawYamlContext); |
| auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext); |
| if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) |
| .dyn_cast_or_null<AffineMapAttr>()) |
| value.affineMapAttr = attr; |
| else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>()) |
| return "could not parse as an affine map attribute"; |
| return StringRef(); |
| } |
| static QuotingType mustQuote(StringRef) { return QuotingType::None; } |
| }; |
| |
| } // namespace yaml |
| } // namespace llvm |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Generation utilities |
| //===----------------------------------------------------------------------===// |
| |
| class GenerationContext { |
| public: |
| GenerationContext(MLIRContext *context, raw_ostream *odsOut, |
| raw_ostream *defnOut) |
| : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut), |
| defnOut(defnOut) {} |
| |
| MLIRContext *getContext() { return context; } |
| |
| void setLoc(Location loc) { this->loc = loc; } |
| Location getLoc() { return loc; } |
| |
| bool shouldGenerateOds() { return odsOut; } |
| bool shouldGenerateDefns() { return defnOut; } |
| |
| raw_ostream &odss() { |
| assert(odsOut && "ODS stream not defined"); |
| return *odsOut; |
| } |
| |
| raw_ostream &defns() { |
| assert(defnOut && "Definition stream not defined"); |
| return *defnOut; |
| } |
| |
| private: |
| MLIRContext *context; |
| Location loc; |
| raw_ostream *odsOut; |
| raw_ostream *defnOut; |
| }; |
| |
| } // namespace |
| |
| static std::string generateCppExpression(SerializedAffineMap self, |
| StringRef contextName) { |
| std::string printedStr; |
| llvm::raw_string_ostream printedSs(printedStr); |
| self.affineMapAttr.print(printedSs); |
| printedSs.flush(); |
| |
| static const char exprFormat[] = |
| R"FMT(mlir::parseAttribute("{0}", {1}).cast<AffineMapAttr>().getValue())FMT"; |
| return llvm::formatv(exprFormat, printedStr, contextName); |
| } |
| |
| template <typename Container> |
| static std::string interleaveToString(Container &container, |
| StringRef separator) { |
| std::string result; |
| llvm::raw_string_ostream ss(result); |
| llvm::interleave(container, ss, separator); |
| ss.flush(); |
| return result; |
| } |
| |
| static Optional<int> |
| findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) { |
| for (auto it : llvm::enumerate(args)) { |
| if (it.value().name == name) |
| return it.index(); |
| } |
| return None; |
| } |
| |
| // Try to map the TypeVar to a predefined or an argument type. |
| static Optional<std::string> |
| findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) { |
| // Handle all predefined types. |
| if (typeVar == "I32") |
| return std::string("helper.getIntegerType(32)"); |
| if (typeVar == "I64") |
| return std::string("helper.getIntegerType(64)"); |
| if (typeVar == "F32") |
| return std::string("helper.getFloat32Type()"); |
| if (typeVar == "F64") |
| return std::string("helper.getFloat64Type()"); |
| |
| // Search all argument types. |
| for (auto it : llvm::enumerate(args)) { |
| if (it.value().typeVar == typeVar) |
| return llvm::formatv("block.getArgument({0}).getType()", it.index()) |
| .str(); |
| } |
| |
| return None; |
| } |
| |
| static ScalarAssign *findAssignment(StringRef name, |
| std::vector<ScalarAssign> &assignments) { |
| for (auto &assign : assignments) { |
| if (assign.arg == name) |
| return &assign; |
| } |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Templates |
| //===----------------------------------------------------------------------===// |
| |
| // A single line banner format. Parameters: |
| // {0}: Single line comment |
| static const char bannerFormat[] = R"FMT( |
| //===----------------------------------------------------------------------===// |
| // {0} |
| //===----------------------------------------------------------------------===// |
| )FMT"; |
| |
| //===----------------------------------------------------------------------===// |
| // Named generic op generation. |
| // These ops map at most a single contraction that complies with the limitations |
| // of a linalg.generic. |
| //===----------------------------------------------------------------------===// |
| |
| // Template for Linalg named ops' ODS definitions. Parameters: |
| // {0}: ODS/C++ op name |
| // {1}: assembly op mnemonic |
| // {2}: op interface list |
| // {3}: documentation (summary + description) |
| // {4}: op attribute list |
| // {5}: builder methods taking standalone attribute parameters |
| // {6}: additional methods for attributes used by indexing maps |
| static const char structuredOpOdsHeaderFormat[] = R"FMT( |
| //===----------------------------------------------------------------------===// |
| // Op definition for {0} |
| //===----------------------------------------------------------------------===// |
| |
| def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], |
| /*extraInterfaces=*/[{2}])> { |
| {3} |
| let arguments = (ins |
| Variadic<AnyType>:$inputs, |
| Variadic<AnyShaped>:$outputs{4} |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$result_tensors); |
| let regions = (region AnyRegion:$region); |
| |
| let skipDefaultBuilders = 1; |
| let builders = [ |
| OpBuilder< |
| (ins "ValueRange":$inputs, "ValueRange":$outputs, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| $_state.addOperands(inputs); |
| $_state.addOperands(outputs); |
| SmallVector<Type> resultTensorTypes; |
| copy_if(outputs.getTypes(), |
| std::back_inserter(resultTensorTypes), |
| [](Type type) {{ return type.isa<RankedTensorType>(); }); |
| $_state.addTypes(resultTensorTypes); |
| $_state.addAttribute( |
| "operand_segment_sizes", |
| $_builder.getI32VectorAttr({{ |
| static_cast<int32_t>(inputs.size()), |
| static_cast<int32_t>(outputs.size())})); |
| $_state.addAttributes(attributes); |
| createAndFillStructuredOpRegion<{0}>( |
| $_builder, |
| $_state, |
| TypeRange(inputs), |
| TypeRange(outputs)); |
| }]>, |
| OpBuilder< |
| (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, |
| "ValueRange":$outputs, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| $_state.addOperands(inputs); |
| $_state.addOperands(outputs); |
| $_state.addTypes(resultTensorTypes); |
| $_state.addAttributes(attributes); |
| $_state.addAttribute( |
| "operand_segment_sizes", |
| $_builder.getI32VectorAttr({{ |
| static_cast<int32_t>(inputs.size()), |
| static_cast<int32_t>(outputs.size())})); |
| createAndFillStructuredOpRegion<{0}>( |
| $_builder, |
| $_state, |
| TypeRange(inputs), |
| TypeRange(outputs)); |
| }]>, |
| OpBuilder< |
| (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| $_state.addOperands(operands); |
| $_state.addAttributes(attributes); |
| $_state.addTypes(resultTensorTypes); |
| (void)$_state.addRegion(); |
| }]> |
| {5} |
| ]; |
| let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; |
| let parser = [{{ |
| return ::parseNamedStructuredOp<{0}>(parser, result); |
| }]; |
| let hasFolder = 1; |
| |
| let extraClassDeclaration = structuredOpsBaseDecls # [{{ |
| // Auto-generated. |
| ArrayAttr iterator_types(); |
| ArrayAttr indexing_maps(); |
| static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); |
| static std::function<void(ImplicitLocOpBuilder &b, Block &)> |
| getRegionBuilder() {{ |
| return regionBuilder; |
| } |
| |
| // Generic methods. |
| static unsigned getNumRegionArgs(); |
| std::string getLibraryCallName(); |
| {6} |
| }]; |
| } |
| )FMT"; |
| |
| // Builder method taking attribute parameters. Parameters: |
| // {0}: Class name |
| // {1}: Comma interleaved attribute parameters |
| // {2}: Attribute initialization |
| static const char structuredOpBuilderFormat[] = R"FMT( |
| , OpBuilder< |
| (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, |
| "ValueRange":$outputs, {1}, |
| CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), |
| [{{ |
| $_state.addOperands(inputs); |
| $_state.addOperands(outputs); |
| $_state.addTypes(resultTensorTypes); |
| $_state.addAttribute( |
| "operand_segment_sizes", |
| $_builder.getI32VectorAttr({{ |
| static_cast<int32_t>(inputs.size()), |
| static_cast<int32_t>(outputs.size())})); |
| createAndFillStructuredOpRegion<{0}>( |
| $_builder, |
| $_state, |
| TypeRange(inputs), |
| TypeRange(outputs)); |
| {2} |
| $_state.addAttributes(attributes); |
| }]> |
| )FMT"; |
| |
| // The iterator_types() method implementation. Parameters: |
| // {0}: Class name |
| // {1}: Comma interleaved iterator type names. |
| static const char structuredOpIteratorTypesFormat[] = |
| R"FMT( |
| ArrayAttr {0}::iterator_types() { |
| return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} }); |
| } |
| )FMT"; |
| |
| // Implementations of fold and getEffects. |
| // Parameters: |
| // {0}: Class name |
| const char structuredOpFoldersFormat[] = R"FMT( |
| LogicalResult {0}::fold(ArrayRef<Attribute>, |
| SmallVectorImpl<OpFoldResult> &) {{ |
| return foldMemRefCast(*this); |
| } |
| void {0}::getEffects(SmallVectorImpl< |
| SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{ |
| SmallVector<Value> inputBuffers = getInputBufferOperands(); |
| SmallVector<Value> outputBuffers = getOutputBufferOperands(); |
| getGenericEffectsImpl(effects, |
| getOperation()->getResults(), inputBuffers, outputBuffers); |
| } |
| )FMT"; |
| |
| static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, |
| GenerationContext &genContext) { |
| if (!genContext.shouldGenerateOds()) |
| return success(); |
| |
| raw_ostream &os = genContext.odss(); |
| |
| std::string interfaceNameList; |
| std::string attrList; |
| std::string attrMethods; |
| std::string attrBuilder; |
| |
| std::string doc; |
| if (opConfig.metadata->doc) { |
| static const char structuredOpDocFmt[] = R"FMT( |
| let summary = [{ {0} }]; |
| let description = [{ |
| {1} |
| }]; |
| )FMT"; |
| StringRef summary, description; |
| std::tie(summary, description) = |
| StringRef(*opConfig.metadata->doc).trim().split('\n'); |
| doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim()); |
| } |
| |
| interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); |
| |
| // Assemble the attribute specific logic required for the op definition. |
| if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return arg.usage == LinalgOperandDefUsage::attribute; |
| })) { |
| SmallVector<std::string> attrDefs; |
| SmallVector<std::string> attrParams; |
| SmallVector<std::string> attrStmts; |
| for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
| if (arg.usage != LinalgOperandDefUsage::attribute) |
| continue; |
| assert(arg.attributeMap.hasValue() && arg.typeVar == "I64"); |
| static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}"; |
| static const char paramFmt[] = "\"Attribute\":${0}"; |
| static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; |
| attrDefs.push_back(llvm::formatv( |
| defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name)); |
| attrParams.push_back(llvm::formatv(paramFmt, arg.name)); |
| attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); |
| } |
| attrList = ",\n" + llvm::join(attrDefs, ",\n"); |
| attrMethods = R"( |
| bool hasDynamicIndexingMaps(); |
| LogicalResult verifyIndexingMapRequiredAttributes(); |
| )"; |
| attrBuilder = llvm::formatv( |
| structuredOpBuilderFormat, opConfig.metadata->cppClassName, |
| llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); |
| } |
| |
| os << llvm::formatv(structuredOpOdsHeaderFormat, |
| opConfig.metadata->cppClassName, opConfig.metadata->name, |
| interfaceNameList, doc, attrList, attrBuilder, |
| attrMethods); |
| |
| return success(); |
| } |
| |
| static LogicalResult |
| generateNamedGenericOpDefns(LinalgOpConfig &opConfig, |
| GenerationContext &genContext) { |
| if (!genContext.shouldGenerateDefns()) |
| return success(); |
| |
| raw_ostream &os = genContext.defns(); |
| StringRef className = opConfig.metadata->cppClassName; |
| |
| // Implementation banner. |
| std::string bannerComment = llvm::formatv("Implementation of {0}", className); |
| os << llvm::formatv(bannerFormat, bannerComment); |
| |
| // Compute the number of scalar and tensor arguments. |
| int64_t numOfArgs = |
| llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return arg.usage != LinalgOperandDefUsage::attribute; |
| }); |
| |
| // Reference iterators. |
| { |
| std::string iteratorsStr; |
| llvm::raw_string_ostream ss(iteratorsStr); |
| llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, |
| [&](LinalgIteratorTypeDef it) { |
| switch (it) { |
| case LinalgIteratorTypeDef::parallel: |
| ss << "getParallelIteratorTypeName()"; |
| break; |
| case LinalgIteratorTypeDef::reduction: |
| ss << "getReductionIteratorTypeName()"; |
| break; |
| } |
| }); |
| ss.flush(); |
| os << llvm::formatv(structuredOpIteratorTypesFormat, className, |
| iteratorsStr); |
| } |
| |
| // Static indexing maps. |
| if (auto &staticMaps = |
| opConfig.structuredOp->indexingMaps.staticIndexingMaps) { |
| if (staticMaps->empty()) |
| return emitError(genContext.getLoc()) << "op has no indexing maps"; |
| AffineMap firstMap = staticMaps->front().affineMap(); |
| |
| // Symbol bindings. |
| { |
| // For each symbol, generate a declaration for it, either with an |
| // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from |
| // an attribute). |
| // TODO: Possibly lift into a top-level method. |
| static const char structuredOpSymbolBindingsFormat[] = R"FMT( |
| static SmallVector<AffineExpr> getSymbolBindings({0} self) { |
| MLIRContext *context = self.getContext(); |
| SmallVector<AffineExpr> exprs; |
| {1} |
| return exprs; |
| } |
| )FMT"; |
| |
| unsigned symbolCount = firstMap.getNumSymbols(); |
| SmallVector<std::string> symbolBindings; |
| for (unsigned i = 0; i < symbolCount; ++i) { |
| symbolBindings.push_back(llvm::formatv( |
| " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); |
| } |
| |
| // Access an index attribute. Parameters: |
| // {0}: Attribute name |
| // {1}: Symbol position |
| // {2}: Attribute index |
| static const char structuredOpAccessAttrFormat[] = R"FMT( |
| int64_t cst{1} = self.{0}().getValues<int64_t>()[{2}]; |
| exprs.push_back(getAffineConstantExpr(cst{1}, context)); |
| )FMT"; |
| // Update all symbol bindings mapped to an attribute. |
| for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
| if (arg.usage != LinalgOperandDefUsage::attribute) |
| continue; |
| assert(arg.attributeMap.hasValue()); |
| for (auto &en : |
| llvm::enumerate(arg.attributeMap->affineMap().getResults())) { |
| if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) { |
| symbolBindings[symbol.getPosition()] = |
| llvm::formatv(structuredOpAccessAttrFormat, arg.name, |
| symbol.getPosition(), en.index()); |
| } |
| } |
| } |
| |
| std::string symbolBindingsStr; |
| llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); |
| llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); |
| symbolBindingsSs.flush(); |
| |
| os << llvm::formatv(structuredOpSymbolBindingsFormat, className, |
| symbolBindingsStr); |
| } |
| |
| // Indexing maps. |
| { |
| // Parameters: |
| // {0}: Class name |
| // {1}: Comma-separated list of dimension variable names. |
| // {2}: Statements |
| static const char structuredOpIndexingMapsFormat[] = R"FMT( |
| ArrayAttr {0}::indexing_maps() { |
| static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; |
| ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr); |
| if (cached) |
| return cached; |
| |
| MLIRContext *context = getContext(); |
| auto symbolBindings = getSymbolBindings(*this); |
| SmallVector<AffineMap> maps; |
| {2} |
| cached = Builder(context).getAffineMapArrayAttr(maps); |
| getOperation()->setAttr(memoizeAttr, cached); |
| return cached; |
| } |
| )FMT"; |
| |
| unsigned dimCount = firstMap.getNumDims(); |
| |
| // Generate a comma-separated list of dim identifiers to be passed to |
| // bindDims, ensuring tht AffineExpr identifiers are bound in the right |
| // order to the proper AffineDimExpr. |
| // This results in vars in scope like: d0, d1, d2... |
| SmallVector<unsigned> dimIndices; |
| for (unsigned i = 0; i < dimCount; ++i) |
| dimIndices.push_back(i); |
| std::string dimIdentsStr; |
| llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); |
| llvm::interleaveComma(dimIndices, dimIdentsSs, |
| [&](unsigned i) { dimIdentsSs << "d" << i; }); |
| dimIdentsSs.flush(); |
| |
| // Statements to add and simplify each affine map. |
| SmallVector<std::string> stmts; |
| for (auto &indexingMap : *staticMaps) { |
| // TODO: Assert that dim and symbol count match the first. |
| stmts.push_back( |
| llvm::formatv("maps.push_back({0});", |
| generateCppExpression(indexingMap, "context"))); |
| stmts.push_back(llvm::formatv( |
| "maps.back() = " |
| "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " |
| "symbolBindings, {0}, 0));", |
| dimCount)); |
| } |
| |
| // TODO: This needs to be memoized and/or converted to non-parser based |
| // C++ codegen prior to real use. |
| os << llvm::formatv(structuredOpIndexingMapsFormat, className, |
| dimIdentsStr, interleaveToString(stmts, "\n ")); |
| } |
| } else { |
| return emitError(genContext.getLoc()) |
| << "generating code for non static indexing maps not currently " |
| "supported"; |
| } |
| |
| // getNumRegionArgs() |
| { |
| // Generates a getNumRegionArgs() method. Parameters: |
| // {0}: Class name |
| // {1}: Number of region args |
| static const char structuredOpGetNumRegionArgsFormat[] = R"FMT( |
| unsigned {0}::getNumRegionArgs() {{ return {1}; } |
| )FMT"; |
| os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className, |
| numOfArgs); |
| } |
| |
| // getLibraryCallName() |
| { |
| // Generates a getLibraryCallName method. Parameters: |
| // {0}: Class name |
| static const char structuredOpGetLibraryCallFormat[] = R"FMT( |
| std::string {0}::getLibraryCallName() {{ |
| return generateLibraryCallName(getOperation()); |
| } |
| )FMT"; |
| os << llvm::formatv(structuredOpGetLibraryCallFormat, className); |
| } |
| |
| // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() |
| if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { |
| return arg.usage == LinalgOperandDefUsage::attribute; |
| })) { |
| std::vector<std::string> attrVerifications; |
| for (LinalgOperandDef &arg : opConfig.structuredOp->args) { |
| if (arg.usage != LinalgOperandDefUsage::attribute) |
| continue; |
| assert(arg.attributeMap.hasValue() && arg.typeVar == "I64"); |
| // Verify index attribute. Paramters: |
| // {0}: Attribute name |
| // {1}: Attribute size |
| static const char attrFmt[] = R"FMT( |
| if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{ |
| if (!attr.getType().getElementType().isInteger(64)) |
| return op->emitError( |
| "incorrect element type for indexing map required attribute '{0}'"); |
| if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} }) |
| return op->emitError( |
| "incorrect shape for indexing map required attribute '{0}'"); |
| } else { |
| return op->emitError( |
| "missing indexing map required attribute '{0}'"); |
| } |
| )FMT"; |
| attrVerifications.push_back(llvm::formatv( |
| attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults())); |
| } |
| |
| // Generates the verifyIndexingMapRequiredAttributes method. Parameters: |
| // {0}: Class name |
| // {1}: Attribute verification |
| static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT( |
| bool {0}::hasDynamicIndexingMaps() {{ return true; } |
| LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ |
| Operation *op = getOperation(); |
| {1} |
| return success(); |
| } |
| )FMT"; |
| os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes, |
| className, llvm::join(attrVerifications, "\n")); |
| } |
| |
| // regionBuilder() |
| { |
| // Generates a regionBuilder method. Parameters. |
| // {0}: Class name |
| // {1}: Number of args |
| // {2}: Statements |
| static const char structuredOpRegionBuilderFormat[] = R"FMT( |
| void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{ |
| assert({1} > 0 && block.getNumArguments() == {1} && |
| "{0} regionBuilder expects {1} (>=0) args"); |
| RegionBuilderHelper helper(block.getArgument(0).getContext(), block); |
| SmallVector<Value> yields; |
| {2} |
| helper.yieldOutputs(yields); |
| } |
| )FMT"; |
| auto &args = opConfig.structuredOp->args; |
| auto &assignments = opConfig.structuredOp->assignments; |
| size_t generatedAssignmentCount = 0; |
| int localCounter = 0; |
| SmallVector<std::string> stmts; |
| for (LinalgOperandDef &arg : args) { |
| if (arg.usage != LinalgOperandDefUsage::output) |
| continue; |
| |
| // Find the assignment that correlates with the argument. |
| ScalarAssign *assignment = findAssignment(arg.name, assignments); |
| if (!assignment) |
| return emitError(genContext.getLoc()) |
| << "no assignment found for output argument " << arg.name; |
| ++generatedAssignmentCount; |
| |
| // Recursively generate the expression. |
| std::function<Optional<std::string>(ScalarExpression &)> |
| generateExpression = |
| [&](ScalarExpression &expression) -> Optional<std::string> { |
| if (expression.arg) { |
| // Argument reference. |
| Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args); |
| if (!argIndex) { |
| emitError(genContext.getLoc()) |
| << "scalar argument not defined on the op: " << *expression.arg; |
| return None; |
| } |
| return std::string( |
| llvm::formatv("block.getArgument({0})", *argIndex)); |
| } |
| if (expression.constant) { |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back( |
| llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT", |
| cppIdent, expression.constant)); |
| return cppIdent; |
| } |
| if (expression.index) { |
| // Access an iteration index. |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back(llvm::formatv("Value {0} = helper.index({1});", |
| cppIdent, *expression.index)); |
| return cppIdent; |
| } |
| if (expression.apply) { |
| // Apply function. |
| // Recursively generate operands. |
| SmallVector<std::string> operandCppValues; |
| for (ScalarExpression &operand : expression.apply->operands) { |
| auto operandCppValue = generateExpression(operand); |
| if (!operandCppValue) |
| return None; |
| operandCppValues.push_back(*operandCppValue); |
| } |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back( |
| llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent, |
| expression.apply->fnName, |
| interleaveToString(operandCppValues, ", "))); |
| return cppIdent; |
| } |
| if (expression.symbolicCast) { |
| // Symbolic cast. |
| // Operands must be arity 1. |
| if (expression.symbolicCast->operands.size() != 1) { |
| emitError(genContext.getLoc()) |
| << "symbolic_cast operand arity must be 1"; |
| return None; |
| } |
| Optional<std::string> operandCppValue = |
| generateExpression(expression.symbolicCast->operands[0]); |
| if (!operandCppValue) |
| return None; |
| |
| Optional<std::string> typeCppValue = |
| findTypeValue(expression.symbolicCast->typeVar, args); |
| if (!typeCppValue) { |
| emitError(genContext.getLoc()) |
| << "type variable " << expression.symbolicCast->typeVar |
| << ", used in a symbolic cast must map to a predefined or " |
| << "an argument type but it does not"; |
| return None; |
| } |
| std::string cppIdent = llvm::formatv("value{0}", ++localCounter); |
| stmts.push_back( |
| llvm::formatv("Value {0} = helper.cast({1}, {2}, {3});", cppIdent, |
| typeCppValue.getValue(), *operandCppValue, |
| expression.symbolicCast->isUnsignedCast)); |
| return cppIdent; |
| } |
| emitError(genContext.getLoc()) << "unknown ScalarExpression type"; |
| return None; |
| }; |
| Optional<std::string> cppValue = generateExpression(assignment->value); |
| if (!cppValue) |
| return failure(); |
| stmts.push_back(llvm::formatv("yields.push_back({0});", cppValue)); |
| } |
| |
| if (generatedAssignmentCount != assignments.size()) |
| return emitError(genContext.getLoc()) |
| << "mismatched number of assignments vs output arguments"; |
| |
| os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, |
| interleaveToString(stmts, "\n ")); |
| } |
| |
| // Canonicalizers and folders. |
| os << llvm::formatv(structuredOpFoldersFormat, className); |
| |
| return success(); |
| } |
| |
| static LogicalResult generateOp(LinalgOpConfig &opConfig, |
| GenerationContext &genContext) { |
| // Switch on op type being generated. |
| if (opConfig.structuredOp) { |
| return success( |
| succeeded(generateNamedGenericOpOds(opConfig, genContext)) && |
| succeeded(generateNamedGenericOpDefns(opConfig, genContext))); |
| } else { |
| return emitError(genContext.getLoc()) << "unsupported operation type"; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Command line options and main |
| //===----------------------------------------------------------------------===// |
| |
| static llvm::cl::opt<std::string> |
| inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), |
| llvm::cl::init("-"), llvm::cl::value_desc("YAML filename")); |
| |
| static llvm::cl::opt<std::string> |
| outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"), |
| llvm::cl::value_desc("filename"), llvm::cl::init("")); |
| |
| static llvm::cl::opt<std::string> |
| outputCppImplFilename("o-impl", |
| llvm::cl::desc("C++ implementation file name"), |
| llvm::cl::value_desc("filename"), llvm::cl::init("")); |
| |
| int main(int argc, char **argv) { |
| llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML"); |
| |
| // Set up the input file. |
| std::string errorMessage; |
| std::unique_ptr<llvm::MemoryBuffer> file = |
| mlir::openInputFile(inputFilename, &errorMessage); |
| if (!file) { |
| llvm::errs() << errorMessage << "\n"; |
| return 1; |
| } |
| |
| MLIRContext mlirContext; |
| LinalgYAMLContext yamlContext{&mlirContext}; |
| |
| std::vector<LinalgOpConfig> opConfigs; |
| |
| // Parse input. |
| Input yin(file->getBuffer(), &yamlContext); |
| yin >> opConfigs; |
| |
| if (yin.error()) |
| return 1; |
| |
| // Open output files. |
| std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl; |
| if (!outputOdsDeclFilename.empty()) { |
| outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage); |
| if (!outputOdsDecl) { |
| llvm::errs() << errorMessage << "\n"; |
| return 1; |
| } |
| } |
| |
| std::unique_ptr<llvm::ToolOutputFile> outputCppImpl; |
| if (!outputCppImplFilename.empty()) { |
| outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage); |
| if (!outputCppImpl) { |
| llvm::errs() << errorMessage << "\n"; |
| return 1; |
| } |
| } |
| |
| if (!outputOdsDecl && !outputCppImpl) { |
| llvm::errs() << "error: No output files specified\n"; |
| return 1; |
| } |
| |
| // Generate. |
| GenerationContext genContext(&mlirContext, |
| outputOdsDecl ? &outputOdsDecl->os() : nullptr, |
| outputCppImpl ? &outputCppImpl->os() : nullptr); |
| |
| for (auto &opConfig : opConfigs) { |
| if (!opConfig.metadata) { |
| emitError(genContext.getLoc()) |
| << "missing operation metadata on subsequent op"; |
| return 1; |
| } |
| |
| genContext.setLoc(NameLoc::get( |
| StringAttr::get(&mlirContext, opConfig.metadata->cppClassName))); |
| if (failed(generateOp(opConfig, genContext))) { |
| return 1; |
| } |
| } |
| |
| if (outputOdsDecl) |
| outputOdsDecl->keep(); |
| if (outputCppImpl) |
| outputCppImpl->keep(); |
| |
| return 0; |
| } |