| //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements rewriting rules that are specific to sparse tensors. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Utils/CodegenUtils.h" |
| #include "Utils/LoopEmitter.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Support/LLVM.h" |
| |
| using namespace mlir; |
| using namespace mlir::bufferization; |
| using namespace mlir::linalg; |
| using namespace mlir::sparse_tensor; |
| |
| //===---------------------------------------------------------------------===// |
| // Helper methods for the actual rewriting rules. |
| //===---------------------------------------------------------------------===// |
| |
| // Helper method to match any typed zero. |
| static bool isZeroValue(Value val) { |
| return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()); |
| } |
| |
| // Helper to detect a sparse tensor type operand. |
| static bool isSparseTensor(Value v) { |
| auto enc = getSparseTensorEncoding(v.getType()); |
| return enc && !llvm::all_of(enc.getLvlTypes(), |
| [](auto lt) { return lt == LevelFormat::Dense; }); |
| } |
| static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } |
| |
| // Helper method to find zero/uninitialized tensor materialization. |
| static bool isMaterializing(OpOperand *op, bool isZero) { |
| Value val = op->get(); |
| // Check allocation, with zero alloc when required. |
| if (auto alloc = val.getDefiningOp<AllocTensorOp>()) { |
| Value copy = alloc.getCopy(); |
| if (isZero) |
| return copy && isZeroValue(copy); |
| return !copy; |
| } |
| // Check for empty tensor materialization. |
| if (auto empty = val.getDefiningOp<tensor::EmptyOp>()) |
| return !isZero; |
| // Last resort for zero alloc: the whole value is zero. |
| return isZero && isZeroValue(val); |
| } |
| |
| // Helper to detect sampling operation. |
| static bool isSampling(GenericOp op) { |
| auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
| if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { |
| if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) { |
| // Both scalar input arguments used exactly once. |
| Value s1 = op.getBlock()->getArgument(0); |
| Value s2 = op.getBlock()->getArgument(1); |
| return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || |
| (def->getOperand(1) == s1 && def->getOperand(0) == s2); |
| } |
| } |
| return false; |
| } |
| |
| // Helper to detect chain of multiplications that do not involve x. |
| static bool isMulChain(Value val, Value x) { |
| if (auto arg = dyn_cast<BlockArgument>(val)) |
| return arg != x; |
| if (auto *def = val.getDefiningOp()) { |
| if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) |
| return isMulChain(def->getOperand(0), x) && |
| isMulChain(def->getOperand(1), x); |
| } |
| return false; |
| } |
| |
| // Helper to detect x = x + <multiplications>. |
| static bool isSumOfMul(GenericOp op) { |
| auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
| if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { |
| if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) { |
| Value x = op.getBlock()->getArguments().back(); |
| return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || |
| (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); |
| } |
| } |
| return false; |
| } |
| |
| // Helper to detect direct yield of a zero value. |
| static bool isZeroYield(GenericOp op) { |
| auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()); |
| if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) { |
| if (arg.getOwner()->getParentOp() == op) { |
| return isZeroValue(op->getOperand(arg.getArgNumber())); |
| } |
| } |
| return isZeroValue(yieldOp.getOperand(0)); |
| } |
| |
| /// Populates given sizes array from type (for static sizes) and from |
| /// the tensor (for dynamic sizes). |
| static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes, |
| Location loc, ShapedType stp, Value tensor) { |
| for (const auto &d : enumerate(stp.getShape())) { |
| Value dim; |
| if (d.value() == ShapedType::kDynamic) |
| dim = tensor::DimOp::create(builder, loc, tensor, d.index()); |
| else |
| dim = constantIndex(builder, loc, d.value()); |
| sizes.push_back(dim); |
| } |
| } |
| |
| static RankedTensorType getBufferType(const SparseTensorType &stt, |
| bool needTmpCOO) { |
| return needTmpCOO ? stt.getCOOType(/*ordered=*/false) |
| : stt.getRankedTensorType(); |
| } |
| |
| /// Collects the dynamic dimension sizes for `tp` with the assumption that |
| /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension |
| /// sizes to dynSizes. |
| static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, |
| SmallVectorImpl<Value> &dynSizes) { |
| for (const auto &d : enumerate(tp.getShape())) { |
| if (d.value() == ShapedType::kDynamic) |
| dynSizes.push_back(sizes[d.index()]); |
| } |
| } |
| |
| static LogicalResult genForeachOnSparseConstant(ForeachOp op, |
| RewriterBase &rewriter, |
| SparseElementsAttr attr) { |
| auto loc = op.getLoc(); |
| SmallVector<Value> reduc = op.getInitArgs(); |
| |
| // Foreach on constant. |
| foreachInSparseConstant( |
| rewriter, loc, attr, op.getOrder().value_or(AffineMap()), |
| [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable { |
| SmallVector<Value> args; |
| args.append(cvs.begin(), cvs.end()); |
| args.push_back(v); |
| args.append(reduc); |
| // Clones the foreach op to get a copy of the loop body. |
| auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation())); |
| assert(args.size() == cloned.getBody()->getNumArguments()); |
| Operation *yield = cloned.getBody()->getTerminator(); |
| rewriter.inlineBlockBefore(cloned.getBody(), op, args); |
| // clean up |
| rewriter.eraseOp(cloned); |
| reduc = yield->getOperands(); |
| rewriter.eraseOp(yield); |
| }); |
| |
| rewriter.replaceOp(op, reduc); |
| return success(); |
| } |
| |
| /// Populates the given sizes array for concatenation from types (for static |
| /// sizes) and from the source tensors (for dynamic sizes). |
| static void concatSizesFromInputs(OpBuilder &builder, |
| SmallVectorImpl<Value> &sizes, Location loc, |
| ShapedType dstTp, ValueRange srcs, |
| unsigned dim) { |
| auto dstShape = dstTp.getShape(); |
| sizesFromSrc(builder, sizes, loc, srcs[0]); |
| |
| // Sum up on the `dim` if the dimension is dynamic. |
| if (dstShape[dim] != ShapedType::kDynamic) { |
| // Faithfully take the static size. |
| sizes[dim] = constantIndex(builder, loc, dstShape[dim]); |
| } else { |
| // Else, compute the shape dynamically. |
| for (const auto &src : srcs.drop_front()) { |
| Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim); |
| // Sum up all the sizes. |
| sizes[dim] = arith::AddIOp::create(builder, loc, sizes[dim], srcSz); |
| } |
| } |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // The actual sparse tensor rewriting rules. |
| //===---------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// TODO: move it to tensor dialect instead. |
| /// |
| /// Fold `tensor.concat` and `tensor.extract_slice` |
| /// |
| /// %concat = tensor.concat dim(2) %t0, %t1 |
| /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32> |
| /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1] |
| /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> |
| /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1] |
| /// : tensor<1x64x2xf32> to tensor<1x64x1xf32> |
| /// |
| /// Becomes |
| /// |
| /// %extract0, %extract1 = %t0, %t1 |
| struct FuseExtractSliceWithConcat |
| : public OpRewritePattern<tensor::ExtractSliceOp> { |
| using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, |
| PatternRewriter &rewriter) const override { |
| auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>(); |
| if (!concatOp) |
| return failure(); |
| |
| Location loc = extractOp.getLoc(); |
| int64_t dim = concatOp.getDim(); |
| int64_t rank = extractOp.getResultType().getRank(); |
| |
| SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1)); |
| SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0)); |
| |
| // Compute the partial sums for the slice offsets. |
| AffineExpr sum = rewriter.getAffineDimExpr(0); |
| SmallVector<AffineExpr> partialSums = {sum}; |
| SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)}; |
| for (auto [idx, input] : |
| llvm::enumerate(concatOp.getInputs().drop_back())) { |
| sum = sum + rewriter.getAffineDimExpr(idx + 1); |
| partialSums.push_back(sum); |
| offsetStrides.push_back( |
| rewriter.createOrFold<tensor::DimOp>(loc, input, dim)); |
| } |
| auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0, |
| partialSums, rewriter.getContext()); |
| SmallVector<OpFoldResult> dimOffsets = |
| affine::makeComposedFoldedMultiResultAffineApply( |
| rewriter, loc, partialSumMap, offsetStrides); |
| |
| auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) { |
| for (auto [l, r] : llvm::zip(lhs, rhs)) { |
| std::optional<int64_t> staticVal = getConstantIntValue(l); |
| if (!staticVal.has_value() || staticVal != getConstantIntValue(r)) |
| return false; |
| } |
| return lhs.size() == rhs.size(); |
| }; |
| |
| for (auto [i, input, offset] : |
| llvm::enumerate(concatOp.getInputs(), dimOffsets)) { |
| SmallVector<OpFoldResult> srcSizes = |
| tensor::getMixedSizes(rewriter, loc, input); |
| srcOffsets[dim] = offset; |
| |
| SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes(); |
| SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets(); |
| SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides(); |
| |
| if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) && |
| allEqual(srcStrides, dstStrides)) { |
| Value operand = concatOp.getOperand(i); |
| if (operand.getType() == extractOp.getResultType()) |
| rewriter.replaceOp(extractOp, operand); |
| break; |
| } |
| } |
| |
| return success(); |
| } |
| }; |
| |
| /// Rewriting rule that fuses sparse_tensor.convert into producer. |
| struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConvertOp op, |
| PatternRewriter &rewriter) const override { |
| auto producer = op.getSource().getDefiningOp<GenericOp>(); |
| if (!producer || producer.getDpsInits().size() != 1 || |
| !isMaterializing(producer.getDpsInitOperand(0), false) || |
| !producer.getResult(0).hasOneUse()) { |
| return failure(); |
| } |
| // Clone the materialization operation, but update the result to sparse. |
| rewriter.setInsertionPoint(producer); |
| Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp(); |
| Operation *cloned = rewriter.clone(*init); |
| cloned->getResult(0).setType(op.getResult().getType()); |
| |
| rewriter.modifyOpInPlace(producer, [&]() { |
| producer.getDpsInitsMutable().assign(cloned->getResults()); |
| producer.getResult(0).setType(op.getResult().getType()); |
| }); |
| |
| rewriter.replaceAllOpUsesWith(op, producer); |
| op->erase(); |
| |
| return success(); |
| } |
| }; |
| |
| /// Rewriting rule that converts direct yield of zero with initial allocation. |
| struct FoldInvariantYield : public OpRewritePattern<GenericOp> { |
| public: |
| using OpRewritePattern<GenericOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GenericOp op, |
| PatternRewriter &rewriter) const override { |
| if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 || |
| !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || |
| !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse()) |
| return failure(); |
| auto outputType = getRankedTensorType(op.getResult(0)); |
| // Yielding zero on newly materialized sparse tensor can be |
| // optimized directly (regardless of dynamic or static size). |
| if (getSparseTensorEncoding(outputType)) { |
| rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); |
| return success(); |
| } |
| // Use static zero value directly instead of materialization. |
| if (!outputType.hasStaticShape()) |
| return failure(); |
| Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp(); |
| rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType)); |
| rewriter.eraseOp(def); |
| return success(); |
| } |
| }; |
| |
| /// Rewriting rule that converts two kernels: |
| /// |
| /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) |
| /// X(i,j) = S(i,j) * T(i,j) |
| /// |
| /// into a single kernel, using distributive law: |
| /// |
| /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) |
| /// |
| /// This kind of fusion (merging two ops into one but using arithmetic |
| /// equalities that may not hold for floating-point computations) would |
| /// be undesirable in the dense case, since we distribute the multiplication |
| /// into the reduction loop. However, for sparse sampling tensor S, such |
| /// a fusion may actually reduce the asymptotic complexity of the kernel, |
| /// since intermediate results may be nullified. |
| struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> { |
| public: |
| using OpRewritePattern<GenericOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GenericOp op, |
| PatternRewriter &rewriter) const override { |
| // Check consumer. |
| if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 || |
| op.getNumResults() != 1 || |
| op.getNumParallelLoops() != op.getNumLoops() || |
| !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() || |
| !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() || |
| !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity()) |
| return failure(); |
| // Find consuming OP2(sparse, other) or OP2(other, sparse). The other |
| // operand can be sparse or dense, since the point of this rewriting rule |
| // is detecting a situation in which *more* sparsity is introduced into |
| // a computation, be it already sparse or still dense. |
| unsigned other = 0; |
| if (isSparseTensor(op.getDpsInputOperand(0))) |
| other = 1; |
| else if (!isSparseTensor(op.getDpsInputOperand(1))) |
| return failure(); |
| // Check producer. |
| auto prod = dyn_cast_or_null<GenericOp>( |
| op.getDpsInputOperand(other)->get().getDefiningOp()); |
| if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 || |
| !prod.getResult(0).hasOneUse()) |
| return failure(); |
| // Sampling consumer and sum of multiplication chain producer. |
| if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) || |
| !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) || |
| !isSampling(op) || !isSumOfMul(prod)) |
| return failure(); |
| // Modify operand structure of producer and consumer. |
| Location loc = prod.getLoc(); |
| SmallVector<Value> inputOps = prod.getInputs(); |
| SmallVector<Value> outputOps = op.getOutputs(); |
| SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray(); |
| inputOps.push_back(op.getDpsInputOperand(1 - other)->get()); |
| fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other |
| // Fuse producer and consumer into a new generic op. |
| auto fusedOp = GenericOp::create( |
| rewriter, loc, op.getResult(0).getType(), inputOps, outputOps, |
| rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(), |
| /*doc=*/nullptr, /*library_call=*/nullptr); |
| Block &prodBlock = prod.getRegion().front(); |
| Block &consBlock = op.getRegion().front(); |
| IRMapping mapper; |
| Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); |
| unsigned num = prodBlock.getNumArguments(); |
| for (unsigned i = 0; i < num - 1; i++) |
| addArg(mapper, fusedBlock, prodBlock.getArgument(i)); |
| addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); |
| addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); |
| // Clone bodies of the producer and consumer in new evaluation order. |
| auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); |
| auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); |
| Value last; |
| for (auto &op : prodBlock.without_terminator()) |
| if (&op != acc) { |
| last = op.getResult(0); |
| rewriter.clone(op, mapper); |
| } |
| mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); |
| mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); |
| last = rewriter.clone(*acc, mapper)->getResult(0); |
| linalg::YieldOp::create(rewriter, loc, last); |
| // Force initial value on merged allocation for dense outputs. |
| // TODO: deal with non alloc tensor here one day |
| if (!getSparseTensorEncoding(op.getResult(0).getType())) { |
| Value init = prod.getDpsInitOperand(0) |
| ->get() |
| .getDefiningOp<AllocTensorOp>() |
| .getCopy(); |
| AllocTensorOp a = |
| op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>(); |
| rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); }); |
| } |
| // Replace consumer with fused operation. Old producer |
| // and consumer ops will be removed by DCE. |
| rewriter.replaceOp(op, fusedOp->getResults()); |
| return success(); |
| } |
| |
| private: |
| // Helper to add argument and record the mapping. |
| static void addArg(IRMapping &mapper, Block *b, BlockArgument a) { |
| mapper.map(a, b->addArgument(a.getType(), a.getLoc())); |
| } |
| }; |
| |
| // Fuse a tensor cast into producing operation. Note that a tensor.cast |
| // should really not be used to convert between sparse encodings. Since |
| // the pattern currently appears as a result of some prior rewriting |
| // we make an attempt to repair very obvious cases. |
| // TODO: audit the pure tensor dialect rewriting rules |
| struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> { |
| public: |
| using OpRewritePattern<tensor::CastOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::CastOp op, |
| PatternRewriter &rewriter) const override { |
| Type srcType = op.getSource().getType(); |
| Type dstType = op.getDest().getType(); |
| // A nop cast simply folds away. |
| if (srcType == dstType) { |
| rewriter.replaceOp(op, op->getResults()); |
| return success(); |
| } |
| // See if a sparsity changing cast can be fused into producer. |
| if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) { |
| if (Operation *def = op.getSource().getDefiningOp()) { |
| if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) { |
| rewriter.modifyOpInPlace(def, [&]() { |
| def->getResult(0).setType(op->getResultTypes()[0]); |
| }); |
| rewriter.replaceOp(op, def->getResult(0)); |
| return success(); |
| } |
| } |
| } |
| // Repair tensor casts with at least one sparse operand into the |
| // the properly supported sparse_tensor.convert. |
| if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) { |
| rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource()); |
| return success(); |
| } |
| // Fail otherwise. |
| return failure(); |
| } |
| }; |
| |
| /// Rewrites a sequence of operations for sparse tensor selections in to |
| /// semi-ring operations such that they can be compiled correctly by the |
| /// sparsifier. E.g., transforming the following sequence |
| /// |
| /// %sel = arith.select %cond, %sp1, %sp2 |
| /// |
| /// to |
| /// |
| /// %sel = binary %sp1, %sp2: |
| /// both (%l, %r) {yield select %cond, %l, %r} |
| /// left (%l) {yield select %cond, %l, 0} |
| /// right (%r) {yield select %cond, 0, %r} |
| /// |
| /// TODO: We require that the tensor used for extracting conditions to be dense |
| /// to sparsify the code. To support a sparse condition tensor, we need a |
| /// tri-nary operation. |
| struct GenSemiRingSelect : public OpRewritePattern<GenericOp> { |
| public: |
| using OpRewritePattern<GenericOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(GenericOp op, |
| PatternRewriter &rewriter) const override { |
| // Rejects non sparse kernels. |
| if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op)) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings; |
| for (Operation &inst : *op.getBody()) { |
| // Matches pattern. |
| auto matched = isRewritablePattern(op, &inst); |
| if (!matched.has_value()) |
| continue; |
| |
| rewriter.setInsertionPoint(&inst); |
| auto [c, t, f] = matched.value(); |
| assert(t.getType() == f.getType()); |
| auto selTp = t.getType(); |
| auto c0 = constantZero(rewriter, loc, selTp); |
| auto binOp = sparse_tensor::BinaryOp::create(rewriter, loc, selTp, t, f); |
| // Initializes all the blocks. |
| rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp}, |
| {t.getLoc(), f.getLoc()}); |
| rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc()); |
| rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc()); |
| |
| for (auto *r : binOp.getRegions()) { |
| Block *b = &r->front(); |
| rewriter.setInsertionPointToStart(b); |
| |
| IRMapping irMap; |
| // Clones the cmp operations into the region to make the binary op |
| // admissible. |
| Value newC = c; |
| if (auto *def = c.getDefiningOp()) |
| newC = rewriter.clone(*def, irMap)->getResult(0); |
| |
| irMap.map(c, newC); |
| if (r == &binOp.getLeftRegion()) { |
| irMap.map(t, b->getArgument(0)); |
| irMap.map(f, c0); |
| } else if (r == &binOp.getRightRegion()) { |
| irMap.map(t, c0); |
| irMap.map(f, b->getArgument(0)); |
| } else { |
| irMap.map(t, b->getArgument(0)); |
| irMap.map(f, b->getArgument(1)); |
| } |
| auto y = rewriter.clone(inst, irMap)->getResult(0); |
| sparse_tensor::YieldOp::create(rewriter, loc, y); |
| } |
| |
| // We successfully rewrited a operation. We can not do replacement here |
| // becuase it invalidate the iterator for the current loop to traverse |
| // the instructions. |
| semiRings.emplace_back(&inst, binOp); |
| } |
| |
| // Finalizes the replacement. |
| for (auto [sel, semi] : semiRings) |
| rewriter.replaceOp(sel, semi->getResults()); |
| |
| return success(!semiRings.empty()); |
| } |
| |
| private: |
| static std::optional<std::tuple<Value, BlockArgument, BlockArgument>> |
| isRewritablePattern(GenericOp op, Operation *v) { |
| auto sel = dyn_cast<arith::SelectOp>(v); |
| if (!sel) |
| return std::nullopt; |
| |
| auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue()); |
| auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue()); |
| // TODO: For simplicity, we only handle cases where both true/false value |
| // are directly loaded the input tensor. We can probably admit more cases |
| // in theory. |
| if (!tVal || !fVal) |
| return std::nullopt; |
| |
| // Helper lambda to determine whether the value is loaded from a dense input |
| // or is a loop invariant. |
| auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool { |
| if (auto bArg = dyn_cast<BlockArgument>(v); |
| bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber()))) |
| return true; |
| // If the value is defined outside the loop, it is a loop invariant. |
| return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody(); |
| }; |
| |
| // If the condition value is load directly from a dense tensor or |
| // loop-invariants, we can sparsify the kernel. |
| auto cond = sel.getCondition(); |
| if (isValFromDenseInputOrInvariant(cond)) |
| return std::make_tuple(cond, tVal, fVal); |
| |
| Value cmpL, cmpR; |
| if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL), |
| matchers::m_Any(&cmpR))) || |
| matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL), |
| matchers::m_Any(&cmpR)))) { |
| // TODO: we can do it recursively to check whether all the leaf values are |
| // loaded from dense tensors or are loop invariants. |
| if (isValFromDenseInputOrInvariant(cmpL) || |
| isValFromDenseInputOrInvariant(cmpR)) |
| return std::make_tuple(cond, tVal, fVal); |
| } |
| |
| return std::nullopt; |
| }; |
| }; |
| |
| /// Rewrites a sparse reduction that would not sparsify directly since |
| /// doing so would only iterate over the stored elements, ignoring the |
| /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max |
| /// (note that reductions like add/sub/or/xor can directly be sparsified |
| /// since the implicit zeros do not contribute to the final result). |
| /// Note that prod/and are still included since, even though they often |
| /// are nullified in sparse data, they may still occur for special |
| /// situations in which e.g. some rows in a sparse matrix are fully |
| /// dense. For min/max, including the implicit zeros is a much more |
| /// common situation. |
| /// |
| /// TODO: this essentially "densifies" the operation; we want to implement |
| /// this much more efficiently by performing the reduction over the |
| /// stored values, and feed in the zero once if there were *any* |
| /// implicit zeros as well; but for now, at least we provide |
| /// the functionality |
| /// |
| struct GenSemiRingReduction : public OpRewritePattern<GenericOp> { |
| public: |
| using OpRewritePattern<GenericOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GenericOp op, |
| PatternRewriter &rewriter) const override { |
| // Reject non-reductions. |
| if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 || |
| op.getNumReductionLoops() == 0 || op.getNumResults() != 1) |
| return failure(); |
| auto *inp = op.getDpsInputOperand(0); |
| auto *init = op.getDpsInitOperand(0); |
| if (!isSparseTensor(inp)) |
| return failure(); |
| // Look for direct x = x OP y for semi-ring ready reductions. |
| auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator()) |
| .getOperand(0) |
| .getDefiningOp(); |
| if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp, |
| arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp, |
| arith::MaxUIOp>(red)) |
| return failure(); |
| Value s0 = op.getBlock()->getArgument(0); |
| Value s1 = op.getBlock()->getArgument(1); |
| if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) && |
| (red->getOperand(0) != s1 || red->getOperand(1) != s0)) |
| return failure(); |
| // Identity. |
| Location loc = op.getLoc(); |
| Value identity = |
| tensor::ExtractOp::create(rewriter, loc, init->get(), ValueRange()); |
| // Unary { |
| // present -> value |
| // absent -> zero. |
| // } |
| Type rtp = s0.getType(); |
| rewriter.setInsertionPointToStart(&op.getRegion().front()); |
| auto semiring = sparse_tensor::UnaryOp::create(rewriter, loc, rtp, s0); |
| Block *present = |
| rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc); |
| rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front()); |
| sparse_tensor::YieldOp::create(rewriter, loc, present->getArgument(0)); |
| rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {}); |
| rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front()); |
| auto zero = |
| arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(rtp)); |
| sparse_tensor::YieldOp::create(rewriter, loc, zero); |
| rewriter.setInsertionPointAfter(semiring); |
| // CustomReduce { |
| // x = x REDUC y, identity |
| // } |
| auto custom = sparse_tensor::ReduceOp::create( |
| rewriter, loc, rtp, semiring.getResult(), s1, identity); |
| Block *region = |
| rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc}); |
| rewriter.setInsertionPointToStart(&custom.getRegion().front()); |
| IRMapping irMap; |
| irMap.map(red->getOperand(0), region->getArgument(0)); |
| irMap.map(red->getOperand(1), region->getArgument(1)); |
| auto *cloned = rewriter.clone(*red, irMap); |
| sparse_tensor::YieldOp::create(rewriter, loc, cloned->getResult(0)); |
| rewriter.setInsertionPointAfter(custom); |
| rewriter.replaceOp(red, custom.getResult()); |
| return success(); |
| } |
| }; |
| |
| /// Sparse rewriting rule for the print operator. This operation is mainly used |
| /// for debugging and testing. As such, it lowers to the vector.print operation |
| /// which only require very light-weight runtime support. |
| struct PrintRewriter : public OpRewritePattern<PrintOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(PrintOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| auto tensor = op.getTensor(); |
| auto stt = getSparseTensorType(tensor); |
| // Header with NSE. |
| auto nse = NumberOfEntriesOp::create(rewriter, loc, tensor); |
| vector::PrintOp::create( |
| rewriter, loc, |
| rewriter.getStringAttr("---- Sparse Tensor ----\nnse = ")); |
| vector::PrintOp::create(rewriter, loc, nse); |
| // Print run-time contents for dim/lvl sizes. |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("dim = ")); |
| printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true); |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("lvl = ")); |
| printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false); |
| // Use the "codegen" foreach loop construct to iterate over |
| // all typical sparse tensor components for printing. |
| foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor, |
| &stt](Type, FieldIndex, |
| SparseTensorFieldKind kind, |
| Level l, LevelType) { |
| switch (kind) { |
| case SparseTensorFieldKind::StorageSpec: { |
| break; |
| } |
| case SparseTensorFieldKind::PosMemRef: { |
| auto lvl = constantIndex(rewriter, loc, l); |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("pos[")); |
| vector::PrintOp::create(rewriter, loc, lvl, |
| vector::PrintPunctuation::NoPunctuation); |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : ")); |
| auto pos = ToPositionsOp::create(rewriter, loc, tensor, l); |
| printContents(rewriter, loc, pos); |
| break; |
| } |
| case SparseTensorFieldKind::CrdMemRef: { |
| auto lvl = constantIndex(rewriter, loc, l); |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("crd[")); |
| vector::PrintOp::create(rewriter, loc, lvl, |
| vector::PrintPunctuation::NoPunctuation); |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("] : ")); |
| Value crd = nullptr; |
| // For COO AoS storage, we want to print a single, linear view of |
| // the full coordinate storage at this level. For any other storage, |
| // we show the coordinate storage for every indivual level. |
| if (stt.getAoSCOOStart() == l) |
| crd = ToCoordinatesBufferOp::create(rewriter, loc, tensor); |
| else |
| crd = ToCoordinatesOp::create(rewriter, loc, tensor, l); |
| printContents(rewriter, loc, crd); |
| break; |
| } |
| case SparseTensorFieldKind::ValMemRef: { |
| vector::PrintOp::create(rewriter, loc, |
| rewriter.getStringAttr("values : ")); |
| auto val = ToValuesOp::create(rewriter, loc, tensor); |
| printContents(rewriter, loc, val); |
| break; |
| } |
| } |
| return true; |
| }); |
| vector::PrintOp::create(rewriter, loc, rewriter.getStringAttr("----\n")); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| |
| private: |
| // Helper to print contents of a single memref. For "push_back" vectors, |
| // we assume that the previous getters for pos/crd/val have added a |
| // slice-to-size view to make sure we just print the size and not the |
| // full capacity. |
| // |
| // Generates code to print (1-dim or higher): |
| // ( a0, a1, ... ) |
| static void printContents(PatternRewriter &rewriter, Location loc, |
| Value vec) { |
| auto shape = cast<ShapedType>(vec.getType()).getShape(); |
| SmallVector<Value> idxs; |
| printContentsLevel(rewriter, loc, vec, 0, shape, idxs); |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine); |
| } |
| |
| // Helper to the helper. |
| static void printContentsLevel(PatternRewriter &rewriter, Location loc, |
| Value vec, unsigned i, ArrayRef<int64_t> shape, |
| SmallVectorImpl<Value> &idxs) { |
| // Open bracket. |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); |
| // Generate for loop. |
| auto zero = constantIndex(rewriter, loc, 0); |
| auto index = constantIndex(rewriter, loc, i); |
| auto size = memref::DimOp::create(rewriter, loc, vec, index); |
| auto step = constantIndex(rewriter, loc, 1); |
| auto forOp = scf::ForOp::create(rewriter, loc, zero, size, step); |
| idxs.push_back(forOp.getInductionVar()); |
| rewriter.setInsertionPointToStart(forOp.getBody()); |
| if (i < shape.size() - 1) { |
| // Enter deeper loop nest. |
| printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs); |
| } else { |
| // Actual contents printing. |
| auto val = memref::LoadOp::create(rewriter, loc, vec, idxs); |
| if (llvm::isa<ComplexType>(val.getType())) { |
| // Since the vector dialect does not support complex types in any op, |
| // we split those into (real, imag) pairs here. |
| Value real = complex::ReOp::create(rewriter, loc, val); |
| Value imag = complex::ImOp::create(rewriter, loc, val); |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); |
| vector::PrintOp::create(rewriter, loc, real, |
| vector::PrintPunctuation::Comma); |
| vector::PrintOp::create(rewriter, loc, imag, |
| vector::PrintPunctuation::Close); |
| } else { |
| vector::PrintOp::create(rewriter, loc, val, |
| vector::PrintPunctuation::NoPunctuation); |
| } |
| // Terminating comma (except at end). |
| auto bound = arith::AddIOp::create(rewriter, loc, idxs.back(), step); |
| Value cond = arith::CmpIOp::create(rewriter, loc, |
| arith::CmpIPredicate::ne, bound, size); |
| scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, cond, /*else*/ false); |
| rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Comma); |
| } |
| idxs.pop_back(); |
| rewriter.setInsertionPointAfter(forOp); |
| // Close bracket. |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close); |
| } |
| |
| // Helper method to print run-time lvl/dim sizes. |
| static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor, |
| unsigned size, bool isDim) { |
| // Open bracket. |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open); |
| // Print unrolled contents (dimop requires constant value). |
| for (unsigned i = 0; i < size; i++) { |
| auto idx = constantIndex(rewriter, loc, i); |
| Value val; |
| if (isDim) |
| val = tensor::DimOp::create(rewriter, loc, tensor, idx); |
| else |
| val = LvlOp::create(rewriter, loc, tensor, idx); |
| vector::PrintOp::create(rewriter, loc, val, |
| i != size - 1 |
| ? vector::PrintPunctuation::Comma |
| : vector::PrintPunctuation::NoPunctuation); |
| } |
| // Close bracket and end of line. |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Close); |
| vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::NewLine); |
| } |
| }; |
| |
| /// Sparse rewriting rule for sparse-to-sparse reshape operator. |
| struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> { |
| public: |
| using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ReshapeOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value srcTensor = op.getSource(); |
| const auto srcTp = tryGetSparseTensorType(srcTensor); |
| const auto dstTp = tryGetSparseTensorType(op.getResult()); |
| if (!srcTp || !dstTp) |
| return failure(); |
| |
| if (!srcTp->hasEncoding() || !dstTp->hasEncoding() || |
| !dstTp->hasStaticDimShape()) |
| return failure(); |
| |
| SmallVector<Value> srcSizes; |
| sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor); |
| SmallVector<Value> dstSizes; |
| for (Dimension d : dstTp->getDimShape()) |
| dstSizes.push_back(constantIndex(rewriter, loc, d)); |
| |
| Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor); |
| // Only need an unordered COO buffer if input and output are not sorted |
| // in the same way. |
| Type bufferTp = getBufferType( |
| dstTp->withoutDimToLvl(), |
| !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity()); |
| SmallVector<Value> dynSizes; |
| Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes, |
| Value(), nnz, Attribute()) |
| .getResult(); |
| |
| // Convert src coordinates to dst coordinates by first collapsing it to 1D |
| // and then expand it to the match the rank of the destination tensor. |
| // Implemented as follows: |
| // foreach srcCoords %srcTensor |
| // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank]) |
| // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank]) |
| // insert expandedCoords, %buffer |
| // |
| // followed by an optional |
| // %t = sparse_tensor.cast %tmp |
| // depending on whether the input/output are sorted in the same way. |
| const auto encSrc = srcTp->getEncoding(); |
| ForeachOp foreachOp = ForeachOp::create( |
| rewriter, loc, srcTensor, buffer, |
| [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, |
| ValueRange reduc) { |
| const Dimension srcRank = srcTp->getDimRank(); |
| SmallVector<Value> srcDcvs; |
| srcDcvs.reserve(srcRank); |
| for (Dimension d = 0; d < srcRank; d++) { |
| Level lvl = toLvl(encSrc, d); |
| srcDcvs.push_back(srcLcvs[lvl]); |
| } |
| |
| Value collapseSize = constantIndex(builder, loc, 1); |
| for (Dimension d = 0; d < srcRank; d++) |
| collapseSize = |
| arith::MulIOp::create(builder, loc, collapseSize, srcSizes[d]); |
| SmallVector<Value, 1> collapsedSizes = {collapseSize}; |
| |
| ReassociationIndices collapseIdx; |
| for (Dimension i = 0; i < srcRank; i++) |
| collapseIdx.push_back(i); |
| SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx}; |
| SmallVector<Value, 1> collapsedDcvs; |
| reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs, |
| collapsedSizes, collapsedDcvs); |
| |
| ReassociationIndices expandIdx; |
| for (Dimension i = 0; i < dstTp->getDimRank(); i++) |
| expandIdx.push_back(i); |
| SmallVector<ReassociationIndices, 1> expandReass = {expandIdx}; |
| SmallVector<Value> dstDcvs; |
| reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs, |
| dstSizes, dstDcvs); |
| |
| auto t = |
| tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs); |
| sparse_tensor::YieldOp::create(builder, loc, t); |
| }); |
| |
| Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true); |
| if (bufferTp != *dstTp) { |
| auto dstRTT = dstTp->getRankedTensorType(); |
| Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult(); |
| DeallocTensorOp::create(rewriter, loc, t); |
| t = converted; |
| } |
| rewriter.replaceOp(op, t); |
| return success(); |
| } |
| }; |
| |
| /// Sparse rewriting rule for sparse-to-sparse reshape operator. |
| template <typename ReshapeOp> |
| struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> { |
| public: |
| using OpRewritePattern<ReshapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ReshapeOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value srcTensor = op.getSrc(); |
| const auto srcTp = getSparseTensorType(srcTensor); |
| const auto dstTp = getSparseTensorType(op.getResult()); |
| if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) |
| return failure(); |
| |
| // Generate code to represent the static dimension constants or compute |
| // the dynamic dimension values. |
| SmallVector<Value> srcSizes; |
| sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); |
| SmallVector<Value> dstSizes; |
| SmallVector<Value> dstDynSizes; |
| if (dstTp.hasStaticDimShape()) { |
| for (Dimension d : dstTp.getDimShape()) |
| dstSizes.push_back(constantIndex(rewriter, loc, d)); |
| } else { |
| ArrayRef<Size> dstShape = dstTp.getDimShape(); |
| genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape, |
| op.getReassociationIndices()); |
| for (auto [idx, shape] : llvm::enumerate(dstShape)) { |
| if (shape == ShapedType::kDynamic) |
| dstDynSizes.push_back(dstSizes[idx]); |
| } |
| } |
| Value nnz = NumberOfEntriesOp::create(rewriter, loc, srcTensor); |
| // Only need a unordered COO buffer if input and output are not sorted |
| // in the same way. |
| Type bufferTp = getBufferType( |
| dstTp.withoutDimToLvl(), |
| !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); |
| |
| Value buffer = |
| AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(), |
| /*sizeHint=*/nnz, Attribute()) |
| .getResult(); |
| |
| // Implement the sparse2sparse reshape as follows: |
| // foreach srcCoords %srcTensor |
| // insert reshapeCvs(srcCoords), %buffer |
| // |
| // followed by an optional |
| // %t = sparse_tensor.cast %tmp |
| // depending on whether the input/output are sorted in the same way. |
| const auto encSrc = srcTp.getEncoding(); |
| ForeachOp foreachOp = ForeachOp::create( |
| rewriter, loc, srcTensor, buffer, |
| [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, |
| ValueRange reduc) { |
| const Dimension dimRank = srcTp.getDimRank(); |
| SmallVector<Value> srcDcvs; |
| srcDcvs.reserve(dimRank); |
| for (Dimension d = 0; d < dimRank; d++) { |
| Level lvl = toLvl(encSrc, d); |
| srcDcvs.push_back(srcLcvs[lvl]); |
| } |
| SmallVector<Value> dstDcvs; |
| reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes, |
| srcDcvs, dstSizes, dstDcvs); |
| auto t = |
| tensor::InsertOp::create(builder, loc, v, reduc.front(), dstDcvs); |
| sparse_tensor::YieldOp::create(builder, loc, t); |
| }); |
| |
| Value t = LoadOp::create(rewriter, loc, foreachOp.getResult(0), true); |
| if (bufferTp != dstTp) { |
| auto dstRTT = dstTp.getRankedTensorType(); |
| Value converted = ConvertOp::create(rewriter, loc, dstRTT, t).getResult(); |
| DeallocTensorOp::create(rewriter, loc, t); |
| t = converted; |
| } |
| rewriter.replaceOp(op, t); |
| return success(); |
| } |
| }; |
| |
| /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape |
| /// operator. |
| template <typename ReshapeOp> |
| struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> { |
| public: |
| using OpRewritePattern<ReshapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ReshapeOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op->getLoc(); |
| auto encDst = getSparseTensorEncoding(op.getResult().getType()); |
| auto encSrc = getSparseTensorEncoding(op.getSrc().getType()); |
| // Since a pure dense expansion is very cheap (change of view), for |
| // a sparse2dense or dense2sparse, we can simply unfuse a sparse |
| // conversion from the reshape operation itself. |
| // All other cases are handled elsewhere. |
| if (encDst && encSrc) { |
| return failure(); |
| } |
| if (encSrc) { |
| auto rtp = getRankedTensorType(op.getSrc()); |
| auto denseTp = |
| RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
| auto convert = ConvertOp::create(rewriter, loc, denseTp, op.getSrc()); |
| rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); }); |
| return success(); |
| } |
| if (encDst) { |
| auto rtp = getRankedTensorType(op.getResult()); |
| auto denseTp = |
| RankedTensorType::get(rtp.getShape(), rtp.getElementType()); |
| ReshapeOp reshape; |
| if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) { |
| reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(), |
| op.getReassociation(), op.getOutputShape(), |
| op.getStaticOutputShape()); |
| } else { |
| reshape = ReshapeOp::create(rewriter, loc, denseTp, op.getSrc(), |
| op.getReassociation()); |
| } |
| Value convert = ConvertOp::create(rewriter, loc, rtp, reshape); |
| rewriter.replaceOp(op, convert); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| |
| // A trivial wrapper to help generate different operations for dense/sparse |
| // tensors. |
| struct TensorLike { |
| TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt, |
| ValueRange sizes) { |
| SmallVector<Value> dynSzs; |
| getDynamicSizes(rtt, sizes, dynSzs); |
| |
| val = AllocTensorOp::create(builder, loc, rtt, dynSzs); |
| if (!isSparse()) { |
| Value c0 = constantZero(builder, loc, rtt.getElementType()); |
| val = linalg::FillOp::create(builder, loc, c0, val).getResult(0); |
| } |
| } |
| |
| void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) { |
| val = tensor::InsertOp::create(builder, loc, v, val, crds); |
| } |
| |
| Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const { |
| if (isSparse()) |
| return LoadOp::create(builder, loc, val, true); |
| return val; |
| } |
| |
| bool isSparse() const { |
| return getSparseTensorEncoding(val.getType()) != nullptr; |
| } |
| |
| Value val; |
| }; |
| |
| struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(tensor::DimOp op, |
| PatternRewriter &rewriter) const override { |
| std::optional<int64_t> dim = op.getConstantIndex(); |
| auto stt = tryGetSparseTensorType(op.getSource()); |
| if (!dim || !stt || !stt->hasEncoding()) |
| return failure(); |
| |
| if (stt->isPermutation()) { |
| rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(), |
| toLvl(stt->getEncoding(), *dim)); |
| return success(); |
| } |
| |
| // Non-permutation dim2lvl/lvl2dim maps. |
| // Compute as follows: |
| // affine.apply #map (l0 - 1, l1 - 1, ...) + 1 |
| // Note that it is not the most efficient way (but a more general one) for |
| // the lvl to dim translation, e.g., for BSR, the dimension size for can be |
| // computed simply by lvl_size * block_size. |
| Location loc = op.getLoc(); |
| SmallVector<Value> maxLvlCrds; |
| for (Level l = 0; l < stt->getLvlRank(); l++) { |
| Value lvlSz = LvlOp::create(rewriter, loc, op.getSource(), l); |
| Value maxLvlCrd = arith::SubIOp::create( |
| rewriter, loc, lvlSz, |
| constantOne(rewriter, loc, rewriter.getIndexType())); |
| maxLvlCrds.push_back(maxLvlCrd); |
| } |
| |
| AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim); |
| Value maxDimCrd = affine::AffineApplyOp::create( |
| rewriter, op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp), |
| maxLvlCrds); |
| |
| Value dimSz = arith::AddIOp::create( |
| rewriter, loc, maxDimCrd, |
| constantOne(rewriter, loc, rewriter.getIndexType())); |
| rewriter.replaceOp(op, dimSz); |
| return success(); |
| } |
| }; |
| |
| struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(ConcatenateOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.needsExtraSort()) |
| op.emitError("ConcatenateOp not staged"); |
| |
| const Location loc = op.getLoc(); |
| const auto dstTp = getSparseTensorType(op); |
| const Dimension conDim = op.getDimension(); |
| SmallVector<Value> sizes; |
| concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); |
| |
| // %t = concatenate %s1, %s2, %s3 {dim = 1} |
| // ==> |
| // if (isSparseDst) |
| // if (allDense) |
| // %tmp = bufferization.alloc_tensor dstTp |
| // else |
| // %tmp = bufferization.alloc_tensor : unordered COO |
| // else |
| // %tmp = memref.alloc : dense tensor |
| // foreach in %s1 : insert d0, d1, %tmp |
| // foreach in %s2 : insert d0, d1 + size(s1), %tmp |
| // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp |
| |
| TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes); |
| Value offset = constantIndex(rewriter, loc, 0); |
| Value iterArg = dstBuf.val; |
| |
| ForeachOp foreachOp; |
| for (Value input : op.getInputs()) { |
| // Builds a for op for each input tensor to append new values into the |
| // output tensor. |
| foreachOp = ForeachOp::create( |
| rewriter, loc, input, iterArg, |
| [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, |
| ValueRange reduc) { |
| SmallVector<Value> offDimCrd(dcvs); |
| offDimCrd[conDim] = |
| arith::AddIOp::create(builder, loc, offDimCrd[conDim], offset); |
| |
| // Enters foreach, updates the SSA chain. |
| dstBuf.val = reduc.front(); |
| if (!dstTp.isAllDense()) { |
| Value cond = genIsNonzero(builder, loc, v); |
| auto ifOp = |
| scf::IfOp::create(builder, loc, reduc.getTypes(), cond, |
| /*else*/ true); |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| scf::YieldOp::create(builder, loc, dstBuf.val); |
| |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| dstBuf.insert(builder, loc, v, offDimCrd); |
| scf::YieldOp::create(builder, loc, dstBuf.val); |
| |
| // Exits the ifOp, update the sparse tensor SSA value. |
| builder.setInsertionPointAfter(ifOp); |
| dstBuf.val = ifOp.getResult(0); |
| } else { |
| dstBuf.insert(builder, loc, v, offDimCrd); |
| } |
| sparse_tensor::YieldOp::create(builder, loc, dstBuf.val); |
| }); |
| // Accumulates the offset. Note that only static-shaped inputs are allowed |
| // by concatenate op verifier, which saves us from computing the offset |
| // dynamically. |
| const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim); |
| assert(ShapedType::isStatic(sz)); |
| offset = arith::AddIOp::create(rewriter, loc, offset, |
| constantIndex(rewriter, loc, sz)); |
| iterArg = foreachOp.getResult(0); |
| dstBuf.val = iterArg; |
| } |
| |
| dstBuf.val = iterArg; |
| Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType()); |
| rewriter.replaceOp(op, ret); |
| return success(); |
| } |
| }; |
| |
| struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(ConvertOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.needsExtraSort()) |
| return op.emitError("ConvertOp not staged."); |
| |
| // TODO: Maybe we want a different operation for this too. |
| auto encDst = getSparseTensorEncoding(op.getType()); |
| auto encSrc = getSparseTensorEncoding(op.getSource().getType()); |
| if (encDst && encSrc && !encSrc.isSlice() && |
| encSrc.withoutBitWidths() == encDst.withoutBitWidths()) { |
| // Trivial tensor conversion and simple element type conversion is handled |
| // in codegen. |
| return failure(); |
| } |
| |
| Location loc = op.getLoc(); |
| Value src = op.getSource(); |
| |
| SparseTensorType srcStt = getSparseTensorType(op.getSource()); |
| SparseTensorType dstStt = getSparseTensorType(op.getDest()); |
| |
| bool fromSparseConst = false; |
| if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) |
| if (isa<SparseElementsAttr>(constOp.getValue())) |
| fromSparseConst = true; |
| |
| const AffineMapAttr foreachOrder = |
| (!dstStt.isIdentity() && fromSparseConst) |
| ? AffineMapAttr::get(dstStt.getExpandedDimToLvl()) |
| : nullptr; |
| |
| bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst; |
| |
| SmallVector<Value> sizes; |
| sizesFromSrc(rewriter, sizes, loc, src); |
| ValueRange vs; |
| TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes); |
| |
| auto foreachOp = ForeachOp::create( |
| rewriter, loc, src, dstBuf.val, foreachOrder, |
| [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, |
| ValueRange reduc) { |
| // Enters the loop, update the SSA value for insertion chain. |
| dstBuf.val = reduc.front(); |
| if (!skipZeroCheck) { |
| Value cond = genIsNonzero(builder, loc, v); |
| auto ifOp = scf::IfOp::create(builder, loc, reduc.getTypes(), cond, |
| /*else*/ true); |
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| scf::YieldOp::create(builder, loc, dstBuf.val); |
| |
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| dstBuf.insert(builder, loc, v, dcvs); |
| scf::YieldOp::create(builder, loc, dstBuf.val); |
| |
| // Exits the ifOp, update the sparse tensor SSA value. |
| builder.setInsertionPointAfter(ifOp); |
| dstBuf.val = ifOp.getResult(0); |
| } else { |
| dstBuf.insert(builder, loc, v, dcvs); |
| } |
| sparse_tensor::YieldOp::create(builder, loc, dstBuf.val); |
| }); |
| |
| rewriter.setInsertionPointAfter(foreachOp); |
| |
| // Exits the for loop, links the SSA chain. |
| dstBuf.val = foreachOp.getResult(0); |
| |
| Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType()); |
| rewriter.replaceOp(op, ret); |
| return success(); |
| } |
| }; |
| |
| struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(CrdTranslateOp op, |
| PatternRewriter &rewriter) const override { |
| AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl |
| ? op.getEncoder().getDimToLvl() |
| : op.getEncoder().getLvlToDim(); |
| |
| SmallVector<Value> outCrds; |
| for (AffineExpr result : map.getResults()) { |
| // TODO: we should probably expand the affine map to IR using our own |
| // rules, since affine.apply assume signed value, while the cooridinates |
| // we provided must always be signless. |
| Value trans = affine::AffineApplyOp::create( |
| rewriter, op.getLoc(), AffineMap::get(map.getNumDims(), 0, result), |
| op.getInCrds()); |
| outCrds.push_back(trans); |
| } |
| rewriter.replaceOp(op, outCrds); |
| return success(); |
| } |
| }; |
| |
| /// Sparse rewriting rule for the foreach operator. |
| struct ForeachRewriter : public OpRewritePattern<ForeachOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ForeachOp op, |
| PatternRewriter &rewriter) const override { |
| |
| auto loc = op.getLoc(); |
| Value input = op.getTensor(); |
| SmallVector<Value> reduc = op.getInitArgs(); |
| const auto stt = getSparseTensorType(input); |
| const Level lvlRank = stt.getLvlRank(); |
| |
| // Special-case: for each over a sparse constant uses its own rewriting |
| // rule. |
| if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) { |
| if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) { |
| return genForeachOnSparseConstant(op, rewriter, attr); |
| } |
| } |
| |
| // Otherwise, use loop emitter to generate loops. |
| const auto enc = stt.getEncoding(); |
| |
| // 1. Generates loop for the sparse input. |
| LoopEmitter loopEmitter( |
| ValueRange{input}, |
| StringAttr::get(getContext(), ForeachOp::getOperationName())); |
| loopEmitter.initializeLoopEmit(rewriter, loc); |
| for (Level l = 0; l < lvlRank; l++) { |
| // TODO: provide utility function for loop sequences that only contains |
| // one for loop? |
| const SmallVector<TensorLevel, 1> tidLvls{ |
| loopEmitter.makeTensorLevel(0, l)}; |
| loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls); |
| // Note that reduc will be taken care of by loop emitter and get updated |
| // in place. |
| loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1, |
| reduc); |
| } |
| |
| SmallVector<Value> lcvs = loopEmitter.getLoopIVs(); |
| if (op.getOrder()) { |
| // TODO: Support it so that we can do direct conversion from CSR->BSR. |
| llvm_unreachable( |
| "Level order not yet implemented on non-constant input tensors."); |
| } |
| |
| Value vals = loopEmitter.getValBuffer()[0]; |
| SmallVector<Value> pos = loopEmitter.getValPosits(0); |
| // Loads the value from sparse tensor using position-index; |
| // loads the value from dense tensor using coords. |
| Value val = enc ? memref::LoadOp::create(rewriter, loc, vals, pos) |
| : memref::LoadOp::create(rewriter, loc, vals, lcvs); |
| |
| // 2. Inline the block in the foreach operator. |
| Block *srcBlock = op.getBody(); |
| |
| // Remap coordinates. |
| SmallVector<Value> args = |
| enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim); |
| |
| // Remap value. |
| args.push_back(val); |
| // Remap reduction variables. |
| args.append(reduc); |
| |
| // Remove sparse_tensor.yield. |
| SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands(); |
| rewriter.eraseOp(srcBlock->getTerminator()); |
| |
| Operation &last = rewriter.getBlock()->back(); |
| if (llvm::isa<scf::YieldOp>(last)) { |
| // Because `scf.for` inserts an implicit yield op when there is no |
| // reduction variable upon creation, we reset the insertion point such |
| // that the block is inlined before *before* the yield op. |
| rewriter.setInsertionPoint(&last); |
| } |
| |
| rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(), |
| rewriter.getInsertionPoint(), args); |
| rewriter.setInsertionPointToEnd(rewriter.getBlock()); |
| for (Level l = 0; l < lvlRank; l++) { |
| // Link the reduction chain. Note that loop emitter update the reducValue |
| // in place. |
| loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); |
| loopEmitter.exitCurrentLoopSeq(rewriter, loc); |
| } |
| |
| // Replace the foreach operator with the value returned by the outtermost |
| // for loop. |
| rewriter.replaceOp(op, reducValue); |
| return success(); |
| } |
| }; |
| |
| /// Sparse rewriting rule for the new operator. |
| struct NewRewriter : public OpRewritePattern<NewOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(NewOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| auto stt = getSparseTensorType(op.getResult()); |
| if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0) |
| return failure(); |
| |
| // Implement the NewOp as follows: |
| // %orderedCoo = sparse_tensor.new %filename |
| // %t = sparse_tensor.convert %orderedCoo |
| // with enveloping reinterpreted_map ops for non-permutations. |
| RankedTensorType dstTp = stt.getRankedTensorType(); |
| RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true); |
| Value cooTensor = NewOp::create(rewriter, loc, cooTp, op.getSource()); |
| Value convert = cooTensor; |
| auto enc = stt.getEncoding(); |
| if (!stt.isPermutation()) { // demap coo, demap dstTp |
| auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl(); |
| convert = ReinterpretMapOp::create(rewriter, loc, coo, convert); |
| dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl()); |
| } |
| convert = ConvertOp::create(rewriter, loc, dstTp, convert); |
| if (!stt.isPermutation()) // remap to original enc |
| convert = ReinterpretMapOp::create(rewriter, loc, enc, convert); |
| rewriter.replaceOp(op, convert); |
| |
| // Release the temporary ordered COO tensor. |
| rewriter.setInsertionPointAfterValue(convert); |
| DeallocTensorOp::create(rewriter, loc, cooTensor); |
| |
| return success(); |
| } |
| }; |
| |
| /// Sparse rewriting rule for the out operator. |
| struct OutRewriter : public OpRewritePattern<OutOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(OutOp op, |
| PatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| // Calculate NNZ. |
| Value src = op.getTensor(); |
| Value nnz = NumberOfEntriesOp::create(rewriter, loc, src); |
| |
| // Allocate a temporary buffer for storing dimension-sizes/coordinates. |
| const auto srcTp = getSparseTensorType(src); |
| const Dimension dimRank = srcTp.getDimRank(); |
| Type indexTp = rewriter.getIndexType(); |
| Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); |
| |
| // Generate code to calculate dimension size values and store the values to |
| // the buffer. |
| SmallVector<Value> dims; |
| sizesForTensor(rewriter, dims, loc, srcTp, src); |
| for (Dimension d = 0; d < dimRank; d++) { |
| memref::StoreOp::create(rewriter, loc, dims[d], dimSizes, |
| constantIndex(rewriter, loc, d)); |
| } |
| |
| // Create a sparse tensor writer and output meta data. |
| Type opaqueTp = getOpaquePointerType(rewriter); |
| Value writer = |
| createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, |
| {op.getDest()}, EmitCInterface::Off) |
| .getResult(0); |
| Value rankValue = constantIndex(rewriter, loc, dimRank); |
| createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, |
| {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); |
| |
| Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords. |
| Type eltTp = srcTp.getElementType(); |
| SmallString<29> outNextFuncName{"outSparseTensorWriterNext", |
| primaryTypeFunctionSuffix(eltTp)}; |
| Value value = genAllocaScalar(rewriter, loc, eltTp); |
| ModuleOp module = op->getParentOfType<ModuleOp>(); |
| |
| // For each element in the source tensor, output the element. |
| ForeachOp::create( |
| rewriter, loc, src, ValueRange(), |
| [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v, |
| ValueRange reduc) { |
| for (Dimension d = 0; d < dimRank; d++) { |
| memref::StoreOp::create(rewriter, loc, dcvs[d], dimCoords, |
| constantIndex(builder, loc, d)); |
| } |
| memref::StoreOp::create(rewriter, loc, v, value); |
| SmallVector<Value> operands{writer, rankValue, dimCoords, value}; |
| FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, |
| EmitCInterface::On); |
| func::CallOp::create(builder, loc, TypeRange(), fn, operands); |
| sparse_tensor::YieldOp::create(builder, loc); |
| }); |
| |
| // Release the writer. |
| createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer}, |
| EmitCInterface::Off); |
| |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| //===---------------------------------------------------------------------===// |
| // Methods that add patterns described in this file to a pattern list. |
| //===---------------------------------------------------------------------===// |
| |
| void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { |
| patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer, |
| FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast, |
| GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>( |
| patterns.getContext()); |
| } |
| |
| void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns, |
| bool enableRT, |
| bool enableConvert) { |
| patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>, |
| ReshapeRewriter<tensor::CollapseShapeOp>, |
| Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>, |
| Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>, |
| SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>( |
| patterns.getContext()); |
| |
| if (enableConvert) |
| patterns.add<DirectConvertRewriter>(patterns.getContext()); |
| if (!enableRT) |
| patterns.add<NewRewriter>(patterns.getContext()); |
| } |
| |
| void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) { |
| // Run CrdTranslateRewriter later in the pipeline so that operation can be |
| // folded before lowering to affine.apply |
| patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext()); |
| } |