[mlir][sparse] generalize sparse tensor output implementation

Moves sparse tensor output support forward by generalizing from injective
insertions only to include reductions. This revision accepts the case with all
parallel outer and all reduction inner loops, since that can be handled with
an injective insertion still. Next revision will allow the inner parallel loop
to move inward (but that will require "access pattern expansion" aka "workspace").

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D114399
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index d396f7a..8724ff3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -122,7 +122,7 @@
   /// invariant expressions in the kernel.
   Merger(unsigned t, unsigned l)
       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
-        dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
+        hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value());
@@ -200,6 +200,9 @@
   /// Dimension setter.
   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
 
+  // Has sparse output tensor setter.
+  void setHasSparseOut(bool s) { hasSparseOut = s; }
+
   /// Convenience getters to immediately access the stored nodes.
   /// Typically it is inadvisible to keep the reference around, as in
   /// "TensorExpr &te = merger.exp(e))", since insertions into the merger
@@ -230,6 +233,7 @@
                  Value v1);
 
 private:
+  /// Private helpers.
   bool maybeZero(unsigned e) const;
   bool isInvariant(unsigned e) const;
   Type inferType(unsigned e, Value src);
@@ -237,11 +241,12 @@
   /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
   Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
 
+  /// Merger data structures.
   const unsigned outTensor;
   const unsigned syntheticTensor;
   const unsigned numTensors;
   const unsigned numLoops;
-
+  bool hasSparseOut;
   std::vector<std::vector<Dim>> dims;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 3b7a0cf..d640af0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -46,15 +46,15 @@
 // Code generation.
 struct CodeGen {
   CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops,
-          OpOperand *op)
+          OpOperand *op, unsigned nest)
       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
         pointers(numTensors, std::vector<Value>(numLoops)),
         indices(numTensors, std::vector<Value>(numLoops)),
         highs(numTensors, std::vector<Value>(numLoops)),
         pidxs(numTensors, std::vector<Value>(numLoops)),
         idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(),
-        redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1),
-        curVecMask() {}
+        redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(),
+        curVecLength(1), curVecMask() {}
   /// Sparsification options.
   SparsificationOptions options;
   /// Universal dense indices and upper bounds (by index). The loops array
@@ -79,8 +79,11 @@
   unsigned redExp;
   Value redVal;
   Reduction redKind;
-  // Sparse tensor as output.
+  // Sparse tensor as output. Implemented either through direct injective
+  // insertion in lexicographic index order (where indices are updated
+  // in the temporary array `lexIdx`) or TODO: access pattern expansion
   OpOperand *sparseOut;
+  unsigned outerParNest;
   Value lexIdx;
   // Current vector length and mask.
   unsigned curVecLength;
@@ -288,10 +291,13 @@
 
 /// Returns true when the tensor expression is admissable for codegen.
 /// Since all sparse input tensors are admissable, we just need to check
-/// whether the output tensor in the tensor expression codegen is admissable.
-/// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs.
+/// whether the out tensor in the tensor expression codegen is admissable.
+/// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective
+/// nesting depth when a "truly dynamic" sparse tensor output occurs.
 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
-                                  unsigned exp, OpOperand **sparseOut) {
+                                  std::vector<unsigned> &topSort, unsigned exp,
+                                  OpOperand **sparseOut,
+                                  unsigned &outerParNest) {
   OpOperand *lhs = op.getOutputOperand(0);
   unsigned tensor = lhs->getOperandNumber();
   auto enc = getSparseTensorEncoding(lhs->get().getType());
@@ -302,7 +308,8 @@
   // An all-dense annotated "sparse" output tensor becomes a linearized random
   // access 1-dim memref. Also admissable since insertions cannot occur.
   bool allDense = true;
-  unsigned numLoops = op.iterator_types().getValue().size();
+  auto iteratorTypes = op.iterator_types().getValue();
+  unsigned numLoops = iteratorTypes.size();
   for (unsigned i = 0; i < numLoops; i++)
     if (merger.isDim(tensor, i, Dim::kSparse)) {
       allDense = false;
@@ -319,15 +326,20 @@
   // Accept "truly dynamic" if the output tensor materializes uninitialized
   // into the computation and insertions occur in lexicographic index order.
   if (isMaterializing(lhs->get())) {
-    // In this first sparse tensor output implementation, this is enforced by
-    // rejecting any reduction loops (since the sparse parallel loops give a
-    // lexicographically sorted and injective view into that tensor).
-    // TODO: generalize to include reductions
-    for (auto attr : op.iterator_types())
-      if (isReductionIterator(attr))
-        return false;
-    *sparseOut = lhs;
-    return true;
+    unsigned nest = 0;
+    for (unsigned i = 0; i < numLoops; i++) {
+      if (isReductionIterator(iteratorTypes[topSort[i]]))
+        break; // terminate at first reduction
+      nest++;
+    }
+    // Determine admissable dynamic insertion situations:
+    // (1) fully injective, since there are no reductions,
+    // (2) admissable 1-d expansion in innermost dimension. TODO: accept
+    if (nest == op.getRank(lhs)) {
+      *sparseOut = lhs;
+      outerParNest = nest;
+      return true;
+    }
   }
   return false;
 }
@@ -704,9 +716,15 @@
       return genVectorInvariantValue(codegen, rewriter, val);
     return val;
   }
