blob: 54c5bd1976a1f1d3ef96673b856db42bb22b18fa [file] [log] [blame]
//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for testing Linalg hoisting functions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
template <char dim>
static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) {
std::string d(1, dim);
StringAttr attr = b.getStringAttr(d);
Type indexType = b.getIndexType();
ProcInfo procInfo = {b.create<gpu::BlockIdOp>(loc, indexType, attr),
b.create<gpu::GridDimOp>(loc, indexType, attr)};
return procInfo;
}
static LinalgLoopDistributionOptions getDistributionOptions() {
LinalgLoopDistributionOptions opts;
opts.procInfoMap.insert(std::make_pair("block_x", getGpuBlockInfo<'x'>));
opts.procInfoMap.insert(std::make_pair("block_y", getGpuBlockInfo<'y'>));
return opts;
}
namespace {
struct TestLinalgDistribution
: public PassWrapper<TestLinalgDistribution, FunctionPass> {
StringRef getArgument() const final { return "test-linalg-distribution"; }
StringRef getDescription() const final { return "Test Linalg distribution."; }
TestLinalgDistribution() = default;
TestLinalgDistribution(const TestLinalgDistribution &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect>();
}
void runOnFunction() override;
};
} // namespace
void TestLinalgDistribution::runOnFunction() {
auto funcOp = getFunction();
OwningRewritePatternList distributeTiledLoopsPatterns(&getContext());
populateLinalgDistributeTiledLoopPattern(
distributeTiledLoopsPatterns, getDistributionOptions(),
LinalgTransformationFilter(
ArrayRef<StringAttr>{},
{StringAttr::get("distributed", funcOp.getContext())})
.addFilter([](Operation *op) {
return success(!op->getParentOfType<linalg::TiledLoopOp>());
}));
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(distributeTiledLoopsPatterns));
// Ensure we drop the marker in the end.
funcOp.walk([](LinalgOp op) {
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
}
namespace mlir {
namespace test {
void registerTestLinalgDistribution() {
PassRegistration<TestLinalgDistribution>();
}
} // namespace test
} // namespace mlir