blob: fe84c613006462b24be40fd5cc2c2e635c10c8d6 [file] [log] [blame]
//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements mlir::applyPatternsGreedily.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
#include <random>
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
using namespace mlir;
#define DEBUG_TYPE "greedy-rewriter"
namespace {
//===----------------------------------------------------------------------===//
// Debugging Infrastructure
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
/// A helper struct that performs various "expensive checks" to detect broken
/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
/// broken if:
/// * IR does not verify after pattern application / folding.
/// * Pattern returns "failure" but the IR has changed.
/// * Pattern returns "success" but the IR has not changed.
///
/// This struct stores finger prints of ops to determine whether the IR has
/// changed or not.
struct ExpensiveChecks : public RewriterBase::ForwardingListener {
ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
: RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
/// Compute finger prints of the given op and its nested ops.
void computeFingerPrints(Operation *topLevel) {
this->topLevel = topLevel;
this->topLevelFingerPrint.emplace(topLevel);
topLevel->walk([&](Operation *op) {
fingerprints.try_emplace(op, op, /*includeNested=*/false);
});
}
/// Clear all finger prints.
void clear() {
topLevel = nullptr;
topLevelFingerPrint.reset();
fingerprints.clear();
}
void notifyRewriteSuccess() {
if (!topLevel)
return;
// Make sure that the IR still verifies.
if (failed(verify(topLevel)))
llvm::report_fatal_error("IR failed to verify after pattern application");
// Pattern application success => IR must have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint == afterFingerPrint) {
// Note: Run "mlir-opt -debug" to see which pattern is broken.
llvm::report_fatal_error(
"pattern returned success but IR did not change");
}
for (const auto &it : fingerprints) {
// Skip top-level op, its finger print is never invalidated.
if (it.first == topLevel)
continue;
// Note: Finger print computation may crash when an op was erased
// without notifying the rewriter. (Run with ASAN to see where the op was
// erased; the op was probably erased directly, bypassing the rewriter
// API.) Finger print computation does may not crash if a new op was
// created at the same memory location. (But then the finger print should
// have changed.)
if (it.second !=
OperationFingerPrint(it.first, /*includeNested=*/false)) {
// Note: Run "mlir-opt -debug" to see which pattern is broken.
llvm::report_fatal_error("operation finger print changed");
}
}
}
void notifyRewriteFailure() {
if (!topLevel)
return;
// Pattern application failure => IR must not have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint != afterFingerPrint) {
// Note: Run "mlir-opt -debug" to see which pattern is broken.
llvm::report_fatal_error("pattern returned failure but IR did change");
}
}
void notifyFoldingSuccess() {
if (!topLevel)
return;
// Make sure that the IR still verifies.
if (failed(verify(topLevel)))
llvm::report_fatal_error("IR failed to verify after folding");
}
protected:
/// Invalidate the finger print of the given op, i.e., remove it from the map.
void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
void notifyBlockErased(Block *block) override {
RewriterBase::ForwardingListener::notifyBlockErased(block);
// The block structure (number of blocks, types of block arguments, etc.)
// is part of the fingerprint of the parent op.
// TODO: The parent op fingerprint should also be invalidated when modifying
// the block arguments of a block, but we do not have a
// `notifyBlockModified` callback yet.
invalidateFingerPrint(block->getParentOp());
}
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
invalidateFingerPrint(op->getParentOp());
}
void notifyOperationModified(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationModified(op);
invalidateFingerPrint(op);
}
void notifyOperationErased(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationErased(op);
op->walk([this](Operation *op) { invalidateFingerPrint(op); });
}
/// Operation finger prints to detect invalid pattern API usage. IR is checked
/// against these finger prints after pattern application to detect cases
/// where IR was modified directly, bypassing the rewriter API.
DenseMap<Operation *, OperationFingerPrint> fingerprints;
/// Top-level operation of the current greedy rewrite.
Operation *topLevel = nullptr;
/// Finger print of the top-level operation.
std::optional<OperationFingerPrint> topLevelFingerPrint;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
#ifndef NDEBUG
static Operation *getDumpRootOp(Operation *op) {
// Dump the parent op so that materialized constants are visible. If the op
// is a top-level op, dump it directly.
if (Operation *parentOp = op->getParentOp())
return parentOp;
return op;
}
static void logSuccessfulFolding(Operation *op) {
llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
op->dump();
llvm::dbgs() << "\n\n";
}
#endif // NDEBUG
//===----------------------------------------------------------------------===//
// Worklist
//===----------------------------------------------------------------------===//
/// A LIFO worklist of operations with efficient removal and set semantics.
///
/// This class maintains a vector of operations and a mapping of operations to
/// positions in the vector, so that operations can be removed efficiently at
/// random. When an operation is removed, it is replaced with nullptr. Such
/// nullptr are skipped when pop'ing elements.
class Worklist {
public:
Worklist();
/// Clear the worklist.
void clear();
/// Return whether the worklist is empty.
bool empty() const;
/// Push an operation to the end of the worklist, unless the operation is
/// already on the worklist.
void push(Operation *op);
/// Pop the an operation from the end of the worklist. Only allowed on
/// non-empty worklists.
Operation *pop();
/// Remove an operation from the worklist.
void remove(Operation *op);
/// Reverse the worklist.
void reverse();
protected:
/// The worklist of operations.
std::vector<Operation *> list;
/// A mapping of operations to positions in `list`.
DenseMap<Operation *, unsigned> map;
};
Worklist::Worklist() { list.reserve(64); }
void Worklist::clear() {
list.clear();
map.clear();
}
bool Worklist::empty() const {
// Skip all nullptr.
return !llvm::any_of(list,
[](Operation *op) { return static_cast<bool>(op); });
}
void Worklist::push(Operation *op) {
assert(op && "cannot push nullptr to worklist");
// Check to see if the worklist already contains this op.
if (!map.insert({op, list.size()}).second)
return;
list.push_back(op);
}
Operation *Worklist::pop() {
assert(!empty() && "cannot pop from empty worklist");
// Skip and remove all trailing nullptr.
while (!list.back())
list.pop_back();
Operation *op = list.back();
list.pop_back();
map.erase(op);
// Cleanup: Remove all trailing nullptr.
while (!list.empty() && !list.back())
list.pop_back();
return op;
}
void Worklist::remove(Operation *op) {
assert(op && "cannot remove nullptr from worklist");
auto it = map.find(op);
if (it != map.end()) {
assert(list[it->second] == op && "malformed worklist data structure");
list[it->second] = nullptr;
map.erase(it);
}
}
void Worklist::reverse() {
std::reverse(list.begin(), list.end());
for (size_t i = 0, e = list.size(); i != e; ++i)
map[list[i]] = i;
}
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
/// A worklist that pops elements at a random position. This worklist is for
/// testing/debugging purposes only. It can be used to ensure that lowering
/// pipelines work correctly regardless of the order in which ops are processed
/// by the GreedyPatternRewriteDriver.
class RandomizedWorklist : public Worklist {
public:
RandomizedWorklist() : Worklist() {
generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED);
}
/// Pop a random non-empty op from the worklist.
Operation *pop() {
Operation *op = nullptr;
do {
assert(!list.empty() && "cannot pop from empty worklist");
int64_t pos = generator() % list.size();
op = list[pos];
list.erase(list.begin() + pos);
for (int64_t i = pos, e = list.size(); i < e; ++i)
map[list[i]] = i;
map.erase(op);
} while (!op);
return op;
}
private:
std::minstd_rand0 generator;
};
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
//===----------------------------------------------------------------------===//
// GreedyPatternRewriteDriver
//===----------------------------------------------------------------------===//
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
/// applies the locally optimal patterns.
///
/// This abstract class manages the worklist and contains helper methods for
/// rewriting ops on the worklist. Derived classes specify how ops are added
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config);
/// Add the given operation to the worklist.
void addSingleOpToWorklist(Operation *op);
/// Add the given operation and its ancestors to the worklist.
void addToWorklist(Operation *op);
/// Notify the driver that the specified operation may have been modified
/// in-place. The operation is added to the worklist.
void notifyOperationModified(Operation *op) override;
/// Notify the driver that the specified operation was inserted. Update the
/// worklist as needed: The operation is enqueued depending on scope and
/// strict mode.
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override;
/// Notify the driver that the specified operation was removed. Update the
/// worklist as needed: The operation and its children are removed from the
/// worklist.
void notifyOperationErased(Operation *op) override;
/// Notify the driver that the specified operation was replaced. Update the
/// worklist as needed: New users are added enqueued.
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
/// Process ops until the worklist is empty or `config.maxNumRewrites` is
/// reached. Return `true` if any IR was changed.
bool processWorklist();
/// The pattern rewriter that is used for making IR modifications and is
/// passed to rewrite patterns.
PatternRewriter rewriter;
/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
#ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
RandomizedWorklist worklist;
#else
Worklist worklist;
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
/// The list of ops we are restricting our rewrites to. These include the
/// supplied set of ops as well as new ops created while rewriting those ops
/// depending on `strictMode`. This set is not maintained when
/// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
private:
/// Look over the provided operands for any defining operations that should
/// be re-added to the worklist. This function should be called when an
/// operation is modified or removed, as it may trigger further
/// simplifications.
void addOperandsToWorklist(Operation *op);
/// Notify the driver that the given block was inserted.
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override;
/// Notify the driver that the given block is about to be removed.
void notifyBlockErased(Block *block) override;
/// For debugging only: Notify the driver of a pattern match failure.
void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
#endif
/// The low-level pattern applicator.
PatternApplicator matcher;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
ExpensiveChecks expensiveChecks;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: rewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
/*driver=*/this,
/*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
// Set up listener.
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
rewriter.setListener(&expensiveChecks);
#else
rewriter.setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
bool GreedyPatternRewriteDriver::processWorklist() {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
/// A utility function to log a process result for the given reason.
auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) {
logger.unindent();
logger.startLine() << "} -> " << result;
if (!msg.isTriviallyEmpty())
logger.getOStream() << " : " << msg;
logger.getOStream() << "\n";
};
auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) {
logResult(result, msg);
logger.startLine() << logLineComment;
};
#endif
bool changed = false;
int64_t numRewrites = 0;
while (!worklist.empty() &&
(numRewrites < config.maxNumRewrites ||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
auto *op = worklist.pop();
LLVM_DEBUG({
logger.getOStream() << "\n";
logger.startLine() << logLineComment;
logger.startLine() << "Processing operation : '" << op->getName() << "'("
<< op << ") {\n";
logger.indent();
// If the operation has no regions, just print it here.
if (op->getNumRegions() == 0) {
op->print(
logger.startLine(),
OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
logger.getOStream() << "\n\n";
}
});
// If the operation is trivially dead - remove it.
if (isOpTriviallyDead(op)) {
rewriter.eraseOp(op);
changed = true;
LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
continue;
}
// Try to fold this op. Do not fold constant ops. That would lead to an
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
#ifndef NDEBUG
Operation *dumpRootOp = getDumpRootOp(op);
#endif // NDEBUG
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
SmallVector<Value> replacements;
bool materializationSucceeded = true;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
assert(value.getType() == resultType &&
"folder produced value of incorrect type");
replacements.push_back(value);
continue;
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
rewriter, cast<Attribute>(ofr), resultType, op->getLoc());
if (!constOp) {
// If materialization fails, cleanup any operations generated for
// the previous results.
llvm::SmallDenseSet<Operation *> replacementOps;
for (Value replacement : replacements) {
assert(replacement.use_empty() &&
"folder reused existing op for one result but constant "
"materialization failed for another result");
replacementOps.insert(replacement.getDefiningOp());
}
for (Operation *op : replacementOps) {
rewriter.eraseOp(op);
}
materializationSucceeded = false;
break;
}
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
if (materializationSucceeded) {
rewriter.replaceOp(op, replacements);
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
}
}
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do
// here.
auto canApplyCallback = [&](const Pattern &pattern) {
LLVM_DEBUG({
logger.getOStream() << "\n";
logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
<< op->getName() << " -> (";
llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
logger.getOStream() << ")' {\n";
logger.indent();
});
if (config.listener)
config.listener->notifyPatternBegin(pattern, op);
return true;
};
function_ref<bool(const Pattern &)> canApply = canApplyCallback;
auto onFailureCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("failure", "pattern failed to match"));
if (config.listener)
config.listener->notifyPatternEnd(pattern, failure());
};
function_ref<void(const Pattern &)> onFailure = onFailureCallback;
auto onSuccessCallback = [&](const Pattern &pattern) {
LLVM_DEBUG(logResult("success", "pattern applied successfully"));
if (config.listener)
config.listener->notifyPatternEnd(pattern, success());
return success();
};
function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback;
#ifdef NDEBUG
// Optimization: PatternApplicator callbacks are not needed when running in
// optimized mode and without a listener.
if (!config.listener) {
canApply = nullptr;
onFailure = nullptr;
onSuccess = nullptr;
}
#endif // NDEBUG
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
expensiveChecks.computeFingerPrints(config.scope->getParentOp());
}
auto clearFingerprints =
llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
LogicalResult matchResult =
matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
expensiveChecks.notifyRewriteSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
expensiveChecks.notifyRewriteFailure();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
return changed;
}
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
assert(op && "expected valid op");
// Gather potential ancestors while looking for a "scope" parent region.
SmallVector<Operation *, 8> ancestors;
Region *region = nullptr;
do {
ancestors.push_back(op);
region = op->getParentRegion();
if (config.scope == region) {
// Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
return;
}
if (region == nullptr)
return;
} while ((op = region->getParentOp()));
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
worklist.push(op);
}
void GreedyPatternRewriteDriver::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);
}
void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
if (config.listener)
config.listener->notifyBlockErased(block);
}
void GreedyPatternRewriteDriver::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationInserted(op, previous);
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
}
void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
LLVM_DEBUG({
logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationModified(op);
addToWorklist(op);
}
void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) {
for (Value operand : op->getOperands()) {
// If this operand currently has at most 2 users, add its defining op to the
// worklist. Indeed, after the op is deleted, then the operand will have at
// most 1 user left. If it has 0 users left, it can be deleted too,
// and if it has 1 user left, there may be further canonicalization
// opportunities.
if (!operand)
continue;
auto *defOp = operand.getDefiningOp();
if (!defOp)
continue;
Operation *otherUser = nullptr;
bool hasMoreThanTwoUses = false;
for (auto user : operand.getUsers()) {
if (user == op || user == otherUser)
continue;
if (!otherUser) {
otherUser = user;
continue;
}
hasMoreThanTwoUses = true;
break;
}
if (hasMoreThanTwoUses)
continue;
addToWorklist(defOp);
}
}
void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) {
LLVM_DEBUG({
logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
<< ")\n";
});
#ifndef NDEBUG
// Only ops that are within the configured scope are added to the worklist of
// the greedy pattern rewriter. Moreover, the parent op of the scope region is
// the part of the IR that is taken into account for the "expensive checks".
// A greedy pattern rewrite is not allowed to erase the parent op of the scope
// region, as that would break the worklist handling and the expensive checks.
if (config.scope && config.scope->getParentOp() == op)
llvm_unreachable(
"scope region must not be erased during greedy pattern rewrite");
#endif // NDEBUG
if (config.listener)
config.listener->notifyOperationErased(op);
addOperandsToWorklist(op);
worklist.remove(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
void GreedyPatternRewriteDriver::notifyOperationReplaced(
Operation *op, ValueRange replacement) {
LLVM_DEBUG({
logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.listener)
config.listener->notifyOperationReplaced(op, replacement);
}
void GreedyPatternRewriteDriver::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Match Failure : " << diag.str() << "\n";
});
if (config.listener)
config.listener->notifyMatchFailure(loc, reasonCallback);
}
//===----------------------------------------------------------------------===//
// RegionPatternRewriteDriver
//===----------------------------------------------------------------------===//
namespace {
/// This driver simplfies all ops in a region.
class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit RegionPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config,
Region &regions);
/// Simplify ops inside `region` and simplify the region itself. Return
/// success if the transformation converged.
LogicalResult simplify(bool *changed) &&;
private:
/// The region that is simplified.
Region &region;
};
} // namespace
RegionPatternRewriteDriver::RegionPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, Region &region)
: GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
// Populate strict mode ops.
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
}
}
namespace {
class GreedyPatternRewriteIteration
: public tracing::ActionImpl<GreedyPatternRewriteIteration> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
: tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
iteration(iteration) {}
static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
void print(raw_ostream &os) const override {
os << "GreedyPatternRewriteIteration(" << iteration << ")";
}
private:
int64_t iteration = 0;
};
} // namespace
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = rewriter.getContext();
do {
// Check if the iteration limit was reached.
if (++iteration > config.maxIterations &&
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
// New iteration: start with an empty worklist.
worklist.clear();
// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
OperationFolder folder(ctx, this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return true;
return false;
};
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
if (!config.cseConstants || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (!config.cseConstants || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
return WalkResult::skip();
});
// Reverse the list so our pop-back loop processes them in-order.
worklist.reverse();
}
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
continueRewrites = processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification !=
GreedySimplifyRegionLevel::Disabled) {
continueRewrites |= succeeded(simplifyRegions(
rewriter, region,
/*mergeBlocks=*/config.enableRegionSimplification ==
GreedySimplifyRegionLevel::Aggressive));
}
},
{&region}, iteration);
} while (continueRewrites);
if (changed)
*changed = iteration > 1;
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return success(!continueRewrites);
}
LogicalResult
mlir::applyPatternsGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed) {
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"patterns can only be applied to operations IsolatedFromAbove");
// Set scope if not specified.
if (!config.scope)
config.scope = &region;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error(
"greedy pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
});
return converged;
}
//===----------------------------------------------------------------------===//
// MultiOpPatternRewriteDriver
//===----------------------------------------------------------------------===//
namespace {
/// This driver simplfies a list of ops.
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
/// Simplify `ops`. Return `success` if the transformation converged.
LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
private:
void notifyOperationErased(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationErased(op);
if (survivingOps)
survivingOps->erase(op);
}
/// An optional set of ops that survived the rewrite. This set is populated
/// at the beginning of `simplifyLocally` with the inititally provided list
/// of ops.
llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
};
} // namespace
MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
: GreedyPatternRewriteDriver(ctx, patterns, config),
survivingOps(survivingOps) {
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert(ops.begin(), ops.end());
if (survivingOps) {
survivingOps->clear();
survivingOps->insert(ops.begin(), ops.end());
}
}
LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
bool *changed) && {
// Populate the initial worklist.
for (Operation *op : ops)
addSingleOpToWorklist(op);
// Process ops on the worklist.
bool result = processWorklist();
if (changed)
*changed = result;
return success(worklist.empty());
}
/// Find the region that is the closest common ancestor of all given ops.
///
/// Note: This function returns `nullptr` if there is a top-level op among the
/// given list of ops.
static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
assert(!ops.empty() && "expected at least one op");
// Fast path in case there is only one op.
if (ops.size() == 1)
return ops.front()->getParentRegion();
Region *region = ops.front()->getParentRegion();
ops = ops.drop_front();
int sz = ops.size();
llvm::BitVector remainingOps(sz, true);
while (region) {
int pos = -1;
// Iterate over all remaining ops.
while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
// Is this op contained in `region`?
if (region->findAncestorOpInRegion(*ops[pos]))
remainingOps.reset(pos);
}
if (remainingOps.none())
break;
region = region->getParentRegion();
}
return region;
}
LogicalResult mlir::applyOpPatternsGreedily(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
if (allErased)
*allErased = true;
return success();
}
// Determine scope of rewrite.
if (!config.scope) {
// Compute scope if none was provided. The scope will remain `nullptr` if
// there is a top-level op among `ops`.
config.scope = findCommonAncestor(ops);
} else {
// If a scope was provided, make sure that all ops are in scope.
#ifndef NDEBUG
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
});
assert(allOpsInScope && "ops must be within the specified scope");
#endif // NDEBUG
}
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error(
"greedy pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
config, ops,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after "
<< config.maxNumRewrites << " rewrites";
});
return converged;
}