blob: a26c3b295de96d2d7c2c56c2419fc2481fe1b28e [file] [log] [blame]
//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
// The byte-code is constructed from the PDL Interpreter dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REWRITE_BYTECODE_H_
#define MLIR_REWRITE_BYTECODE_H_
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace pdl_interp {
class RecordMatchOp;
} // end namespace pdl_interp
namespace detail {
class PDLByteCode;
/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
using OwningOpRange = llvm::OwningArrayRef<Operation *>;
//===----------------------------------------------------------------------===//
// PDLByteCodePattern
//===----------------------------------------------------------------------===//
/// All of the data pertaining to a specific pattern within the bytecode.
class PDLByteCodePattern : public Pattern {
public:
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
ByteCodeAddr rewriterAddr);
/// Return the bytecode address of the rewriter for this pattern.
ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
private:
template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
: Pattern(std::forward<Args>(patternArgs)...),
rewriterAddr(rewriterAddr) {}
/// The address of the rewriter for this pattern.
ByteCodeAddr rewriterAddr;
};
//===----------------------------------------------------------------------===//
// PDLByteCodeMutableState
//===----------------------------------------------------------------------===//
/// This class contains the mutable state of a bytecode instance. This allows
/// for a bytecode instance to be cached and reused across various different
/// threads/drivers.
class PDLByteCodeMutableState {
public:
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
/// to the position of the pattern within the range returned by
/// `PDLByteCode::getPatterns`.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
/// Cleanup any allocated state after a match/rewrite has been completed. This
/// method should be called irregardless of whether the match+rewrite was a
/// success or not.
void cleanupAfterMatchAndRewrite();
private:
/// Allow access to data fields.
friend class PDLByteCode;
/// The mutable block of memory used during the matching and rewriting phases
/// of the bytecode.
std::vector<const void *> memory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of operations. These are always stored by
/// owning references, because at no point in the execution of the byte code
/// we get an indexed range (view) of operations.
std::vector<OwningOpRange> opRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
/// A set of type ranges that have been allocated by the byte code interpreter
/// to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of values.
std::vector<ValueRange> valueRangeMemory;
/// A set of value ranges that have been allocated by the byte code
/// interpreter to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
/// The current index of ranges being iterated over for each level of nesting.
/// These are always maintained at 0 for the loops that are not active, so we
/// do not need to have a separate initialization phase for each loop.
std::vector<unsigned> loopIndex;
/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
};
//===----------------------------------------------------------------------===//
// PDLByteCode
//===----------------------------------------------------------------------===//
/// The bytecode class is also the interpreter. Contains the bytecode itself,
/// the static info, addresses of the rewriter functions, the interpreter
/// memory buffer, and the execution context.
class PDLByteCode {
public:
/// Each successful match returns a MatchResult, which contains information
/// necessary to execute the rewriter and indicates the originating pattern.
struct MatchResult {
MatchResult(Location loc, const PDLByteCodePattern &pattern,
PatternBenefit benefit)
: location(loc), pattern(&pattern), benefit(benefit) {}
MatchResult(const MatchResult &) = delete;
MatchResult &operator=(const MatchResult &) = delete;
MatchResult(MatchResult &&other) = default;
MatchResult &operator=(MatchResult &&) = default;
/// The location of operations to be replaced.
Location location;
/// Memory values defined in the matcher that are passed to the rewriter.
SmallVector<const void *> values;
/// Memory used for the range input values.
SmallVector<TypeRange, 0> typeRangeValues;
SmallVector<ValueRange, 0> valueRangeValues;
/// The originating pattern that was matched. This is always non-null, but
/// represented with a pointer to allow for assignment.
const PDLByteCodePattern *pattern;
/// The current benefit of the pattern that was matched.
PatternBenefit benefit;
};
/// Create a ByteCode instance from the given module containing operations in
/// the PDL interpreter dialect.
PDLByteCode(ModuleOp module,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns);
/// Return the patterns held by the bytecode.
ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
/// Initialize the given state such that it can be used to execute the current
/// bytecode.
void initializeMutableState(PDLByteCodeMutableState &state) const;
/// Run the pattern matcher on the given root operation, collecting the
/// matched patterns in `matches`.
void match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const;
/// Run the rewriter of the given pattern that was previously matched in
/// `match`.
void rewrite(PatternRewriter &rewriter, const MatchResult &match,
PDLByteCodeMutableState &state) const;
private:
/// Execute the given byte code starting at the provided instruction `inst`.
/// `matches` is an optional field provided when this function is executed in
/// a matching context.
void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
PDLByteCodeMutableState &state,
SmallVectorImpl<MatchResult> *matches) const;
/// A vector containing pointers to uniqued data. The storage is intentionally
/// opaque such that we can store a wide range of data types. The types of
/// data stored here include:
/// * Attribute, OperationName, Type
std::vector<const void *> uniquedData;
/// A vector containing the generated bytecode for the matcher.
SmallVector<ByteCodeField, 64> matcherByteCode;
/// A vector containing the generated bytecode for all of the rewriters.
SmallVector<ByteCodeField, 64> rewriterByteCode;
/// The set of patterns contained within the bytecode.
SmallVector<PDLByteCodePattern, 32> patterns;
/// A set of user defined functions invoked via PDL.
std::vector<PDLConstraintFunction> constraintFunctions;
std::vector<PDLRewriteFunction> rewriteFunctions;
/// The maximum memory index used by a value.
ByteCodeField maxValueMemoryIndex = 0;
/// The maximum number of different types of ranges.
ByteCodeField maxOpRangeCount = 0;
ByteCodeField maxTypeRangeCount = 0;
ByteCodeField maxValueRangeCount = 0;
/// The maximum number of nested loops.
ByteCodeField maxLoopLevel = 0;
};
} // end namespace detail
} // end namespace mlir
#endif // MLIR_REWRITE_BYTECODE_H_