[mlir][linalg][bufferize][NFC] Extract func boundary bufferization
Bufferization of function boundaries is extracted from ComprehensiveBufferize into a separate file. This will become its own build target in the future.
Differential Revision: https://reviews.llvm.org/D114226
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index ee01a99..f52a9aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -284,10 +284,6 @@
/// Obsolete ops that should be deleted after bufferization.
SmallVector<Operation *> obsoleteOps;
-
- /// A map for looking up bufferized function types.
- // TODO: Entangle function calls and FuncOps from the remaining bufferization.
- DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index a569eac..a1fd04d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -24,9 +24,6 @@
/// Return default allocation callbacks.
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
-/// Register external models implemented for the `BufferizableOpInterface`.
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
-
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
BufferizationOptions();
@@ -61,8 +58,12 @@
std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
};
-LogicalResult runComprehensiveBufferize(ModuleOp moduleOp,
- const BufferizationOptions &options);
+/// Bufferize the given function. Does not bufferize the function boundary.
+// TODO: This function is meant to be called from ModuleBufferize and not can
+// not yet be called standalone.
+LogicalResult runComprehensiveBufferize(FuncOp funcOp,
+ const BufferizationOptions &options,
+ BufferizationState &state);
} // namespace comprehensive_bufferize
} // namespace linalg
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
new file mode 100644
index 0000000..01f687e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -0,0 +1,37 @@
+//===- ModuleBufferization.h - Bufferization across Func. Boundaries ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+struct LogicalResult;
+class ModuleOp;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+
+struct BufferizationOptions;
+
+/// Bufferize the given module. This bufferizations performs a simple function
+/// call analysis to determine which function arguments are inplaceable.
+LogicalResult runComprehensiveBufferize(ModuleOp moduleOp,
+ const BufferizationOptions &options);
+
+namespace std_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace std_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index d323b94..4d7c445 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -459,11 +459,13 @@
op = op->getParentOp();
}
- // FuncOp is an allocation hoisting barrier, so the above loop should never
- // run out of parents.
- assert(
- (op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
- "expected traversal to end at allocation hoisting barrier");
+ if (!op) {
+ // No allocation hoisting barrier found. Hoist to FuncOp.
+ op = b.getInsertionBlock()->getParentOp();
+ if (!isa<FuncOp>(op))
+ op = op->getParentOfType<FuncOp>();
+ assert(op && "could not find enclosing FuncOp");
+ }
// TODO: Handle cases where allocation hoisting barrier has more than one
// region or block.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index c367c00..68d5d03 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -4,6 +4,7 @@
BufferizableOpInterface.cpp
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
+ ModuleBufferization.cpp
SCFInterfaceImpl.cpp
TensorInterfaceImpl.cpp
VectorInterfaceImpl.cpp
@@ -80,6 +81,7 @@
add_mlir_dialect_library(MLIRComprehensiveBufferize
ComprehensiveBufferize.cpp
+ ModuleBufferization.cpp
LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index c10168f..53eaab5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -112,20 +112,12 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AsmState.h"
-#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/InferTypeOpInterface.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -145,35 +137,8 @@
static std::string printValueInfo(Value, bool prefix = true);
#endif
-//===----------------------------------------------------------------------===//
-// Generic helpers.
-//===----------------------------------------------------------------------===//
-
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
-/// Return the FuncOp called by `callOp`.
-static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<FuncOp>(
- SymbolTable::lookupNearestSymbolFrom(callOp, sym));
-}
-
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
- ReturnOp returnOp;
- for (Block &b : funcOp.body()) {
- if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
-}
-
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
// These are for testing and debugging only. Bufferization information is
@@ -216,16 +181,6 @@
BoolAttr::get(bbArg.getContext(), inPlace));
}
-/// Remove the attribute that triggers inplace bufferization on a FuncOp
-/// argument `bbArg`.
-static void removeBufferizationFuncArguments(BlockArgument bbArg) {
- auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
- funcOp.removeArgAttr(bbArg.getArgNumber(),
- BufferizableOpInterface::kBufferLayoutAttrName);
- funcOp.removeArgAttr(bbArg.getArgNumber(),
- BufferizableOpInterface::kInplaceableAttrName);
-}
-
//===----------------------------------------------------------------------===//
// Printing helpers.
//===----------------------------------------------------------------------===//
@@ -568,66 +523,6 @@
}
//===----------------------------------------------------------------------===//
-// Forward declarations.
-//===----------------------------------------------------------------------===//
-
-/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such
-/// an op. Return null otherwise.
-static Operation *getEquivalentAlloc(Value value,
- const BufferizationAliasInfo &aliasInfo);
-
-/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
-/// Return null if no such bbArg can be found.
-static BlockArgument
-getEquivalentEnclosingFuncBBArg(Value v,
- const BufferizationAliasInfo &aliasInfo);
-
-//===----------------------------------------------------------------------===//
-// Bufferization-specific MemRefType support.
-//===----------------------------------------------------------------------===//
-
-/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
-/// tensor is replaced by the corresponding buffer type.
-/// In order for all the callers to agree, this *must* bufferize to the most
-/// dynamic buffer type supported.
-/// A later pass across all CallOps in the module can decide whether to simplify
-/// the types of to version according to some cost model.
-static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
- TypeRange argumentTypes,
- TypeRange resultTypes) {
- auto rewrite = [](Type t) -> Type {
- // TODO: non-zero address space.
- // TODO: layout information if relevant.
- if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
- return getDynamicMemRefType(rankedTensorType);
- if (auto tensorType = t.dyn_cast<TensorType>())
- return getContiguousOrUnrankedMemRefType(tensorType);
- return t;
- };
- auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
- auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
- return FunctionType::get(ctx, argTypes, retTypes);
-}
-
-/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
-/// it. Otherwise, construct a new entry based on `argumentTypes` and
-/// `resultTypes`.
-// TODO: improve the layering.
-static FunctionType getOrCreateBufferizedFunctionType(
- FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
- DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
- auto it = bufferizedFunctionTypes.find(funcOp);
- if (it != bufferizedFunctionTypes.end())
- return it->second;
-
- auto it2 = bufferizedFunctionTypes.try_emplace(
- funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
- resultTypes));
- LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
- return it2.first->second;
-}
-
-//===----------------------------------------------------------------------===//
// Bufferization as simple BlockAndValueMapping rewrites.
//===----------------------------------------------------------------------===//
@@ -774,343 +669,6 @@
return res;
}
-//===----------------------------------------------------------------------===//
-// Bufferization entry-point for modules.
-//===----------------------------------------------------------------------===//
-
-/// Return the op with Allocate MemoryEffect if `v` is equivalent to such an
-/// an op. Return null otherwise.
-static Operation *getEquivalentAlloc(Value value,
- const BufferizationAliasInfo &aliasInfo) {
- Operation *res = nullptr;
- aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
- if (!res)
- if (auto interface =
- dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
- if (auto effect =
- interface.getEffectOnValue<MemoryEffects::Allocate>(v))
- res = v.getDefiningOp();
- });
- return res;
-}
-
-/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
-/// Return null if no such bbArg can be found.
-static BlockArgument
-getEquivalentEnclosingFuncBBArg(Value v,
- const BufferizationAliasInfo &aliasInfo) {
- if (!v.getType().isa<RankedTensorType>())
- return nullptr;
- Operation *op = v.getParentBlock()->getParentOp();
- FuncOp funcOp = dyn_cast<FuncOp>(op);
- if (!funcOp)
- funcOp = op->getParentOfType<FuncOp>();
- assert(funcOp && "expected non-null FuncOp");
- for (BlockArgument bbArg : funcOp.getArguments()) {
- if (!bbArg.getType().isa<RankedTensorType>())
- continue;
- if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
- return bbArg;
- }
- return nullptr;
-}
-
-/// Rewrite the `funcOp` arguments analysis return values and terminator into
-/// buffer form (using the canonical memref layout for now), according to the
-/// inPlace-bufferizable information of the function arguments.
-/// This relies on a buffer equivalence analysis of each return operand. When a
-/// result buffer is equivalent to:
-/// 1. a BlockArgument of `funcOp`, it can be dropped from the return values
-/// and becomes inplaceable at all callers. This assumes all CallOp perform
-/// the necessary work to clone operands so as to make them inplaceable.
-// Reliance on this logic will need to be relaxed in thefuture.
-/// 2. an op with an Alloc effect, this currently fails bufferization but is a
-/// candidate for hoisting and creating a new inplace operand at all caller
-/// sites.
-/// 3. if such a hoisting for 2. is not possible (e.g. data-dependent that
-/// prevents hoisting), this is currently unsupported and will require a
-/// refcounted buffer type.
-static LogicalResult bufferizeFuncOpBoundary(
- FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
- DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
- LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
-
- // If nothing to do then we are done.
- if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
- !llvm::any_of(funcOp.getType().getResults(), isaTensor))
- return success();
-
- // Get the bufferized FunctionType for funcOp or construct it if not yet
- // available.
- // TODO: Atm we have 3 cases:
- // 1. if a function is called from within the Module, it must have bufferized
- // to inplaceable tensor results.
- // 2. if it is bodiless, it must have bufferized and is not allowed to have
- // result tensors.
- // 3. if it is not called internally, it still must bufferize to inplaceable
- // tensor results and we construct it now (e.g. top-level function called
- // externally).
- // -> Figure out a better layering.
- TypeRange resultTypes;
-
- // Corner case: Bodiless FuncOp
- // ============================
- // The body of such functions is assumed opaque and we can't know the
- // bufferization contract they want to enforce atm.
- // As a consequence, only support functions that don't return any tensor atm.
- if (funcOp.getBody().empty()) {
- if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
- return funcOp->emitError() << "cannot bufferize bodiless function that "
- << "returns a tensor";
- FunctionType bufferizedFuncType =
- getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
- TypeRange{}, bufferizedFunctionTypes);
- funcOp.setType(bufferizedFuncType);
- LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
- return success();
- }
-
- // Support only single return-terminated block in the function.
- ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- // 1. For each FuncOp result, keep track of which inplace argument it reuses.
- SmallVector<Value> returnValues;
- for (OpOperand &returnOperand : returnOp->getOpOperands()) {
- // If not a renturn tensor type just forward it.
- if (!returnOperand.get().getType().isa<RankedTensorType>()) {
- returnValues.push_back(returnOperand.get());
- continue;
- }
-
- // If return operand is equivalent to some bbArg, no need to return it.
- Value returnVal = returnOperand.get();
- if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
- continue;
-
- // TODO: Need to hoist above function boundary.
- if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
- returnValues.push_back(allocOp->getResult(0));
- continue;
- }
-
- // Other cases legitimately need to return a tensor, this is currently not
- // supported. For instance, if hoisting across function boundary has
- // failed, it may be due to e.g. data-dependent sizes. In such a case, we
- // would need a better type than memref.
- int64_t returnIdx = returnOperand.getOperandNumber();
- return returnOp->emitError()
- << "buffer result #" << returnIdx << " not produced by an alloc\n";
- }
-
- // 2. Rewrite the terminator without the inPlace bufferizable values.
- ValueRange retValues{returnValues};
- FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
- funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
- bufferizedFunctionTypes);
- OpBuilder b(returnOp);
- b.create<ReturnOp>(returnOp.getLoc(), returnValues);
- returnOp->erase();
-
- // 3. Rewrite the bbArgs.
- // Iterate on the original `numArgs` and replace them in order.
- // This guarantees the argument order still matches after the rewrite.
- Block &frontBlock = funcOp.body().front();
- unsigned numArgs = frontBlock.getNumArguments();
- for (unsigned idx = 0; idx < numArgs; ++idx) {
- auto bbArg = frontBlock.getArgument(0);
- auto tensorType = bbArg.getType().dyn_cast<TensorType>();
- // Non-tensor types are just forwarded.
- if (!tensorType) {
- frontBlock.addArgument(bbArg.getType());
- bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
- frontBlock.eraseArgument(0);
- continue;
- }
-
- // Get the buffer type from the bufferized function type.
- Type memrefType = bufferizedFuncType.getInput(idx);
- Value memref = frontBlock.addArgument(memrefType);
- OpBuilder b(funcOp->getContext());
- b.setInsertionPointToStart(&frontBlock);
- // Replace all uses of bbArg through a ToMemrefOp by a memref::CastOp.
- for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
- if (auto toMemrefOp =
- dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
- auto castOp = b.create<memref::CastOp>(
- funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
- toMemrefOp.memref().replaceAllUsesWith(castOp);
- aliasInfo.insertNewBufferEquivalence(castOp.dest(),
- toMemrefOp.memref());
- }
- }
- // Replace all remaining uses by a tensor_load.
- if (!bbArg.use_empty()) {
- auto toTensorOp =
- b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
- aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
- bbArg.replaceAllUsesWith(toTensorOp);
- }
- frontBlock.eraseArgument(0);
- // TODO: add support to erase aliasInfo entries if deemed necessary.
- }
-
- // 4. Rewrite the FuncOp type to buffer form.
- funcOp.setType(bufferizedFuncType);
-
- LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
-
- return success();
-}
-
-/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
-/// callee-caller order (i.e. callees without callers first).
-/// Store the map of FuncOp to all its callers in `callerMap`.
-/// Return `failure()` if a cycle of calls is detected or if we are unable to
-/// retrieve the called FuncOp from any CallOpInterface.
-static LogicalResult
-getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<FuncOp> &orderedFuncOps,
- DenseMap<FuncOp, DenseSet<Operation *>> &callerMap) {
- // For each FuncOp, the set of functions called by it (i.e. the union of
- // symbols of all nested CallOpInterfaceOp).
- DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
- // For each FuncOp, the number of CallOpInterface it contains.
- DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
- WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
- if (!funcOp.body().empty()) {
- ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
- }
-
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
- // Only support CallOp for now.
- if (!isa<CallOp>(callOp.getOperation()))
- return callOp->emitError() << "expected a CallOp";
- FuncOp calledFunction = getCalledFunction(callOp);
- assert(calledFunction && "could not retrieved called FuncOp");
- auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
- it.first->getSecond().insert(callOp);
- if (calledBy[calledFunction].count(funcOp) == 0) {
- calledBy[calledFunction].insert(funcOp);
- numberCallOpsContainedInFuncOp[funcOp]++;
- }
- return WalkResult::advance();
- });
- });
- if (res.wasInterrupted())
- return failure();
- // Iteratively remove function operation that do not call any of the
- // functions remaining in the callCounter map and add them to the worklist.
- while (!numberCallOpsContainedInFuncOp.empty()) {
- auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
- [](auto entry) { return entry.getSecond() == 0; });
- if (it == numberCallOpsContainedInFuncOp.end())
- return moduleOp.emitOpError(
- "expected callgraph to be free of circular dependencies.");
- orderedFuncOps.push_back(it->getFirst());
- for (auto callee : calledBy[it->getFirst()])
- numberCallOpsContainedInFuncOp[callee]--;
- numberCallOpsContainedInFuncOp.erase(it);
- }
- return success();
-}
-
-static void
-foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
- FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
- auto itCallers = callerMap.find(callee);
- if (itCallers == callerMap.end())
- return;
- for (Operation *caller : itCallers->second)
- doit(caller);
-}
-
-/// Postprocess the linalg.buffer_layout annotation across function boundaries.
-/// This is a purely mechanical process that may later become part of a
-/// separate pass with its own layout assignment heuristic.
-static void layoutPostProcessing(ModuleOp moduleOp) {
- SmallVector<FuncOp> orderedFuncOps;
- DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
- auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
- (void)res;
- assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
-
- for (FuncOp funcOp : orderedFuncOps) {
- DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
- foreachCaller(callerMap, funcOp, [&](Operation *caller) {
- operandsPerCaller.try_emplace(caller, SmallVector<Value>());
- });
-
- SmallVector<Type> argumentTypes;
- // Iterate on each function argument and check it it was marked with a
- // desired layout.
- for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
- int argNumber = it.index();
- Type inputType = it.value();
- auto memrefType = inputType.dyn_cast<MemRefType>();
- auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
- argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
- AffineMap desiredLayoutMap =
- layoutAttr ? layoutAttr.getValue() : AffineMap();
- AffineMap currentLayoutMap =
- memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
- if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
- argumentTypes.push_back(inputType);
- foreachCaller(callerMap, funcOp, [&](Operation *caller) {
- operandsPerCaller.find(caller)->getSecond().push_back(
- caller->getOperand(argNumber));
- });
- continue;
- }
-
- // Compute the buffer type with desired layout and add to input argument
- // types.
- MemRefType desiredMemrefType = MemRefType::get(
- memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
- argumentTypes.push_back(desiredMemrefType);
-
- // If funcOp's body is not empty, change the bbArg type and propagate.
- if (!funcOp.body().empty()) {
- BlockArgument bbArg = funcOp.getArgument(argNumber);
- bbArg.setType(desiredMemrefType);
- OpBuilder b(bbArg.getContext());
- b.setInsertionPointToStart(bbArg.getOwner());
- // Cast back to the original memrefType and let it canonicalize.
- Value cast =
- b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
- bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
- }
-
- // Cast to desired buffer type on all callers to `funcOp`.
- // TODO: on the callee side, this may even have to trigger a copy to
- // change the layout. For now let the memref::CastOp fail to verify in
- // such cases.
- auto castArg = [&](Operation *caller) {
- OpBuilder b(caller);
- Value newOperand = b.create<memref::CastOp>(
- funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
- operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
- };
- foreachCaller(callerMap, funcOp, castArg);
- }
-
- // Set operands with cast buffer on all callers to `funcOp`.
- foreachCaller(callerMap, funcOp, [&](Operation *caller) {
- caller->setOperands(operandsPerCaller.lookup(caller));
- });
-
- // Finally set the funcOp type to update the arguments.
- auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
- funcOp.getType().getResults());
- funcOp.setType(newFuncType);
- }
-}
-
#ifndef NDEBUG
/// Assert that the current bufferization decisions are consistent.
static void checkAliasInfoConsistency(FuncOp funcOp,
@@ -1149,96 +707,46 @@
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
- ModuleOp moduleOp, const BufferizationOptions &options) {
- SmallVector<FuncOp> orderedFuncOps;
- DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
- return failure();
+ FuncOp funcOp, const BufferizationOptions &options,
+ BufferizationState &state) {
- DominanceInfo domInfo(moduleOp);
- BufferizationState state(moduleOp, *options.allocationFns);
+ DominanceInfo domInfo(funcOp);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
- // Interestingly, all function args that are not visible outside of a module
- // can be fully bufferized inplace by guaranteeing the CallOp is bufferized
- // inplace. Therefore, we just bufferize funcOp as if none of its results were
- // inplaceable, detect which operands are cloned internally and decide what to
- // do at call sites.
- for (FuncOp funcOp : orderedFuncOps) {
- // No body => no analysis.
- if (funcOp.body().empty())
- continue;
-
- // In a first approximation:
- // =========================
- // If the function is called, we can allocate on the caller side which lets
- // us force inplace arguments at function boundaries.
- // TODO: do not rely on this behavior.
- if (callerMap.find(funcOp) != callerMap.end())
- for (BlockArgument bbArg : funcOp.getArguments())
- if (bbArg.getType().isa<TensorType>())
- aliasInfo.setBufferizesToWritableMemory(bbArg);
+ if (funcOp.body().empty())
+ return success();
#ifndef NDEBUG
- checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
+ checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
#endif // NDEBUG
- // If the analysis fails, just return.
- if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
- options.analysisFuzzerSeed)))
+ // If the analysis fails, just return.
+ if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
+ options.analysisFuzzerSeed)))
+ return failure();
+
+ for (const std::unique_ptr<PostAnalysisStep> &step :
+ options.postAnalysisSteps) {
+ SmallVector<Operation *> newOps;
+ if (failed(step->run(funcOp, aliasInfo, domInfo, newOps)))
return failure();
-
- for (const std::unique_ptr<PostAnalysisStep> &step :
- options.postAnalysisSteps) {
- SmallVector<Operation *> newOps;
- if (failed(step->run(funcOp, aliasInfo, domInfo, newOps)))
- return failure();
- // Analyze ops that were created by the PostAnalysisStep.
- if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
- return failure();
- }
-
- // Bufferization phase.
- if (!options.testAnalysisOnly) {
- // Bufferize all ops in funcOp.
- if (failed(bufferizeFuncOp(funcOp, state)))
- return failure();
-
- // Erase all obsolete ops.
- state.eraseObsoleteOps();
- }
+ // Analyze ops that were created by the PostAnalysisStep.
+ if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
+ return failure();
}
+
// Annotate operations if we only want to report the analysis.
if (options.testAnalysisOnly) {
- annotateOpsWithBufferizationMarkers(moduleOp, aliasInfo);
+ annotateOpsWithBufferizationMarkers(funcOp, aliasInfo);
return success();
}
- for (FuncOp funcOp : orderedFuncOps) {
- // Note: It would be good to apply cleanups here but we cannot as aliasInfo
- // would be invalidated.
- if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
- state.bufferizedFunctionTypes)))
- return failure();
+ // Bufferize all ops in funcOp.
+ if (failed(bufferizeFuncOp(funcOp, state)))
+ return failure();
- if (!options.allowReturnMemref &&
- llvm::any_of(funcOp.getType().getResults(), [](Type t) {
- return t.isa<MemRefType, UnrankedMemRefType>();
- })) {
- funcOp->emitError("memref return type is unsupported");
- return failure();
- }
- }
-
- // Perform a post-processing pass of layout modification at function boundary
- // according to the kBufferLayoutAttrName.
- layoutPostProcessing(moduleOp);
-
- // Post-pass cleanup of inplaceable and buffer_layout attributes.
- moduleOp.walk([&](FuncOp op) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationFuncArguments(bbArg);
- });
+ // Erase all obsolete ops.
+ state.eraseObsoleteOps();
return success();
}
@@ -1278,243 +786,3 @@
BufferizationOptions::BufferizationOptions()
: allocationFns(defaultAllocationCallbacks()) {}
-//===----------------------------------------------------------------------===//
-// BufferizableOpInterface Implementations
-//===----------------------------------------------------------------------===//
-
-// TODO: Move these to a different file and BUILD target, so that they are
-// decoupled from ComprehensiveBufferize.
-
-namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace std_ext {
-
-struct CallOpInterface
- : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
- // of the matching bbArg may. It is the responsibility of the caller to
- // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
- // conservative.
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- // CallOpInterface alone doesn't bufferize to a memory write, one of the
- // uses of the matching bbArg may. It is the responsibility of the caller to
- // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
- // conservative.
- return true;
- }
-
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
- // TODO: Can we do better?
- return {};
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- // CallOpInterface is special, it needs to wait for the callee to be
- // bufferized and needs to inspect the BufferAliasInfo object. It can't
- // make a proper determination by itself and needs to be conservative.
- return OpResult();
- }
-
- BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
- return BufferRelation::Equivalent;
- }
-
- /// In a first approximation, all the function arguments of a FuncOp are
- /// marked inplaceable. For now, it is the responsibility of the `callOp`
- /// bufferization to allow FuncOp that are inplaceable to write inPlace.
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- CallOp callOp = cast<CallOp>(op);
- FuncOp funcOp = getCalledFunction(callOp);
- assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
- "expected Callop to a FuncOp");
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(callOp);
-
- // 1. Filter return types:
- // - if the callee is bodiless / external, we cannot inspect it and we
- // cannot assume anything. We can just assert that it does not return a
- // tensor as this would have to bufferize to "return a memref", whose
- // semantics is ill-defined.
- // - if the callee has a body, we perform inter-procedural equivalence
- // analysis. When successful, a result folds onto an operand. When
- // unsuccessful, additional work is needed to either:
- // * hoist a result into an inplaceable operand or
- // * devise a better representation to truly return a buffer.
- SmallVector<Type> resultTypes;
- SmallVector<Value> hoistedArguments;
- if (funcOp.body().empty()) {
- if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
- return callOp->emitError()
- << "cannot bufferize bodiless function that returns a tensor";
- } else {
- ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- // For each FuncOp result, keep track of which inplace argument it reuses.
- for (OpOperand &returnOperand : returnOp->getOpOperands()) {
- Type returnType = returnOperand.get().getType();
- if (!isaTensor(returnType)) {
- resultTypes.push_back(returnType);
- continue;
- }
-
- // If return operand is equivalent to some bbArg, no need to return it.
- Value returnVal = returnOperand.get();
- if (BlockArgument bbArg =
- getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
- Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
- int64_t idx = bbArg.getArgNumber();
- Value buffer = state.lookupBuffer(callOp->getOperand(idx));
- // Add CallOp operand/result equivalence: this is interprocedural
- // info.
- state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
- state.mapBuffer(oldRes, buffer);
- // Add a ToTensorOp to kill all uses of the CallOp return.
- // Replace all uses of the CallOp results so we can erase the CallOp.
- // This ToTensorOp must fold/DCE away or bufferization should be
- // considered failed.
- Value toTensor =
- b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
- oldRes.replaceAllUsesWith(toTensor);
- // Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(toTensor, buffer);
- state.mapBuffer(toTensor, buffer);
- continue;
- }
-
- // TODO: Need to hoist above function boundary.
- if (Operation *allocOp =
- getEquivalentAlloc(returnVal, state.aliasInfo)) {
- hoistedArguments.push_back(allocOp->getResult(0));
- continue;
- }
-
- // Other cases legitimately need to return a tensor, this is currently
- // not supported. For instance, if hoisting across function boundary has
- // failed, it may be due to e.g. data-dependent sizes. In such a case,
- // we would we need a better type than memref.
- resultTypes.push_back(returnType);
-
- int64_t returnIdx = returnOperand.getOperandNumber();
- return returnOp->emitError() << "buffer result #" << returnIdx
- << " not produced by an alloc\n";
- }
- }
-
- // 2. Compute bufferized FunctionType.
- SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
- ValueRange hoistedArgs{hoistedArguments};
- llvm::append_range(argumentTypes, hoistedArgs.getTypes());
- // Get the bufferized FunctionType for funcOp or construct it if not yet
- // available.
- FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
- funcOp, argumentTypes, resultTypes, state.bufferizedFunctionTypes);
-
- // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
- SmallVector<Value> newOperands;
- newOperands.reserve(callOp->getNumOperands());
- for (OpOperand &opOperand : callOp->getOpOperands()) {
- Value tensorOperand = opOperand.get();
- // Non-tensor operands are just copied.
- if (!tensorOperand.getType().isa<TensorType>()) {
- newOperands.push_back(tensorOperand);
- continue;
- }
-
- // Tensor operands are guaranteed to have been buferized.
- int64_t idx = opOperand.getOperandNumber();
- Value buffer = state.lookupBuffer(tensorOperand);
-
- // Caller / callee type mistmatch is handled with a CastOp.
- auto memRefType = bufferizedFuncType.getInput(idx);
- // Since we don't yet have a clear layout story, buffer_cast may
- // conservatively turn tensors into more dynamic memref than necessary.
- // If the memref type of the callee fails, introduce an extra memref.cast
- // that will either canonicalize away or fail compilation until we can do
- // something better.
- if (buffer.getType() != memRefType) {
- Value castBuffer =
- b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
- // Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
- state.mapBuffer(tensorOperand, castBuffer);
- buffer = castBuffer;
- }
- newOperands.push_back(buffer);
- }
-
- // 4. Create the new CallOp.
- Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
- resultTypes, newOperands);
- newCallOp->setAttrs(callOp->getAttrs());
-
- // 5. Delete the op at the end of bufferization.
- state.markOpObsolete(callOp);
-
- return success();
- }
-};
-
-struct ReturnOpInterface
- : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
- ReturnOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return OpResult();
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- auto returnOp = cast<ReturnOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- // Cannot insert after returnOp.
- b.setInsertionPoint(returnOp);
-
- assert(isa<FuncOp>(returnOp->getParentOp()) &&
- "only support FuncOp parent for ReturnOp");
- for (OpOperand &operand : returnOp->getOpOperands()) {
- auto tensorType = operand.get().getType().dyn_cast<TensorType>();
- if (!tensorType)
- continue;
- Value v = state.lookupBuffer(operand.get());
- Value returnTensor =
- b.create<bufferization::ToTensorOp>(returnOp.getLoc(), v);
- operand.set(returnTensor);
- state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
- state.mapBuffer(returnTensor, v);
- }
- return success();
- }
-};
-
-} // namespace std_ext
-
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
- registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
- registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
-
- // Ops that are not bufferizable but are allocation hoisting barriers.
- registry.addOpInterface<FuncOp, AllocationHoistingBarrierOnly<FuncOp>>();
-}
-
-} // namespace comprehensive_bufferize
-} // namespace linalg
-} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
new file mode 100644
index 0000000..03d1cab
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -0,0 +1,748 @@
+//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#define DEBUG_TYPE "comprehensive-module-bufferize"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X)
+
+using namespace mlir;
+using namespace linalg;
+using namespace tensor;
+using namespace comprehensive_bufferize;
+
+namespace {
+/// A specialization of BufferizationState that keeps track of additional
+/// state required for bufferization of function boundaries.
+struct ModuleBufferizationState : public BufferizationState {
+ using BufferizationState::BufferizationState;
+
+ /// A map for looking up bufferized function types.
+ DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
+};
+} // namespace
+
+static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+
+/// Remove the attribute that triggers inplace bufferization on a FuncOp
+/// argument `bbArg`.
+static void removeBufferizationFuncArguments(BlockArgument bbArg) {
+ auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
+ funcOp.removeArgAttr(bbArg.getArgNumber(),
+ BufferizableOpInterface::kBufferLayoutAttrName);
+ funcOp.removeArgAttr(bbArg.getArgNumber(),
+ BufferizableOpInterface::kInplaceableAttrName);
+}
+
+/// Return the FuncOp called by `callOp`.
+static FuncOp getCalledFunction(CallOpInterface callOp) {
+ SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
+ if (!sym)
+ return nullptr;
+ return dyn_cast_or_null<FuncOp>(
+ SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+}
+
+/// Return the unique ReturnOp that terminates `funcOp`.
+/// Return nullptr if there is no such unique ReturnOp.
+static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
+ ReturnOp returnOp;
+ for (Block &b : funcOp.body()) {
+ if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
+ if (returnOp)
+ return nullptr;
+ returnOp = candidateOp;
+ }
+ }
+ return returnOp;
+}
+
+/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
+/// tensor is replaced by the corresponding buffer type.
+/// In order for all the callers to agree, this *must* bufferize to the most
+/// dynamic buffer type supported.
+/// A later pass across all CallOps in the module can decide whether to simplify
+/// the types of to version according to some cost model.
+static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
+ TypeRange argumentTypes,
+ TypeRange resultTypes) {
+ auto rewrite = [](Type t) -> Type {
+ // TODO: non-zero address space.
+ // TODO: layout information if relevant.
+ if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
+ return getDynamicMemRefType(rankedTensorType);
+ if (auto tensorType = t.dyn_cast<TensorType>())
+ return getContiguousOrUnrankedMemRefType(tensorType);
+ return t;
+ };
+ auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
+ auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
+ return FunctionType::get(ctx, argTypes, retTypes);
+}
+
+/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
+/// it. Otherwise, construct a new entry based on `argumentTypes` and
+/// `resultTypes`.
+// TODO: improve the layering.
+static FunctionType getOrCreateBufferizedFunctionType(
+ FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+ auto it = bufferizedFunctionTypes.find(funcOp);
+ if (it != bufferizedFunctionTypes.end())
+ return it->second;
+
+ auto it2 = bufferizedFunctionTypes.try_emplace(
+ funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
+ resultTypes));
+ LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
+ return it2.first->second;
+}
+
+/// Return the op with Allocate MemoryEffect if `v` is equivalent to such an
+/// an op. Return null otherwise.
+static Operation *getEquivalentAlloc(Value value,
+ const BufferizationAliasInfo &aliasInfo) {
+ Operation *res = nullptr;
+ aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
+ if (!res)
+ if (auto interface =
+ dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
+ if (auto effect =
+ interface.getEffectOnValue<MemoryEffects::Allocate>(v))
+ res = v.getDefiningOp();
+ });
+ return res;
+}
+
+/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
+/// Return null if no such bbArg can be found.
+static BlockArgument
+getEquivalentEnclosingFuncBBArg(Value v,
+ const BufferizationAliasInfo &aliasInfo) {
+ if (!v.getType().isa<RankedTensorType>())
+ return nullptr;
+ Operation *op = v.getParentBlock()->getParentOp();
+ FuncOp funcOp = dyn_cast<FuncOp>(op);
+ if (!funcOp)
+ funcOp = op->getParentOfType<FuncOp>();
+ assert(funcOp && "expected non-null FuncOp");
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (!bbArg.getType().isa<RankedTensorType>())
+ continue;
+ if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
+ return bbArg;
+ }
+ return nullptr;
+}
+
+/// Rewrite the `funcOp` arguments analysis return values and terminator into
+/// buffer form (using the canonical memref layout for now), according to the
+/// inPlace-bufferizable information of the function arguments.
+/// This relies on a buffer equivalence analysis of each return operand. When a
+/// result buffer is equivalent to:
+/// 1. a BlockArgument of `funcOp`, it can be dropped from the return values
+/// and becomes inplaceable at all callers. This assumes all CallOp perform
+/// the necessary work to clone operands so as to make them inplaceable.
+// Reliance on this logic will need to be relaxed in thefuture.
+/// 2. an op with an Alloc effect, this currently fails bufferization but is a
+/// candidate for hoisting and creating a new inplace operand at all caller
+/// sites.
+/// 3. if such a hoisting for 2. is not possible (e.g. data-dependent that
+/// prevents hoisting), this is currently unsupported and will require a
+/// refcounted buffer type.
+static LogicalResult bufferizeFuncOpBoundary(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+ LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+
+ // If nothing to do then we are done.
+ if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
+ !llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return success();
+
+ // Get the bufferized FunctionType for funcOp or construct it if not yet
+ // available.
+ // TODO: Atm we have 3 cases:
+ // 1. if a function is called from within the Module, it must have bufferized
+ // to inplaceable tensor results.
+ // 2. if it is bodiless, it must have bufferized and is not allowed to have
+ // result tensors.
+ // 3. if it is not called internally, it still must bufferize to inplaceable
+ // tensor results and we construct it now (e.g. top-level function called
+ // externally).
+ // -> Figure out a better layering.
+ TypeRange resultTypes;
+
+ // Corner case: Bodiless FuncOp
+ // ============================
+ // The body of such functions is assumed opaque and we can't know the
+ // bufferization contract they want to enforce atm.
+ // As a consequence, only support functions that don't return any tensor atm.
+ if (funcOp.getBody().empty()) {
+ if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return funcOp->emitError() << "cannot bufferize bodiless function that "
+ << "returns a tensor";
+ FunctionType bufferizedFuncType =
+ getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
+ TypeRange{}, bufferizedFunctionTypes);
+ funcOp.setType(bufferizedFuncType);
+ LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
+ return success();
+ }
+
+ // Support only single return-terminated block in the function.
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ assert(returnOp && "expected func with single return op");
+
+ // 1. For each FuncOp result, keep track of which inplace argument it reuses.
+ SmallVector<Value> returnValues;
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ // If not a renturn tensor type just forward it.
+ if (!returnOperand.get().getType().isa<RankedTensorType>()) {
+ returnValues.push_back(returnOperand.get());
+ continue;
+ }
+
+ // If return operand is equivalent to some bbArg, no need to return it.
+ Value returnVal = returnOperand.get();
+ if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
+ continue;
+
+ // TODO: Need to hoist above function boundary.
+ if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
+ returnValues.push_back(allocOp->getResult(0));
+ continue;
+ }
+
+ // Other cases legitimately need to return a tensor, this is currently not
+ // supported. For instance, if hoisting across function boundary has
+ // failed, it may be due to e.g. data-dependent sizes. In such a case, we
+ // would need a better type than memref.
+ int64_t returnIdx = returnOperand.getOperandNumber();
+ return returnOp->emitError()
+ << "buffer result #" << returnIdx << " not produced by an alloc\n";
+ }
+
+ // 2. Rewrite the terminator without the inPlace bufferizable values.
+ ValueRange retValues{returnValues};
+ FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+ funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
+ bufferizedFunctionTypes);
+ OpBuilder b(returnOp);
+ b.create<ReturnOp>(returnOp.getLoc(), returnValues);
+ returnOp->erase();
+
+ // 3. Rewrite the bbArgs.
+ // Iterate on the original `numArgs` and replace them in order.
+ // This guarantees the argument order still matches after the rewrite.
+ Block &frontBlock = funcOp.body().front();
+ unsigned numArgs = frontBlock.getNumArguments();
+ for (unsigned idx = 0; idx < numArgs; ++idx) {
+ auto bbArg = frontBlock.getArgument(0);
+ auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+ // Non-tensor types are just forwarded.
+ if (!tensorType) {
+ frontBlock.addArgument(bbArg.getType());
+ bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
+ frontBlock.eraseArgument(0);
+ continue;
+ }
+
+ // Get the buffer type from the bufferized function type.
+ Type memrefType = bufferizedFuncType.getInput(idx);
+ Value memref = frontBlock.addArgument(memrefType);
+ OpBuilder b(funcOp->getContext());
+ b.setInsertionPointToStart(&frontBlock);
+ // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
+ for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
+ if (auto toMemrefOp =
+ dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
+ auto castOp = b.create<memref::CastOp>(
+ funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
+ toMemrefOp.memref().replaceAllUsesWith(castOp);
+ aliasInfo.insertNewBufferEquivalence(castOp.dest(),
+ toMemrefOp.memref());
+ }
+ }
+ // Replace all remaining uses by a to_tensor.
+ if (!bbArg.use_empty()) {
+ auto toTensorOp =
+ b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
+ aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
+ bbArg.replaceAllUsesWith(toTensorOp);
+ }
+ frontBlock.eraseArgument(0);
+ // TODO: add support to erase aliasInfo entries if deemed necessary.
+ }
+
+ // 4. Rewrite the FuncOp type to buffer form.
+ funcOp.setType(bufferizedFuncType);
+
+ LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
+
+ return success();
+}
+
+/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
+/// callee-caller order (i.e. callees without callers first).
+/// Store the map of FuncOp to all its callers in `callerMap`.
+/// Return `failure()` if a cycle of calls is detected or if we are unable to
+/// retrieve the called FuncOp from any CallOpInterface.
+static LogicalResult
+getFuncOpsOrderedByCalls(ModuleOp moduleOp,
+ SmallVectorImpl<FuncOp> &orderedFuncOps,
+ DenseMap<FuncOp, DenseSet<Operation *>> &callerMap) {
+ // For each FuncOp, the set of functions called by it (i.e. the union of
+ // symbols of all nested CallOpInterfaceOp).
+ DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
+ // For each FuncOp, the number of CallOpInterface it contains.
+ DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
+ WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
+ if (!funcOp.body().empty()) {
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!returnOp)
+ return funcOp->emitError()
+ << "cannot bufferize a FuncOp with tensors and "
+ "without a unique ReturnOp";
+ }
+
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+ // Only support CallOp for now.
+ if (!isa<CallOp>(callOp.getOperation()))
+ return callOp->emitError() << "expected a CallOp";
+ FuncOp calledFunction = getCalledFunction(callOp);
+ assert(calledFunction && "could not retrieved called FuncOp");
+ auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
+ it.first->getSecond().insert(callOp);
+ if (calledBy[calledFunction].count(funcOp) == 0) {
+ calledBy[calledFunction].insert(funcOp);
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ });
+ if (res.wasInterrupted())
+ return failure();
+ // Iteratively remove function operation that do not call any of the
+ // functions remaining in the callCounter map and add them to the worklist.
+ while (!numberCallOpsContainedInFuncOp.empty()) {
+ auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
+ [](auto entry) { return entry.getSecond() == 0; });
+ if (it == numberCallOpsContainedInFuncOp.end())
+ return moduleOp.emitOpError(
+ "expected callgraph to be free of circular dependencies.");
+ orderedFuncOps.push_back(it->getFirst());
+ for (auto callee : calledBy[it->getFirst()])
+ numberCallOpsContainedInFuncOp[callee]--;
+ numberCallOpsContainedInFuncOp.erase(it);
+ }
+ return success();
+}
+
+static void
+foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
+ FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
+ auto itCallers = callerMap.find(callee);
+ if (itCallers == callerMap.end())
+ return;
+ for (Operation *caller : itCallers->second)
+ doit(caller);
+}
+
+/// Postprocess the linalg.buffer_layout annotation across function boundaries.
+/// This is a purely mechanical process that may later become part of a
+/// separate pass with its own layout assignment heuristic.
+static void layoutPostProcessing(ModuleOp moduleOp) {
+ SmallVector<FuncOp> orderedFuncOps;
+ DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+ auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
+ (void)res;
+ assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
+
+ for (FuncOp funcOp : orderedFuncOps) {
+ DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
+ foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+ operandsPerCaller.try_emplace(caller, SmallVector<Value>());
+ });
+
+ SmallVector<Type> argumentTypes;
+ // Iterate on each function argument and check it it was marked with a
+ // desired layout.
+ for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
+ int argNumber = it.index();
+ Type inputType = it.value();
+ auto memrefType = inputType.dyn_cast<MemRefType>();
+ auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
+ argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
+ AffineMap desiredLayoutMap =
+ layoutAttr ? layoutAttr.getValue() : AffineMap();
+ AffineMap currentLayoutMap =
+ memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
+ if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
+ argumentTypes.push_back(inputType);
+ foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+ operandsPerCaller.find(caller)->getSecond().push_back(
+ caller->getOperand(argNumber));
+ });
+ continue;
+ }
+
+ // Compute the buffer type with desired layout and add to input argument
+ // types.
+ MemRefType desiredMemrefType = MemRefType::get(
+ memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
+ argumentTypes.push_back(desiredMemrefType);
+
+ // If funcOp's body is not empty, change the bbArg type and propagate.
+ if (!funcOp.body().empty()) {
+ BlockArgument bbArg = funcOp.getArgument(argNumber);
+ bbArg.setType(desiredMemrefType);
+ OpBuilder b(bbArg.getContext());
+ b.setInsertionPointToStart(bbArg.getOwner());
+ // Cast back to the original memrefType and let it canonicalize.
+ Value cast =
+ b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
+ bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
+ }
+
+ // Cast to desired buffer type on all callers to `funcOp`.
+ // TODO: on the callee side, this may even have to trigger a copy to
+ // change the layout. For now let the memref::CastOp fail to verify in
+ // such cases.
+ auto castArg = [&](Operation *caller) {
+ OpBuilder b(caller);
+ Value newOperand = b.create<memref::CastOp>(
+ funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
+ operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
+ };
+ foreachCaller(callerMap, funcOp, castArg);
+ }
+
+ // Set operands with cast buffer on all callers to `funcOp`.
+ foreachCaller(callerMap, funcOp, [&](Operation *caller) {
+ caller->setOperands(operandsPerCaller.lookup(caller));
+ });
+
+ // Finally set the funcOp type to update the arguments.
+ auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
+ funcOp.getType().getResults());
+ funcOp.setType(newFuncType);
+ }
+}
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace std_ext {
+
+struct CallOpInterface
+ : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
+ // of the matching bbArg may. It is the responsibility of the caller to
+ // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
+ // conservative.
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ // CallOpInterface alone doesn't bufferize to a memory write, one of the
+ // uses of the matching bbArg may. It is the responsibility of the caller to
+ // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
+ // conservative.
+ return true;
+ }
+
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ // TODO: Can we do better?
+ return {};
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ // CallOpInterface is special, it needs to wait for the callee to be
+ // bufferized and needs to inspect the BufferAliasInfo object. It can't
+ // make a proper determination by itself and needs to be conservative.
+ return OpResult();
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+ return BufferRelation::Equivalent;
+ }
+
+ /// In a first approximation, all the function arguments of a FuncOp are
+ /// marked inplaceable. For now, it is the responsibility of the `callOp`
+ /// bufferization to allow FuncOp that are inplaceable to write inPlace.
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ CallOp callOp = cast<CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
+ "expected Callop to a FuncOp");
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(callOp);
+
+ // 1. Filter return types:
+ // - if the callee is bodiless / external, we cannot inspect it and we
+ // cannot assume anything. We can just assert that it does not return a
+ // tensor as this would have to bufferize to "return a memref", whose
+ // semantics is ill-defined.
+ // - if the callee has a body, we perform inter-procedural equivalence
+ // analysis. When successful, a result folds onto an operand. When
+ // unsuccessful, additional work is needed to either:
+ // * hoist a result into an inplaceable operand or
+ // * devise a better representation to truly return a buffer.
+ SmallVector<Type> resultTypes;
+ SmallVector<Value> hoistedArguments;
+ if (funcOp.body().empty()) {
+ if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return callOp->emitError()
+ << "cannot bufferize bodiless function that returns a tensor";
+ } else {
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ assert(returnOp && "expected func with single return op");
+
+ // For each FuncOp result, keep track of which inplace argument it reuses.
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ Type returnType = returnOperand.get().getType();
+ if (!isaTensor(returnType)) {
+ resultTypes.push_back(returnType);
+ continue;
+ }
+
+ // If return operand is equivalent to some bbArg, no need to return it.
+ Value returnVal = returnOperand.get();
+ if (BlockArgument bbArg =
+ getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
+ Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
+ int64_t idx = bbArg.getArgNumber();
+ Value buffer = state.lookupBuffer(callOp->getOperand(idx));
+ // Add CallOp operand/result equivalence: this is interprocedural
+ // info.
+ state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
+ state.mapBuffer(oldRes, buffer);
+ // Add a ToTensorOp to kill all uses of the CallOp return.
+ // Replace all uses of the CallOp results so we can erase the CallOp.
+ // This ToTensorOp must fold/DCE away or bufferization should be
+ // considered failed.
+ Value toTensorOp =
+ b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
+ oldRes.replaceAllUsesWith(toTensorOp);
+ // Add new op equivalence info.
+ state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
+ state.mapBuffer(toTensorOp, buffer);
+ continue;
+ }
+
+ // TODO: Need to hoist above function boundary.
+ if (Operation *allocOp =
+ getEquivalentAlloc(returnVal, state.aliasInfo)) {
+ hoistedArguments.push_back(allocOp->getResult(0));
+ continue;
+ }
+
+ // Other cases legitimately need to return a tensor, this is currently
+ // not supported. For instance, if hoisting across function boundary has
+ // failed, it may be due to e.g. data-dependent sizes. In such a case,
+ // we would we need a better type than memref.
+ resultTypes.push_back(returnType);
+
+ int64_t returnIdx = returnOperand.getOperandNumber();
+ return returnOp->emitError() << "buffer result #" << returnIdx
+ << " not produced by an alloc\n";
+ }
+ }
+
+ // 2. Compute bufferized FunctionType.
+ SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
+ ValueRange hoistedArgs{hoistedArguments};
+ llvm::append_range(argumentTypes, hoistedArgs.getTypes());
+ // Get the bufferized FunctionType for funcOp or construct it if not yet
+ // available.
+ // TODO: Assert that `state` is a ModuleBufferizationState.
+ FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+ funcOp, argumentTypes, resultTypes,
+ static_cast<ModuleBufferizationState &>(state).bufferizedFunctionTypes);
+
+ // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
+ SmallVector<Value> newOperands;
+ newOperands.reserve(callOp->getNumOperands());
+ for (OpOperand &opOperand : callOp->getOpOperands()) {
+ Value tensorOperand = opOperand.get();
+ // Non-tensor operands are just copied.
+ if (!tensorOperand.getType().isa<TensorType>()) {
+ newOperands.push_back(tensorOperand);
+ continue;
+ }
+
+ // Tensor operands are guaranteed to have been buferized.
+ int64_t idx = opOperand.getOperandNumber();
+ Value buffer = state.lookupBuffer(tensorOperand);
+
+ // Caller / callee type mistmatch is handled with a CastOp.
+ auto memRefType = bufferizedFuncType.getInput(idx);
+ // Since we don't yet have a clear layout story, buffer_cast may
+ // conservatively turn tensors into more dynamic memref than necessary.
+ // If the memref type of the callee fails, introduce an extra memref.cast
+ // that will either canonicalize away or fail compilation until we can do
+ // something better.
+ if (buffer.getType() != memRefType) {
+ Value castBuffer =
+ b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
+ // Add new op equivalence info.
+ state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
+ state.mapBuffer(tensorOperand, castBuffer);
+ buffer = castBuffer;
+ }
+ newOperands.push_back(buffer);
+ }
+
+ // 4. Create the new CallOp.
+ Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
+ resultTypes, newOperands);
+ newCallOp->setAttrs(callOp->getAttrs());
+
+ // 5. Delete the op at the end of bufferization.
+ state.markOpObsolete(callOp);
+
+ return success();
+ }
+};
+
+struct ReturnOpInterface
+ : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
+ ReturnOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return OpResult();
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto returnOp = cast<ReturnOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ // Cannot insert after returnOp.
+ b.setInsertionPoint(returnOp);
+
+ assert(isa<FuncOp>(returnOp->getParentOp()) &&
+ "only support FuncOp parent for ReturnOp");
+ for (OpOperand &operand : returnOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+ Value v = state.lookupBuffer(operand.get());
+ Value returnTensor = b.create<bufferization::ToTensorOp>(
+ returnOp.getLoc(), v);
+ operand.set(returnTensor);
+ state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
+ state.mapBuffer(returnTensor, v);
+ }
+ return success();
+ }
+};
+
+} // namespace std_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::std_ext::
+ registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
+ registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
+ registry.addOpInterface<FuncOp, AllocationHoistingBarrierOnly<FuncOp>>();
+}
+
+LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
+ ModuleOp moduleOp, const BufferizationOptions &options) {
+ SmallVector<FuncOp> orderedFuncOps;
+ DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ return failure();
+
+ ModuleBufferizationState state(moduleOp, *options.allocationFns);
+ BufferizationAliasInfo &aliasInfo = state.aliasInfo;
+
+ // Interestingly, all function args that are not visible outside of a module
+ // can be fully bufferized inplace by guaranteeing the CallOp is bufferized
+ // inplace. Therefore, we just bufferize funcOp as if none of its results were
+ // inplaceable, detect which operands are cloned internally and decide what to
+ // do at call sites.
+ for (FuncOp funcOp : orderedFuncOps) {
+ // No body => no analysis.
+ if (funcOp.body().empty())
+ continue;
+
+ // In a first approximation:
+ // =========================
+ // If the function is called, we can allocate on the caller side which lets
+ // us force inplace arguments at function boundaries.
+ // TODO: do not rely on this behavior.
+ if (callerMap.find(funcOp) != callerMap.end())
+ for (BlockArgument bbArg : funcOp.getArguments())
+ if (bbArg.getType().isa<TensorType>())
+ aliasInfo.setBufferizesToWritableMemory(bbArg);
+
+ // Analyze and bufferize funcOp.
+ if (failed(runComprehensiveBufferize(funcOp, options, state)))
+ return failure();
+ }
+
+ if (options.testAnalysisOnly)
+ return success();
+
+ for (FuncOp funcOp : orderedFuncOps) {
+ // Note: It would be good to apply cleanups here but we cannot as aliasInfo
+ // would be invalidated.
+ if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
+ state.bufferizedFunctionTypes)))
+ return failure();
+
+ if (!options.allowReturnMemref &&
+ llvm::any_of(funcOp.getType().getResults(), [](Type t) {
+ return t.isa<MemRefType, UnrankedMemRefType>();
+ })) {
+ funcOp->emitError("memref return type is unsupported");
+ return failure();
+ }
+ }
+
+ // Perform a post-processing pass of layout modification at function boundary
+ // according to the kBufferLayoutAttrName.
+ layoutPostProcessing(moduleOp);
+
+ // Post-pass cleanup of inplaceable and buffer_layout attributes.
+ moduleOp.walk([&](FuncOp op) {
+ for (BlockArgument bbArg : op.getArguments())
+ removeBufferizationFuncArguments(bbArg);
+ });
+
+ return success();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index d864bfe..31c7e4c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
@@ -44,11 +45,11 @@
memref::MemRefDialect, tensor::TensorDialect,
vector::VectorDialect, scf::SCFDialect,
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
- registerBufferizableOpInterfaceExternalModels(registry);
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ std_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 907e332..cd7f97f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6667,9 +6667,11 @@
name = "ComprehensiveBufferize",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
+ "lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
+ "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h",
],
includes = ["include"],
deps = [