blob: 434501b030e4a6572af3844e48efe645dd94602d [file]
//===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // namespace bufferization
} // namespace mlir
using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
using AllocDynamicSizesMap =
llvm::DenseMap<func::FuncOp, SmallVector<SmallVector<Value>>>;
/// Return `true` if the given MemRef type has a fully dynamic layout.
static bool hasFullyDynamicLayoutMap(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(type.getStridesAndOffset(strides, offset)))
return false;
if (!llvm::all_of(strides, ShapedType::isDynamic))
return false;
if (ShapedType::isStatic(offset))
return false;
return true;
}
/// Return `true` if the given MemRef type has a static identity layout (i.e.,
/// no layout).
static bool hasStaticIdentityLayout(MemRefType type) {
return type.getLayout().isIdentity();
}
/// Return the dynamic shapes of the `memref` based on the defining op. If the
/// complete dynamic shape fails to be captured, return an empty value.
/// Currently, only function block arguments are supported for capturing.
static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
Operation *defOp = memref.getDefiningOp();
if (!defOp)
return {};
auto operands = defOp->getOperands();
SmallVector<Value> dynamicSizes;
for (Value size : operands) {
if (!isa<IndexType>(size.getType()))
continue;
BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
if (!sizeSrc)
return {};
auto arguments = funcOp.getArguments();
auto iter = llvm::find(arguments, sizeSrc);
if (iter == arguments.end())
return {};
dynamicSizes.push_back(*iter);
}
return dynamicSizes;
}
/// Returns the dynamic sizes at the callee, through the call relationship
/// between the caller and callee.
static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
func::FuncOp callee,
ValueRange dynamicSizes) {
SmallVector<Value> mappedDynamicSizes;
for (Value size : dynamicSizes) {
for (auto [src, dst] :
llvm::zip_first(call.getOperands(), callee.getArguments())) {
if (size != dst)
continue;
mappedDynamicSizes.push_back(src);
}
}
assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
"could not find all dynamic sizes");
return mappedDynamicSizes;
}
// Updates the func op and entry block.
//
// Any args appended to the entry block are added to `appendedEntryArgs`.
// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
// to each newly created function argument.
static LogicalResult
updateFuncOp(func::FuncOp func,
SmallVectorImpl<BlockArgument> &appendedEntryArgs,
bool addResultAttribute) {
auto functionType = func.getFunctionType();
// Collect information about the results will become appended arguments.
SmallVector<Type, 6> erasedResultTypes;
BitVector erasedResultIndices(functionType.getNumResults());
for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
if (!hasStaticIdentityLayout(memrefType) &&
!hasFullyDynamicLayoutMap(memrefType)) {
// Only buffers with static identity layout can be allocated. These can
// be casted to memrefs with fully dynamic layout map. Other layout maps
// are not supported.
return func->emitError()
<< "cannot create out param for result with unsupported layout";
}
erasedResultIndices.set(resultType.index());
erasedResultTypes.push_back(memrefType);
}
}
// Add the new arguments to the function type.
auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
functionType.getResults());
func.setType(newFunctionType);
// Transfer the result attributes to arg attributes.
auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
func.setArgAttrs(functionType.getNumInputs() + i,
func.getResultAttrs(*erasedIndicesIt));
if (addResultAttribute)
func.setArgAttr(functionType.getNumInputs() + i,
StringAttr::get(func.getContext(), "bufferize.result"),
UnitAttr::get(func.getContext()));
}
// Erase the results.
if (failed(func.eraseResults(erasedResultIndices)))
return failure();
// Add the new arguments to the entry block if the function is not external.
if (func.isExternal())
return success();
Location loc = func.getLoc();
for (Type type : erasedResultTypes)
appendedEntryArgs.push_back(func.front().addArgument(type, loc));
return success();
}
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
if (isa<MemRefType>(operand.getType()))
copyIntoOutParams.push_back(operand);
else
keepAsReturnOperands.push_back(operand);
}
OpBuilder builder(op);
SmallVector<SmallVector<Value>> dynamicSizes;
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
bool hoistStaticAllocs =
options.hoistStaticAllocs &&
cast<MemRefType>(orig.getType()).hasStaticShape();
bool hoistDynamicAllocs =
options.hoistDynamicAllocs &&
!cast<MemRefType>(orig.getType()).hasStaticShape();
if ((hoistStaticAllocs || hoistDynamicAllocs) &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
orig.getDefiningOp())) {
orig.replaceAllUsesWith(arg);
if (hoistDynamicAllocs) {
SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
dynamicSizes.push_back(dynamicSize);
}
orig.getDefiningOp()->erase();
} else {
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
op.erase();
auto dynamicSizePair =
std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
dynamicSizes);
map.insert(dynamicSizePair);
return WalkResult::advance();
});
return failure(res.wasInterrupted());
}
// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
const bufferization::BufferResultsToOutParamsOpts &options) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
if (!callee) {
op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
<< "symbol table";
didFail = true;
return;
}
if (!options.filterFn(&callee))
return;
if (callee.isPublic() && !options.modifyPublicFunctions)
return;
if (callee.isExternal())
return;
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
for (OpResult result : op.getResults()) {
if (isa<MemRefType>(result.getType()))
replaceWithOutParams.push_back(result);
else
replaceWithNewCallResults.push_back(result);
}
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
size_t dynamicSizesIndex = 0;
for (Value memref : replaceWithOutParams) {
SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
? dynamicSizes[dynamicSizesIndex]
: SmallVector<Value>();
bool memrefStaticShape =
cast<MemRefType>(memref.getType()).hasStaticShape();
if (!memrefStaticShape && dynamicSize.empty()) {
op.emitError()
<< "cannot create out param for dynamically shaped result";
didFail = true;
return;
}
auto memrefType = cast<MemRefType>(memref.getType());
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
if (memrefStaticShape) {
dynamicSize = {};
} else {
++dynamicSizesIndex;
dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
}
auto maybeOutParam =
options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
return;
}
Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
"layout map not supported");
outParam =
memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
}
memref.replaceAllUsesWith(outParam);
outParams.push_back(outParam);
}
auto newOperands = llvm::to_vector<6>(op.getOperands());
newOperands.append(outParams.begin(), outParams.end());
auto newResultTypes = llvm::map_to_vector<6>(
replaceWithNewCallResults, [](Value v) { return v.getType(); });
auto newCall = func::CallOp::create(
builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
op.erase();
});
return failure(didFail);
}
LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
ModuleOp module,
const bufferization::BufferResultsToOutParamsOpts &options) {
// It maps the shape source of the dynamic shape memref returned by each
// function.
AllocDynamicSizesMap map;
for (auto func : module.getOps<func::FuncOp>()) {
if (func.isPublic() && !options.modifyPublicFunctions)
continue;
if (func.isExternal())
continue;
if (!options.filterFn(&func))
continue;
SmallVector<BlockArgument, 6> appendedEntryArgs;
if (failed(
updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
return failure();
if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
return failure();
}
}
if (failed(updateCalls(module, map, options)))
return failure();
return success();
}
namespace {
struct BufferResultsToOutParamsPass
: bufferization::impl::BufferResultsToOutParamsPassBase<
BufferResultsToOutParamsPass> {
using Base::Base;
void runOnOperation() override {
// Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
if (addResultAttribute)
options.addResultAttribute = true;
if (hoistStaticAllocs)
options.hoistStaticAllocs = true;
if (hoistDynamicAllocs)
options.hoistDynamicAllocs = true;
if (modifyPublicFunctions)
options.modifyPublicFunctions = true;
if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
options)))
return signalPassFailure();
}
private:
bufferization::BufferResultsToOutParamsOpts options;
};
} // namespace