+  // Insertion (a sparse tensor output "loads" as zero).
+  OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
+  if (t == codegen.sparseOut) {
+    Type tp = getElementTypeOrSelf(t->get().getType());
+    return rewriter.create<arith::ConstantOp>(op.getLoc(), tp,
+                                              rewriter.getZeroAttr(tp));
+  }
   // Actual load.
   SmallVector<Value, 4> args;
-  OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
   Value ptr = genSubscript(codegen, rewriter, op, t, args);
   if (codegen.curVecLength > 1)
     return genVectorLoad(codegen, rewriter, ptr, args);
@@ -1515,11 +1533,14 @@
 
     // Rejects an inadmissable tensor expression.
     OpOperand *sparseOut = nullptr;
-    if (!isAdmissableTensorExp(merger, op, exp, &sparseOut))
+    unsigned outerParNest = 0;
+    if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut,
+                               outerParNest))
       return failure();
 
     // Recursively generates code.
-    CodeGen codegen(options, numTensors, numLoops, sparseOut);
+    merger.setHasSparseOut(sparseOut != nullptr);
+    CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest);
     genBuffers(merger, codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, topSort, exp, 0);
     genResult(merger, codegen, rewriter, op);
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index cd75911..466191d 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -415,9 +415,13 @@
   case kInvariant: {
     // Either the index is really used in the tensor expression, or it is
     // set to the undefined index in that dimension. An invariant expression
-    // is set to a synthetic tensor with undefined indices only.
+    // and a truly dynamic sparse output tensor are set to a synthetic tensor
+    // with undefined indices only to ensure the iteration space is not
+    // skipped as a result of their contents.
     unsigned s = addSet();
     unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
+    if (hasSparseOut && t == outTensor)
+      t = syntheticTensor;
     latSets[s].push_back(addLat(t, i, e));
     return s;
   }
@@ -593,8 +597,8 @@
     }
   }
   // Construct binary operations if subexpressions can be built.
-  // TODO: see buildLattices() for an explanation of rejecting
-  //       certain division and shift operations
+  // See buildLattices() for an explanation of rejecting certain
+  // division and shift operations
   if (def->getNumOperands() == 2) {
     auto x = buildTensorExp(op, def->getOperand(0));
     auto y = buildTensorExp(op, def->getOperand(1));
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 90ba2ff..5481c51 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -11,6 +11,10 @@
   dimOrdering = affine_map<(i,j) -> (i,j)>
 }>
 
+#SparseTensor = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed", "compressed" ]
+}>
+
 #trait_scale_inpl = {
   indexing_maps = [
     affine_map<(i,j) -> (i,j)>   // X (out)
@@ -182,3 +186,161 @@
   } -> tensor<10x20xf32, #DCSR>
   return %0 : tensor<10x20xf32, #DCSR>
 }
