| //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// |
| // |
| // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_IR_SYMBOLTABLE_H |
| #define MLIR_IR_SYMBOLTABLE_H |
| |
| #include "mlir/IR/OpDefinition.h" |
| #include "llvm/ADT/StringMap.h" |
| |
| namespace mlir { |
| class Identifier; |
| class Operation; |
| |
| /// This class allows for representing and managing the symbol table used by |
| /// operations with the 'SymbolTable' trait. Inserting into and erasing from |
| /// this SymbolTable will also insert and erase from the Operation given to it |
| /// at construction. |
| class SymbolTable { |
| public: |
| /// Build a symbol table with the symbols within the given operation. |
| SymbolTable(Operation *symbolTableOp); |
| |
| /// Look up a symbol with the specified name, returning null if no such |
| /// name exists. Names never include the @ on them. |
| Operation *lookup(StringRef name) const; |
| template <typename T> T lookup(StringRef name) const { |
| return dyn_cast_or_null<T>(lookup(name)); |
| } |
| |
| /// Erase the given symbol from the table. |
| void erase(Operation *symbol); |
| |
| /// Insert a new symbol into the table, and rename it as necessary to avoid |
| /// collisions. Also insert at the specified location in the body of the |
| /// associated operation. |
| void insert(Operation *symbol, Block::iterator insertPt = {}); |
| |
| /// Return the name of the attribute used for symbol names. |
| static StringRef getSymbolAttrName() { return "sym_name"; } |
| |
| /// Returns the associated operation. |
| Operation *getOp() const { return symbolTableOp; } |
| |
| /// Return the name of the attribute used for symbol visibility. |
| static StringRef getVisibilityAttrName() { return "sym_visibility"; } |
| |
| //===--------------------------------------------------------------------===// |
| // Symbol Utilities |
| //===--------------------------------------------------------------------===// |
| |
| /// An enumeration detailing the different visibility types that a symbol may |
| /// have. |
| enum class Visibility { |
| /// The symbol is public and may be referenced anywhere internal or external |
| /// to the visible references in the IR. |
| Public, |
| |
| /// The symbol is private and may only be referenced by SymbolRefAttrs local |
| /// to the operations within the current symbol table. |
| Private, |
| |
| /// The symbol is visible to the current IR, which may include operations in |
| /// symbol tables above the one that owns the current symbol. `Nested` |
| /// visibility allows for referencing a symbol outside of its current symbol |
| /// table, while retaining the ability to observe all uses. |
| Nested, |
| }; |
| |
| /// Returns true if the given operation defines a symbol. |
| static bool isSymbol(Operation *op); |
| |
| /// Returns the name of the given symbol operation. |
| static StringRef getSymbolName(Operation *symbol); |
| /// Sets the name of the given symbol operation. |
| static void setSymbolName(Operation *symbol, StringRef name); |
| |
| /// Returns the visibility of the given symbol operation. |
| static Visibility getSymbolVisibility(Operation *symbol); |
| /// Sets the visibility of the given symbol operation. |
| static void setSymbolVisibility(Operation *symbol, Visibility vis); |
| |
| /// Returns the operation registered with the given symbol name with the |
| /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation |
| /// with the 'OpTrait::SymbolTable' trait. |
| static Operation *lookupSymbolIn(Operation *op, StringRef symbol); |
| static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); |
| |
| /// Returns the operation registered with the given symbol name within the |
| /// closest parent operation of, or including, 'from' with the |
| /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was |
| /// found. |
| static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); |
| static Operation *lookupNearestSymbolFrom(Operation *from, |
| SymbolRefAttr symbol); |
| |
| /// This class represents a specific symbol use. |
| class SymbolUse { |
| public: |
| SymbolUse(Operation *op, SymbolRefAttr symbolRef) |
| : owner(op), symbolRef(symbolRef) {} |
| |
| /// Return the operation user of this symbol reference. |
| Operation *getUser() const { return owner; } |
| |
| /// Return the symbol reference that this use represents. |
| SymbolRefAttr getSymbolRef() const { return symbolRef; } |
| |
| private: |
| /// The operation that this access is held by. |
| Operation *owner; |
| |
| /// The symbol reference that this use represents. |
| SymbolRefAttr symbolRef; |
| }; |
| |
| /// This class implements a range of SymbolRef uses. |
| class UseRange { |
| public: |
| UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} |
| |
| using iterator = std::vector<SymbolUse>::const_iterator; |
| iterator begin() const { return uses.begin(); } |
| iterator end() const { return uses.end(); } |
| |
| private: |
| std::vector<SymbolUse> uses; |
| }; |
| |
| /// Get an iterator range for all of the uses, for any symbol, that are nested |
| /// within the given operation 'from'. This does not traverse into any nested |
| /// symbol tables, and will also only return uses on 'from' if it does not |
| /// also define a symbol table. This is because we treat the region as the |
| /// boundary of the symbol table, and not the op itself. This function returns |
| /// None if there are any unknown operations that may potentially be symbol |
| /// tables. |
| static Optional<UseRange> getSymbolUses(Operation *from); |
| |
| /// Get all of the uses of the given symbol that are nested within the given |
| /// operation 'from'. This does not traverse into any nested symbol tables, |
| /// and will also only return uses on 'from' if it does not also define a |
| /// symbol table. This is because we treat the region as the boundary of the |
| /// symbol table, and not the op itself. This function returns None if there |
| /// are any unknown operations that may potentially be symbol tables. |
| static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from); |
| static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from); |
| |
| /// Return if the given symbol is known to have no uses that are nested |
| /// within the given operation 'from'. This does not traverse into any nested |
| /// symbol tables, and will also only count uses on 'from' if it does not also |
| /// define a symbol table. This is because we treat the region as the boundary |
| /// of the symbol table, and not the op itself. This function will also return |
| /// false if there are any unknown operations that may potentially be symbol |
| /// tables. This doesn't necessarily mean that there are no uses, we just |
| /// can't conservatively prove it. |
| static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); |
| static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); |
| |
| /// Attempt to replace all uses of the given symbol 'oldSymbol' with the |
| /// provided symbol 'newSymbol' that are nested within the given operation |
| /// 'from'. This does not traverse into any nested symbol tables, and will |
| /// also only replace uses on 'from' if it does not also define a symbol |
| /// table. This is because we treat the region as the boundary of the symbol |
| /// table, and not the op itself. If there are any unknown operations that may |
| /// potentially be symbol tables, no uses are replaced and failure is |
| /// returned. |
| LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, |
| StringRef newSymbol, |
| Operation *from); |
| LLVM_NODISCARD static LogicalResult |
| replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, |
| Operation *from); |
| |
| private: |
| Operation *symbolTableOp; |
| |
| /// This is a mapping from a name to the symbol with that name. |
| llvm::StringMap<Operation *> symbolTable; |
| |
| /// This is used when name conflicts are detected. |
| unsigned uniquingCounter = 0; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // SymbolTable Trait Types |
| //===----------------------------------------------------------------------===// |
| |
| namespace OpTrait { |
| namespace impl { |
| LogicalResult verifySymbolTable(Operation *op); |
| LogicalResult verifySymbol(Operation *op); |
| } // namespace impl |
| |
| /// A trait used to provide symbol table functionalities to a region operation. |
| /// This operation must hold exactly 1 region. Once attached, all operations |
| /// that are directly within the region, i.e not including those within child |
| /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will |
| /// be verified to ensure that the names are uniqued. These operations must also |
| /// adhere to the constraints defined by the `Symbol` trait, even if they do not |
| /// inherit from it. |
| template <typename ConcreteType> |
| class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySymbolTable(op); |
| } |
| |
| /// Look up a symbol with the specified name, returning null if no such |
| /// name exists. Symbol names never include the @ on them. Note: This |
| /// performs a linear scan of held symbols. |
| Operation *lookupSymbol(StringRef name) { |
| return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); |
| } |
| template <typename T> T lookupSymbol(StringRef name) { |
| return dyn_cast_or_null<T>(lookupSymbol(name)); |
| } |
| }; |
| |
| /// A trait used to define a symbol that can be used on operations within a |
| /// symbol table. Operations using this trait must adhere to the following: |
| /// * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'. |
| template <typename ConcreteType> |
| class Symbol : public TraitBase<ConcreteType, Symbol> { |
| public: |
| using Visibility = mlir::SymbolTable::Visibility; |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySymbol(op); |
| } |
| |
| /// Returns the name of this symbol. |
| StringRef getName() { |
| return this->getOperation() |
| ->template getAttrOfType<StringAttr>( |
| mlir::SymbolTable::getSymbolAttrName()) |
| .getValue(); |
| } |
| |
| /// Set the name of this symbol. |
| void setName(StringRef name) { |
| this->getOperation()->setAttr( |
| mlir::SymbolTable::getSymbolAttrName(), |
| StringAttr::get(name, this->getOperation()->getContext())); |
| } |
| |
| /// Returns the visibility of the current symbol. |
| Visibility getVisibility() { |
| return mlir::SymbolTable::getSymbolVisibility(this->getOperation()); |
| } |
| |
| /// Sets the visibility of the current symbol. |
| void setVisibility(Visibility vis) { |
| mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis); |
| } |
| |
| /// Get all of the uses of the current symbol that are nested within the given |
| /// operation 'from'. |
| /// Note: See mlir::SymbolTable::getSymbolUses for more details. |
| Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) { |
| return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from); |
| } |
| |
| /// Return if the current symbol is known to have no uses that are nested |
| /// within the given operation 'from'. |
| /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details. |
| bool symbolKnownUseEmpty(Operation *from) { |
| return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from); |
| } |
| |
| /// Attempt to replace all uses of the current symbol with the provided symbol |
| /// 'newSymbol' that are nested within the given operation 'from'. |
| /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details. |
| LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol, |
| Operation *from) { |
| return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(), |
| newSymbol, from); |
| } |
| }; |
| |
| } // end namespace OpTrait |
| } // end namespace mlir |
| |
| #endif // MLIR_IR_SYMBOLTABLE_H |