[mlir][linalg] Add switch to disable/enable vector transfer lowering.
Add a switch to code gen strategy to disable/enable the vector transfer lowering and disable it by default.
Differential Revision: https://reviews.llvm.org/D111647
GitOrigin-RevId: c8faeb1edd8447fb67ed7ef04158a07582aa8771
diff --git a/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 251b850..7454c91 100644
--- a/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -228,6 +228,10 @@
this->lateCodegenStrategyOptions.maxTransferRank = val;
return *this;
}
+ CodegenStrategy &setEnableVectorTransferLowering(bool val) {
+ this->lateCodegenStrategyOptions.enableVectorTransferLowering = val;
+ return *this;
+ }
CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) {
this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val;
return *this;
diff --git a/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6d8422f..c29e364 100644
--- a/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -836,6 +836,7 @@
/// Vector lowering operations may result in surprising behavior when
/// composing multiple codegen strategies and must be enabled explicitly.
int64_t maxTransferRank = 1;
+ bool enableVectorTransferLowering = false;
bool enableVectorTransferPartialRewrite = false;
bool enableVectorContractLowering = false;
bool enableVectorToSCFConversion = false;
@@ -854,6 +855,7 @@
/// form.
struct LinalgVectorLoweringOptions {
int64_t maxTransferRank = 1;
+ bool enableVectorTransferLowering = false;
bool enableVectorTransferPartialRewrite = false;
bool enableVectorContractLowering = false;
bool enableVectorToSCFConversion = false;
diff --git a/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index 04bdc49..1770cd9 100644
--- a/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -49,6 +49,8 @@
LinalgVectorLoweringOptions vectorLoweringOptions;
vectorLoweringOptions.maxTransferRank =
lateCodegenStrategyOptions.maxTransferRank;
+ vectorLoweringOptions.enableVectorTransferLowering =
+ lateCodegenStrategyOptions.enableVectorTransferLowering;
vectorLoweringOptions.enableVectorTransferPartialRewrite =
lateCodegenStrategyOptions.enableVectorTransferPartialRewrite;
vectorLoweringOptions.enableVectorContractLowering =
diff --git a/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index bb1793d..617ea0a 100644
--- a/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -260,8 +260,10 @@
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
- vector::populateVectorTransferLoweringPatterns(patterns,
- options.maxTransferRank);
+ if (options.enableVectorTransferLowering) {
+ vector::populateVectorTransferLoweringPatterns(patterns,
+ options.maxTransferRank);
+ }
if (options.enableVectorTransferPartialRewrite) {
patterns.add<vector::VectorTransferFullPartialRewriter>(
context, options.vectorTransformOptions);