| //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // OpDefinitionsGen uses the description of operations to generate C++ |
| // definitions for ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/CodeGenHelpers.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "mlir/TableGen/Pattern.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/Path.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace llvm; |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| |
| /// Generate a unique label based on the current file name to prevent name |
| /// collisions if multiple generated files are included at once. |
| static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records, |
| StringRef tag) { |
| // Use the input file name when generating a unique name. |
| std::string inputFilename = records.getInputFilename(); |
| |
| // Drop all but the base filename. |
| StringRef nameRef = llvm::sys::path::filename(inputFilename); |
| nameRef.consume_back(".td"); |
| |
| // Sanitize any invalid characters. |
| std::string uniqueName(tag); |
| for (char c : nameRef) { |
| if (llvm::isAlnum(c) || c == '_') |
| uniqueName.push_back(c); |
| else |
| uniqueName.append(llvm::utohexstr((unsigned char)c)); |
| } |
| return uniqueName; |
| } |
| |
| StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( |
| raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag) |
| : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} |
| |
| void StaticVerifierFunctionEmitter::emitOpConstraints( |
| ArrayRef<llvm::Record *> opDefs) { |
| NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); |
| emitTypeConstraints(); |
| emitAttrConstraints(); |
| emitSuccessorConstraints(); |
| emitRegionConstraints(); |
| } |
| |
| void StaticVerifierFunctionEmitter::emitPatternConstraints( |
| const llvm::ArrayRef<DagLeaf> constraints) { |
| collectPatternConstraints(constraints); |
| emitPatternConstraints(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constraint Getters |
| |
| StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( |
| const Constraint &constraint) const { |
| const auto *it = typeConstraints.find(constraint); |
| assert(it != typeConstraints.end() && "expected to find a type constraint"); |
| return it->second; |
| } |
| |
| // Find a uniqued attribute constraint. Since not all attribute constraints can |
| // be uniqued, return std::nullopt if one was not found. |
| std::optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn( |
| const Constraint &constraint) const { |
| const auto *it = attrConstraints.find(constraint); |
| return it == attrConstraints.end() ? std::optional<StringRef>() |
| : StringRef(it->second); |
| } |
| |
| StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( |
| const Constraint &constraint) const { |
| const auto *it = successorConstraints.find(constraint); |
| assert(it != successorConstraints.end() && |
| "expected to find a sucessor constraint"); |
| return it->second; |
| } |
| |
| StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( |
| const Constraint &constraint) const { |
| const auto *it = regionConstraints.find(constraint); |
| assert(it != regionConstraints.end() && |
| "expected to find a region constraint"); |
| return it->second; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constraint Emission |
| |
| /// Code templates for emitting type, attribute, successor, and region |
| /// constraints. Each of these templates require the following arguments: |
| /// |
| /// {0}: The unique constraint name. |
| /// {1}: The constraint code. |
| /// {2}: The constraint description. |
| |
| /// Code for a type constraint. These may be called on the type of either |
| /// operands or results. |
| static const char *const typeConstraintCode = R"( |
| static ::mlir::LogicalResult {0}( |
| ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, |
| unsigned valueIndex) { |
| if (!({1})) { |
| return op->emitOpError(valueKind) << " #" << valueIndex |
| << " must be {2}, but got " << type; |
| } |
| return ::mlir::success(); |
| } |
| )"; |
| |
| /// Code for an attribute constraint. These may be called from ops only. |
| /// Attribute constraints cannot reference anything other than `$_self` and |
| /// `$_op`. |
| /// |
| /// TODO: Unique constraints for adaptors. However, most Adaptor::verify |
| /// functions are stripped anyways. |
| static const char *const attrConstraintCode = R"( |
| static ::mlir::LogicalResult {0}( |
| ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ |
| if (attr && !({1})) |
| return emitError() << "attribute '" << attrName |
| << "' failed to satisfy constraint: {2}"; |
| return ::mlir::success(); |
| } |
| static ::mlir::LogicalResult {0}( |
| ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{ |
| return {0}(attr, attrName, [op]() {{ |
| return op->emitOpError(); |
| }); |
| } |
| )"; |
| |
| /// Code for a successor constraint. |
| static const char *const successorConstraintCode = R"( |
| static ::mlir::LogicalResult {0}( |
| ::mlir::Operation *op, ::mlir::Block *successor, |
| ::llvm::StringRef successorName, unsigned successorIndex) { |
| if (!({1})) { |
| return op->emitOpError("successor #") << successorIndex << " ('" |
| << successorName << ")' failed to verify constraint: {2}"; |
| } |
| return ::mlir::success(); |
| } |
| )"; |
| |
| /// Code for a region constraint. Callers will need to pass in the region's name |
| /// for emitting an error message. |
| static const char *const regionConstraintCode = R"( |
| static ::mlir::LogicalResult {0}( |
| ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, |
| unsigned regionIndex) { |
| if (!({1})) { |
| return op->emitOpError("region #") << regionIndex |
| << (regionName.empty() ? " " : " ('" + regionName + "') ") |
| << "failed to verify constraint: {2}"; |
| } |
| return ::mlir::success(); |
| } |
| )"; |
| |
| /// Code for a pattern type or attribute constraint. |
| /// |
| /// {3}: "Type type" or "Attribute attr". |
| static const char *const patternAttrOrTypeConstraintCode = R"( |
| static ::mlir::LogicalResult {0}( |
| ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, |
| ::llvm::StringRef failureStr) { |
| if (!({1})) { |
| return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { |
| diag << failureStr << ": {2}"; |
| }); |
| } |
| return ::mlir::success(); |
| } |
| )"; |
| |
| void StaticVerifierFunctionEmitter::emitConstraints( |
| const ConstraintMap &constraints, StringRef selfName, |
| const char *const codeTemplate) { |
| FmtContext ctx; |
| ctx.addSubst("_op", "*op").withSelf(selfName); |
| for (auto &it : constraints) { |
| os << formatv(codeTemplate, it.second, |
| tgfmt(it.first.getConditionTemplate(), &ctx), |
| escapeString(it.first.getSummary())); |
| } |
| } |
| |
| void StaticVerifierFunctionEmitter::emitTypeConstraints() { |
| emitConstraints(typeConstraints, "type", typeConstraintCode); |
| } |
| |
| void StaticVerifierFunctionEmitter::emitAttrConstraints() { |
| emitConstraints(attrConstraints, "attr", attrConstraintCode); |
| } |
| |
| void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { |
| emitConstraints(successorConstraints, "successor", successorConstraintCode); |
| } |
| |
| void StaticVerifierFunctionEmitter::emitRegionConstraints() { |
| emitConstraints(regionConstraints, "region", regionConstraintCode); |
| } |
| |
| void StaticVerifierFunctionEmitter::emitPatternConstraints() { |
| FmtContext ctx; |
| ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type"); |
| for (auto &it : typeConstraints) { |
| os << formatv(patternAttrOrTypeConstraintCode, it.second, |
| tgfmt(it.first.getConditionTemplate(), &ctx), |
| escapeString(it.first.getSummary()), "Type type"); |
| } |
| ctx.withSelf("attr"); |
| for (auto &it : attrConstraints) { |
| os << formatv(patternAttrOrTypeConstraintCode, it.second, |
| tgfmt(it.first.getConditionTemplate(), &ctx), |
| escapeString(it.first.getSummary()), "Attribute attr"); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constraint Uniquing |
| |
| /// An attribute constraint that references anything other than itself and the |
| /// current op cannot be generically extracted into a function. Most |
| /// prohibitive are operands and results, which require calls to |
| /// `getODSOperands` or `getODSResults`. Attribute references are tricky too |
| /// because ops use cached identifiers. |
| static bool canUniqueAttrConstraint(Attribute attr) { |
| FmtContext ctx; |
| auto test = tgfmt(attr.getConditionTemplate(), |
| &ctx.withSelf("attr").addSubst("_op", "*op")) |
| .str(); |
| return !StringRef(test).contains("<no-subst-found>"); |
| } |
| |
| std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind, |
| unsigned index) { |
| return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + |
| Twine(index)) |
| .str(); |
| } |
| |
| void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map, |
| StringRef kind, |
| Constraint constraint) { |
| auto *it = map.find(constraint); |
| if (it == map.end()) |
| map.insert({constraint, getUniqueName(kind, map.size())}); |
| } |
| |
| void StaticVerifierFunctionEmitter::collectOpConstraints( |
| ArrayRef<Record *> opDefs) { |
| const auto collectTypeConstraints = [&](Operator::const_value_range values) { |
| for (const NamedTypeConstraint &value : values) |
| if (value.hasPredicate()) |
| collectConstraint(typeConstraints, "type", value.constraint); |
| }; |
| |
| for (Record *def : opDefs) { |
| Operator op(*def); |
| /// Collect type constraints. |
| collectTypeConstraints(op.getOperands()); |
| collectTypeConstraints(op.getResults()); |
| /// Collect attribute constraints. |
| for (const NamedAttribute &namedAttr : op.getAttributes()) { |
| if (!namedAttr.attr.getPredicate().isNull() && |
| !namedAttr.attr.isDerivedAttr() && |
| canUniqueAttrConstraint(namedAttr.attr)) |
| collectConstraint(attrConstraints, "attr", namedAttr.attr); |
| } |
| /// Collect successor constraints. |
| for (const NamedSuccessor &successor : op.getSuccessors()) { |
| if (!successor.constraint.getPredicate().isNull()) { |
| collectConstraint(successorConstraints, "successor", |
| successor.constraint); |
| } |
| } |
| /// Collect region constraints. |
| for (const NamedRegion ®ion : op.getRegions()) |
| if (!region.constraint.getPredicate().isNull()) |
| collectConstraint(regionConstraints, "region", region.constraint); |
| } |
| } |
| |
| void StaticVerifierFunctionEmitter::collectPatternConstraints( |
| const llvm::ArrayRef<DagLeaf> constraints) { |
| for (auto &leaf : constraints) { |
| assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); |
| collectConstraint( |
| leaf.isOperandMatcher() ? typeConstraints : attrConstraints, |
| leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint()); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Public Utility Functions |
| //===----------------------------------------------------------------------===// |
| |
| std::string mlir::tblgen::escapeString(StringRef value) { |
| std::string ret; |
| llvm::raw_string_ostream os(ret); |
| os.write_escaped(value); |
| return os.str(); |
| } |