| //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements MLIR to byte-code generation and the interpreter. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "ByteCode.h" |
| #include "mlir/Analysis/Liveness.h" |
| #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
| #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/RegionGraphTraits.h" |
| #include "llvm/ADT/IntervalMap.h" |
| #include "llvm/ADT/PostOrderIterator.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "pdl-bytecode" |
| |
| using namespace mlir; |
| using namespace mlir::detail; |
| |
| //===----------------------------------------------------------------------===// |
| // PDLByteCodePattern |
| //===----------------------------------------------------------------------===// |
| |
| PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, |
| ByteCodeAddr rewriterAddr) { |
| SmallVector<StringRef, 8> generatedOps; |
| if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) |
| generatedOps = |
| llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); |
| |
| PatternBenefit benefit = matchOp.benefit(); |
| MLIRContext *ctx = matchOp.getContext(); |
| |
| // Check to see if this is pattern matches a specific operation type. |
| if (Optional<StringRef> rootKind = matchOp.rootKind()) |
| return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, |
| ctx); |
| return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, |
| MatchAnyOpTypeTag()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PDLByteCodeMutableState |
| //===----------------------------------------------------------------------===// |
| |
| /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds |
| /// to the position of the pattern within the range returned by |
| /// `PDLByteCode::getPatterns`. |
| void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, |
| PatternBenefit benefit) { |
| currentPatternBenefits[patternIndex] = benefit; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Bytecode OpCodes |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| enum OpCode : ByteCodeField { |
| /// Apply an externally registered constraint. |
| ApplyConstraint, |
| /// Apply an externally registered rewrite. |
| ApplyRewrite, |
| /// Check if two generic values are equal. |
| AreEqual, |
| /// Unconditional branch. |
| Branch, |
| /// Compare the operand count of an operation with a constant. |
| CheckOperandCount, |
| /// Compare the name of an operation with a constant. |
| CheckOperationName, |
| /// Compare the result count of an operation with a constant. |
| CheckResultCount, |
| /// Invoke a native creation method. |
| CreateNative, |
| /// Create an operation. |
| CreateOperation, |
| /// Erase an operation. |
| EraseOp, |
| /// Terminate a matcher or rewrite sequence. |
| Finalize, |
| /// Get a specific attribute of an operation. |
| GetAttribute, |
| /// Get the type of an attribute. |
| GetAttributeType, |
| /// Get the defining operation of a value. |
| GetDefiningOp, |
| /// Get a specific operand of an operation. |
| GetOperand0, |
| GetOperand1, |
| GetOperand2, |
| GetOperand3, |
| GetOperandN, |
| /// Get a specific result of an operation. |
| GetResult0, |
| GetResult1, |
| GetResult2, |
| GetResult3, |
| GetResultN, |
| /// Get the type of a value. |
| GetValueType, |
| /// Check if a generic value is not null. |
| IsNotNull, |
| /// Record a successful pattern match. |
| RecordMatch, |
| /// Replace an operation. |
| ReplaceOp, |
| /// Compare an attribute with a set of constants. |
| SwitchAttribute, |
| /// Compare the operand count of an operation with a set of constants. |
| SwitchOperandCount, |
| /// Compare the name of an operation with a set of constants. |
| SwitchOperationName, |
| /// Compare the result count of an operation with a set of constants. |
| SwitchResultCount, |
| /// Compare a type with a set of constants. |
| SwitchType, |
| }; |
| |
| enum class PDLValueKind { Attribute, Operation, Type, Value }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // ByteCode Generation |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // Generator |
| |
| namespace { |
| struct ByteCodeWriter; |
| |
| /// This class represents the main generator for the pattern bytecode. |
| class Generator { |
| public: |
| Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData, |
| SmallVectorImpl<ByteCodeField> &matcherByteCode, |
| SmallVectorImpl<ByteCodeField> &rewriterByteCode, |
| SmallVectorImpl<PDLByteCodePattern> &patterns, |
| ByteCodeField &maxValueMemoryIndex, |
| llvm::StringMap<PDLConstraintFunction> &constraintFns, |
| llvm::StringMap<PDLCreateFunction> &createFns, |
| llvm::StringMap<PDLRewriteFunction> &rewriteFns) |
| : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), |
| rewriterByteCode(rewriterByteCode), patterns(patterns), |
| maxValueMemoryIndex(maxValueMemoryIndex) { |
| for (auto it : llvm::enumerate(constraintFns)) |
| constraintToMemIndex.try_emplace(it.value().first(), it.index()); |
| for (auto it : llvm::enumerate(createFns)) |
| nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); |
| for (auto it : llvm::enumerate(rewriteFns)) |
| externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); |
| } |
| |
| /// Generate the bytecode for the given PDL interpreter module. |
| void generate(ModuleOp module); |
| |
| /// Return the memory index to use for the given value. |
| ByteCodeField &getMemIndex(Value value) { |
| assert(valueToMemIndex.count(value) && |
| "expected memory index to be assigned"); |
| return valueToMemIndex[value]; |
| } |
| |
| /// Return an index to use when referring to the given data that is uniqued in |
| /// the MLIR context. |
| template <typename T> |
| std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &> |
| getMemIndex(T val) { |
| const void *opaqueVal = val.getAsOpaquePointer(); |
| |
| // Get or insert a reference to this value. |
| auto it = uniquedDataToMemIndex.try_emplace( |
| opaqueVal, maxValueMemoryIndex + uniquedData.size()); |
| if (it.second) |
| uniquedData.push_back(opaqueVal); |
| return it.first->second; |
| } |
| |
| private: |
| /// Allocate memory indices for the results of operations within the matcher |
| /// and rewriters. |
| void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); |
| |
| /// Generate the bytecode for the given operation. |
| void generate(Operation *op, ByteCodeWriter &writer); |
| void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); |
| void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); |
| |
| /// Mapping from value to its corresponding memory index. |
| DenseMap<Value, ByteCodeField> valueToMemIndex; |
| |
| /// Mapping from the name of an externally registered rewrite to its index in |
| /// the bytecode registry. |
| llvm::StringMap<ByteCodeField> externalRewriterToMemIndex; |
| |
| /// Mapping from the name of an externally registered constraint to its index |
| /// in the bytecode registry. |
| llvm::StringMap<ByteCodeField> constraintToMemIndex; |
| |
| /// Mapping from the name of an externally registered creation method to its |
| /// index in the bytecode registry. |
| llvm::StringMap<ByteCodeField> nativeCreateToMemIndex; |
| |
| /// Mapping from rewriter function name to the bytecode address of the |
| /// rewriter function in byte. |
| llvm::StringMap<ByteCodeAddr> rewriterToAddr; |
| |
| /// Mapping from a uniqued storage object to its memory index within |
| /// `uniquedData`. |
| DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex; |
| |
| /// The current MLIR context. |
| MLIRContext *ctx; |
| |
| /// Data of the ByteCode class to be populated. |
| std::vector<const void *> &uniquedData; |
| SmallVectorImpl<ByteCodeField> &matcherByteCode; |
| SmallVectorImpl<ByteCodeField> &rewriterByteCode; |
| SmallVectorImpl<PDLByteCodePattern> &patterns; |
| ByteCodeField &maxValueMemoryIndex; |
| }; |
| |
| /// This class provides utilities for writing a bytecode stream. |
| struct ByteCodeWriter { |
| ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator) |
| : bytecode(bytecode), generator(generator) {} |
| |
| /// Append a field to the bytecode. |
| void append(ByteCodeField field) { bytecode.push_back(field); } |
| void append(OpCode opCode) { bytecode.push_back(opCode); } |
| |
| /// Append an address to the bytecode. |
| void append(ByteCodeAddr field) { |
| static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, |
| "unexpected ByteCode address size"); |
| |
| ByteCodeField fieldParts[2]; |
| std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); |
| bytecode.append({fieldParts[0], fieldParts[1]}); |
| } |
| |
| /// Append a successor range to the bytecode, the exact address will need to |
| /// be resolved later. |
| void append(SuccessorRange successors) { |
| // Add back references to the any successors so that the address can be |
| // resolved later. |
| for (Block *successor : successors) { |
| unresolvedSuccessorRefs[successor].push_back(bytecode.size()); |
| append(ByteCodeAddr(0)); |
| } |
| } |
| |
| /// Append a range of values that will be read as generic PDLValues. |
| void appendPDLValueList(OperandRange values) { |
| bytecode.push_back(values.size()); |
| for (Value value : values) { |
| // Append the type of the value in addition to the value itself. |
| PDLValueKind kind = |
| TypeSwitch<Type, PDLValueKind>(value.getType()) |
| .Case<pdl::AttributeType>( |
| [](Type) { return PDLValueKind::Attribute; }) |
| .Case<pdl::OperationType>( |
| [](Type) { return PDLValueKind::Operation; }) |
| .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; }) |
| .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; }); |
| bytecode.push_back(static_cast<ByteCodeField>(kind)); |
| append(value); |
| } |
| } |
| |
| /// Check if the given class `T` has an iterator type. |
| template <typename T, typename... Args> |
| using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer()); |
| |
| /// Append a value that will be stored in a memory slot and not inline within |
| /// the bytecode. |
| template <typename T> |
| std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value || |
| std::is_pointer<T>::value> |
| append(T value) { |
| bytecode.push_back(generator.getMemIndex(value)); |
| } |
| |
| /// Append a range of values. |
| template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>> |
| std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value> |
| append(T range) { |
| bytecode.push_back(llvm::size(range)); |
| for (auto it : range) |
| append(it); |
| } |
| |
| /// Append a variadic number of fields to the bytecode. |
| template <typename FieldTy, typename Field2Ty, typename... FieldTys> |
| void append(FieldTy field, Field2Ty field2, FieldTys... fields) { |
| append(field); |
| append(field2, fields...); |
| } |
| |
| /// Successor references in the bytecode that have yet to be resolved. |
| DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs; |
| |
| /// The underlying bytecode buffer. |
| SmallVectorImpl<ByteCodeField> &bytecode; |
| |
| /// The main generator producing PDL. |
| Generator &generator; |
| }; |
| } // end anonymous namespace |
| |
| void Generator::generate(ModuleOp module) { |
| FuncOp matcherFunc = module.lookupSymbol<FuncOp>( |
| pdl_interp::PDLInterpDialect::getMatcherFunctionName()); |
| ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>( |
| pdl_interp::PDLInterpDialect::getRewriterModuleName()); |
| assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); |
| |
| // Allocate memory indices for the results of operations within the matcher |
| // and rewriters. |
| allocateMemoryIndices(matcherFunc, rewriterModule); |
| |
| // Generate code for the rewriter functions. |
| ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); |
| for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { |
| rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); |
| for (Operation &op : rewriterFunc.getOps()) |
| generate(&op, rewriterByteCodeWriter); |
| } |
| assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && |
| "unexpected branches in rewriter function"); |
| |
| // Generate code for the matcher function. |
| DenseMap<Block *, ByteCodeAddr> blockToAddr; |
| llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody()); |
| ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); |
| for (Block *block : rpot) { |
| // Keep track of where this block begins within the matcher function. |
| blockToAddr.try_emplace(block, matcherByteCode.size()); |
| for (Operation &op : *block) |
| generate(&op, matcherByteCodeWriter); |
| } |
| |
| // Resolve successor references in the matcher. |
| for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { |
| ByteCodeAddr addr = blockToAddr[it.first]; |
| for (unsigned offsetToFix : it.second) |
| std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); |
| } |
| } |
| |
| void Generator::allocateMemoryIndices(FuncOp matcherFunc, |
| ModuleOp rewriterModule) { |
| // Rewriters use simplistic allocation scheme that simply assigns an index to |
| // each result. |
| for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) { |
| ByteCodeField index = 0; |
| for (BlockArgument arg : rewriterFunc.getArguments()) |
| valueToMemIndex.try_emplace(arg, index++); |
| rewriterFunc.getBody().walk([&](Operation *op) { |
| for (Value result : op->getResults()) |
| valueToMemIndex.try_emplace(result, index++); |
| }); |
| if (index > maxValueMemoryIndex) |
| maxValueMemoryIndex = index; |
| } |
| |
| // The matcher function uses a more sophisticated numbering that tries to |
| // minimize the number of memory indices assigned. This is done by determining |
| // a live range of the values within the matcher, then the allocation is just |
| // finding the minimal number of overlapping live ranges. This is essentially |
| // a simplified form of register allocation where we don't necessarily have a |
| // limited number of registers, but we still want to minimize the number used. |
| DenseMap<Operation *, ByteCodeField> opToIndex; |
| matcherFunc.getBody().walk([&](Operation *op) { |
| opToIndex.insert(std::make_pair(op, opToIndex.size())); |
| }); |
| |
| // Liveness info for each of the defs within the matcher. |
| using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>; |
| LivenessSet::Allocator allocator; |
| DenseMap<Value, LivenessSet> valueDefRanges; |
| |
| // Assign the root operation being matched to slot 0. |
| BlockArgument rootOpArg = matcherFunc.getArgument(0); |
| valueToMemIndex[rootOpArg] = 0; |
| |
| // Walk each of the blocks, computing the def interval that the value is used. |
| Liveness matcherLiveness(matcherFunc); |
| for (Block &block : matcherFunc.getBody()) { |
| const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); |
| assert(info && "expected liveness info for block"); |
| auto processValue = [&](Value value, Operation *firstUseOrDef) { |
| // We don't need to process the root op argument, this value is always |
| // assigned to the first memory slot. |
| if (value == rootOpArg) |
| return; |
| |
| // Set indices for the range of this block that the value is used. |
| auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; |
| defRangeIt->second.insert( |
| opToIndex[firstUseOrDef], |
| opToIndex[info->getEndOperation(value, firstUseOrDef)], |
| /*dummyValue*/ 0); |
| }; |
| |
| // Process the live-ins of this block. |
| for (Value liveIn : info->in()) |
| processValue(liveIn, &block.front()); |
| |
| // Process any new defs within this block. |
| for (Operation &op : block) |
| for (Value result : op.getResults()) |
| processValue(result, &op); |
| } |
| |
| // Greedily allocate memory slots using the computed def live ranges. |
| std::vector<LivenessSet> allocatedIndices; |
| for (auto &defIt : valueDefRanges) { |
| ByteCodeField &memIndex = valueToMemIndex[defIt.first]; |
| LivenessSet &defSet = defIt.second; |
| |
| // Try to allocate to an existing index. |
| for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { |
| LivenessSet &existingIndex = existingIndexIt.value(); |
| llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps( |
| defIt.second, existingIndex); |
| if (overlaps.valid()) |
| continue; |
| // Union the range of the def within the existing index. |
| for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) |
| existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); |
| memIndex = existingIndexIt.index() + 1; |
| } |
| |
| // If no existing index could be used, add a new one. |
| if (memIndex == 0) { |
| allocatedIndices.emplace_back(allocator); |
| for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) |
| allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); |
| memIndex = allocatedIndices.size(); |
| } |
| } |
| |
| // Update the max number of indices. |
| ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; |
| if (numMatcherIndices > maxValueMemoryIndex) |
| maxValueMemoryIndex = numMatcherIndices; |
| } |
| |
| void Generator::generate(Operation *op, ByteCodeWriter &writer) { |
| TypeSwitch<Operation *>(op) |
| .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp, |
| pdl_interp::AreEqualOp, pdl_interp::BranchOp, |
| pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, |
| pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, |
| pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, |
| pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp, |
| pdl_interp::CreateTypeOp, pdl_interp::EraseOp, |
| pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, |
| pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, |
| pdl_interp::GetOperandOp, pdl_interp::GetResultOp, |
| pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp, |
| pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, |
| pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, |
| pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp, |
| pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( |
| [&](auto interpOp) { this->generate(interpOp, writer); }) |
| .Default([](Operation *) { |
| llvm_unreachable("unknown `pdl_interp` operation"); |
| }); |
| } |
| |
| void Generator::generate(pdl_interp::ApplyConstraintOp op, |
| ByteCodeWriter &writer) { |
| assert(constraintToMemIndex.count(op.name()) && |
| "expected index for constraint function"); |
| writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], |
| op.constParamsAttr()); |
| writer.appendPDLValueList(op.args()); |
| writer.append(op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::ApplyRewriteOp op, |
| ByteCodeWriter &writer) { |
| assert(externalRewriterToMemIndex.count(op.name()) && |
| "expected index for rewrite function"); |
| writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], |
| op.constParamsAttr(), op.root()); |
| writer.appendPDLValueList(op.args()); |
| } |
| void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); |
| } |
| void Generator::generate(pdl_interp::CheckAttributeOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::CheckOperandCountOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::CheckOperationNameOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::CheckOperationName, op.operation(), |
| OperationName(op.name(), ctx), op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::CheckResultCountOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::CheckResultCount, op.operation(), op.count(), |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::CreateAttributeOp op, |
| ByteCodeWriter &writer) { |
| // Simply repoint the memory index of the result to the constant. |
| getMemIndex(op.attribute()) = getMemIndex(op.value()); |
| } |
| void Generator::generate(pdl_interp::CreateNativeOp op, |
| ByteCodeWriter &writer) { |
| assert(nativeCreateToMemIndex.count(op.name()) && |
| "expected index for creation function"); |
| writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], |
| op.result(), op.constParamsAttr()); |
| writer.appendPDLValueList(op.args()); |
| } |
| void Generator::generate(pdl_interp::CreateOperationOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::CreateOperation, op.operation(), |
| OperationName(op.name(), ctx), op.operands()); |
| |
| // Add the attributes. |
| OperandRange attributes = op.attributes(); |
| writer.append(static_cast<ByteCodeField>(attributes.size())); |
| for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { |
| writer.append( |
| Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx), |
| std::get<1>(it)); |
| } |
| writer.append(op.types()); |
| } |
| void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { |
| // Simply repoint the memory index of the result to the constant. |
| getMemIndex(op.result()) = getMemIndex(op.value()); |
| } |
| void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::EraseOp, op.operation()); |
| } |
| void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::Finalize); |
| } |
| void Generator::generate(pdl_interp::GetAttributeOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), |
| Identifier::get(op.name(), ctx)); |
| } |
| void Generator::generate(pdl_interp::GetAttributeTypeOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::GetAttributeType, op.result(), op.value()); |
| } |
| void Generator::generate(pdl_interp::GetDefiningOpOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); |
| } |
| void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { |
| uint32_t index = op.index(); |
| if (index < 4) |
| writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index)); |
| else |
| writer.append(OpCode::GetOperandN, index); |
| writer.append(op.operation(), op.value()); |
| } |
| void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { |
| uint32_t index = op.index(); |
| if (index < 4) |
| writer.append(static_cast<OpCode>(OpCode::GetResult0 + index)); |
| else |
| writer.append(OpCode::GetResultN, index); |
| writer.append(op.operation(), op.value()); |
| } |
| void Generator::generate(pdl_interp::GetValueTypeOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::GetValueType, op.result(), op.value()); |
| } |
| void Generator::generate(pdl_interp::InferredTypeOp op, |
| ByteCodeWriter &writer) { |
| // InferType maps to a null type as a marker for inferring a result type. |
| getMemIndex(op.type()) = getMemIndex(Type()); |
| } |
| void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { |
| ByteCodeField patternIndex = patterns.size(); |
| patterns.emplace_back(PDLByteCodePattern::create( |
| op, rewriterToAddr[op.rewriter().getLeafReference()])); |
| writer.append(OpCode::RecordMatch, patternIndex, |
| SuccessorRange(op.getOperation()), op.matchedOps(), |
| op.inputs()); |
| } |
| void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); |
| } |
| void Generator::generate(pdl_interp::SwitchAttributeOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::SwitchOperandCountOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::SwitchOperationNameOp op, |
| ByteCodeWriter &writer) { |
| auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { |
| return OperationName(attr.cast<StringAttr>().getValue(), ctx); |
| }); |
| writer.append(OpCode::SwitchOperationName, op.operation(), cases, |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::SwitchResultCountOp op, |
| ByteCodeWriter &writer) { |
| writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), |
| op.getSuccessors()); |
| } |
| void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { |
| writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), |
| op.getSuccessors()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PDLByteCode |
| //===----------------------------------------------------------------------===// |
| |
| PDLByteCode::PDLByteCode(ModuleOp module, |
| llvm::StringMap<PDLConstraintFunction> constraintFns, |
| llvm::StringMap<PDLCreateFunction> createFns, |
| llvm::StringMap<PDLRewriteFunction> rewriteFns) { |
| Generator generator(module.getContext(), uniquedData, matcherByteCode, |
| rewriterByteCode, patterns, maxValueMemoryIndex, |
| constraintFns, createFns, rewriteFns); |
| generator.generate(module); |
| |
| // Initialize the external functions. |
| for (auto &it : constraintFns) |
| constraintFunctions.push_back(std::move(it.second)); |
| for (auto &it : createFns) |
| createFunctions.push_back(std::move(it.second)); |
| for (auto &it : rewriteFns) |
| rewriteFunctions.push_back(std::move(it.second)); |
| } |
| |
| /// Initialize the given state such that it can be used to execute the current |
| /// bytecode. |
| void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { |
| state.memory.resize(maxValueMemoryIndex, nullptr); |
| state.currentPatternBenefits.reserve(patterns.size()); |
| for (const PDLByteCodePattern &pattern : patterns) |
| state.currentPatternBenefits.push_back(pattern.getBenefit()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ByteCode Execution |
| |
| namespace { |
| /// This class provides support for executing a bytecode stream. |
| class ByteCodeExecutor { |
| public: |
| ByteCodeExecutor(const ByteCodeField *curCodeIt, |
| MutableArrayRef<const void *> memory, |
| ArrayRef<const void *> uniquedMemory, |
| ArrayRef<ByteCodeField> code, |
| ArrayRef<PatternBenefit> currentPatternBenefits, |
| ArrayRef<PDLByteCodePattern> patterns, |
| ArrayRef<PDLConstraintFunction> constraintFunctions, |
| ArrayRef<PDLCreateFunction> createFunctions, |
| ArrayRef<PDLRewriteFunction> rewriteFunctions) |
| : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), |
| code(code), currentPatternBenefits(currentPatternBenefits), |
| patterns(patterns), constraintFunctions(constraintFunctions), |
| createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} |
| |
| /// Start executing the code at the current bytecode index. `matches` is an |
| /// optional field provided when this function is executed in a matching |
| /// context. |
| void execute(PatternRewriter &rewriter, |
| SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, |
| Optional<Location> mainRewriteLoc = {}); |
| |
| private: |
| /// Read a value from the bytecode buffer, optionally skipping a certain |
| /// number of prefix values. These methods always update the buffer to point |
| /// to the next field after the read data. |
| template <typename T = ByteCodeField> |
| T read(size_t skipN = 0) { |
| curCodeIt += skipN; |
| return readImpl<T>(); |
| } |
| ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); } |
| |
| /// Read a list of values from the bytecode buffer. |
| template <typename ValueT, typename T> |
| void readList(SmallVectorImpl<T> &list) { |
| list.clear(); |
| for (unsigned i = 0, e = read(); i != e; ++i) |
| list.push_back(read<ValueT>()); |
| } |
| |
| /// Jump to a specific successor based on a predicate value. |
| void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } |
| /// Jump to a specific successor based on a destination index. |
| void selectJump(size_t destIndex) { |
| curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)]; |
| } |
| |
| /// Handle a switch operation with the provided value and cases. |
| template <typename T, typename RangeT> |
| void handleSwitch(const T &value, RangeT &&cases) { |
| LLVM_DEBUG({ |
| llvm::dbgs() << " * Value: " << value << "\n" |
| << " * Cases: "; |
| llvm::interleaveComma(cases, llvm::dbgs()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| |
| // Check to see if the attribute value is within the case list. Jump to |
| // the correct successor index based on the result. |
| for (auto it = cases.begin(), e = cases.end(); it != e; ++it) |
| if (*it == value) |
| return selectJump(size_t((it - cases.begin()) + 1)); |
| selectJump(size_t(0)); |
| } |
| |
| /// Internal implementation of reading various data types from the bytecode |
| /// stream. |
| template <typename T> |
| const void *readFromMemory() { |
| size_t index = *curCodeIt++; |
| |
| // If this type is an SSA value, it can only be stored in non-const memory. |
| if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size()) |
| return memory[index]; |
| |
| // Otherwise, if this index is not inbounds it is uniqued. |
| return uniquedMemory[index - memory.size()]; |
| } |
| template <typename T> |
| std::enable_if_t<std::is_pointer<T>::value, T> readImpl() { |
| return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>())); |
| } |
| template <typename T> |
| std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value, |
| T> |
| readImpl() { |
| return T(T::getFromOpaquePointer(readFromMemory<T>())); |
| } |
| template <typename T> |
| std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() { |
| switch (static_cast<PDLValueKind>(read())) { |
| case PDLValueKind::Attribute: |
| return read<Attribute>(); |
| case PDLValueKind::Operation: |
| return read<Operation *>(); |
| case PDLValueKind::Type: |
| return read<Type>(); |
| case PDLValueKind::Value: |
| return read<Value>(); |
| } |
| llvm_unreachable("unhandled PDLValueKind"); |
| } |
| template <typename T> |
| std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() { |
| static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, |
| "unexpected ByteCode address size"); |
| ByteCodeAddr result; |
| std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); |
| curCodeIt += 2; |
| return result; |
| } |
| template <typename T> |
| std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() { |
| return *curCodeIt++; |
| } |
| |
| /// The underlying bytecode buffer. |
| const ByteCodeField *curCodeIt; |
| |
| /// The current execution memory. |
| MutableArrayRef<const void *> memory; |
| |
| /// References to ByteCode data necessary for execution. |
| ArrayRef<const void *> uniquedMemory; |
| ArrayRef<ByteCodeField> code; |
| ArrayRef<PatternBenefit> currentPatternBenefits; |
| ArrayRef<PDLByteCodePattern> patterns; |
| ArrayRef<PDLConstraintFunction> constraintFunctions; |
| ArrayRef<PDLCreateFunction> createFunctions; |
| ArrayRef<PDLRewriteFunction> rewriteFunctions; |
| }; |
| } // end anonymous namespace |
| |
| void ByteCodeExecutor::execute( |
| PatternRewriter &rewriter, |
| SmallVectorImpl<PDLByteCode::MatchResult> *matches, |
| Optional<Location> mainRewriteLoc) { |
| while (true) { |
| OpCode opCode = static_cast<OpCode>(read()); |
| switch (opCode) { |
| case ApplyConstraint: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); |
| const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; |
| ArrayAttr constParams = read<ArrayAttr>(); |
| SmallVector<PDLValue, 16> args; |
| readList<PDLValue>(args); |
| LLVM_DEBUG({ |
| llvm::dbgs() << " * Arguments: "; |
| llvm::interleaveComma(args, llvm::dbgs()); |
| llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; |
| }); |
| |
| // Invoke the constraint and jump to the proper destination. |
| selectJump(succeeded(constraintFn(args, constParams, rewriter))); |
| break; |
| } |
| case ApplyRewrite: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); |
| const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; |
| ArrayAttr constParams = read<ArrayAttr>(); |
| Operation *root = read<Operation *>(); |
| SmallVector<PDLValue, 16> args; |
| readList<PDLValue>(args); |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << " * Root: " << *root << "\n" |
| << " * Arguments: "; |
| llvm::interleaveComma(args, llvm::dbgs()); |
| llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; |
| }); |
| rewriteFn(root, args, constParams, rewriter); |
| break; |
| } |
| case AreEqual: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); |
| const void *lhs = read<const void *>(); |
| const void *rhs = read<const void *>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); |
| selectJump(lhs == rhs); |
| break; |
| } |
| case Branch: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); |
| curCodeIt = &code[read<ByteCodeAddr>()]; |
| break; |
| } |
| case CheckOperandCount: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); |
| Operation *op = read<Operation *>(); |
| uint32_t expectedCount = read<uint32_t>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" |
| << " * Expected: " << expectedCount << "\n\n"); |
| selectJump(op->getNumOperands() == expectedCount); |
| break; |
| } |
| case CheckOperationName: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); |
| Operation *op = read<Operation *>(); |
| OperationName expectedName = read<OperationName>(); |
| |
| LLVM_DEBUG(llvm::dbgs() |
| << " * Found: \"" << op->getName() << "\"\n" |
| << " * Expected: \"" << expectedName << "\"\n\n"); |
| selectJump(op->getName() == expectedName); |
| break; |
| } |
| case CheckResultCount: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); |
| Operation *op = read<Operation *>(); |
| uint32_t expectedCount = read<uint32_t>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" |
| << " * Expected: " << expectedCount << "\n\n"); |
| selectJump(op->getNumResults() == expectedCount); |
| break; |
| } |
| case CreateNative: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); |
| const PDLCreateFunction &createFn = createFunctions[read()]; |
| ByteCodeField resultIndex = read(); |
| ArrayAttr constParams = read<ArrayAttr>(); |
| SmallVector<PDLValue, 16> args; |
| readList<PDLValue>(args); |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << " * Arguments: "; |
| llvm::interleaveComma(args, llvm::dbgs()); |
| llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; |
| }); |
| |
| PDLValue result = createFn(args, constParams, rewriter); |
| memory[resultIndex] = result.getAsOpaquePointer(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); |
| break; |
| } |
| case CreateOperation: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); |
| assert(mainRewriteLoc && "expected rewrite loc to be provided when " |
| "executing the rewriter bytecode"); |
| |
| unsigned memIndex = read(); |
| OperationState state(*mainRewriteLoc, read<OperationName>()); |
| readList<Value>(state.operands); |
| for (unsigned i = 0, e = read(); i != e; ++i) { |
| Identifier name = read<Identifier>(); |
| if (Attribute attr = read<Attribute>()) |
| state.addAttribute(name, attr); |
| } |
| |
| bool hasInferredTypes = false; |
| for (unsigned i = 0, e = read(); i != e; ++i) { |
| Type resultType = read<Type>(); |
| hasInferredTypes |= !resultType; |
| state.types.push_back(resultType); |
| } |
| |
| // Handle the case where the operation has inferred types. |
| if (hasInferredTypes) { |
| InferTypeOpInterface::Concept *concept = |
| state.name.getAbstractOperation() |
| ->getInterface<InferTypeOpInterface>(); |
| |
| // TODO: Handle failure. |
| SmallVector<Type, 2> inferredTypes; |
| if (failed(concept->inferReturnTypes( |
| state.getContext(), state.location, state.operands, |
| state.attributes.getDictionary(state.getContext()), |
| state.regions, inferredTypes))) |
| return; |
| |
| for (unsigned i = 0, e = state.types.size(); i != e; ++i) |
| if (!state.types[i]) |
| state.types[i] = inferredTypes[i]; |
| } |
| Operation *resultOp = rewriter.createOperation(state); |
| memory[memIndex] = resultOp; |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << " * Attributes: " |
| << state.attributes.getDictionary(state.getContext()) |
| << "\n * Operands: "; |
| llvm::interleaveComma(state.operands, llvm::dbgs()); |
| llvm::dbgs() << "\n * Result Types: "; |
| llvm::interleaveComma(state.types, llvm::dbgs()); |
| llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; |
| }); |
| break; |
| } |
| case EraseOp: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); |
| Operation *op = read<Operation *>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); |
| rewriter.eraseOp(op); |
| break; |
| } |
| case Finalize: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); |
| return; |
| } |
| case GetAttribute: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); |
| unsigned memIndex = read(); |
| Operation *op = read<Operation *>(); |
| Identifier attrName = read<Identifier>(); |
| Attribute attr = op->getAttr(attrName); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" |
| << " * Attribute: " << attrName << "\n" |
| << " * Result: " << attr << "\n\n"); |
| memory[memIndex] = attr.getAsOpaquePointer(); |
| break; |
| } |
| case GetAttributeType: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); |
| unsigned memIndex = read(); |
| Attribute attr = read<Attribute>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" |
| << " * Result: " << attr.getType() << "\n\n"); |
| memory[memIndex] = attr.getType().getAsOpaquePointer(); |
| break; |
| } |
| case GetDefiningOp: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); |
| unsigned memIndex = read(); |
| Value value = read<Value>(); |
| Operation *op = value ? value.getDefiningOp() : nullptr; |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" |
| << " * Result: " << *op << "\n\n"); |
| memory[memIndex] = op; |
| break; |
| } |
| case GetOperand0: |
| case GetOperand1: |
| case GetOperand2: |
| case GetOperand3: |
| case GetOperandN: { |
| LLVM_DEBUG({ |
| llvm::dbgs() << "Executing GetOperand" |
| << (opCode == GetOperandN ? Twine("N") |
| : Twine(opCode - GetOperand0)) |
| << ":\n"; |
| }); |
| unsigned index = |
| opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0); |
| Operation *op = read<Operation *>(); |
| unsigned memIndex = read(); |
| Value operand = |
| index < op->getNumOperands() ? op->getOperand(index) : Value(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" |
| << " * Index: " << index << "\n" |
| << " * Result: " << operand << "\n\n"); |
| memory[memIndex] = operand.getAsOpaquePointer(); |
| break; |
| } |
| case GetResult0: |
| case GetResult1: |
| case GetResult2: |
| case GetResult3: |
| case GetResultN: { |
| LLVM_DEBUG({ |
| llvm::dbgs() << "Executing GetResult" |
| << (opCode == GetResultN ? Twine("N") |
| : Twine(opCode - GetResult0)) |
| << ":\n"; |
| }); |
| unsigned index = |
| opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0); |
| Operation *op = read<Operation *>(); |
| unsigned memIndex = read(); |
| OpResult result = |
| index < op->getNumResults() ? op->getResult(index) : OpResult(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" |
| << " * Index: " << index << "\n" |
| << " * Result: " << result << "\n\n"); |
| memory[memIndex] = result.getAsOpaquePointer(); |
| break; |
| } |
| case GetValueType: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); |
| unsigned memIndex = read(); |
| Value value = read<Value>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" |
| << " * Result: " << value.getType() << "\n\n"); |
| memory[memIndex] = value.getType().getAsOpaquePointer(); |
| break; |
| } |
| case IsNotNull: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); |
| const void *value = read<const void *>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); |
| selectJump(value != nullptr); |
| break; |
| } |
| case RecordMatch: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); |
| assert(matches && |
| "expected matches to be provided when executing the matcher"); |
| unsigned patternIndex = read(); |
| PatternBenefit benefit = currentPatternBenefits[patternIndex]; |
| const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; |
| |
| // If the benefit of the pattern is impossible, skip the processing of the |
| // rest of the pattern. |
| if (benefit.isImpossibleToMatch()) { |
| LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); |
| curCodeIt = dest; |
| break; |
| } |
| |
| // Create a fused location containing the locations of each of the |
| // operations used in the match. This will be used as the location for |
| // created operations during the rewrite that don't already have an |
| // explicit location set. |
| unsigned numMatchLocs = read(); |
| SmallVector<Location, 4> matchLocs; |
| matchLocs.reserve(numMatchLocs); |
| for (unsigned i = 0; i != numMatchLocs; ++i) |
| matchLocs.push_back(read<Operation *>()->getLoc()); |
| Location matchLoc = rewriter.getFusedLoc(matchLocs); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" |
| << " * Location: " << matchLoc << "\n\n"); |
| matches->emplace_back(matchLoc, patterns[patternIndex], benefit); |
| readList<const void *>(matches->back().values); |
| curCodeIt = dest; |
| break; |
| } |
| case ReplaceOp: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); |
| Operation *op = read<Operation *>(); |
| SmallVector<Value, 16> args; |
| readList<Value>(args); |
| |
| LLVM_DEBUG({ |
| llvm::dbgs() << " * Operation: " << *op << "\n" |
| << " * Values: "; |
| llvm::interleaveComma(args, llvm::dbgs()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| rewriter.replaceOp(op, args); |
| break; |
| } |
| case SwitchAttribute: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); |
| Attribute value = read<Attribute>(); |
| ArrayAttr cases = read<ArrayAttr>(); |
| handleSwitch(value, cases); |
| break; |
| } |
| case SwitchOperandCount: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); |
| Operation *op = read<Operation *>(); |
| auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); |
| handleSwitch(op->getNumOperands(), cases); |
| break; |
| } |
| case SwitchOperationName: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); |
| OperationName value = read<Operation *>()->getName(); |
| size_t caseCount = read(); |
| |
| // The operation names are stored in-line, so to print them out for |
| // debugging purposes we need to read the array before executing the |
| // switch so that we can display all of the possible values. |
| LLVM_DEBUG({ |
| const ByteCodeField *prevCodeIt = curCodeIt; |
| llvm::dbgs() << " * Value: " << value << "\n" |
| << " * Cases: "; |
| llvm::interleaveComma( |
| llvm::map_range(llvm::seq<size_t>(0, caseCount), |
| [&](size_t i) { return read<OperationName>(); }), |
| llvm::dbgs()); |
| llvm::dbgs() << "\n\n"; |
| curCodeIt = prevCodeIt; |
| }); |
| |
| // Try to find the switch value within any of the cases. |
| size_t jumpDest = 0; |
| for (size_t i = 0; i != caseCount; ++i) { |
| if (read<OperationName>() == value) { |
| curCodeIt += (caseCount - i - 1); |
| jumpDest = i + 1; |
| break; |
| } |
| } |
| selectJump(jumpDest); |
| break; |
| } |
| case SwitchResultCount: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); |
| Operation *op = read<Operation *>(); |
| auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); |
| |
| LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); |
| handleSwitch(op->getNumResults(), cases); |
| break; |
| } |
| case SwitchType: { |
| LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); |
| Type value = read<Type>(); |
| auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); |
| handleSwitch(value, cases); |
| break; |
| } |
| } |
| } |
| } |
| |
| /// Run the pattern matcher on the given root operation, collecting the matched |
| /// patterns in `matches`. |
| void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, |
| SmallVectorImpl<MatchResult> &matches, |
| PDLByteCodeMutableState &state) const { |
| // The first memory slot is always the root operation. |
| state.memory[0] = op; |
| |
| // The matcher function always starts at code address 0. |
| ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, |
| matcherByteCode, state.currentPatternBenefits, |
| patterns, constraintFunctions, createFunctions, |
| rewriteFunctions); |
| executor.execute(rewriter, &matches); |
| |
| // Order the found matches by benefit. |
| std::stable_sort(matches.begin(), matches.end(), |
| [](const MatchResult &lhs, const MatchResult &rhs) { |
| return lhs.benefit > rhs.benefit; |
| }); |
| } |
| |
| /// Run the rewriter of the given pattern on the root operation `op`. |
| void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, |
| PDLByteCodeMutableState &state) const { |
| // The arguments of the rewrite function are stored at the start of the |
| // memory buffer. |
| llvm::copy(match.values, state.memory.begin()); |
| |
| ByteCodeExecutor executor( |
| &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, |
| uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, |
| constraintFunctions, createFunctions, rewriteFunctions); |
| executor.execute(rewriter, /*matches=*/nullptr, match.location); |
| } |