| //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements mlir::applyPatternsGreedily. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #include "mlir/Config/mlir-config.h" |
| #include "mlir/IR/Action.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Verifier.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Rewrite/PatternApplicator.h" |
| #include "mlir/Transforms/FoldUtils.h" |
| #include "mlir/Transforms/RegionUtils.h" |
| #include "llvm/ADT/BitVector.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/ScopedPrinter.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED |
| #include <random> |
| #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "greedy-rewriter" |
| |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Debugging Infrastructure |
| //===----------------------------------------------------------------------===// |
| |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| /// A helper struct that performs various "expensive checks" to detect broken |
| /// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is |
| /// broken if: |
| /// * IR does not verify after pattern application / folding. |
| /// * Pattern returns "failure" but the IR has changed. |
| /// * Pattern returns "success" but the IR has not changed. |
| /// |
| /// This struct stores finger prints of ops to determine whether the IR has |
| /// changed or not. |
| struct ExpensiveChecks : public RewriterBase::ForwardingListener { |
| ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel) |
| : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {} |
| |
| /// Compute finger prints of the given op and its nested ops. |
| void computeFingerPrints(Operation *topLevel) { |
| this->topLevel = topLevel; |
| this->topLevelFingerPrint.emplace(topLevel); |
| topLevel->walk([&](Operation *op) { |
| fingerprints.try_emplace(op, op, /*includeNested=*/false); |
| }); |
| } |
| |
| /// Clear all finger prints. |
| void clear() { |
| topLevel = nullptr; |
| topLevelFingerPrint.reset(); |
| fingerprints.clear(); |
| } |
| |
| void notifyRewriteSuccess() { |
| if (!topLevel) |
| return; |
| |
| // Make sure that the IR still verifies. |
| if (failed(verify(topLevel))) |
| llvm::report_fatal_error("IR failed to verify after pattern application"); |
| |
| // Pattern application success => IR must have changed. |
| OperationFingerPrint afterFingerPrint(topLevel); |
| if (*topLevelFingerPrint == afterFingerPrint) { |
| // Note: Run "mlir-opt -debug" to see which pattern is broken. |
| llvm::report_fatal_error( |
| "pattern returned success but IR did not change"); |
| } |
| for (const auto &it : fingerprints) { |
| // Skip top-level op, its finger print is never invalidated. |
| if (it.first == topLevel) |
| continue; |
| // Note: Finger print computation may crash when an op was erased |
| // without notifying the rewriter. (Run with ASAN to see where the op was |
| // erased; the op was probably erased directly, bypassing the rewriter |
| // API.) Finger print computation does may not crash if a new op was |
| // created at the same memory location. (But then the finger print should |
| // have changed.) |
| if (it.second != |
| OperationFingerPrint(it.first, /*includeNested=*/false)) { |
| // Note: Run "mlir-opt -debug" to see which pattern is broken. |
| llvm::report_fatal_error("operation finger print changed"); |
| } |
| } |
| } |
| |
| void notifyRewriteFailure() { |
| if (!topLevel) |
| return; |
| |
| // Pattern application failure => IR must not have changed. |
| OperationFingerPrint afterFingerPrint(topLevel); |
| if (*topLevelFingerPrint != afterFingerPrint) { |
| // Note: Run "mlir-opt -debug" to see which pattern is broken. |
| llvm::report_fatal_error("pattern returned failure but IR did change"); |
| } |
| } |
| |
| void notifyFoldingSuccess() { |
| if (!topLevel) |
| return; |
| |
| // Make sure that the IR still verifies. |
| if (failed(verify(topLevel))) |
| llvm::report_fatal_error("IR failed to verify after folding"); |
| } |
| |
| protected: |
| /// Invalidate the finger print of the given op, i.e., remove it from the map. |
| void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); } |
| |
| void notifyBlockErased(Block *block) override { |
| RewriterBase::ForwardingListener::notifyBlockErased(block); |
| |
| // The block structure (number of blocks, types of block arguments, etc.) |
| // is part of the fingerprint of the parent op. |
| // TODO: The parent op fingerprint should also be invalidated when modifying |
| // the block arguments of a block, but we do not have a |
| // `notifyBlockModified` callback yet. |
| invalidateFingerPrint(block->getParentOp()); |
| } |
| |
| void notifyOperationInserted(Operation *op, |
| OpBuilder::InsertPoint previous) override { |
| RewriterBase::ForwardingListener::notifyOperationInserted(op, previous); |
| invalidateFingerPrint(op->getParentOp()); |
| } |
| |
| void notifyOperationModified(Operation *op) override { |
| RewriterBase::ForwardingListener::notifyOperationModified(op); |
| invalidateFingerPrint(op); |
| } |
| |
| void notifyOperationErased(Operation *op) override { |
| RewriterBase::ForwardingListener::notifyOperationErased(op); |
| op->walk([this](Operation *op) { invalidateFingerPrint(op); }); |
| } |
| |
| /// Operation finger prints to detect invalid pattern API usage. IR is checked |
| /// against these finger prints after pattern application to detect cases |
| /// where IR was modified directly, bypassing the rewriter API. |
| DenseMap<Operation *, OperationFingerPrint> fingerprints; |
| |
| /// Top-level operation of the current greedy rewrite. |
| Operation *topLevel = nullptr; |
| |
| /// Finger print of the top-level operation. |
| std::optional<OperationFingerPrint> topLevelFingerPrint; |
| }; |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| |
| #ifndef NDEBUG |
| static Operation *getDumpRootOp(Operation *op) { |
| // Dump the parent op so that materialized constants are visible. If the op |
| // is a top-level op, dump it directly. |
| if (Operation *parentOp = op->getParentOp()) |
| return parentOp; |
| return op; |
| } |
| static void logSuccessfulFolding(Operation *op) { |
| llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n"; |
| op->dump(); |
| llvm::dbgs() << "\n\n"; |
| } |
| #endif // NDEBUG |
| |
| //===----------------------------------------------------------------------===// |
| // Worklist |
| //===----------------------------------------------------------------------===// |
| |
| /// A LIFO worklist of operations with efficient removal and set semantics. |
| /// |
| /// This class maintains a vector of operations and a mapping of operations to |
| /// positions in the vector, so that operations can be removed efficiently at |
| /// random. When an operation is removed, it is replaced with nullptr. Such |
| /// nullptr are skipped when pop'ing elements. |
| class Worklist { |
| public: |
| Worklist(); |
| |
| /// Clear the worklist. |
| void clear(); |
| |
| /// Return whether the worklist is empty. |
| bool empty() const; |
| |
| /// Push an operation to the end of the worklist, unless the operation is |
| /// already on the worklist. |
| void push(Operation *op); |
| |
| /// Pop the an operation from the end of the worklist. Only allowed on |
| /// non-empty worklists. |
| Operation *pop(); |
| |
| /// Remove an operation from the worklist. |
| void remove(Operation *op); |
| |
| /// Reverse the worklist. |
| void reverse(); |
| |
| protected: |
| /// The worklist of operations. |
| std::vector<Operation *> list; |
| |
| /// A mapping of operations to positions in `list`. |
| DenseMap<Operation *, unsigned> map; |
| }; |
| |
| Worklist::Worklist() { list.reserve(64); } |
| |
| void Worklist::clear() { |
| list.clear(); |
| map.clear(); |
| } |
| |
| bool Worklist::empty() const { |
| // Skip all nullptr. |
| return !llvm::any_of(list, |
| [](Operation *op) { return static_cast<bool>(op); }); |
| } |
| |
| void Worklist::push(Operation *op) { |
| assert(op && "cannot push nullptr to worklist"); |
| // Check to see if the worklist already contains this op. |
| if (!map.insert({op, list.size()}).second) |
| return; |
| list.push_back(op); |
| } |
| |
| Operation *Worklist::pop() { |
| assert(!empty() && "cannot pop from empty worklist"); |
| // Skip and remove all trailing nullptr. |
| while (!list.back()) |
| list.pop_back(); |
| Operation *op = list.back(); |
| list.pop_back(); |
| map.erase(op); |
| // Cleanup: Remove all trailing nullptr. |
| while (!list.empty() && !list.back()) |
| list.pop_back(); |
| return op; |
| } |
| |
| void Worklist::remove(Operation *op) { |
| assert(op && "cannot remove nullptr from worklist"); |
| auto it = map.find(op); |
| if (it != map.end()) { |
| assert(list[it->second] == op && "malformed worklist data structure"); |
| list[it->second] = nullptr; |
| map.erase(it); |
| } |
| } |
| |
| void Worklist::reverse() { |
| std::reverse(list.begin(), list.end()); |
| for (size_t i = 0, e = list.size(); i != e; ++i) |
| map[list[i]] = i; |
| } |
| |
| #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED |
| /// A worklist that pops elements at a random position. This worklist is for |
| /// testing/debugging purposes only. It can be used to ensure that lowering |
| /// pipelines work correctly regardless of the order in which ops are processed |
| /// by the GreedyPatternRewriteDriver. |
| class RandomizedWorklist : public Worklist { |
| public: |
| RandomizedWorklist() : Worklist() { |
| generator.seed(MLIR_GREEDY_REWRITE_RANDOMIZER_SEED); |
| } |
| |
| /// Pop a random non-empty op from the worklist. |
| Operation *pop() { |
| Operation *op = nullptr; |
| do { |
| assert(!list.empty() && "cannot pop from empty worklist"); |
| int64_t pos = generator() % list.size(); |
| op = list[pos]; |
| list.erase(list.begin() + pos); |
| for (int64_t i = pos, e = list.size(); i < e; ++i) |
| map[list[i]] = i; |
| map.erase(op); |
| } while (!op); |
| return op; |
| } |
| |
| private: |
| std::minstd_rand0 generator; |
| }; |
| #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED |
| |
| //===----------------------------------------------------------------------===// |
| // GreedyPatternRewriteDriver |
| //===----------------------------------------------------------------------===// |
| |
| /// This is a worklist-driven driver for the PatternMatcher, which repeatedly |
| /// applies the locally optimal patterns. |
| /// |
| /// This abstract class manages the worklist and contains helper methods for |
| /// rewriting ops on the worklist. Derived classes specify how ops are added |
| /// to the worklist in the beginning. |
| class GreedyPatternRewriteDriver : public RewriterBase::Listener { |
| protected: |
| explicit GreedyPatternRewriteDriver(MLIRContext *ctx, |
| const FrozenRewritePatternSet &patterns, |
| const GreedyRewriteConfig &config); |
| |
| /// Add the given operation to the worklist. |
| void addSingleOpToWorklist(Operation *op); |
| |
| /// Add the given operation and its ancestors to the worklist. |
| void addToWorklist(Operation *op); |
| |
| /// Notify the driver that the specified operation may have been modified |
| /// in-place. The operation is added to the worklist. |
| void notifyOperationModified(Operation *op) override; |
| |
| /// Notify the driver that the specified operation was inserted. Update the |
| /// worklist as needed: The operation is enqueued depending on scope and |
| /// strict mode. |
| void notifyOperationInserted(Operation *op, |
| OpBuilder::InsertPoint previous) override; |
| |
| /// Notify the driver that the specified operation was removed. Update the |
| /// worklist as needed: The operation and its children are removed from the |
| /// worklist. |
| void notifyOperationErased(Operation *op) override; |
| |
| /// Notify the driver that the specified operation was replaced. Update the |
| /// worklist as needed: New users are added enqueued. |
| void notifyOperationReplaced(Operation *op, ValueRange replacement) override; |
| |
| /// Process ops until the worklist is empty or `config.maxNumRewrites` is |
| /// reached. Return `true` if any IR was changed. |
| bool processWorklist(); |
| |
| /// The pattern rewriter that is used for making IR modifications and is |
| /// passed to rewrite patterns. |
| PatternRewriter rewriter; |
| |
| /// The worklist for this transformation keeps track of the operations that |
| /// need to be (re)visited. |
| #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED |
| RandomizedWorklist worklist; |
| #else |
| Worklist worklist; |
| #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED |
| |
| /// Configuration information for how to simplify. |
| const GreedyRewriteConfig config; |
| |
| /// The list of ops we are restricting our rewrites to. These include the |
| /// supplied set of ops as well as new ops created while rewriting those ops |
| /// depending on `strictMode`. This set is not maintained when |
| /// `config.strictMode` is GreedyRewriteStrictness::AnyOp. |
| llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; |
| |
| private: |
| /// Look over the provided operands for any defining operations that should |
| /// be re-added to the worklist. This function should be called when an |
| /// operation is modified or removed, as it may trigger further |
| /// simplifications. |
| void addOperandsToWorklist(Operation *op); |
| |
| /// Notify the driver that the given block was inserted. |
| void notifyBlockInserted(Block *block, Region *previous, |
| Region::iterator previousIt) override; |
| |
| /// Notify the driver that the given block is about to be removed. |
| void notifyBlockErased(Block *block) override; |
| |
| /// For debugging only: Notify the driver of a pattern match failure. |
| void |
| notifyMatchFailure(Location loc, |
| function_ref<void(Diagnostic &)> reasonCallback) override; |
| |
| #ifndef NDEBUG |
| /// A logger used to emit information during the application process. |
| llvm::ScopedPrinter logger{llvm::dbgs()}; |
| #endif |
| |
| /// The low-level pattern applicator. |
| PatternApplicator matcher; |
| |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| ExpensiveChecks expensiveChecks; |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| }; |
| } // namespace |
| |
| GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( |
| MLIRContext *ctx, const FrozenRewritePatternSet &patterns, |
| const GreedyRewriteConfig &config) |
| : rewriter(ctx), config(config), matcher(patterns) |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| // clang-format off |
| , expensiveChecks( |
| /*driver=*/this, |
| /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr) |
| // clang-format on |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| { |
| // Apply a simple cost model based solely on pattern benefit. |
| matcher.applyDefaultCostModel(); |
| |
| // Set up listener. |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| // Send IR notifications to the debug handler. This handler will then forward |
| // all notifications to this GreedyPatternRewriteDriver. |
| rewriter.setListener(&expensiveChecks); |
| #else |
| rewriter.setListener(this); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| } |
| |
| bool GreedyPatternRewriteDriver::processWorklist() { |
| #ifndef NDEBUG |
| const char *logLineComment = |
| "//===-------------------------------------------===//\n"; |
| |
| /// A utility function to log a process result for the given reason. |
| auto logResult = [&](StringRef result, const llvm::Twine &msg = {}) { |
| logger.unindent(); |
| logger.startLine() << "} -> " << result; |
| if (!msg.isTriviallyEmpty()) |
| logger.getOStream() << " : " << msg; |
| logger.getOStream() << "\n"; |
| }; |
| auto logResultWithLine = [&](StringRef result, const llvm::Twine &msg = {}) { |
| logResult(result, msg); |
| logger.startLine() << logLineComment; |
| }; |
| #endif |
| |
| bool changed = false; |
| int64_t numRewrites = 0; |
| while (!worklist.empty() && |
| (numRewrites < config.maxNumRewrites || |
| config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) { |
| auto *op = worklist.pop(); |
| |
| LLVM_DEBUG({ |
| logger.getOStream() << "\n"; |
| logger.startLine() << logLineComment; |
| logger.startLine() << "Processing operation : '" << op->getName() << "'(" |
| << op << ") {\n"; |
| logger.indent(); |
| |
| // If the operation has no regions, just print it here. |
| if (op->getNumRegions() == 0) { |
| op->print( |
| logger.startLine(), |
| OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs()); |
| logger.getOStream() << "\n\n"; |
| } |
| }); |
| |
| // If the operation is trivially dead - remove it. |
| if (isOpTriviallyDead(op)) { |
| rewriter.eraseOp(op); |
| changed = true; |
| |
| LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead")); |
| continue; |
| } |
| |
| // Try to fold this op. Do not fold constant ops. That would lead to an |
| // infinite folding loop, as every constant op would be folded to an |
| // Attribute and then immediately be rematerialized as a constant op, which |
| // is then put on the worklist. |
| if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) { |
| SmallVector<OpFoldResult> foldResults; |
| if (succeeded(op->fold(foldResults))) { |
| LLVM_DEBUG(logResultWithLine("success", "operation was folded")); |
| #ifndef NDEBUG |
| Operation *dumpRootOp = getDumpRootOp(op); |
| #endif // NDEBUG |
| if (foldResults.empty()) { |
| // Op was modified in-place. |
| notifyOperationModified(op); |
| changed = true; |
| LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| expensiveChecks.notifyFoldingSuccess(); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| continue; |
| } |
| |
| // Op results can be replaced with `foldResults`. |
| assert(foldResults.size() == op->getNumResults() && |
| "folder produced incorrect number of results"); |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(op); |
| SmallVector<Value> replacements; |
| bool materializationSucceeded = true; |
| for (auto [ofr, resultType] : |
| llvm::zip_equal(foldResults, op->getResultTypes())) { |
| if (auto value = ofr.dyn_cast<Value>()) { |
| assert(value.getType() == resultType && |
| "folder produced value of incorrect type"); |
| replacements.push_back(value); |
| continue; |
| } |
| // Materialize Attributes as SSA values. |
| Operation *constOp = op->getDialect()->materializeConstant( |
| rewriter, cast<Attribute>(ofr), resultType, op->getLoc()); |
| |
| if (!constOp) { |
| // If materialization fails, cleanup any operations generated for |
| // the previous results. |
| llvm::SmallDenseSet<Operation *> replacementOps; |
| for (Value replacement : replacements) { |
| assert(replacement.use_empty() && |
| "folder reused existing op for one result but constant " |
| "materialization failed for another result"); |
| replacementOps.insert(replacement.getDefiningOp()); |
| } |
| for (Operation *op : replacementOps) { |
| rewriter.eraseOp(op); |
| } |
| |
| materializationSucceeded = false; |
| break; |
| } |
| |
| assert(constOp->hasTrait<OpTrait::ConstantLike>() && |
| "materializeConstant produced op that is not a ConstantLike"); |
| assert(constOp->getResultTypes()[0] == resultType && |
| "materializeConstant produced incorrect result type"); |
| replacements.push_back(constOp->getResult(0)); |
| } |
| |
| if (materializationSucceeded) { |
| rewriter.replaceOp(op, replacements); |
| changed = true; |
| LLVM_DEBUG(logSuccessfulFolding(dumpRootOp)); |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| expensiveChecks.notifyFoldingSuccess(); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| continue; |
| } |
| } |
| } |
| |
| // Try to match one of the patterns. The rewriter is automatically |
| // notified of any necessary changes, so there is nothing else to do |
| // here. |
| auto canApplyCallback = [&](const Pattern &pattern) { |
| LLVM_DEBUG({ |
| logger.getOStream() << "\n"; |
| logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '" |
| << op->getName() << " -> ("; |
| llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream()); |
| logger.getOStream() << ")' {\n"; |
| logger.indent(); |
| }); |
| if (config.listener) |
| config.listener->notifyPatternBegin(pattern, op); |
| return true; |
| }; |
| function_ref<bool(const Pattern &)> canApply = canApplyCallback; |
| auto onFailureCallback = [&](const Pattern &pattern) { |
| LLVM_DEBUG(logResult("failure", "pattern failed to match")); |
| if (config.listener) |
| config.listener->notifyPatternEnd(pattern, failure()); |
| }; |
| function_ref<void(const Pattern &)> onFailure = onFailureCallback; |
| auto onSuccessCallback = [&](const Pattern &pattern) { |
| LLVM_DEBUG(logResult("success", "pattern applied successfully")); |
| if (config.listener) |
| config.listener->notifyPatternEnd(pattern, success()); |
| return success(); |
| }; |
| function_ref<LogicalResult(const Pattern &)> onSuccess = onSuccessCallback; |
| |
| #ifdef NDEBUG |
| // Optimization: PatternApplicator callbacks are not needed when running in |
| // optimized mode and without a listener. |
| if (!config.listener) { |
| canApply = nullptr; |
| onFailure = nullptr; |
| onSuccess = nullptr; |
| } |
| #endif // NDEBUG |
| |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| if (config.scope) { |
| expensiveChecks.computeFingerPrints(config.scope->getParentOp()); |
| } |
| auto clearFingerprints = |
| llvm::make_scope_exit([&]() { expensiveChecks.clear(); }); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| |
| LogicalResult matchResult = |
| matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess); |
| |
| if (succeeded(matchResult)) { |
| LLVM_DEBUG(logResultWithLine("success", "at least one pattern matched")); |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| expensiveChecks.notifyRewriteSuccess(); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| changed = true; |
| ++numRewrites; |
| } else { |
| LLVM_DEBUG(logResultWithLine("failure", "all patterns failed to match")); |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| expensiveChecks.notifyRewriteFailure(); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| } |
| } |
| |
| return changed; |
| } |
| |
| void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { |
| assert(op && "expected valid op"); |
| // Gather potential ancestors while looking for a "scope" parent region. |
| SmallVector<Operation *, 8> ancestors; |
| Region *region = nullptr; |
| do { |
| ancestors.push_back(op); |
| region = op->getParentRegion(); |
| if (config.scope == region) { |
| // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops. |
| for (Operation *op : ancestors) |
| addSingleOpToWorklist(op); |
| return; |
| } |
| if (region == nullptr) |
| return; |
| } while ((op = region->getParentOp())); |
| } |
| |
| void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { |
| if (config.strictMode == GreedyRewriteStrictness::AnyOp || |
| strictModeFilteredOps.contains(op)) |
| worklist.push(op); |
| } |
| |
| void GreedyPatternRewriteDriver::notifyBlockInserted( |
| Block *block, Region *previous, Region::iterator previousIt) { |
| if (config.listener) |
| config.listener->notifyBlockInserted(block, previous, previousIt); |
| } |
| |
| void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) { |
| if (config.listener) |
| config.listener->notifyBlockErased(block); |
| } |
| |
| void GreedyPatternRewriteDriver::notifyOperationInserted( |
| Operation *op, OpBuilder::InsertPoint previous) { |
| LLVM_DEBUG({ |
| logger.startLine() << "** Insert : '" << op->getName() << "'(" << op |
| << ")\n"; |
| }); |
| if (config.listener) |
| config.listener->notifyOperationInserted(op, previous); |
| if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) |
| strictModeFilteredOps.insert(op); |
| addToWorklist(op); |
| } |
| |
| void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) { |
| LLVM_DEBUG({ |
| logger.startLine() << "** Modified: '" << op->getName() << "'(" << op |
| << ")\n"; |
| }); |
| if (config.listener) |
| config.listener->notifyOperationModified(op); |
| addToWorklist(op); |
| } |
| |
| void GreedyPatternRewriteDriver::addOperandsToWorklist(Operation *op) { |
| for (Value operand : op->getOperands()) { |
| // If this operand currently has at most 2 users, add its defining op to the |
| // worklist. Indeed, after the op is deleted, then the operand will have at |
| // most 1 user left. If it has 0 users left, it can be deleted too, |
| // and if it has 1 user left, there may be further canonicalization |
| // opportunities. |
| if (!operand) |
| continue; |
| |
| auto *defOp = operand.getDefiningOp(); |
| if (!defOp) |
| continue; |
| |
| Operation *otherUser = nullptr; |
| bool hasMoreThanTwoUses = false; |
| for (auto user : operand.getUsers()) { |
| if (user == op || user == otherUser) |
| continue; |
| if (!otherUser) { |
| otherUser = user; |
| continue; |
| } |
| hasMoreThanTwoUses = true; |
| break; |
| } |
| if (hasMoreThanTwoUses) |
| continue; |
| |
| addToWorklist(defOp); |
| } |
| } |
| |
| void GreedyPatternRewriteDriver::notifyOperationErased(Operation *op) { |
| LLVM_DEBUG({ |
| logger.startLine() << "** Erase : '" << op->getName() << "'(" << op |
| << ")\n"; |
| }); |
| |
| #ifndef NDEBUG |
| // Only ops that are within the configured scope are added to the worklist of |
| // the greedy pattern rewriter. Moreover, the parent op of the scope region is |
| // the part of the IR that is taken into account for the "expensive checks". |
| // A greedy pattern rewrite is not allowed to erase the parent op of the scope |
| // region, as that would break the worklist handling and the expensive checks. |
| if (config.scope && config.scope->getParentOp() == op) |
| llvm_unreachable( |
| "scope region must not be erased during greedy pattern rewrite"); |
| #endif // NDEBUG |
| |
| if (config.listener) |
| config.listener->notifyOperationErased(op); |
| |
| addOperandsToWorklist(op); |
| worklist.remove(op); |
| |
| if (config.strictMode != GreedyRewriteStrictness::AnyOp) |
| strictModeFilteredOps.erase(op); |
| } |
| |
| void GreedyPatternRewriteDriver::notifyOperationReplaced( |
| Operation *op, ValueRange replacement) { |
| LLVM_DEBUG({ |
| logger.startLine() << "** Replace : '" << op->getName() << "'(" << op |
| << ")\n"; |
| }); |
| if (config.listener) |
| config.listener->notifyOperationReplaced(op, replacement); |
| } |
| |
| void GreedyPatternRewriteDriver::notifyMatchFailure( |
| Location loc, function_ref<void(Diagnostic &)> reasonCallback) { |
| LLVM_DEBUG({ |
| Diagnostic diag(loc, DiagnosticSeverity::Remark); |
| reasonCallback(diag); |
| logger.startLine() << "** Match Failure : " << diag.str() << "\n"; |
| }); |
| if (config.listener) |
| config.listener->notifyMatchFailure(loc, reasonCallback); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RegionPatternRewriteDriver |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This driver simplfies all ops in a region. |
| class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { |
| public: |
| explicit RegionPatternRewriteDriver(MLIRContext *ctx, |
| const FrozenRewritePatternSet &patterns, |
| const GreedyRewriteConfig &config, |
| Region ®ions); |
| |
| /// Simplify ops inside `region` and simplify the region itself. Return |
| /// success if the transformation converged. |
| LogicalResult simplify(bool *changed) &&; |
| |
| private: |
| /// The region that is simplified. |
| Region ®ion; |
| }; |
| } // namespace |
| |
| RegionPatternRewriteDriver::RegionPatternRewriteDriver( |
| MLIRContext *ctx, const FrozenRewritePatternSet &patterns, |
| const GreedyRewriteConfig &config, Region ®ion) |
| : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) { |
| // Populate strict mode ops. |
| if (config.strictMode != GreedyRewriteStrictness::AnyOp) { |
| region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); |
| } |
| } |
| |
| namespace { |
| class GreedyPatternRewriteIteration |
| : public tracing::ActionImpl<GreedyPatternRewriteIteration> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration) |
| GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration) |
| : tracing::ActionImpl<GreedyPatternRewriteIteration>(units), |
| iteration(iteration) {} |
| static constexpr StringLiteral tag = "GreedyPatternRewriteIteration"; |
| void print(raw_ostream &os) const override { |
| os << "GreedyPatternRewriteIteration(" << iteration << ")"; |
| } |
| |
| private: |
| int64_t iteration = 0; |
| }; |
| } // namespace |
| |
| LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { |
| bool continueRewrites = false; |
| int64_t iteration = 0; |
| MLIRContext *ctx = rewriter.getContext(); |
| do { |
| // Check if the iteration limit was reached. |
| if (++iteration > config.maxIterations && |
| config.maxIterations != GreedyRewriteConfig::kNoLimit) |
| break; |
| |
| // New iteration: start with an empty worklist. |
| worklist.clear(); |
| |
| // `OperationFolder` CSE's constant ops (and may move them into parents |
| // regions to enable more aggressive CSE'ing). |
| OperationFolder folder(ctx, this); |
| auto insertKnownConstant = [&](Operation *op) { |
| // Check for existing constants when populating the worklist. This avoids |
| // accidentally reversing the constant order during processing. |
| Attribute constValue; |
| if (matchPattern(op, m_Constant(&constValue))) |
| if (!folder.insertKnownConstant(op, constValue)) |
| return true; |
| return false; |
| }; |
| |
| if (!config.useTopDownTraversal) { |
| // Add operations to the worklist in postorder. |
| region.walk([&](Operation *op) { |
| if (!config.cseConstants || !insertKnownConstant(op)) |
| addToWorklist(op); |
| }); |
| } else { |
| // Add all nested operations to the worklist in preorder. |
| region.walk<WalkOrder::PreOrder>([&](Operation *op) { |
| if (!config.cseConstants || !insertKnownConstant(op)) { |
| addToWorklist(op); |
| return WalkResult::advance(); |
| } |
| return WalkResult::skip(); |
| }); |
| |
| // Reverse the list so our pop-back loop processes them in-order. |
| worklist.reverse(); |
| } |
| |
| ctx->executeAction<GreedyPatternRewriteIteration>( |
| [&] { |
| continueRewrites = processWorklist(); |
| |
| // After applying patterns, make sure that the CFG of each of the |
| // regions is kept up to date. |
| if (config.enableRegionSimplification != |
| GreedySimplifyRegionLevel::Disabled) { |
| continueRewrites |= succeeded(simplifyRegions( |
| rewriter, region, |
| /*mergeBlocks=*/config.enableRegionSimplification == |
| GreedySimplifyRegionLevel::Aggressive)); |
| } |
| }, |
| {®ion}, iteration); |
| } while (continueRewrites); |
| |
| if (changed) |
| *changed = iteration > 1; |
| |
| // Whether the rewrite converges, i.e. wasn't changed in the last iteration. |
| return success(!continueRewrites); |
| } |
| |
| LogicalResult |
| mlir::applyPatternsGreedily(Region ®ion, |
| const FrozenRewritePatternSet &patterns, |
| GreedyRewriteConfig config, bool *changed) { |
| // The top-level operation must be known to be isolated from above to |
| // prevent performing canonicalizations on operations defined at or above |
| // the region containing 'op'. |
| assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() && |
| "patterns can only be applied to operations IsolatedFromAbove"); |
| |
| // Set scope if not specified. |
| if (!config.scope) |
| config.scope = ®ion; |
| |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| if (failed(verify(config.scope->getParentOp()))) |
| llvm::report_fatal_error( |
| "greedy pattern rewriter input IR failed to verify"); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| |
| // Start the pattern driver. |
| RegionPatternRewriteDriver driver(region.getContext(), patterns, config, |
| region); |
| LogicalResult converged = std::move(driver).simplify(changed); |
| LLVM_DEBUG(if (failed(converged)) { |
| llvm::dbgs() << "The pattern rewrite did not converge after scanning " |
| << config.maxIterations << " times\n"; |
| }); |
| return converged; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MultiOpPatternRewriteDriver |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This driver simplfies a list of ops. |
| class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { |
| public: |
| explicit MultiOpPatternRewriteDriver( |
| MLIRContext *ctx, const FrozenRewritePatternSet &patterns, |
| const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, |
| llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr); |
| |
| /// Simplify `ops`. Return `success` if the transformation converged. |
| LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&; |
| |
| private: |
| void notifyOperationErased(Operation *op) override { |
| GreedyPatternRewriteDriver::notifyOperationErased(op); |
| if (survivingOps) |
| survivingOps->erase(op); |
| } |
| |
| /// An optional set of ops that survived the rewrite. This set is populated |
| /// at the beginning of `simplifyLocally` with the inititally provided list |
| /// of ops. |
| llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr; |
| }; |
| } // namespace |
| |
| MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( |
| MLIRContext *ctx, const FrozenRewritePatternSet &patterns, |
| const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, |
| llvm::SmallDenseSet<Operation *, 4> *survivingOps) |
| : GreedyPatternRewriteDriver(ctx, patterns, config), |
| survivingOps(survivingOps) { |
| if (config.strictMode != GreedyRewriteStrictness::AnyOp) |
| strictModeFilteredOps.insert(ops.begin(), ops.end()); |
| |
| if (survivingOps) { |
| survivingOps->clear(); |
| survivingOps->insert(ops.begin(), ops.end()); |
| } |
| } |
| |
| LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops, |
| bool *changed) && { |
| // Populate the initial worklist. |
| for (Operation *op : ops) |
| addSingleOpToWorklist(op); |
| |
| // Process ops on the worklist. |
| bool result = processWorklist(); |
| if (changed) |
| *changed = result; |
| |
| return success(worklist.empty()); |
| } |
| |
| /// Find the region that is the closest common ancestor of all given ops. |
| /// |
| /// Note: This function returns `nullptr` if there is a top-level op among the |
| /// given list of ops. |
| static Region *findCommonAncestor(ArrayRef<Operation *> ops) { |
| assert(!ops.empty() && "expected at least one op"); |
| // Fast path in case there is only one op. |
| if (ops.size() == 1) |
| return ops.front()->getParentRegion(); |
| |
| Region *region = ops.front()->getParentRegion(); |
| ops = ops.drop_front(); |
| int sz = ops.size(); |
| llvm::BitVector remainingOps(sz, true); |
| while (region) { |
| int pos = -1; |
| // Iterate over all remaining ops. |
| while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) { |
| // Is this op contained in `region`? |
| if (region->findAncestorOpInRegion(*ops[pos])) |
| remainingOps.reset(pos); |
| } |
| if (remainingOps.none()) |
| break; |
| region = region->getParentRegion(); |
| } |
| return region; |
| } |
| |
| LogicalResult mlir::applyOpPatternsGreedily( |
| ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns, |
| GreedyRewriteConfig config, bool *changed, bool *allErased) { |
| if (ops.empty()) { |
| if (changed) |
| *changed = false; |
| if (allErased) |
| *allErased = true; |
| return success(); |
| } |
| |
| // Determine scope of rewrite. |
| if (!config.scope) { |
| // Compute scope if none was provided. The scope will remain `nullptr` if |
| // there is a top-level op among `ops`. |
| config.scope = findCommonAncestor(ops); |
| } else { |
| // If a scope was provided, make sure that all ops are in scope. |
| #ifndef NDEBUG |
| bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) { |
| return static_cast<bool>(config.scope->findAncestorOpInRegion(*op)); |
| }); |
| assert(allOpsInScope && "ops must be within the specified scope"); |
| #endif // NDEBUG |
| } |
| |
| #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| if (config.scope && failed(verify(config.scope->getParentOp()))) |
| llvm::report_fatal_error( |
| "greedy pattern rewriter input IR failed to verify"); |
| #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS |
| |
| // Start the pattern driver. |
| llvm::SmallDenseSet<Operation *, 4> surviving; |
| MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, |
| config, ops, |
| allErased ? &surviving : nullptr); |
| LogicalResult converged = std::move(driver).simplify(ops, changed); |
| if (allErased) |
| *allErased = surviving.empty(); |
| LLVM_DEBUG(if (failed(converged)) { |
| llvm::dbgs() << "The pattern rewrite did not converge after " |
| << config.maxNumRewrites << " rewrites"; |
| }); |
| return converged; |
| } |