[MLIR] Cache symbol tables during OneShotBufferization analyses (#138125)

During bufferization, the callee of each `func::CallOp` / `CallableOpInterface` operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the `ModuleOp` body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols. The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/
This patch aims to partially address this scaling issue by leveraging the `SymbolTableCollection` class, whose instance is added to the `FuncAnalysisState` extension. Later modifications are also expected to address the problem in other methods required by `BufferizableOpInterface` (e.g., `bufferize` and `getBufferType`), which suffer of the same problem but do not provide access to any bufferization state.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226..51f3c08 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@
   /// analyzed.
   DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
 
+  /// A collection of cached SymbolTables used for faster function lookup.
+  mutable SymbolTableCollection symbolTables;
+
   /// This function is called right before analyzing the given FuncOp. It
   /// initializes the data structures for the FuncOp in this state object.
   void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 0b0dcc9..3f76a44 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,29 @@
 }
 
 /// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+                                SymbolTableCollection &symbolTables) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTables.lookupNearestSymbolFrom(callOp, sym));
+}
+
+/// Return the FuncOp called by `callOp`.
+static FuncOp getCalledFunction(CallOpInterface callOp,
+                                const AnalysisState &state) {
+  auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
+
+  if (auto *funcAnalysisState =
+          oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
+    // Use the cached symbol tables.
+    return getCalledFunction(callOp, funcAnalysisState->symbolTables);
+  }
+
+  SymbolTableCollection symbolTables;
+  return getCalledFunction(callOp, symbolTables);
 }
 
 /// Get FuncAnalysisState.
@@ -135,7 +151,7 @@
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    FuncOp funcOp = getCalledFunction(callOp, state);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
@@ -150,7 +166,7 @@
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    FuncOp funcOp = getCalledFunction(callOp, state);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
@@ -165,7 +181,7 @@
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+    FuncOp funcOp = getCalledFunction(callOp, state);
     assert(funcOp && "expected CallOp to a FuncOp");
     if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
       // FuncOp not analyzed yet. Any OpResult may be aliasing.
@@ -199,7 +215,11 @@
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
 
     // If the callee was already bufferized, we can directly take the type from
@@ -243,7 +263,11 @@
     // 2. Rewrite tensor operands as memrefs based on type of the already
     //    bufferized callee.
     SmallVector<Value> newOperands;
-    FuncOp funcOp = getCalledFunction(callOp);
+
+    // TODO Avoid recomputing the symbol tables every time.
+    SymbolTableCollection symbolTable;
+
+    FuncOp funcOp = getCalledFunction(callOp, symbolTable);
     assert(funcOp && "expected CallOp to a FuncOp");
     FunctionType funcType = funcOp.getFunctionType();
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf..a025da8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@
 }
 
 /// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+                  mlir::SymbolTableCollection &symbolTable) {
   SymbolRefAttr sym =
       llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<func::FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+      symbolTable.lookupNearestSymbolFrom(callOp, sym));
 }
 
 /// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+  // TODO Avoid recomputing the symbol tables every time.
+  mlir::SymbolTableCollection symbolTable;
+
   for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp);
+      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
       // it is not necessary to bufferize the callee before the caller.