| //======- BufferViewFlowAnalysis.cpp - Buffer alias analysis -*- C++ -*-======// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" |
| |
| #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" |
| #include "mlir/Interfaces/CallInterfaces.h" |
| #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Interfaces/ViewLikeInterface.h" |
| #include "llvm/ADT/SetOperations.h" |
| #include "llvm/ADT/SetVector.h" |
| |
| using namespace mlir; |
| using namespace mlir::bufferization; |
| |
| //===----------------------------------------------------------------------===// |
| // BufferViewFlowAnalysis |
| //===----------------------------------------------------------------------===// |
| |
| /// Constructs a new alias analysis using the op provided. |
| BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } |
| |
| static BufferViewFlowAnalysis::ValueSetT |
| resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) { |
| BufferViewFlowAnalysis::ValueSetT result; |
| SmallVector<Value, 8> queue; |
| queue.push_back(value); |
| while (!queue.empty()) { |
| Value currentValue = queue.pop_back_val(); |
| if (result.insert(currentValue).second) { |
| auto it = map.find(currentValue); |
| if (it != map.end()) { |
| for (Value aliasValue : it->second) |
| queue.push_back(aliasValue); |
| } |
| } |
| } |
| return result; |
| } |
| |
| /// Find all immediate and indirect dependent buffers this value could |
| /// potentially have. Note that the resulting set will also contain the value |
| /// provided as it is a dependent alias of itself. |
| BufferViewFlowAnalysis::ValueSetT |
| BufferViewFlowAnalysis::resolve(Value rootValue) const { |
| return resolveValues(dependencies, rootValue); |
| } |
| |
| BufferViewFlowAnalysis::ValueSetT |
| BufferViewFlowAnalysis::resolveReverse(Value rootValue) const { |
| return resolveValues(reverseDependencies, rootValue); |
| } |
| |
| /// Removes the given values from all alias sets. |
| void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) { |
| for (auto &entry : dependencies) |
| llvm::set_subtract(entry.second, aliasValues); |
| } |
| |
| void BufferViewFlowAnalysis::rename(Value from, Value to) { |
| dependencies[to] = dependencies[from]; |
| dependencies.erase(from); |
| |
| for (auto &[_, value] : dependencies) { |
| if (value.contains(from)) { |
| value.insert(to); |
| value.erase(from); |
| } |
| } |
| } |
| |
| /// This function constructs a mapping from values to its immediate |
| /// dependencies. It iterates over all blocks, gets their predecessors, |
| /// determines the values that will be passed to the corresponding block |
| /// arguments and inserts them into the underlying map. Furthermore, it wires |
| /// successor regions and branch-like return operations from nested regions. |
| void BufferViewFlowAnalysis::build(Operation *op) { |
| // Registers all dependencies of the given values. |
| auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { |
| for (auto [value, dep] : llvm::zip_equal(values, dependencies)) { |
| this->dependencies[value].insert(dep); |
| this->reverseDependencies[dep].insert(value); |
| } |
| }; |
| |
| // Mark all buffer results and buffer region entry block arguments of the |
| // given op as terminals. |
| auto populateTerminalValues = [&](Operation *op) { |
| for (Value v : op->getResults()) |
| if (isa<BaseMemRefType>(v.getType())) |
| this->terminals.insert(v); |
| for (Region &r : op->getRegions()) |
| for (BlockArgument v : r.getArguments()) |
| if (isa<BaseMemRefType>(v.getType())) |
| this->terminals.insert(v); |
| }; |
| |
| op->walk([&](Operation *op) { |
| // Query BufferViewFlowOpInterface. If the op does not implement that |
| // interface, try to infer the dependencies from other interfaces that the |
| // op may implement. |
| if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) { |
| bufferViewFlowOp.populateDependencies(registerDependencies); |
| for (Value v : op->getResults()) |
| if (isa<BaseMemRefType>(v.getType()) && |
| bufferViewFlowOp.mayBeTerminalBuffer(v)) |
| this->terminals.insert(v); |
| for (Region &r : op->getRegions()) |
| for (BlockArgument v : r.getArguments()) |
| if (isa<BaseMemRefType>(v.getType()) && |
| bufferViewFlowOp.mayBeTerminalBuffer(v)) |
| this->terminals.insert(v); |
| return WalkResult::advance(); |
| } |
| |
| // Add additional dependencies created by view changes to the alias list. |
| if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { |
| registerDependencies(viewInterface.getViewSource(), |
| viewInterface->getResult(0)); |
| return WalkResult::advance(); |
| } |
| |
| if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) { |
| // Query all branch interfaces to link block argument dependencies. |
| Block *parentBlock = branchInterface->getBlock(); |
| for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); |
| it != e; ++it) { |
| // Query the branch op interface to get the successor operands. |
| auto successorOperands = |
| branchInterface.getSuccessorOperands(it.getIndex()); |
| // Build the actual mapping of values to their immediate dependencies. |
| registerDependencies(successorOperands.getForwardedOperands(), |
| (*it)->getArguments().drop_front( |
| successorOperands.getProducedOperandCount())); |
| } |
| return WalkResult::advance(); |
| } |
| |
| if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(op)) { |
| // Query the RegionBranchOpInterface to find potential successor regions. |
| // Extract all entry regions and wire all initial entry successor inputs. |
| SmallVector<RegionSuccessor, 2> entrySuccessors; |
| regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(), |
| entrySuccessors); |
| for (RegionSuccessor &entrySuccessor : entrySuccessors) { |
| // Wire the entry region's successor arguments with the initial |
| // successor inputs. |
| registerDependencies( |
| regionInterface.getEntrySuccessorOperands(entrySuccessor), |
| entrySuccessor.getSuccessorInputs()); |
| } |
| |
| // Wire flow between regions and from region exits. |
| for (Region ®ion : regionInterface->getRegions()) { |
| // Iterate over all successor region entries that are reachable from the |
| // current region. |
| SmallVector<RegionSuccessor, 2> successorRegions; |
| regionInterface.getSuccessorRegions(region, successorRegions); |
| for (RegionSuccessor &successorRegion : successorRegions) { |
| // Iterate over all immediate terminator operations and wire the |
| // successor inputs with the successor operands of each terminator. |
| for (Block &block : region) |
| if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>( |
| block.getTerminator())) |
| registerDependencies( |
| terminator.getSuccessorOperands(successorRegion), |
| successorRegion.getSuccessorInputs()); |
| } |
| } |
| |
| return WalkResult::advance(); |
| } |
| |
| // Region terminators are handled together with RegionBranchOpInterface. |
| if (isa<RegionBranchTerminatorOpInterface>(op)) |
| return WalkResult::advance(); |
| |
| if (isa<CallOpInterface>(op)) { |
| // This is an intra-function analysis. We have no information about other |
| // functions. Conservatively assume that each operand may alias with each |
| // result. Also mark the results are terminals because the function could |
| // return newly allocated buffers. |
| populateTerminalValues(op); |
| for (Value operand : op->getOperands()) |
| for (Value result : op->getResults()) |
| registerDependencies({operand}, {result}); |
| return WalkResult::advance(); |
| } |
| |
| // We have no information about unknown ops. |
| populateTerminalValues(op); |
| |
| return WalkResult::advance(); |
| }); |
| } |
| |
| bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const { |
| assert(isa<BaseMemRefType>(value.getType()) && "expected memref"); |
| return terminals.contains(value); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BufferOriginAnalysis |
| //===----------------------------------------------------------------------===// |
| |
| /// Return "true" if the given value is the result of a memory allocation. |
| static bool hasAllocateSideEffect(Value v) { |
| Operation *op = v.getDefiningOp(); |
| if (!op) |
| return false; |
| return hasEffect<MemoryEffects::Allocate>(op, v); |
| } |
| |
| /// Return "true" if the given value is a function block argument. |
| static bool isFunctionArgument(Value v) { |
| auto bbArg = dyn_cast<BlockArgument>(v); |
| if (!bbArg) |
| return false; |
| Block *b = bbArg.getOwner(); |
| auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp()); |
| if (!funcOp) |
| return false; |
| return bbArg.getOwner() == &funcOp.getFunctionBody().front(); |
| } |
| |
| /// Given a memref value, return the "base" value by skipping over all |
| /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. |
| static Value getViewBase(Value value) { |
| while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) |
| value = viewLikeOp.getViewSource(); |
| return value; |
| } |
| |
| BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {} |
| |
| std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) { |
| assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer"); |
| assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer"); |
| |
| // Skip over all view-like ops. |
| v1 = getViewBase(v1); |
| v2 = getViewBase(v2); |
| |
| // Fast path: If both buffers are the same SSA value, we can be sure that |
| // they originate from the same allocation. |
| if (v1 == v2) |
| return true; |
| |
| // Compute the SSA values from which the buffers `v1` and `v2` originate. |
| SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1); |
| SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2); |
| |
| // Originating buffers are "terminal" if they could not be traced back any |
| // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers: |
| // - function block arguments |
| // - values defined by allocation ops such as "memref.alloc" |
| // - values defined by ops that are unknown to the buffer view flow analysis |
| // - values that are marked as "terminal" in the `BufferViewFlowOpInterface` |
| SmallPtrSet<Value, 16> terminal1, terminal2; |
| |
| // While gathering terminal buffers, keep track of whether all terminal |
| // buffers are newly allocated buffer or function entry arguments. |
| bool allAllocs1 = true, allAllocs2 = true; |
| bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true; |
| |
| // Helper function that gathers terminal buffers among `origin`. |
| auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin, |
| SmallPtrSet<Value, 16> &terminal, |
| bool &allAllocs, |
| bool &allAllocsOrFuncEntryArgs) { |
| for (Value v : origin) { |
| if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) { |
| terminal.insert(v); |
| allAllocs &= hasAllocateSideEffect(v); |
| allAllocsOrFuncEntryArgs &= |
| isFunctionArgument(v) || hasAllocateSideEffect(v); |
| } |
| } |
| assert(!terminal.empty() && "expected non-empty terminal set"); |
| }; |
| |
| // Gather terminal buffers for `v1` and `v2`. |
| gatherTerminalBuffers(origin1, terminal1, allAllocs1, |
| allAllocsOrFuncEntryArgs1); |
| gatherTerminalBuffers(origin2, terminal2, allAllocs2, |
| allAllocsOrFuncEntryArgs2); |
| |
| // If both `v1` and `v2` have a single matching terminal buffer, they are |
| // guaranteed to originate from the same buffer allocation. |
| if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) && |
| *terminal1.begin() == *terminal2.begin()) |
| return true; |
| |
| // At least one of the two values has multiple terminals. |
| |
| // Check if there is overlap between the terminal buffers of `v1` and `v2`. |
| bool distinctTerminalSets = true; |
| for (Value v : terminal1) |
| distinctTerminalSets &= !terminal2.contains(v); |
| // If there is overlap between the terminal buffers of `v1` and `v2`, we |
| // cannot make an accurate decision without further analysis. |
| if (!distinctTerminalSets) |
| return std::nullopt; |
| |
| // If `v1` originates from only allocs, and `v2` is guaranteed to originate |
| // from different allocations (that is guaranteed if `v2` originates from |
| // only distinct allocs or function entry arguments), we can be sure that |
| // `v1` and `v2` originate from different allocations. The same argument can |
| // be made when swapping `v1` and `v2`. |
| bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2); |
| bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2; |
| if (isolatedAlloc1 || isolatedAlloc2) |
| return false; |
| |
| // Otherwise: We do not know whether `v1` and `v2` originate from the same |
| // allocation or not. |
| // TODO: Function arguments are currently handled conservatively. We assume |
| // that they could be the same allocation. |
| // TODO: Terminals other than allocations and function arguments are |
| // currently handled conservatively. We assume that they could be the same |
| // allocation. E.g., we currently return "nullopt" for values that originate |
| // from different "memref.get_global" ops (with different symbols). |
| return std::nullopt; |
| } |