blob: 7f4166c12ded9e809dd0083c0a1b4ea3ee36c1e9 [file] [log] [blame]
//===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file declares various operation folding utilities. These
// utilities are intended to be used by passes to unify and simply their logic.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_FOLDUTILS_H
#define MLIR_TRANSFORMS_FOLDUTILS_H
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Interfaces/FoldInterfaces.h"
namespace mlir {
class Operation;
class Value;
//===--------------------------------------------------------------------===//
// OperationFolder
//===--------------------------------------------------------------------===//
/// A utility class for folding operations, and unifying duplicated constants
/// generated along the way.
class OperationFolder {
public:
OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, replaces `op`'s uses with
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
/// operations generated when folding. If the op was completely folded it is
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
LogicalResult
tryToFold(Operation *op,
function_ref<void(Operation *)> processGeneratedConstants = nullptr,
function_ref<void(Operation *)> preReplaceAction = nullptr,
bool *inPlaceUpdate = nullptr);
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
///
/// Note: this method must be called if a constant op is to be deleted
/// externally to this OperationFolder. `op` must be a constant op.
void notifyRemoval(Operation *op);
/// Create an operation of specific op type with the given builder,
/// and immediately try to fold it. This function populates 'results' with
/// the results after folding the operation.
template <typename OpTy, typename... Args>
void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
Location location, Args &&... args) {
// The op needs to be inserted only if the fold (below) fails, or the number
// of results produced by the successful folding is zero (which is treated
// as an in-place fold). Using create methods of the builder will insert the
// op, so not using it here.
OperationState state(location, OpTy::getOperationName());
OpTy::build(builder, state, std::forward<Args>(args)...);
Operation *op = Operation::create(state);
if (failed(tryToFold(builder, op, results)) || results.empty()) {
builder.insert(op);
results.assign(op->result_begin(), op->result_end());
return;
}
op->destroy();
}
/// Overload to create or fold a single result operation.
template <typename OpTy, typename... Args>
typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
Value>::type
create(OpBuilder &builder, Location location, Args &&... args) {
SmallVector<Value, 1> results;
create<OpTy>(builder, results, location, std::forward<Args>(args)...);
return results.front();
}
/// Overload to create or fold a zero result operation.
template <typename OpTy, typename... Args>
typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
OpTy>::type
create(OpBuilder &builder, Location location, Args &&... args) {
auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
SmallVector<Value, 0> unused;
(void)tryToFold(op.getOperation(), unused);
// Folding cannot remove a zero-result operation, so for convenience we
// continue to return it.
return op;
}
/// Clear out any constants cached inside of the folder.
void clear();
/// Get or create a constant using the given builder. On success this returns
/// the constant operation, nullptr otherwise.
Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
Attribute value, Type type, Location loc);
private:
/// This map keeps track of uniqued constants by dialect, attribute, and type.
/// A constant operation materializes an attribute with a type. Dialects may
/// generate different constants with the same input attribute and type, so we
/// also need to track per-dialect.
using ConstantMap =
DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult tryToFold(
OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants = nullptr);
/// Try to get or create a new constant entry. On success this returns the
/// constant operation, nullptr otherwise.
Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
Dialect *dialect, OpBuilder &builder,
Attribute value, Type type, Location loc);
/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes;
/// This map tracks all of the dialects that an operation is referenced by;
/// given that many dialects may generate the same constant.
DenseMap<Operation *, SmallVector<Dialect *, 2>> referencedDialects;
/// A collection of dialect folder interfaces.
DialectInterfaceCollection<DialectFoldInterface> interfaces;
};
} // end namespace mlir
#endif // MLIR_TRANSFORMS_FOLDUTILS_H