| //===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_IR_DIALECTINTERFACE_H |
| #define MLIR_IR_DIALECTINTERFACE_H |
| |
| #include "mlir/Support/TypeID.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| |
| namespace mlir { |
| class Dialect; |
| class MLIRContext; |
| class Operation; |
| |
| //===----------------------------------------------------------------------===// |
| // DialectInterface |
| //===----------------------------------------------------------------------===// |
| namespace detail { |
| /// The base class used for all derived interface types. This class provides |
| /// utilities necessary for registration. |
| template <typename ConcreteType, typename BaseT> |
| class DialectInterfaceBase : public BaseT { |
| public: |
| using Base = DialectInterfaceBase<ConcreteType, BaseT>; |
| |
| /// Get a unique id for the derived interface type. |
| static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); } |
| |
| protected: |
| DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {} |
| }; |
| } // end namespace detail |
| |
| /// This class represents an interface overridden for a single dialect. |
| class DialectInterface { |
| public: |
| virtual ~DialectInterface(); |
| |
| /// The base class used for all derived interface types. This class provides |
| /// utilities necessary for registration. |
| template <typename ConcreteType> |
| using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>; |
| |
| /// Return the dialect that this interface represents. |
| Dialect *getDialect() const { return dialect; } |
| |
| /// Return the derived interface id. |
| TypeID getID() const { return interfaceID; } |
| |
| protected: |
| DialectInterface(Dialect *dialect, TypeID id) |
| : dialect(dialect), interfaceID(id) {} |
| |
| private: |
| /// The dialect that represents this interface. |
| Dialect *dialect; |
| |
| /// The unique identifier for the derived interface type. |
| TypeID interfaceID; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // DialectInterfaceCollection |
| //===----------------------------------------------------------------------===// |
| |
| namespace detail { |
| /// This class is the base class for a collection of instances for a specific |
| /// interface kind. |
| class DialectInterfaceCollectionBase { |
| /// DenseMap info for dialect interfaces that allows lookup by the dialect. |
| struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> { |
| using DenseMapInfo<const DialectInterface *>::isEqual; |
| |
| static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); } |
| static unsigned getHashValue(const DialectInterface *key) { |
| return getHashValue(key->getDialect()); |
| } |
| |
| static bool isEqual(Dialect *lhs, const DialectInterface *rhs) { |
| if (rhs == getEmptyKey() || rhs == getTombstoneKey()) |
| return false; |
| return lhs == rhs->getDialect(); |
| } |
| }; |
| |
| /// A set of registered dialect interface instances. |
| using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>; |
| using InterfaceVectorT = std::vector<const DialectInterface *>; |
| |
| public: |
| DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind); |
| virtual ~DialectInterfaceCollectionBase(); |
| |
| protected: |
| /// Get the interface for the dialect of given operation, or null if one |
| /// is not registered. |
| const DialectInterface *getInterfaceFor(Operation *op) const; |
| |
| /// Get the interface for the given dialect. |
| const DialectInterface *getInterfaceFor(Dialect *dialect) const { |
| auto it = interfaces.find_as(dialect); |
| return it == interfaces.end() ? nullptr : *it; |
| } |
| |
| /// An iterator class that iterates the held interface objects of the given |
| /// derived interface type. |
| template <typename InterfaceT> |
| struct iterator |
| : public llvm::mapped_iterator_base<iterator<InterfaceT>, |
| InterfaceVectorT::const_iterator, |
| const InterfaceT &> { |
| using llvm::mapped_iterator_base<iterator<InterfaceT>, |
| InterfaceVectorT::const_iterator, |
| const InterfaceT &>::mapped_iterator_base; |
| |
| /// Map the element to the iterator result type. |
| const InterfaceT &mapElement(const DialectInterface *interface) const { |
| return *static_cast<const InterfaceT *>(interface); |
| } |
| }; |
| |
| /// Iterator access to the held interfaces. |
| template <typename InterfaceT> iterator<InterfaceT> interface_begin() const { |
| return iterator<InterfaceT>(orderedInterfaces.begin()); |
| } |
| template <typename InterfaceT> iterator<InterfaceT> interface_end() const { |
| return iterator<InterfaceT>(orderedInterfaces.end()); |
| } |
| |
| private: |
| /// A set of registered dialect interface instances. |
| InterfaceSetT interfaces; |
| /// An ordered list of the registered interface instances, necessary for |
| /// deterministic iteration. |
| // NOTE: SetVector does not provide find access, so it can't be used here. |
| InterfaceVectorT orderedInterfaces; |
| }; |
| } // namespace detail |
| |
| /// A collection of dialect interfaces within a context, for a given concrete |
| /// interface type. |
| template <typename InterfaceType> |
| class DialectInterfaceCollection |
| : public detail::DialectInterfaceCollectionBase { |
| public: |
| using Base = DialectInterfaceCollection<InterfaceType>; |
| |
| /// Collect the registered dialect interfaces within the provided context. |
| DialectInterfaceCollection(MLIRContext *ctx) |
| : detail::DialectInterfaceCollectionBase( |
| ctx, InterfaceType::getInterfaceID()) {} |
| |
| /// Get the interface for a given object, or null if one is not registered. |
| /// The object may be a dialect or an operation instance. |
| template <typename Object> |
| const InterfaceType *getInterfaceFor(Object *obj) const { |
| return static_cast<const InterfaceType *>( |
| detail::DialectInterfaceCollectionBase::getInterfaceFor(obj)); |
| } |
| |
| /// Iterator access to the held interfaces. |
| using iterator = |
| detail::DialectInterfaceCollectionBase::iterator<InterfaceType>; |
| iterator begin() const { return interface_begin<InterfaceType>(); } |
| iterator end() const { return interface_end<InterfaceType>(); } |
| |
| private: |
| using detail::DialectInterfaceCollectionBase::interface_begin; |
| using detail::DialectInterfaceCollectionBase::interface_end; |
| }; |
| |
| } // namespace mlir |
| |
| #endif |