blob: 7e4e1d8335b1645725e416085d5c7d4dac302620 [file] [log] [blame]
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines abstraction and registration mechanism for dialect hooks.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_HOOKS_H
#define MLIR_IR_DIALECT_HOOKS_H
#include "mlir/IR/Dialect.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
using DialectHooksSetter = std::function<void(MLIRContext *)>;
/// Dialect hooks allow external components to register their functions to
/// be called for specific tasks specialized per dialect, such as decoding
/// of opaque constants. To register concrete dialect hooks, one should
/// define a DialectHooks subclass and use it as a template
/// argument to DialectHooksRegistration. For example,
/// class MyHooks : public DialectHooks {...};
/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
// Returns hook to constant fold an operation.
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
DialectExtractElementHook getExtractElementHook() { return nullptr; }
};
/// Registers a function that will set hooks in the registered dialects
/// based on information coming from DialectHooksRegistration.
void registerDialectHooksSetter(const DialectHooksSetter &function);
/// DialectHooksRegistration provides a global initializer that registers
/// a dialect hooks setter routine.
/// Usage:
///
/// // At namespace scope.
/// static DialectHooksRegistration<MyHooks, MyDialect> unused;
template <typename ConcreteHooks> struct DialectHooksRegistration {
DialectHooksRegistration(StringRef dialectName) {
registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
if (!dialect) {
llvm::errs() << "error: cannot register hooks for unknown dialect '"
<< dialectName << "'\n";
abort();
}
// Set hooks.
ConcreteHooks hooks;
if (auto h = hooks.getConstantFoldHook())
dialect->constantFoldHook = h;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())
dialect->extractElementHook = h;
});
}
};
} // namespace mlir
#endif