[mlir][sparse] support sparsification to coiterate operations. (#102546)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 6e17f80..2803223 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1749,6 +1749,10 @@ let results = (outs Variadic<AnyType>:$results); let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions); + let builders = [ + OpBuilder<(ins "ValueRange":$iterSpace, "ValueRange":$initArgs, "unsigned":$numCases)>, + ]; + let extraClassDeclaration = [{ unsigned getSpaceDim() { return llvm::cast<::mlir::sparse_tensor::IterSpaceType>( @@ -1765,18 +1769,18 @@ }); } - // The block arguments starts with referenced coordinates, follows by - // user-provided iteration arguments and ends with iterators. + // The block arguments starts with user-provided iteration arguments, + // follows by referenced coordinates and ends with iterators. Block::BlockArgListType getCrds(unsigned regionIdx) { return getRegion(regionIdx).getArguments() - .take_front(getCrdUsedLvls().count()); + .slice(getNumRegionIterArgs(), getCrdUsedLvls().count()); } - unsigned getNumRegionIterArgs(unsigned regionIdx) { + unsigned getNumRegionIterArgs() { return getInitArgs().size(); } Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) { return getRegion(regionIdx).getArguments() - .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx)); + .take_front(getNumRegionIterArgs()); } Block::BlockArgListType getRegionIterators(unsigned regionIdx) { return getRegion(regionIdx).getArguments()
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index a284aa2..a143189 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2293,9 +2293,10 @@ if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren)) return failure(); - if (failed(parseUsedCoordList(parser, state, blockArgs))) + SmallVector<OpAsmParser::Argument> coords; + if (failed(parseUsedCoordList(parser, state, coords))) return failure(); - size_t numCrds = blockArgs.size(); + size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" SmallVector<OpAsmParser::UnresolvedOperand> initArgs; @@ -2303,6 +2304,7 @@ if (hasIterArgs) if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); + blockArgs.append(coords); SmallVector<Type> iterSpaceTps; // parse ": (sparse_tensor.iter_space, ...) -> ret" @@ -2326,8 +2328,8 @@ state.operands.append(spacesVals); if (hasIterArgs) { - // Strip off leading args that used for coordinates. - MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); + // Strip off trailing args that used for coordinates. + MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), @@ -2602,6 +2604,24 @@ regions.push_back(RegionSuccessor(getResults())); } +void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, + ValueRange iterSpaces, ValueRange initArgs, + unsigned numCases) { + unsigned rank = + cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim(); + // All ones. + I64BitSet set((1 << rank) - 1); + // Generates all-zero case bits (they only serve as placeholders), which are + // supposed to be overriden later. We need to preallocate all the regions as + // mlir::Region cannot be dynamically added later after the operation is + // created. + SmallVector<int64_t> caseBits(numCases, 0); + ArrayAttr cases = builder.getI64ArrayAttr(caseBits); + return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces, + initArgs, set, cases, + /*caseRegionsCount=*/numCases); +} + ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector<Value> spaces; @@ -2685,7 +2705,7 @@ LogicalResult CoIterateOp::verifyRegions() { for (unsigned r = 0, e = getNumRegions(); r < e; r++) { - if (getNumRegionIterArgs(r) != getNumResults()) + if (getNumRegionIterArgs() != getNumResults()) return emitOpError( "mismatch in number of basic block args and defined values");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 5fb009e..cc372ed 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1395,7 +1395,7 @@ loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, + loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1, reduc); }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 08fc104..bf12dc8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -842,11 +842,13 @@ /// one sparse level in the list. static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder, ArrayRef<TensorLevel> tidLvls, - bool tryParallel, bool needsUniv) { + unsigned numCases, bool tryParallel, + bool needsUniv) { Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { // Construct while-loop with a parameter for each index. return env.emitter().enterCoIterationOverTensorsAtLvls( - builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv); + builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel, + needsUniv); }); assert(loop); return loop; @@ -855,9 +857,11 @@ /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, - bool needsUniv, ArrayRef<TensorLevel> tidLvls) { + unsigned numCases, bool needsUniv, + ArrayRef<TensorLevel> tidLvls) { bool tryParallel = shouldTryParallize(env, curr, tidLvls); - return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv); + return genCoIteration(env, builder, tidLvls, numCases, tryParallel, + needsUniv); } /// Generates the induction structure for a while-loop. @@ -900,6 +904,26 @@ // basic block where scf::Yield should be inserted. } +/// Generates a case region in the coiterate operation. +static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder, + unsigned caseIdx, LatPointId allCase, + LatPointId curCase, + MutableArrayRef<Value> reduc) { + assert(allCase == curCase || env.merger().latGT(allCase, curCase)); + const BitVector &allCaseBits = env.merger().lat(allCase).simple; + const BitVector &curCaseBits = env.merger().lat(curCase).simple; + + /// Computes the subset of iterators that are valid in the current case being + /// generated. + I64BitSet caseBit(0); + for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits())) + if (curCaseBits.test(set)) + caseBit.set(idx); + + env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit, + caseIdx, reduc); +} + /// Generates a single if-statement within a while-loop. static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, LatPointId p) { @@ -1175,7 +1199,10 @@ /// Starts a single loop in current sequence. static std::pair<Operation *, bool> startLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr, - LatPointId li, bool needsUniv) { + LatPointId li, unsigned numCases, + bool needsUniv) { + // TODO: numCases only used when generating iterator-based loops. Cleanup + // after fully migration. // The set of tensors + lvls to generate loops on SmallVector<TensorLevel> tidLvls; @@ -1186,7 +1213,7 @@ translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls); // Emit the for/while-loop control. - Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls); + Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls); Location loc = env.op().getLoc(); for (auto [tidLvl, exp] : affineTidLvls) { env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp); @@ -1259,42 +1286,73 @@ // Start a loop sequence. bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts); - // Emit a loop for every lattice point L0 >= Li in this loop sequence. - // We cannot change this to `for (const LatPointId li : env.set(lts))` - // because the loop body causes data-movement which invalidates - // the iterator. + // When using sparse-iterator-based loops, we only need one loops, as + // opposed to a loop sequence, to cover all the iterator spaces. const unsigned lsize = env.set(lts).size(); - for (unsigned i = 0; i < lsize; i++) { - const LatPointId li = env.set(lts)[i]; - // Start a loop. - auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv); - - // Visit all lattices points with Li >= Lj to generate the - // loop-body, possibly with if statements for coiteration. - Value redInput = env.getReduc(); - Value cntInput = env.getExpandCount(); - Value insInput = env.getInsertionChain(); - Value validIns = env.getValidLexInsert(); - // We cannot change this to `for (const LatPointId lj : env.set(lts))` - // because the loop body causes data-movement which invalidates the - // iterator. + if (env.generatingSparseIterator()) { + // Get the largest lattice point and start a loop. + const LatPointId li = env.set(lts)[0]; + auto [loop, isSingleCond] = + startLoop(env, rewriter, curr, li, lsize, needsUniv); + assert(isSingleCond == llvm::isa<IterateOp>(loop)); + // We cannot change this to `for (const LatPointId li : env.set(lts))` + // because the loop body causes data-movement which invalidates + // the iterator. for (unsigned j = 0; j < lsize; j++) { const LatPointId lj = env.set(lts)[j]; const ExprId ej = env.lat(lj).exp; - if (li == lj || env.merger().latGT(li, lj)) { - // Recurse into body of each branch. - if (!isSingleCond) { - scf::IfOp ifOp = genIf(env, rewriter, curr, lj); + // Recurse into body of each branch. + if (!isSingleCond) { + env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) { + genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc); genStmt(env, rewriter, ej, curr + 1); - endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); - } else { - genStmt(env, rewriter, ej, curr + 1); - } + // TODO: handle yield values. + assert(reduc.empty() && "Not Implemented"); + rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc()); + return std::nullopt; + }); + // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); + } else { + genStmt(env, rewriter, ej, curr + 1); } } - // End a loop. needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); + } else { + // Emit a loop for every lattice point L0 >= Li in this loop sequence. + for (unsigned i = 0; i < lsize; i++) { + const LatPointId li = env.set(lts)[i]; + // Start a loop. + auto [loop, isSingleCond] = + startLoop(env, rewriter, curr, li, lsize, needsUniv); + + // Visit all lattices points with Li >= Lj to generate the + // loop-body, possibly with if statements for coiteration. + Value redInput = env.getReduc(); + Value cntInput = env.getExpandCount(); + Value insInput = env.getInsertionChain(); + Value validIns = env.getValidLexInsert(); + // We cannot change this to `for (const LatPointId lj : env.set(lts))` + // because the loop body causes data-movement which invalidates the + // iterator. + for (unsigned j = 0; j < lsize; j++) { + const LatPointId lj = env.set(lts)[j]; + const ExprId ej = env.lat(lj).exp; + if (li == lj || env.merger().latGT(li, lj)) { + // Recurse into body of each branch. + if (!isSingleCond) { + scf::IfOp ifOp = genIf(env, rewriter, curr, lj); + genStmt(env, rewriter, ej, curr + 1); + endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns); + } else { + genStmt(env, rewriter, ej, curr + 1); + } + } + } + + // End a loop. + needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond); + } } // End a loop sequence.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h index d69ae53..34b793ee 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -49,6 +49,10 @@ linalg::GenericOp op() const { return linalgOp; } const SparsificationOptions &options() const { return sparseOptions; } + bool generatingSparseIterator() const { + return sparseOptions.sparseEmitStrategy == + SparseEmitStrategy::kSparseIterator; + } Merger &merger() { return latticeMerger; } LoopEmitter &emitter() { return loopEmitter; }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index 2be0193..efb3295 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -615,33 +615,106 @@ return true; } +Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder, + Location loc, + I64BitSet caseBit, + unsigned caseIdx, + MutableArrayRef<Value> reduc) { + auto coIterOp = cast<CoIterateOp>(loopStack.back().loop); + SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>()); + cases[caseIdx] = builder.getI64IntegerAttr(caseBit); + + coIterOp.setCasesAttr(builder.getArrayAttr(cases)); + Region &caseRegion = coIterOp.getRegion(caseIdx); + assert(caseRegion.getBlocks().empty() && + "re-initialize the same coiteration case region."); + + // Each block starts with by a list of user-provided iteration arguments. + TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes(); + // Followed by a list of used coordinates of index type. + SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(), + builder.getIndexType()); + + blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end()); + // Ends with a set of iterators that defines the actually iteration space. + for (auto i : caseBit.bits()) { + blockArgTps.push_back( + cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType()) + .getIteratorType()); + } + SmallVector<Location> locs(blockArgTps.size(), loc); + caseRegion.emplaceBlock().addArguments(blockArgTps, locs); + + // Entering the new region scope, updating the SSA chain. + builder.setInsertionPointToStart(&caseRegion.front()); + // Update the coordinates. + loopStack.back().iv = coIterOp.getCrds(caseIdx).front(); + // Updates loop iteration arguments. + ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx); + llvm::copy(iterArgs, reduc.begin()); + // Updates sparse iterator values. + ValueRange iters = coIterOp.getRegionIterators(caseIdx); + ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls; + for (auto [i, tl] : llvm::enumerate(unpackTensorLevelRange(tidLvls))) { + if (caseBit[i]) { + spIterVals[tl.first][tl.second] = iters.front(); + iters = iters.drop_front(); + } else { + spIterVals[tl.first][tl.second] = nullptr; + } + } + // Must have consumed all iterator SSA values. + assert(iters.empty()); + return &caseRegion; +} + Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls( OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, - MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) { - + unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel, + bool needsUniv) { + // TODO: Argument `numCases` only used when generating iterator-based sparse + // loops. Simplify the code upon feature complete. // TODO: handle coiteration with sparse iterator. if (emitStrategy == SparseEmitStrategy::kSparseIterator) { - assert(tidLvls.size() == 1); - auto [tid, lvl] = unpackTensorLevel(tidLvls.front()); - Value t = tensors[tid]; + if (tidLvls.size() == 1) { + auto [tid, lvl] = unpackTensorLevel(tidLvls.front()); + Value t = tensors[tid]; - // Extract and iterate over the iteration space. - ExtractIterSpaceOp extractSpaceOp = - lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) - : builder.create<ExtractIterSpaceOp>( - loc, t, spIterVals[tid][lvl - 1], lvl); + // Extract and iterate over the iteration space. + ExtractIterSpaceOp extractSpaceOp = + lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) + : builder.create<ExtractIterSpaceOp>( + loc, t, spIterVals[tid][lvl - 1], lvl); - IterateOp iterOp = builder.create<IterateOp>( - loc, extractSpaceOp.getExtractedSpace(), reduc); - spIterVals[tid][lvl] = iterOp.getIterator(); + IterateOp iterOp = builder.create<IterateOp>( + loc, extractSpaceOp.getExtractedSpace(), reduc); + spIterVals[tid][lvl] = iterOp.getIterator(); - // Update the reduction varaibles. - llvm::copy(iterOp.getRegionIterArgs(), reduc.begin()); - // Set the insertion point to loop body. - builder.setInsertionPointToStart(iterOp.getBody()); - loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(), - iterOp.getIterator(), loopTag); - return iterOp; + // Update the reduction varaibles. + llvm::copy(iterOp.getRegionIterArgs(), reduc.begin()); + // Set the insertion point to loop body. + builder.setInsertionPointToStart(iterOp.getBody()); + loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(), + iterOp.getCrds().front(), loopTag); + return iterOp; + } + + // CoIteration Loops. + SmallVector<Value> spaces; + for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { + Value t = tensors[tid]; + ExtractIterSpaceOp extractSpaceOp = + lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t) + : builder.create<ExtractIterSpaceOp>( + loc, t, spIterVals[tid][lvl - 1], lvl); + spaces.push_back(extractSpaceOp.getExtractedSpace()); + } + auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases); + // The CoIterationOp does not have insertion block nor induction variable. + // TODO: the `struct LoopInfo` should be simplied after full migration. + loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr, + /*induction variable*/ nullptr, loopTag); + return coIterOp; } // TODO: support multiple return on parallel for? @@ -866,6 +939,18 @@ // Clean up the values, it would help use to discover potential bug at a // earlier stage (instead of silently using a wrong value). const LoopInfo &loopInfo = loopStack.back(); + if (emitStrategy == SparseEmitStrategy::kSparseIterator) { + Operation *p = loopInfo.loop; + if (isa<IterateOp>(p)) + rewriter.create<sparse_tensor::YieldOp>(loc, reduc); + + // Exit the loop. + rewriter.setInsertionPointAfter(p); + // In-place update reduction variables. + llvm::copy(p->getResults(), reduc.begin()); + loopStack.pop_back(); + return; + } // Sets the insertion point to the right position. rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h index f3e73e4..a9eb888 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -145,8 +145,12 @@ /// return the reduction variable used inside the generated loop. Operation *enterCoIterationOverTensorsAtLvls( OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls, - MutableArrayRef<Value> reduc = {}, bool isParallel = false, - bool needsUniv = false); + unsigned numCases, MutableArrayRef<Value> reduc = {}, + bool isParallel = false, bool needsUniv = false); + + Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc, + I64BitSet caseBit, unsigned caseIdx, + MutableArrayRef<Value> reduc); /// Generates code to exit the current loop (e.g., generates yields, forwards /// loop induction variables, etc). @@ -260,9 +264,9 @@ // required for levels with non-tivial index expressions, which is // maintained by the sliceDrivenInfo array below. const llvm::SmallVector<TensorLevel> tidLvls; - const Operation *loop; // the loop operation + Operation *loop; // the loop operation Block *const userCodeBlock; // the block holding users' generated code. - const Value iv; // the induction variable for the loop + Value iv; // the induction variable for the loop }; void categorizeIterators(ArrayRef<TensorLevel> tidLvls,
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir index 268b394..2487156 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
@@ -1,4 +1,8 @@ -// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s +// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER" + +// TODO: temporarilly disabled since there is no lowering rules from `coiterate` to `scf`. +// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s + #COO = #sparse_tensor.encoding<{ @@ -10,13 +14,18 @@ ) }> +#VEC = #sparse_tensor.encoding<{ + map = (d0) -> (d0 : compressed) +}> + + // CHECK-LABEL: func.func @sqsum( // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex> +// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xindex> // CHECK: %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex> // CHECK: %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex> -// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse> to memref<?xi32> +// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xi32> // CHECK: %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} { // CHECK: %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32> // CHECK: %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32 @@ -27,6 +36,12 @@ // CHECK: %[[RET:.*]] = bufferization.to_tensor // CHECK: return %[[RET]] : tensor<i32> // CHECK: } + +// ITER-LABEL: func.func @sqsum( +// ITER: sparse_tensor.iterate +// ITER: sparse_tensor.iterate +// ITER: sparse_tensor.iterate +// ITER: } func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> { %cst = arith.constant dense<0> : tensor<i32> %0 = linalg.generic { @@ -43,3 +58,42 @@ } -> tensor<i32> return %0 : tensor<i32> } + + +// ITER-LABEL: func.func @add( +// ITER: sparse_tensor.coiterate +// ITER: case %[[IT_1:.*]], %[[IT_2:.*]] { +// ITER: %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]] +// ITER: %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]] +// ITER: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32 +// ITER: memref.store %[[SUM]] +// ITER: } +// ITER: case %[[IT_1:.*]], _ { +// ITER: %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]] +// ITER: memref.store %[[LHS]] +// ITER: } +// ITER: case _, %[[IT_2:.*]] { +// ITER: %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]] +// ITER: memref.store %[[RHS]] +// ITER: } +// ITER: bufferization.to_tensor +// ITER: return +// ITER: } +func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> { + %cst = arith.constant dense<0> : tensor<10xi32> + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)> + ], + iterator_types = ["parallel"] + } + ins(%arg0, %arg1 : tensor<10xi32, #VEC>, tensor<10xi32, #VEC>) + outs(%cst : tensor<10xi32>) { + ^bb0(%in1: i32, %in2: i32, %out: i32): + %2 = arith.addi %in1, %in2 : i32 + linalg.yield %2 : i32 + } -> tensor<10xi32> + return %0 : tensor<10xi32> +}