[mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops

This change provides `BufferizableOpInterface` implementations for ops from the Bufferization dialects. These ops are needed at the bufferization boundaries for partial bufferization.

Differential Revision: https://reviews.llvm.org/D114618

GitOrigin-RevId: d30fcadf07ee552f20156ea90be2fdb54cb9cb08
diff --git a/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h b/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
new file mode 100644
index 0000000..23c17f4
--- /dev/null
+++ b/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace bufferization_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace bufferization_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 4d7c445..e0c5a10 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -416,10 +416,6 @@
                                                  BufferizationState &state) {
   OpBuilder b(op->getContext());
 
-  // Skip ToMemrefOp and ToTensorOp.
-  if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
-    return success();
-
   // Check if op has tensor results or operands.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
new file mode 100644
index 0000000..c8a2649
--- /dev/null
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -0,0 +1,101 @@
+//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace linalg;
+using namespace comprehensive_bufferize;
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace bufferization_ext {
+
+// TODO: These ops should implement BufferizableOpInterface directly when moved
+// to the Bufferization dialect.
+
+// TODO: These implementations are conservative and will likely have to be
+// loosened for partial bufferization.
+
+/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
+/// location of the incoming tensor once it will be bufferized. In the anlysis,
+/// the incoming tensor is assumed to bufferize to a memory read and to an
+/// inplace memory write, since it is unknown what will happen to the resulting
+/// memref.
+struct ToMemrefOpInterface
+    : public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
+                                                    bufferization::ToMemrefOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    // It is unknown whether the resulting MemRef will be read or not.
+    return true;
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {};
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return OpResult();
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    return success();
+  }
+};
+
+/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
+/// not lower any further, and they should have disappeared by the time the
+/// input is fully bufferized.
+///
+/// The analysis has no information about the memref that is loaded from by the
+/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
+/// potentially alias with any other bufferized tensor. Since ToTensorOp and
+/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
+/// directly in the analysis. However, declaring ToTensorOp results as not
+/// writable also enforces a buffer copy and has the same effect.
+struct ToTensorOpInterface
+    : public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
+                                                    bufferization::ToTensorOp> {
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {};
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
+    state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
+    return success();
+  }
+
+  bool isWritable(Operation *op, Value value) const {
+    // It is unknown whether the MemRef operand is writable or not.
+    return false;
+  }
+};
+
+} // namespace bufferization_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::bufferization_ext::
+    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addOpInterface<bufferization::ToMemrefOp,
+                          bufferization_ext::ToMemrefOpInterface>();
+  registry.addOpInterface<bufferization::ToTensorOp,
+                          bufferization_ext::ToTensorOpInterface>();
+}
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 68d5d03..f033196 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -2,6 +2,7 @@
   AffineInterfaceImpl.cpp
   ArithInterfaceImpl.cpp
   BufferizableOpInterface.cpp
+  BufferizationInterfaceImpl.cpp
   ComprehensiveBufferize.cpp
   LinalgInterfaceImpl.cpp
   ModuleBufferization.cpp
