blob: e9f72e500686afa81a2e8325b34a989dfcbfc218 [file] [log] [blame]
//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing fusion of elementwise operations in
// Linalg, mainly linalg options.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
static void addOperands(Operation *op, SetVector<Value> &operandSet) {
if (!op)
return;
TypeSwitch<Operation *, void>(op)
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
SmallVector<Value> inputOperands = linalgOp.getInputOperands();
operandSet.insert(inputOperands.begin(), inputOperands.end());
})
.Default([&](Operation *operation) {
operandSet.insert(operation->operand_begin(), operation->operand_end());
});
}
template <int limit = 3>
static bool setFusedOpOperandLimit(const OpResult &producer,
const OpOperand &consumer) {
SetVector<Value> fusedOpOperands;
if (producer.getOwner()->getNumResults() != 1)
return false;
addOperands(consumer.getOwner(), fusedOpOperands);
fusedOpOperands.remove(producer);
addOperands(producer.getOwner(), fusedOpOperands);
return fusedOpOperands.size() <= limit;
}
namespace {
struct TestLinalgElementwiseFusion
: public PassWrapper<TestLinalgElementwiseFusion, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-elementwise-fusion-patterns";
}
StringRef getDescription() const final {
return "Test Linalg element wise operation fusion patterns";
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
RewritePatternSet fusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns,
linalg::LinalgElementwiseFusionOptions()
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
}
};
struct TestLinalgControlFuseByExpansion
: public PassWrapper<TestLinalgControlFuseByExpansion, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-control-fusion-by-expansion";
}
StringRef getDescription() const final {
return "Test controlling of fusion of elementwise ops with reshape by "
"expansion";
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
RewritePatternSet fusionPatterns(context);
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
[](const OpResult &producer, OpOperand &consumer) {
if (auto collapseOp =
producer.getDefiningOp<linalg::TensorCollapseShapeOp>()) {
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
return false;
}
}
if (auto expandOp =
dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
}
}
return linalg::skipUnitDimReshape(producer, consumer);
};
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
}
};
struct TestPushExpandingReshape
: public PassWrapper<TestPushExpandingReshape, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
StringRef getDescription() const final {
return "Test Linalg reshape push patterns";
}
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
} // namespace
namespace test {
void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion>();
}
void registerTestLinalgControlFuseByExpansion() {
PassRegistration<TestLinalgControlFuseByExpansion>();
}
void registerTestPushExpandingReshape() {
PassRegistration<TestPushExpandingReshape>();
}
} // namespace test
} // namespace mlir