[mlir][linalg] Adapt the decompose patterns to use a filter (NFC).
The revision updates the convolution decomposition patterns to take a linalg transformation filter. The transformation filter in a later revision allows use the patterns from CodegenStrategy.
Depends On D114690
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D114797
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d324256..b4f3f07 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -112,9 +112,10 @@
linalg::LinalgTransformationFilter());
/// Create a LinalgStrategyDecomposePass.
-// TODO: atm this is applied to all supported ops. If/when we need finer control
-// this should be exposed with an opName + filter and a proper pattern.
-std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyDecomposePass();
+// TODO: if/when we need finer control add an `opName` parameter.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyDecomposePass(linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter());
/// Create a LinalgStrategyInterchangePass.
std::unique_ptr<OperationPass<FuncOp>>
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 37bb40e..060f462 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -286,8 +286,7 @@
];
}
-// TODO: atm this is applied to all supported ops. If/when we need finer control
-// this should be exposed with an opName + filter and a proper pattern.
+// TODO: if/when we need finer control add an anchorOp option.
def LinalgStrategyDecomposePass
: FunctionPass<"linalg-strategy-decompose-pass"> {
let summary = "Configurable pass to apply pattern-based generalization.";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5268da7..82e6080 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -49,12 +49,6 @@
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
-/// Populates patterns to decompose high-D convolution ops into low-D ones. This
-/// is a step in progressive lowering for convolution ops, afterwards we can
-/// vectorize the low-D convolution ops.
-void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Populates patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
@@ -1178,6 +1172,16 @@
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
+/// Linalg decompose convolutions patterns
+
+/// Populates patterns to decompose high-D convolution ops into low-D ones. This
+/// is a step in progressive lowering for convolution ops, afterwards we can
+/// vectorize the low-D convolution ops.
+void populateDecomposeConvolutionPatterns(
+ RewritePatternSet &patterns,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+
/// Linalg distribution patterns
//
/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index c1b887e..c006a1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -191,16 +191,21 @@
LinalgStrategyDecomposePass() = default;
+ LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
+ : filter(filter) {}
+
void runOnFunction() override {
auto funcOp = getFunction();
if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
return;
RewritePatternSet decompositionPattern(funcOp.getContext());
- populateDecomposeConvolutionPatterns(decompositionPattern);
+ populateDecomposeConvolutionPatterns(decompositionPattern, filter);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(decompositionPattern))))
signalPassFailure();
}
+
+ LinalgTransformationFilter filter;
};
/// Configurable pass to apply pattern-based linalg generalization.
@@ -478,12 +483,12 @@
LinalgTransformationFilter filter) {
return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
}
+
/// Create a LinalgStrategyDecomposePass.
-// TODO: atm this is applied to all supported ops. If/when we need finer control
-// this should be exposed with an opName + filter and a proper pattern.
+// TODO: if/when we need finer control add an `opName` parameter.
std::unique_ptr<OperationPass<FuncOp>>
-mlir::createLinalgStrategyDecomposePass() {
- return std::make_unique<LinalgStrategyDecomposePass>();
+mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) {
+ return std::make_unique<LinalgStrategyDecomposePass>(filter);
}
/// Create a LinalgStrategyInterchangePass.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 73db8b7..917ff77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -929,31 +929,36 @@
/// convolution ops.
struct DownscaleSizeOneWindowed2DConvolution final
: public OpRewritePattern<Conv2DNhwcHwcfOp> {
- using OpRewritePattern::OpRewritePattern;
+ DownscaleSizeOneWindowed2DConvolution(
+ MLIRContext *context,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {}
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
- auto linalgOp = cast<linalg::LinalgOp>(*convOp);
- if (linalgOp.hasBufferSemantics())
+ if (failed(filter.checkAndNotify(rewriter, convOp)))
+ return failure();
+ if (convOp.hasBufferSemantics())
return failure(); // To be implemented
Value input = convOp.inputs().front();
- Value filter = convOp.inputs().back();
+ Value kernel = convOp.inputs().back();
Value output = convOp.outputs().front();
auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto filterType = filter.getType().dyn_cast<RankedTensorType>();
+ auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
auto outputType = output.getType().dyn_cast<RankedTensorType>();
- auto filterShape = filterType.getShape();
+ auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t fhSize = filterShape[0], fwSize = filterShape[1];
+ int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
int64_t ohSize = outputShape[1], owSize = outputShape[2];
- bool removeH = (fhSize == 1 && ohSize == 1);
- bool removeW = (fwSize == 1 && owSize == 1);
+ bool removeH = (khSize == 1 && ohSize == 1);
+ bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
return failure();
@@ -962,8 +967,8 @@
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
- RankedTensorType newFilterType =
- RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
+ RankedTensorType newKernelType =
+ RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
RankedTensorType newOutputType =
RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
@@ -971,8 +976,8 @@
Location loc = convOp.getLoc();
Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, input, newInputType);
- Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, filter, newFilterType);
+ Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, kernel, newKernelType);
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, output, newOutputType);
@@ -988,7 +993,7 @@
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
- loc, newOutputType, ValueRange{newInput, newFilter},
+ loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
// Insert back.
@@ -996,20 +1001,31 @@
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);
+ filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
return success();
};
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
};
/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
/// dimensions into 1-D depthwise convolution ops.
struct DownscaleDepthwiseConv2DNhwcHwcOp final
: public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
- using OpRewritePattern::OpRewritePattern;
+ DownscaleDepthwiseConv2DNhwcHwcOp(
+ MLIRContext *context,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
+ filter(filter) {}
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
PatternRewriter &rewriter) const override {
- auto linalgOp = cast<linalg::LinalgOp>(*convOp);
- if (linalgOp.hasBufferSemantics())
+ if (failed(filter.checkAndNotify(rewriter, convOp)))
+ return failure();
+ if (convOp.hasBufferSemantics())
return failure(); // To be implemented
Value input = convOp.inputs().front();
@@ -1071,15 +1087,21 @@
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);
+ filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
return success();
};
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
};
} // namespace
-void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit) {
+void linalg::populateDecomposeConvolutionPatterns(
+ RewritePatternSet &patterns, LinalgTransformationFilter filter,
+ PatternBenefit benefit) {
patterns.add<DownscaleSizeOneWindowed2DConvolution,
- DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
+ DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
benefit);
}