blob: e66dbbc664b4ef03306472b57b4e81e02742b4ed [file] [log] [blame]
//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares a generic pass for converting between MLIR dialects.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
// Forward declarations.
class Block;
class ConversionPatternRewriter;
class FuncOp;
class MLIRContext;
class Operation;
class Type;
class Value;
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
/// Type conversion class. Specific conversions and materializations can be
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
/// This class provides all of the information necessary to convert a type
/// signature.
class SignatureConversion {
public:
SignatureConversion(unsigned numOrigInputs)
: remappedInputs(numOrigInputs) {}
/// This struct represents a range of new types or a single value that
/// remaps an existing signature input.
struct InputMapping {
size_t inputNo, size;
Value replacementValue;
};
/// Return the argument types for the new signature.
ArrayRef<Type> getConvertedTypes() const { return argTypes; }
/// Get the input mapping for the given argument.
Optional<InputMapping> getInputMapping(unsigned input) const {
return remappedInputs[input];
}
//===------------------------------------------------------------------===//
// Conversion Hooks
//===------------------------------------------------------------------===//
/// Remap an input of the original signature with a new set of types. The
/// new types are appended to the new signature conversion.
void addInputs(unsigned origInputNo, ArrayRef<Type> types);
/// Append new input types to the signature conversion, this should only be
/// used if the new types are not intended to remap an existing input.
void addInputs(ArrayRef<Type> types);
/// Remap an input of the original signature to another `replacement`
/// value. This drops the original argument.
void remapInput(unsigned origInputNo, Value replacement);
private:
/// Remap an input of the original signature with a range of types in the
/// new signature.
void remapInput(unsigned origInputNo, unsigned newInputNo,
unsigned newInputCount = 1);
/// The remapping information for each of the original arguments.
SmallVector<Optional<InputMapping>, 4> remappedInputs;
/// The set of new argument types.
SmallVector<Type, 4> argTypes;
};
/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms(where `T` is a class derived from `Type`:
/// * Optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `llvm::None` to signify failure. If `llvm::None` is returned, the
/// converter is allowed to try another conversion function to perform
/// the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
/// `failure` or `llvm::None` to signify a failed conversion. If the new
/// set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `llvm::None` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
/// "call stack" of the recursive conversion: it contains the list of
/// types currently being converted, with the current type being the
/// last one. If it is present more than once in the list, the
/// conversion concerns a recursive type.
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<0>>
void addConversion(FnT &&callback) {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
/// Register a materialization function, which must be convertible to the
/// following form:
/// `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
/// must return a Value of the converted type on success, an `llvm::None` if
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.
///
/// This method registers a materialization that will be called when
/// converting an illegal block argument type, to a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
argumentMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting a legal type to an illegal source type. This is used when
/// conversions to an illegal type must persist beyond the main conversion.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
sourceMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting type from an illegal, or source, type to a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// Convert the given type. This function should return failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results);
/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t);
/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
/// otherwise.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl<Type> &results);
/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type);
/// Return true if all of the given types are legal for this type converter.
template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
!std::is_convertible<RangeT, Operation *>::value,
bool>
isLegal(RangeT &&range) {
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
}
/// Return true if the given operation has legal operand and result types.
bool isLegal(Operation *op);
/// Return true if the types of block arguments within the region are legal.
bool isLegal(Region *region);
/// Return true if the inputs and outputs of the given function type are
/// legal.
bool isSignatureLegal(FunctionType ty);
/// This method allows for converting a specific argument of a signature. It
/// takes as inputs the original argument input number, type.
/// On success, it populates 'result' with any new mappings.
LogicalResult convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result);
LogicalResult convertSignatureArgs(TypeRange types,
SignatureConversion &result,
unsigned origInputOffset = 0);
/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
/// valid conversion for the signature on success, None otherwise.
Optional<SignatureConversion> convertBlockSignature(Block *block);
/// Materialize a conversion from a set of types into one result type by
/// generating a cast sequence of some kind. See the respective
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) {
return materializeConversion(argumentMaterializations, builder, loc,
resultType, inputs);
}
Value materializeSourceConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) {
return materializeConversion(sourceMaterializations, builder, loc,
resultType, inputs);
}
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) {
return materializeConversion(targetMaterializations, builder, loc,
resultType, inputs);
}
private:
/// The signature of the callback used to convert a type. If the new set of
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<Optional<LogicalResult>(
Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
/// The signature of the callback used to materialize a conversion.
using MaterializationCallbackFn =
std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
/// Attempt to materialize a conversion using one of the provided
/// materialization functions.
Value materializeConversion(
MutableArrayRef<MaterializationCallbackFn> materializations,
OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `Optional<Type>(T)`
template <typename T, typename FnT>
std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
if (Optional<Type> resultOpt = callback(type)) {
bool wasSuccess = static_cast<bool>(resultOpt.getValue());
if (wasSuccess)
results.push_back(resultOpt.getValue());
return Optional<LogicalResult>(success(wasSuccess));
}
return Optional<LogicalResult>();
});
}
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
/// &)`
template <typename T, typename FnT>
std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
return callback(type, results);
});
}
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
/// &, ArrayRef<Type>)`.
template <typename T, typename FnT>
std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
ArrayRef<Type>>::value,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
Type type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> Optional<LogicalResult> {
T derivedType = type.dyn_cast<T>();
if (!derivedType)
return llvm::None;
return callback(derivedType, results, callStack);
};
}
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
conversions.emplace_back(std::move(callback));
cachedDirectConversions.clear();
cachedMultiConversions.clear();
}
/// Generate a wrapper for the given materialization callback. The callback
/// may take any subclass of `Type` and the wrapper will check for the target
/// type to be of the expected class before calling the callback.
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Type resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
if (T derivedType = resultType.dyn_cast<T>())
return callback(builder, derivedType, inputs, loc);
return llvm::None;
};
}
/// The set of registered conversion functions.
SmallVector<ConversionCallbackFn, 4> conversions;
/// The list of registered materialization functions.
SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
/// A set of cached conversions to avoid recomputing in the common case.
/// Direct 1-1 conversions are the most common, so this cache stores the
/// successful 1-1 conversions as well as all failed conversions.
DenseMap<Type, Type> cachedDirectConversions;
/// This cache stores the successful 1->N conversions, where N != 1.
DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// Stores the types that are being converted in the case when convertType
/// is being called recursively to convert nested types.
SmallVector<Type, 2> conversionCallStack;
};
//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//
/// Base class for the conversion patterns. This pattern class enables type
/// conversions, and other uses specific to the conversion framework. As such,
/// patterns of this type can only be used with the 'apply*' methods below.
class ConversionPattern : public RewritePattern {
public:
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `operands` is a list of the rewritten
/// operand values that are passed to `op`, `rewriter` can be used to emit the
/// new operations. This function should not fail. If some specific cases of
/// the operation are not supported, these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return failure();
rewrite(op, operands, rewriter);
return success();
}
/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
/// Return the type converter held by this pattern, or nullptr if the pattern
/// does not require type conversion.
TypeConverter *getTypeConverter() const { return typeConverter; }
template <typename ConverterTy>
std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
ConverterTy *>
getTypeConverter() const {
return static_cast<ConverterTy *>(typeConverter);
}
protected:
/// See `RewritePattern::RewritePattern` for information on the other
/// available constructors.
using RewritePattern::RewritePattern;
/// Construct a conversion pattern with the given converter, and forward the
/// remaining arguments to RewritePattern.
template <typename... Args>
ConversionPattern(TypeConverter &typeConverter, Args &&...args)
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
protected:
/// An optional type converter for use by this pattern.
TypeConverter *typeConverter = nullptr;
private:
using RewritePattern::rewrite;
};
/// OpConversionPattern is a wrapper around ConversionPattern that allows for
/// matching and rewriting against an instance of a derived operation class as
/// opposed to a raw Operation.
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
context) {}
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, op->getAttrDictionary()),
rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual LogicalResult match(SourceOp op) const {
llvm_unreachable("must override match or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return failure();
rewrite(op, adaptor, rewriter);
return success();
}
private:
using ConversionPattern::matchAndRewrite;
};
/// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
/// allows for matching and rewriting against an instance of an OpInterface
/// class as opposed to a raw Operation.
template <typename SourceOp>
class OpInterfaceConversionPattern : public ConversionPattern {
public:
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
OpInterfaceConversionPattern(TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
/// Wrappers around the ConversionPattern methods that pass the derived op
/// type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (failed(match(op)))
return failure();
rewrite(op, operands, rewriter);
return success();
}
private:
using ConversionPattern::matchAndRewrite;
};
/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionLike op with the given type converter. This only supports
/// FunctionLike ops which use FunctionType to represent their type.
void populateFunctionLikeTypeConversionPattern(StringRef functionLikeOpName,
RewritePatternSet &patterns,
TypeConverter &converter);
template <typename FuncOpT>
void populateFunctionLikeTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter) {
populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(),
patterns, converter);
}
/// Add a pattern to the given pattern list to convert the signature of a FuncOp
/// with the given type converter.
void populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter);
//===----------------------------------------------------------------------===//
// Conversion PatternRewriter
//===----------------------------------------------------------------------===//
namespace detail {
struct ConversionPatternRewriterImpl;
} // end namespace detail
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
class ConversionPatternRewriter final : public PatternRewriter {
public:
ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
/// Apply a signature conversion to the entry block of the given region. This
/// replaces the entry block with a new block containing the updated
/// signature. The new entry block to the region is returned for convenience.
///
/// If provided, `converter` will be used for any materializations.
Block *
applySignatureConversion(Region *region,
TypeConverter::SignatureConversion &conversion,
TypeConverter *converter = nullptr);
/// Convert the types of block arguments within the given region. This
/// replaces each block with a new block containing the updated signature. The
/// entry block may have a special conversion if `entryConversion` is
/// provided. On success, the new entry block to the region is returned for
/// convenience. Otherwise, failure is returned.
FailureOr<Block *> convertRegionTypes(
Region *region, TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
/// Convert the types of block arguments within the given region except for
/// the entry region. This replaces each non-entry block with a new block
/// containing the updated signature.
///
/// If special conversion behavior is needed for the non-entry blocks (for
/// example, we need to convert only a subset of a BB arguments), such
/// behavior can be specified in blockConversions.
LogicalResult convertNonEntryRegionTypes(
Region *region, TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions);
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
Value getRemappedValue(Value key);
/// Return the converted values that replace 'keys' with types defined by the
/// type converter of the currently executing pattern. Returns failure if the
/// remap failed, success otherwise.
LogicalResult getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results);
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
/// PatternRewriter hook for replacing the results of an operation when the
/// given functor returns true.
void replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) override;
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
using PatternRewriter::replaceOp;
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
/// otherwise an assert will be issued.
void eraseOp(Operation *op) override;
/// PatternRewriter hook for erase all operations in a block. This is not yet
/// implemented for dialect conversion.
void eraseBlock(Block *block) override;
/// PatternRewriter hook creating a new block.
void notifyBlockCreated(Block *block) override;
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
/// PatternRewriter hook for merging a block into another.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override;
/// PatternRewriter hook for moving blocks out of a region.
void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before) override;
using PatternRewriter::inlineRegionBefore;
/// PatternRewriter hook for cloning blocks of one region into another. The
/// given region to clone *must* not have been modified as part of conversion
/// yet, i.e. it must be within an operation that is either in the process of
/// conversion, or has not yet been converted.
void cloneRegionBefore(Region &region, Region &parent,
Region::iterator before,
BlockAndValueMapping &mapping) override;
using PatternRewriter::cloneRegionBefore;
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
/// Note: These methods only track updates to the top-level operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
void startRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
void finalizeRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
void cancelRootUpdate(Operation *op) override;
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
notifyMatchFailure(Operation *op,
function_ref<void(Diagnostic &)> reasonCallback) override;
using PatternRewriter::notifyMatchFailure;
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
private:
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
/// This class describes a specific conversion target.
class ConversionTarget {
public:
/// This enumeration corresponds to the specific action to take when
/// considering an operation legal for this conversion target.
enum class LegalizationAction {
/// The target supports this operation.
Legal,
/// This operation has dynamic legalization constraints that must be checked
/// by the target.
Dynamic,
/// The target explicitly does not support this operation.
Illegal,
};
/// A structure containing additional information describing a specific legal
/// operation instance.
struct LegalOpDetails {
/// A flag that indicates if this operation is 'recursively' legal. This
/// means that if an operation is legal, either statically or dynamically,
/// all of the operations nested within are also considered legal.
bool isRecursivelyLegal = false;
};
/// The signature of the callback used to determine if an operation is
/// dynamically legal on the target.
using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
//===--------------------------------------------------------------------===//
// Legality Registration
//===--------------------------------------------------------------------===//
/// Register a legality action for the given operation.
void setOpAction(OperationName op, LegalizationAction action);
template <typename OpT>
void setOpAction(LegalizationAction action) {
setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
}
/// Register the given operations as legal.
template <typename OpT>
void addLegalOp() {
setOpAction<OpT>(LegalizationAction::Legal);
}
template <typename OpT, typename OpT2, typename... OpTs>
void addLegalOp() {
addLegalOp<OpT>();
addLegalOp<OpT2, OpTs...>();
}
/// Register the given operation as dynamically legal and set the dynamic
/// legalization callback to the one provided.
template <typename OpT>
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
OperationName opName(OpT::getOperationName(), &ctx);
setOpAction(opName, LegalizationAction::Dynamic);
setLegalityCallback(opName, callback);
}
template <typename OpT, typename OpT2, typename... OpTs>
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
addDynamicallyLegalOp<OpT>(callback);
addDynamicallyLegalOp<OpT2, OpTs...>(callback);
}
template <typename OpT, class Callable>
typename std::enable_if<
!llvm::is_invocable<Callable, Operation *>::value>::type
addDynamicallyLegalOp(Callable &&callback) {
addDynamicallyLegalOp<OpT>(
[=](Operation *op) { return callback(cast<OpT>(op)); });
}
/// Register the given operation as illegal, i.e. this operation is known to
/// not be supported by this target.
template <typename OpT>
void addIllegalOp() {
setOpAction<OpT>(LegalizationAction::Illegal);
}
template <typename OpT, typename OpT2, typename... OpTs>
void addIllegalOp() {
addIllegalOp<OpT>();
addIllegalOp<OpT2, OpTs...>();
}
/// Mark an operation, that *must* have either been set as `Legal` or
/// `DynamicallyLegal`, as being recursively legal. This means that in
/// addition to the operation itself, all of the operations nested within are
/// also considered legal. An optional dynamic legality callback may be
/// provided to mark subsets of legal instances as recursively legal.
template <typename OpT>
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
OperationName opName(OpT::getOperationName(), &ctx);
markOpRecursivelyLegal(opName, callback);
}
template <typename OpT, typename OpT2, typename... OpTs>
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
markOpRecursivelyLegal<OpT>(callback);
markOpRecursivelyLegal<OpT2, OpTs...>(callback);
}
template <typename OpT, class Callable>
typename std::enable_if<
!llvm::is_invocable<Callable, Operation *>::value>::type
markOpRecursivelyLegal(Callable &&callback) {
markOpRecursivelyLegal<OpT>(
[=](Operation *op) { return callback(cast<OpT>(op)); });
}
/// Register a legality action for the given dialects.
void setDialectAction(ArrayRef<StringRef> dialectNames,
LegalizationAction action);
/// Register the operations of the given dialects as legal.
template <typename... Names>
void addLegalDialect(StringRef name, Names... names) {
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Legal);
}
template <typename... Args>
void addLegalDialect() {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Legal);
}
/// Register the operations of the given dialects as dynamically legal, i.e.
/// requiring custom handling by the callback.
template <typename... Names>
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback,
StringRef name, Names... names) {
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Dynamic);
setLegalityCallback(dialectNames, std::move(callback));
}
template <typename... Args>
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
addDynamicallyLegalDialect(std::move(callback),
Args::getDialectNamespace()...);
}
/// Register unknown operations as dynamically legal. For operations(and
/// dialects) that do not have a set legalization action, treat them as
/// dynamically legal and invoke the given callback.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
setLegalityCallback(fn);
}
/// Register the operations of the given dialects as illegal, i.e.
/// operations of this dialect are not supported by the target.
template <typename... Names>
void addIllegalDialect(StringRef name, Names... names) {
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Illegal);
}
template <typename... Args>
void addIllegalDialect() {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Illegal);
}
//===--------------------------------------------------------------------===//
// Legality Querying
//===--------------------------------------------------------------------===//
/// Get the legality action for the given operation.
Optional<LegalizationAction> getOpAction(OperationName op) const;
/// If the given operation instance is legal on this target, a structure
/// containing legality information is returned. If the operation is not
/// legal, None is returned. Also returns None is operation legality wasn't
/// registered by user or dynamic legality callbacks returned None.
///
/// Note: Legality is actually a 4-state: Legal(recursive=true),
/// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
/// either as Legal or Illegal depending on context.
Optional<LegalOpDetails> isLegal(Operation *op) const;
/// Returns true is operation instance is illegal on this target. Returns
/// false if operation is legal, operation legality wasn't registered by user
/// or dynamic legality callbacks returned None.
bool isIllegal(Operation *op) const;
private:
/// Set the dynamic legality callback for the given operation.
void setLegalityCallback(OperationName name,
const DynamicLegalityCallbackFn &callback);
/// Set the dynamic legality callback for the given dialects.
void setLegalityCallback(ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback);
/// Set the dynamic legality callback for the unknown ops.
void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
/// Set the recursive legality callback for the given operation and mark the
/// operation as recursively legal.
void markOpRecursivelyLegal(OperationName name,
const DynamicLegalityCallbackFn &callback);
/// The set of information that configures the legalization of an operation.
struct LegalizationInfo {
/// The legality action this operation was given.
LegalizationAction action = LegalizationAction::Illegal;
/// If some legal instances of this operation may also be recursively legal.
bool isRecursivelyLegal = false;
/// The legality callback if this operation is dynamically legal.
DynamicLegalityCallbackFn legalityFn;
};
/// Get the legalization information for the given operation.
Optional<LegalizationInfo> getOpInfo(OperationName op) const;
/// A deterministic mapping of operation name and its respective legality
/// information.
llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
/// A set of legality callbacks for given operation names that are used to
/// check if an operation instance is recursively legal.
DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
/// A deterministic mapping of dialect name to the specific legality action to
/// take.
llvm::StringMap<LegalizationAction> legalDialects;
/// A set of dynamic legality callbacks for given dialect names.
llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
/// An optional legality callback for unknown operations.
DynamicLegalityCallbackFn unknownLegalityFn;
/// The current context this target applies to.
MLIRContext &ctx;
};
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
/// Below we define several entry points for operation conversion. It is
/// important to note that the patterns provided to the conversion framework may
/// have additional constraints. See the `PatternRewriter Hooks` section of the
/// ConversionPatternRewriter, to see what additional constraints are imposed on
/// the use of the PatternRewriter.
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
/// returns failure if there ops explicitly marked as illegal. If an
/// `unconvertedOps` set is provided, all operations that are found not to be
/// legalizable to the given `target` are placed within that set. (Note that if
/// there is an op explicitly marked as illegal, the conversion terminates and
/// the `unconvertedOps` set will not necessarily be complete.)
LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps = nullptr);
LogicalResult
applyPartialConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps = nullptr);
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
/// fails, or if there are unreachable blocks in any of the regions nested
/// within 'ops'.
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
const FrozenRewritePatternSet &patterns);
LogicalResult applyFullConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns);
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
/// operations on success and only pre-existing operations are added to the set.
/// This method only returns failure if there are unreachable blocks in any of
/// the regions nested within 'ops'.
LogicalResult applyAnalysisConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> &convertedOps);
LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> &convertedOps);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_