blob: 3ffed3d932d9bb637a8b455511c66da44b3fbc93 [file] [log] [blame]
//===- DialectInterface.h - IR Dialect Interfaces ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECTINTERFACE_H
#define MLIR_IR_DIALECTINTERFACE_H
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
class Dialect;
class MLIRContext;
class Operation;
//===----------------------------------------------------------------------===//
// DialectInterface
//===----------------------------------------------------------------------===//
namespace detail {
/// The base class used for all derived interface types. This class provides
/// utilities necessary for registration.
template <typename ConcreteType, typename BaseT>
class DialectInterfaceBase : public BaseT {
public:
using Base = DialectInterfaceBase<ConcreteType, BaseT>;
/// Get a unique id for the derived interface type.
static TypeID getInterfaceID() { return TypeID::get<ConcreteType>(); }
protected:
DialectInterfaceBase(Dialect *dialect) : BaseT(dialect, getInterfaceID()) {}
};
} // end namespace detail
/// This class represents an interface overridden for a single dialect.
class DialectInterface {
public:
virtual ~DialectInterface();
/// The base class used for all derived interface types. This class provides
/// utilities necessary for registration.
template <typename ConcreteType>
using Base = detail::DialectInterfaceBase<ConcreteType, DialectInterface>;
/// Return the dialect that this interface represents.
Dialect *getDialect() const { return dialect; }
/// Return the derived interface id.
TypeID getID() const { return interfaceID; }
protected:
DialectInterface(Dialect *dialect, TypeID id)
: dialect(dialect), interfaceID(id) {}
private:
/// The dialect that represents this interface.
Dialect *dialect;
/// The unique identifier for the derived interface type.
TypeID interfaceID;
};
//===----------------------------------------------------------------------===//
// DialectInterfaceCollection
//===----------------------------------------------------------------------===//
namespace detail {
/// This class is the base class for a collection of instances for a specific
/// interface kind.
class DialectInterfaceCollectionBase {
/// DenseMap info for dialect interfaces that allows lookup by the dialect.
struct InterfaceKeyInfo : public DenseMapInfo<const DialectInterface *> {
using DenseMapInfo<const DialectInterface *>::isEqual;
static unsigned getHashValue(Dialect *key) { return llvm::hash_value(key); }
static unsigned getHashValue(const DialectInterface *key) {
return getHashValue(key->getDialect());
}
static bool isEqual(Dialect *lhs, const DialectInterface *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == rhs->getDialect();
}
};
/// A set of registered dialect interface instances.
using InterfaceSetT = DenseSet<const DialectInterface *, InterfaceKeyInfo>;
using InterfaceVectorT = std::vector<const DialectInterface *>;
public:
DialectInterfaceCollectionBase(MLIRContext *ctx, TypeID interfaceKind);
virtual ~DialectInterfaceCollectionBase();
protected:
/// Get the interface for the dialect of given operation, or null if one
/// is not registered.
const DialectInterface *getInterfaceFor(Operation *op) const;
/// Get the interface for the given dialect.
const DialectInterface *getInterfaceFor(Dialect *dialect) const {
auto it = interfaces.find_as(dialect);
return it == interfaces.end() ? nullptr : *it;
}
/// An iterator class that iterates the held interface objects of the given
/// derived interface type.
template <typename InterfaceT>
struct iterator
: public llvm::mapped_iterator_base<iterator<InterfaceT>,
InterfaceVectorT::const_iterator,
const InterfaceT &> {
using llvm::mapped_iterator_base<iterator<InterfaceT>,
InterfaceVectorT::const_iterator,
const InterfaceT &>::mapped_iterator_base;
/// Map the element to the iterator result type.
const InterfaceT &mapElement(const DialectInterface *interface) const {
return *static_cast<const InterfaceT *>(interface);
}
};
/// Iterator access to the held interfaces.
template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
return iterator<InterfaceT>(orderedInterfaces.begin());
}
template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
return iterator<InterfaceT>(orderedInterfaces.end());
}
private:
/// A set of registered dialect interface instances.
InterfaceSetT interfaces;
/// An ordered list of the registered interface instances, necessary for
/// deterministic iteration.
// NOTE: SetVector does not provide find access, so it can't be used here.
InterfaceVectorT orderedInterfaces;
};
} // namespace detail
/// A collection of dialect interfaces within a context, for a given concrete
/// interface type.
template <typename InterfaceType>
class DialectInterfaceCollection
: public detail::DialectInterfaceCollectionBase {
public:
using Base = DialectInterfaceCollection<InterfaceType>;
/// Collect the registered dialect interfaces within the provided context.
DialectInterfaceCollection(MLIRContext *ctx)
: detail::DialectInterfaceCollectionBase(
ctx, InterfaceType::getInterfaceID()) {}
/// Get the interface for a given object, or null if one is not registered.
/// The object may be a dialect or an operation instance.
template <typename Object>
const InterfaceType *getInterfaceFor(Object *obj) const {
return static_cast<const InterfaceType *>(
detail::DialectInterfaceCollectionBase::getInterfaceFor(obj));
}
/// Iterator access to the held interfaces.
using iterator =
detail::DialectInterfaceCollectionBase::iterator<InterfaceType>;
iterator begin() const { return interface_begin<InterfaceType>(); }
iterator end() const { return interface_end<InterfaceType>(); }
private:
using detail::DialectInterfaceCollectionBase::interface_begin;
using detail::DialectInterfaceCollectionBase::interface_end;
};
} // namespace mlir
#endif