[MLIR][Shape] Lower `size_to_index` and `index_to_size` with declarative rules
Replace implemented rewrite patterns with equivalent declarative rules.
Differential Revision: https://reviews.llvm.org/D82023
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 6cce898..e774114 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -38,32 +38,6 @@
}
};
-class IndexToSizeOpConversion : public OpConversionPattern<IndexToSizeOp> {
-public:
- using OpConversionPattern<IndexToSizeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(IndexToSizeOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- IndexToSizeOp::Adaptor transformed(operands);
- rewriter.replaceOp(op.getOperation(), transformed.arg());
- return success();
- }
-};
-
-class SizeToIndexOpConversion : public OpConversionPattern<SizeToIndexOp> {
-public:
- using OpConversionPattern<SizeToIndexOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(SizeToIndexOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- SizeToIndexOp::Adaptor transformed(operands);
- rewriter.replaceOp(op.getOperation(), transformed.arg());
- return success();
- }
-};
-
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
public:
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
@@ -132,9 +106,7 @@
patterns.insert<
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
- ConstSizeOpConverter,
- IndexToSizeOpConversion,
- SizeToIndexOpConversion>(ctx);
+ ConstSizeOpConverter>(ctx);
// clang-format on
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
index 3ad5421..a133548 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td
@@ -10,3 +10,12 @@
(Shape_ToExtentTensorOp $input),
(replaceWithValue $input)>;
+// Convert `index_to_size` and `size_to_index` to no-ops as sizes will be
+// represented as indices.
+def IndexToSizeOpConversion : Pat<
+ (Shape_IndexToSizeOp $arg),
+ (replaceWithValue $arg)>;
+def SizeToIndexOpConversion : Pat<
+ (Shape_SizeToIndexOp $arg),
+ (replaceWithValue $arg)>;
+