blob: 00c62163c23bc2d6d19d7e62dde1d881b6bd3d93 [file] [log] [blame]
//===- FrozenRewritePatternSet.h --------------------------------*- C++ -*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace detail {
class PDLByteCode;
} // end namespace detail
/// This class represents a frozen set of patterns that can be processed by a
/// pattern applicator. This class is designed to enable caching pattern lists
/// such that they need not be continuously recomputed. Note that all copies of
/// this class share the same compiled pattern list, allowing for a reduction in
/// the number of duplicated patterns that need to be created.
class FrozenRewritePatternSet {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
/// A map of operation specific native patterns.
using OpSpecificNativePatternListT =
DenseMap<OperationName, std::vector<RewritePattern *>>;
FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default;
FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default;
FrozenRewritePatternSet &
operator=(const FrozenRewritePatternSet &patterns) = default;
FrozenRewritePatternSet &
operator=(FrozenRewritePatternSet &&patterns) = default;
/// Freeze the patterns held in `patterns`, and take ownership.
/// `disabledPatternLabels` is a set of labels used to filter out input
/// patterns with a label in this set. `enabledPatternLabels` is a set of
/// labels used to filter out input patterns that do not have one of the
/// labels in this set.
RewritePatternSet &&patterns,
ArrayRef<std::string> disabledPatternLabels = llvm::None,
ArrayRef<std::string> enabledPatternLabels = llvm::None);
/// Return the op specific native patterns held by this list.
const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const {
return impl->nativeOpSpecificPatternMap;
/// Return the "match any" native patterns held by this list.
getMatchAnyOpNativePatterns() const {
const NativePatternListT &nativeList = impl->nativeAnyOpPatterns;
return llvm::make_pointee_range(nativeList);
/// Return the compiled PDL bytecode held by this list. Returns null if
/// there are no PDL patterns within the list.
const detail::PDLByteCode *getPDLByteCode() const {
return impl->pdlByteCode.get();
/// The internal implementation of the frozen pattern list.
struct Impl {
/// The set of native C++ rewrite patterns that are matched to specific
/// operation kinds.
OpSpecificNativePatternListT nativeOpSpecificPatternMap;
/// The full op-specific native rewrite list. This allows for the map above
/// to contain duplicate patterns, e.g. for interfaces and traits.
NativePatternListT nativeOpSpecificPatternList;
/// The set of native C++ rewrite patterns that are matched to "any"
/// operation.
NativePatternListT nativeAnyOpPatterns;
/// The bytecode containing the compiled PDL patterns.
std::unique_ptr<detail::PDLByteCode> pdlByteCode;
/// A pointer to the internal pattern list. This uses a shared_ptr to avoid
/// the need to compile the same pattern list multiple times. For example,
/// during multi-threaded pass execution, all copies of a pass can share the
/// same pattern list.
std::shared_ptr<Impl> impl;
} // end namespace mlir