blob: d3b4b055bc0c96ba221432e26b787f98f04fe164 [file] [log] [blame]
//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the 'dialect' abstraction.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_H
#define MLIR_IR_DIALECT_H
#include "mlir/IR/OperationSupport.h"
namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
class DialectInterface;
class OpBuilder;
class Type;
using DialectConstantDecodeHook =
std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
using DialectConstantFoldHook = std::function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
/// default named types for asm printing, etc.
///
/// Instances of the dialect object are global across all MLIRContext's that may
/// be active in the process.
///
class Dialect {
public:
virtual ~Dialect();
/// Utility function that returns if the given string is a valid dialect
/// namespace.
static bool isValidNamespace(StringRef str);
MLIRContext *getContext() const { return context; }
StringRef getNamespace() const { return name; }
/// Returns true if this dialect allows for unregistered operations, i.e.
/// operations prefixed with the dialect namespace but not registered with
/// addOperation.
bool allowsUnknownOperations() const { return unknownOpsAllowed; }
/// Return true if this dialect allows for unregistered types, i.e., types
/// prefixed with the dialect namespace but not registered with addType.
/// These are represented with OpaqueType.
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
//===--------------------------------------------------------------------===//
// Constant Hooks
//===--------------------------------------------------------------------===//
/// Registered fallback constant fold hook for the dialect. Like the constant
/// fold hook of each operation, it attempts to constant fold the operation
/// with the specified constant operand values - the elements in "operands"
/// will correspond directly to the operands of the operation, but may be null
/// if non-constant. If constant folding is successful, this fills in the
/// `results` vector. If not, this returns failure and `results` is
/// unspecified.
DialectConstantFoldHook constantFoldHook =
[](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) { return failure(); };
/// Registered hook to decode opaque constants associated with this
/// dialect. The hook function attempts to decode an opaque constant tensor
/// into a tensor with non-opaque content. If decoding is successful, this
/// method returns false and sets 'output' attribute. If not, it returns true
/// and leaves 'output' unspecified. The default hook fails to decode.
DialectConstantDecodeHook decodeHook =
[](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
/// Registered hook to extract an element from an opaque constant associated
/// with this dialect. If element has been successfully extracted, this
/// method returns that element. If not, it returns an empty attribute.
/// The default hook fails to extract an element.
DialectExtractElementHook extractElementHook =
[](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
return Attribute();
};
/// Registered hook to materialize a single constant operation from a given
/// attribute value with the desired resultant type. This method should use
/// the provided builder to create the operation without changing the
/// insertion position. The generated operation is expected to be constant
/// like, i.e. single result, zero operands, non side-effecting, etc. On
/// success, this hook should return the value generated to represent the
/// constant value. Otherwise, it should return null on failure.
virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
return nullptr;
}
//===--------------------------------------------------------------------===//
// Parsing Hooks
//===--------------------------------------------------------------------===//
/// Parse an attribute registered to this dialect. If 'type' is nonnull, it
/// refers to the expected type of the attribute.
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
/// Print an attribute registered to this dialect. Note: The type of the
/// attribute need not be printed by this method as it is always printed by
/// the caller.
virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
llvm_unreachable("dialect has no registered attribute printing hook");
}
/// Parse a type registered to this dialect.
virtual Type parseType(DialectAsmParser &parser) const;
/// Print a type registered to this dialect.
virtual void printType(Type, DialectAsmPrinter &) const {
llvm_unreachable("dialect has no registered type printing hook");
}
//===--------------------------------------------------------------------===//
// Verification Hooks
//===--------------------------------------------------------------------===//
/// Verify an attribute from this dialect on the argument at 'argIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
virtual LogicalResult verifyRegionArgAttribute(Operation *,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute);
/// Verify an attribute from this dialect on the result at 'resultIndex' for
/// the region at 'regionIndex' on the given operation. Returns failure if
/// the verification failed, success otherwise. This hook may optionally be
/// invoked from any operation containing a region.
virtual LogicalResult verifyRegionResultAttribute(Operation *,
unsigned regionIndex,
unsigned resultIndex,
NamedAttribute);
/// Verify an attribute from this dialect on the given operation. Returns
/// failure if the verification failed, success otherwise.
virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
return success();
}
//===--------------------------------------------------------------------===//
// Interfaces
//===--------------------------------------------------------------------===//
/// Lookup an interface for the given ID if one is registered, otherwise
/// nullptr.
const DialectInterface *getRegisteredInterface(ClassID *interfaceID) {
auto it = registeredInterfaces.find(interfaceID);
return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
}
template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
return static_cast<const InterfaceT *>(
getRegisteredInterface(InterfaceT::getInterfaceID()));
}
protected:
/// The constructor takes a unique namespace for this dialect as well as the
/// context to bind to.
/// Note: The namespace must not contain '.' characters.
/// Note: All operations belonging to this dialect must have names starting
/// with the namespace followed by '.'.
/// Example:
/// - "tf" for the TensorFlow ops like "tf.add".
Dialect(StringRef name, MLIRContext *context);
/// This method is used by derived classes to add their operations to the set.
///
template <typename... Args> void addOperations() {
VariadicOperationAdder<Args...>::addToSet(*this);
}
// It would be nice to define this as variadic functions instead of a nested
// variadic type, but we can't do that: function template partial
// specialization is not allowed, and we can't define an overload set because
// we don't have any arguments of the types we are pushing around.
template <typename First, typename... Rest> class VariadicOperationAdder {
public:
static void addToSet(Dialect &dialect) {
dialect.addOperation(AbstractOperation::get<First>(dialect));
VariadicOperationAdder<Rest...>::addToSet(dialect);
}
};
template <typename First> class VariadicOperationAdder<First> {
public:
static void addToSet(Dialect &dialect) {
dialect.addOperation(AbstractOperation::get<First>(dialect));
}
};
void addOperation(AbstractOperation opInfo);
/// This method is used by derived classes to add their types to the set.
template <typename... Args> void addTypes() {
VariadicSymbolAdder<Args...>::addToSet(*this);
}
/// This method is used by derived classes to add their attributes to the set.
template <typename... Args> void addAttributes() {
VariadicSymbolAdder<Args...>::addToSet(*this);
}
// It would be nice to define this as variadic functions instead of a nested
// variadic type, but we can't do that: function template partial
// specialization is not allowed, and we can't define an overload set
// because we don't have any arguments of the types we are pushing around.
template <typename First, typename... Rest> struct VariadicSymbolAdder {
static void addToSet(Dialect &dialect) {
VariadicSymbolAdder<First>::addToSet(dialect);
VariadicSymbolAdder<Rest...>::addToSet(dialect);
}
};
template <typename First> struct VariadicSymbolAdder<First> {
static void addToSet(Dialect &dialect) {
dialect.addSymbol(First::getClassID());
}
};
/// Enable support for unregistered operations.
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
/// Enable support for unregistered types.
void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
/// Register a dialect interface with this dialect instance.
void addInterface(std::unique_ptr<DialectInterface> interface);
/// Register a set of dialect interfaces with this dialect instance.
template <typename T, typename T2, typename... Tys> void addInterfaces() {
addInterfaces<T>();
addInterfaces<T2, Tys...>();
}
template <typename T> void addInterfaces() {
addInterface(std::make_unique<T>(this));
}
private:
// Register a symbol(e.g. type) with its given unique class identifier.
void addSymbol(const ClassID *const classID);
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;
/// Register this dialect object with the specified context. The context
/// takes ownership of the heap allocated dialect.
void registerDialect(MLIRContext *context);
/// The namespace of this dialect.
StringRef name;
/// This is the context that owns this Dialect object.
MLIRContext *context;
/// Flag that specifies whether this dialect supports unregistered operations,
/// i.e. operations prefixed with the dialect namespace but not registered
/// with addOperation.
bool unknownOpsAllowed = false;
/// Flag that specifies whether this dialect allows unregistered types, i.e.
/// types prefixed with the dialect namespace but not registered with addType.
/// These types are represented with OpaqueType.
bool unknownTypesAllowed = false;
/// A collection of registered dialect interfaces.
DenseMap<ClassID *, std::unique_ptr<DialectInterface>> registeredInterfaces;
};
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
/// Registers a specific dialect creation function with the system, typically
/// used through the DialectRegistration template.
void registerDialectAllocator(const DialectAllocatorFunction &function);
/// Registers all dialects with the specified MLIRContext.
void registerAllDialects(MLIRContext *context);
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
template <typename ConcreteDialect> void registerDialect() {
registerDialectAllocator([](MLIRContext *ctx) {
// Just allocate the dialect, the context takes ownership of it.
new ConcreteDialect(ctx);
});
}
/// DialectRegistration provides a global initializer that registers a Dialect
/// allocation routine.
///
/// Usage:
///
/// // At namespace scope.
/// static DialectRegistration<MyDialect> Unused;
template <typename ConcreteDialect> struct DialectRegistration {
DialectRegistration() { registerDialect<ConcreteDialect>(); }
};
} // namespace mlir
#endif