[mlir][linalg][bufferize] Compose dialect-specific bufferization state

Use composition instead of inheritance for storing dialect-specific bufferization state. This is in preparation of adding "tensor dialect"-specific bufferization state.

Differential Revision: https://reviews.llvm.org/D114508
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index f52a9aa..e03aaea 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -230,6 +230,13 @@
   MemCpyFn memCpyFn;
 };
 
+/// Dialect-specific bufferization state. Analysis/bufferization information
+/// that is specific to ops from a certain dialect can be stored in derived
+/// variants of this struct.
+struct DialectBufferizationState {
+  virtual ~DialectBufferizationState() = default;
+};
+
 /// BufferizationState keeps track of bufferization state and provides access to
 /// the results of the analysis.
 struct BufferizationState {
@@ -271,6 +278,14 @@
   /// Erase all ops that were marked obsolete.
   void eraseObsoleteOps();
 
+  /// Return dialect-specific bufferization state.
+  template <typename StateT> StateT &getDialectState(StringRef name) {
+    // Create state if it does not exist yet.
+    if (!dialectState.count(name))
+      dialectState[name] = std::make_unique<StateT>();
+    return static_cast<StateT &>(*dialectState[name]);
+  }
+
   /// `aliasInfo` keeps track of aliasing and equivalent values.
   BufferizationAliasInfo aliasInfo;
 
@@ -284,6 +299,9 @@
 
   /// Obsolete ops that should be deleted after bufferization.
   SmallVector<Operation *> obsoleteOps;
+
+  /// Dialect-specific bufferization state.
+  DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
 };
 
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index e354195..c98cc1d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -27,11 +27,9 @@
 using namespace comprehensive_bufferize;
 
 namespace {
-/// A specialization of BufferizationState that keeps track of additional
-/// state required for bufferization of function boundaries.
-struct ModuleBufferizationState : public BufferizationState {
-  using BufferizationState::BufferizationState;
-
+/// Extra bufferization state that is required for bufferization of function
+/// boundaries.
+struct ModuleBufferizationState : public DialectBufferizationState {
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
 
@@ -40,6 +38,12 @@
 };
 } // namespace
 
+static ModuleBufferizationState &
+getModuleBufferizationState(BufferizationState &state) {
+  return state.getDialectState<ModuleBufferizationState>(
+      StandardOpsDialect::getDialectNamespace());
+}
+
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
 /// If `value` is a memref::CastOp, return its source. Otherwise, return
@@ -127,7 +131,9 @@
 /// Store function BlockArguments that are equivalent to a returned value in
 /// the given ModuleBufferizationState.
 static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
-                                           ModuleBufferizationState &state) {
+                                           BufferizationState &state) {
+  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+
   // Support only single return-terminated block in the function.
   ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
   assert(returnOp && "expected func with single return op");
@@ -137,7 +143,7 @@
       for (BlockArgument bbArg : funcOp.getArguments())
         if (bbArg.getType().isa<RankedTensorType>())
           if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
-            state.equivalentReturnValToBBArg[returnVal] = bbArg;
+            moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
 }
 
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
@@ -155,8 +161,9 @@
 /// originate from an op with an Alloc effect, they could be hoisted in the
 /// future.
 static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
-                                             ModuleBufferizationState &state) {
+                                             BufferizationState &state) {
   LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // If nothing to do then we are done.
@@ -188,7 +195,7 @@
                                  << "returns a tensor";
     FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
         funcOp, funcOp.getType().getInputs(), TypeRange{},
-        state.bufferizedFunctionTypes);
+        moduleState.bufferizedFunctionTypes);
     funcOp.setType(bufferizedFuncType);
     LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
     return success();
@@ -210,7 +217,7 @@
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    if (state.equivalentReturnValToBBArg.count(returnVal))
+    if (moduleState.equivalentReturnValToBBArg.count(returnVal))
       continue;
 
     // Cast values at the call site if necessary.
@@ -221,7 +228,7 @@
   ValueRange retValues{returnValues};
   FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
       funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
-      state.bufferizedFunctionTypes);
+      moduleState.bufferizedFunctionTypes);
   OpBuilder b(returnOp);
   b.create<ReturnOp>(returnOp.getLoc(), returnValues);
   returnOp->erase();
@@ -474,7 +481,7 @@
     FuncOp funcOp = getCalledFunction(callOp);
     assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
            "expected Callop to a FuncOp");
-    auto &moduleState = static_cast<ModuleBufferizationState &>(state);
+    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
@@ -649,7 +656,7 @@
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return failure();
 
-  ModuleBufferizationState state(moduleOp, *options.allocationFns);
+  BufferizationState state(moduleOp, *options.allocationFns);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // Interestingly, all function args that are not visible outside of a module