[MLIR] Cache symbol tables during OneShotBufferization analyses (#138125)
During bufferization, the callee of each `func::CallOp` / `CallableOpInterface` operation is retrieved by means of a symbol table that is temporarily built for the lookup purpose. The creation of the symbol table requires a linear scan of the operation body (e.g., a linear scan of the `ModuleOp` body). Considering that functions are typically called at least once, this leads to a scaling behavior that is quadratic with respect to the number of symbols. The problem is described in the following Discourse topic: https://discourse.llvm.org/t/quadratic-scaling-of-bufferization/86122/
This patch aims to partially address this scaling issue by leveraging the `SymbolTableCollection` class, whose instance is added to the `FuncAnalysisState` extension. Later modifications are also expected to address the problem in other methods required by `BufferizableOpInterface` (e.g., `bufferize` and `getBufferType`), which suffer of the same problem but do not provide access to any bufferization state.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index e8e6226..51f3c08 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -69,6 +69,9 @@
/// analyzed.
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+ /// A collection of cached SymbolTables used for faster function lookup.
+ mutable SymbolTableCollection symbolTables;
+
/// This function is called right before analyzing the given FuncOp. It
/// initializes the data structures for the FuncOp in this state object.
void startFunctionAnalysis(FuncOp funcOp);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 0b0dcc9..3f76a44 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -76,13 +76,29 @@
}
/// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
+static FuncOp getCalledFunction(CallOpInterface callOp,
+ SymbolTableCollection &symbolTables) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTables.lookupNearestSymbolFrom(callOp, sym));
+}
+
+/// Return the FuncOp called by `callOp`.
+static FuncOp getCalledFunction(CallOpInterface callOp,
+ const AnalysisState &state) {
+ auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
+
+ if (auto *funcAnalysisState =
+ oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
+ // Use the cached symbol tables.
+ return getCalledFunction(callOp, funcAnalysisState->symbolTables);
+ }
+
+ SymbolTableCollection symbolTables;
+ return getCalledFunction(callOp, symbolTables);
}
/// Get FuncAnalysisState.
@@ -135,7 +151,7 @@
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ FuncOp funcOp = getCalledFunction(callOp, state);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
@@ -150,7 +166,7 @@
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ FuncOp funcOp = getCalledFunction(callOp, state);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
@@ -165,7 +181,7 @@
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+ FuncOp funcOp = getCalledFunction(callOp, state);
assert(funcOp && "expected CallOp to a FuncOp");
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Any OpResult may be aliasing.
@@ -199,7 +215,11 @@
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
// If the callee was already bufferized, we can directly take the type from
@@ -243,7 +263,11 @@
// 2. Rewrite tensor operands as memrefs based on type of the already
// bufferized callee.
SmallVector<Value> newOperands;
- FuncOp funcOp = getCalledFunction(callOp);
+
+ // TODO Avoid recomputing the symbol tables every time.
+ SymbolTableCollection symbolTable;
+
+ FuncOp funcOp = getCalledFunction(callOp, symbolTable);
assert(funcOp && "expected CallOp to a FuncOp");
FunctionType funcType = funcOp.getFunctionType();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index edd6bcf..a025da8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -280,13 +280,15 @@
}
/// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(func::CallOp callOp) {
+static func::FuncOp
+getCalledFunction(func::CallOp callOp,
+ mlir::SymbolTableCollection &symbolTable) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ symbolTable.lookupNearestSymbolFrom(callOp, sym));
}
/// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+
+ // TODO Avoid recomputing the symbol tables every time.
+ mlir::SymbolTableCollection symbolTable;
+
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp);
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.