| //===- Distibution.cpp - linalg named ops to generic ops --------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the Linalg distibution pass. It updates `tiled_loop` |
| // control variables depending on the distribution type. |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| #define DEBUG_TYPE "linalg-distribution" |
| |
| #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| |
| namespace { |
| |
| struct DistributeTiledLoopPattern |
| : public OpRewritePattern<linalg::TiledLoopOp> { |
| DistributeTiledLoopPattern(MLIRContext *context, |
| LinalgLoopDistributionOptions options, |
| LinalgTransformationFilter marker) |
| : OpRewritePattern<linalg::TiledLoopOp>(context), options(options), |
| marker(marker) {} |
| LogicalResult matchAndRewrite(linalg::TiledLoopOp op, |
| PatternRewriter &rewriter) const override { |
| if (failed(marker.checkAndNotify(rewriter, op))) |
| return failure(); |
| if (!op.distribution_types().hasValue()) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| SmallVector<Value, 2> newLowerBounds = op.lowerBound(); |
| SmallVector<Value, 2> newUpperBounds = op.upperBound(); |
| SmallVector<Value, 2> newSteps = op.step(); |
| |
| // Update bounds and steps. |
| auto distributionTypes = op.distribution_types().getValue(); |
| for (int i = 0, e = op.getNumLoops(); i < e; ++i) { |
| StringRef type = distributionTypes[i].cast<StringAttr>().getValue(); |
| auto procInfoCallback = options.procInfoMap.find(type); |
| if (procInfoCallback == options.procInfoMap.end()) |
| continue; |
| |
| if (!isParallelIterator(op.iterator_types()[i])) { |
| op.emitOpError("only support for parallel loops is implemented"); |
| return failure(); |
| } |
| ProcInfo info = procInfoCallback->second(rewriter, loc); |
| updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs, |
| newLowerBounds[i], newUpperBounds[i], |
| newSteps[i]); |
| } |
| rewriter.updateRootInPlace(op, [&] { |
| op.setLowerBounds(newLowerBounds); |
| op.setUpperBounds(newUpperBounds); |
| op.setSteps(newSteps); |
| }); |
| marker.replaceLinalgTransformationFilter(rewriter, op); |
| return success(); |
| } |
| |
| private: |
| LinalgLoopDistributionOptions options; |
| LinalgTransformationFilter marker; |
| }; |
| |
| } // namespace |
| |
| void mlir::linalg::populateLinalgDistributeTiledLoopPattern( |
| RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts, |
| const LinalgTransformationFilter &marker) { |
| patterns.add<DistributeTiledLoopPattern>(patterns.getContext(), opts, marker); |
| } |