blob: e951d688202208330f7b8f286c84f631e5ff4bde [file] [log] [blame]
//===- Distibution.cpp - linalg named ops to generic ops --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the Linalg distibution pass. It updates `tiled_loop`
// control variables depending on the distribution type.
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#define DEBUG_TYPE "linalg-distribution"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
using namespace mlir::linalg;
namespace {
struct DistributeTiledLoopPattern
: public OpRewritePattern<linalg::TiledLoopOp> {
DistributeTiledLoopPattern(MLIRContext *context,
LinalgLoopDistributionOptions options,
LinalgTransformationFilter marker)
: OpRewritePattern<linalg::TiledLoopOp>(context), options(options),
marker(marker) {}
LogicalResult matchAndRewrite(linalg::TiledLoopOp op,
PatternRewriter &rewriter) const override {
if (failed(marker.checkAndNotify(rewriter, op)))
return failure();
if (!op.distribution_types().hasValue())
return failure();
Location loc = op.getLoc();
SmallVector<Value, 2> newLowerBounds = op.lowerBound();
SmallVector<Value, 2> newUpperBounds = op.upperBound();
SmallVector<Value, 2> newSteps = op.step();
// Update bounds and steps.
auto distributionTypes = op.distribution_types().getValue();
for (int i = 0, e = op.getNumLoops(); i < e; ++i) {
StringRef type = distributionTypes[i].cast<StringAttr>().getValue();
auto procInfoCallback = options.procInfoMap.find(type);
if (procInfoCallback == options.procInfoMap.end())
continue;
if (!isParallelIterator(op.iterator_types()[i])) {
op.emitOpError("only support for parallel loops is implemented");
return failure();
}
ProcInfo info = procInfoCallback->second(rewriter, loc);
updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs,
newLowerBounds[i], newUpperBounds[i],
newSteps[i]);
}
rewriter.updateRootInPlace(op, [&] {
op.setLowerBounds(newLowerBounds);
op.setUpperBounds(newUpperBounds);
op.setSteps(newSteps);
});
marker.replaceLinalgTransformationFilter(rewriter, op);
return success();
}
private:
LinalgLoopDistributionOptions options;
LinalgTransformationFilter marker;
};
} // namespace
void mlir::linalg::populateLinalgDistributeTiledLoopPattern(
RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
const LinalgTransformationFilter &marker) {
patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker);
}