| //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| using namespace mlir; |
| |
| /// Custom constraint invoked from PDL. |
| static LogicalResult customSingleEntityConstraint(PDLValue value, |
| ArrayAttr constantParams, |
| PatternRewriter &rewriter) { |
| Operation *rootOp = value.cast<Operation *>(); |
| return success(rootOp->getName().getStringRef() == "test.op"); |
| } |
| static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values, |
| ArrayAttr constantParams, |
| PatternRewriter &rewriter) { |
| return customSingleEntityConstraint(values[1], constantParams, rewriter); |
| } |
| static LogicalResult |
| customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values, |
| ArrayAttr constantParams, |
| PatternRewriter &rewriter) { |
| if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) |
| return failure(); |
| ValueRange operandValues = values[0].cast<ValueRange>(); |
| TypeRange typeValues = values[1].cast<TypeRange>(); |
| if (operandValues.size() != 2 || typeValues.size() != 2) |
| return failure(); |
| return success(); |
| } |
| |
| // Custom creator invoked from PDL. |
| static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams, |
| PatternRewriter &rewriter, PDLResultList &results) { |
| results.push_back(rewriter.createOperation( |
| OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"))); |
| } |
| static void customVariadicResultCreate(ArrayRef<PDLValue> args, |
| ArrayAttr constantParams, |
| PatternRewriter &rewriter, |
| PDLResultList &results) { |
| Operation *root = args[0].cast<Operation *>(); |
| results.push_back(root->getOperands()); |
| results.push_back(root->getOperands().getTypes()); |
| } |
| static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams, |
| PatternRewriter &rewriter, |
| PDLResultList &results) { |
| results.push_back(rewriter.getF32Type()); |
| } |
| |
| /// Custom rewriter invoked from PDL. |
| static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams, |
| PatternRewriter &rewriter, PDLResultList &results) { |
| Operation *root = args[0].cast<Operation *>(); |
| OperationState successOpState(root->getLoc(), "test.success"); |
| successOpState.addOperands(args[1].cast<Value>()); |
| successOpState.addAttribute("constantParams", constantParams); |
| rewriter.createOperation(successOpState); |
| rewriter.eraseOp(root); |
| } |
| |
| namespace { |
| struct TestPDLByteCodePass |
| : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { |
| StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } |
| StringRef getDescription() const final { |
| return "Test PDL ByteCode functionality"; |
| } |
| void runOnOperation() final { |
| ModuleOp module = getOperation(); |
| |
| // The test cases are encompassed via two modules, one containing the |
| // patterns and one containing the operations to rewrite. |
| ModuleOp patternModule = module.lookupSymbol<ModuleOp>( |
| StringAttr::get(module->getContext(), "patterns")); |
| ModuleOp irModule = module.lookupSymbol<ModuleOp>( |
| StringAttr::get(module->getContext(), "ir")); |
| if (!patternModule || !irModule) |
| return; |
| |
| // Process the pattern module. |
| patternModule.getOperation()->remove(); |
| PDLPatternModule pdlPattern(patternModule); |
| pdlPattern.registerConstraintFunction("multi_entity_constraint", |
| customMultiEntityConstraint); |
| pdlPattern.registerConstraintFunction("single_entity_constraint", |
| customSingleEntityConstraint); |
| pdlPattern.registerConstraintFunction("multi_entity_var_constraint", |
| customMultiEntityVariadicConstraint); |
| pdlPattern.registerRewriteFunction("creator", customCreate); |
| pdlPattern.registerRewriteFunction("var_creator", |
| customVariadicResultCreate); |
| pdlPattern.registerRewriteFunction("type_creator", customCreateType); |
| pdlPattern.registerRewriteFunction("rewriter", customRewriter); |
| |
| RewritePatternSet patternList(std::move(pdlPattern)); |
| |
| // Invoke the pattern driver with the provided patterns. |
| (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), |
| std::move(patternList)); |
| } |
| }; |
| } // end anonymous namespace |
| |
| namespace mlir { |
| namespace test { |
| void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } |
| } // namespace test |
| } // namespace mlir |