| //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_SPARSESPACECOLLAPSE |
| #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" |
| } // namespace mlir |
| |
| #define DEBUG_TYPE "sparse-space-collapse" |
| |
| using namespace mlir; |
| using namespace sparse_tensor; |
| |
| namespace { |
| |
| struct CollapseSpaceInfo { |
| ExtractIterSpaceOp space; |
| IterateOp loop; |
| }; |
| |
| bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) { |
| auto pIterArgs = parent.getRegionIterArgs(); |
| auto nInitArgs = node.getInits(); |
| if (pIterArgs.size() != nInitArgs.size()) |
| return false; |
| |
| // Two loops are collapsable if they are perfectly nested. |
| auto pYields = parent.getYieldedValues(); |
| auto nResult = node.getLoopResults().value(); |
| |
| bool yieldEq = |
| llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) { |
| return std::get<0>(zipped) == std::get<1>(zipped); |
| }); |
| |
| // Parent iter_args should be passed directly to the node's init_args. |
| bool iterArgEq = |
| llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) { |
| return std::get<0>(zipped) == std::get<1>(zipped); |
| }); |
| |
| return yieldEq && iterArgEq; |
| } |
| |
| bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse, |
| ExtractIterSpaceOp curSpace) { |
| |
| auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp { |
| Value spaceVal = space.getExtractedSpace(); |
| if (spaceVal.hasOneUse()) |
| return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin()); |
| return nullptr; |
| }; |
| |
| if (toCollapse.empty()) { |
| // Collapse root. |
| if (auto itOp = getIterateOpOverSpace(curSpace)) { |
| CollapseSpaceInfo &info = toCollapse.emplace_back(); |
| info.space = curSpace; |
| info.loop = itOp; |
| return true; |
| } |
| return false; |
| } |
| |
| auto parent = toCollapse.back().space; |
| auto pItOp = toCollapse.back().loop; |
| auto nItOp = getIterateOpOverSpace(curSpace); |
| |
| // Can only collapse spaces extracted from the same tensor. |
| if (parent.getTensor() != curSpace.getTensor()) { |
| LLVM_DEBUG({ |
| llvm::dbgs() |
| << "failed to collpase spaces extracted from different tensors."; |
| }); |
| return false; |
| } |
| |
| // Can only collapse consecutive simple iteration on one tensor (i.e., no |
| // coiteration). |
| if (!nItOp || nItOp->getBlock() != curSpace->getBlock() || |
| pItOp.getIterator() != curSpace.getParentIter() || |
| curSpace->getParentOp() != pItOp.getOperation()) { |
| LLVM_DEBUG( |
| { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; }); |
| return false; |
| } |
| |
| if (pItOp && !isCollapsableLoops(pItOp, nItOp)) { |
| LLVM_DEBUG({ |
| llvm::dbgs() |
| << "failed to collapse IterateOps that are not perfectly nested."; |
| }); |
| return false; |
| } |
| |
| CollapseSpaceInfo &info = toCollapse.emplace_back(); |
| info.space = curSpace; |
| info.loop = nItOp; |
| return true; |
| } |
| |
| void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) { |
| if (toCollapse.size() < 2) |
| return; |
| |
| ExtractIterSpaceOp root = toCollapse.front().space; |
| ExtractIterSpaceOp leaf = toCollapse.back().space; |
| Location loc = root.getLoc(); |
| |
| assert(root->hasOneUse() && leaf->hasOneUse()); |
| |
| // Insert collapsed operation at the same scope as root operation. |
| OpBuilder builder(root); |
| |
| // Construct the collapsed iteration space. |
| auto collapsedSpace = ExtractIterSpaceOp::create( |
| builder, loc, root.getTensor(), root.getParentIter(), root.getLoLvl(), |
| leaf.getHiLvl()); |
| |
| auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin()); |
| auto innermost = toCollapse.back().loop; |
| |
| IRMapping mapper; |
| mapper.map(leaf, collapsedSpace.getExtractedSpace()); |
| for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs())) |
| mapper.map(std::get<0>(z), std::get<1>(z)); |
| |
| auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper)); |
| builder.setInsertionPointToStart(cloned.getBody()); |
| |
| I64BitSet crdUsedLvls; |
| unsigned shift = 0, argIdx = 1; |
| for (auto info : toCollapse.drop_back()) { |
| I64BitSet set = info.loop.getCrdUsedLvls(); |
| crdUsedLvls |= set.lshift(shift); |
| shift += info.loop.getSpaceDim(); |
| for (BlockArgument crd : info.loop.getCrds()) { |
| BlockArgument collapsedCrd = cloned.getBody()->insertArgument( |
| argIdx++, builder.getIndexType(), crd.getLoc()); |
| crd.replaceAllUsesWith(collapsedCrd); |
| } |
| } |
| crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift); |
| cloned.getIterator().setType(collapsedSpace.getType().getIteratorType()); |
| cloned.setCrdUsedLvls(crdUsedLvls); |
| |
| rItOp.replaceAllUsesWith(cloned.getResults()); |
| // Erase collapsed loops. |
| rItOp.erase(); |
| root.erase(); |
| } |
| |
| struct SparseSpaceCollapsePass |
| : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> { |
| SparseSpaceCollapsePass() = default; |
| |
| void runOnOperation() override { |
| func::FuncOp func = getOperation(); |
| |
| // A naive (experimental) implementation to collapse consecutive sparse |
| // spaces. It does NOT handle complex cases where multiple spaces are |
| // extracted in the same basic block. E.g., |
| // |
| // %space1 = extract_space %t1 ... |
| // %space2 = extract_space %t2 ... |
| // sparse_tensor.iterate(%sp1) ... |
| // |
| SmallVector<CollapseSpaceInfo> toCollapse; |
| func->walk([&](ExtractIterSpaceOp op) { |
| if (!legalToCollapse(toCollapse, op)) { |
| // if not legal to collapse one more space, collapse the existing ones |
| // and clear. |
| collapseSparseSpace(toCollapse); |
| toCollapse.clear(); |
| } |
| }); |
| |
| collapseSparseSpace(toCollapse); |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() { |
| return std::make_unique<SparseSpaceCollapsePass>(); |
| } |