| //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a pass to test SCF dialect utils. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| #include "mlir/Dialect/SCF/Utils/Utils.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| using namespace mlir; |
| |
| namespace { |
| struct TestSCFForUtilsPass |
| : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> { |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass) |
| |
| StringRef getArgument() const final { return "test-scf-for-utils"; } |
| StringRef getDescription() const final { return "test scf.for utils"; } |
| explicit TestSCFForUtilsPass() = default; |
| TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} |
| |
| Option<bool> testReplaceWithNewYields{ |
| *this, "test-replace-with-new-yields", |
| llvm::cl::desc("Test replacing a loop with a new loop that returns new " |
| "additional yield values"), |
| llvm::cl::init(false)}; |
| |
| void runOnOperation() override { |
| func::FuncOp func = getOperation(); |
| SmallVector<scf::ForOp, 4> toErase; |
| |
| if (testReplaceWithNewYields) { |
| func.walk([&](scf::ForOp forOp) { |
| if (forOp.getNumResults() == 0) |
| return; |
| auto newInitValues = forOp.getInitArgs(); |
| if (newInitValues.empty()) |
| return; |
| SmallVector<Value> oldYieldValues = |
| llvm::to_vector(forOp.getYieldedValues()); |
| NewYieldValuesFn fn = [&](OpBuilder &b, Location loc, |
| ArrayRef<BlockArgument> newBBArgs) { |
| SmallVector<Value> newYieldValues; |
| for (auto yieldVal : oldYieldValues) { |
| newYieldValues.push_back( |
| b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); |
| } |
| return newYieldValues; |
| }; |
| IRRewriter rewriter(forOp.getContext()); |
| if (failed(forOp.replaceWithAdditionalYields( |
| rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true, |
| fn))) |
| signalPassFailure(); |
| }); |
| } |
| } |
| }; |
| |
| struct TestSCFIfUtilsPass |
| : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass) |
| |
| StringRef getArgument() const final { return "test-scf-if-utils"; } |
| StringRef getDescription() const final { return "test scf.if utils"; } |
| explicit TestSCFIfUtilsPass() = default; |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<func::FuncDialect>(); |
| } |
| |
| void runOnOperation() override { |
| int count = 0; |
| getOperation().walk([&](scf::IfOp ifOp) { |
| auto strCount = std::to_string(count++); |
| func::FuncOp thenFn, elseFn; |
| OpBuilder b(ifOp); |
| IRRewriter rewriter(b); |
| if (failed(outlineIfOp(rewriter, ifOp, &thenFn, |
| std::string("outlined_then") + strCount, &elseFn, |
| std::string("outlined_else") + strCount))) { |
| this->signalPassFailure(); |
| return WalkResult::interrupt(); |
| } |
| return WalkResult::advance(); |
| }); |
| } |
| }; |
| |
| static const StringLiteral kTestPipeliningLoopMarker = |
| "__test_pipelining_loop__"; |
| static const StringLiteral kTestPipeliningStageMarker = |
| "__test_pipelining_stage__"; |
| /// Marker to express the order in which operations should be after |
| /// pipelining. |
| static const StringLiteral kTestPipeliningOpOrderMarker = |
| "__test_pipelining_op_order__"; |
| |
| static const StringLiteral kTestPipeliningAnnotationPart = |
| "__test_pipelining_part"; |
| static const StringLiteral kTestPipeliningAnnotationIteration = |
| "__test_pipelining_iteration"; |
| |
| struct TestSCFPipeliningPass |
| : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> { |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass) |
| |
| TestSCFPipeliningPass() = default; |
| TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} |
| StringRef getArgument() const final { return "test-scf-pipelining"; } |
| StringRef getDescription() const final { return "test scf.forOp pipelining"; } |
| |
| Option<bool> annotatePipeline{ |
| *this, "annotate", |
| llvm::cl::desc("Annote operations during loop pipelining transformation"), |
| llvm::cl::init(false)}; |
| |
| Option<bool> noEpiloguePeeling{ |
| *this, "no-epilogue-peeling", |
| llvm::cl::desc("Use predicates instead of peeling the epilogue."), |
| llvm::cl::init(false)}; |
| |
| static void |
| getSchedule(scf::ForOp forOp, |
| std::vector<std::pair<Operation *, unsigned>> &schedule) { |
| if (!forOp->hasAttr(kTestPipeliningLoopMarker)) |
| return; |
| |
| schedule.resize(forOp.getBody()->getOperations().size() - 1); |
| forOp.walk([&schedule](Operation *op) { |
| auto attrStage = |
| op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); |
| auto attrCycle = |
| op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); |
| if (attrCycle && attrStage) { |
| // TODO: Index can be out-of-bounds if ops of the loop body disappear |
| // due to folding. |
| schedule[attrCycle.getInt()] = |
| std::make_pair(op, unsigned(attrStage.getInt())); |
| } |
| }); |
| } |
| |
| /// Helper to generate "predicated" version of `op`. For simplicity we just |
| /// wrap the operation in a scf.ifOp operation. |
| static Operation *predicateOp(RewriterBase &rewriter, Operation *op, |
| Value pred) { |
| Location loc = op->getLoc(); |
| auto ifOp = |
| rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); |
| // True branch. |
| rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), |
| ifOp.getThenRegion().front().begin()); |
| rewriter.setInsertionPointAfter(op); |
| if (op->getNumResults() > 0) |
| rewriter.create<scf::YieldOp>(loc, op->getResults()); |
| // False branch. |
| rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| SmallVector<Value> elseYieldOperands; |
| elseYieldOperands.reserve(ifOp.getNumResults()); |
| if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) { |
| // For sub-views, just clone the op. |
| // NOTE: This is okay in the test because we use dynamic memref sizes, so |
| // the verifier will not complain. Otherwise, we may create a logically |
| // out-of-bounds view and a different technique should be used. |
| Operation *opClone = rewriter.clone(*op); |
| elseYieldOperands.append(opClone->result_begin(), opClone->result_end()); |
| } else { |
| // Default to assuming constant numeric values. |
| for (Type type : op->getResultTypes()) { |
| elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getZeroAttr(type))); |
| } |
| } |
| if (op->getNumResults() > 0) |
| rewriter.create<scf::YieldOp>(loc, elseYieldOperands); |
| return ifOp.getOperation(); |
| } |
| |
| static void annotate(Operation *op, |
| mlir::scf::PipeliningOption::PipelinerPart part, |
| unsigned iteration) { |
| OpBuilder b(op); |
| switch (part) { |
| case mlir::scf::PipeliningOption::PipelinerPart::Prologue: |
| op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); |
| break; |
| case mlir::scf::PipeliningOption::PipelinerPart::Kernel: |
| op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); |
| break; |
| case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: |
| op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); |
| break; |
| } |
| op->setAttr(kTestPipeliningAnnotationIteration, |
| b.getI32IntegerAttr(iteration)); |
| } |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<arith::ArithDialect, memref::MemRefDialect>(); |
| } |
| |
| void runOnOperation() override { |
| RewritePatternSet patterns(&getContext()); |
| mlir::scf::PipeliningOption options; |
| options.getScheduleFn = getSchedule; |
| options.supportDynamicLoops = true; |
| options.predicateFn = predicateOp; |
| if (annotatePipeline) |
| options.annotateFn = annotate; |
| if (noEpiloguePeeling) { |
| options.peelEpilogue = false; |
| } |
| scf::populateSCFLoopPipeliningPatterns(patterns, options); |
| (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
| getOperation().walk([](Operation *op) { |
| // Clean up the markers. |
| op->removeAttr(kTestPipeliningStageMarker); |
| op->removeAttr(kTestPipeliningOpOrderMarker); |
| }); |
| } |
| }; |
| } // namespace |
| |
| namespace mlir { |
| namespace test { |
| void registerTestSCFUtilsPass() { |
| PassRegistration<TestSCFForUtilsPass>(); |
| PassRegistration<TestSCFIfUtilsPass>(); |
| PassRegistration<TestSCFPipeliningPass>(); |
| } |
| } // namespace test |
| } // namespace mlir |