blob: c13a1a1858eb14aaa62a39b5e0549f66d8cefd21 [file] [log] [blame]
//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines functionality for registring and extending dialects.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECTREGISTRY_H
#define MLIR_IR_DIALECTREGISTRY_H
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <map>
#include <tuple>
namespace mlir {
class Dialect;
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
using DynamicDialectPopulationFunction =
std::function<void(MLIRContext *, DynamicDialect *)>;
//===----------------------------------------------------------------------===//
// DialectExtension
//===----------------------------------------------------------------------===//
/// This class represents an opaque dialect extension. It contains a set of
/// required dialects and an application function. The required dialects control
/// when the extension is applied, i.e. the extension is applied when all
/// required dialects are loaded. The application function can be used to attach
/// additional functionality to attributes, dialects, operations, types, etc.,
/// and may also load additional necessary dialects.
class DialectExtensionBase {
public:
virtual ~DialectExtensionBase();
/// Return the dialects that our required by this extension to be loaded
/// before applying. If empty then the extension is invoked for every loaded
/// dialect indepently.
ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
/// Apply this extension to the given context and the required dialects.
virtual void apply(MLIRContext *context,
MutableArrayRef<Dialect *> dialects) const = 0;
/// Return a copy of this extension.
virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
protected:
/// Initialize the extension with a set of required dialects.
/// If the list is empty, the extension is invoked for every loaded dialect
/// independently.
DialectExtensionBase(ArrayRef<StringRef> dialectNames)
: dialectNames(dialectNames.begin(), dialectNames.end()) {}
private:
/// The names of the dialects affected by this extension.
SmallVector<StringRef> dialectNames;
};
/// This class represents a dialect extension anchored on the given set of
/// dialects. When all of the specified dialects have been loaded, the
/// application function of this extension will be executed.
template <typename DerivedT, typename... DialectsT>
class DialectExtension : public DialectExtensionBase {
public:
/// Applies this extension to the given context and set of required dialects.
virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
/// Return a copy of this extension.
std::unique_ptr<DialectExtensionBase> clone() const final {
return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
}
protected:
DialectExtension()
: DialectExtensionBase(
ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
/// Override the base apply method to allow providing the exact dialect types.
void apply(MLIRContext *context,
MutableArrayRef<Dialect *> dialects) const final {
unsigned dialectIdx = 0;
auto derivedDialects = std::tuple<DialectsT *...>{
static_cast<DialectsT *>(dialects[dialectIdx++])...};
std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
derivedDialects);
}
};
namespace dialect_extension_detail {
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error.
void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
TypeID interfaceRequestorID,
TypeID interfaceID,
StringRef interfaceName);
/// Checks if the given interface, which is attempting to be attached, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// the promised interface is marked as resolved.
void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
TypeID interfaceRequestorID,
TypeID interfaceID);
/// Checks if a promise has been made for the interface/requestor pair.
bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
TypeID interfaceID);
/// Checks if a promise has been made for the interface/requestor pair.
template <typename ConcreteT, typename InterfaceT>
bool hasPromisedInterface(Dialect &dialect) {
return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
InterfaceT::getInterfaceID());
}
} // namespace dialect_extension_detail
//===----------------------------------------------------------------------===//
// DialectRegistry
//===----------------------------------------------------------------------===//
/// The DialectRegistry maps a dialect namespace to a constructor for the
/// matching dialect. This allows for decoupling the list of dialects
/// "available" from the dialects loaded in the Context. The parser in
/// particular will lazily load dialects in the Context as operations are
/// encountered.
class DialectRegistry {
using MapTy =
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
public:
explicit DialectRegistry();
template <typename ConcreteDialect>
void insert() {
insert(TypeID::get<ConcreteDialect>(),
ConcreteDialect::getDialectNamespace(),
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
// Just allocate the dialect, the context
// takes ownership of it.
return ctx->getOrLoadDialect<ConcreteDialect>();
})));
}
template <typename ConcreteDialect, typename OtherDialect,
typename... MoreDialects>
void insert() {
insert<ConcreteDialect>();
insert<OtherDialect, MoreDialects...>();
}
/// Add a new dialect constructor to the registry. The constructor must be
/// calling MLIRContext::getOrLoadDialect in order for the context to take
/// ownership of the dialect and for delayed interface registration to happen.
void insert(TypeID typeID, StringRef name,
const DialectAllocatorFunction &ctor);
/// Add a new dynamic dialect constructor in the registry. The constructor
/// provides as argument the created dynamic dialect, and is expected to
/// register the dialect types, attributes, and ops, using the
/// methods defined in ExtensibleDialect such as registerDynamicOperation.
void insertDynamic(StringRef name,
const DynamicDialectPopulationFunction &ctor);
/// Return an allocation function for constructing the dialect identified
/// by its namespace, or nullptr if the namespace is not in this registry.
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
// Register all dialects available in the current registry with the registry
// in the provided context.
void appendTo(DialectRegistry &destination) const {
for (const auto &nameAndRegistrationIt : registry)
destination.insert(nameAndRegistrationIt.second.first,
nameAndRegistrationIt.first,
nameAndRegistrationIt.second.second);
// Merge the extensions.
for (const auto &extension : extensions)
destination.extensions.push_back(extension->clone());
}
/// Return the names of dialects known to this registry.
auto getDialectNames() const {
return llvm::map_range(
registry,
[](const MapTy::value_type &item) -> StringRef { return item.first; });
}
/// Apply any held extensions that require the given dialect. Users are not
/// expected to call this directly.
void applyExtensions(Dialect *dialect) const;
/// Apply any applicable extensions to the given context. Users are not
/// expected to call this directly.
void applyExtensions(MLIRContext *ctx) const;
/// Add the given extension to the registry.
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
extensions.push_back(std::move(extension));
}
/// Add the given extensions to the registry.
template <typename... ExtensionsT>
void addExtensions() {
(addExtension(std::make_unique<ExtensionsT>()), ...);
}
/// Add an extension function that requires the given dialects.
/// Note: This bare functor overload is provided in addition to the
/// std::function variant to enable dialect type deduction, e.g.:
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
///
/// is equivalent to:
/// registry.addExtension<MyDialect>(
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
/// )
template <typename... DialectsT>
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
addExtension<DialectsT...>(
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
}
template <typename... DialectsT>
void
addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
struct Extension : public DialectExtension<Extension, DialectsT...> {
Extension(const Extension &) = default;
Extension(ExtensionFnT extensionFn)
: extensionFn(std::move(extensionFn)) {}
~Extension() override = default;
void apply(MLIRContext *context, DialectsT *...dialects) const final {
extensionFn(context, dialects...);
}
ExtensionFnT extensionFn;
};
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
}
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
/// contains all of the components of this registry.
bool isSubsetOf(const DialectRegistry &rhs) const;
private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
};
} // namespace mlir
#endif // MLIR_IR_DIALECTREGISTRY_H