[mlir] Add base class type aliases for rewrites/conversions. NFC. (#158433)
This is to simplify writing rewrite/conversion patterns that usually
start with:
```c++
struct MyPattern : public OpRewritePattern<MyOp> {
using OpRewritePattern::OpRewritePattern;
```
and allow for:
```c++
struct MyPattern : public OpRewritePattern<MyOp> {
using Base::Base;
```
similar to how we enable it for pass classes.
GitOrigin-RevId: c88f3c582dc2ef5f2fdfd0c5887f5f7562f49095
diff --git a/include/mlir/IR/PatternMatch.h b/include/mlir/IR/PatternMatch.h
index 7b0b9ce..576481a 100644
--- a/include/mlir/IR/PatternMatch.h
+++ b/include/mlir/IR/PatternMatch.h
@@ -312,6 +312,9 @@
template <typename SourceOp>
struct OpRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpRewritePattern;
/// Patterns must specify the root operation name they match against, and can
/// also specify the benefit of the pattern matching and a list of generated
@@ -328,6 +331,9 @@
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpInterfaceRewritePattern;
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
@@ -341,6 +347,10 @@
template <template <typename> class TraitType>
class OpTraitRewritePattern : public RewritePattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpTraitRewritePattern;
+
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
benefit, context) {}
diff --git a/include/mlir/Transforms/DialectConversion.h b/include/mlir/Transforms/DialectConversion.h
index bfbe12d..6ef649e 100644
--- a/include/mlir/Transforms/DialectConversion.h
+++ b/include/mlir/Transforms/DialectConversion.h
@@ -40,6 +40,10 @@
/// registered using addConversion and addMaterialization, respectively.
class TypeConverter {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = TypeConverter;
+
virtual ~TypeConverter() = default;
TypeConverter() = default;
// Copy the registered conversions, but not the caches
@@ -679,6 +683,10 @@
template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpConversionPattern;
+
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
@@ -729,6 +737,10 @@
template <typename SourceOp>
class OpInterfaceConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpInterfaceConversionPattern;
+
OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
SourceOp::getInterfaceID(), benefit, context) {}
@@ -773,6 +785,10 @@
template <template <typename> class TraitType>
class OpTraitConversionPattern : public ConversionPattern {
public:
+ /// Type alias to allow derived classes to inherit constructors with
+ /// `using Base::Base;`.
+ using Base = OpTraitConversionPattern;
+
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
TypeID::get<TraitType>(), benefit, context) {}
diff --git a/test/lib/Dialect/Test/TestPatterns.cpp b/test/lib/Dialect/Test/TestPatterns.cpp
index 93b007c..f8b5144 100644
--- a/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/test/lib/Dialect/Test/TestPatterns.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -114,7 +115,8 @@
struct FolderInsertBeforePreviouslyFoldedConstantPattern
: public OpRewritePattern<TestCastOp> {
public:
- using OpRewritePattern<TestCastOp>::OpRewritePattern;
+ static_assert(std::is_same_v<Base, OpRewritePattern<TestCastOp>>);
+ using Base::Base;
LogicalResult matchAndRewrite(TestCastOp op,
PatternRewriter &rewriter) const override {
@@ -1306,7 +1308,8 @@
/// b) or: drops all block arguments and replaces each with 2x the first
/// operand.
class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
- using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
+ static_assert(std::is_same_v<Base, OpConversionPattern<ConvertBlockArgsOp>>);
+ using Base::Base;
LogicalResult
matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
@@ -1431,7 +1434,9 @@
namespace {
struct TestTypeConverter : public TypeConverter {
- using TypeConverter::TypeConverter;
+ static_assert(std::is_same_v<Base, TypeConverter>);
+ using Base::Base;
+
TestTypeConverter() {
addConversion(convertType);
addSourceMaterialization(materializeCast);