blob: 81cd3296de294602cc8a649985f2f74d65e6d294 [file] [log] [blame]
//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
namespace mlir {
#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
#define DEBUG_TYPE "sparse-space-collapse"
using namespace mlir;
using namespace sparse_tensor;
namespace {
struct CollapseSpaceInfo {
ExtractIterSpaceOp space;
IterateOp loop;
};
bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
auto pIterArgs = parent.getRegionIterArgs();
auto nInitArgs = node.getInits();
if (pIterArgs.size() != nInitArgs.size())
return false;
// Two loops are collapsable if they are perfectly nested.
auto pYields = parent.getYieldedValues();
auto nResult = node.getLoopResults().value();
bool yieldEq =
llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
return std::get<0>(zipped) == std::get<1>(zipped);
});
// Parent iter_args should be passed directly to the node's init_args.
bool iterArgEq =
llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
return std::get<0>(zipped) == std::get<1>(zipped);
});
return yieldEq && iterArgEq;
}
bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
ExtractIterSpaceOp curSpace) {
auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
Value spaceVal = space.getExtractedSpace();
if (spaceVal.hasOneUse())
return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
return nullptr;
};
if (toCollapse.empty()) {
// Collapse root.
if (auto itOp = getIterateOpOverSpace(curSpace)) {
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
info.loop = itOp;
return true;
}
return false;
}
auto parent = toCollapse.back().space;
auto pItOp = toCollapse.back().loop;
auto nItOp = getIterateOpOverSpace(curSpace);
// Can only collapse spaces extracted from the same tensor.
if (parent.getTensor() != curSpace.getTensor()) {
LLVM_DEBUG({
llvm::dbgs()
<< "failed to collpase spaces extracted from different tensors.";
});
return false;
}
// Can only collapse consecutive simple iteration on one tensor (i.e., no
// coiteration).
if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
pItOp.getIterator() != curSpace.getParentIter() ||
curSpace->getParentOp() != pItOp.getOperation()) {
LLVM_DEBUG(
{ llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
return false;
}
if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
LLVM_DEBUG({
llvm::dbgs()
<< "failed to collapse IterateOps that are not perfectly nested.";
});
return false;
}
CollapseSpaceInfo &info = toCollapse.emplace_back();
info.space = curSpace;
info.loop = nItOp;
return true;
}
void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
if (toCollapse.size() < 2)
return;
ExtractIterSpaceOp root = toCollapse.front().space;
ExtractIterSpaceOp leaf = toCollapse.back().space;
Location loc = root.getLoc();
assert(root->hasOneUse() && leaf->hasOneUse());
// Insert collapsed operation at the same scope as root operation.
OpBuilder builder(root);
// Construct the collapsed iteration space.
auto collapsedSpace = ExtractIterSpaceOp::create(
builder, loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
leaf.getHiLvl());
auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
auto innermost = toCollapse.back().loop;
IRMapping mapper;
mapper.map(leaf, collapsedSpace.getExtractedSpace());
for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
mapper.map(std::get<0>(z), std::get<1>(z));
auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
builder.setInsertionPointToStart(cloned.getBody());
I64BitSet crdUsedLvls;
unsigned shift = 0, argIdx = 1;
for (auto info : toCollapse.drop_back()) {
I64BitSet set = info.loop.getCrdUsedLvls();
crdUsedLvls |= set.lshift(shift);
shift += info.loop.getSpaceDim();
for (BlockArgument crd : info.loop.getCrds()) {
BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
argIdx++, builder.getIndexType(), crd.getLoc());
crd.replaceAllUsesWith(collapsedCrd);
}
}
crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
cloned.setCrdUsedLvls(crdUsedLvls);
rItOp.replaceAllUsesWith(cloned.getResults());
// Erase collapsed loops.
rItOp.erase();
root.erase();
}
struct SparseSpaceCollapsePass
: public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
SparseSpaceCollapsePass() = default;
void runOnOperation() override {
func::FuncOp func = getOperation();
// A naive (experimental) implementation to collapse consecutive sparse
// spaces. It does NOT handle complex cases where multiple spaces are
// extracted in the same basic block. E.g.,
//
// %space1 = extract_space %t1 ...
// %space2 = extract_space %t2 ...
// sparse_tensor.iterate(%sp1) ...
//
SmallVector<CollapseSpaceInfo> toCollapse;
func->walk([&](ExtractIterSpaceOp op) {
if (!legalToCollapse(toCollapse, op)) {
// if not legal to collapse one more space, collapse the existing ones
// and clear.
collapseSparseSpace(toCollapse);
toCollapse.clear();
}
});
collapseSparseSpace(toCollapse);
}
};
} // namespace
std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
return std::make_unique<SparseSpaceCollapsePass>();
}