@@ -80,6 +81,7 @@
 )
 
 add_mlir_dialect_library(MLIRComprehensiveBufferize
+  BufferizationInterfaceImpl.cpp
   ComprehensiveBufferize.cpp
   ModuleBufferization.cpp
 
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6cbc308..d4571d3a 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -239,6 +239,12 @@
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand,
                                  const BufferizationAliasInfo &aliasInfo) {
+  // The analysis does not know what happens to the result of a ToMemrefOp, so
+  // we assume that it is written to.
+  // TODO: This is a conservative implementation. This rule will have to be
+  // relaxed for partial bufferization.
+  if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
+    return true;
   // OpOperands without an aliasing OpResult do not write.
   OpResult opResult = getAliasingOpResult(opOperand);
   if (!opResult)
@@ -453,14 +459,23 @@
 /// If `checkConsistencyOnly` is true, this function checks if there is a
 /// read-after-write conflict without bufferizing `operand` inplace. This would
 /// indicate a problem with the current inplace bufferization decisions.
+///
+/// Note: If `checkConsistencyOnly`, this function may be called with a null
+/// OpResult. In that case, only the consistency of bufferization decisions
+/// involving aliases of the given OpOperand are checked.
 bool wouldCreateReadAfterWriteInterference(
     OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
     const BufferizationAliasInfo &aliasInfo,
     bool checkConsistencyOnly = false) {
 #ifndef NDEBUG
-  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
-  assert(llvm::find(opOperands, &operand) != opOperands.end() &&
-         "operand and result do not match");
+  if (result) {
+    SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+    assert(llvm::find(opOperands, &operand) != opOperands.end() &&
+           "operand and result do not match");
+  } else {
+    assert(checkConsistencyOnly &&
+           "result not provided, can only check consistency");
+  }
 #endif // NDEBUG
 
   // Helper function to iterate on aliases of `root` and capture the reads.
@@ -486,9 +501,11 @@
   // Collect reads and writes of all aliases of OpOperand and OpResult.
   DenseSet<OpOperand *> usesRead, usesWrite;
   getAliasingReads(usesRead, operand.get());
-  getAliasingReads(usesRead, result);
+  if (result)
+    getAliasingReads(usesRead, result);
   getAliasingInplaceWrites(usesWrite, operand.get());
-  getAliasingInplaceWrites(usesWrite, result);
+  if (result)
+    getAliasingInplaceWrites(usesWrite, result);
   if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
@@ -673,25 +690,38 @@
   return res;
 }
 
-#ifndef NDEBUG
 /// Assert that the current bufferization decisions are consistent.
-static void checkAliasInfoConsistency(FuncOp funcOp,
-                                      const DominanceInfo &domInfo,
-                                      const BufferizationAliasInfo &aliasInfo) {
-  funcOp.walk([&](Operation *op) {
+static LogicalResult
+checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
+                          const BufferizationAliasInfo &aliasInfo) {
+  Operation *inconsistentOp = nullptr;
+  WalkResult walkResult = funcOp.walk([&](Operation *op) {
     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
       for (OpOperand &opOperand : op->getOpOperands())
-        if (opOperand.get().getType().isa<TensorType>())
-          if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
-            // If this assertion fails, there is probably an inconsistent
-            // combination of "mustBufferizeInPlace" decisions.
-            assert(!wouldCreateReadAfterWriteInterference(
-                       opOperand, opResult, domInfo, aliasInfo,
-                       /*checkConsistencyOnly=*/true) &&
-                   "found read after write conflict before running analysis");
+        if (opOperand.get().getType().isa<TensorType>()) {
+          OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
+          if (wouldCreateReadAfterWriteInterference(
+                  opOperand, opResult, domInfo, aliasInfo,
+                  /*checkConsistencyOnly=*/true)) {
+            // This error can happen for two reasons. Either the input IR
+            // already has a read-after-write conflict. Or certain
+            // "mustBufferizeInPlace" interface methods are implemented
+            // incorrectly.
+            inconsistentOp = op;
+            return WalkResult::interrupt();
+          }
+        }
+    return WalkResult::advance();
   });
+
+  if (walkResult.wasInterrupted())
+    // This can currently happen in one situation: When a tensor is passed into
+    // a ToMemrefOp and read by another op consecutively. ToMemrefOps are
+    // currently handled conservatively. Once a tensor is passed into a
+    // ToMemrefOp, it may longer be read.
+    return inconsistentOp->emitError("input IR has RaW conflict");
+  return success();
 }
-#endif
 
 /// Annotate the IR with the result of the analysis. For testing/debugging only.
 static void
@@ -720,9 +750,8 @@
   if (funcOp.body().empty())
     return success();
 
-#ifndef NDEBUG
-  checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
-#endif // NDEBUG
+  if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
+    return failure();
 
   // If the analysis fails, just return.
   if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
diff --git a/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index ca713d1..1910298 100644
--- a/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
@@ -47,6 +48,7 @@
                 arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 616d845..2e2792b 100644
--- a/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1492,3 +1492,44 @@
   %0 = call @some_use(%A, %v) : (tensor<?xf32>, vector<5xf32>) -> (tensor<?xf32>)
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @to_tensor_op_not_writable
+func @to_tensor_op_not_writable(%m: memref<?xf32>, %v:  vector<5xf32>,
+                                %idx1: index, %idx2: index)
+    -> vector<10xf32> {
+  %0 = bufferization.to_tensor %m : memref<?xf32>
+
+  // Write to the tensor. Cannot be inplace due to tensor_load.
+  //      CHECK: vector.transfer_write
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %w = vector.transfer_write %v, %0[%idx1] : vector<5xf32>, tensor<?xf32>
+
+  // Read from the tensor and return result.
+  %cst = arith.constant 0.0 : f32
+  %r = vector.transfer_read %w[%idx2], %cst : tensor<?xf32>, vector<10xf32>
+  return %r : vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_memref_op_is_reading
+func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                              %idx1: index, %idx2: index, %idx3: index,
+                              %v1: vector<5xf32>)
+    -> (vector<5xf32>, vector<5xf32>) {
+  // Write + read to/from tensor.
+  //      CHECK: vector.transfer_write
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor<?xf32>
+  %cst = arith.constant 0.0 : f32
+  %r1 = vector.transfer_read %1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+
+  // Write + read to/from same memref.
+  %0 = bufferization.to_memref %t1 : memref<?xf32>
+  vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref<?xf32>
+  %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
+
+  return %r1, %r2 : vector<5xf32>, vector<5xf32>
+}
diff --git a/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index c0a91e3..edeb0c0 100644
--- a/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -167,3 +167,23 @@
   }
   return %r: tensor<4xi32>
 }
+
+// -----
+
+func @to_memref_op_is_writing(
+    %t1: tensor<?xf32> {linalg.inplaceable = true}, %idx1: index,
+    %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
+  // This is a RaW conflict because to_memref is an inplace write and %t1 is
+  // read further down. This will likely have to change with partial
+  // bufferization.
+
+  // expected-error @+1 {{input IR has RaW conflict}}
+  %0 = bufferization.to_memref %t1 : memref<?xf32>
+
+  // Read from both.
+  %cst = arith.constant 0.0 : f32
+  %r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+  %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
+
+  return %r1, %r2 : vector<5xf32>, vector<5xf32>
+}