blob: 61e1fbcf3feaf3a35b104d0bf9da7b74c4b52fd7 [file] [log] [blame]
//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
using namespace mlir;
/// Helper function to call the `makeRegionIsolatedFromAbove` to convert
/// `test.one_region_op` to `test.isolated_one_region_op`.
static LogicalResult
makeIsolatedFromAboveImpl(RewriterBase &rewriter,
test::OneRegionWithOperandsOp regionOp,
llvm::function_ref<bool(Operation *)> callBack) {
Region &region = regionOp.getRegion();
SmallVector<Value> capturedValues =
makeRegionIsolatedFromAbove(rewriter, region, callBack);
SmallVector<Value> operands = regionOp.getOperands();
operands.append(capturedValues);
auto isolatedRegionOp =
rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
isolatedRegionOp.getRegion().begin());
rewriter.eraseOp(regionOp);
return success();
}
namespace {
/// Simple test for making region isolated from above without cloning any
/// operations.
struct SimpleMakeIsolatedFromAbove
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp,
[](Operation *) { return false; });
}
};
/// Test for making region isolated from above while clong operations
/// with no operands.
struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
return op->getNumOperands() == 0;
});
}
};
/// Test for making region isolated from above while clong operations
/// with no operands.
struct MakeIsolatedFromAboveAndCloneOpsWithOperands
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp,
[](Operation *op) { return true; });
}
};
/// Test pass for testing the `makeIsolatedFromAbove` function.
struct TestMakeIsolatedFromAbovePass
: public PassWrapper<TestMakeIsolatedFromAbovePass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
TestMakeIsolatedFromAbovePass() = default;
TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final {
return "test-make-isolated-from-above";
}
StringRef getDescription() const final {
return "Test making a region isolated from above";
}
Option<bool> simple{
*this, "simple",
llvm::cl::desc("Test simple case with no cloning of operations"),
llvm::cl::init(false)};
Option<bool> cloneOpsWithNoOperands{
*this, "clone-ops-with-no-operands",
llvm::cl::desc("Test case with cloning of operations with no operands"),
llvm::cl::init(false)};
Option<bool> cloneOpsWithOperands{
*this, "clone-ops-with-operands",
llvm::cl::desc("Test case with cloning of operations with no operands"),
llvm::cl::init(false)};
void runOnOperation() override;
};
} // namespace
void TestMakeIsolatedFromAbovePass::runOnOperation() {
MLIRContext *context = &getContext();
func::FuncOp funcOp = getOperation();
if (simple) {
RewritePatternSet patterns(context);
patterns.insert<SimpleMakeIsolatedFromAbove>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
if (cloneOpsWithNoOperands) {
RewritePatternSet patterns(context);
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
if (cloneOpsWithOperands) {
RewritePatternSet patterns(context);
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
}
namespace mlir {
namespace test {
void registerTestMakeIsolatedFromAbovePass() {
PassRegistration<TestMakeIsolatedFromAbovePass>();
}
} // namespace test
} // namespace mlir