[mlir][linalg][bufferize][NFC] Move FuncOp boundary bufferization to ModuleBufferization

Differential Revision: https://reviews.llvm.org/D114670

GitOrigin-RevId: 5e1c038f7da50f09e4f39a66d5e10aa06f9546c3
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 1a3ce68..b887af0 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -16,10 +16,9 @@
 // Composability with extensible set of ops is not a first-class concern.
 //
 // Bufferization occurs by:
-//  a. performing an inPlace analysis `inPlaceAnalysisFuncOpBody`
-//     which marks each operation within the function with the
-//     `kInPlaceResultsAttrName` attribute.
-//  b. traversing each operation in the function and rewriting it in
+//  a. performing an inPlace analysis `inPlaceAnalysis` which marks each
+//     operation within the op with the `kInPlaceResultsAttrName` attribute.
+//  b. traversing each operation in the op and rewriting it in
 //     buffer form and keeping a BlockAndValueMapping mapping of the
 //     rewrites. New allocations are introduced during this step.
 //     TODO: Allocation + depending op hoisting to outermost enclosing
@@ -544,37 +543,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Bufferization as simple BlockAndValueMapping rewrites.
-//===----------------------------------------------------------------------===//
-
-/// FuncOp always creates TensorToMemRef ops.
-static LogicalResult bufferizeFuncOp(FuncOp funcOp, BufferizationState &state) {
-  // Take a guard before anything else.
-  OpBuilder b(funcOp->getContext());
-  b.setInsertionPointToStart(&funcOp.body().front());
-
-  // Create BufferCastOps for function args.
-  for (auto bbArg : funcOp.getArguments()) {
-    auto tensorType = bbArg.getType().dyn_cast<TensorType>();
-    if (!tensorType)
-      continue;
-    auto rankedTensorType = tensorType.dyn_cast<RankedTensorType>();
-    // Cast the tensor to the most dynamic buffer possible. Further
-    // canonicalizations will clean up.
-    Type memRefType = rankedTensorType
-                          ? getDynamicMemRefType(rankedTensorType)
-                          : getContiguousOrUnrankedMemRefType(tensorType);
-    Value bufferCast =
-        b.create<bufferization::ToMemrefOp>(funcOp.getLoc(), memRefType, bbArg);
-    state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
-    state.mapBuffer(bbArg, bufferCast);
-  }
-
-  // Bufferize function body.
-  return bufferize(&funcOp.body(), state);
-}
-
-//===----------------------------------------------------------------------===//
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
 
@@ -654,42 +622,6 @@
   return success();
 }
 
-/// Analyze the `funcOp` body to determine which OpResults are inplaceable.
-static LogicalResult
-inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
-                          const DominanceInfo &domInfo,
-                          unsigned analysisFuzzerSeed = 0) {
-  LLVM_DEBUG(llvm::dbgs() << "\n\n");
-  LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
-  assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
-         "expected a funcOp definition with a body");
-
-  // Collect ops so we can build our own reverse traversal.
-  SmallVector<Operation *> ops;
-  funcOp.walk([&](Operation *op) {
-    // No tensors => no buffers.
-    if (none_of(op->getOperandTypes(), isaTensor) &&
-        none_of(op->getResultTypes(), isaTensor))
-      return;
-    ops.push_back(op);
-  });
-
-  // Set the function arguments marked with inplaceable to be known as
-  // bufferizing to a writeable memory.
-  for (BlockArgument bbArg : funcOp.getArguments()) {
-    BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
-        bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName);
-    if (inplaceAttr && inplaceAttr.getValue())
-      aliasInfo.setBufferizesToWritableMemory(bbArg);
-  }
-
-  LogicalResult res =
-      inPlaceAnalysis(ops, aliasInfo, domInfo, analysisFuzzerSeed);
-  LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
-
-  return res;
-}
-
 /// Assert that the current bufferization decisions are consistent.
 static LogicalResult
 checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
@@ -753,9 +685,19 @@
   if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
     return failure();
 
+  // Collect ops so we can build our own reverse traversal.
+  SmallVector<Operation *> ops;
+  funcOp.walk([&](Operation *op) {
+    // No tensors => no buffers.
+    if (none_of(op->getOperandTypes(), isaTensor) &&
+        none_of(op->getResultTypes(), isaTensor))
+      return;
+    ops.push_back(op);
+  });
+
   // If the analysis fails, just return.
-  if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
-                                       options.analysisFuzzerSeed)))
+  if (failed(
+          inPlaceAnalysis(ops, aliasInfo, domInfo, options.analysisFuzzerSeed)))
     return failure();
 
   for (const std::unique_ptr<PostAnalysisStep> &step :
@@ -775,7 +717,11 @@
   }
 
   // Bufferize all ops in funcOp.
-  if (failed(bufferizeFuncOp(funcOp, state)))
+  OpBuilder b(funcOp.getContext());
+  auto bufferizableOp =
+      dyn_cast<BufferizableOpInterface>(funcOp.getOperation());
+  assert(bufferizableOp && "must use ModuleBufferization");
+  if (failed(bufferizableOp.bufferize(b, state)))
     return failure();
 
   // Erase all obsolete ops.
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 5770651..9ed8055 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -629,6 +629,40 @@
   }
 };
 
+struct FuncOpInterface
+    : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto funcOp = cast<FuncOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPointToStart(&funcOp.body().front());
+
+    // Create BufferCastOps for function args.
+    for (auto bbArg : funcOp.getArguments()) {
+      auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+      if (!tensorType)
+        continue;
+      auto rankedTensorType = tensorType.dyn_cast<RankedTensorType>();
+      // Cast the tensor to the most dynamic buffer possible. Further
+      // canonicalizations will clean up.
+      Type memRefType = rankedTensorType
+                            ? getDynamicMemRefType(rankedTensorType)
+                            : getContiguousOrUnrankedMemRefType(tensorType);
+      Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
+                                                             memRefType, bbArg);
+      state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
+      state.mapBuffer(bbArg, bufferCast);
+    }
+
+    // Bufferize function body.
+    return comprehensive_bufferize::bufferize(&funcOp.body(), state);
+  }
+
+  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
+};
+
 } // namespace std_ext
 } // namespace comprehensive_bufferize
 } // namespace linalg
@@ -638,7 +672,7 @@
     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
   registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
   registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
-  registry.addOpInterface<FuncOp, AllocationHoistingBarrierOnly<FuncOp>>();
+  registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
@@ -671,6 +705,15 @@
         if (bbArg.getType().isa<TensorType>())
           aliasInfo.setBufferizesToWritableMemory(bbArg);
 
+    // Set the function arguments marked with inplaceable to be known as
+    // bufferizing to a writeable memory.
+    for (BlockArgument bbArg : funcOp.getArguments()) {
+      BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
+          bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName);
+      if (inplaceAttr && inplaceAttr.getValue())
+        aliasInfo.setBufferizesToWritableMemory(bbArg);
+    }
+
     // Analyze and bufferize funcOp.
     if (failed(runComprehensiveBufferize(funcOp, options, state)))
       return failure();