| //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements an applicator that applies pattern rewrites based upon a |
| // user defined cost model. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Rewrite/PatternApplicator.h" |
| #include "ByteCode.h" |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "pattern-application" |
| |
| using namespace mlir; |
| using namespace mlir::detail; |
| |
| PatternApplicator::PatternApplicator( |
| const FrozenRewritePatternSet &frozenPatternList) |
| : frozenPatternList(frozenPatternList) { |
| if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
| mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); |
| bytecode->initializeMutableState(*mutableByteCodeState); |
| } |
| } |
| PatternApplicator::~PatternApplicator() {} |
| |
| #ifndef NDEBUG |
| /// Log a message for a pattern that is impossible to match. |
| static void logImpossibleToMatch(const Pattern &pattern) { |
| llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() |
| << "' because it is impossible to match or cannot lead " |
| "to legal IR (by cost model)\n"; |
| } |
| |
| /// Log IR after pattern application. |
| static Operation *getDumpRootOp(Operation *op) { |
| return op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>(); |
| } |
| static void logSucessfulPatternApplication(Operation *op) { |
| llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n"; |
| op->dump(); |
| llvm::dbgs() << "\n\n"; |
| } |
| #endif |
| |
| void PatternApplicator::applyCostModel(CostModel model) { |
| // Apply the cost model to the bytecode patterns first, and then the native |
| // patterns. |
| if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
| for (auto it : llvm::enumerate(bytecode->getPatterns())) |
| mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); |
| } |
| |
| // Copy over the patterns so that we can sort by benefit based on the cost |
| // model. Patterns that are already impossible to match are ignored. |
| patterns.clear(); |
| for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { |
| for (const RewritePattern *pattern : it.second) { |
| if (pattern->getBenefit().isImpossibleToMatch()) |
| LLVM_DEBUG(logImpossibleToMatch(*pattern)); |
| else |
| patterns[it.first].push_back(pattern); |
| } |
| } |
| anyOpPatterns.clear(); |
| for (const RewritePattern &pattern : |
| frozenPatternList.getMatchAnyOpNativePatterns()) { |
| if (pattern.getBenefit().isImpossibleToMatch()) |
| LLVM_DEBUG(logImpossibleToMatch(pattern)); |
| else |
| anyOpPatterns.push_back(&pattern); |
| } |
| |
| // Sort the patterns using the provided cost model. |
| llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; |
| auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { |
| return benefits[lhs] > benefits[rhs]; |
| }; |
| auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { |
| // Special case for one pattern in the list, which is the most common case. |
| if (list.size() == 1) { |
| if (model(*list.front()).isImpossibleToMatch()) { |
| LLVM_DEBUG(logImpossibleToMatch(*list.front())); |
| list.clear(); |
| } |
| return; |
| } |
| |
| // Collect the dynamic benefits for the current pattern list. |
| benefits.clear(); |
| for (const Pattern *pat : list) |
| benefits.try_emplace(pat, model(*pat)); |
| |
| // Sort patterns with highest benefit first, and remove those that are |
| // impossible to match. |
| std::stable_sort(list.begin(), list.end(), cmp); |
| while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { |
| LLVM_DEBUG(logImpossibleToMatch(*list.back())); |
| list.pop_back(); |
| } |
| }; |
| for (auto &it : patterns) |
| processPatternList(it.second); |
| processPatternList(anyOpPatterns); |
| } |
| |
| void PatternApplicator::walkAllPatterns( |
| function_ref<void(const Pattern &)> walk) { |
| for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) |
| for (const auto &pattern : it.second) |
| walk(*pattern); |
| for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) |
| walk(it); |
| if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
| for (const Pattern &it : bytecode->getPatterns()) |
| walk(it); |
| } |
| } |
| |
| LogicalResult PatternApplicator::matchAndRewrite( |
| Operation *op, PatternRewriter &rewriter, |
| function_ref<bool(const Pattern &)> canApply, |
| function_ref<void(const Pattern &)> onFailure, |
| function_ref<LogicalResult(const Pattern &)> onSuccess) { |
| // Before checking native patterns, first match against the bytecode. This |
| // won't automatically perform any rewrites so there is no need to worry about |
| // conflicts. |
| SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; |
| const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); |
| if (bytecode) |
| bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); |
| |
| // Check to see if there are patterns matching this specific operation type. |
| MutableArrayRef<const RewritePattern *> opPatterns; |
| auto patternIt = patterns.find(op->getName()); |
| if (patternIt != patterns.end()) |
| opPatterns = patternIt->second; |
| |
| // Process the patterns for that match the specific operation type, and any |
| // operation type in an interleaved fashion. |
| unsigned opIt = 0, opE = opPatterns.size(); |
| unsigned anyIt = 0, anyE = anyOpPatterns.size(); |
| unsigned pdlIt = 0, pdlE = pdlMatches.size(); |
| LogicalResult result = failure(); |
| do { |
| // Find the next pattern with the highest benefit. |
| const Pattern *bestPattern = nullptr; |
| unsigned *bestPatternIt = &opIt; |
| const PDLByteCode::MatchResult *pdlMatch = nullptr; |
| |
| /// Operation specific patterns. |
| if (opIt < opE) |
| bestPattern = opPatterns[opIt]; |
| /// Operation agnostic patterns. |
| if (anyIt < anyE && |
| (!bestPattern || |
| bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { |
| bestPatternIt = &anyIt; |
| bestPattern = anyOpPatterns[anyIt]; |
| } |
| /// PDL patterns. |
| if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < |
| pdlMatches[pdlIt].benefit)) { |
| bestPatternIt = &pdlIt; |
| pdlMatch = &pdlMatches[pdlIt]; |
| bestPattern = pdlMatch->pattern; |
| } |
| if (!bestPattern) |
| break; |
| |
| // Update the pattern iterator on failure so that this pattern isn't |
| // attempted again. |
| ++(*bestPatternIt); |
| |
| // Check that the pattern can be applied. |
| if (canApply && !canApply(*bestPattern)) |
| continue; |
| |
| // Try to match and rewrite this pattern. The patterns are sorted by |
| // benefit, so if we match we can immediately rewrite. For PDL patterns, the |
| // match has already been performed, we just need to rewrite. |
| rewriter.setInsertionPoint(op); |
| #ifndef NDEBUG |
| // Operation `op` may be invalidated after applying the rewrite pattern. |
| Operation *dumpRootOp = getDumpRootOp(op); |
| #endif |
| if (pdlMatch) { |
| bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); |
| result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); |
| } else { |
| const auto *pattern = static_cast<const RewritePattern *>(bestPattern); |
| |
| LLVM_DEBUG(llvm::dbgs() |
| << "Trying to match \"" << pattern->getDebugName() << "\"\n"); |
| result = pattern->matchAndRewrite(op, rewriter); |
| LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result " |
| << succeeded(result) << "\n"); |
| |
| if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) |
| result = failure(); |
| } |
| if (succeeded(result)) { |
| LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); |
| break; |
| } |
| |
| // Perform any necessary cleanups. |
| if (onFailure) |
| onFailure(*bestPattern); |
| } while (true); |
| |
| if (mutableByteCodeState) |
| mutableByteCodeState->cleanupAfterMatchAndRewrite(); |
| return result; |
| } |