[mlir][linalg] Adapt the decompose patterns to use a filter (NFC).

The revision updates the convolution decomposition patterns to take a linalg transformation filter. The transformation filter in a later revision allows use the patterns from CodegenStrategy.

Depends On D114690

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114797
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d324256..b4f3f07 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -112,9 +112,10 @@
                                        linalg::LinalgTransformationFilter());
 
 /// Create a LinalgStrategyDecomposePass.
-// TODO: atm this is applied to all supported ops. If/when we need finer control
-// this should be exposed with an opName + filter and a proper pattern.
-std::unique_ptr<OperationPass<FuncOp>> createLinalgStrategyDecomposePass();
+// TODO: if/when we need finer control add an `opName` parameter.
+std::unique_ptr<OperationPass<FuncOp>>
+createLinalgStrategyDecomposePass(linalg::LinalgTransformationFilter filter =
+                                      linalg::LinalgTransformationFilter());
 
 /// Create a LinalgStrategyInterchangePass.
 std::unique_ptr<OperationPass<FuncOp>>
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 37bb40e..060f462 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -286,8 +286,7 @@
   ];
 }
 
-// TODO: atm this is applied to all supported ops. If/when we need finer control
-// this should be exposed with an opName + filter and a proper pattern.
+// TODO: if/when we need finer control add an anchorOp option.
 def LinalgStrategyDecomposePass
     : FunctionPass<"linalg-strategy-decompose-pass"> {
   let summary = "Configurable pass to apply pattern-based generalization.";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5268da7..82e6080 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -49,12 +49,6 @@
     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
     ArrayRef<int64_t> tileSizes);
 
-/// Populates patterns to decompose high-D convolution ops into low-D ones. This
-/// is a step in progressive lowering for convolution ops, afterwards we can
-/// vectorize the low-D convolution ops.
-void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
-
 /// Populates patterns for vectorizing low-D convolution ops. This is a step in
 /// progressive lowering for convolution ops, it assume high-D convolution ops
 /// were decomposed previously.
@@ -1178,6 +1172,16 @@
     RewritePatternSet &patterns,
     LinalgTransformationFilter filter = LinalgTransformationFilter());
 
+/// Linalg decompose convolutions patterns
+
+/// Populates patterns to decompose high-D convolution ops into low-D ones. This
+/// is a step in progressive lowering for convolution ops, afterwards we can
+/// vectorize the low-D convolution ops.
+void populateDecomposeConvolutionPatterns(
+    RewritePatternSet &patterns,
+    LinalgTransformationFilter filter = LinalgTransformationFilter(),
+    PatternBenefit benefit = 1);
+
 /// Linalg distribution patterns
 //
 /// Populates `patterns` with patterns to distribute linalg.tiled_loop.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index c1b887e..c006a1e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -191,16 +191,21 @@
 
   LinalgStrategyDecomposePass() = default;
 
+  LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
+      : filter(filter) {}
+
   void runOnFunction() override {
     auto funcOp = getFunction();
     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
       return;
     RewritePatternSet decompositionPattern(funcOp.getContext());
-    populateDecomposeConvolutionPatterns(decompositionPattern);
+    populateDecomposeConvolutionPatterns(decompositionPattern, filter);
     if (failed(applyPatternsAndFoldGreedily(funcOp,
                                             std::move(decompositionPattern))))
       signalPassFailure();
   }
+
+  LinalgTransformationFilter filter;
 };
 
 /// Configurable pass to apply pattern-based linalg generalization.
@@ -478,12 +483,12 @@
                                          LinalgTransformationFilter filter) {
   return std::make_unique<LinalgStrategyGeneralizePass>(opName, filter);
 }
+
 /// Create a LinalgStrategyDecomposePass.
-// TODO: atm this is applied to all supported ops. If/when we need finer control
-// this should be exposed with an opName + filter and a proper pattern.
+// TODO: if/when we need finer control add an `opName` parameter.
 std::unique_ptr<OperationPass<FuncOp>>
