blob: 11a9db79b322e9aabec33e9dd3a4f336dc8347f1 [file] [log] [blame]
//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements MLIR to byte-code generation and the interpreter.
//
//===----------------------------------------------------------------------===//
#include "ByteCode.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "pdl-bytecode"
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
// PDLByteCodePattern
//===----------------------------------------------------------------------===//
PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
ByteCodeAddr rewriterAddr) {
SmallVector<StringRef, 8> generatedOps;
if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
generatedOps =
llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
PatternBenefit benefit = matchOp.benefit();
MLIRContext *ctx = matchOp.getContext();
// Check to see if this is pattern matches a specific operation type.
if (Optional<StringRef> rootKind = matchOp.rootKind())
return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
ctx);
return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
MatchAnyOpTypeTag());
}
//===----------------------------------------------------------------------===//
// PDLByteCodeMutableState
//===----------------------------------------------------------------------===//
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
/// to the position of the pattern within the range returned by
/// `PDLByteCode::getPatterns`.
void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
PatternBenefit benefit) {
currentPatternBenefits[patternIndex] = benefit;
}
//===----------------------------------------------------------------------===//
// Bytecode OpCodes
//===----------------------------------------------------------------------===//
namespace {
enum OpCode : ByteCodeField {
/// Apply an externally registered constraint.
ApplyConstraint,
/// Apply an externally registered rewrite.
ApplyRewrite,
/// Check if two generic values are equal.
AreEqual,
/// Unconditional branch.
Branch,
/// Compare the operand count of an operation with a constant.
CheckOperandCount,
/// Compare the name of an operation with a constant.
CheckOperationName,
/// Compare the result count of an operation with a constant.
CheckResultCount,
/// Invoke a native creation method.
CreateNative,
/// Create an operation.
CreateOperation,
/// Erase an operation.
EraseOp,
/// Terminate a matcher or rewrite sequence.
Finalize,
/// Get a specific attribute of an operation.
GetAttribute,
/// Get the type of an attribute.
GetAttributeType,
/// Get the defining operation of a value.
GetDefiningOp,
/// Get a specific operand of an operation.
GetOperand0,
GetOperand1,
GetOperand2,
GetOperand3,
GetOperandN,
/// Get a specific result of an operation.
GetResult0,
GetResult1,
GetResult2,
GetResult3,
GetResultN,
/// Get the type of a value.
GetValueType,
/// Check if a generic value is not null.
IsNotNull,
/// Record a successful pattern match.
RecordMatch,
/// Replace an operation.
ReplaceOp,
/// Compare an attribute with a set of constants.
SwitchAttribute,
/// Compare the operand count of an operation with a set of constants.
SwitchOperandCount,
/// Compare the name of an operation with a set of constants.
SwitchOperationName,
/// Compare the result count of an operation with a set of constants.
SwitchResultCount,
/// Compare a type with a set of constants.
SwitchType,
};
enum class PDLValueKind { Attribute, Operation, Type, Value };
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// ByteCode Generation
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Generator
namespace {
struct ByteCodeWriter;
/// This class represents the main generator for the pattern bytecode.
class Generator {
public:
Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
SmallVectorImpl<ByteCodeField> &matcherByteCode,
SmallVectorImpl<ByteCodeField> &rewriterByteCode,
SmallVectorImpl<PDLByteCodePattern> &patterns,
ByteCodeField &maxValueMemoryIndex,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
llvm::StringMap<PDLCreateFunction> &createFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex) {
for (auto it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(createFns))
nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(rewriteFns))
externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
}
/// Generate the bytecode for the given PDL interpreter module.
void generate(ModuleOp module);
/// Return the memory index to use for the given value.
ByteCodeField &getMemIndex(Value value) {
assert(valueToMemIndex.count(value) &&
"expected memory index to be assigned");
return valueToMemIndex[value];
}
/// Return an index to use when referring to the given data that is uniqued in
/// the MLIR context.
template <typename T>
std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val) {
const void *opaqueVal = val.getAsOpaquePointer();
// Get or insert a reference to this value.
auto it = uniquedDataToMemIndex.try_emplace(
opaqueVal, maxValueMemoryIndex + uniquedData.size());
if (it.second)
uniquedData.push_back(opaqueVal);
return it.first->second;
}
private:
/// Allocate memory indices for the results of operations within the matcher
/// and rewriters.
void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
/// Generate the bytecode for the given operation.
void generate(Operation *op, ByteCodeWriter &writer);
void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
/// Mapping from value to its corresponding memory index.
DenseMap<Value, ByteCodeField> valueToMemIndex;
/// Mapping from the name of an externally registered rewrite to its index in
/// the bytecode registry.
llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
/// Mapping from the name of an externally registered constraint to its index
/// in the bytecode registry.
llvm::StringMap<ByteCodeField> constraintToMemIndex;
/// Mapping from the name of an externally registered creation method to its
/// index in the bytecode registry.
llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
/// Mapping from rewriter function name to the bytecode address of the
/// rewriter function in byte.
llvm::StringMap<ByteCodeAddr> rewriterToAddr;
/// Mapping from a uniqued storage object to its memory index within
/// `uniquedData`.
DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
/// The current MLIR context.
MLIRContext *ctx;
/// Data of the ByteCode class to be populated.
std::vector<const void *> &uniquedData;
SmallVectorImpl<ByteCodeField> &matcherByteCode;
SmallVectorImpl<ByteCodeField> &rewriterByteCode;
SmallVectorImpl<PDLByteCodePattern> &patterns;
ByteCodeField &maxValueMemoryIndex;
};
/// This class provides utilities for writing a bytecode stream.
struct ByteCodeWriter {
ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
: bytecode(bytecode), generator(generator) {}
/// Append a field to the bytecode.
void append(ByteCodeField field) { bytecode.push_back(field); }
void append(OpCode opCode) { bytecode.push_back(opCode); }
/// Append an address to the bytecode.
void append(ByteCodeAddr field) {
static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
"unexpected ByteCode address size");
ByteCodeField fieldParts[2];
std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
bytecode.append({fieldParts[0], fieldParts[1]});
}
/// Append a successor range to the bytecode, the exact address will need to
/// be resolved later.
void append(SuccessorRange successors) {
// Add back references to the any successors so that the address can be
// resolved later.
for (Block *successor : successors) {
unresolvedSuccessorRefs[successor].push_back(bytecode.size());
append(ByteCodeAddr(0));
}
}
/// Append a range of values that will be read as generic PDLValues.
void appendPDLValueList(OperandRange values) {
bytecode.push_back(values.size());
for (Value value : values) {
// Append the type of the value in addition to the value itself.
PDLValueKind kind =
TypeSwitch<Type, PDLValueKind>(value.getType())
.Case<pdl::AttributeType>(
[](Type) { return PDLValueKind::Attribute; })
.Case<pdl::OperationType>(
[](Type) { return PDLValueKind::Operation; })
.Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
.Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
bytecode.push_back(static_cast<ByteCodeField>(kind));
append(value);
}
}
/// Check if the given class `T` has an iterator type.
template <typename T, typename... Args>
using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
/// Append a value that will be stored in a memory slot and not inline within
/// the bytecode.
template <typename T>
std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
std::is_pointer<T>::value>
append(T value) {
bytecode.push_back(generator.getMemIndex(value));
}
/// Append a range of values.
template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append(T range) {
bytecode.push_back(llvm::size(range));
for (auto it : range)
append(it);
}
/// Append a variadic number of fields to the bytecode.
template <typename FieldTy, typename Field2Ty, typename... FieldTys>
void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
append(field);
append(field2, fields...);
}
/// Successor references in the bytecode that have yet to be resolved.
DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
/// The underlying bytecode buffer.
SmallVectorImpl<ByteCodeField> &bytecode;
/// The main generator producing PDL.
Generator &generator;
};
} // end anonymous namespace
void Generator::generate(ModuleOp module) {
FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
pdl_interp::PDLInterpDialect::getMatcherFunctionName());
ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
pdl_interp::PDLInterpDialect::getRewriterModuleName());
assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
// Allocate memory indices for the results of operations within the matcher
// and rewriters.
allocateMemoryIndices(matcherFunc, rewriterModule);
// Generate code for the rewriter functions.
ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
for (Operation &op : rewriterFunc.getOps())
generate(&op, rewriterByteCodeWriter);
}
assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
"unexpected branches in rewriter function");
// Generate code for the matcher function.
DenseMap<Block *, ByteCodeAddr> blockToAddr;
llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
for (Block *block : rpot) {
// Keep track of where this block begins within the matcher function.
blockToAddr.try_emplace(block, matcherByteCode.size());
for (Operation &op : *block)
generate(&op, matcherByteCodeWriter);
}
// Resolve successor references in the matcher.
for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
ByteCodeAddr addr = blockToAddr[it.first];
for (unsigned offsetToFix : it.second)
std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
}
}
void Generator::allocateMemoryIndices(FuncOp matcherFunc,
ModuleOp rewriterModule) {
// Rewriters use simplistic allocation scheme that simply assigns an index to
// each result.
for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
ByteCodeField index = 0;
for (BlockArgument arg : rewriterFunc.getArguments())
valueToMemIndex.try_emplace(arg, index++);
rewriterFunc.getBody().walk([&](Operation *op) {
for (Value result : op->getResults())
valueToMemIndex.try_emplace(result, index++);
});
if (index > maxValueMemoryIndex)
maxValueMemoryIndex = index;
}
// The matcher function uses a more sophisticated numbering that tries to
// minimize the number of memory indices assigned. This is done by determining
// a live range of the values within the matcher, then the allocation is just
// finding the minimal number of overlapping live ranges. This is essentially
// a simplified form of register allocation where we don't necessarily have a
// limited number of registers, but we still want to minimize the number used.
DenseMap<Operation *, ByteCodeField> opToIndex;
matcherFunc.getBody().walk([&](Operation *op) {
opToIndex.insert(std::make_pair(op, opToIndex.size()));
});
// Liveness info for each of the defs within the matcher.
using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
LivenessSet::Allocator allocator;
DenseMap<Value, LivenessSet> valueDefRanges;
// Assign the root operation being matched to slot 0.
BlockArgument rootOpArg = matcherFunc.getArgument(0);
valueToMemIndex[rootOpArg] = 0;
// Walk each of the blocks, computing the def interval that the value is used.
Liveness matcherLiveness(matcherFunc);
for (Block &block : matcherFunc.getBody()) {
const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
assert(info && "expected liveness info for block");
auto processValue = [&](Value value, Operation *firstUseOrDef) {
// We don't need to process the root op argument, this value is always
// assigned to the first memory slot.
if (value == rootOpArg)
return;
// Set indices for the range of this block that the value is used.
auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
defRangeIt->second.insert(
opToIndex[firstUseOrDef],
opToIndex[info->getEndOperation(value, firstUseOrDef)],
/*dummyValue*/ 0);
};
// Process the live-ins of this block.
for (Value liveIn : info->in())
processValue(liveIn, &block.front());
// Process any new defs within this block.
for (Operation &op : block)
for (Value result : op.getResults())
processValue(result, &op);
}
// Greedily allocate memory slots using the computed def live ranges.
std::vector<LivenessSet> allocatedIndices;
for (auto &defIt : valueDefRanges) {
ByteCodeField &memIndex = valueToMemIndex[defIt.first];
LivenessSet &defSet = defIt.second;
// Try to allocate to an existing index.
for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
LivenessSet &existingIndex = existingIndexIt.value();
llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
defIt.second, existingIndex);
if (overlaps.valid())
continue;
// Union the range of the def within the existing index.
for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
memIndex = existingIndexIt.index() + 1;
}
// If no existing index could be used, add a new one.
if (memIndex == 0) {
allocatedIndices.emplace_back(allocator);
for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
memIndex = allocatedIndices.size();
}
}
// Update the max number of indices.
ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
if (numMatcherIndices > maxValueMemoryIndex)
maxValueMemoryIndex = numMatcherIndices;
}
void Generator::generate(Operation *op, ByteCodeWriter &writer) {
TypeSwitch<Operation *>(op)
.Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
pdl_interp::AreEqualOp, pdl_interp::BranchOp,
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
});
}
void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
assert(constraintToMemIndex.count(op.name()) &&
"expected index for constraint function");
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
op.constParamsAttr());
writer.appendPDLValueList(op.args());
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
ByteCodeWriter &writer) {
assert(externalRewriterToMemIndex.count(op.name()) &&
"expected index for rewrite function");
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
op.constParamsAttr(), op.root());
writer.appendPDLValueList(op.args());
}
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
}
void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
}
void Generator::generate(pdl_interp::CheckAttributeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckOperandCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckOperationNameOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckOperationName, op.operation(),
OperationName(op.name(), ctx), op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckResultCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
op.getSuccessors());
}
void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
}
void Generator::generate(pdl_interp::CreateAttributeOp op,
ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.attribute()) = getMemIndex(op.value());
}
void Generator::generate(pdl_interp::CreateNativeOp op,
ByteCodeWriter &writer) {
assert(nativeCreateToMemIndex.count(op.name()) &&
"expected index for creation function");
writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
op.result(), op.constParamsAttr());
writer.appendPDLValueList(op.args());
}
void Generator::generate(pdl_interp::CreateOperationOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CreateOperation, op.operation(),
OperationName(op.name(), ctx), op.operands());
// Add the attributes.
OperandRange attributes = op.attributes();
writer.append(static_cast<ByteCodeField>(attributes.size()));
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
writer.append(
Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
std::get<1>(it));
}
writer.append(op.types());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.result()) = getMemIndex(op.value());
}
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
writer.append(OpCode::EraseOp, op.operation());
}
void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::Finalize);
}
void Generator::generate(pdl_interp::GetAttributeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
Identifier::get(op.name(), ctx));
}
void Generator::generate(pdl_interp::GetAttributeTypeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetAttributeType, op.result(), op.value());
}
void Generator::generate(pdl_interp::GetDefiningOpOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
}
void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
uint32_t index = op.index();
if (index < 4)
writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
else
writer.append(OpCode::GetOperandN, index);
writer.append(op.operation(), op.value());
}
void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
uint32_t index = op.index();
if (index < 4)
writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
else
writer.append(OpCode::GetResultN, index);
writer.append(op.operation(), op.value());
}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetValueType, op.result(), op.value());
}
void Generator::generate(pdl_interp::InferredTypeOp op,
ByteCodeWriter &writer) {
// InferType maps to a null type as a marker for inferring a result type.
getMemIndex(op.type()) = getMemIndex(Type());
}
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
}
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
ByteCodeField patternIndex = patterns.size();
patterns.emplace_back(PDLByteCodePattern::create(
op, rewriterToAddr[op.rewriter().getLeafReference()]));
writer.append(OpCode::RecordMatch, patternIndex,
SuccessorRange(op.getOperation()), op.matchedOps(),
op.inputs());
}
void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
}
void Generator::generate(pdl_interp::SwitchAttributeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
op.getSuccessors());
}
void Generator::generate(pdl_interp::SwitchOperandCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
op.getSuccessors());
}
void Generator::generate(pdl_interp::SwitchOperationNameOp op,
ByteCodeWriter &writer) {
auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
return OperationName(attr.cast<StringAttr>().getValue(), ctx);
});
writer.append(OpCode::SwitchOperationName, op.operation(), cases,
op.getSuccessors());
}
void Generator::generate(pdl_interp::SwitchResultCountOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
op.getSuccessors());
}
void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
op.getSuccessors());
}
//===----------------------------------------------------------------------===//
// PDLByteCode
//===----------------------------------------------------------------------===//
PDLByteCode::PDLByteCode(ModuleOp module,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLCreateFunction> createFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
constraintFns, createFns, rewriteFns);
generator.generate(module);
// Initialize the external functions.
for (auto &it : constraintFns)
constraintFunctions.push_back(std::move(it.second));
for (auto &it : createFns)
createFunctions.push_back(std::move(it.second));
for (auto &it : rewriteFns)
rewriteFunctions.push_back(std::move(it.second));
}
/// Initialize the given state such that it can be used to execute the current
/// bytecode.
void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
state.memory.resize(maxValueMemoryIndex, nullptr);
state.currentPatternBenefits.reserve(patterns.size());
for (const PDLByteCodePattern &pattern : patterns)
state.currentPatternBenefits.push_back(pattern.getBenefit());
}
//===----------------------------------------------------------------------===//
// ByteCode Execution
namespace {
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
ByteCodeExecutor(const ByteCodeField *curCodeIt,
MutableArrayRef<const void *> memory,
ArrayRef<const void *> uniquedMemory,
ArrayRef<ByteCodeField> code,
ArrayRef<PatternBenefit> currentPatternBenefits,
ArrayRef<PDLByteCodePattern> patterns,
ArrayRef<PDLConstraintFunction> constraintFunctions,
ArrayRef<PDLCreateFunction> createFunctions,
ArrayRef<PDLRewriteFunction> rewriteFunctions)
: curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
code(code), currentPatternBenefits(currentPatternBenefits),
patterns(patterns), constraintFunctions(constraintFunctions),
createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
/// Start executing the code at the current bytecode index. `matches` is an
/// optional field provided when this function is executed in a matching
/// context.
void execute(PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
Optional<Location> mainRewriteLoc = {});
private:
/// Read a value from the bytecode buffer, optionally skipping a certain
/// number of prefix values. These methods always update the buffer to point
/// to the next field after the read data.
template <typename T = ByteCodeField>
T read(size_t skipN = 0) {
curCodeIt += skipN;
return readImpl<T>();
}
ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
/// Read a list of values from the bytecode buffer.
template <typename ValueT, typename T>
void readList(SmallVectorImpl<T> &list) {
list.clear();
for (unsigned i = 0, e = read(); i != e; ++i)
list.push_back(read<ValueT>());
}
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
void selectJump(size_t destIndex) {
curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
}
/// Handle a switch operation with the provided value and cases.
template <typename T, typename RangeT>
void handleSwitch(const T &value, RangeT &&cases) {
LLVM_DEBUG({
llvm::dbgs() << " * Value: " << value << "\n"
<< " * Cases: ";
llvm::interleaveComma(cases, llvm::dbgs());
llvm::dbgs() << "\n\n";
});
// Check to see if the attribute value is within the case list. Jump to
// the correct successor index based on the result.
for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
if (*it == value)
return selectJump(size_t((it - cases.begin()) + 1));
selectJump(size_t(0));
}
/// Internal implementation of reading various data types from the bytecode
/// stream.
template <typename T>
const void *readFromMemory() {
size_t index = *curCodeIt++;
// If this type is an SSA value, it can only be stored in non-const memory.
if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
return memory[index];
// Otherwise, if this index is not inbounds it is uniqued.
return uniquedMemory[index - memory.size()];
}
template <typename T>
std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
}
template <typename T>
std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
T>
readImpl() {
return T(T::getFromOpaquePointer(readFromMemory<T>()));
}
template <typename T>
std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
switch (static_cast<PDLValueKind>(read())) {
case PDLValueKind::Attribute:
return read<Attribute>();
case PDLValueKind::Operation:
return read<Operation *>();
case PDLValueKind::Type:
return read<Type>();
case PDLValueKind::Value:
return read<Value>();
}
llvm_unreachable("unhandled PDLValueKind");
}
template <typename T>
std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
"unexpected ByteCode address size");
ByteCodeAddr result;
std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
curCodeIt += 2;
return result;
}
template <typename T>
std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
return *curCodeIt++;
}
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
/// The current execution memory.
MutableArrayRef<const void *> memory;
/// References to ByteCode data necessary for execution.
ArrayRef<const void *> uniquedMemory;
ArrayRef<ByteCodeField> code;
ArrayRef<PatternBenefit> currentPatternBenefits;
ArrayRef<PDLByteCodePattern> patterns;
ArrayRef<PDLConstraintFunction> constraintFunctions;
ArrayRef<PDLCreateFunction> createFunctions;
ArrayRef<PDLRewriteFunction> rewriteFunctions;
};
} // end anonymous namespace
void ByteCodeExecutor::execute(
PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> *matches,
Optional<Location> mainRewriteLoc) {
while (true) {
OpCode opCode = static_cast<OpCode>(read());
switch (opCode) {
case ApplyConstraint: {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
ArrayAttr constParams = read<ArrayAttr>();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
LLVM_DEBUG({
llvm::dbgs() << " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
});
// Invoke the constraint and jump to the proper destination.
selectJump(succeeded(constraintFn(args, constParams, rewriter)));
break;
}
case ApplyRewrite: {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
ArrayAttr constParams = read<ArrayAttr>();
Operation *root = read<Operation *>();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
LLVM_DEBUG({
llvm::dbgs() << " * Root: " << *root << "\n"
<< " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
});
rewriteFn(root, args, constParams, rewriter);
break;
}
case AreEqual: {
LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
const void *lhs = read<const void *>();
const void *rhs = read<const void *>();
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
selectJump(lhs == rhs);
break;
}
case Branch: {
LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
curCodeIt = &code[read<ByteCodeAddr>()];
break;
}
case CheckOperandCount: {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
<< " * Expected: " << expectedCount << "\n\n");
selectJump(op->getNumOperands() == expectedCount);
break;
}
case CheckOperationName: {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
Operation *op = read<Operation *>();
OperationName expectedName = read<OperationName>();
LLVM_DEBUG(llvm::dbgs()
<< " * Found: \"" << op->getName() << "\"\n"
<< " * Expected: \"" << expectedName << "\"\n\n");
selectJump(op->getName() == expectedName);
break;
}
case CheckResultCount: {
LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
Operation *op = read<Operation *>();
uint32_t expectedCount = read<uint32_t>();
LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
<< " * Expected: " << expectedCount << "\n\n");
selectJump(op->getNumResults() == expectedCount);
break;
}
case CreateNative: {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
const PDLCreateFunction &createFn = createFunctions[read()];
ByteCodeField resultIndex = read();
ArrayAttr constParams = read<ArrayAttr>();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
LLVM_DEBUG({
llvm::dbgs() << " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
});
PDLValue result = createFn(args, constParams, rewriter);
memory[resultIndex] = result.getAsOpaquePointer();
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n");
break;
}
case CreateOperation: {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
assert(mainRewriteLoc && "expected rewrite loc to be provided when "
"executing the rewriter bytecode");
unsigned memIndex = read();
OperationState state(*mainRewriteLoc, read<OperationName>());
readList<Value>(state.operands);
for (unsigned i = 0, e = read(); i != e; ++i) {
Identifier name = read<Identifier>();
if (Attribute attr = read<Attribute>())
state.addAttribute(name, attr);
}
bool hasInferredTypes = false;
for (unsigned i = 0, e = read(); i != e; ++i) {
Type resultType = read<Type>();
hasInferredTypes |= !resultType;
state.types.push_back(resultType);
}
// Handle the case where the operation has inferred types.
if (hasInferredTypes) {
InferTypeOpInterface::Concept *concept =
state.name.getAbstractOperation()
->getInterface<InferTypeOpInterface>();
// TODO: Handle failure.
SmallVector<Type, 2> inferredTypes;
if (failed(concept->inferReturnTypes(
state.getContext(), state.location, state.operands,
state.attributes.getDictionary(state.getContext()),
state.regions, inferredTypes)))
return;
for (unsigned i = 0, e = state.types.size(); i != e; ++i)
if (!state.types[i])
state.types[i] = inferredTypes[i];
}
Operation *resultOp = rewriter.createOperation(state);
memory[memIndex] = resultOp;
LLVM_DEBUG({
llvm::dbgs() << " * Attributes: "
<< state.attributes.getDictionary(state.getContext())
<< "\n * Operands: ";
llvm::interleaveComma(state.operands, llvm::dbgs());
llvm::dbgs() << "\n * Result Types: ";
llvm::interleaveComma(state.types, llvm::dbgs());
llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n";
});
break;
}
case EraseOp: {
LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
Operation *op = read<Operation *>();
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n");
rewriter.eraseOp(op);
break;
}
case Finalize: {
LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
return;
}
case GetAttribute: {
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
unsigned memIndex = read();
Operation *op = read<Operation *>();
Identifier attrName = read<Identifier>();
Attribute attr = op->getAttr(attrName);
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
<< " * Attribute: " << attrName << "\n"
<< " * Result: " << attr << "\n\n");
memory[memIndex] = attr.getAsOpaquePointer();
break;
}
case GetAttributeType: {
LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
unsigned memIndex = read();
Attribute attr = read<Attribute>();
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
<< " * Result: " << attr.getType() << "\n\n");
memory[memIndex] = attr.getType().getAsOpaquePointer();
break;
}
case GetDefiningOp: {
LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
unsigned memIndex = read();
Value value = read<Value>();
Operation *op = value ? value.getDefiningOp() : nullptr;
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
<< " * Result: " << *op << "\n\n");
memory[memIndex] = op;
break;
}
case GetOperand0:
case GetOperand1:
case GetOperand2:
case GetOperand3:
case GetOperandN: {
LLVM_DEBUG({
llvm::dbgs() << "Executing GetOperand"
<< (opCode == GetOperandN ? Twine("N")
: Twine(opCode - GetOperand0))
<< ":\n";
});
unsigned index =
opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
Operation *op = read<Operation *>();
unsigned memIndex = read();
Value operand =
index < op->getNumOperands() ? op->getOperand(index) : Value();
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
<< " * Index: " << index << "\n"
<< " * Result: " << operand << "\n\n");
memory[memIndex] = operand.getAsOpaquePointer();
break;
}
case GetResult0:
case GetResult1:
case GetResult2:
case GetResult3:
case GetResultN: {
LLVM_DEBUG({
llvm::dbgs() << "Executing GetResult"
<< (opCode == GetResultN ? Twine("N")
: Twine(opCode - GetResult0))
<< ":\n";
});
unsigned index =
opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
Operation *op = read<Operation *>();
unsigned memIndex = read();
OpResult result =
index < op->getNumResults() ? op->getResult(index) : OpResult();
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
<< " * Index: " << index << "\n"
<< " * Result: " << result << "\n\n");
memory[memIndex] = result.getAsOpaquePointer();
break;
}
case GetValueType: {
LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
unsigned memIndex = read();
Value value = read<Value>();
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
<< " * Result: " << value.getType() << "\n\n");
memory[memIndex] = value.getType().getAsOpaquePointer();
break;
}
case IsNotNull: {
LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
const void *value = read<const void *>();
LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n");
selectJump(value != nullptr);
break;
}
case RecordMatch: {
LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
assert(matches &&
"expected matches to be provided when executing the matcher");
unsigned patternIndex = read();
PatternBenefit benefit = currentPatternBenefits[patternIndex];
const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
// If the benefit of the pattern is impossible, skip the processing of the
// rest of the pattern.
if (benefit.isImpossibleToMatch()) {
LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n");
curCodeIt = dest;
break;
}
// Create a fused location containing the locations of each of the
// operations used in the match. This will be used as the location for
// created operations during the rewrite that don't already have an
// explicit location set.
unsigned numMatchLocs = read();
SmallVector<Location, 4> matchLocs;
matchLocs.reserve(numMatchLocs);
for (unsigned i = 0; i != numMatchLocs; ++i)
matchLocs.push_back(read<Operation *>()->getLoc());
Location matchLoc = rewriter.getFusedLoc(matchLocs);
LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
<< " * Location: " << matchLoc << "\n\n");
matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
readList<const void *>(matches->back().values);
curCodeIt = dest;
break;
}
case ReplaceOp: {
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
readList<Value>(args);
LLVM_DEBUG({
llvm::dbgs() << " * Operation: " << *op << "\n"
<< " * Values: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n\n";
});
rewriter.replaceOp(op, args);
break;
}
case SwitchAttribute: {
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
Attribute value = read<Attribute>();
ArrayAttr cases = read<ArrayAttr>();
handleSwitch(value, cases);
break;
}
case SwitchOperandCount: {
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
Operation *op = read<Operation *>();
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
handleSwitch(op->getNumOperands(), cases);
break;
}
case SwitchOperationName: {
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
OperationName value = read<Operation *>()->getName();
size_t caseCount = read();
// The operation names are stored in-line, so to print them out for
// debugging purposes we need to read the array before executing the
// switch so that we can display all of the possible values.
LLVM_DEBUG({
const ByteCodeField *prevCodeIt = curCodeIt;
llvm::dbgs() << " * Value: " << value << "\n"
<< " * Cases: ";
llvm::interleaveComma(
llvm::map_range(llvm::seq<size_t>(0, caseCount),
[&](size_t i) { return read<OperationName>(); }),
llvm::dbgs());
llvm::dbgs() << "\n\n";
curCodeIt = prevCodeIt;
});
// Try to find the switch value within any of the cases.
size_t jumpDest = 0;
for (size_t i = 0; i != caseCount; ++i) {
if (read<OperationName>() == value) {
curCodeIt += (caseCount - i - 1);
jumpDest = i + 1;
break;
}
}
selectJump(jumpDest);
break;
}
case SwitchResultCount: {
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
Operation *op = read<Operation *>();
auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
handleSwitch(op->getNumResults(), cases);
break;
}
case SwitchType: {
LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
Type value = read<Type>();
auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
handleSwitch(value, cases);
break;
}
}
}
}
/// Run the pattern matcher on the given root operation, collecting the matched
/// patterns in `matches`.
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const {
// The first memory slot is always the root operation.
state.memory[0] = op;
// The matcher function always starts at code address 0.
ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
matcherByteCode, state.currentPatternBenefits,
patterns, constraintFunctions, createFunctions,
rewriteFunctions);
executor.execute(rewriter, &matches);
// Order the found matches by benefit.
std::stable_sort(matches.begin(), matches.end(),
[](const MatchResult &lhs, const MatchResult &rhs) {
return lhs.benefit > rhs.benefit;
});
}
/// Run the rewriter of the given pattern on the root operation `op`.
void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
PDLByteCodeMutableState &state) const {
// The arguments of the rewrite function are stored at the start of the
// memory buffer.
llvm::copy(match.values, state.memory.begin());
ByteCodeExecutor executor(
&rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, createFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location);
}