blob: 1587a4ea13138e821e8004221877061d9ac812e7 [file] [log] [blame]
//===- NestedMacher.h - Nested matcher for Function -------------*- C++ -*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
class NestedPattern;
class Operation;
/// An NestedPattern captures nested patterns in the IR.
/// It is used in conjunction with a scoped NestedPatternContext which is an
/// llvm::BumpPtrAllocator that handles memory allocations efficiently and
/// avoids ownership issues.
/// In order to use NestedPatterns, first create a scoped context.
/// When the context goes out of scope, everything is freed.
/// This design simplifies the API by avoiding references to the context and
/// makes it clear that references to matchers must not escape.
/// Example:
/// {
/// NestedPatternContext context;
/// auto gemmLike = Doall(Doall(Red(LoadStores())));
/// auto matches = gemmLike.match(f);
/// // do work on matches
/// } // everything is freed
/// Nested abstraction for matching results.
/// Provides access to the nested Operation* captured by a Matcher.
/// A NestedMatch contains an Operation* and the children NestedMatch and is
/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose
/// lifetime is managed by an RAII NestedPatternContext.
class NestedMatch {
static NestedMatch build(Operation *operation,
ArrayRef<NestedMatch> nestedMatches);
NestedMatch(const NestedMatch &) = default;
NestedMatch &operator=(const NestedMatch &) = default;
explicit operator bool() { return matchedOperation != nullptr; }
Operation *getMatchedOperation() const { return matchedOperation; }
ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; }
friend class NestedPattern;
friend class NestedPatternContext;
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
NestedMatch() = default;
/// Payload, holds a NestedMatch and all its children along this branch.
Operation *matchedOperation;
ArrayRef<NestedMatch> matchedChildren;
/// A NestedPattern is a nested operation walker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
/// 3. TODO: optionally applies actions (lambda).
/// Nested patterns are meant to capture imperfectly nested loops while matching
/// properties over the whole loop nest. For instance, in vectorization we are
/// interested in capturing all the imperfectly nested loops of a certain type
/// and such that all the load and stores have certain access patterns along the
/// loops' induction variables). Such NestedMatches are first captured using the
/// `match` function and are later processed to analyze properties and apply
/// transformations in a non-greedy way.
/// The NestedMatches captured in the IR can grow large, especially after
/// aggressive unrolling. As experience has shown, it is generally better to use
/// a plain walk over operations to match flat patterns but the current
/// implementation is competitive nonetheless.
using FilterFunctionType = std::function<bool(Operation &)>;
inline bool defaultFilterFunction(Operation &) { return true; }
class NestedPattern {
NestedPattern(ArrayRef<NestedPattern> nested,
FilterFunctionType filter = defaultFilterFunction);
NestedPattern(const NestedPattern &other);
NestedPattern &operator=(const NestedPattern &other);
~NestedPattern() {
// Call destructors manually, ArrayRef is non-owning so it wouldn't call
// them, but we should free the memory allocated by std::function outside of
// the arena allocator.
/// Returns all the top-level matches in `func`.
void match(FuncOp func, SmallVectorImpl<NestedMatch> *matches) {
func.walk([&](Operation *op) { matchOne(op, matches); });
/// Returns all the top-level matches in `op`.
void match(Operation *op, SmallVectorImpl<NestedMatch> *matches) {
op->walk([&](Operation *child) { matchOne(child, matches); });
/// Returns the depth of the pattern.
unsigned getDepth() const;
friend class NestedPatternContext;
friend class NestedMatch;
friend struct State;
/// Copies the list of nested patterns to the arena allocator associated with
/// this pattern.
void copyNestedToThis(ArrayRef<NestedPattern> nested);
/// Calls destructors on nested patterns.
void freeNested();
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
/// Matches this pattern against a single `op` and fills matches with the
/// result.
void matchOne(Operation *op, SmallVectorImpl<NestedMatch> *matches);
/// Nested patterns to be matched.
ArrayRef<NestedPattern> nestedPatterns;
/// Extra filter function to apply to prune patterns as the IR is walked.
FilterFunctionType filter;
/// skip is an implementation detail needed so that we can implement match
/// without switching on the type of the Operation. The idea is that a
/// NestedPattern first checks if it matches locally and then recursively
/// applies its nested matchers to its elem->nested. Since we want to rely on
/// the existing operation walking functionality rather than duplicate
/// it, we allow an off-by-one traversal to account for the fact that we
/// write:
/// void match(Operation *elem) {
/// for (auto &c : getNestedPatterns()) {
/// NestedPattern childPattern(...);
/// ^~~~ Needs off-by-one skip.
Operation *skip;
/// RAII structure to transparently manage the bump allocator for
/// NestedPattern and NestedMatch classes. This avoids passing a context to
/// all the API functions.
class NestedPatternContext {
NestedPatternContext() {
assert(NestedMatch::allocator() == nullptr &&
"Only a single NestedPatternContext is supported");
assert(NestedPattern::allocator() == nullptr &&
"Only a single NestedPatternContext is supported");
NestedMatch::allocator() = &allocator;
NestedPattern::allocator() = &allocator;
~NestedPatternContext() {
NestedMatch::allocator() = nullptr;
NestedPattern::allocator() = nullptr;
llvm::BumpPtrAllocator allocator;
namespace matcher {
// Syntactic sugar NestedPattern builder functions.
NestedPattern Op(FilterFunctionType filter = defaultFilterFunction);
NestedPattern If(NestedPattern child);
NestedPattern If(FilterFunctionType filter, NestedPattern child);
NestedPattern If(ArrayRef<NestedPattern> nested = {});
NestedPattern If(FilterFunctionType filter,
ArrayRef<NestedPattern> nested = {});
NestedPattern For(NestedPattern child);
NestedPattern For(FilterFunctionType filter, NestedPattern child);
NestedPattern For(ArrayRef<NestedPattern> nested = {});
NestedPattern For(FilterFunctionType filter,
ArrayRef<NestedPattern> nested = {});
bool isParallelLoop(Operation &op);
bool isReductionLoop(Operation &op);
bool isLoadOrStore(Operation &op);
} // end namespace matcher
} // end namespace mlir