-mlir::createLinalgStrategyDecomposePass() {
-  return std::make_unique<LinalgStrategyDecomposePass>();
+mlir::createLinalgStrategyDecomposePass(LinalgTransformationFilter filter) {
+  return std::make_unique<LinalgStrategyDecomposePass>(filter);
 }
 
 /// Create a LinalgStrategyInterchangePass.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 73db8b7..917ff77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -929,31 +929,36 @@
 /// convolution ops.
 struct DownscaleSizeOneWindowed2DConvolution final
     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
-  using OpRewritePattern::OpRewritePattern;
+  DownscaleSizeOneWindowed2DConvolution(
+      MLIRContext *context,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit), filter(filter) {}
 
   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
                                 PatternRewriter &rewriter) const override {
-    auto linalgOp = cast<linalg::LinalgOp>(*convOp);
-    if (linalgOp.hasBufferSemantics())
+    if (failed(filter.checkAndNotify(rewriter, convOp)))
+      return failure();
+    if (convOp.hasBufferSemantics())
       return failure(); // To be implemented
 
     Value input = convOp.inputs().front();
-    Value filter = convOp.inputs().back();
+    Value kernel = convOp.inputs().back();
     Value output = convOp.outputs().front();
 
     auto inputType = input.getType().dyn_cast<RankedTensorType>();
-    auto filterType = filter.getType().dyn_cast<RankedTensorType>();
+    auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
     auto outputType = output.getType().dyn_cast<RankedTensorType>();
 
-    auto filterShape = filterType.getShape();
+    auto kernelShape = kernelType.getShape();
     auto outputShape = outputType.getShape();
 
     // Only handle the case where at least one of the window dimensions is
     // of size 1. Other cases can rely on tiling to reduce to such cases.
-    int64_t fhSize = filterShape[0], fwSize = filterShape[1];
+    int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
     int64_t ohSize = outputShape[1], owSize = outputShape[2];
-    bool removeH = (fhSize == 1 && ohSize == 1);
-    bool removeW = (fwSize == 1 && owSize == 1);
+    bool removeH = (khSize == 1 && ohSize == 1);
+    bool removeW = (kwSize == 1 && owSize == 1);
     if (!removeH && !removeW)
       return failure();
 
@@ -962,8 +967,8 @@
     using RTTBuilder = RankedTensorType::Builder;
     RankedTensorType newInputType =
         RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-    RankedTensorType newFilterType =
-        RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
+    RankedTensorType newKernelType =
+        RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
     RankedTensorType newOutputType =
         RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
 
@@ -971,8 +976,8 @@
     Location loc = convOp.getLoc();
     Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
         rewriter, loc, input, newInputType);
-    Value newFilter = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, filter, newFilterType);
+    Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, kernel, newKernelType);
     Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
         rewriter, loc, output, newOutputType);
 
@@ -988,7 +993,7 @@
     auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
 
     auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
-        loc, newOutputType, ValueRange{newInput, newFilter},
+        loc, newOutputType, ValueRange{newInput, newKernel},
         ValueRange{newOutput}, stridesAttr, dilationsAttr);
 
     // Insert back.
@@ -996,20 +1001,31 @@
         rewriter, loc, conv1DOp.getResult(0), output);
     rewriter.replaceOp(convOp, inserted);
 
+    filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
     return success();
   };
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
 };
 
 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
 /// dimensions into 1-D depthwise convolution ops.
 struct DownscaleDepthwiseConv2DNhwcHwcOp final
     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
-  using OpRewritePattern::OpRewritePattern;
+  DownscaleDepthwiseConv2DNhwcHwcOp(
+      MLIRContext *context,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
+        filter(filter) {}
 
   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
                                 PatternRewriter &rewriter) const override {
-    auto linalgOp = cast<linalg::LinalgOp>(*convOp);
-    if (linalgOp.hasBufferSemantics())
+    if (failed(filter.checkAndNotify(rewriter, convOp)))
+      return failure();
+    if (convOp.hasBufferSemantics())
       return failure(); // To be implemented
 
     Value input = convOp.inputs().front();
@@ -1071,15 +1087,21 @@
         rewriter, loc, conv1DOp.getResult(0), output);
     rewriter.replaceOp(convOp, inserted);
 
+    filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
     return success();
   };
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
 };
 
 } // namespace
 
-void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
-                                                  PatternBenefit benefit) {
+void linalg::populateDecomposeConvolutionPatterns(
+    RewritePatternSet &patterns, LinalgTransformationFilter filter,
+    PatternBenefit benefit) {
   patterns.add<DownscaleSizeOneWindowed2DConvolution,
-               DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
+               DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
                                                   benefit);
 }