+
+#trait_sumred = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,j,k)>, // A
+    affine_map<(i,j,k) -> (i,j,k)>, // B
+    affine_map<(i,j,k) -> (i,j)>    // X (out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)"
+}
+
+// CHECK-LABEL:   func @sumred(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>)
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>>
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>>
+// CHECK:           %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_6]], %[[VAL_7]]] : tensor<?x?xi32, #{{.*}}>>
+// CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_19:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_20:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_21:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xi32, #{{.*}}>> to memref<?xi32>
+// CHECK:           %[[VAL_23:.*]] = memref.alloca(%[[VAL_4]]) : memref<?xindex>
+// CHECK:           %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:           %[[VAL_27:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:           %[[VAL_28:.*]]:2 = scf.while (%[[VAL_29:.*]] = %[[VAL_24]], %[[VAL_30:.*]] = %[[VAL_26]]) : (index, index) -> (index, index) {
+// CHECK:             %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_25]] : index
+// CHECK:             %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_27]] : index
+// CHECK:             %[[VAL_33:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1
+// CHECK:             scf.condition(%[[VAL_33]]) %[[VAL_29]], %[[VAL_30]] : index, index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index):
+// CHECK:             %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:             %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK:             %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index
+// CHECK:             %[[VAL_39:.*]] = select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index
+// CHECK:             memref.store %[[VAL_39]], %[[VAL_23]]{{\[}}%[[VAL_2]]] : memref<?xindex>
+// CHECK:             %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index
+// CHECK:             %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index
+// CHECK:             %[[VAL_42:.*]] = arith.andi %[[VAL_40]], %[[VAL_41]] : i1
+// CHECK:             scf.if %[[VAL_42]] {
+// CHECK:               %[[VAL_43:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_34]]] : memref<?xindex>
+// CHECK:               %[[VAL_44:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_45:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_44]]] : memref<?xindex>
+// CHECK:               %[[VAL_46:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_35]]] : memref<?xindex>
+// CHECK:               %[[VAL_47:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
+// CHECK:               %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// CHECK:               %[[VAL_49:.*]]:2 = scf.while (%[[VAL_50:.*]] = %[[VAL_43]], %[[VAL_51:.*]] = %[[VAL_46]]) : (index, index) -> (index, index) {
+// CHECK:                 %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_45]] : index
+// CHECK:                 %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_48]] : index
+// CHECK:                 %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1
+// CHECK:                 scf.condition(%[[VAL_54]]) %[[VAL_50]], %[[VAL_51]] : index, index
+// CHECK:               } do {
+// CHECK:               ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index):
+// CHECK:                 %[[VAL_57:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:                 %[[VAL_58:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_56]]] : memref<?xindex>
+// CHECK:                 %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index
+// CHECK:                 %[[VAL_60:.*]] = select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index
+// CHECK:                 memref.store %[[VAL_60]], %[[VAL_23]]{{\[}}%[[VAL_3]]] : memref<?xindex>
+// CHECK:                 %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1
+// CHECK:                 scf.if %[[VAL_63]] {
+// CHECK:                   %[[VAL_64:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref<?xindex>
+// CHECK:                   %[[VAL_65:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index
+// CHECK:                   %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_65]]] : memref<?xindex>
+// CHECK:                   %[[VAL_67:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_56]]] : memref<?xindex>
+// CHECK:                   %[[VAL_68:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
+// CHECK:                   %[[VAL_69:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_68]]] : memref<?xindex>
+// CHECK:                   %[[VAL_70:.*]]:3 = scf.while (%[[VAL_71:.*]] = %[[VAL_64]], %[[VAL_72:.*]] = %[[VAL_67]], %[[VAL_73:.*]] = %[[VAL_5]]) : (index, index, i32) -> (index, index, i32) {
+// CHECK:                     %[[VAL_74:.*]] = arith.cmpi ult, %[[VAL_71]], %[[VAL_66]] : index
+// CHECK:                     %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_72]], %[[VAL_69]] : index
+// CHECK:                     %[[VAL_76:.*]] = arith.andi %[[VAL_74]], %[[VAL_75]] : i1
+// CHECK:                     scf.condition(%[[VAL_76]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, i32
+// CHECK:                   } do {
+// CHECK:                   ^bb0(%[[VAL_77:.*]]: index, %[[VAL_78:.*]]: index, %[[VAL_79:.*]]: i32):
+// CHECK:                     %[[VAL_80:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_77]]] : memref<?xindex>
+// CHECK:                     %[[VAL_81:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_78]]] : memref<?xindex>
+// CHECK:                     %[[VAL_82:.*]] = arith.cmpi ult, %[[VAL_81]], %[[VAL_80]] : index
+// CHECK:                     %[[VAL_83:.*]] = select %[[VAL_82]], %[[VAL_81]], %[[VAL_80]] : index
+// CHECK:                     memref.store %[[VAL_83]], %[[VAL_23]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:                     %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index
+// CHECK:                     %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index
+// CHECK:                     %[[VAL_86:.*]] = arith.andi %[[VAL_84]], %[[VAL_85]] : i1
+// CHECK:                     %[[VAL_87:.*]] = scf.if %[[VAL_86]] -> (i32) {
+// CHECK:                       %[[VAL_88:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_77]]] : memref<?xi32>
+// CHECK:                       %[[VAL_89:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_78]]] : memref<?xi32>
+// CHECK:                       %[[VAL_90:.*]] = arith.muli %[[VAL_88]], %[[VAL_89]] : i32
+// CHECK:                       %[[VAL_91:.*]] = arith.addi %[[VAL_79]], %[[VAL_90]] : i32
+// CHECK:                       scf.yield %[[VAL_91]] : i32
+// CHECK:                     } else {
+// CHECK:                       scf.yield %[[VAL_79]] : i32
+// CHECK:                     }
+// CHECK:                     %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index
+// CHECK:                     %[[VAL_93:.*]] = arith.addi %[[VAL_77]], %[[VAL_3]] : index
+// CHECK:                     %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_77]] : index
+// CHECK:                     %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index
+// CHECK:                     %[[VAL_96:.*]] = arith.addi %[[VAL_78]], %[[VAL_3]] : index
+// CHECK:                     %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index
+// CHECK:                     scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32
+// CHECK:                   }
+// CHECK:                   sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor<?x?xi32, #{{.*}}>, memref<?xindex>, i32
+// CHECK:                 } else {
+// CHECK:                 }
+// CHECK:                 %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_101:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index
+// CHECK:                 %[[VAL_102:.*]] = select %[[VAL_100]], %[[VAL_101]], %[[VAL_55]] : index
+// CHECK:                 %[[VAL_103:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index
+// CHECK:                 %[[VAL_104:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index
+// CHECK:                 %[[VAL_105:.*]] = select %[[VAL_103]], %[[VAL_104]], %[[VAL_56]] : index
+// CHECK:                 scf.yield %[[VAL_102]], %[[VAL_105]] : index, index
+// CHECK:               }
+// CHECK:             } else {
+// CHECK:             }
+// CHECK:             %[[VAL_106:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index
+// CHECK:             %[[VAL_107:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_34]] : index
+// CHECK:             %[[VAL_109:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index
+// CHECK:             %[[VAL_110:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_111:.*]] = select %[[VAL_109]], %[[VAL_110]], %[[VAL_35]] : index
+// CHECK:             scf.yield %[[VAL_108]], %[[VAL_111]] : index, index
+// CHECK:           }
+// CHECK:           %[[VAL_112:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<?x?xi32, #{{.*}}>
+// CHECK:           return %[[VAL_112]] : tensor<?x?xi32, #{{.*}}>
+// CHECK:         }
+func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
+             %argb: tensor<?x?x?xi32, #SparseTensor>) -> tensor<?x?xi32, #DCSR> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
+  %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
+  %xinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xi32, #DCSR>
+  %0 = linalg.generic #trait_sumred
+    ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
+                      tensor<?x?x?xi32, #SparseTensor>)
+    outs(%xinit: tensor<?x?xi32, #DCSR>) {
+      ^bb(%a: i32, %b: i32, %x: i32):
+        %0 = arith.muli %a, %b : i32
+        %1 = arith.addi %x, %0 : i32
+        linalg.yield %1 : i32
+  } -> tensor<?x?xi32, #DCSR>
+  return %0 : tensor<?x?xi32, #DCSR>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
new file mode 100644
index 0000000..0834323
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir
@@ -0,0 +1,99 @@
+// RUN: mlir-opt %s \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --linalg-bufferize --convert-linalg-to-loops \
+// RUN:   --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN:   --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#SparseTensor = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed", "compressed" ]
+}>
+
+#redsum = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,j,k)>, // A
+    affine_map<(i,j,k) -> (i,j,k)>, // B
+    affine_map<(i,j,k) -> (i,j)>    // X (out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)"
+}
+
+module {
+  func @redsum(%arga: tensor<?x?x?xi32, #SparseTensor>,
+               %argb: tensor<?x?x?xi32, #SparseTensor>)
+	           -> tensor<?x?xi32, #SparseMatrix> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?x?xi32, #SparseTensor>
+    %d1 = tensor.dim %arga, %c1 : tensor<?x?x?xi32, #SparseTensor>
+    %xinit = sparse_tensor.init [%d0, %d1] : tensor<?x?xi32, #SparseMatrix>
+    %0 = linalg.generic #redsum
+      ins(%arga, %argb: tensor<?x?x?xi32, #SparseTensor>,
+                        tensor<?x?x?xi32, #SparseTensor>)
+      outs(%xinit: tensor<?x?xi32, #SparseMatrix>) {
+        ^bb(%a: i32, %b: i32, %x: i32):
+          %0 = arith.muli %a, %b : i32
+          %1 = arith.addi %x, %0 : i32
+          linalg.yield %1 : i32
+    } -> tensor<?x?xi32, #SparseMatrix>
+    return %0 : tensor<?x?xi32, #SparseMatrix>
+  }
+
+  // Driver method to call and verify tensor kernel.
+  func @entry() {
+    %c0 = arith.constant 0 : index
+    %i0 = arith.constant -1 : i32
+
+    // Setup very sparse 3-d tensors.
+    %t1 = arith.constant sparse<
+       [ [1,1,3], [2,0,0], [2,2,1], [2,2,2], [2,2,3] ], [ 1, 2, 3, 4, 5 ]
+    > : tensor<3x3x4xi32>
+    %t2 = arith.constant sparse<
+       [ [1,0,0], [1,1,3], [2,2,1], [2,2,3] ], [ 6, 7, 8, 9 ]
+    > : tensor<3x3x4xi32>
+    %st1 = sparse_tensor.convert %t1
+      : tensor<3x3x4xi32> to tensor<?x?x?xi32, #SparseTensor>
+    %st2 = sparse_tensor.convert %t2
+      : tensor<3x3x4xi32> to tensor<?x?x?xi32, #SparseTensor>
+
+
+    // Call kernel.
+    %0 = call @redsum(%st1, %st2)
+      : (tensor<?x?x?xi32, #SparseTensor>,
+         tensor<?x?x?xi32, #SparseTensor>) -> tensor<?x?xi32, #SparseMatrix>
+
+    //
+    // Verify results. Only two entries stored in result. Correct structure.
+    //
+    // CHECK: ( 7, 69, -1, -1 )
+    // CHECK-NEXT: ( ( 0, 0, 0 ), ( 0, 7, 0 ), ( 0, 0, 69 ) )
+    //
+    %val = sparse_tensor.values %0
+      : tensor<?x?xi32, #SparseMatrix> to memref<?xi32>
+    %vv = vector.transfer_read %val[%c0], %i0: memref<?xi32>, vector<4xi32>
+    vector.print %vv : vector<4xi32>
+    %dm = sparse_tensor.convert %0
+      : tensor<?x?xi32, #SparseMatrix> to tensor<?x?xi32>
+    %db = bufferization.to_memref %dm : memref<?x?xi32>
+    %vm = vector.transfer_read %db[%c0, %c0], %i0: memref<?x?xi32>, vector<3x3xi32>
+    vector.print %vm : vector<3x3xi32>
+
+    // Release the resources.
+    sparse_tensor.release %st1 : tensor<?x?x?xi32, #SparseTensor>
+    sparse_tensor.release %st2 : tensor<?x?x?xi32, #SparseTensor>
+    sparse_tensor.release %0 : tensor<?x?xi32, #SparseMatrix>
+    memref.dealloc %db : memref<?x?xi32>
+    return
+  }
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
index 3d2da32..08e380d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir
@@ -144,7 +144,7 @@
     return %0 : tensor<f64>
   }
 
-  // Dumps just the values array of the sparse vector.
+  // Dumps a sparse vector.
   func @dump(%arg0: tensor<?xf64, #SparseVector>) {
     // Dump the values array to verify only sparse contents are stored.
     %c0 = arith.constant 0 : index