blob: ce45f847ccaedf9fc3ff76ab4bfb287153edd212 [file]
//===----------- MultiBuffering.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements multi buffering transformation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "memref-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
/// Return true if the op fully overwrite the given `buffer` value.
static bool overrideBuffer(Operation *op, Value buffer) {
auto copyOp = dyn_cast<memref::CopyOp>(op);
if (!copyOp)
return false;
return copyOp.getTarget() == buffer;
}
/// Replace the uses of `oldOp` with the given `val` and for view-like uses
/// propagate the type change. Changing the memref type may require propagating
/// it through view-like ops (subview, expand_shape, collapse_shape, cast) so
/// we need to propagate the type change and erase old view ops.
///
/// Only view-like ops whose result type can be recomputed from the new source
/// type and existing op attributes are handled here. Other ops fall back to
/// operand replacement without type propagation.
static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter,
Operation *oldOp, Value val) {
SmallVector<Operation *> opsToErase;
// Iterate with early_inc to erase current user inside the loop.
for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) {
Operation *user = use.getOwner();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(user);
MemRefType srcType = cast<MemRefType>(val.getType());
// Try to create a new view-like op with updated result type.
// Each view-like op has its own method to compute the result type.
bool typeInferenceFailed = false;
Value replacement =
llvm::TypeSwitch<Operation *, Value>(user)
.Case([&](memref::SubViewOp subview) -> Value {
MemRefType newType =
memref::SubViewOp::inferRankReducedResultType(
subview.getType().getShape(), srcType,
subview.getStaticOffsets(), subview.getStaticSizes(),
subview.getStaticStrides());
return memref::SubViewOp::create(
rewriter, subview->getLoc(), newType, val,
subview.getMixedOffsets(), subview.getMixedSizes(),
subview.getMixedStrides());
})
.Case([&](memref::ExpandShapeOp expand) -> Value {
FailureOr<MemRefType> newType =
memref::ExpandShapeOp::computeExpandedType(
srcType, expand.getResultType().getShape(),
expand.getReassociationIndices());
if (failed(newType)) {
typeInferenceFailed = true;
return Value();
}
return memref::ExpandShapeOp::create(
rewriter, expand->getLoc(), *newType, val,
expand.getReassociationIndices(),
expand.getMixedOutputShape());
})
.Case([&](memref::CollapseShapeOp collapse) -> Value {
FailureOr<MemRefType> newType =
memref::CollapseShapeOp::computeCollapsedType(
srcType, collapse.getReassociationIndices());
if (failed(newType)) {
typeInferenceFailed = true;
return Value();
}
return memref::CollapseShapeOp::create(
rewriter, collapse->getLoc(), *newType, val,
collapse.getReassociationIndices());
})
.Case([&](memref::CastOp cast) -> Value {
if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) {
typeInferenceFailed = true;
return Value();
}
return memref::CastOp::create(rewriter, cast->getLoc(),
cast.getType(), val);
})
.Default([&](Operation *) -> Value { return Value(); });
if (typeInferenceFailed) {
user->emitOpError(
"failed to compute view-like result type after multi-buffering");
return failure();
}
if (replacement) {
// Recursively propagate through view-like ops and mark old op for
// erasure.
if (failed(replaceUsesAndPropagateType(rewriter, user, replacement)))
return failure();
opsToErase.push_back(user);
} else {
// Not a view-like op: just replace operand.
rewriter.startOpModification(user);
use.set(val);
rewriter.finalizeOpModification(user);
}
}
for (Operation *op : opsToErase) {
rewriter.eraseOp(op);
}
return success();
}
// Transformation to do multi-buffering/array expansion to remove dependencies
// on the temporary allocation between consecutive loop iterations.
// Returns success if the transformation happened and failure otherwise.
// This is not a pattern as it requires propagating the new memref type to its
// uses and requires updating subview ops.
FailureOr<memref::AllocOp>
mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
unsigned multiBufferingFactor,
bool skipOverrideAnalysis) {
LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n");
DominanceInfo dom(allocOp->getParentOp());
LoopLikeOpInterface candidateLoop;
for (Operation *user : allocOp->getUsers()) {
auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
if (!parentLoop) {
if (isa<memref::DeallocOp>(user)) {
// Allow dealloc outside of any loop.
// TODO: The whole precondition function here is very brittle and will
// need to rethought an isolated into a cleaner analysis.
continue;
}
LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n");
LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n");
return failure();
}
if (!skipOverrideAnalysis) {
/// Make sure there is no loop-carried dependency on the allocation.
if (!overrideBuffer(user, allocOp.getResult())) {
LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n");
continue;
}
// If this user doesn't dominate all the other users keep looking.
if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
return !dom.dominates(user, otherUser);
})) {
LLVM_DEBUG(
DBGS() << "--Skip user: does not dominate all other users\n");
continue;
}
} else {
if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
return !isa<memref::DeallocOp>(otherUser) &&
!parentLoop->isProperAncestor(otherUser);
})) {
LLVM_DEBUG(
DBGS()
<< "--Skip user: not all other users are in the parent loop\n");
continue;
}
}
candidateLoop = parentLoop;
break;
}
if (!candidateLoop) {
LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n");
return failure();
}
std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
if (!inductionVar || !lowerBound || !singleStep ||
!llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
return failure();
}
if (!dom.dominates(allocOp.getOperation(), candidateLoop)) {
LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n");
return failure();
}
LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n");
// 1. Construct the multi-buffered memref type.
ArrayRef<int64_t> originalShape = allocOp.getType().getShape();
SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor};
llvm::append_range(multiBufferedShape, originalShape);
LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n");
MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType())
.setShape(multiBufferedShape)
.setLayout(MemRefLayoutAttrInterface());
LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n");
// 2. Create the multi-buffered alloc.
Location loc = allocOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(allocOp);
auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType,
ValueRange{}, allocOp->getAttrs());
LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n");
// 3. Within the loop, build the modular leading index (i.e. each loop
// iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
rewriter.setInsertionPointToStart(
&candidateLoop.getLoopRegions().front()->front());
Value ivVal = *inductionVar;
Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
AffineExpr iv, lb, step;
bindDims(rewriter.getContext(), iv, lb, step);
Value bufferIndex = affine::makeComposedAffineApply(
rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor,
{ivVal, lbVal, stepVal});
LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n");
// 4. Build the subview accessing the particular slice, taking modular
// rotation into account.
int64_t mbMemRefTypeRank = mbMemRefType.getRank();
IntegerAttr zero = rewriter.getIndexAttr(0);
IntegerAttr one = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero);
SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one);
SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one);
// Offset is [bufferIndex, 0 ... 0 ].
offsets.front() = bufferIndex;
// Sizes is [1, original_size_0 ... original_size_n ].
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
originalShape, mbMemRefType, offsets, sizes, strides);
Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
// 5. Due to the recursive nature of replaceUsesAndPropagateType , we need
// to handle dealloc uses separately..
for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) {
auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner());
if (!deallocOp)
continue;
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(deallocOp);
auto newDeallocOp =
memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc);
(void)newDeallocOp;
LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n");
rewriter.eraseOp(deallocOp);
}
// 6. RAUW with the particular slice, taking modular rotation into account.
if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview)))
return failure();
// 7. Finally, erase the old allocOp.
rewriter.eraseOp(allocOp);
return mbAlloc;
}
FailureOr<memref::AllocOp>
mlir::memref::multiBuffer(memref::AllocOp allocOp,
unsigned multiBufferingFactor,
bool skipOverrideAnalysis) {
IRRewriter rewriter(allocOp->getContext());
return multiBuffer(rewriter, allocOp, multiBufferingFactor,
skipOverrideAnalysis);
}