blob: e354195dfca0bc749ff6778cc8d30b04005f794f [file] [log] [blame]
//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "comprehensive-module-bufferize"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
using namespace mlir;
using namespace linalg;
using namespace tensor;
using namespace comprehensive_bufferize;
namespace {
/// A specialization of BufferizationState that keeps track of additional
/// state required for bufferization of function boundaries.
struct ModuleBufferizationState : public BufferizationState {
using BufferizationState::BufferizationState;
/// A map for looking up bufferized function types.
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
/// A mapping of return values to equivalent BlockArguments.
DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
};
} // namespace
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
/// If `value` is a memref::CastOp, return its source. Otherwise, return
/// `value` directly.
static Value getNonCastedValue(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.source();
return value;
}
/// Remove the attribute that triggers inplace bufferization on a FuncOp
/// argument `bbArg`.
static void removeBufferizationFuncArguments(BlockArgument bbArg) {
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizableOpInterface::kBufferLayoutAttrName);
funcOp.removeArgAttr(bbArg.getArgNumber(),
BufferizableOpInterface::kInplaceableAttrName);
}
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
ReturnOp returnOp;
for (Block &b : funcOp.body()) {
if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
}
}
return returnOp;
}
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
/// tensor is replaced by the corresponding buffer type.
/// In order for all the callers to agree, this *must* bufferize to the most
/// dynamic buffer type supported.
/// A later pass across all CallOps in the module can decide whether to simplify
/// the types of to version according to some cost model.
static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
TypeRange argumentTypes,
TypeRange resultTypes) {
auto rewrite = [](Type t) -> Type {
// TODO: non-zero address space.
// TODO: layout information if relevant.
if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
return getDynamicMemRefType(rankedTensorType);
if (auto tensorType = t.dyn_cast<TensorType>())
return getContiguousOrUnrankedMemRefType(tensorType);
return t;
};
auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
return FunctionType::get(ctx, argTypes, retTypes);
}
/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
/// it. Otherwise, construct a new entry based on `argumentTypes` and
/// `resultTypes`.
// TODO: improve the layering.
static FunctionType getOrCreateBufferizedFunctionType(
FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
auto it = bufferizedFunctionTypes.find(funcOp);
if (it != bufferizedFunctionTypes.end())
return it->second;
auto it2 = bufferizedFunctionTypes.try_emplace(
funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
resultTypes));
LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
return it2.first->second;
}
/// Store function BlockArguments that are equivalent to a returned value in
/// the given ModuleBufferizationState.
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
ModuleBufferizationState &state) {
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
for (Value returnVal : returnOp.operands())
if (returnVal.getType().isa<RankedTensorType>())
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<RankedTensorType>())
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
state.equivalentReturnValToBBArg[returnVal] = bbArg;
}
/// Rewrite the `funcOp` arguments analysis return values and terminator into
/// buffer form (using the canonical memref layout for now), according to the
/// inPlace-bufferizable information of the function arguments.
///
/// This relies on a buffer equivalence analysis of each return operand. When a
/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
/// dropped from the return values and becomes inplaceable at all callers. This
/// assumes all CallOp perform the necessary work to clone operands so as to
/// make them inplaceable. Reliance on this logic will need to be relaxed in the
/// future.
///
/// Note: Returning a memref currently fails bufferization. If such memrefs
/// originate from an op with an Alloc effect, they could be hoisted in the
/// future.
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
ModuleBufferizationState &state) {
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
!llvm::any_of(funcOp.getType().getResults(), isaTensor))
return success();
// Get the bufferized FunctionType for funcOp or construct it if not yet
// available.
// TODO: Atm we have 3 cases:
// 1. if a function is called from within the Module, it must have bufferized
// to inplaceable tensor results.
// 2. if it is bodiless, it must have bufferized and is not allowed to have
// result tensors.
// 3. if it is not called internally, it still must bufferize to inplaceable
// tensor results and we construct it now (e.g. top-level function called
// externally).
// -> Figure out a better layering.
TypeRange resultTypes;
// Corner case: Bodiless FuncOp
// ============================
// The body of such functions is assumed opaque and we can't know the
// bufferization contract they want to enforce atm.
// As a consequence, only support functions that don't return any tensor atm.
if (funcOp.getBody().empty()) {
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return funcOp->emitError() << "cannot bufferize bodiless function that "
<< "returns a tensor";
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), TypeRange{},
state.bufferizedFunctionTypes);
funcOp.setType(bufferizedFuncType);
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
return success();
}
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
// 1. For each FuncOp result, keep track of which inplace argument it reuses.
SmallVector<Value> returnValues;
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Value returnVal = returnOperand.get();
// If not a renturn tensor type just forward it.
if (!returnVal.getType().isa<RankedTensorType>()) {
returnValues.push_back(returnVal);
continue;
}
// If return operand is equivalent to some bbArg, no need to return it.
if (state.equivalentReturnValToBBArg.count(returnVal))
continue;
// Cast values at the call site if necessary.
returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal)));
}
// 2. Rewrite the terminator without the inPlace bufferizable values.
ValueRange retValues{returnValues};
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
state.bufferizedFunctionTypes);
OpBuilder b(returnOp);
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
// 3. Rewrite the bbArgs.
// Iterate on the original `numArgs` and replace them in order.
// This guarantees the argument order still matches after the rewrite.
Block &frontBlock = funcOp.body().front();
unsigned numArgs = frontBlock.getNumArguments();
for (unsigned idx = 0; idx < numArgs; ++idx) {
auto bbArg = frontBlock.getArgument(0);
auto tensorType = bbArg.getType().dyn_cast<TensorType>();
// Non-tensor types are just forwarded.
if (!tensorType) {
frontBlock.addArgument(bbArg.getType());
bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
frontBlock.eraseArgument(0);
continue;
}
// Get the buffer type from the bufferized function type.
Type memrefType = bufferizedFuncType.getInput(idx);
Value memref = frontBlock.addArgument(memrefType);
OpBuilder b(funcOp->getContext());
b.setInsertionPointToStart(&frontBlock);
// Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
if (auto toMemrefOp =
dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
aliasInfo.insertNewBufferEquivalence(castOp.dest(),
toMemrefOp.memref());
}
}
// Replace all remaining uses by a to_tensor.
if (!bbArg.use_empty()) {
auto toTensorOp =
b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
bbArg.replaceAllUsesWith(toTensorOp);
}
frontBlock.eraseArgument(0);
// TODO: add support to erase aliasInfo entries if deemed necessary.
}
// 4. Rewrite the FuncOp type to buffer form.
funcOp.setType(bufferizedFuncType);
LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
return success();
}
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
/// Return `failure()` if a cycle of calls is detected or if we are unable to
/// retrieve the called FuncOp from any CallOpInterface.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
SmallVectorImpl<FuncOp> &orderedFuncOps,
DenseMap<FuncOp, DenseSet<Operation *>> &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested CallOpInterfaceOp).
DenseMap<FuncOp, DenseSet<FuncOp>> calledBy;
// For each FuncOp, the number of CallOpInterface it contains.
DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
if (!funcOp.body().empty()) {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
if (!returnOp)
return funcOp->emitError()
<< "cannot bufferize a FuncOp with tensors and "
"without a unique ReturnOp";
}
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
// Only support CallOp for now.
if (!isa<CallOp>(callOp.getOperation()))
return callOp->emitError() << "expected a CallOp";
FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called FuncOp");
auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
it.first->getSecond().insert(callOp);
if (calledBy[calledFunction].count(funcOp) == 0) {
calledBy[calledFunction].insert(funcOp);
numberCallOpsContainedInFuncOp[funcOp]++;
}
return WalkResult::advance();
});
});
if (res.wasInterrupted())
return failure();
// Iteratively remove function operation that do not call any of the
// functions remaining in the callCounter map and add them to the worklist.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
[](auto entry) { return entry.getSecond() == 0; });
if (it == numberCallOpsContainedInFuncOp.end())
return moduleOp.emitOpError(
"expected callgraph to be free of circular dependencies.");
orderedFuncOps.push_back(it->getFirst());
for (auto callee : calledBy[it->getFirst()])
numberCallOpsContainedInFuncOp[callee]--;
numberCallOpsContainedInFuncOp.erase(it);
}
return success();
}
static void
foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
auto itCallers = callerMap.find(callee);
if (itCallers == callerMap.end())
return;
for (Operation *caller : itCallers->second)
doit(caller);
}
/// Postprocess the linalg.buffer_layout annotation across function boundaries.
/// This is a purely mechanical process that may later become part of a
/// separate pass with its own layout assignment heuristic.
static void layoutPostProcessing(ModuleOp moduleOp) {
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
(void)res;
assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
for (FuncOp funcOp : orderedFuncOps) {
DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
operandsPerCaller.try_emplace(caller, SmallVector<Value>());
});
SmallVector<Type> argumentTypes;
// Iterate on each function argument and check it it was marked with a
// desired layout.
for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
int argNumber = it.index();
Type inputType = it.value();
auto memrefType = inputType.dyn_cast<MemRefType>();
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
AffineMap desiredLayoutMap =
layoutAttr ? layoutAttr.getValue() : AffineMap();
AffineMap currentLayoutMap =
memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
argumentTypes.push_back(inputType);
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
operandsPerCaller.find(caller)->getSecond().push_back(
caller->getOperand(argNumber));
});
continue;
}
// Compute the buffer type with desired layout and add to input argument
// types.
MemRefType desiredMemrefType = MemRefType::get(
memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
argumentTypes.push_back(desiredMemrefType);
// If funcOp's body is not empty, change the bbArg type and propagate.
if (!funcOp.body().empty()) {
BlockArgument bbArg = funcOp.getArgument(argNumber);
bbArg.setType(desiredMemrefType);
OpBuilder b(bbArg.getContext());
b.setInsertionPointToStart(bbArg.getOwner());
// Cast back to the original memrefType and let it canonicalize.
Value cast =
b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
}
// Cast to desired buffer type on all callers to `funcOp`.
// TODO: on the callee side, this may even have to trigger a copy to
// change the layout. For now let the memref::CastOp fail to verify in
// such cases.
auto castArg = [&](Operation *caller) {
OpBuilder b(caller);
Value newOperand = b.create<memref::CastOp>(
funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
};
foreachCaller(callerMap, funcOp, castArg);
}
// Set operands with cast buffer on all callers to `funcOp`.
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
caller->setOperands(operandsPerCaller.lookup(caller));
});
// Finally set the funcOp type to update the arguments.
auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
funcOp.getType().getResults());
funcOp.setType(newFuncType);
}
}
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
// of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
// conservative.
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
// CallOpInterface alone doesn't bufferize to a memory write, one of the
// uses of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
// conservative.
return true;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
// TODO: Can we do better?
return {};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
// CallOpInterface is special, it needs to wait for the callee to be
// bufferized and needs to inspect the BufferAliasInfo object. It can't
// make a proper determination by itself and needs to be conservative.
return OpResult();
}
BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
return BufferRelation::Equivalent;
}
/// In a first approximation, all the function arguments of a FuncOp are
/// marked inplaceable. For now, it is the responsibility of the `callOp`
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp");
auto &moduleState = static_cast<ModuleBufferizationState &>(state);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(callOp);
// 1. Filter return types:
// - if the callee is bodiless / external, we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
// - if the callee has a body, we perform inter-procedural equivalence
// analysis. When successful, a result folds onto an operand. When
// unsuccessful, additional work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
SmallVector<Type> resultTypes;
if (funcOp.body().empty()) {
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
} else {
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Type returnType = returnOperand.get().getType();
if (!isaTensor(returnType)) {
resultTypes.push_back(returnType);
continue;
}
// If return operand is equivalent to some bbArg, no need to return it.
Value returnVal = returnOperand.get();
if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
BlockArgument bbArg =
moduleState.equivalentReturnValToBBArg[returnVal];
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
int64_t idx = bbArg.getArgNumber();
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural
// info.
state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This ToTensorOp must fold/DCE away or bufferization should be
// considered failed.
Value toTensorOp =
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
state.mapBuffer(toTensorOp, buffer);
continue;
}
resultTypes.push_back(returnType);
}
}
// 2. Compute bufferized FunctionType.
SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
// Get the bufferized FunctionType for funcOp or construct it if not yet
// available.
FunctionType bufferizedFuncType =
getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
moduleState.bufferizedFunctionTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
SmallVector<Value> newOperands;
newOperands.reserve(callOp->getNumOperands());
for (OpOperand &opOperand : callOp->getOpOperands()) {
Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied.
if (!tensorOperand.getType().isa<TensorType>()) {
newOperands.push_back(tensorOperand);
continue;
}
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
Value buffer = state.lookupBuffer(tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
// Since we don't yet have a clear layout story, buffer_cast may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
// something better.
if (buffer.getType() != memRefType) {
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
}
// 4. Create the new CallOp.
Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// 5. Delete the op at the end of bufferization.
state.markOpObsolete(callOp);
return success();
}
};
struct ReturnOpInterface
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
ReturnOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto returnOp = cast<ReturnOp>(op);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
// Cannot insert after returnOp.
b.setInsertionPoint(returnOp);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
"only support FuncOp parent for ReturnOp");
for (OpOperand &operand : returnOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
Value v = state.lookupBuffer(operand.get());
Value returnTensor = b.create<bufferization::ToTensorOp>(
returnOp.getLoc(), v);
operand.set(returnTensor);
state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
state.mapBuffer(returnTensor, v);
}
return success();
}
};
} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::std_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
registry.addOpInterface<FuncOp, AllocationHoistingBarrierOnly<FuncOp>>();
}
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
ModuleOp moduleOp, const BufferizationOptions &options) {
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
ModuleBufferizationState state(moduleOp, *options.allocationFns);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
// inplace. Therefore, we just bufferize funcOp as if none of its results were
// inplaceable, detect which operands are cloned internally and decide what to
// do at call sites.
for (FuncOp funcOp : orderedFuncOps) {
// No body => no analysis.
if (funcOp.body().empty())
continue;
// In a first approximation:
// =========================
// If the function is called, we can allocate on the caller side which lets
// us force inplace arguments at function boundaries.
// TODO: do not rely on this behavior.
if (callerMap.find(funcOp) != callerMap.end())
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<TensorType>())
aliasInfo.setBufferizesToWritableMemory(bbArg);
// Analyze and bufferize funcOp.
if (failed(runComprehensiveBufferize(funcOp, options, state)))
return failure();
populateEquivalentFuncOpBBArgs(funcOp, state);
}
if (options.testAnalysisOnly)
return success();
for (FuncOp funcOp : orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
if (failed(bufferizeFuncOpBoundary(funcOp, state)))
return failure();
if (!options.allowReturnMemref &&
llvm::any_of(funcOp.getType().getResults(), [](Type t) {
return t.isa<MemRefType, UnrankedMemRefType>();
})) {
funcOp->emitError("memref return type is unsupported");
return failure();
}
}
// Perform a post-processing pass of layout modification at function boundary
// according to the kBufferLayoutAttrName.
layoutPostProcessing(moduleOp);
// Post-pass cleanup of inplaceable and buffer_layout attributes.
moduleOp.walk([&](FuncOp op) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationFuncArguments(bbArg);
});
return success();
}