| //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the MLIR AsmPrinter class, which is used to implement |
| // the various print() methods on the core IR objects. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/AsmState.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinDialect.h" |
| #include "mlir/IR/BuiltinTypeInterfaces.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/DialectResourceBlobManager.h" |
| #include "mlir/IR/IntegerSet.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/Verifier.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include "llvm/ADT/ScopedHashTable.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallString.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/Endian.h" |
| #include "llvm/Support/ManagedStatic.h" |
| #include "llvm/Support/Regex.h" |
| #include "llvm/Support/SaveAndRestore.h" |
| #include "llvm/Support/Threading.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include <type_traits> |
| |
| #include <optional> |
| #include <tuple> |
| |
| using namespace mlir; |
| using namespace mlir::detail; |
| |
| #define DEBUG_TYPE "mlir-asm-printer" |
| |
| void OperationName::print(raw_ostream &os) const { os << getStringRef(); } |
| |
| void OperationName::dump() const { print(llvm::errs()); } |
| |
| //===--------------------------------------------------------------------===// |
| // AsmParser |
| //===--------------------------------------------------------------------===// |
| |
| AsmParser::~AsmParser() = default; |
| DialectAsmParser::~DialectAsmParser() = default; |
| OpAsmParser::~OpAsmParser() = default; |
| |
| MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); } |
| |
| /// Parse a type list. |
| /// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918 |
| ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) { |
| return parseCommaSeparatedList( |
| [&]() { return parseType(result.emplace_back()); }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DialectAsmPrinter |
| //===----------------------------------------------------------------------===// |
| |
| DialectAsmPrinter::~DialectAsmPrinter() = default; |
| |
| //===----------------------------------------------------------------------===// |
| // OpAsmPrinter |
| //===----------------------------------------------------------------------===// |
| |
| OpAsmPrinter::~OpAsmPrinter() = default; |
| |
| void OpAsmPrinter::printFunctionalType(Operation *op) { |
| auto &os = getStream(); |
| os << '('; |
| llvm::interleaveComma(op->getOperands(), os, [&](Value operand) { |
| // Print the types of null values as <<NULL TYPE>>. |
| *this << (operand ? operand.getType() : Type()); |
| }); |
| os << ") -> "; |
| |
| // Print the result list. We don't parenthesize single result types unless |
| // it is a function (avoiding a grammar ambiguity). |
| bool wrapped = op->getNumResults() != 1; |
| if (!wrapped && op->getResult(0).getType() && |
| llvm::isa<FunctionType>(op->getResult(0).getType())) |
| wrapped = true; |
| |
| if (wrapped) |
| os << '('; |
| |
| llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) { |
| // Print the types of null values as <<NULL TYPE>>. |
| *this << (result ? result.getType() : Type()); |
| }); |
| |
| if (wrapped) |
| os << ')'; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Operation OpAsm interface. |
| //===----------------------------------------------------------------------===// |
| |
| /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
| #include "mlir/IR/OpAsmInterface.cpp.inc" |
| |
| LogicalResult |
| OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { |
| return entry.emitError() << "unknown 'resource' key '" << entry.getKey() |
| << "' for dialect '" << getDialect()->getNamespace() |
| << "'"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpPrintingFlags |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This struct contains command line options that can be used to initialize |
| /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need |
| /// for global command line options. |
| struct AsmPrinterOptions { |
| llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{ |
| "mlir-print-elementsattrs-with-hex-if-larger", |
| llvm::cl::desc( |
| "Print DenseElementsAttrs with a hex string that have " |
| "more elements than the given upper limit (use -1 to disable)")}; |
| |
| llvm::cl::opt<unsigned> elideElementsAttrIfLarger{ |
| "mlir-elide-elementsattrs-if-larger", |
| llvm::cl::desc("Elide ElementsAttrs with \"...\" that have " |
| "more elements than the given upper limit")}; |
| |
| llvm::cl::opt<unsigned> elideResourceStringsIfLarger{ |
| "mlir-elide-resource-strings-if-larger", |
| llvm::cl::desc( |
| "Elide printing value of resources if string is too long in chars.")}; |
| |
| llvm::cl::opt<bool> printDebugInfoOpt{ |
| "mlir-print-debuginfo", llvm::cl::init(false), |
| llvm::cl::desc("Print debug info in MLIR output")}; |
| |
| llvm::cl::opt<bool> printPrettyDebugInfoOpt{ |
| "mlir-pretty-debuginfo", llvm::cl::init(false), |
| llvm::cl::desc("Print pretty debug info in MLIR output")}; |
| |
| // Use the generic op output form in the operation printer even if the custom |
| // form is defined. |
| llvm::cl::opt<bool> printGenericOpFormOpt{ |
| "mlir-print-op-generic", llvm::cl::init(false), |
| llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; |
| |
| llvm::cl::opt<bool> assumeVerifiedOpt{ |
| "mlir-print-assume-verified", llvm::cl::init(false), |
| llvm::cl::desc("Skip op verification when using custom printers"), |
| llvm::cl::Hidden}; |
| |
| llvm::cl::opt<bool> printLocalScopeOpt{ |
| "mlir-print-local-scope", llvm::cl::init(false), |
| llvm::cl::desc("Print with local scope and inline information (eliding " |
| "aliases for attributes, types, and locations")}; |
| |
| llvm::cl::opt<bool> skipRegionsOpt{ |
| "mlir-print-skip-regions", llvm::cl::init(false), |
| llvm::cl::desc("Skip regions when printing ops.")}; |
| |
| llvm::cl::opt<bool> printValueUsers{ |
| "mlir-print-value-users", llvm::cl::init(false), |
| llvm::cl::desc( |
| "Print users of operation results and block arguments as a comment")}; |
| |
| llvm::cl::opt<bool> printUniqueSSAIDs{ |
| "mlir-print-unique-ssa-ids", llvm::cl::init(false), |
| llvm::cl::desc("Print unique SSA ID numbers for values, block arguments " |
| "and naming conflicts across all regions")}; |
| }; |
| } // namespace |
| |
| static llvm::ManagedStatic<AsmPrinterOptions> clOptions; |
| |
| /// Register a set of useful command-line options that can be used to configure |
| /// various flags within the AsmPrinter. |
| void mlir::registerAsmPrinterCLOptions() { |
| // Make sure that the options struct has been initialized. |
| *clOptions; |
| } |
| |
| /// Initialize the printing flags with default supplied by the cl::opts above. |
| OpPrintingFlags::OpPrintingFlags() |
| : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), |
| printGenericOpFormFlag(false), skipRegionsFlag(false), |
| assumeVerifiedFlag(false), printLocalScope(false), |
| printValueUsersFlag(false), printUniqueSSAIDsFlag(false) { |
| // Initialize based upon command line options, if they are available. |
| if (!clOptions.isConstructed()) |
| return; |
| if (clOptions->elideElementsAttrIfLarger.getNumOccurrences()) |
| elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger; |
| if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) |
| elementsAttrHexElementLimit = |
| clOptions->printElementsAttrWithHexIfLarger.getValue(); |
| if (clOptions->elideResourceStringsIfLarger.getNumOccurrences()) |
| resourceStringCharLimit = clOptions->elideResourceStringsIfLarger; |
| printDebugInfoFlag = clOptions->printDebugInfoOpt; |
| printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; |
| printGenericOpFormFlag = clOptions->printGenericOpFormOpt; |
| assumeVerifiedFlag = clOptions->assumeVerifiedOpt; |
| printLocalScope = clOptions->printLocalScopeOpt; |
| skipRegionsFlag = clOptions->skipRegionsOpt; |
| printValueUsersFlag = clOptions->printValueUsers; |
| printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs; |
| } |
| |
| /// Enable the elision of large elements attributes, by printing a '...' |
| /// instead of the element data, when the number of elements is greater than |
| /// `largeElementLimit`. Note: The IR generated with this option is not |
| /// parsable. |
| OpPrintingFlags & |
| OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) { |
| elementsAttrElementLimit = largeElementLimit; |
| return *this; |
| } |
| |
| OpPrintingFlags & |
| OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) { |
| elementsAttrHexElementLimit = largeElementLimit; |
| return *this; |
| } |
| |
| OpPrintingFlags & |
| OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) { |
| resourceStringCharLimit = largeResourceLimit; |
| return *this; |
| } |
| |
| /// Enable printing of debug information. If 'prettyForm' is set to true, |
| /// debug information is printed in a more readable 'pretty' form. |
| OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable, |
| bool prettyForm) { |
| printDebugInfoFlag = enable; |
| printDebugInfoPrettyFormFlag = prettyForm; |
| return *this; |
| } |
| |
| /// Always print operations in the generic form. |
| OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) { |
| printGenericOpFormFlag = enable; |
| return *this; |
| } |
| |
| /// Always skip Regions. |
| OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) { |
| skipRegionsFlag = skip; |
| return *this; |
| } |
| |
| /// Do not verify the operation when using custom operation printers. |
| OpPrintingFlags &OpPrintingFlags::assumeVerified() { |
| assumeVerifiedFlag = true; |
| return *this; |
| } |
| |
| /// Use local scope when printing the operation. This allows for using the |
| /// printer in a more localized and thread-safe setting, but may not necessarily |
| /// be identical of what the IR will look like when dumping the full module. |
| OpPrintingFlags &OpPrintingFlags::useLocalScope() { |
| printLocalScope = true; |
| return *this; |
| } |
| |
| /// Print users of values as comments. |
| OpPrintingFlags &OpPrintingFlags::printValueUsers() { |
| printValueUsersFlag = true; |
| return *this; |
| } |
| |
| /// Return if the given ElementsAttr should be elided. |
| bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { |
| return elementsAttrElementLimit && |
| *elementsAttrElementLimit < int64_t(attr.getNumElements()) && |
| !llvm::isa<SplatElementsAttr>(attr); |
| } |
| |
| /// Return if the given ElementsAttr should be printed as hex string. |
| bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const { |
| // -1 is used to disable hex printing. |
| return (elementsAttrHexElementLimit != -1) && |
| (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) && |
| !llvm::isa<SplatElementsAttr>(attr); |
| } |
| |
| /// Return the size limit for printing large ElementsAttr. |
| std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const { |
| return elementsAttrElementLimit; |
| } |
| |
| /// Return the size limit for printing large ElementsAttr as hex string. |
| int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const { |
| return elementsAttrHexElementLimit; |
| } |
| |
| /// Return the size limit for printing large ElementsAttr. |
| std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const { |
| return resourceStringCharLimit; |
| } |
| |
| /// Return if debug information should be printed. |
| bool OpPrintingFlags::shouldPrintDebugInfo() const { |
| return printDebugInfoFlag; |
| } |
| |
| /// Return if debug information should be printed in the pretty form. |
| bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const { |
| return printDebugInfoPrettyFormFlag; |
| } |
| |
| /// Return if operations should be printed in the generic form. |
| bool OpPrintingFlags::shouldPrintGenericOpForm() const { |
| return printGenericOpFormFlag; |
| } |
| |
| /// Return if Region should be skipped. |
| bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; } |
| |
| /// Return if operation verification should be skipped. |
| bool OpPrintingFlags::shouldAssumeVerified() const { |
| return assumeVerifiedFlag; |
| } |
| |
| /// Return if the printer should use local scope when dumping the IR. |
| bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } |
| |
| /// Return if the printer should print users of values. |
| bool OpPrintingFlags::shouldPrintValueUsers() const { |
| return printValueUsersFlag; |
| } |
| |
| /// Return if the printer should use unique IDs. |
| bool OpPrintingFlags::shouldPrintUniqueSSAIDs() const { |
| return printUniqueSSAIDsFlag || shouldPrintGenericOpForm(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // NewLineCounter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class is a simple formatter that emits a new line when inputted into a |
| /// stream, that enables counting the number of newlines emitted. This class |
| /// should be used whenever emitting newlines in the printer. |
| struct NewLineCounter { |
| unsigned curLine = 1; |
| }; |
| |
| static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) { |
| ++newLine.curLine; |
| return os << '\n'; |
| } |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // AsmPrinter::Impl |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| class AsmPrinter::Impl { |
| public: |
| Impl(raw_ostream &os, AsmStateImpl &state); |
| explicit Impl(Impl &other) : Impl(other.os, other.state) {} |
| |
| /// Returns the output stream of the printer. |
| raw_ostream &getStream() { return os; } |
| |
| template <typename Container, typename UnaryFunctor> |
| inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { |
| llvm::interleaveComma(c, os, eachFn); |
| } |
| |
| /// This enum describes the different kinds of elision for the type of an |
| /// attribute when printing it. |
| enum class AttrTypeElision { |
| /// The type must not be elided, |
| Never, |
| /// The type may be elided when it matches the default used in the parser |
| /// (for example i64 is the default for integer attributes). |
| May, |
| /// The type must be elided. |
| Must |
| }; |
| |
| /// Print the given attribute or an alias. |
| void printAttribute(Attribute attr, |
| AttrTypeElision typeElision = AttrTypeElision::Never); |
| /// Print the given attribute without considering an alias. |
| void printAttributeImpl(Attribute attr, |
| AttrTypeElision typeElision = AttrTypeElision::Never); |
| |
| /// Print the alias for the given attribute, return failure if no alias could |
| /// be printed. |
| LogicalResult printAlias(Attribute attr); |
| |
| /// Print the given type or an alias. |
| void printType(Type type); |
| /// Print the given type. |
| void printTypeImpl(Type type); |
| |
| /// Print the alias for the given type, return failure if no alias could |
| /// be printed. |
| LogicalResult printAlias(Type type); |
| |
| /// Print the given location to the stream. If `allowAlias` is true, this |
| /// allows for the internal location to use an attribute alias. |
| void printLocation(LocationAttr loc, bool allowAlias = false); |
| |
| /// Print a reference to the given resource that is owned by the given |
| /// dialect. |
| void printResourceHandle(const AsmDialectResourceHandle &resource); |
| |
| void printAffineMap(AffineMap map); |
| void |
| printAffineExpr(AffineExpr expr, |
| function_ref<void(unsigned, bool)> printValueName = nullptr); |
| void printAffineConstraint(AffineExpr expr, bool isEq); |
| void printIntegerSet(IntegerSet set); |
| |
| LogicalResult pushCyclicPrinting(const void *opaquePointer); |
| |
| void popCyclicPrinting(); |
| |
| void printDimensionList(ArrayRef<int64_t> shape); |
| |
| protected: |
| void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| ArrayRef<StringRef> elidedAttrs = {}, |
| bool withKeyword = false); |
| void printNamedAttribute(NamedAttribute attr); |
| void printTrailingLocation(Location loc, bool allowAlias = true); |
| void printLocationInternal(LocationAttr loc, bool pretty = false, |
| bool isTopLevel = false); |
| |
| /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
| /// used instead of individual elements when the elements attr is large. |
| void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); |
| |
| /// Print a dense string elements attribute. |
| void printDenseStringElementsAttr(DenseStringElementsAttr attr); |
| |
| /// Print a dense elements attribute. If 'allowHex' is true, a hex string is |
| /// used instead of individual elements when the elements attr is large. |
| void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, |
| bool allowHex); |
| |
| /// Print a dense array attribute. |
| void printDenseArrayAttr(DenseArrayAttr attr); |
| |
| void printDialectAttribute(Attribute attr); |
| void printDialectType(Type type); |
| |
| /// Print an escaped string, wrapped with "". |
| void printEscapedString(StringRef str); |
| |
| /// Print a hex string, wrapped with "". |
| void printHexString(StringRef str); |
| void printHexString(ArrayRef<char> data); |
| |
| /// This enum is used to represent the binding strength of the enclosing |
| /// context that an AffineExprStorage is being printed in, so we can |
| /// intelligently produce parens. |
| enum class BindingStrength { |
| Weak, // + and - |
| Strong, // All other binary operators. |
| }; |
| void printAffineExprInternal( |
| AffineExpr expr, BindingStrength enclosingTightness, |
| function_ref<void(unsigned, bool)> printValueName = nullptr); |
| |
| /// The output stream for the printer. |
| raw_ostream &os; |
| |
| /// An underlying assembly printer state. |
| AsmStateImpl &state; |
| |
| /// A set of flags to control the printer's behavior. |
| OpPrintingFlags printerFlags; |
| |
| /// A tracker for the number of new lines emitted during printing. |
| NewLineCounter newLine; |
| }; |
| } // namespace mlir |
| |
| //===----------------------------------------------------------------------===// |
| // AliasInitializer |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class represents a specific instance of a symbol Alias. |
| class SymbolAlias { |
| public: |
| SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType, |
| bool isDeferrable) |
| : name(name), suffixIndex(suffixIndex), isType(isType), |
| isDeferrable(isDeferrable) {} |
| |
| /// Print this alias to the given stream. |
| void print(raw_ostream &os) const { |
| os << (isType ? "!" : "#") << name; |
| if (suffixIndex) |
| os << suffixIndex; |
| } |
| |
| /// Returns true if this is a type alias. |
| bool isTypeAlias() const { return isType; } |
| |
| /// Returns true if this alias supports deferred resolution when parsing. |
| bool canBeDeferred() const { return isDeferrable; } |
| |
| private: |
| /// The main name of the alias. |
| StringRef name; |
| /// The suffix index of the alias. |
| uint32_t suffixIndex : 30; |
| /// A flag indicating whether this alias is for a type. |
| bool isType : 1; |
| /// A flag indicating whether this alias may be deferred or not. |
| bool isDeferrable : 1; |
| }; |
| |
| /// This class represents a utility that initializes the set of attribute and |
| /// type aliases, without the need to store the extra information within the |
| /// main AliasState class or pass it around via function arguments. |
| class AliasInitializer { |
| public: |
| AliasInitializer( |
| DialectInterfaceCollection<OpAsmDialectInterface> &interfaces, |
| llvm::BumpPtrAllocator &aliasAllocator) |
| : interfaces(interfaces), aliasAllocator(aliasAllocator), |
| aliasOS(aliasBuffer) {} |
| |
| void initialize(Operation *op, const OpPrintingFlags &printerFlags, |
| llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias); |
| |
| /// Visit the given attribute to see if it has an alias. `canBeDeferred` is |
| /// set to true if the originator of this attribute can resolve the alias |
| /// after parsing has completed (e.g. in the case of operation locations). |
| /// `elideType` indicates if the type of the attribute should be skipped when |
| /// looking for nested aliases. Returns the maximum alias depth of the |
| /// attribute, and the alias index of this attribute. |
| std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false, |
| bool elideType = false) { |
| return visitImpl(attr, aliases, canBeDeferred, elideType); |
| } |
| |
| /// Visit the given type to see if it has an alias. `canBeDeferred` is |
| /// set to true if the originator of this attribute can resolve the alias |
| /// after parsing has completed. Returns the maximum alias depth of the type, |
| /// and the alias index of this type. |
| std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) { |
| return visitImpl(type, aliases, canBeDeferred); |
| } |
| |
| private: |
| struct InProgressAliasInfo { |
| InProgressAliasInfo() |
| : aliasDepth(0), isType(false), canBeDeferred(false) {} |
| InProgressAliasInfo(StringRef alias) |
| : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {} |
| |
| bool operator<(const InProgressAliasInfo &rhs) const { |
| // Order first by depth, then by attr/type kind, and then by name. |
| if (aliasDepth != rhs.aliasDepth) |
| return aliasDepth < rhs.aliasDepth; |
| if (isType != rhs.isType) |
| return isType; |
| return alias < rhs.alias; |
| } |
| |
| /// The alias for the attribute or type, or std::nullopt if the value has no |
| /// alias. |
| std::optional<StringRef> alias; |
| /// The alias depth of this attribute or type, i.e. an indication of the |
| /// relative ordering of when to print this alias. |
| unsigned aliasDepth : 30; |
| /// If this alias represents a type or an attribute. |
| bool isType : 1; |
| /// If this alias can be deferred or not. |
| bool canBeDeferred : 1; |
| /// Indices for child aliases. |
| SmallVector<size_t> childIndices; |
| }; |
| |
| /// Visit the given attribute or type to see if it has an alias. |
| /// `canBeDeferred` is set to true if the originator of this value can resolve |
| /// the alias after parsing has completed (e.g. in the case of operation |
| /// locations). Returns the maximum alias depth of the value, and its alias |
| /// index. |
| template <typename T, typename... PrintArgs> |
| std::pair<size_t, size_t> |
| visitImpl(T value, |
| llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
| bool canBeDeferred, PrintArgs &&...printArgs); |
| |
| /// Mark the given alias as non-deferrable. |
| void markAliasNonDeferrable(size_t aliasIndex); |
| |
| /// Try to generate an alias for the provided symbol. If an alias is |
| /// generated, the provided alias mapping and reverse mapping are updated. |
| template <typename T> |
| void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); |
| |
| /// Given a collection of aliases and symbols, initialize a mapping from a |
| /// symbol to a given alias. |
| static void initializeAliases( |
| llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
| llvm::MapVector<const void *, SymbolAlias> &symbolToAlias); |
| |
| /// The set of asm interfaces within the context. |
| DialectInterfaceCollection<OpAsmDialectInterface> &interfaces; |
| |
| /// An allocator used for alias names. |
| llvm::BumpPtrAllocator &aliasAllocator; |
| |
| /// The set of built aliases. |
| llvm::MapVector<const void *, InProgressAliasInfo> aliases; |
| |
| /// Storage and stream used when generating an alias. |
| SmallString<32> aliasBuffer; |
| llvm::raw_svector_ostream aliasOS; |
| }; |
| |
| /// This class implements a dummy OpAsmPrinter that doesn't print any output, |
| /// and merely collects the attributes and types that *would* be printed in a |
| /// normal print invocation so that we can generate proper aliases. This allows |
| /// for us to generate aliases only for the attributes and types that would be |
| /// in the output, and trims down unnecessary output. |
| class DummyAliasOperationPrinter : private OpAsmPrinter { |
| public: |
| explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags, |
| AliasInitializer &initializer) |
| : printerFlags(printerFlags), initializer(initializer) {} |
| |
| /// Prints the entire operation with the custom assembly form, if available, |
| /// or the generic assembly form, otherwise. |
| void printCustomOrGenericOp(Operation *op) override { |
| // Visit the operation location. |
| if (printerFlags.shouldPrintDebugInfo()) |
| initializer.visit(op->getLoc(), /*canBeDeferred=*/true); |
| |
| // If requested, always print the generic form. |
| if (!printerFlags.shouldPrintGenericOpForm()) { |
| op->getName().printAssembly(op, *this, /*defaultDialect=*/""); |
| return; |
| } |
| |
| // Otherwise print with the generic assembly form. |
| printGenericOp(op); |
| } |
| |
| private: |
| /// Print the given operation in the generic form. |
| void printGenericOp(Operation *op, bool printOpName = true) override { |
| // Consider nested operations for aliases. |
| if (!printerFlags.shouldSkipRegions()) { |
| for (Region ®ion : op->getRegions()) |
| printRegion(region, /*printEntryBlockArgs=*/true, |
| /*printBlockTerminators=*/true); |
| } |
| |
| // Visit all the types used in the operation. |
| for (Type type : op->getOperandTypes()) |
| printType(type); |
| for (Type type : op->getResultTypes()) |
| printType(type); |
| |
| // Consider the attributes of the operation for aliases. |
| for (const NamedAttribute &attr : op->getAttrs()) |
| printAttribute(attr.getValue()); |
| } |
| |
| /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
| /// block are not printed. If 'printBlockTerminator' is false, the terminator |
| /// operation of the block is not printed. |
| void print(Block *block, bool printBlockArgs = true, |
| bool printBlockTerminator = true) { |
| // Consider the types of the block arguments for aliases if 'printBlockArgs' |
| // is set to true. |
| if (printBlockArgs) { |
| for (BlockArgument arg : block->getArguments()) { |
| printType(arg.getType()); |
| |
| // Visit the argument location. |
| if (printerFlags.shouldPrintDebugInfo()) |
| // TODO: Allow deferring argument locations. |
| initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); |
| } |
| } |
| |
| // Consider the operations within this block, ignoring the terminator if |
| // requested. |
| bool hasTerminator = |
| !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
| auto range = llvm::make_range( |
| block->begin(), |
| std::prev(block->end(), |
| (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
| for (Operation &op : range) |
| printCustomOrGenericOp(&op); |
| } |
| |
| /// Print the given region. |
| void printRegion(Region ®ion, bool printEntryBlockArgs, |
| bool printBlockTerminators, |
| bool printEmptyBlock = false) override { |
| if (region.empty()) |
| return; |
| if (printerFlags.shouldSkipRegions()) { |
| os << "{...}"; |
| return; |
| } |
| |
| auto *entryBlock = ®ion.front(); |
| print(entryBlock, printEntryBlockArgs, printBlockTerminators); |
| for (Block &b : llvm::drop_begin(region, 1)) |
| print(&b); |
| } |
| |
| void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs, |
| bool omitType) override { |
| printType(arg.getType()); |
| // Visit the argument location. |
| if (printerFlags.shouldPrintDebugInfo()) |
| // TODO: Allow deferring argument locations. |
| initializer.visit(arg.getLoc(), /*canBeDeferred=*/false); |
| } |
| |
| /// Consider the given type to be printed for an alias. |
| void printType(Type type) override { initializer.visit(type); } |
| |
| /// Consider the given attribute to be printed for an alias. |
| void printAttribute(Attribute attr) override { initializer.visit(attr); } |
| void printAttributeWithoutType(Attribute attr) override { |
| printAttribute(attr); |
| } |
| LogicalResult printAlias(Attribute attr) override { |
| initializer.visit(attr); |
| return success(); |
| } |
| LogicalResult printAlias(Type type) override { |
| initializer.visit(type); |
| return success(); |
| } |
| |
| /// Consider the given location to be printed for an alias. |
| void printOptionalLocationSpecifier(Location loc) override { |
| printAttribute(loc); |
| } |
| |
| /// Print the given set of attributes with names not included within |
| /// 'elidedAttrs'. |
| void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| ArrayRef<StringRef> elidedAttrs = {}) override { |
| if (attrs.empty()) |
| return; |
| if (elidedAttrs.empty()) { |
| for (const NamedAttribute &attr : attrs) |
| printAttribute(attr.getValue()); |
| return; |
| } |
| llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), |
| elidedAttrs.end()); |
| for (const NamedAttribute &attr : attrs) |
| if (!elidedAttrsSet.contains(attr.getName().strref())) |
| printAttribute(attr.getValue()); |
| } |
| void printOptionalAttrDictWithKeyword( |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<StringRef> elidedAttrs = {}) override { |
| printOptionalAttrDict(attrs, elidedAttrs); |
| } |
| |
| /// Return a null stream as the output stream, this will ignore any data fed |
| /// to it. |
| raw_ostream &getStream() const override { return os; } |
| |
| /// The following are hooks of `OpAsmPrinter` that are not necessary for |
| /// determining potential aliases. |
| void printFloat(const APFloat &) override {} |
| void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} |
| void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} |
| void printNewline() override {} |
| void increaseIndent() override {} |
| void decreaseIndent() override {} |
| void printOperand(Value) override {} |
| void printOperand(Value, raw_ostream &os) override { |
| // Users expect the output string to have at least the prefixed % to signal |
| // a value name. To maintain this invariant, emit a name even if it is |
| // guaranteed to go unused. |
| os << "%"; |
| } |
| void printKeywordOrString(StringRef) override {} |
| void printString(StringRef) override {} |
| void printResourceHandle(const AsmDialectResourceHandle &) override {} |
| void printSymbolName(StringRef) override {} |
| void printSuccessor(Block *) override {} |
| void printSuccessorAndUseList(Block *, ValueRange) override {} |
| void shadowRegionArgs(Region &, ValueRange) override {} |
| |
| /// The printer flags to use when determining potential aliases. |
| const OpPrintingFlags &printerFlags; |
| |
| /// The initializer to use when identifying aliases. |
| AliasInitializer &initializer; |
| |
| /// A dummy output stream. |
| mutable llvm::raw_null_ostream os; |
| }; |
| |
| class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { |
| public: |
| explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer, |
| bool canBeDeferred, |
| SmallVectorImpl<size_t> &childIndices) |
| : initializer(initializer), canBeDeferred(canBeDeferred), |
| childIndices(childIndices) {} |
| |
| /// Print the given attribute/type, visiting any nested aliases that would be |
| /// generated as part of printing. Returns the maximum alias depth found while |
| /// printing the given value. |
| template <typename T, typename... PrintArgs> |
| size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) { |
| printAndVisitNestedAliasesImpl(value, printArgs...); |
| return maxAliasDepth; |
| } |
| |
| private: |
| /// Print the given attribute/type, visiting any nested aliases that would be |
| /// generated as part of printing. |
| void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) { |
| if (!isa<BuiltinDialect>(attr.getDialect())) { |
| attr.getDialect().printAttribute(attr, *this); |
| |
| // Process the builtin attributes. |
| } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr, |
| IntegerSetAttr, UnitAttr>(attr)) { |
| return; |
| } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) { |
| printAttribute(distinctAttr.getReferencedAttr()); |
| } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) { |
| for (const NamedAttribute &nestedAttr : dictAttr.getValue()) { |
| printAttribute(nestedAttr.getName()); |
| printAttribute(nestedAttr.getValue()); |
| } |
| } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) { |
| for (Attribute nestedAttr : arrayAttr.getValue()) |
| printAttribute(nestedAttr); |
| } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) { |
| printType(typeAttr.getValue()); |
| } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) { |
| printAttribute(locAttr.getFallbackLocation()); |
| } else if (auto locAttr = dyn_cast<NameLoc>(attr)) { |
| if (!isa<UnknownLoc>(locAttr.getChildLoc())) |
| printAttribute(locAttr.getChildLoc()); |
| } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) { |
| printAttribute(locAttr.getCallee()); |
| printAttribute(locAttr.getCaller()); |
| } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) { |
| if (Attribute metadata = locAttr.getMetadata()) |
| printAttribute(metadata); |
| for (Location nestedLoc : locAttr.getLocations()) |
| printAttribute(nestedLoc); |
| } |
| |
| // Don't print the type if we must elide it, or if it is a None type. |
| if (!elideType) { |
| if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { |
| Type attrType = typedAttr.getType(); |
| if (!llvm::isa<NoneType>(attrType)) |
| printType(attrType); |
| } |
| } |
| } |
| void printAndVisitNestedAliasesImpl(Type type) { |
| if (!isa<BuiltinDialect>(type.getDialect())) |
| return type.getDialect().printType(type, *this); |
| |
| // Only visit the layout of memref if it isn't the identity. |
| if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) { |
| printType(memrefTy.getElementType()); |
| MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
| if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) |
| printAttribute(memrefTy.getLayout()); |
| if (memrefTy.getMemorySpace()) |
| printAttribute(memrefTy.getMemorySpace()); |
| return; |
| } |
| |
| // For most builtin types, we can simply walk the sub elements. |
| auto visitFn = [&](auto element) { |
| if (element) |
| (void)printAlias(element); |
| }; |
| type.walkImmediateSubElements(visitFn, visitFn); |
| } |
| |
| /// Consider the given type to be printed for an alias. |
| void printType(Type type) override { |
| recordAliasResult(initializer.visit(type, canBeDeferred)); |
| } |
| |
| /// Consider the given attribute to be printed for an alias. |
| void printAttribute(Attribute attr) override { |
| recordAliasResult(initializer.visit(attr, canBeDeferred)); |
| } |
| void printAttributeWithoutType(Attribute attr) override { |
| recordAliasResult( |
| initializer.visit(attr, canBeDeferred, /*elideType=*/true)); |
| } |
| LogicalResult printAlias(Attribute attr) override { |
| printAttribute(attr); |
| return success(); |
| } |
| LogicalResult printAlias(Type type) override { |
| printType(type); |
| return success(); |
| } |
| |
| /// Record the alias result of a child element. |
| void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) { |
| childIndices.push_back(aliasDepthAndIndex.second); |
| if (aliasDepthAndIndex.first > maxAliasDepth) |
| maxAliasDepth = aliasDepthAndIndex.first; |
| } |
| |
| /// Return a null stream as the output stream, this will ignore any data fed |
| /// to it. |
| raw_ostream &getStream() const override { return os; } |
| |
| /// The following are hooks of `DialectAsmPrinter` that are not necessary for |
| /// determining potential aliases. |
| void printFloat(const APFloat &) override {} |
| void printKeywordOrString(StringRef) override {} |
| void printString(StringRef) override {} |
| void printSymbolName(StringRef) override {} |
| void printResourceHandle(const AsmDialectResourceHandle &) override {} |
| |
| LogicalResult pushCyclicPrinting(const void *opaquePointer) override { |
| return success(cyclicPrintingStack.insert(opaquePointer)); |
| } |
| |
| void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } |
| |
| /// Stack of potentially cyclic mutable attributes or type currently being |
| /// printed. |
| SetVector<const void *> cyclicPrintingStack; |
| |
| /// The initializer to use when identifying aliases. |
| AliasInitializer &initializer; |
| |
| /// If the aliases visited by this printer can be deferred. |
| bool canBeDeferred; |
| |
| /// The indices of child aliases. |
| SmallVectorImpl<size_t> &childIndices; |
| |
| /// The maximum alias depth found by the printer. |
| size_t maxAliasDepth = 0; |
| |
| /// A dummy output stream. |
| mutable llvm::raw_null_ostream os; |
| }; |
| } // namespace |
| |
| /// Sanitize the given name such that it can be used as a valid identifier. If |
| /// the string needs to be modified in any way, the provided buffer is used to |
| /// store the new copy, |
| static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, |
| StringRef allowedPunctChars = "$._-", |
| bool allowTrailingDigit = true) { |
| assert(!name.empty() && "Shouldn't have an empty name here"); |
| |
| auto validChar = [&](char ch) { |
| return llvm::isAlnum(ch) || allowedPunctChars.contains(ch); |
| }; |
| |
| auto copyNameToBuffer = [&] { |
| for (char ch : name) { |
| if (validChar(ch)) |
| buffer.push_back(ch); |
| else if (ch == ' ') |
| buffer.push_back('_'); |
| else |
| buffer.append(llvm::utohexstr((unsigned char)ch)); |
| } |
| }; |
| |
| // Check to see if this name is valid. If it starts with a digit, then it |
| // could conflict with the autogenerated numeric ID's, so add an underscore |
| // prefix to avoid problems. |
| if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) { |
| buffer.push_back('_'); |
| copyNameToBuffer(); |
| return buffer; |
| } |
| |
| // If the name ends with a trailing digit, add a '_' to avoid potential |
| // conflicts with autogenerated ID's. |
| if (!allowTrailingDigit && isdigit(name.back())) { |
| copyNameToBuffer(); |
| buffer.push_back('_'); |
| return buffer; |
| } |
| |
| // Check to see that the name consists of only valid identifier characters. |
| for (char ch : name) { |
| if (!validChar(ch)) { |
| copyNameToBuffer(); |
| return buffer; |
| } |
| } |
| |
| // If there are no invalid characters, return the original name. |
| return name; |
| } |
| |
| /// Given a collection of aliases and symbols, initialize a mapping from a |
| /// symbol to a given alias. |
| void AliasInitializer::initializeAliases( |
| llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols, |
| llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) { |
| SmallVector<std::pair<const void *, InProgressAliasInfo>, 0> |
| unprocessedAliases = visitedSymbols.takeVector(); |
| llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) { |
| return lhs.second < rhs.second; |
| }); |
| |
| llvm::StringMap<unsigned> nameCounts; |
| for (auto &[symbol, aliasInfo] : unprocessedAliases) { |
| if (!aliasInfo.alias) |
| continue; |
| StringRef alias = *aliasInfo.alias; |
| unsigned nameIndex = nameCounts[alias]++; |
| symbolToAlias.insert( |
| {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, |
| aliasInfo.canBeDeferred)}); |
| } |
| } |
| |
| void AliasInitializer::initialize( |
| Operation *op, const OpPrintingFlags &printerFlags, |
| llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) { |
| // Use a dummy printer when walking the IR so that we can collect the |
| // attributes/types that will actually be used during printing when |
| // considering aliases. |
| DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); |
| aliasPrinter.printCustomOrGenericOp(op); |
| |
| // Initialize the aliases. |
| initializeAliases(aliases, attrTypeToAlias); |
| } |
| |
| template <typename T, typename... PrintArgs> |
| std::pair<size_t, size_t> AliasInitializer::visitImpl( |
| T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases, |
| bool canBeDeferred, PrintArgs &&...printArgs) { |
| auto [it, inserted] = |
| aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()}); |
| size_t aliasIndex = std::distance(aliases.begin(), it); |
| if (!inserted) { |
| // Make sure that the alias isn't deferred if we don't permit it. |
| if (!canBeDeferred) |
| markAliasNonDeferrable(aliasIndex); |
| return {static_cast<size_t>(it->second.aliasDepth), aliasIndex}; |
| } |
| |
| // Try to generate an alias for this value. |
| generateAlias(value, it->second, canBeDeferred); |
| it->second.isType = std::is_base_of_v<Type, T>; |
| it->second.canBeDeferred = canBeDeferred; |
| |
| // Print the value, capturing any nested elements that require aliases. |
| SmallVector<size_t> childAliases; |
| DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases); |
| size_t maxAliasDepth = |
| printer.printAndVisitNestedAliases(value, printArgs...); |
| |
| // Make sure to recompute `it` in case the map was reallocated. |
| it = std::next(aliases.begin(), aliasIndex); |
| |
| // If we had sub elements, update to account for the depth. |
| it->second.childIndices = std::move(childAliases); |
| if (maxAliasDepth) |
| it->second.aliasDepth = maxAliasDepth + 1; |
| |
| // Propagate the alias depth of the value. |
| return {(size_t)it->second.aliasDepth, aliasIndex}; |
| } |
| |
| void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { |
| auto *it = std::next(aliases.begin(), aliasIndex); |
| |
| // If already marked non-deferrable stop the recursion. |
| // All children should already be marked non-deferrable as well. |
| if (!it->second.canBeDeferred) |
| return; |
| |
| it->second.canBeDeferred = false; |
| |
| // Propagate the non-deferrable flag to any child aliases. |
| for (size_t childIndex : it->second.childIndices) |
| markAliasNonDeferrable(childIndex); |
| } |
| |
| template <typename T> |
| void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, |
| bool canBeDeferred) { |
| SmallString<32> nameBuffer; |
| for (const auto &interface : interfaces) { |
| OpAsmDialectInterface::AliasResult result = |
| interface.getAlias(symbol, aliasOS); |
| if (result == OpAsmDialectInterface::AliasResult::NoAlias) |
| continue; |
| nameBuffer = std::move(aliasBuffer); |
| assert(!nameBuffer.empty() && "expected valid alias name"); |
| if (result == OpAsmDialectInterface::AliasResult::FinalAlias) |
| break; |
| } |
| |
| if (nameBuffer.empty()) |
| return; |
| |
| SmallString<16> tempBuffer; |
| StringRef name = |
| sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", |
| /*allowTrailingDigit=*/false); |
| name = name.copy(aliasAllocator); |
| alias = InProgressAliasInfo(name); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AliasState |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class manages the state for type and attribute aliases. |
| class AliasState { |
| public: |
| // Initialize the internal aliases. |
| void |
| initialize(Operation *op, const OpPrintingFlags &printerFlags, |
| DialectInterfaceCollection<OpAsmDialectInterface> &interfaces); |
| |
| /// Get an alias for the given attribute if it has one and print it in `os`. |
| /// Returns success if an alias was printed, failure otherwise. |
| LogicalResult getAlias(Attribute attr, raw_ostream &os) const; |
| |
| /// Get an alias for the given type if it has one and print it in `os`. |
| /// Returns success if an alias was printed, failure otherwise. |
| LogicalResult getAlias(Type ty, raw_ostream &os) const; |
| |
| /// Print all of the referenced aliases that can not be resolved in a deferred |
| /// manner. |
| void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
| printAliases(p, newLine, /*isDeferred=*/false); |
| } |
| |
| /// Print all of the referenced aliases that support deferred resolution. |
| void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { |
| printAliases(p, newLine, /*isDeferred=*/true); |
| } |
| |
| private: |
| /// Print all of the referenced aliases that support the provided resolution |
| /// behavior. |
| void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
| bool isDeferred); |
| |
| /// Mapping between attribute/type and alias. |
| llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias; |
| |
| /// An allocator used for alias names. |
| llvm::BumpPtrAllocator aliasAllocator; |
| }; |
| } // namespace |
| |
| void AliasState::initialize( |
| Operation *op, const OpPrintingFlags &printerFlags, |
| DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) { |
| AliasInitializer initializer(interfaces, aliasAllocator); |
| initializer.initialize(op, printerFlags, attrTypeToAlias); |
| } |
| |
| LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const { |
| const auto *it = attrTypeToAlias.find(attr.getAsOpaquePointer()); |
| if (it == attrTypeToAlias.end()) |
| return failure(); |
| it->second.print(os); |
| return success(); |
| } |
| |
| LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { |
| const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); |
| if (it == attrTypeToAlias.end()) |
| return failure(); |
| |
| it->second.print(os); |
| return success(); |
| } |
| |
| void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, |
| bool isDeferred) { |
| auto filterFn = [=](const auto &aliasIt) { |
| return aliasIt.second.canBeDeferred() == isDeferred; |
| }; |
| for (auto &[opaqueSymbol, alias] : |
| llvm::make_filter_range(attrTypeToAlias, filterFn)) { |
| alias.print(p.getStream()); |
| p.getStream() << " = "; |
| |
| if (alias.isTypeAlias()) { |
| // TODO: Support nested aliases in mutable types. |
| Type type = Type::getFromOpaquePointer(opaqueSymbol); |
| if (type.hasTrait<TypeTrait::IsMutable>()) |
| p.getStream() << type; |
| else |
| p.printTypeImpl(type); |
| } else { |
| // TODO: Support nested aliases in mutable attributes. |
| Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); |
| if (attr.hasTrait<AttributeTrait::IsMutable>()) |
| p.getStream() << attr; |
| else |
| p.printAttributeImpl(attr); |
| } |
| |
| p.getStream() << newLine; |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SSANameState |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Info about block printing: a number which is its position in the visitation |
| /// order, and a name that is used to print reference to it, e.g. ^bb42. |
| struct BlockInfo { |
| int ordering; |
| StringRef name; |
| }; |
| |
| /// This class manages the state of SSA value names. |
| class SSANameState { |
| public: |
| /// A sentinel value used for values with names set. |
| enum : unsigned { NameSentinel = ~0U }; |
| |
| SSANameState(Operation *op, const OpPrintingFlags &printerFlags); |
| SSANameState() = default; |
| |
| /// Print the SSA identifier for the given value to 'stream'. If |
| /// 'printResultNo' is true, it also presents the result number ('#' number) |
| /// of this value. |
| void printValueID(Value value, bool printResultNo, raw_ostream &stream) const; |
| |
| /// Print the operation identifier. |
| void printOperationID(Operation *op, raw_ostream &stream) const; |
| |
| /// Return the result indices for each of the result groups registered by this |
| /// operation, or empty if none exist. |
| ArrayRef<int> getOpResultGroups(Operation *op); |
| |
| /// Get the info for the given block. |
| BlockInfo getBlockInfo(Block *block); |
| |
| /// Renumber the arguments for the specified region to the same names as the |
| /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for |
| /// details. |
| void shadowRegionArgs(Region ®ion, ValueRange namesToUse); |
| |
| private: |
| /// Number the SSA values within the given IR unit. |
| void numberValuesInRegion(Region ®ion); |
| void numberValuesInBlock(Block &block); |
| void numberValuesInOp(Operation &op); |
| |
| /// Given a result of an operation 'result', find the result group head |
| /// 'lookupValue' and the result of 'result' within that group in |
| /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group |
| /// has more than 1 result. |
| void getResultIDAndNumber(OpResult result, Value &lookupValue, |
| std::optional<int> &lookupResultNo) const; |
| |
| /// Set a special value name for the given value. |
| void setValueName(Value value, StringRef name); |
| |
| /// Uniques the given value name within the printer. If the given name |
| /// conflicts, it is automatically renamed. |
| StringRef uniqueValueName(StringRef name); |
| |
| /// This is the value ID for each SSA value. If this returns NameSentinel, |
| /// then the valueID has an entry in valueNames. |
| DenseMap<Value, unsigned> valueIDs; |
| DenseMap<Value, StringRef> valueNames; |
| |
| /// When printing users of values, an operation without a result might |
| /// be the user. This map holds ids for such operations. |
| DenseMap<Operation *, unsigned> operationIDs; |
| |
| /// This is a map of operations that contain multiple named result groups, |
| /// i.e. there may be multiple names for the results of the operation. The |
| /// value of this map are the result numbers that start a result group. |
| DenseMap<Operation *, SmallVector<int, 1>> opResultGroups; |
| |
| /// This maps blocks to there visitation number in the current region as well |
| /// as the string representing their name. |
| DenseMap<Block *, BlockInfo> blockNames; |
| |
| /// This keeps track of all of the non-numeric names that are in flight, |
| /// allowing us to check for duplicates. |
| /// Note: the value of the map is unused. |
| llvm::ScopedHashTable<StringRef, char> usedNames; |
| llvm::BumpPtrAllocator usedNameAllocator; |
| |
| /// This is the next value ID to assign in numbering. |
| unsigned nextValueID = 0; |
| /// This is the next ID to assign to a region entry block argument. |
| unsigned nextArgumentID = 0; |
| /// This is the next ID to assign when a name conflict is detected. |
| unsigned nextConflictID = 0; |
| |
| /// These are the printing flags. They control, eg., whether to print in |
| /// generic form. |
| OpPrintingFlags printerFlags; |
| }; |
| } // namespace |
| |
| SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags) |
| : printerFlags(printerFlags) { |
| llvm::SaveAndRestore valueIDSaver(nextValueID); |
| llvm::SaveAndRestore argumentIDSaver(nextArgumentID); |
| llvm::SaveAndRestore conflictIDSaver(nextConflictID); |
| |
| // The naming context includes `nextValueID`, `nextArgumentID`, |
| // `nextConflictID` and `usedNames` scoped HashTable. This information is |
| // carried from the parent region. |
| using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy; |
| using NamingContext = |
| std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>; |
| |
| // Allocator for UsedNamesScopeTy |
| llvm::BumpPtrAllocator allocator; |
| |
| // Add a scope for the top level operation. |
| auto *topLevelNamesScope = |
| new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames); |
| |
| SmallVector<NamingContext, 8> nameContext; |
| for (Region ®ion : op->getRegions()) |
| nameContext.push_back(std::make_tuple(®ion, nextValueID, nextArgumentID, |
| nextConflictID, topLevelNamesScope)); |
| |
| numberValuesInOp(*op); |
| |
| while (!nameContext.empty()) { |
| Region *region; |
| UsedNamesScopeTy *parentScope; |
| |
| if (printerFlags.shouldPrintUniqueSSAIDs()) |
| // To print unique SSA IDs, ignore saved ID counts from parent regions |
| std::tie(region, std::ignore, std::ignore, std::ignore, parentScope) = |
| nameContext.pop_back_val(); |
| else |
| std::tie(region, nextValueID, nextArgumentID, nextConflictID, |
| parentScope) = nameContext.pop_back_val(); |
| |
| // When we switch from one subtree to another, pop the scopes(needless) |
| // until the parent scope. |
| while (usedNames.getCurScope() != parentScope) { |
| usedNames.getCurScope()->~UsedNamesScopeTy(); |
| assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) && |
| "top level parentScope must be a nullptr"); |
| } |
| |
| // Add a scope for the current region. |
| auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>()) |
| UsedNamesScopeTy(usedNames); |
| |
| numberValuesInRegion(*region); |
| |
| for (Operation &op : region->getOps()) |
| for (Region ®ion : op.getRegions()) |
| nameContext.push_back(std::make_tuple(®ion, nextValueID, |
| nextArgumentID, nextConflictID, |
| curNamesScope)); |
| } |
| |
| // Manually remove all the scopes. |
| while (usedNames.getCurScope() != nullptr) |
| usedNames.getCurScope()->~UsedNamesScopeTy(); |
| } |
| |
| void SSANameState::printValueID(Value value, bool printResultNo, |
| raw_ostream &stream) const { |
| if (!value) { |
| stream << "<<NULL VALUE>>"; |
| return; |
| } |
| |
| std::optional<int> resultNo; |
| auto lookupValue = value; |
| |
| // If this is an operation result, collect the head lookup value of the result |
| // group and the result number of 'result' within that group. |
| if (OpResult result = dyn_cast<OpResult>(value)) |
| getResultIDAndNumber(result, lookupValue, resultNo); |
| |
| auto it = valueIDs.find(lookupValue); |
| if (it == valueIDs.end()) { |
| stream << "<<UNKNOWN SSA VALUE>>"; |
| return; |
| } |
| |
| stream << '%'; |
| if (it->second != NameSentinel) { |
| stream << it->second; |
| } else { |
| auto nameIt = valueNames.find(lookupValue); |
| assert(nameIt != valueNames.end() && "Didn't have a name entry?"); |
| stream << nameIt->second; |
| } |
| |
| if (resultNo && printResultNo) |
| stream << '#' << *resultNo; |
| } |
| |
| void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const { |
| auto it = operationIDs.find(op); |
| if (it == operationIDs.end()) { |
| stream << "<<UNKNOWN OPERATION>>"; |
| } else { |
| stream << '%' << it->second; |
| } |
| } |
| |
| ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) { |
| auto it = opResultGroups.find(op); |
| return it == opResultGroups.end() ? ArrayRef<int>() : it->second; |
| } |
| |
| BlockInfo SSANameState::getBlockInfo(Block *block) { |
| auto it = blockNames.find(block); |
| BlockInfo invalidBlock{-1, "INVALIDBLOCK"}; |
| return it != blockNames.end() ? it->second : invalidBlock; |
| } |
| |
| void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { |
| assert(!region.empty() && "cannot shadow arguments of an empty region"); |
| assert(region.getNumArguments() == namesToUse.size() && |
| "incorrect number of names passed in"); |
| assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && |
| "only KnownIsolatedFromAbove ops can shadow names"); |
| |
| SmallVector<char, 16> nameStr; |
| for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { |
| auto nameToUse = namesToUse[i]; |
| if (nameToUse == nullptr) |
| continue; |
| auto nameToReplace = region.getArgument(i); |
| |
| nameStr.clear(); |
| llvm::raw_svector_ostream nameStream(nameStr); |
| printValueID(nameToUse, /*printResultNo=*/true, nameStream); |
| |
| // Entry block arguments should already have a pretty "arg" name. |
| assert(valueIDs[nameToReplace] == NameSentinel); |
| |
| // Use the name without the leading %. |
| auto name = StringRef(nameStream.str()).drop_front(); |
| |
| // Overwrite the name. |
| valueNames[nameToReplace] = name.copy(usedNameAllocator); |
| } |
| } |
| |
| void SSANameState::numberValuesInRegion(Region ®ion) { |
| auto setBlockArgNameFn = [&](Value arg, StringRef name) { |
| assert(!valueIDs.count(arg) && "arg numbered multiple times"); |
| assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == ®ion && |
| "arg not defined in current region"); |
| setValueName(arg, name); |
| }; |
| |
| if (!printerFlags.shouldPrintGenericOpForm()) { |
| if (Operation *op = region.getParentOp()) { |
| if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op)) |
| asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); |
| } |
| } |
| |
| // Number the values within this region in a breadth-first order. |
| unsigned nextBlockID = 0; |
| for (auto &block : region) { |
| // Each block gets a unique ID, and all of the operations within it get |
| // numbered as well. |
| auto blockInfoIt = blockNames.insert({&block, {-1, ""}}); |
| if (blockInfoIt.second) { |
| // This block hasn't been named through `getAsmBlockArgumentNames`, use |
| // default `^bbNNN` format. |
| std::string name; |
| llvm::raw_string_ostream(name) << "^bb" << nextBlockID; |
| blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator); |
| } |
| blockInfoIt.first->second.ordering = nextBlockID++; |
| |
| numberValuesInBlock(block); |
| } |
| } |
| |
| void SSANameState::numberValuesInBlock(Block &block) { |
| // Number the block arguments. We give entry block arguments a special name |
| // 'arg'. |
| bool isEntryBlock = block.isEntryBlock(); |
| SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); |
| llvm::raw_svector_ostream specialName(specialNameBuffer); |
| for (auto arg : block.getArguments()) { |
| if (valueIDs.count(arg)) |
| continue; |
| if (isEntryBlock) { |
| specialNameBuffer.resize(strlen("arg")); |
| specialName << nextArgumentID++; |
| } |
| setValueName(arg, specialName.str()); |
| } |
| |
| // Number the operations in this block. |
| for (auto &op : block) |
| numberValuesInOp(op); |
| } |
| |
| void SSANameState::numberValuesInOp(Operation &op) { |
| // Function used to set the special result names for the operation. |
| SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0); |
| auto setResultNameFn = [&](Value result, StringRef name) { |
| assert(!valueIDs.count(result) && "result numbered multiple times"); |
| assert(result.getDefiningOp() == &op && "result not defined by 'op'"); |
| setValueName(result, name); |
| |
| // Record the result number for groups not anchored at 0. |
| if (int resultNo = llvm::cast<OpResult>(result).getResultNumber()) |
| resultGroups.push_back(resultNo); |
| }; |
| // Operations can customize the printing of block names in OpAsmOpInterface. |
| auto setBlockNameFn = [&](Block *block, StringRef name) { |
| assert(block->getParentOp() == &op && |
| "getAsmBlockArgumentNames callback invoked on a block not directly " |
| "nested under the current operation"); |
| assert(!blockNames.count(block) && "block numbered multiple times"); |
| SmallString<16> tmpBuffer{"^"}; |
| name = sanitizeIdentifier(name, tmpBuffer); |
| if (name.data() != tmpBuffer.data()) { |
| tmpBuffer.append(name); |
| name = tmpBuffer.str(); |
| } |
| name = name.copy(usedNameAllocator); |
| blockNames[block] = {-1, name}; |
| }; |
| |
| if (!printerFlags.shouldPrintGenericOpForm()) { |
| if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) { |
| asmInterface.getAsmBlockNames(setBlockNameFn); |
| asmInterface.getAsmResultNames(setResultNameFn); |
| } |
| } |
| |
| unsigned numResults = op.getNumResults(); |
| if (numResults == 0) { |
| // If value users should be printed, operations with no result need an id. |
| if (printerFlags.shouldPrintValueUsers()) { |
| if (operationIDs.try_emplace(&op, nextValueID).second) |
| ++nextValueID; |
| } |
| return; |
| } |
| Value resultBegin = op.getResult(0); |
| |
| // If the first result wasn't numbered, give it a default number. |
| if (valueIDs.try_emplace(resultBegin, nextValueID).second) |
| ++nextValueID; |
| |
| // If this operation has multiple result groups, mark it. |
| if (resultGroups.size() != 1) { |
| llvm::array_pod_sort(resultGroups.begin(), resultGroups.end()); |
| opResultGroups.try_emplace(&op, std::move(resultGroups)); |
| } |
| } |
| |
| void SSANameState::getResultIDAndNumber( |
| OpResult result, Value &lookupValue, |
| std::optional<int> &lookupResultNo) const { |
| Operation *owner = result.getOwner(); |
| if (owner->getNumResults() == 1) |
| return; |
| int resultNo = result.getResultNumber(); |
| |
| // If this operation has multiple result groups, we will need to find the |
| // one corresponding to this result. |
| auto resultGroupIt = opResultGroups.find(owner); |
| if (resultGroupIt == opResultGroups.end()) { |
| // If not, just use the first result. |
| lookupResultNo = resultNo; |
| lookupValue = owner->getResult(0); |
| return; |
| } |
| |
| // Find the correct index using a binary search, as the groups are ordered. |
| ArrayRef<int> resultGroups = resultGroupIt->second; |
| const auto *it = llvm::upper_bound(resultGroups, resultNo); |
| int groupResultNo = 0, groupSize = 0; |
| |
| // If there are no smaller elements, the last result group is the lookup. |
| if (it == resultGroups.end()) { |
| groupResultNo = resultGroups.back(); |
| groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back(); |
| } else { |
| // Otherwise, the previous element is the lookup. |
| groupResultNo = *std::prev(it); |
| groupSize = *it - groupResultNo; |
| } |
| |
| // We only record the result number for a group of size greater than 1. |
| if (groupSize != 1) |
| lookupResultNo = resultNo - groupResultNo; |
| lookupValue = owner->getResult(groupResultNo); |
| } |
| |
| void SSANameState::setValueName(Value value, StringRef name) { |
| // If the name is empty, the value uses the default numbering. |
| if (name.empty()) { |
| valueIDs[value] = nextValueID++; |
| return; |
| } |
| |
| valueIDs[value] = NameSentinel; |
| valueNames[value] = uniqueValueName(name); |
| } |
| |
| StringRef SSANameState::uniqueValueName(StringRef name) { |
| SmallString<16> tmpBuffer; |
| name = sanitizeIdentifier(name, tmpBuffer); |
| |
| // Check to see if this name is already unique. |
| if (!usedNames.count(name)) { |
| name = name.copy(usedNameAllocator); |
| } else { |
| // Otherwise, we had a conflict - probe until we find a unique name. This |
| // is guaranteed to terminate (and usually in a single iteration) because it |
| // generates new names by incrementing nextConflictID. |
| SmallString<64> probeName(name); |
| probeName.push_back('_'); |
| while (true) { |
| probeName += llvm::utostr(nextConflictID++); |
| if (!usedNames.count(probeName)) { |
| name = probeName.str().copy(usedNameAllocator); |
| break; |
| } |
| probeName.resize(name.size() + 1); |
| } |
| } |
| |
| usedNames.insert(name, char()); |
| return name; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DistinctState |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class manages the state for distinct attributes. |
| class DistinctState { |
| public: |
| /// Returns a unique identifier for the given distinct attribute. |
| uint64_t getId(DistinctAttr distinctAttr); |
| |
| private: |
| uint64_t distinctCounter = 0; |
| DenseMap<DistinctAttr, uint64_t> distinctAttrMap; |
| }; |
| } // namespace |
| |
| uint64_t DistinctState::getId(DistinctAttr distinctAttr) { |
| auto [it, inserted] = |
| distinctAttrMap.try_emplace(distinctAttr, distinctCounter); |
| if (inserted) |
| distinctCounter++; |
| return it->getSecond(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Resources |
| //===----------------------------------------------------------------------===// |
| |
| AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; |
| AsmResourceBuilder::~AsmResourceBuilder() = default; |
| AsmResourceParser::~AsmResourceParser() = default; |
| AsmResourcePrinter::~AsmResourcePrinter() = default; |
| |
| StringRef mlir::toString(AsmResourceEntryKind kind) { |
| switch (kind) { |
| case AsmResourceEntryKind::Blob: |
| return "blob"; |
| case AsmResourceEntryKind::Bool: |
| return "bool"; |
| case AsmResourceEntryKind::String: |
| return "string"; |
| } |
| llvm_unreachable("unknown AsmResourceEntryKind"); |
| } |
| |
| AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) { |
| std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()]; |
| if (!collection) |
| collection = std::make_unique<ResourceCollection>(key); |
| return *collection; |
| } |
| |
| std::vector<std::unique_ptr<AsmResourcePrinter>> |
| FallbackAsmResourceMap::getPrinters() { |
| std::vector<std::unique_ptr<AsmResourcePrinter>> printers; |
| for (auto &it : keyToResources) { |
| ResourceCollection *collection = it.second.get(); |
| auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) { |
| return collection->buildResources(op, builder); |
| }; |
| printers.emplace_back( |
| AsmResourcePrinter::fromCallable(collection->getName(), buildValues)); |
| } |
| return printers; |
| } |
| |
| LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource( |
| AsmParsedResourceEntry &entry) { |
| switch (entry.getKind()) { |
| case AsmResourceEntryKind::Blob: { |
| FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(); |
| if (failed(blob)) |
| return failure(); |
| resources.emplace_back(entry.getKey(), std::move(*blob)); |
| return success(); |
| } |
| case AsmResourceEntryKind::Bool: { |
| FailureOr<bool> value = entry.parseAsBool(); |
| if (failed(value)) |
| return failure(); |
| resources.emplace_back(entry.getKey(), *value); |
| break; |
| } |
| case AsmResourceEntryKind::String: { |
| FailureOr<std::string> str = entry.parseAsString(); |
| if (failed(str)) |
| return failure(); |
| resources.emplace_back(entry.getKey(), std::move(*str)); |
| break; |
| } |
| } |
| return success(); |
| } |
| |
| void FallbackAsmResourceMap::ResourceCollection::buildResources( |
| Operation *op, AsmResourceBuilder &builder) const { |
| for (const auto &entry : resources) { |
| if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value)) |
| builder.buildBlob(entry.key, *value); |
| else if (const auto *value = std::get_if<bool>(&entry.value)) |
| builder.buildBool(entry.key, *value); |
| else if (const auto *value = std::get_if<std::string>(&entry.value)) |
| builder.buildString(entry.key, *value); |
| else |
| llvm_unreachable("unknown AsmResourceEntryKind"); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AsmState |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace detail { |
| class AsmStateImpl { |
| public: |
| explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, |
| AsmState::LocationMap *locationMap) |
| : interfaces(op->getContext()), nameState(op, printerFlags), |
| printerFlags(printerFlags), locationMap(locationMap) {} |
| explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
| AsmState::LocationMap *locationMap) |
| : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} |
| |
| /// Initialize the alias state to enable the printing of aliases. |
| void initializeAliases(Operation *op) { |
| aliasState.initialize(op, printerFlags, interfaces); |
| } |
| |
| /// Get the state used for aliases. |
| AliasState &getAliasState() { return aliasState; } |
| |
| /// Get the state used for SSA names. |
| SSANameState &getSSANameState() { return nameState; } |
| |
| /// Get the state used for distinct attribute identifiers. |
| DistinctState &getDistinctState() { return distinctState; } |
| |
| /// Return the dialects within the context that implement |
| /// OpAsmDialectInterface. |
| DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() { |
| return interfaces; |
| } |
| |
| /// Return the non-dialect resource printers. |
| auto getResourcePrinters() { |
| return llvm::make_pointee_range(externalResourcePrinters); |
| } |
| |
| /// Get the printer flags. |
| const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } |
| |
| /// Register the location, line and column, within the buffer that the given |
| /// operation was printed at. |
| void registerOperationLocation(Operation *op, unsigned line, unsigned col) { |
| if (locationMap) |
| (*locationMap)[op] = std::make_pair(line, col); |
| } |
| |
| /// Return the referenced dialect resources within the printer. |
| DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
| getDialectResources() { |
| return dialectResources; |
| } |
| |
| LogicalResult pushCyclicPrinting(const void *opaquePointer) { |
| return success(cyclicPrintingStack.insert(opaquePointer)); |
| } |
| |
| void popCyclicPrinting() { cyclicPrintingStack.pop_back(); } |
| |
| private: |
| /// Collection of OpAsm interfaces implemented in the context. |
| DialectInterfaceCollection<OpAsmDialectInterface> interfaces; |
| |
| /// A collection of non-dialect resource printers. |
| SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters; |
| |
| /// A set of dialect resources that were referenced during printing. |
| DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources; |
| |
| /// The state used for attribute and type aliases. |
| AliasState aliasState; |
| |
| /// The state used for SSA value names. |
| SSANameState nameState; |
| |
| /// The state used for distinct attribute identifiers. |
| DistinctState distinctState; |
| |
| /// Flags that control op output. |
| OpPrintingFlags printerFlags; |
| |
| /// An optional location map to be populated. |
| AsmState::LocationMap *locationMap; |
| |
| /// Stack of potentially cyclic mutable attributes or type currently being |
| /// printed. |
| SetVector<const void *> cyclicPrintingStack; |
| |
| // Allow direct access to the impl fields. |
| friend AsmState; |
| }; |
| |
| template <typename Range> |
| void printDimensionList(raw_ostream &stream, Range &&shape) { |
| llvm::interleave( |
| shape, stream, |
| [&stream](const auto &dimSize) { |
| if (ShapedType::isDynamic(dimSize)) |
| stream << "?"; |
| else |
| stream << dimSize; |
| }, |
| "x"); |
| } |
| |
| } // namespace detail |
| } // namespace mlir |
| |
| /// Verifies the operation and switches to generic op printing if verification |
| /// fails. We need to do this because custom print functions may fail for |
| /// invalid ops. |
| static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, |
| OpPrintingFlags printerFlags) { |
| if (printerFlags.shouldPrintGenericOpForm() || |
| printerFlags.shouldAssumeVerified()) |
| return printerFlags; |
| |
| // Ignore errors emitted by the verifier. We check the thread id to avoid |
| // consuming other threads' errors. |
| auto parentThreadId = llvm::get_threadid(); |
| ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) { |
| if (parentThreadId == llvm::get_threadid()) { |
| LLVM_DEBUG({ |
| diag.print(llvm::dbgs()); |
| llvm::dbgs() << "\n"; |
| }); |
| return success(); |
| } |
| return failure(); |
| }); |
| if (failed(verify(op))) { |
| LLVM_DEBUG(llvm::dbgs() |
| << DEBUG_TYPE << ": '" << op->getName() |
| << "' failed to verify and will be printed in generic form\n"); |
| printerFlags.printGenericOpForm(); |
| } |
| |
| return printerFlags; |
| } |
| |
| AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, |
| LocationMap *locationMap, FallbackAsmResourceMap *map) |
| : impl(std::make_unique<AsmStateImpl>( |
| op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) { |
| if (map) |
| attachFallbackResourcePrinter(*map); |
| } |
| AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, |
| LocationMap *locationMap, FallbackAsmResourceMap *map) |
| : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) { |
| if (map) |
| attachFallbackResourcePrinter(*map); |
| } |
| AsmState::~AsmState() = default; |
| |
| const OpPrintingFlags &AsmState::getPrinterFlags() const { |
| return impl->getPrinterFlags(); |
| } |
| |
| void AsmState::attachResourcePrinter( |
| std::unique_ptr<AsmResourcePrinter> printer) { |
| impl->externalResourcePrinters.emplace_back(std::move(printer)); |
| } |
| |
| DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> & |
| AsmState::getDialectResources() const { |
| return impl->getDialectResources(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AsmPrinter::Impl |
| //===----------------------------------------------------------------------===// |
| |
| AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state) |
| : os(os), state(state), printerFlags(state.getPrinterFlags()) {} |
| |
| void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) { |
| // Check to see if we are printing debug information. |
| if (!printerFlags.shouldPrintDebugInfo()) |
| return; |
| |
| os << " "; |
| printLocation(loc, /*allowAlias=*/allowAlias); |
| } |
| |
| void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty, |
| bool isTopLevel) { |
| // If this isn't a top-level location, check for an alias. |
| if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os))) |
| return; |
| |
| TypeSwitch<LocationAttr>(loc) |
| .Case<OpaqueLoc>([&](OpaqueLoc loc) { |
| printLocationInternal(loc.getFallbackLocation(), pretty); |
| }) |
| .Case<UnknownLoc>([&](UnknownLoc loc) { |
| if (pretty) |
| os << "[unknown]"; |
| else |
| os << "unknown"; |
| }) |
| .Case<FileLineColLoc>([&](FileLineColLoc loc) { |
| if (pretty) |
| os << loc.getFilename().getValue(); |
| else |
| printEscapedString(loc.getFilename()); |
| os << ':' << loc.getLine() << ':' << loc.getColumn(); |
| }) |
| .Case<NameLoc>([&](NameLoc loc) { |
| printEscapedString(loc.getName()); |
| |
| // Print the child if it isn't unknown. |
| auto childLoc = loc.getChildLoc(); |
| if (!llvm::isa<UnknownLoc>(childLoc)) { |
| os << '('; |
| printLocationInternal(childLoc, pretty); |
| os << ')'; |
| } |
| }) |
| .Case<CallSiteLoc>([&](CallSiteLoc loc) { |
| Location caller = loc.getCaller(); |
| Location callee = loc.getCallee(); |
| if (!pretty) |
| os << "callsite("; |
| printLocationInternal(callee, pretty); |
| if (pretty) { |
| if (llvm::isa<NameLoc>(callee)) { |
| if (llvm::isa<FileLineColLoc>(caller)) { |
| os << " at "; |
| } else { |
| os << newLine << " at "; |
| } |
| } else { |
| os << newLine << " at "; |
| } |
| } else { |
| os << " at "; |
| } |
| printLocationInternal(caller, pretty); |
| if (!pretty) |
| os << ")"; |
| }) |
| .Case<FusedLoc>([&](FusedLoc loc) { |
| if (!pretty) |
| os << "fused"; |
| if (Attribute metadata = loc.getMetadata()) { |
| os << '<'; |
| printAttribute(metadata); |
| os << '>'; |
| } |
| os << '['; |
| interleave( |
| loc.getLocations(), |
| [&](Location loc) { printLocationInternal(loc, pretty); }, |
| [&]() { os << ", "; }); |
| os << ']'; |
| }); |
| } |
| |
| /// Print a floating point value in a way that the parser will be able to |
| /// round-trip losslessly. |
| static void printFloatValue(const APFloat &apValue, raw_ostream &os, |
| bool *printedHex = nullptr) { |
| // We would like to output the FP constant value in exponential notation, |
| // but we cannot do this if doing so will lose precision. Check here to |
| // make sure that we only output it in exponential format if we can parse |
| // the value back and get the same value. |
| bool isInf = apValue.isInfinity(); |
| bool isNaN = apValue.isNaN(); |
| if (!isInf && !isNaN) { |
| SmallString<128> strValue; |
| apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, |
| /*TruncateZero=*/false); |
| |
| // Check to make sure that the stringized number is not some string like |
| // "Inf" or NaN, that atof will accept, but the lexer will not. Check |
| // that the string matches the "[-+]?[0-9]" regex. |
| assert(((strValue[0] >= '0' && strValue[0] <= '9') || |
| ((strValue[0] == '-' || strValue[0] == '+') && |
| (strValue[1] >= '0' && strValue[1] <= '9'))) && |
| "[-+]?[0-9] regex does not match!"); |
| |
| // Parse back the stringized version and check that the value is equal |
| // (i.e., there is no precision loss). |
| if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { |
| os << strValue; |
| return; |
| } |
| |
| // If it is not, use the default format of APFloat instead of the |
| // exponential notation. |
| strValue.clear(); |
| apValue.toString(strValue); |
| |
| // Make sure that we can parse the default form as a float. |
| if (strValue.str().contains('.')) { |
| os << strValue; |
| return; |
| } |
| } |
| |
| // Print special values in hexadecimal format. The sign bit should be included |
| // in the literal. |
| if (printedHex) |
| *printedHex = true; |
| SmallVector<char, 16> str; |
| APInt apInt = apValue.bitcastToAPInt(); |
| apInt.toString(str, /*Radix=*/16, /*Signed=*/false, |
| /*formatAsCLiteral=*/true); |
| os << str; |
| } |
| |
| void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { |
| if (printerFlags.shouldPrintDebugInfoPrettyForm()) |
| return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true); |
| |
| os << "loc("; |
| if (!allowAlias || failed(printAlias(loc))) |
| printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true); |
| os << ')'; |
| } |
| |
| void AsmPrinter::Impl::printResourceHandle( |
| const AsmDialectResourceHandle &resource) { |
| auto *interface = cast<OpAsmDialectInterface>(resource.getDialect()); |
| os << interface->getResourceKey(resource); |
| state.getDialectResources()[resource.getDialect()].insert(resource); |
| } |
| |
| /// Returns true if the given dialect symbol data is simple enough to print in |
| /// the pretty form. This is essentially when the symbol takes the form: |
| /// identifier (`<` body `>`)? |
| static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { |
| // The name must start with an identifier. |
| if (symName.empty() || !isalpha(symName.front())) |
| return false; |
| |
| // Ignore all the characters that are valid in an identifier in the symbol |
| // name. |
| symName = symName.drop_while( |
| [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; }); |
| if (symName.empty()) |
| return true; |
| |
| // If we got to an unexpected character, then it must be a <>. Check that the |
| // rest of the symbol is wrapped within <>. |
| return symName.front() == '<' && symName.back() == '>'; |
| } |
| |
| /// Print the given dialect symbol to the stream. |
| static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, |
| StringRef dialectName, StringRef symString) { |
| os << symPrefix << dialectName; |
| |
| // If this symbol name is simple enough, print it directly in pretty form, |
| // otherwise, we print it as an escaped string. |
| if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { |
| os << '.' << symString; |
| return; |
| } |
| |
| os << '<' << symString << '>'; |
| } |
| |
| /// Returns true if the given string can be represented as a bare identifier. |
| static bool isBareIdentifier(StringRef name) { |
| // By making this unsigned, the value passed in to isalnum will always be |
| // in the range 0-255. This is important when building with MSVC because |
| // its implementation will assert. This situation can arise when dealing |
| // with UTF-8 multibyte characters. |
| if (name.empty() || (!isalpha(name[0]) && name[0] != '_')) |
| return false; |
| return llvm::all_of(name.drop_front(), [](unsigned char c) { |
| return isalnum(c) || c == '_' || c == '$' || c == '.'; |
| }); |
| } |
| |
| /// Print the given string as a keyword, or a quoted and escaped string if it |
| /// has any special or non-printable characters in it. |
| static void printKeywordOrString(StringRef keyword, raw_ostream &os) { |
| // If it can be represented as a bare identifier, write it directly. |
| if (isBareIdentifier(keyword)) { |
| os << keyword; |
| return; |
| } |
| |
| // Otherwise, output the keyword wrapped in quotes with proper escaping. |
| os << "\""; |
| printEscapedString(keyword, os); |
| os << '"'; |
| } |
| |
| /// Print the given string as a symbol reference. A symbol reference is |
| /// represented as a string prefixed with '@'. The reference is surrounded with |
| /// ""'s and escaped if it has any special or non-printable characters in it. |
| static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { |
| if (symbolRef.empty()) { |
| os << "@<<INVALID EMPTY SYMBOL>>"; |
| return; |
| } |
| os << '@'; |
| printKeywordOrString(symbolRef, os); |
| } |
| |
| // Print out a valid ElementsAttr that is succinct and can represent any |
| // potential shape/type, for use when eliding a large ElementsAttr. |
| // |
| // We choose to use a dense resource ElementsAttr literal with conspicuous |
| // content to hopefully alert readers to the fact that this has been elided. |
| static void printElidedElementsAttr(raw_ostream &os) { |
| os << R"(dense_resource<__elided__>)"; |
| } |
| |
| LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { |
| return state.getAliasState().getAlias(attr, os); |
| } |
| |
| LogicalResult AsmPrinter::Impl::printAlias(Type type) { |
| return state.getAliasState().getAlias(type, os); |
| } |
| |
| void AsmPrinter::Impl::printAttribute(Attribute attr, |
| AttrTypeElision typeElision) { |
| if (!attr) { |
| os << "<<NULL ATTRIBUTE>>"; |
| return; |
| } |
| |
| // Try to print an alias for this attribute. |
| if (succeeded(printAlias(attr))) |
| return; |
| return printAttributeImpl(attr, typeElision); |
| } |
| |
| void AsmPrinter::Impl::printAttributeImpl(Attribute attr, |
| AttrTypeElision typeElision) { |
| if (!isa<BuiltinDialect>(attr.getDialect())) { |
| printDialectAttribute(attr); |
| } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) { |
| printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), |
| opaqueAttr.getAttrData()); |
| } else if (llvm::isa<UnitAttr>(attr)) { |
| os << "unit"; |
| return; |
| } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) { |
| os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<"; |
| if (!llvm::isa<UnitAttr>(distinctAttr.getReferencedAttr())) { |
| printAttribute(distinctAttr.getReferencedAttr()); |
| } |
| os << '>'; |
| return; |
| } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) { |
| os << '{'; |
| interleaveComma(dictAttr.getValue(), |
| [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
| os << '}'; |
| |
| } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) { |
| Type intType = intAttr.getType(); |
| if (intType.isSignlessInteger(1)) { |
| os << (intAttr.getValue().getBoolValue() ? "true" : "false"); |
| |
| // Boolean integer attributes always elides the type. |
| return; |
| } |
| |
| // Only print attributes as unsigned if they are explicitly unsigned or are |
| // signless 1-bit values. Indexes, signed values, and multi-bit signless |
| // values print as signed. |
| bool isUnsigned = |
| intType.isUnsignedInteger() || intType.isSignlessInteger(1); |
| intAttr.getValue().print(os, !isUnsigned); |
| |
| // IntegerAttr elides the type if I64. |
| if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64)) |
| return; |
| |
| } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) { |
| bool printedHex = false; |
| printFloatValue(floatAttr.getValue(), os, &printedHex); |
| |
| // FloatAttr elides the type if F64. |
| if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() && |
| !printedHex) |
| return; |
| |
| } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) { |
| printEscapedString(strAttr.getValue()); |
| |
| } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) { |
| os << '['; |
| interleaveComma(arrayAttr.getValue(), [&](Attribute attr) { |
| printAttribute(attr, AttrTypeElision::May); |
| }); |
| os << ']'; |
| |
| } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) { |
| os << "affine_map<"; |
| affineMapAttr.getValue().print(os); |
| os << '>'; |
| |
| // AffineMap always elides the type. |
| return; |
| |
| } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) { |
| os << "affine_set<"; |
| integerSetAttr.getValue().print(os); |
| os << '>'; |
| |
| // IntegerSet always elides the type. |
| return; |
| |
| } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) { |
| printType(typeAttr.getValue()); |
| |
| } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) { |
| printSymbolReference(refAttr.getRootReference().getValue(), os); |
| for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { |
| os << "::"; |
| printSymbolReference(nestedRef.getValue(), os); |
| } |
| |
| } else if (auto intOrFpEltAttr = |
| llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) { |
| if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) { |
| printElidedElementsAttr(os); |
| } else { |
| os << "dense<"; |
| printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); |
| os << '>'; |
| } |
| |
| } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) { |
| if (printerFlags.shouldElideElementsAttr(strEltAttr)) { |
| printElidedElementsAttr(os); |
| } else { |
| os << "dense<"; |
| printDenseStringElementsAttr(strEltAttr); |
| os << '>'; |
| } |
| |
| } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) { |
| if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) || |
| printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) { |
| printElidedElementsAttr(os); |
| } else { |
| os << "sparse<"; |
| DenseIntElementsAttr indices = sparseEltAttr.getIndices(); |
| if (indices.getNumElements() != 0) { |
| printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false); |
| os << ", "; |
| printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true); |
| } |
| os << '>'; |
| } |
| } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) { |
| stridedLayoutAttr.print(os); |
| } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) { |
| os << "array<"; |
| printType(denseArrayAttr.getElementType()); |
| if (!denseArrayAttr.empty()) { |
| os << ": "; |
| printDenseArrayAttr(denseArrayAttr); |
| } |
| os << ">"; |
| return; |
| } else if (auto resourceAttr = |
| llvm::dyn_cast<DenseResourceElementsAttr>(attr)) { |
| os << "dense_resource<"; |
| printResourceHandle(resourceAttr.getRawHandle()); |
| os << ">"; |
| } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) { |
| printLocation(locAttr); |
| } else { |
| llvm::report_fatal_error("Unknown builtin attribute"); |
| } |
| // Don't print the type if we must elide it, or if it is a None type. |
| if (typeElision != AttrTypeElision::Must) { |
| if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) { |
| Type attrType = typedAttr.getType(); |
| if (!llvm::isa<NoneType>(attrType)) { |
| os << " : "; |
| printType(attrType); |
| } |
| } |
| } |
| } |
| |
| /// Print the integer element of a DenseElementsAttr. |
| static void printDenseIntElement(const APInt &value, raw_ostream &os, |
| Type type) { |
| if (type.isInteger(1)) |
| os << (value.getBoolValue() ? "true" : "false"); |
| else |
| value.print(os, !type.isUnsignedInteger()); |
| } |
| |
| static void |
| printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os, |
| function_ref<void(unsigned)> printEltFn) { |
| // Special case for 0-d and splat tensors. |
| if (isSplat) |
| return printEltFn(0); |
| |
| // Special case for degenerate tensors. |
| auto numElements = type.getNumElements(); |
| if (numElements == 0) |
| return; |
| |
| // We use a mixed-radix counter to iterate through the shape. When we bump a |
| // non-least-significant digit, we emit a close bracket. When we next emit an |
| // element we re-open all closed brackets. |
| |
| // The mixed-radix counter, with radices in 'shape'. |
| int64_t rank = type.getRank(); |
| SmallVector<unsigned, 4> counter(rank, 0); |
| // The number of brackets that have been opened and not closed. |
| unsigned openBrackets = 0; |
| |
| auto shape = type.getShape(); |
| auto bumpCounter = [&] { |
| // Bump the least significant digit. |
| ++counter[rank - 1]; |
| // Iterate backwards bubbling back the increment. |
| for (unsigned i = rank - 1; i > 0; --i) |
| if (counter[i] >= shape[i]) { |
| // Index 'i' is rolled over. Bump (i-1) and close a bracket. |
| counter[i] = 0; |
| ++counter[i - 1]; |
| --openBrackets; |
| os << ']'; |
| } |
| }; |
| |
| for (unsigned idx = 0, e = numElements; idx != e; ++idx) { |
| if (idx != 0) |
| os << ", "; |
| while (openBrackets++ < rank) |
| os << '['; |
| openBrackets = rank; |
| printEltFn(idx); |
| bumpCounter(); |
| } |
| while (openBrackets-- > 0) |
| os << ']'; |
| } |
| |
| void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr, |
| bool allowHex) { |
| if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) |
| return printDenseStringElementsAttr(stringAttr); |
| |
| printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr), |
| allowHex); |
| } |
| |
| void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( |
| DenseIntOrFPElementsAttr attr, bool allowHex) { |
| auto type = attr.getType(); |
| auto elementType = type.getElementType(); |
| |
| // Check to see if we should format this attribute as a hex string. |
| if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) { |
| ArrayRef<char> rawData = attr.getRawData(); |
| if (llvm::endianness::native == llvm::endianness::big) { |
| // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE |
| // machines. It is converted here to print in LE format. |
| SmallVector<char, 64> outDataVec(rawData.size()); |
| MutableArrayRef<char> convRawData(outDataVec); |
| DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( |
| rawData, convRawData, type); |
| printHexString(convRawData); |
| } else { |
| printHexString(rawData); |
| } |
| |
| return; |
| } |
| |
| if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) { |
| Type complexElementType = complexTy.getElementType(); |
| // Note: The if and else below had a common lambda function which invoked |
| // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2 |
| // and hence was replaced. |
| if (llvm::isa<IntegerType>(complexElementType)) { |
| auto valueIt = attr.value_begin<std::complex<APInt>>(); |
| printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| auto complexValue = *(valueIt + index); |
| os << "("; |
| printDenseIntElement(complexValue.real(), os, complexElementType); |
| os << ","; |
| printDenseIntElement(complexValue.imag(), os, complexElementType); |
| os << ")"; |
| }); |
| } else { |
| auto valueIt = attr.value_begin<std::complex<APFloat>>(); |
| printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| auto complexValue = *(valueIt + index); |
| os << "("; |
| printFloatValue(complexValue.real(), os); |
| os << ","; |
| printFloatValue(complexValue.imag(), os); |
| os << ")"; |
| }); |
| } |
| } else if (elementType.isIntOrIndex()) { |
| auto valueIt = attr.value_begin<APInt>(); |
| printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| printDenseIntElement(*(valueIt + index), os, elementType); |
| }); |
| } else { |
| assert(llvm::isa<FloatType>(elementType) && "unexpected element type"); |
| auto valueIt = attr.value_begin<APFloat>(); |
| printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { |
| printFloatValue(*(valueIt + index), os); |
| }); |
| } |
| } |
| |
| void AsmPrinter::Impl::printDenseStringElementsAttr( |
| DenseStringElementsAttr attr) { |
| ArrayRef<StringRef> data = attr.getRawStringData(); |
| auto printFn = [&](unsigned index) { printEscapedString(data[index]); }; |
| printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); |
| } |
| |
| void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { |
| Type type = attr.getElementType(); |
| unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); |
| unsigned byteSize = bitwidth / 8; |
| ArrayRef<char> data = attr.getRawData(); |
| |
| auto printElementAt = [&](unsigned i) { |
| APInt value(bitwidth, 0); |
| if (bitwidth) { |
| llvm::LoadIntFromMemory( |
| value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i), |
| byteSize); |
| } |
| // Print the data as-is or as a float. |
| if (type.isIntOrIndex()) { |
| printDenseIntElement(value, getStream(), type); |
| } else { |
| APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value); |
| printFloatValue(fltVal, getStream()); |
| } |
| }; |
| llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(), |
| printElementAt); |
| } |
| |
| void AsmPrinter::Impl::printType(Type type) { |
| if (!type) { |
| os << "<<NULL TYPE>>"; |
| return; |
| } |
| |
| // Try to print an alias for this type. |
| if (succeeded(printAlias(type))) |
| return; |
| return printTypeImpl(type); |
| } |
| |
| void AsmPrinter::Impl::printTypeImpl(Type type) { |
| TypeSwitch<Type>(type) |
| .Case<OpaqueType>([&](OpaqueType opaqueTy) { |
| printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), |
| opaqueTy.getTypeData()); |
| }) |
| .Case<IndexType>([&](Type) { os << "index"; }) |
| .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; }) |
| .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; }) |
| .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; }) |
| .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; }) |
| .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; }) |
| .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; }) |
| .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; }) |
| .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; }) |
| .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; }) |
| .Case<BFloat16Type>([&](Type) { os << "bf16"; }) |
| .Case<Float16Type>([&](Type) { os << "f16"; }) |
| .Case<FloatTF32Type>([&](Type) { os << "tf32"; }) |
| .Case<Float32Type>([&](Type) { os << "f32"; }) |
| .Case<Float64Type>([&](Type) { os << "f64"; }) |
| .Case<Float80Type>([&](Type) { os << "f80"; }) |
| .Case<Float128Type>([&](Type) { os << "f128"; }) |
| .Case<IntegerType>([&](IntegerType integerTy) { |
| if (integerTy.isSigned()) |
| os << 's'; |
| else if (integerTy.isUnsigned()) |
| os << 'u'; |
| os << 'i' << integerTy.getWidth(); |
| }) |
| .Case<FunctionType>([&](FunctionType funcTy) { |
| os << '('; |
| interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); |
| os << ") -> "; |
| ArrayRef<Type> results = funcTy.getResults(); |
| if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) { |
| printType(results[0]); |
| } else { |
| os << '('; |
| interleaveComma(results, [&](Type ty) { printType(ty); }); |
| os << ')'; |
| } |
| }) |
| .Case<VectorType>([&](VectorType vectorTy) { |
| auto scalableDims = vectorTy.getScalableDims(); |
| os << "vector<"; |
| auto vShape = vectorTy.getShape(); |
| unsigned lastDim = vShape.size(); |
| unsigned dimIdx = 0; |
| for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { |
| if (!scalableDims.empty() && scalableDims[dimIdx]) |
| os << '['; |
| os << vShape[dimIdx]; |
| if (!scalableDims.empty() && scalableDims[dimIdx]) |
| os << ']'; |
| os << 'x'; |
| } |
| printType(vectorTy.getElementType()); |
| os << '>'; |
| }) |
| .Case<RankedTensorType>([&](RankedTensorType tensorTy) { |
| os << "tensor<"; |
| printDimensionList(tensorTy.getShape()); |
| if (!tensorTy.getShape().empty()) |
| os << 'x'; |
| printType(tensorTy.getElementType()); |
| // Only print the encoding attribute value if set. |
| if (tensorTy.getEncoding()) { |
| os << ", "; |
| printAttribute(tensorTy.getEncoding()); |
| } |
| os << '>'; |
| }) |
| .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) { |
| os << "tensor<*x"; |
| printType(tensorTy.getElementType()); |
| os << '>'; |
| }) |
| .Case<MemRefType>([&](MemRefType memrefTy) { |
| os << "memref<"; |
| printDimensionList(memrefTy.getShape()); |
| if (!memrefTy.getShape().empty()) |
| os << 'x'; |
| printType(memrefTy.getElementType()); |
| MemRefLayoutAttrInterface layout = memrefTy.getLayout(); |
| if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) { |
| os << ", "; |
| printAttribute(memrefTy.getLayout(), AttrTypeElision::May); |
| } |
| // Only print the memory space if it is the non-default one. |
| if (memrefTy.getMemorySpace()) { |
| os << ", "; |
| printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); |
| } |
| os << '>'; |
| }) |
| .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) { |
| os << "memref<*x"; |
| printType(memrefTy.getElementType()); |
| // Only print the memory space if it is the non-default one. |
| if (memrefTy.getMemorySpace()) { |
| os << ", "; |
| printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); |
| } |
| os << '>'; |
| }) |
| .Case<ComplexType>([&](ComplexType complexTy) { |
| os << "complex<"; |
| printType(complexTy.getElementType()); |
| os << '>'; |
| }) |
| .Case<TupleType>([&](TupleType tupleTy) { |
| os << "tuple<"; |
| interleaveComma(tupleTy.getTypes(), |
| [&](Type type) { printType(type); }); |
| os << '>'; |
| }) |
| .Case<NoneType>([&](Type) { os << "none"; }) |
| .Default([&](Type type) { return printDialectType(type); }); |
| } |
| |
| void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| ArrayRef<StringRef> elidedAttrs, |
| bool withKeyword) { |
| // If there are no attributes, then there is nothing to be done. |
| if (attrs.empty()) |
| return; |
| |
| // Functor used to print a filtered attribute list. |
| auto printFilteredAttributesFn = [&](auto filteredAttrs) { |
| // Print the 'attributes' keyword if necessary. |
| if (withKeyword) |
| os << " attributes"; |
| |
| // Otherwise, print them all out in braces. |
| os << " {"; |
| interleaveComma(filteredAttrs, |
| [&](NamedAttribute attr) { printNamedAttribute(attr); }); |
| os << '}'; |
| }; |
| |
| // If no attributes are elided, we can directly print with no filtering. |
| if (elidedAttrs.empty()) |
| return printFilteredAttributesFn(attrs); |
| |
| // Otherwise, filter out any attributes that shouldn't be included. |
| llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(), |
| elidedAttrs.end()); |
| auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { |
| return !elidedAttrsSet.contains(attr.getName().strref()); |
| }); |
| if (!filteredAttrs.empty()) |
| printFilteredAttributesFn(filteredAttrs); |
| } |
| void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { |
| // Print the name without quotes if possible. |
| ::printKeywordOrString(attr.getName().strref(), os); |
| |
| // Pretty printing elides the attribute value for unit attributes. |
| if (llvm::isa<UnitAttr>(attr.getValue())) |
| return; |
| |
| os << " = "; |
| printAttribute(attr.getValue()); |
| } |
| |
| void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { |
| auto &dialect = attr.getDialect(); |
| |
| // Ask the dialect to serialize the attribute to a string. |
| std::string attrName; |
| { |
| llvm::raw_string_ostream attrNameStr(attrName); |
| Impl subPrinter(attrNameStr, state); |
| DialectAsmPrinter printer(subPrinter); |
| dialect.printAttribute(attr, printer); |
| } |
| printDialectSymbol(os, "#", dialect.getNamespace(), attrName); |
| } |
| |
| void AsmPrinter::Impl::printDialectType(Type type) { |
| auto &dialect = type.getDialect(); |
| |
| // Ask the dialect to serialize the type to a string. |
| std::string typeName; |
| { |
| llvm::raw_string_ostream typeNameStr(typeName); |
| Impl subPrinter(typeNameStr, state); |
| DialectAsmPrinter printer(subPrinter); |
| dialect.printType(type, printer); |
| } |
| printDialectSymbol(os, "!", dialect.getNamespace(), typeName); |
| } |
| |
| void AsmPrinter::Impl::printEscapedString(StringRef str) { |
| os << "\""; |
| llvm::printEscapedString(str, os); |
| os << "\""; |
| } |
| |
| void AsmPrinter::Impl::printHexString(StringRef str) { |
| os << "\"0x" << llvm::toHex(str) << "\""; |
| } |
| void AsmPrinter::Impl::printHexString(ArrayRef<char> data) { |
| printHexString(StringRef(data.data(), data.size())); |
| } |
| |
| LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { |
| return state.pushCyclicPrinting(opaquePointer); |
| } |
| |
| void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } |
| |
| void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) { |
| detail::printDimensionList(os, shape); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // AsmPrinter |
| //===--------------------------------------------------------------------===// |
| |
| AsmPrinter::~AsmPrinter() = default; |
| |
| raw_ostream &AsmPrinter::getStream() const { |
| assert(impl && "expected AsmPrinter::getStream to be overriden"); |
| return impl->getStream(); |
| } |
| |
| /// Print the given floating point value in a stablized form. |
| void AsmPrinter::printFloat(const APFloat &value) { |
| assert(impl && "expected AsmPrinter::printFloat to be overriden"); |
| printFloatValue(value, impl->getStream()); |
| } |
| |
| void AsmPrinter::printType(Type type) { |
| assert(impl && "expected AsmPrinter::printType to be overriden"); |
| impl->printType(type); |
| } |
| |
| void AsmPrinter::printAttribute(Attribute attr) { |
| assert(impl && "expected AsmPrinter::printAttribute to be overriden"); |
| impl->printAttribute(attr); |
| } |
| |
| LogicalResult AsmPrinter::printAlias(Attribute attr) { |
| assert(impl && "expected AsmPrinter::printAlias to be overriden"); |
| return impl->printAlias(attr); |
| } |
| |
| LogicalResult AsmPrinter::printAlias(Type type) { |
| assert(impl && "expected AsmPrinter::printAlias to be overriden"); |
| return impl->printAlias(type); |
| } |
| |
| void AsmPrinter::printAttributeWithoutType(Attribute attr) { |
| assert(impl && |
| "expected AsmPrinter::printAttributeWithoutType to be overriden"); |
| impl->printAttribute(attr, Impl::AttrTypeElision::Must); |
| } |
| |
| void AsmPrinter::printKeywordOrString(StringRef keyword) { |
| assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden"); |
| ::printKeywordOrString(keyword, impl->getStream()); |
| } |
| |
| void AsmPrinter::printString(StringRef keyword) { |
| assert(impl && "expected AsmPrinter::printString to be overriden"); |
| *this << '"'; |
| printEscapedString(keyword, getStream()); |
| *this << '"'; |
| } |
| |
| void AsmPrinter::printSymbolName(StringRef symbolRef) { |
| assert(impl && "expected AsmPrinter::printSymbolName to be overriden"); |
| ::printSymbolReference(symbolRef, impl->getStream()); |
| } |
| |
| void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) { |
| assert(impl && "expected AsmPrinter::printResourceHandle to be overriden"); |
| impl->printResourceHandle(resource); |
| } |
| |
| void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) { |
| detail::printDimensionList(getStream(), shape); |
| } |
| |
| LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { |
| return impl->pushCyclicPrinting(opaquePointer); |
| } |
| |
| void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } |
| |
| //===----------------------------------------------------------------------===// |
| // Affine expressions and maps |
| //===----------------------------------------------------------------------===// |
| |
| void AsmPrinter::Impl::printAffineExpr( |
| AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) { |
| printAffineExprInternal(expr, BindingStrength::Weak, printValueName); |
| } |
| |
| void AsmPrinter::Impl::printAffineExprInternal( |
| AffineExpr expr, BindingStrength enclosingTightness, |
| function_ref<void(unsigned, bool)> printValueName) { |
| const char *binopSpelling = nullptr; |
| switch (expr.getKind()) { |
| case AffineExprKind::SymbolId: { |
| unsigned pos = cast<AffineSymbolExpr>(expr).getPosition(); |
| if (printValueName) |
| printValueName(pos, /*isSymbol=*/true); |
| else |
| os << 's' << pos; |
| return; |
| } |
| case AffineExprKind::DimId: { |
| unsigned pos = cast<AffineDimExpr>(expr).getPosition(); |
| if (printValueName) |
| printValueName(pos, /*isSymbol=*/false); |
| else |
| os << 'd' << pos; |
| return; |
| } |
| case AffineExprKind::Constant: |
| os << cast<AffineConstantExpr>(expr).getValue(); |
| return; |
| case AffineExprKind::Add: |
| binopSpelling = " + "; |
| break; |
| case AffineExprKind::Mul: |
| binopSpelling = " * "; |
| break; |
| case AffineExprKind::FloorDiv: |
| binopSpelling = " floordiv "; |
| break; |
| case AffineExprKind::CeilDiv: |
| binopSpelling = " ceildiv "; |
| break; |
| case AffineExprKind::Mod: |
| binopSpelling = " mod "; |
| break; |
| } |
| |
| auto binOp = cast<AffineBinaryOpExpr>(expr); |
| AffineExpr lhsExpr = binOp.getLHS(); |
| AffineExpr rhsExpr = binOp.getRHS(); |
| |
| // Handle tightly binding binary operators. |
| if (binOp.getKind() != AffineExprKind::Add) { |
| if (enclosingTightness == BindingStrength::Strong) |
| os << '('; |
| |
| // Pretty print multiplication with -1. |
| auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr); |
| if (rhsConst && binOp.getKind() == AffineExprKind::Mul && |
| rhsConst.getValue() == -1) { |
| os << "-"; |
| printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| |
| printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); |
| |
| os << binopSpelling; |
| printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); |
| |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| |
| // Print out special "pretty" forms for add. |
| if (enclosingTightness == BindingStrength::Strong) |
| os << '('; |
| |
| // Pretty print addition to a product that has a negative operand as a |
| // subtraction. |
| if (auto rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) { |
| if (rhs.getKind() == AffineExprKind::Mul) { |
| AffineExpr rrhsExpr = rhs.getRHS(); |
| if (auto rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) { |
| if (rrhs.getValue() == -1) { |
| printAffineExprInternal(lhsExpr, BindingStrength::Weak, |
| printValueName); |
| os << " - "; |
| if (rhs.getLHS().getKind() == AffineExprKind::Add) { |
| printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, |
| printValueName); |
| } else { |
| printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, |
| printValueName); |
| } |
| |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| |
| if (rrhs.getValue() < -1) { |
| printAffineExprInternal(lhsExpr, BindingStrength::Weak, |
| printValueName); |
| os << " - "; |
| printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, |
| printValueName); |
| os << " * " << -rrhs.getValue(); |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| } |
| } |
| } |
| |
| // Pretty print addition to a negative number as a subtraction. |
| if (auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr)) { |
| if (rhsConst.getValue() < 0) { |
| printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); |
| os << " - " << -rhsConst.getValue(); |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| return; |
| } |
| } |
| |
| printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); |
| |
| os << " + "; |
| printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); |
| |
| if (enclosingTightness == BindingStrength::Strong) |
| os << ')'; |
| } |
| |
| void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) { |
| printAffineExprInternal(expr, BindingStrength::Weak); |
| isEq ? os << " == 0" : os << " >= 0"; |
| } |
| |
| void AsmPrinter::Impl::printAffineMap(AffineMap map) { |
| // Dimension identifiers. |
| os << '('; |
| for (int i = 0; i < (int)map.getNumDims() - 1; ++i) |
| os << 'd' << i << ", "; |
| if (map.getNumDims() >= 1) |
| os << 'd' << map.getNumDims() - 1; |
| os << ')'; |
| |
| // Symbolic identifiers. |
| if (map.getNumSymbols() != 0) { |
| os << '['; |
| for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) |
| os << 's' << i << ", "; |
| if (map.getNumSymbols() >= 1) |
| os << 's' << map.getNumSymbols() - 1; |
| os << ']'; |
| } |
| |
| // Result affine expressions. |
| os << " -> ("; |
| interleaveComma(map.getResults(), |
| [&](AffineExpr expr) { printAffineExpr(expr); }); |
| os << ')'; |
| } |
| |
| void AsmPrinter::Impl::printIntegerSet(IntegerSet set) { |
| // Dimension identifiers. |
| os << '('; |
| for (unsigned i = 1; i < set.getNumDims(); ++i) |
| os << 'd' << i - 1 << ", "; |
| if (set.getNumDims() >= 1) |
| os << 'd' << set.getNumDims() - 1; |
| os << ')'; |
| |
| // Symbolic identifiers. |
| if (set.getNumSymbols() != 0) { |
| os << '['; |
| for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) |
| os << 's' << i << ", "; |
| if (set.getNumSymbols() >= 1) |
| os << 's' << set.getNumSymbols() - 1; |
| os << ']'; |
| } |
| |
| // Print constraints. |
| os << " : ("; |
| int numConstraints = set.getNumConstraints(); |
| for (int i = 1; i < numConstraints; ++i) { |
| printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); |
| os << ", "; |
| } |
| if (numConstraints >= 1) |
| printAffineConstraint(set.getConstraint(numConstraints - 1), |
| set.isEq(numConstraints - 1)); |
| os << ')'; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OperationPrinter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class contains the logic for printing operations, regions, and blocks. |
| class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { |
| public: |
| using Impl = AsmPrinter::Impl; |
| using Impl::printType; |
| |
| explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) |
| : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {} |
| |
| /// Print the given top-level operation. |
| void printTopLevelOperation(Operation *op); |
| |
| /// Print the given operation, including its left-hand side and its right-hand |
| /// side, with its indent and location. |
| void printFullOpWithIndentAndLoc(Operation *op); |
| /// Print the given operation, including its left-hand side and its right-hand |
| /// side, but not including indentation and location. |
| void printFullOp(Operation *op); |
| /// Print the right-hand size of the given operation in the custom or generic |
| /// form. |
| void printCustomOrGenericOp(Operation *op) override; |
| /// Print the right-hand side of the given operation in the generic form. |
| void printGenericOp(Operation *op, bool printOpName) override; |
| |
| /// Print the name of the given block. |
| void printBlockName(Block *block); |
| |
| /// Print the given block. If 'printBlockArgs' is false, the arguments of the |
| /// block are not printed. If 'printBlockTerminator' is false, the terminator |
| /// operation of the block is not printed. |
| void print(Block *block, bool printBlockArgs = true, |
| bool printBlockTerminator = true); |
| |
| /// Print the ID of the given value, optionally with its result number. |
| void printValueID(Value value, bool printResultNo = true, |
| raw_ostream *streamOverride = nullptr) const; |
| |
| /// Print the ID of the given operation. |
| void printOperationID(Operation *op, |
| raw_ostream *streamOverride = nullptr) const; |
| |
| //===--------------------------------------------------------------------===// |
| // OpAsmPrinter methods |
| //===--------------------------------------------------------------------===// |
| |
| /// Print a loc(...) specifier if printing debug info is enabled. Locations |
| /// may be deferred with an alias. |
| void printOptionalLocationSpecifier(Location loc) override { |
| printTrailingLocation(loc); |
| } |
| |
| /// Print a newline and indent the printer to the start of the current |
| /// operation. |
| void printNewline() override { |
| os << newLine; |
| os.indent(currentIndent); |
| } |
| |
| /// Increase indentation. |
| void increaseIndent() override { currentIndent += indentWidth; } |
| |
| /// Decrease indentation. |
| void decreaseIndent() override { currentIndent -= indentWidth; } |
| |
| /// Print a block argument in the usual format of: |
| /// %ssaName : type {attr1=42} loc("here") |
| /// where location printing is controlled by the standard internal option. |
| /// You may pass omitType=true to not print a type, and pass an empty |
| /// attribute list if you don't care for attributes. |
| void printRegionArgument(BlockArgument arg, |
| ArrayRef<NamedAttribute> argAttrs = {}, |
| bool omitType = false) override; |
| |
| /// Print the ID for the given value. |
| void printOperand(Value value) override { printValueID(value); } |
| void printOperand(Value value, raw_ostream &os) override { |
| printValueID(value, /*printResultNo=*/true, &os); |
| } |
| |
| /// Print an optional attribute dictionary with a given set of elided values. |
| void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
| ArrayRef<StringRef> elidedAttrs = {}) override { |
| Impl::printOptionalAttrDict(attrs, elidedAttrs); |
| } |
| void printOptionalAttrDictWithKeyword( |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<StringRef> elidedAttrs = {}) override { |
| Impl::printOptionalAttrDict(attrs, elidedAttrs, |
| /*withKeyword=*/true); |
| } |
| |
| /// Print the given successor. |
| void printSuccessor(Block *successor) override; |
| |
| /// Print an operation successor with the operands used for the block |
| /// arguments. |
| void printSuccessorAndUseList(Block *successor, |
| ValueRange succOperands) override; |
| |
| /// Print the given region. |
| void printRegion(Region ®ion, bool printEntryBlockArgs, |
| bool printBlockTerminators, bool printEmptyBlock) override; |
| |
| /// Renumber the arguments for the specified region to the same names as the |
| /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
| /// operations. If any entry in namesToUse is null, the corresponding |
| /// argument name is left alone. |
| void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { |
| state.getSSANameState().shadowRegionArgs(region, namesToUse); |
| } |
| |
| /// Print the given affine map with the symbol and dimension operands printed |
| /// inline with the map. |
| void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
| ValueRange operands) override; |
| |
| /// Print the given affine expression with the symbol and dimension operands |
| /// printed inline with the expression. |
| void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
| ValueRange symOperands) override; |
| |
| /// Print users of this operation or id of this operation if it has no result. |
| void printUsersComment(Operation *op); |
| |
| /// Print users of this block arg. |
| void printUsersComment(BlockArgument arg); |
| |
| /// Print the users of a value. |
| void printValueUsers(Value value); |
| |
| /// Print either the ids of the result values or the id of the operation if |
| /// the operation has no results. |
| void printUserIDs(Operation *user, bool prefixComma = false); |
| |
| private: |
| /// This class represents a resource builder implementation for the MLIR |
| /// textual assembly format. |
| class ResourceBuilder : public AsmResourceBuilder { |
| public: |
| using ValueFn = function_ref<void(raw_ostream &)>; |
| using PrintFn = function_ref<void(StringRef, ValueFn)>; |
| |
| ResourceBuilder(PrintFn printFn) : printFn(printFn) {} |
| ~ResourceBuilder() override = default; |
| |
| void buildBool(StringRef key, bool data) final { |
| printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); }); |
| } |
| |
| void buildString(StringRef key, StringRef data) final { |
| printFn(key, [&](raw_ostream &os) { |
| os << "\""; |
| llvm::printEscapedString(data, os); |
| os << "\""; |
| }); |
| } |
| |
| void buildBlob(StringRef key, ArrayRef<char> data, |
| uint32_t dataAlignment) final { |
| printFn(key, [&](raw_ostream &os) { |
| // Store the blob in a hex string containing the alignment and the data. |
| llvm::support::ulittle32_t dataAlignmentLE(dataAlignment); |
| os << "\"0x" |
| << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE), |
| sizeof(dataAlignment))) |
| << llvm::toHex(StringRef(data.data(), data.size())) << "\""; |
| }); |
| } |
| |
| private: |
| PrintFn printFn; |
| }; |
| |
| /// Print the metadata dictionary for the file, eliding it if it is empty. |
| void printFileMetadataDictionary(Operation *op); |
| |
| /// Print the resource sections for the file metadata dictionary. |
| /// `checkAddMetadataDict` is used to indicate that metadata is going to be |
| /// added, and the file metadata dictionary should be started if it hasn't |
| /// yet. |
| void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict, |
| Operation *op); |
| |
| // Contains the stack of default dialects to use when printing regions. |
| // A new dialect is pushed to the stack before parsing regions nested under an |
| // operation implementing `OpAsmOpInterface`, and popped when done. At the |
| // top-level we start with "builtin" as the default, so that the top-level |
| // `module` operation prints as-is. |
| SmallVector<StringRef> defaultDialectStack{"builtin"}; |
| |
| /// The number of spaces used for indenting nested operations. |
| const static unsigned indentWidth = 2; |
| |
| // This is the current indentation level for nested structures. |
| unsigned currentIndent = 0; |
| }; |
| } // namespace |
| |
| void OperationPrinter::printTopLevelOperation(Operation *op) { |
| // Output the aliases at the top level that can't be deferred. |
| state.getAliasState().printNonDeferredAliases(*this, newLine); |
| |
| // Print the module. |
| printFullOpWithIndentAndLoc(op); |
| os << newLine; |
| |
| // Output the aliases at the top level that can be deferred. |
| state.getAliasState().printDeferredAliases(*this, newLine); |
| |
| // Output any file level metadata. |
| printFileMetadataDictionary(op); |
| } |
| |
| void OperationPrinter::printFileMetadataDictionary(Operation *op) { |
| bool sawMetadataEntry = false; |
| auto checkAddMetadataDict = [&] { |
| if (!std::exchange(sawMetadataEntry, true)) |
| os << newLine << "{-#" << newLine; |
| }; |
| |
| // Add the various types of metadata. |
| printResourceFileMetadata(checkAddMetadataDict, op); |
| |
| // If the file dictionary exists, close it. |
| if (sawMetadataEntry) |
| os << newLine << "#-}" << newLine; |
| } |
| |
| void OperationPrinter::printResourceFileMetadata( |
| function_ref<void()> checkAddMetadataDict, Operation *op) { |
| // Functor used to add data entries to the file metadata dictionary. |
| bool hadResource = false; |
| bool needResourceComma = false; |
| bool needEntryComma = false; |
| auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, |
| auto &&...providerArgs) { |
| bool hadEntry = false; |
| auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { |
| checkAddMetadataDict(); |
| |
| auto printFormatting = [&]() { |
| // Emit the top-level resource entry if we haven't yet. |
| if (!std::exchange(hadResource, true)) { |
| if (needResourceComma) |
| os << "," << newLine; |
| os << " " << dictName << "_resources: {" << newLine; |
| } |
| // Emit the parent resource entry if we haven't yet. |
| if (!std::exchange(hadEntry, true)) { |
| if (needEntryComma) |
| os << "," << newLine; |
| os << " " << name << ": {" << newLine; |
| } else { |
| os << "," << newLine; |
| } |
| }; |
| |
| std::optional<uint64_t> charLimit = |
| printerFlags.getLargeResourceStringLimit(); |
| if (charLimit.has_value()) { |
| std::string resourceStr; |
| llvm::raw_string_ostream ss(resourceStr); |
| valueFn(ss); |
| |
| // Only print entry if it's string is small enough |
| if (resourceStr.size() > charLimit.value()) |
| return; |
| |
| printFormatting(); |
| os << " " << key << ": " << resourceStr; |
| } else { |
| printFormatting(); |
| os << " " << key << ": "; |
| valueFn(os); |
| } |
| }; |
| ResourceBuilder entryBuilder(printFn); |
| provider.buildResources(op, providerArgs..., entryBuilder); |
| |
| needEntryComma |= hadEntry; |
| if (hadEntry) |
| os << newLine << " }"; |
| }; |
| |
| // Print the `dialect_resources` section if we have any dialects with |
| // resources. |
| for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { |
| auto &dialectResources = state.getDialectResources(); |
| StringRef name = interface.getDialect()->getNamespace(); |
| auto it = dialectResources.find(interface.getDialect()); |
| if (it != dialectResources.end()) |
| processProvider("dialect", name, interface, it->second); |
| else |
| processProvider("dialect", name, interface, |
| SetVector<AsmDialectResourceHandle>()); |
| } |
| if (hadResource) |
| os << newLine << " }"; |
| |
| // Print the `external_resources` section if we have any external clients with |
| // resources. |
| needEntryComma = false; |
| needResourceComma = hadResource; |
| hadResource = false; |
| for (const auto &printer : state.getResourcePrinters()) |
| processProvider("external", printer.getName(), printer); |
| if (hadResource) |
| os << newLine << " }"; |
| } |
| |
| /// Print a block argument in the usual format of: |
| /// %ssaName : type {attr1=42} loc("here") |
| /// where location printing is controlled by the standard internal option. |
| /// You may pass omitType=true to not print a type, and pass an empty |
| /// attribute list if you don't care for attributes. |
| void OperationPrinter::printRegionArgument(BlockArgument arg, |
| ArrayRef<NamedAttribute> argAttrs, |
| bool omitType) { |
| printOperand(arg); |
| if (!omitType) { |
| os << ": "; |
| printType(arg.getType()); |
| } |
| printOptionalAttrDict(argAttrs); |
| // TODO: We should allow location aliases on block arguments. |
| printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); |
| } |
| |
| void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { |
| // Track the location of this operation. |
| state.registerOperationLocation(op, newLine.curLine, currentIndent); |
| |
| os.indent(currentIndent); |
| printFullOp(op); |
| printTrailingLocation(op->getLoc()); |
| if (printerFlags.shouldPrintValueUsers()) |
| printUsersComment(op); |
| } |
| |
| void OperationPrinter::printFullOp(Operation *op) { |
| if (size_t numResults = op->getNumResults()) { |
| auto printResultGroup = [&](size_t resultNo, size_t resultCount) { |
| printValueID(op->getResult(resultNo), /*printResultNo=*/false); |
| if (resultCount > 1) |
| os << ':' << resultCount; |
| }; |
| |
| // Check to see if this operation has multiple result groups. |
| ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op); |
| if (!resultGroups.empty()) { |
| // Interleave the groups excluding the last one, this one will be handled |
| // separately. |
| interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) { |
| printResultGroup(resultGroups[i], |
| resultGroups[i + 1] - resultGroups[i]); |
| }); |
| os << ", "; |
| printResultGroup(resultGroups.back(), numResults - resultGroups.back()); |
| |
| } else { |
| printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults); |
| } |
| |
| os << " = "; |
| } |
| |
| printCustomOrGenericOp(op); |
| } |
| |
| void OperationPrinter::printUsersComment(Operation *op) { |
| unsigned numResults = op->getNumResults(); |
| if (!numResults && op->getNumOperands()) { |
| os << " // id: "; |
| printOperationID(op); |
| } else if (numResults && op->use_empty()) { |
| os << " // unused"; |
| } else if (numResults && !op->use_empty()) { |
| // Print "user" if the operation has one result used to compute one other |
| // result, or is used in one operation with no result. |
| unsigned usedInNResults = 0; |
| unsigned usedInNOperations = 0; |
| SmallPtrSet<Operation *, 1> userSet; |
| for (Operation *user : op->getUsers()) { |
| if (userSet.insert(user).second) { |
| ++usedInNOperations; |
| usedInNResults += user->getNumResults(); |
| } |
| } |
| |
| // We already know that users is not empty. |
| bool exactlyOneUniqueUse = |
| usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1; |
| os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": "; |
| bool shouldPrintBrackets = numResults > 1; |
| auto printOpResult = [&](OpResult opResult) { |
| if (shouldPrintBrackets) |
| os << "("; |
| printValueUsers(opResult); |
| if (shouldPrintBrackets) |
| os << ")"; |
| }; |
| |
| interleaveComma(op->getResults(), printOpResult); |
| } |
| } |
| |
| void OperationPrinter::printUsersComment(BlockArgument arg) { |
| os << "// "; |
| printValueID(arg); |
| if (arg.use_empty()) { |
| os << " is unused"; |
| } else { |
| os << " is used by "; |
| printValueUsers(arg); |
| } |
| os << newLine; |
| } |
| |
| void OperationPrinter::printValueUsers(Value value) { |
| if (value.use_empty()) |
| os << "unused"; |
| |
| // One value might be used as the operand of an operation more than once. |
| // Only print the operations results once in that case. |
| SmallPtrSet<Operation *, 1> userSet; |
| for (auto [index, user] : enumerate(value.getUsers())) { |
| if (userSet.insert(user).second) |
| printUserIDs(user, index); |
| } |
| } |
| |
| void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) { |
| if (prefixComma) |
| os << ", "; |
| |
| if (!user->getNumResults()) { |
| printOperationID(user); |
| } else { |
| interleaveComma(user->getResults(), |
| [this](Value result) { printValueID(result); }); |
| } |
| } |
| |
| void OperationPrinter::printCustomOrGenericOp(Operation *op) { |
| // If requested, always print the generic form. |
| if (!printerFlags.shouldPrintGenericOpForm()) { |
| // Check to see if this is a known operation. If so, use the registered |
| // custom printer hook. |
| if (auto opInfo = op->getRegisteredInfo()) { |
| opInfo->printAssembly(op, *this, defaultDialectStack.back()); |
| return; |
| } |
| // Otherwise try to dispatch to the dialect, if available. |
| if (Dialect *dialect = op->getDialect()) { |
| if (auto opPrinter = dialect->getOperationPrinter(op)) { |
| // Print the op name first. |
| StringRef name = op->getName().getStringRef(); |
| // Only drop the default dialect prefix when it cannot lead to |
| // ambiguities. |
| if (name.count('.') == 1) |
| name.consume_front((defaultDialectStack.back() + ".").str()); |
| os << name; |
| |
| // Print the rest of the op now. |
| opPrinter(op, *this); |
| return; |
| } |
| } |
| } |
| |
| // Otherwise print with the generic assembly form. |
| printGenericOp(op, /*printOpName=*/true); |
| } |
| |
| void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { |
| if (printOpName) |
| printEscapedString(op->getName().getStringRef()); |
| os << '('; |
| interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); }); |
| os << ')'; |
| |
| // For terminators, print the list of successors and their operands. |
| if (op->getNumSuccessors() != 0) { |
| os << '['; |
| interleaveComma(op->getSuccessors(), |
| [&](Block *successor) { printBlockName(successor); }); |
| os << ']'; |
| } |
| |
| // Print the properties. |
| if (Attribute prop = op->getPropertiesAsAttribute()) { |
| os << " <"; |
| Impl::printAttribute(prop); |
| os << '>'; |
| } |
| |
| // Print regions. |
| if (op->getNumRegions() != 0) { |
| os << " ("; |
| interleaveComma(op->getRegions(), [&](Region ®ion) { |
| printRegion(region, /*printEntryBlockArgs=*/true, |
| /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); |
| }); |
| os << ')'; |
| } |
| |
| printOptionalAttrDict(op->getPropertiesStorage() |
| ? llvm::to_vector(op->getDiscardableAttrs()) |
| : op->getAttrs()); |
| |
| // Print the type signature of the operation. |
| os << " : "; |
| printFunctionalType(op); |
| } |
| |
| void OperationPrinter::printBlockName(Block *block) { |
| os << state.getSSANameState().getBlockInfo(block).name; |
| } |
| |
| void OperationPrinter::print(Block *block, bool printBlockArgs, |
| bool printBlockTerminator) { |
| // Print the block label and argument list if requested. |
| if (printBlockArgs) { |
| os.indent(currentIndent); |
| printBlockName(block); |
| |
| // Print the argument list if non-empty. |
| if (!block->args_empty()) { |
| os << '('; |
| interleaveComma(block->getArguments(), [&](BlockArgument arg) { |
| printValueID(arg); |
| os << ": "; |
| printType(arg.getType()); |
| // TODO: We should allow location aliases on block arguments. |
| printTrailingLocation(arg.getLoc(), /*allowAlias*/ false); |
| }); |
| os << ')'; |
| } |
| os << ':'; |
| |
| // Print out some context information about the predecessors of this block. |
| if (!block->getParent()) { |
| os << " // block is not in a region!"; |
| } else if (block->hasNoPredecessors()) { |
| if (!block->isEntryBlock()) |
| os << " // no predecessors"; |
| } else if (auto *pred = block->getSinglePredecessor()) { |
| os << " // pred: "; |
| printBlockName(pred); |
| } else { |
| // We want to print the predecessors in a stable order, not in |
| // whatever order the use-list is in, so gather and sort them. |
| SmallVector<BlockInfo, 4> predIDs; |
| for (auto *pred : block->getPredecessors()) |
| predIDs.push_back(state.getSSANameState().getBlockInfo(pred)); |
| llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { |
| return lhs.ordering < rhs.ordering; |
| }); |
| |
| os << " // " << predIDs.size() << " preds: "; |
| |
| interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; }); |
| } |
| os << newLine; |
| } |
| |
| currentIndent += indentWidth; |
| |
| if (printerFlags.shouldPrintValueUsers()) { |
| for (BlockArgument arg : block->getArguments()) { |
| os.indent(currentIndent); |
| printUsersComment(arg); |
| } |
| } |
| |
| bool hasTerminator = |
| !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>(); |
| auto range = llvm::make_range( |
| block->begin(), |
| std::prev(block->end(), |
| (!hasTerminator || printBlockTerminator) ? 0 : 1)); |
| for (auto &op : range) { |
| printFullOpWithIndentAndLoc(&op); |
| os << newLine; |
| } |
| currentIndent -= indentWidth; |
| } |
| |
| void OperationPrinter::printValueID(Value value, bool printResultNo, |
| raw_ostream *streamOverride) const { |
| state.getSSANameState().printValueID(value, printResultNo, |
| streamOverride ? *streamOverride : os); |
| } |
| |
| void OperationPrinter::printOperationID(Operation *op, |
| raw_ostream *streamOverride) const { |
| state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride |
| : os); |
| } |
| |
| void OperationPrinter::printSuccessor(Block *successor) { |
| printBlockName(successor); |
| } |
| |
| void OperationPrinter::printSuccessorAndUseList(Block *successor, |
| ValueRange succOperands) { |
| printBlockName(successor); |
| if (succOperands.empty()) |
| return; |
| |
| os << '('; |
| interleaveComma(succOperands, |
| [this](Value operand) { printValueID(operand); }); |
| os << " : "; |
| interleaveComma(succOperands, |
| [this](Value operand) { printType(operand.getType()); }); |
| os << ')'; |
| } |
| |
| void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, |
| bool printBlockTerminators, |
| bool printEmptyBlock) { |
| if (printerFlags.shouldSkipRegions()) { |
| os << "{...}"; |
| return; |
| } |
| os << "{" << newLine; |
| if (!region.empty()) { |
| auto restoreDefaultDialect = |
| llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); }); |
| if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp())) |
| defaultDialectStack.push_back(iface.getDefaultDialect()); |
| else |
| defaultDialectStack.push_back(""); |
| |
| auto *entryBlock = ®ion.front(); |
| // Force printing the block header if printEmptyBlock is set and the block |
| // is empty or if printEntryBlockArgs is set and there are arguments to |
| // print. |
| bool shouldAlwaysPrintBlockHeader = |
| (printEmptyBlock && entryBlock->empty()) || |
| (printEntryBlockArgs && entryBlock->getNumArguments() != 0); |
| print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); |
| for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) |
| print(&b); |
| } |
| os.indent(currentIndent) << "}"; |
| } |
| |
| void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
| ValueRange operands) { |
| if (!mapAttr) { |
| os << "<<NULL AFFINE MAP>>"; |
| return; |
| } |
| AffineMap map = mapAttr.getValue(); |
| unsigned numDims = map.getNumDims(); |
| auto printValueName = [&](unsigned pos, bool isSymbol) { |
| unsigned index = isSymbol ? numDims + pos : pos; |
| assert(index < operands.size()); |
| if (isSymbol) |
| os << "symbol("; |
| printValueID(operands[index]); |
| if (isSymbol) |
| os << ')'; |
| }; |
| |
| interleaveComma(map.getResults(), [&](AffineExpr expr) { |
| printAffineExpr(expr, printValueName); |
| }); |
| } |
| |
| void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, |
| ValueRange dimOperands, |
| ValueRange symOperands) { |
| auto printValueName = [&](unsigned pos, bool isSymbol) { |
| if (!isSymbol) |
| return printValueID(dimOperands[pos]); |
| os << "symbol("; |
| printValueID(symOperands[pos]); |
| os << ')'; |
| }; |
| printAffineExpr(expr, printValueName); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // print and dump methods |
| //===----------------------------------------------------------------------===// |
| |
| void Attribute::print(raw_ostream &os, bool elideType) const { |
| if (!*this) { |
| os << "<<NULL ATTRIBUTE>>"; |
| return; |
| } |
| |
| AsmState state(getContext()); |
| print(os, state, elideType); |
| } |
| void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const { |
| using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision; |
| AsmPrinter::Impl(os, state.getImpl()) |
| .printAttribute(*this, elideType ? AttrTypeElision::Must |
| : AttrTypeElision::Never); |
| } |
| |
| void Attribute::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void Attribute::printStripped(raw_ostream &os, AsmState &state) const { |
| if (!*this) { |
| os << "<<NULL ATTRIBUTE>>"; |
| return; |
| } |
| |
| AsmPrinter::Impl subPrinter(os, state.getImpl()); |
| if (succeeded(subPrinter.printAlias(*this))) |
| return; |
| |
| auto &dialect = this->getDialect(); |
| uint64_t posPrior = os.tell(); |
| DialectAsmPrinter printer(subPrinter); |
| dialect.printAttribute(*this, printer); |
| if (posPrior != os.tell()) |
| return; |
| |
| // Fallback to printing with prefix if the above failed to write anything |
| // to the output stream. |
| print(os, state); |
| } |
| void Attribute::printStripped(raw_ostream &os) const { |
| if (!*this) { |
| os << "<<NULL ATTRIBUTE>>"; |
| return; |
| } |
| |
| AsmState state(getContext()); |
| printStripped(os, state); |
| } |
| |
| void Type::print(raw_ostream &os) const { |
| if (!*this) { |
| os << "<<NULL TYPE>>"; |
| return; |
| } |
| |
| AsmState state(getContext()); |
| print(os, state); |
| } |
| void Type::print(raw_ostream &os, AsmState &state) const { |
| AsmPrinter::Impl(os, state.getImpl()).printType(*this); |
| } |
| |
| void Type::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void AffineMap::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void IntegerSet::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void AffineExpr::print(raw_ostream &os) const { |
| if (!expr) { |
| os << "<<NULL AFFINE EXPR>>"; |
| return; |
| } |
| AsmState state(getContext()); |
| AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this); |
| } |
| |
| void AffineExpr::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void AffineMap::print(raw_ostream &os) const { |
| if (!map) { |
| os << "<<NULL AFFINE MAP>>"; |
| return; |
| } |
| AsmState state(getContext()); |
| AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this); |
| } |
| |
| void IntegerSet::print(raw_ostream &os) const { |
| AsmState state(getContext()); |
| AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this); |
| } |
| |
| void Value::print(raw_ostream &os) const { print(os, OpPrintingFlags()); } |
| void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const { |
| if (!impl) { |
| os << "<<NULL VALUE>>"; |
| return; |
| } |
| |
| if (auto *op = getDefiningOp()) |
| return op->print(os, flags); |
| // TODO: Improve BlockArgument print'ing. |
| BlockArgument arg = llvm::cast<BlockArgument>(*this); |
| os << "<block argument> of type '" << arg.getType() |
| << "' at index: " << arg.getArgNumber(); |
| } |
| void Value::print(raw_ostream &os, AsmState &state) const { |
| if (!impl) { |
| os << "<<NULL VALUE>>"; |
| return; |
| } |
| |
| if (auto *op = getDefiningOp()) |
| return op->print(os, state); |
| |
| // TODO: Improve BlockArgument print'ing. |
| BlockArgument arg = llvm::cast<BlockArgument>(*this); |
| os << "<block argument> of type '" << arg.getType() |
| << "' at index: " << arg.getArgNumber(); |
| } |
| |
| void Value::dump() const { |
| print(llvm::errs()); |
| llvm::errs() << "\n"; |
| } |
| |
| void Value::printAsOperand(raw_ostream &os, AsmState &state) const { |
| // TODO: This doesn't necessarily capture all potential cases. |
| // Currently, region arguments can be shadowed when printing the main |
| // operation. If the IR hasn't been printed, this will produce the old SSA |
| // name and not the shadowed name. |
| state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true, |
| os); |
| } |
| |
| static Operation *findParent(Operation *op, bool shouldUseLocalScope) { |
| do { |
| // If we are printing local scope, stop at the first operation that is |
| // isolated from above. |
| if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| break; |
| |
| // Otherwise, traverse up to the next parent. |
| Operation *parentOp = op->getParentOp(); |
| if (!parentOp) |
| break; |
| op = parentOp; |
| } while (true); |
| return op; |
| } |
| |
| void Value::printAsOperand(raw_ostream &os, |
| const OpPrintingFlags &flags) const { |
| Operation *op; |
| if (auto result = llvm::dyn_cast<OpResult>(*this)) { |
| op = result.getOwner(); |
| } else { |
| op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp(); |
| if (!op) { |
| os << "<<UNKNOWN SSA VALUE>>"; |
| return; |
| } |
| } |
| op = findParent(op, flags.shouldUseLocalScope()); |
| AsmState state(op, flags); |
| printAsOperand(os, state); |
| } |
| |
| void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) { |
| // Find the operation to number from based upon the provided flags. |
| Operation *op = findParent(this, printerFlags.shouldUseLocalScope()); |
| AsmState state(op, printerFlags); |
| print(os, state); |
| } |
| void Operation::print(raw_ostream &os, AsmState &state) { |
| OperationPrinter printer(os, state.getImpl()); |
| if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { |
| state.getImpl().initializeAliases(this); |
| printer.printTopLevelOperation(this); |
| } else { |
| printer.printFullOpWithIndentAndLoc(this); |
| } |
| } |
| |
| void Operation::dump() { |
| print(llvm::errs(), OpPrintingFlags().useLocalScope()); |
| llvm::errs() << "\n"; |
| } |
| |
| void Block::print(raw_ostream &os) { |
| Operation *parentOp = getParentOp(); |
| if (!parentOp) { |
| os << "<<UNLINKED BLOCK>>\n"; |
| return; |
| } |
| // Get the top-level op. |
| while (auto *nextOp = parentOp->getParentOp()) |
| parentOp = nextOp; |
| |
| AsmState state(parentOp); |
| print(os, state); |
| } |
| void Block::print(raw_ostream &os, AsmState &state) { |
| OperationPrinter(os, state.getImpl()).print(this); |
| } |
| |
| void Block::dump() { print(llvm::errs()); } |
| |
| /// Print out the name of the block without printing its body. |
| void Block::printAsOperand(raw_ostream &os, bool printType) { |
| Operation *parentOp = getParentOp(); |
| if (!parentOp) { |
| os << "<<UNLINKED BLOCK>>\n"; |
| return; |
| } |
| AsmState state(parentOp); |
| printAsOperand(os, state); |
| } |
| void Block::printAsOperand(raw_ostream &os, AsmState &state) { |
| OperationPrinter printer(os, state.getImpl()); |
| printer.printBlockName(this); |
| } |
| |
| raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) { |
| block.print(os); |
| return os; |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Custom printers |
| //===--------------------------------------------------------------------===// |
| namespace mlir { |
| |
| void printDimensionList(OpAsmPrinter &printer, Operation *op, |
| ArrayRef<int64_t> dimensions) { |
| if (dimensions.empty()) |
| printer << "["; |
| printer.printDimensionList(dimensions); |
| if (dimensions.empty()) |
| printer << "]"; |
| } |
| |
| ParseResult parseDimensionList(OpAsmParser &parser, |
| DenseI64ArrayAttr &dimensions) { |
| // Empty list case denoted by "[]". |
| if (succeeded(parser.parseOptionalLSquare())) { |
| if (failed(parser.parseRSquare())) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "Failed parsing dimension list."; |
| } |
| dimensions = |
| DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>()); |
| return success(); |
| } |
| |
| // Non-empty list case. |
| SmallVector<int64_t> shapeArr; |
| if (failed(parser.parseDimensionList(shapeArr, true, false))) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "Failed parsing dimension list."; |
| } |
| if (shapeArr.empty()) { |
| return parser.emitError(parser.getCurrentLocation()) |
| << "Failed parsing dimension list. Did you mean an empty list? It " |
| "must be denoted by \"[]\"."; |
| } |
| dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr); |
| return success(); |
| } |
| |
| } // namespace mlir |