blob: 9314496ecda15565aecf7345f5aa6a3eafb5b21e [file] [log] [blame]
//===- PatternApplicator.h - PatternApplicator ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an applicator that applies pattern rewrites based upon a
// user defined cost model.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
#define MLIR_REWRITE_PATTERNAPPLICATOR_H
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
namespace mlir {
class PatternRewriter;
namespace detail {
class PDLByteCodeMutableState;
} // end namespace detail
/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
public:
/// The cost model dynamically assigns a PatternBenefit to a particular
/// pattern. Users can query contained patterns and pass analysis results to
/// applyCostModel. Patterns to be discarded should have a benefit of
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
explicit PatternApplicator(const FrozenRewritePatternSet &frozenPatternList);
~PatternApplicator();
/// Attempt to match and rewrite the given op with any pattern, allowing a
/// predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
///
/// canApply: called before each match and rewrite attempt; return false to
/// skip pattern.
/// onFailure: called when a pattern fails to match to perform cleanup.
/// onSuccess: called when a pattern match succeeds; return failure() to
/// invalidate the match and try another pattern.
LogicalResult
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply = {},
function_ref<void(const Pattern &)> onFailure = {},
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
/// Apply a cost model to the patterns within this applicator.
void applyCostModel(CostModel model);
/// Apply the default cost model that solely uses the pattern's static
/// benefit.
void applyDefaultCostModel() {
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
}
/// Walk all of the patterns within the applicator.
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
private:
/// The list that owns the patterns used within this applicator.
const FrozenRewritePatternSet &frozenPatternList;
/// The set of patterns to match for each operation, stable sorted by benefit.
DenseMap<OperationName, SmallVector<const RewritePattern *, 2>> patterns;
/// The set of patterns that may match against any operation type, stable
/// sorted by benefit.
SmallVector<const RewritePattern *, 1> anyOpPatterns;
/// The mutable state used during execution of the PDL bytecode.
std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
};
} // end namespace mlir
#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H