[mlir][sparse] allow multiple COO segments in sparse encodings. (#91786)
**NOTE**: we still have implementation holes when handling multiple COO
segments in the encoding. But the format should be considered to be
legal.
GitOrigin-RevId: 13af97a70e7202507dcca89d2f732e5126d2bbcd
diff --git a/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 4cc6ee9..4adb1c1 100644
--- a/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -788,24 +788,29 @@
return emitError() << "unexpected position bitwidth: " << posWidth;
if (!acceptBitWidth(crdWidth))
return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
- if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
- it != std::end(lvlTypes)) {
+
+ // Verify every COO segment.
+ auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
+ while (it != lvlTypes.end()) {
if (it == lvlTypes.begin() ||
- (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
+ !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
return emitError() << "expected compressed or loose_compressed level "
"before singleton level";
- if (!std::all_of(it, lvlTypes.end(),
+
+ auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
+ if (!std::all_of(it, curCOOEnd,
[](LevelType i) { return isSingletonLT(i); }))
return emitError() << "expected all singleton lvlTypes "
"following a singleton level";
// We can potentially support mixed SoA/AoS singleton levels.
- if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
+ if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
return it->isa<LevelPropNonDefault::SoA>() ==
i.isa<LevelPropNonDefault::SoA>();
})) {
return emitError() << "expected all singleton lvlTypes stored in the "
"same memory layout (SoA vs AoS).";
}
+ it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
}
auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);
diff --git a/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 7fb1c76..44710ca 100644
--- a/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -156,6 +156,17 @@
// -----
+#COO_DENSE = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2: dense)
+}>
+
+// CHECK-DAG: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2 : dense) }>
+// CHECK-LABEL: func private @sparse_coo_trailing_dense(
+// CHECK-SAME: tensor<?x?x1xf32, #[[$COO]]>)
+func.func private @sparse_coo_trailing_dense(tensor<?x?x1xf32, #COO_DENSE>)
+
+// -----
+
#BCOO = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
}>