| //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" |
| #include "../PassDetail.h" |
| #include "PredicateTree.h" |
| #include "mlir/Dialect/PDL/IR/PDL.h" |
| #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
| #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
| #include "mlir/Pass/Pass.h" |
| #include "llvm/ADT/MapVector.h" |
| #include "llvm/ADT/ScopedHashTable.h" |
| #include "llvm/ADT/Sequence.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| using namespace mlir; |
| using namespace mlir::pdl_to_pdl_interp; |
| |
| //===----------------------------------------------------------------------===// |
| // PatternLowering |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class generators operations within the PDL Interpreter dialect from a |
| /// given module containing PDL pattern operations. |
| struct PatternLowering { |
| public: |
| PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule); |
| |
| /// Generate code for matching and rewriting based on the pattern operations |
| /// within the module. |
| void lower(ModuleOp module); |
| |
| private: |
| using ValueMap = llvm::ScopedHashTable<Position *, Value>; |
| using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>; |
| |
| /// Generate interpreter operations for the tree rooted at the given matcher |
| /// node, in the specified region. |
| Block *generateMatcher(MatcherNode &node, Region ®ion); |
| |
| /// Get or create an access to the provided positional value in the current |
| /// block. This operation may mutate the provided block pointer if nested |
| /// regions (i.e., pdl_interp.iterate) are required. |
| Value getValueAt(Block *¤tBlock, Position *pos); |
| |
| /// Create the interpreter predicate operations. This operation may mutate the |
| /// provided current block pointer if nested regions (iterates) are required. |
| void generate(BoolNode *boolNode, Block *¤tBlock, Value val); |
| |
| /// Create the interpreter switch / predicate operations, with several case |
| /// destinations. This operation never mutates the provided current block |
| /// pointer, because the switch operation does not need Values beyond `val`. |
| void generate(SwitchNode *switchNode, Block *currentBlock, Value val); |
| |
| /// Create the interpreter operations to record a successful pattern match |
| /// using the contained root operation. This operation may mutate the current |
| /// block pointer if nested regions (i.e., pdl_interp.iterate) are required. |
| void generate(SuccessNode *successNode, Block *¤tBlock); |
| |
| /// Generate a rewriter function for the given pattern operation, and returns |
| /// a reference to that function. |
| SymbolRefAttr generateRewriter(pdl::PatternOp pattern, |
| SmallVectorImpl<Position *> &usedMatchValues); |
| |
| /// Generate the rewriter code for the given operation. |
| void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::AttributeOp attrOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::EraseOp eraseOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::OperationOp operationOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::ReplaceOp replaceOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::ResultOp resultOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::ResultsOp resultOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::TypeOp typeOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| void generateRewriter(pdl::TypesOp typeOp, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| |
| /// Generate the values used for resolving the result types of an operation |
| /// created within a dag rewriter region. |
| void generateOperationResultTypeRewriter( |
| pdl::OperationOp op, SmallVectorImpl<Value> &types, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue); |
| |
| /// A builder to use when generating interpreter operations. |
| OpBuilder builder; |
| |
| /// The matcher function used for all match related logic within PDL patterns. |
| FuncOp matcherFunc; |
| |
| /// The rewriter module containing the all rewrite related logic within PDL |
| /// patterns. |
| ModuleOp rewriterModule; |
| |
| /// The symbol table of the rewriter module used for insertion. |
| SymbolTable rewriterSymbolTable; |
| |
| /// A scoped map connecting a position with the corresponding interpreter |
| /// value. |
| ValueMap values; |
| |
| /// A stack of blocks used as the failure destination for matcher nodes that |
| /// don't have an explicit failure path. |
| SmallVector<Block *, 8> failureBlockStack; |
| |
| /// A mapping between values defined in a pattern match, and the corresponding |
| /// positional value. |
| DenseMap<Value, Position *> valueToPosition; |
| |
| /// The set of operation values whose whose location will be used for newly |
| /// generated operations. |
| SetVector<Value> locOps; |
| }; |
| } // end anonymous namespace |
| |
| PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule) |
| : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), |
| rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} |
| |
| void PatternLowering::lower(ModuleOp module) { |
| PredicateUniquer predicateUniquer; |
| PredicateBuilder predicateBuilder(predicateUniquer, module.getContext()); |
| |
| // Define top-level scope for the arguments to the matcher function. |
| ValueMapScope topLevelValueScope(values); |
| |
| // Insert the root operation, i.e. argument to the matcher, at the root |
| // position. |
| Block *matcherEntryBlock = matcherFunc.addEntryBlock(); |
| values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0)); |
| |
| // Generate a root matcher node from the provided PDL module. |
| std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree( |
| module, predicateBuilder, valueToPosition); |
| Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody()); |
| assert(failureBlockStack.empty() && "failed to empty the stack"); |
| |
| // After generation, merged the first matched block into the entry. |
| matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(), |
| firstMatcherBlock->getOperations()); |
| firstMatcherBlock->erase(); |
| } |
| |
| Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { |
| // Push a new scope for the values used by this matcher. |
| Block *block = ®ion.emplaceBlock(); |
| ValueMapScope scope(values); |
| |
| // If this is the return node, simply insert the corresponding interpreter |
| // finalize. |
| if (isa<ExitNode>(node)) { |
| builder.setInsertionPointToEnd(block); |
| builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc()); |
| return block; |
| } |
| |
| // Get the next block in the match sequence. |
| // This is intentionally executed first, before we get the value for the |
| // position associated with the node, so that we preserve an "there exist" |
| // semantics: if getting a value requires an upward traversal (going from a |
| // value to its consumers), we want to perform the check on all the consumers |
| // before we pass control to the failure node. |
| std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode(); |
| Block *failureBlock; |
| if (failureNode) { |
| failureBlock = generateMatcher(*failureNode, region); |
| failureBlockStack.push_back(failureBlock); |
| } else { |
| assert(!failureBlockStack.empty() && "expected valid failure block"); |
| failureBlock = failureBlockStack.back(); |
| } |
| |
| // If this node contains a position, get the corresponding value for this |
| // block. |
| Block *currentBlock = block; |
| Position *position = node.getPosition(); |
| Value val = position ? getValueAt(currentBlock, position) : Value(); |
| |
| // If this value corresponds to an operation, record that we are going to use |
| // its location as part of a fused location. |
| bool isOperationValue = val && val.getType().isa<pdl::OperationType>(); |
| if (isOperationValue) |
| locOps.insert(val); |
| |
| // Dispatch to the correct method based on derived node type. |
| TypeSwitch<MatcherNode *>(&node) |
| .Case<BoolNode, SwitchNode>([&](auto *derivedNode) { |
| this->generate(derivedNode, currentBlock, val); |
| }) |
| .Case([&](SuccessNode *successNode) { |
| generate(successNode, currentBlock); |
| }); |
| |
| // Pop all the failure blocks that were inserted due to nesting of |
| // pdl_interp.iterate. |
| while (failureBlockStack.back() != failureBlock) { |
| failureBlockStack.pop_back(); |
| assert(!failureBlockStack.empty() && "unable to locate failure block"); |
| } |
| |
| // Pop the new failure block. |
| if (failureNode) |
| failureBlockStack.pop_back(); |
| |
| if (isOperationValue) |
| locOps.remove(val); |
| |
| return block; |
| } |
| |
| Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { |
| if (Value val = values.lookup(pos)) |
| return val; |
| |
| // Get the value for the parent position. |
| Value parentVal = getValueAt(currentBlock, pos->getParent()); |
| |
| // TODO: Use a location from the position. |
| Location loc = parentVal.getLoc(); |
| builder.setInsertionPointToEnd(currentBlock); |
| Value value; |
| switch (pos->getKind()) { |
| case Predicates::OperationPos: { |
| auto *operationPos = cast<OperationPosition>(pos); |
| if (!operationPos->isUpward()) { |
| // Standard (downward) traversal which directly follows the defining op. |
| value = builder.create<pdl_interp::GetDefiningOpOp>( |
| loc, builder.getType<pdl::OperationType>(), parentVal); |
| break; |
| } |
| |
| // The first operation retrieves the representative value of a range. |
| // This applies only when the parent is a range of values. |
| if (parentVal.getType().isa<pdl::RangeType>()) |
| value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0); |
| else |
| value = parentVal; |
| |
| // The second operation retrieves the users. |
| value = builder.create<pdl_interp::GetUsersOp>(loc, value); |
| |
| // The third operation iterates over them. |
| assert(!failureBlockStack.empty() && "expected valid failure block"); |
| auto foreach = builder.create<pdl_interp::ForEachOp>( |
| loc, value, failureBlockStack.back(), /*initLoop=*/true); |
| value = foreach.getLoopVariable(); |
| |
| // Create the success and continuation blocks. |
| Block *successBlock = builder.createBlock(&foreach.region()); |
| Block *continueBlock = builder.createBlock(successBlock); |
| builder.create<pdl_interp::ContinueOp>(loc); |
| failureBlockStack.push_back(continueBlock); |
| |
| // The fourth operation extracts the operand(s) of the user at the specified |
| // index (which can be None, indicating all operands). |
| builder.setInsertionPointToStart(&foreach.region().front()); |
| Value operands = builder.create<pdl_interp::GetOperandsOp>( |
| loc, parentVal.getType(), value, operationPos->getIndex()); |
| |
| // The fifth operation compares the operands to the parent value / range. |
| builder.create<pdl_interp::AreEqualOp>(loc, parentVal, operands, |
| successBlock, continueBlock); |
| currentBlock = successBlock; |
| break; |
| } |
| case Predicates::OperandPos: { |
| auto *operandPos = cast<OperandPosition>(pos); |
| value = builder.create<pdl_interp::GetOperandOp>( |
| loc, builder.getType<pdl::ValueType>(), parentVal, |
| operandPos->getOperandNumber()); |
| break; |
| } |
| case Predicates::OperandGroupPos: { |
| auto *operandPos = cast<OperandGroupPosition>(pos); |
| Type valueTy = builder.getType<pdl::ValueType>(); |
| value = builder.create<pdl_interp::GetOperandsOp>( |
| loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, |
| parentVal, operandPos->getOperandGroupNumber()); |
| break; |
| } |
| case Predicates::AttributePos: { |
| auto *attrPos = cast<AttributePosition>(pos); |
| value = builder.create<pdl_interp::GetAttributeOp>( |
| loc, builder.getType<pdl::AttributeType>(), parentVal, |
| attrPos->getName().strref()); |
| break; |
| } |
| case Predicates::TypePos: { |
| if (parentVal.getType().isa<pdl::AttributeType>()) |
| value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal); |
| else |
| value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal); |
| break; |
| } |
| case Predicates::ResultPos: { |
| auto *resPos = cast<ResultPosition>(pos); |
| value = builder.create<pdl_interp::GetResultOp>( |
| loc, builder.getType<pdl::ValueType>(), parentVal, |
| resPos->getResultNumber()); |
| break; |
| } |
| case Predicates::ResultGroupPos: { |
| auto *resPos = cast<ResultGroupPosition>(pos); |
| Type valueTy = builder.getType<pdl::ValueType>(); |
| value = builder.create<pdl_interp::GetResultsOp>( |
| loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy, |
| parentVal, resPos->getResultGroupNumber()); |
| break; |
| } |
| default: |
| llvm_unreachable("Generating unknown Position getter"); |
| break; |
| } |
| |
| values.insert(pos, value); |
| return value; |
| } |
| |
| void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, |
| Value val) { |
| Location loc = val.getLoc(); |
| Qualifier *question = boolNode->getQuestion(); |
| Qualifier *answer = boolNode->getAnswer(); |
| Region *region = currentBlock->getParent(); |
| |
| // Execute the getValue queries first, so that we create success |
| // matcher in the correct (possibly nested) region. |
| SmallVector<Value> args; |
| if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) { |
| args = {getValueAt(currentBlock, equalToQuestion->getValue())}; |
| } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) { |
| for (Position *position : std::get<1>(cstQuestion->getValue())) |
| args.push_back(getValueAt(currentBlock, position)); |
| } |
| |
| // Generate the matcher in the current (potentially nested) region |
| // and get the failure successor. |
| Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); |
| Block *failure = failureBlockStack.back(); |
| |
| // Finally, create the predicate. |
| builder.setInsertionPointToEnd(currentBlock); |
| Predicates::Kind kind = question->getKind(); |
| switch (kind) { |
| case Predicates::IsNotNullQuestion: |
| builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure); |
| break; |
| case Predicates::OperationNameQuestion: { |
| auto *opNameAnswer = cast<OperationNameAnswer>(answer); |
| builder.create<pdl_interp::CheckOperationNameOp>( |
| loc, val, opNameAnswer->getValue().getStringRef(), success, failure); |
| break; |
| } |
| case Predicates::TypeQuestion: { |
| auto *ans = cast<TypeAnswer>(answer); |
| if (val.getType().isa<pdl::RangeType>()) |
| builder.create<pdl_interp::CheckTypesOp>( |
| loc, val, ans->getValue().cast<ArrayAttr>(), success, failure); |
| else |
| builder.create<pdl_interp::CheckTypeOp>( |
| loc, val, ans->getValue().cast<TypeAttr>(), success, failure); |
| break; |
| } |
| case Predicates::AttributeQuestion: { |
| auto *ans = cast<AttributeAnswer>(answer); |
| builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(), |
| success, failure); |
| break; |
| } |
| case Predicates::OperandCountAtLeastQuestion: |
| case Predicates::OperandCountQuestion: |
| builder.create<pdl_interp::CheckOperandCountOp>( |
| loc, val, cast<UnsignedAnswer>(answer)->getValue(), |
| /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion, |
| success, failure); |
| break; |
| case Predicates::ResultCountAtLeastQuestion: |
| case Predicates::ResultCountQuestion: |
| builder.create<pdl_interp::CheckResultCountOp>( |
| loc, val, cast<UnsignedAnswer>(answer)->getValue(), |
| /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion, |
| success, failure); |
| break; |
| case Predicates::EqualToQuestion: { |
| bool trueAnswer = isa<TrueAnswer>(answer); |
| builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(), |
| trueAnswer ? success : failure, |
| trueAnswer ? failure : success); |
| break; |
| } |
| case Predicates::ConstraintQuestion: { |
| auto value = cast<ConstraintQuestion>(question)->getValue(); |
| builder.create<pdl_interp::ApplyConstraintOp>( |
| loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(), |
| success, failure); |
| break; |
| } |
| default: |
| llvm_unreachable("Generating unknown Predicate operation"); |
| } |
| } |
| |
| template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> |
| static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, |
| llvm::MapVector<Qualifier *, Block *> &dests) { |
| std::vector<ValT> values; |
| std::vector<Block *> blocks; |
| values.reserve(dests.size()); |
| blocks.reserve(dests.size()); |
| for (const auto &it : dests) { |
| blocks.push_back(it.second); |
| values.push_back(cast<PredT>(it.first)->getValue()); |
| } |
| builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks); |
| } |
| |
| void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock, |
| Value val) { |
| Qualifier *question = switchNode->getQuestion(); |
| Region *region = currentBlock->getParent(); |
| Block *defaultDest = failureBlockStack.back(); |
| |
| // If the switch question is not an exact answer, i.e. for the `at_least` |
| // cases, we generate a special block sequence. |
| Predicates::Kind kind = question->getKind(); |
| if (kind == Predicates::OperandCountAtLeastQuestion || |
| kind == Predicates::ResultCountAtLeastQuestion) { |
| // Order the children such that the cases are in reverse numerical order. |
| SmallVector<unsigned> sortedChildren = llvm::to_vector<16>( |
| llvm::seq<unsigned>(0, switchNode->getChildren().size())); |
| llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) { |
| return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() > |
| cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue(); |
| }); |
| |
| // Build the destination for each child using the next highest child as a |
| // a failure destination. This essentially creates the following control |
| // flow: |
| // |
| // if (operand_count < 1) |
| // goto failure |
| // if (child1.match()) |
| // ... |
| // |
| // if (operand_count < 2) |
| // goto failure |
| // if (child2.match()) |
| // ... |
| // |
| // failure: |
| // ... |
| // |
| failureBlockStack.push_back(defaultDest); |
| Location loc = val.getLoc(); |
| for (unsigned idx : sortedChildren) { |
| auto &child = switchNode->getChild(idx); |
| Block *childBlock = generateMatcher(*child.second, *region); |
| Block *predicateBlock = builder.createBlock(childBlock); |
| builder.setInsertionPointToEnd(predicateBlock); |
| unsigned ans = cast<UnsignedAnswer>(child.first)->getValue(); |
| switch (kind) { |
| case Predicates::OperandCountAtLeastQuestion: |
| builder.create<pdl_interp::CheckOperandCountOp>( |
| loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); |
| break; |
| case Predicates::ResultCountAtLeastQuestion: |
| builder.create<pdl_interp::CheckResultCountOp>( |
| loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest); |
| break; |
| default: |
| llvm_unreachable("Generating invalid AtLeast operation"); |
| } |
| failureBlockStack.back() = predicateBlock; |
| } |
| Block *firstPredicateBlock = failureBlockStack.pop_back_val(); |
| currentBlock->getOperations().splice(currentBlock->end(), |
| firstPredicateBlock->getOperations()); |
| firstPredicateBlock->erase(); |
| return; |
| } |
| |
| // Otherwise, generate each of the children and generate an interpreter |
| // switch. |
| llvm::MapVector<Qualifier *, Block *> children; |
| for (auto &it : switchNode->getChildren()) |
| children.insert({it.first, generateMatcher(*it.second, *region)}); |
| builder.setInsertionPointToEnd(currentBlock); |
| |
| switch (question->getKind()) { |
| case Predicates::OperandCountQuestion: |
| return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer, |
| int32_t>(val, defaultDest, builder, children); |
| case Predicates::ResultCountQuestion: |
| return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer, |
| int32_t>(val, defaultDest, builder, children); |
| case Predicates::OperationNameQuestion: |
| return createSwitchOp<pdl_interp::SwitchOperationNameOp, |
| OperationNameAnswer>(val, defaultDest, builder, |
| children); |
| case Predicates::TypeQuestion: |
| if (val.getType().isa<pdl::RangeType>()) { |
| return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>( |
| val, defaultDest, builder, children); |
| } |
| return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>( |
| val, defaultDest, builder, children); |
| case Predicates::AttributeQuestion: |
| return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>( |
| val, defaultDest, builder, children); |
| default: |
| llvm_unreachable("Generating unknown switch predicate."); |
| } |
| } |
| |
| void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) { |
| pdl::PatternOp pattern = successNode->getPattern(); |
| Value root = successNode->getRoot(); |
| |
| // Generate a rewriter for the pattern this success node represents, and track |
| // any values used from the match region. |
| SmallVector<Position *, 8> usedMatchValues; |
| SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues); |
| |
| // Process any values used in the rewrite that are defined in the match. |
| std::vector<Value> mappedMatchValues; |
| mappedMatchValues.reserve(usedMatchValues.size()); |
| for (Position *position : usedMatchValues) |
| mappedMatchValues.push_back(getValueAt(currentBlock, position)); |
| |
| // Collect the set of operations generated by the rewriter. |
| SmallVector<StringRef, 4> generatedOps; |
| for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>()) |
| generatedOps.push_back(*op.name()); |
| ArrayAttr generatedOpsAttr; |
| if (!generatedOps.empty()) |
| generatedOpsAttr = builder.getStrArrayAttr(generatedOps); |
| |
| // Grab the root kind if present. |
| StringAttr rootKindAttr; |
| if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>()) |
| if (Optional<StringRef> rootKind = rootOp.name()) |
| rootKindAttr = builder.getStringAttr(*rootKind); |
| |
| builder.setInsertionPointToEnd(currentBlock); |
| builder.create<pdl_interp::RecordMatchOp>( |
| pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), |
| rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(), |
| failureBlockStack.back()); |
| } |
| |
| SymbolRefAttr PatternLowering::generateRewriter( |
| pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) { |
| FuncOp rewriterFunc = |
| FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter", |
| builder.getFunctionType(llvm::None, llvm::None)); |
| rewriterSymbolTable.insert(rewriterFunc); |
| |
| // Generate the rewriter function body. |
| builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock()); |
| |
| // Map an input operand of the pattern to a generated interpreter value. |
| DenseMap<Value, Value> rewriteValues; |
| auto mapRewriteValue = [&](Value oldValue) { |
| Value &newValue = rewriteValues[oldValue]; |
| if (newValue) |
| return newValue; |
| |
| // Prefer materializing constants directly when possible. |
| Operation *oldOp = oldValue.getDefiningOp(); |
| if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) { |
| if (Attribute value = attrOp.valueAttr()) { |
| return newValue = builder.create<pdl_interp::CreateAttributeOp>( |
| attrOp.getLoc(), value); |
| } |
| } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) { |
| if (TypeAttr type = typeOp.typeAttr()) { |
| return newValue = builder.create<pdl_interp::CreateTypeOp>( |
| typeOp.getLoc(), type); |
| } |
| } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) { |
| if (ArrayAttr type = typeOp.typesAttr()) { |
| return newValue = builder.create<pdl_interp::CreateTypesOp>( |
| typeOp.getLoc(), typeOp.getType(), type); |
| } |
| } |
| |
| // Otherwise, add this as an input to the rewriter. |
| Position *inputPos = valueToPosition.lookup(oldValue); |
| assert(inputPos && "expected value to be a pattern input"); |
| usedMatchValues.push_back(inputPos); |
| return newValue = rewriterFunc.front().addArgument(oldValue.getType()); |
| }; |
| |
| // If this is a custom rewriter, simply dispatch to the registered rewrite |
| // method. |
| pdl::RewriteOp rewriter = pattern.getRewriter(); |
| if (StringAttr rewriteName = rewriter.nameAttr()) { |
| SmallVector<Value> args; |
| if (rewriter.root()) |
| args.push_back(mapRewriteValue(rewriter.root())); |
| auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); |
| args.append(mappedArgs.begin(), mappedArgs.end()); |
| builder.create<pdl_interp::ApplyRewriteOp>( |
| rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, |
| rewriter.externalConstParamsAttr()); |
| } else { |
| // Otherwise this is a dag rewriter defined using PDL operations. |
| for (Operation &rewriteOp : *rewriter.getBody()) { |
| llvm::TypeSwitch<Operation *>(&rewriteOp) |
| .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, |
| pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, |
| pdl::TypeOp, pdl::TypesOp>([&](auto op) { |
| this->generateRewriter(op, rewriteValues, mapRewriteValue); |
| }); |
| } |
| } |
| |
| // Update the signature of the rewrite function. |
| rewriterFunc.setType(builder.getFunctionType( |
| llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()), |
| /*results=*/llvm::None)); |
| |
| builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc()); |
| return SymbolRefAttr::get( |
| builder.getContext(), |
| pdl_interp::PDLInterpDialect::getRewriterModuleName(), |
| SymbolRefAttr::get(rewriterFunc)); |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| SmallVector<Value, 2> arguments; |
| for (Value argument : rewriteOp.args()) |
| arguments.push_back(mapRewriteValue(argument)); |
| auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>( |
| rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), |
| arguments, rewriteOp.constParamsAttr()); |
| for (auto it : llvm::zip(rewriteOp.results(), interpOp.results())) |
| rewriteValues[std::get<0>(it)] = std::get<1>(it); |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| Value newAttr = builder.create<pdl_interp::CreateAttributeOp>( |
| attrOp.getLoc(), attrOp.valueAttr()); |
| rewriteValues[attrOp] = newAttr; |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(), |
| mapRewriteValue(eraseOp.operation())); |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| SmallVector<Value, 4> operands; |
| for (Value operand : operationOp.operands()) |
| operands.push_back(mapRewriteValue(operand)); |
| |
| SmallVector<Value, 4> attributes; |
| for (Value attr : operationOp.attributes()) |
| attributes.push_back(mapRewriteValue(attr)); |
| |
| SmallVector<Value, 2> types; |
| generateOperationResultTypeRewriter(operationOp, types, rewriteValues, |
| mapRewriteValue); |
| |
| // Create the new operation. |
| Location loc = operationOp.getLoc(); |
| Value createdOp = builder.create<pdl_interp::CreateOperationOp>( |
| loc, *operationOp.name(), types, operands, attributes, |
| operationOp.attributeNames()); |
| rewriteValues[operationOp.op()] = createdOp; |
| |
| // Generate accesses for any results that have their types constrained. |
| // Handle the case where there is a single range representing all of the |
| // result types. |
| OperandRange resultTys = operationOp.types(); |
| if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) { |
| Value &type = rewriteValues[resultTys[0]]; |
| if (!type) { |
| auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp); |
| type = builder.create<pdl_interp::GetValueTypeOp>(loc, results); |
| } |
| return; |
| } |
| |
| // Otherwise, populate the individual results. |
| bool seenVariableLength = false; |
| Type valueTy = builder.getType<pdl::ValueType>(); |
| Type valueRangeTy = pdl::RangeType::get(valueTy); |
| for (auto it : llvm::enumerate(resultTys)) { |
| Value &type = rewriteValues[it.value()]; |
| if (type) |
| continue; |
| bool isVariadic = it.value().getType().isa<pdl::RangeType>(); |
| seenVariableLength |= isVariadic; |
| |
| // After a variable length result has been seen, we need to use result |
| // groups because the exact index of the result is not statically known. |
| Value resultVal; |
| if (seenVariableLength) |
| resultVal = builder.create<pdl_interp::GetResultsOp>( |
| loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index()); |
| else |
| resultVal = builder.create<pdl_interp::GetResultOp>( |
| loc, valueTy, createdOp, it.index()); |
| type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal); |
| } |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| SmallVector<Value, 4> replOperands; |
| |
| // If the replacement was another operation, get its results. `pdl` allows |
| // for using an operation for simplicitly, but the interpreter isn't as |
| // user facing. |
| if (Value replOp = replaceOp.replOperation()) { |
| // Don't use replace if we know the replaced operation has no results. |
| auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>(); |
| if (!opOp || !opOp.types().empty()) { |
| replOperands.push_back(builder.create<pdl_interp::GetResultsOp>( |
| replOp.getLoc(), mapRewriteValue(replOp))); |
| } |
| } else { |
| for (Value operand : replaceOp.replValues()) |
| replOperands.push_back(mapRewriteValue(operand)); |
| } |
| |
| // If there are no replacement values, just create an erase instead. |
| if (replOperands.empty()) { |
| builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(), |
| mapRewriteValue(replaceOp.operation())); |
| return; |
| } |
| |
| builder.create<pdl_interp::ReplaceOp>( |
| replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands); |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>( |
| resultOp.getLoc(), builder.getType<pdl::ValueType>(), |
| mapRewriteValue(resultOp.parent()), resultOp.index()); |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>( |
| resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()), |
| resultOp.index()); |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| // If the type isn't constant, the users (e.g. OperationOp) will resolve this |
| // type. |
| if (TypeAttr typeAttr = typeOp.typeAttr()) { |
| rewriteValues[typeOp] = |
| builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr); |
| } |
| } |
| |
| void PatternLowering::generateRewriter( |
| pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| // If the type isn't constant, the users (e.g. OperationOp) will resolve this |
| // type. |
| if (ArrayAttr typeAttr = typeOp.typesAttr()) { |
| rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>( |
| typeOp.getLoc(), typeOp.getType(), typeAttr); |
| } |
| } |
| |
| void PatternLowering::generateOperationResultTypeRewriter( |
| pdl::OperationOp op, SmallVectorImpl<Value> &types, |
| DenseMap<Value, Value> &rewriteValues, |
| function_ref<Value(Value)> mapRewriteValue) { |
| // Look for an operation that was replaced by `op`. The result types will be |
| // inferred from the results that were replaced. |
| Block *rewriterBlock = op->getBlock(); |
| Value replacedOp; |
| for (OpOperand &use : op.op().getUses()) { |
| // Check that the use corresponds to a ReplaceOp and that it is the |
| // replacement value, not the operation being replaced. |
| pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner()); |
| if (!replOpUser || use.getOperandNumber() == 0) |
| continue; |
| // Make sure the replaced operation was defined before this one. |
| Value replOpVal = replOpUser.operation(); |
| Operation *replacedOp = replOpVal.getDefiningOp(); |
| if (replacedOp->getBlock() == rewriterBlock && |
| !replacedOp->isBeforeInBlock(op)) |
| continue; |
| |
| Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>( |
| replacedOp->getLoc(), mapRewriteValue(replOpVal)); |
| types.push_back(builder.create<pdl_interp::GetValueTypeOp>( |
| replacedOp->getLoc(), replacedOpResults)); |
| return; |
| } |
| |
| // Check if the operation has type inference support. |
| if (op.hasTypeInference()) { |
| types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc())); |
| return; |
| } |
| |
| // Otherwise, handle inference for each of the result types individually. |
| OperandRange resultTypeValues = op.types(); |
| types.reserve(resultTypeValues.size()); |
| for (auto it : llvm::enumerate(resultTypeValues)) { |
| Value resultType = it.value(); |
| |
| // Check for an already translated value. |
| if (Value existingRewriteValue = rewriteValues.lookup(resultType)) { |
| types.push_back(existingRewriteValue); |
| continue; |
| } |
| |
| // Check for an input from the matcher. |
| if (resultType.getDefiningOp()->getBlock() != rewriterBlock) { |
| types.push_back(mapRewriteValue(resultType)); |
| continue; |
| } |
| |
| // The verifier asserts that the result types of each pdl.operation can be |
| // inferred. If we reach here, there is a bug either in the logic above or |
| // in the verifier for pdl.operation. |
| op->emitOpError() << "unable to infer result type for operation"; |
| llvm_unreachable("unable to infer result type for operation"); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion Pass |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| struct PDLToPDLInterpPass |
| : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { |
| void runOnOperation() final; |
| }; |
| } // namespace |
| |
| /// Convert the given module containing PDL pattern operations into a PDL |
| /// Interpreter operations. |
| void PDLToPDLInterpPass::runOnOperation() { |
| ModuleOp module = getOperation(); |
| |
| // Create the main matcher function This function contains all of the match |
| // related functionality from patterns in the module. |
| OpBuilder builder = OpBuilder::atBlockBegin(module.getBody()); |
| FuncOp matcherFunc = builder.create<FuncOp>( |
| module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(), |
| builder.getFunctionType(builder.getType<pdl::OperationType>(), |
| /*results=*/llvm::None), |
| /*attrs=*/llvm::None); |
| |
| // Create a nested module to hold the functions invoked for rewriting the IR |
| // after a successful match. |
| ModuleOp rewriterModule = builder.create<ModuleOp>( |
| module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); |
| |
| // Generate the code for the patterns within the module. |
| PatternLowering generator(matcherFunc, rewriterModule); |
| generator.lower(module); |
| |
| // After generation, delete all of the pattern operations. |
| for (pdl::PatternOp pattern : |
| llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) |
| pattern.erase(); |
| } |
| |
| std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { |
| return std::make_unique<PDLToPDLInterpPass>(); |
| } |