blob: 68ef51992efeea960d4e502cb66cd6aed07cc608 [file] [log] [blame]
//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/DebugLog.h"
#include <optional>
namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // namespace bufferization
} // namespace mlir
#define DEBUG_TYPE "bufferize"
using namespace mlir;
using namespace mlir::bufferization;
namespace {
static OneShotBufferizationOptions::AnalysisHeuristic
parseHeuristicOption(const std::string &s) {
if (s == "bottom-up")
return OneShotBufferizationOptions::AnalysisHeuristic::BottomUp;
if (s == "top-down")
return OneShotBufferizationOptions::AnalysisHeuristic::TopDown;
if (s == "bottom-up-from-terminators")
return OneShotBufferizationOptions::AnalysisHeuristic::
BottomUpFromTerminators;
if (s == "fuzzer")
return OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer;
llvm_unreachable("invalid analysisheuristic option");
}
struct OneShotBufferizePass
: public bufferization::impl::OneShotBufferizePassBase<
OneShotBufferizePass> {
using Base::Base;
void runOnOperation() override {
OneShotBufferizationOptions opt;
if (!options) {
// Make new bufferization options if none were provided when creating the
// pass.
opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
opt.allowUnknownOps = allowUnknownOps;
opt.analysisFuzzerSeed = analysisFuzzerSeed;
opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
opt.copyBeforeWrite = copyBeforeWrite;
opt.dumpAliasSets = dumpAliasSets;
opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
if (mustInferMemorySpace && useEncodingForMemorySpace) {
emitError(getOperation()->getLoc())
<< "only one of 'must-infer-memory-space' and "
"'use-encoding-for-memory-space' are allowed in "
<< getArgument();
return signalPassFailure();
}
if (mustInferMemorySpace) {
opt.defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> {
return std::nullopt;
};
}
if (useEncodingForMemorySpace) {
opt.defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> {
if (auto rtt = dyn_cast<RankedTensorType>(t))
return rtt.getEncoding();
return std::nullopt;
};
}
opt.printConflicts = printConflicts;
opt.bufferAlignment = bufferAlignment;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
opt.checkParallelRegions = checkParallelRegions;
opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
// Configure type converter.
LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
emitError(UnknownLoc::get(&getContext()),
"Invalid option: 'infer-layout-map' is not a valid value for "
"'unknown-type-conversion'");
return signalPassFailure();
}
opt.unknownTypeConverterFn = [=](TensorType tensorType,
Attribute memorySpace,
const BufferizationOptions &options) {
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
assert(unknownTypeConversionOption ==
LayoutMapOption::FullyDynamicLayoutMap &&
"invalid layout map option");
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
};
// Configure op filter.
OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
// Filter may be specified via options.
if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
return llvm::is_contained(this->dialectFilter,
op->getDialect()->getNamespace());
// No filter specified: All other ops are allowed.
return true;
};
opt.opFilter.allowOperation(filterFn);
} else {
opt = *options;
}
if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
// These two flags do not make sense together: "copy-before-write"
// indicates that copies should be inserted before every memory write,
// but "test-analysis-only" indicates that only the analysis should be
// tested. (I.e., no IR is bufferized.)
emitError(UnknownLoc::get(&getContext()),
"Invalid option: 'copy-before-write' cannot be used with "
"'test-analysis-only'");
return signalPassFailure();
}
if (opt.printConflicts && !opt.testAnalysisOnly) {
emitError(
UnknownLoc::get(&getContext()),
"Invalid option: 'print-conflicts' requires 'test-analysis-only'");
return signalPassFailure();
}
if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
emitError(
UnknownLoc::get(&getContext()),
"Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
return signalPassFailure();
}
BufferizationState state;
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
if (failed(
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
} else {
if (!opt.noAnalysisFuncFilter.empty()) {
emitError(UnknownLoc::get(&getContext()),
"Invalid option: 'no-analysis-func-filter' requires "
"'bufferize-function-boundaries'");
return signalPassFailure();
}
if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
}
// Set pass statistics.
this->numBufferAlloc = statistics.numBufferAlloc;
this->numTensorInPlace = statistics.numTensorInPlace;
this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
}
private:
std::optional<OneShotBufferizationOptions> options;
};
} // namespace
//===----------------------------------------------------------------------===//
// BufferizableOpInterface-based Bufferization
//===----------------------------------------------------------------------===//
namespace {
/// A rewriter that keeps track of extra information during bufferization.
class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
public:
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
DenseSet<Operation *> &toBufferOps,
SmallVector<Operation *> &worklist,
const BufferizationOptions &options,
BufferizationStatistics *statistics)
: IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
worklist(worklist), analysisState(options), statistics(statistics) {
setListener(this);
}
protected:
void notifyOperationErased(Operation *op) override {
erasedOps.insert(op);
// Erase if present.
toBufferOps.erase(op);
}
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
// We only care about newly created ops.
if (previous.isSet())
return;
erasedOps.erase(op);
// Gather statistics about allocs.
if (statistics) {
if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
statistics->numBufferAlloc += static_cast<int64_t>(
sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
}
// Keep track of to_buffer ops.
if (isa<ToBufferOp>(op)) {
toBufferOps.insert(op);
return;
}
// Skip to_tensor ops.
if (isa<ToTensorOp>(op))
return;
// Skip non-tensor ops.
if (!hasTensorSemantics(op))
return;
// Skip ops that are not allowed to be bufferized.
auto const &options = analysisState.getOptions();
if (!options.isOpAllowed(op))
return;
// Add op to worklist.
worklist.push_back(op);
}
private:
/// A set of all erased ops.
DenseSet<Operation *> &erasedOps;
/// A set of all to_buffer ops.
DenseSet<Operation *> &toBufferOps;
/// The worklist of ops to be bufferized.
SmallVector<Operation *> &worklist;
/// The analysis state. Used for debug assertions and access to the
/// bufferization options.
const AnalysisState analysisState;
/// Bufferization statistics for debugging.
BufferizationStatistics *statistics;
};
} // namespace
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
if (options.copyBeforeWrite) {
AnalysisState analysisState(options);
if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
return failure();
}
// Keep track of to_buffer ops.
DenseSet<Operation *> toBufferOps;
op->walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
// Gather all bufferizable ops in top-to-bottom order.
//
// We should ideally know the exact memref type of all operands when
// bufferizing an op. (This is the case when bufferizing top-to-bottom.)
// Otherwise, we have to use a memref type with a fully dynamic layout map to
// avoid copies. We are currently missing patterns for layout maps to
// canonicalize away (or canonicalize to more precise layouts).
SmallVector<Operation *> worklist;
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
if (options.isOpAllowed(op) && hasTensorSemantics(op))
worklist.push_back(op);
});
// Keep track of all erased ops.
DenseSet<Operation *> erasedOps;
// Bufferize all ops.
BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
worklist, options, statistics);
for (unsigned i = 0; i < worklist.size(); ++i) {
Operation *nextOp = worklist[i];
// Skip ops that were erased.
if (erasedOps.contains(nextOp))
continue;
// Skip ops that are not bufferizable or not allowed.
auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
if (!bufferizableOp)
continue;
// Skip ops that no longer have tensor semantics.
if (!hasTensorSemantics(nextOp))
continue;
// Check for unsupported unstructured control flow.
if (!bufferizableOp.supportsUnstructuredControlFlow())
for (Region &r : nextOp->getRegions())
if (r.getBlocks().size() > 1)
return nextOp->emitOpError(
"op or BufferizableOpInterface implementation does not support "
"unstructured control flow, but at least one region has multiple "
"blocks");
// Bufferize the op.
LDBG(3) << "//===-------------------------------------------===//\n"
<< "IR after bufferizing: " << nextOp->getName();
rewriter.setInsertionPoint(nextOp);
if (failed(
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
LDBG(2) << "failed to bufferize\n"
<< "//===-------------------------------------------===//";
return nextOp->emitError("failed to bufferize op");
}
LDBG(3) << *op << "\n//===-------------------------------------------===//";
}
// Return early if the top-level op is entirely gone.
if (erasedOps.contains(op))
return success();
// Fold all to_buffer(to_tensor(x)) pairs.
for (Operation *op : toBufferOps) {
rewriter.setInsertionPoint(op);
(void)bufferization::foldToBufferToTensorPair(
rewriter, cast<ToBufferOp>(op), options);
}
// Remove all dead to_tensor ops.
op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
if (toTensorOp->getUses().empty()) {
rewriter.eraseOp(toTensorOp);
return WalkResult::skip();
}
return WalkResult::advance();
});
/// Check the result of bufferization. Return an error if an op was not
/// bufferized, unless partial bufferization is allowed.
if (options.allowUnknownOps)
return success();
for (Operation *op : worklist) {
// Skip ops that are entirely gone.
if (erasedOps.contains(op))
continue;
// Ops that no longer have tensor semantics (because they were updated
// in-place) are allowed.
if (!hasTensorSemantics(op))
continue;
// Continue ops that are not allowed.
if (!options.isOpAllowed(op))
continue;
// Ops without any uses and no side effects will fold away.
if (op->getUses().empty() && isMemoryEffectFree(op))
continue;
// ToTensorOps/ToBufferOps are allowed in the output.
if (isa<ToTensorOp, ToBufferOp>(op))
continue;
return op->emitError("op was not bufferized");
}
return success();
}
LogicalResult
bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) {
OpBuilder::InsertionGuard g(rewriter);
auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
if (!bufferizableOp)
return failure();
// Compute the new signature.
SmallVector<Type> newTypes;
for (BlockArgument &bbArg : block->getArguments()) {
auto tensorType = dyn_cast<TensorType>(bbArg.getType());
if (!tensorType) {
newTypes.push_back(bbArg.getType());
continue;
}
FailureOr<BufferLikeType> bufferType =
bufferization::getBufferType(bbArg, options, state);
if (failed(bufferType))
return failure();
newTypes.push_back(*bufferType);
}
// Change the type of all block arguments.
for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
if (bbArg.getType() == type)
continue;
// Collect all uses of the bbArg.
SmallVector<OpOperand *> bbArgUses;
for (OpOperand &use : bbArg.getUses())
bbArgUses.push_back(&use);
Type tensorType = bbArg.getType();
// Change the bbArg type to memref.
bbArg.setType(type);
// Replace all uses of the original tensor bbArg.
rewriter.setInsertionPointToStart(block);
if (!bbArgUses.empty()) {
Value toTensorOp = bufferization::ToTensorOp::create(
rewriter, bbArg.getLoc(), tensorType, bbArg);
for (OpOperand *use : bbArgUses)
use->set(toTensorOp);
}
}
// Bufferize callers of the block.
for (Operation *op : block->getUsers()) {
auto branchOp = dyn_cast<BranchOpInterface>(op);
if (!branchOp)
return op->emitOpError("cannot bufferize ops with block references that "
"do not implement BranchOpInterface");
auto it = llvm::find(op->getSuccessors(), block);
assert(it != op->getSuccessors().end() && "could find successor");
int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
SmallVector<Value> newOperands;
for (auto [operand, type] :
llvm::zip(operands.getForwardedOperands(), newTypes)) {
if (operand.getType() == type) {
// Not a tensor type. Nothing to do for this operand.
newOperands.push_back(operand);
continue;
}
FailureOr<BufferLikeType> operandBufferType =
bufferization::getBufferType(operand, options, state);
if (failed(operandBufferType))
return failure();
rewriter.setInsertionPointAfterValue(operand);
Value bufferizedOperand = bufferization::ToBufferOp::create(
rewriter, operand.getLoc(), *operandBufferType, operand);
// A cast is needed if the operand and the block argument have different
// bufferized types.
if (type != *operandBufferType)
bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(),
type, bufferizedOperand);
newOperands.push_back(bufferizedOperand);
}
operands.getMutableForwardedOperands().assign(newOperands);
}
return success();
}