| //===----------- MultiBuffering.cpp ---------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements multi buffering transformation. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/Dominance.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/ValueRange.h" |
| #include "mlir/Interfaces/LoopLikeInterface.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "memref-transforms" |
| #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| #define DBGSNL() (llvm::dbgs() << "\n") |
| |
| /// Return true if the op fully overwrite the given `buffer` value. |
| static bool overrideBuffer(Operation *op, Value buffer) { |
| auto copyOp = dyn_cast<memref::CopyOp>(op); |
| if (!copyOp) |
| return false; |
| return copyOp.getTarget() == buffer; |
| } |
| |
| /// Replace the uses of `oldOp` with the given `val` and for view-like uses |
| /// propagate the type change. Changing the memref type may require propagating |
| /// it through view-like ops (subview, expand_shape, collapse_shape, cast) so |
| /// we need to propagate the type change and erase old view ops. |
| /// |
| /// Only view-like ops whose result type can be recomputed from the new source |
| /// type and existing op attributes are handled here. Other ops fall back to |
| /// operand replacement without type propagation. |
| static LogicalResult replaceUsesAndPropagateType(RewriterBase &rewriter, |
| Operation *oldOp, Value val) { |
| SmallVector<Operation *> opsToErase; |
| // Iterate with early_inc to erase current user inside the loop. |
| for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) { |
| Operation *user = use.getOwner(); |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(user); |
| MemRefType srcType = cast<MemRefType>(val.getType()); |
| |
| // Try to create a new view-like op with updated result type. |
| // Each view-like op has its own method to compute the result type. |
| bool typeInferenceFailed = false; |
| Value replacement = |
| llvm::TypeSwitch<Operation *, Value>(user) |
| .Case([&](memref::SubViewOp subview) -> Value { |
| MemRefType newType = |
| memref::SubViewOp::inferRankReducedResultType( |
| subview.getType().getShape(), srcType, |
| subview.getStaticOffsets(), subview.getStaticSizes(), |
| subview.getStaticStrides()); |
| return memref::SubViewOp::create( |
| rewriter, subview->getLoc(), newType, val, |
| subview.getMixedOffsets(), subview.getMixedSizes(), |
| subview.getMixedStrides()); |
| }) |
| .Case([&](memref::ExpandShapeOp expand) -> Value { |
| FailureOr<MemRefType> newType = |
| memref::ExpandShapeOp::computeExpandedType( |
| srcType, expand.getResultType().getShape(), |
| expand.getReassociationIndices()); |
| if (failed(newType)) { |
| typeInferenceFailed = true; |
| return Value(); |
| } |
| return memref::ExpandShapeOp::create( |
| rewriter, expand->getLoc(), *newType, val, |
| expand.getReassociationIndices(), |
| expand.getMixedOutputShape()); |
| }) |
| .Case([&](memref::CollapseShapeOp collapse) -> Value { |
| FailureOr<MemRefType> newType = |
| memref::CollapseShapeOp::computeCollapsedType( |
| srcType, collapse.getReassociationIndices()); |
| if (failed(newType)) { |
| typeInferenceFailed = true; |
| return Value(); |
| } |
| return memref::CollapseShapeOp::create( |
| rewriter, collapse->getLoc(), *newType, val, |
| collapse.getReassociationIndices()); |
| }) |
| .Case([&](memref::CastOp cast) -> Value { |
| if (!memref::CastOp::areCastCompatible(srcType, cast.getType())) { |
| typeInferenceFailed = true; |
| return Value(); |
| } |
| return memref::CastOp::create(rewriter, cast->getLoc(), |
| cast.getType(), val); |
| }) |
| .Default([&](Operation *) -> Value { return Value(); }); |
| |
| if (typeInferenceFailed) { |
| user->emitOpError( |
| "failed to compute view-like result type after multi-buffering"); |
| return failure(); |
| } |
| |
| if (replacement) { |
| // Recursively propagate through view-like ops and mark old op for |
| // erasure. |
| if (failed(replaceUsesAndPropagateType(rewriter, user, replacement))) |
| return failure(); |
| opsToErase.push_back(user); |
| } else { |
| // Not a view-like op: just replace operand. |
| rewriter.startOpModification(user); |
| use.set(val); |
| rewriter.finalizeOpModification(user); |
| } |
| } |
| |
| for (Operation *op : opsToErase) { |
| rewriter.eraseOp(op); |
| } |
| |
| return success(); |
| } |
| |
| // Transformation to do multi-buffering/array expansion to remove dependencies |
| // on the temporary allocation between consecutive loop iterations. |
| // Returns success if the transformation happened and failure otherwise. |
| // This is not a pattern as it requires propagating the new memref type to its |
| // uses and requires updating subview ops. |
| FailureOr<memref::AllocOp> |
| mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, |
| unsigned multiBufferingFactor, |
| bool skipOverrideAnalysis) { |
| LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n"); |
| DominanceInfo dom(allocOp->getParentOp()); |
| LoopLikeOpInterface candidateLoop; |
| for (Operation *user : allocOp->getUsers()) { |
| auto parentLoop = user->getParentOfType<LoopLikeOpInterface>(); |
| if (!parentLoop) { |
| if (isa<memref::DeallocOp>(user)) { |
| // Allow dealloc outside of any loop. |
| // TODO: The whole precondition function here is very brittle and will |
| // need to rethought an isolated into a cleaner analysis. |
| continue; |
| } |
| LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n"); |
| LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n"); |
| return failure(); |
| } |
| if (!skipOverrideAnalysis) { |
| /// Make sure there is no loop-carried dependency on the allocation. |
| if (!overrideBuffer(user, allocOp.getResult())) { |
| LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n"); |
| continue; |
| } |
| // If this user doesn't dominate all the other users keep looking. |
| if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { |
| return !dom.dominates(user, otherUser); |
| })) { |
| LLVM_DEBUG( |
| DBGS() << "--Skip user: does not dominate all other users\n"); |
| continue; |
| } |
| } else { |
| if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { |
| return !isa<memref::DeallocOp>(otherUser) && |
| !parentLoop->isProperAncestor(otherUser); |
| })) { |
| LLVM_DEBUG( |
| DBGS() |
| << "--Skip user: not all other users are in the parent loop\n"); |
| continue; |
| } |
| } |
| candidateLoop = parentLoop; |
| break; |
| } |
| |
| if (!candidateLoop) { |
| LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n"); |
| return failure(); |
| } |
| |
| std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar(); |
| std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound(); |
| std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep(); |
| if (!inductionVar || !lowerBound || !singleStep || |
| !llvm::hasSingleElement(candidateLoop.getLoopRegions())) { |
| LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n"); |
| return failure(); |
| } |
| |
| if (!dom.dominates(allocOp.getOperation(), candidateLoop)) { |
| LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n"); |
| return failure(); |
| } |
| |
| LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n"); |
| |
| // 1. Construct the multi-buffered memref type. |
| ArrayRef<int64_t> originalShape = allocOp.getType().getShape(); |
| SmallVector<int64_t, 4> multiBufferedShape{multiBufferingFactor}; |
| llvm::append_range(multiBufferedShape, originalShape); |
| LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n"); |
| MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) |
| .setShape(multiBufferedShape) |
| .setLayout(MemRefLayoutAttrInterface()); |
| LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n"); |
| |
| // 2. Create the multi-buffered alloc. |
| Location loc = allocOp->getLoc(); |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(allocOp); |
| auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType, |
| ValueRange{}, allocOp->getAttrs()); |
| LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); |
| |
| // 3. Within the loop, build the modular leading index (i.e. each loop |
| // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). |
| rewriter.setInsertionPointToStart( |
| &candidateLoop.getLoopRegions().front()->front()); |
| Value ivVal = *inductionVar; |
| Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound); |
| Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep); |
| AffineExpr iv, lb, step; |
| bindDims(rewriter.getContext(), iv, lb, step); |
| Value bufferIndex = affine::makeComposedAffineApply( |
| rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor, |
| {ivVal, lbVal, stepVal}); |
| LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n"); |
| |
| // 4. Build the subview accessing the particular slice, taking modular |
| // rotation into account. |
| int64_t mbMemRefTypeRank = mbMemRefType.getRank(); |
| IntegerAttr zero = rewriter.getIndexAttr(0); |
| IntegerAttr one = rewriter.getIndexAttr(1); |
| SmallVector<OpFoldResult> offsets(mbMemRefTypeRank, zero); |
| SmallVector<OpFoldResult> sizes(mbMemRefTypeRank, one); |
| SmallVector<OpFoldResult> strides(mbMemRefTypeRank, one); |
| // Offset is [bufferIndex, 0 ... 0 ]. |
| offsets.front() = bufferIndex; |
| // Sizes is [1, original_size_0 ... original_size_n ]. |
| for (int64_t i = 0, e = originalShape.size(); i != e; ++i) |
| sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); |
| // Strides is [1, 1 ... 1 ]. |
| MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( |
| originalShape, mbMemRefType, offsets, sizes, strides); |
| Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc, |
| offsets, sizes, strides); |
| LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); |
| |
| // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need |
| // to handle dealloc uses separately.. |
| for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { |
| auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner()); |
| if (!deallocOp) |
| continue; |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(deallocOp); |
| auto newDeallocOp = |
| memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc); |
| (void)newDeallocOp; |
| LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); |
| rewriter.eraseOp(deallocOp); |
| } |
| |
| // 6. RAUW with the particular slice, taking modular rotation into account. |
| if (failed(replaceUsesAndPropagateType(rewriter, allocOp, subview))) |
| return failure(); |
| |
| // 7. Finally, erase the old allocOp. |
| rewriter.eraseOp(allocOp); |
| |
| return mbAlloc; |
| } |
| |
| FailureOr<memref::AllocOp> |
| mlir::memref::multiBuffer(memref::AllocOp allocOp, |
| unsigned multiBufferingFactor, |
| bool skipOverrideAnalysis) { |
| IRRewriter rewriter(allocOp->getContext()); |
| return multiBuffer(rewriter, allocOp, multiBufferingFactor, |
| skipOverrideAnalysis); |
| } |