[mlir][SparseTensor] add `numSymbols` information to simplify affine expressions (#191649)
Previously, the `translateShape` function hard-coded the `numSymbols`
parameter to 0. This makes the affine expression fail when the sparse
tensor encoding has symbols.
This PR fixes the issue by extracting and passing the `numSymbols`
information during translation. A regression test has also been added to
ensure this behavior remains supported.
Closes #191209
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b77a536..eab2d14 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -525,10 +525,15 @@
}
};
+ // The number of symbols information is included inside the `dimToLvl` map
+ // during parsing. Here, we're extracting it to be used when simplifying the
+ // affine expression.
+ unsigned numSymbols = getDimToLvl().getNumSymbols();
+
for (AffineExpr exp : transMap.getResults()) {
// Do constant propagation on the affine map.
- AffineExpr evalExp =
- simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
+ AffineExpr evalExp = simplifyAffineExpr(exp.replaceDims(dimRep),
+ srcShape.size(), numSymbols);
// use llvm namespace here to avoid ambiguity
if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
ret.push_back(c.getValue() + 1);
diff --git a/mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir b/mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir
new file mode 100644
index 0000000..7cd68ee
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -sparsification-and-bufferization | FileCheck %s
+
+// Tests that mlir-opt does not crash when parsing sparse tensor encodings with symbols.
+
+// CHECK-DAG: #[[$SPARSE_0:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed) }>
+// CHECK-DAG: #[[$SPARSE_1:.*]] = #sparse_tensor.encoding<{ map = [s0](d0, d1) -> (d0 * (s0 * 3) : dense, d0 : dense, d1 : compressed) }>
+
+#Sparse = #sparse_tensor.encoding<{
+ map = [c](i, j) -> (c * 3 * i : dense, i : dense, j : compressed)
+}>
+
+// CHECK-LABEL: func.func @tensor_add(
+// CHECK-SAME: %{{.*}}: memref<?xindex>, %{{.*}}: memref<?xindex>, %{{.*}}: memref<?xf32>,
+// CHECK-SAME: %{{.*}}: !sparse_tensor.storage_specifier<#[[$SPARSE_0]]>) -> memref<8x8xf32> {
+func.func @tensor_add(%arg0: tensor<8x8xf32, #Sparse>) -> tensor<8x8xf32> {
+ %result_out = tensor.empty() : tensor<8x8xf32>
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32>
+ // CHECK: %[[RES:.*]] = linalg.add ins(%{{.*}}, %{{.*}} : tensor<8x8xf32, #[[$SPARSE_1]]>, tensor<8x8xf32, #[[$SPARSE_1]]>)
+ %result = linalg.add
+ ins(%arg0, %arg0 : tensor<8x8xf32, #Sparse>, tensor<8x8xf32, #Sparse>)
+ outs(%result_out : tensor<8x8xf32>) -> tensor<8x8xf32>
+
+ // CHECK: return %{{.*}} : memref<8x8xf32>
+ return %result : tensor<8x8xf32>
+}