| //===- DecomposeLinalgOps.cpp - Pattern to break up Linalg ops ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include <optional> |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| |
| namespace { |
| |
| /// Pattern to decompose a GenericOp that has more than two statements |
| /// into one GenericOp with the first statement (i.e. peeled operation), and |
| /// a second GenericOp with the remaining statements (i.e. residual operations). |
| |
| /// - The result of the first GenericOp has the same shape as the iteration |
| /// space of the GenericOp. The body of the op yields as many values as the |
| /// original op plus all the results of the peeled operation. |
| /// - The second GenericOp has as many operands as the original operation plus |
| /// all the results of the first Generic Op. It has the same number of yields as |
| /// the original op. |
| /// - If the result of the peeled operation was yielded by the original |
| /// GenericOp the uses of the corresponding results will be replaced with the |
| /// result of the first GenericOp created. |
| /// |
| /// Example |
| /// |
| /// ```mlir |
| /// %result:2 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) |
| /// outs(%init0, %init1 : ...) { |
| /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ...): |
| /// %0 = <s0> %b0, %b1 : ... |
| /// %1 = <s1> %0, %b2 : ... |
| /// linalg.yield %0, %1 : ... |
| /// } -> (..., ...) |
| /// return %result#0, %result#1 |
| /// ``` |
| /// |
| /// gets split into |
| /// |
| /// ```mlir |
| /// %init = tensor.empty ... |
| /// %op0:3 = linalg.generic ... ins(%arg0, %arg1, %arg2 : ...) |
| /// outs(%init0, %init1, %init : ...) |
| /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): |
| /// %0 = <s0> %b0, %b1 : ... |
| /// linalg.yield %0, %..., %0 : ... |
| /// } -> (..., ..., ...) |
| /// %op1:2 = linalg.generic ... ins(%arg0, %arg1, %arg2, %op0#2 : ...) |
| /// outs(%init0, %init1 : ...) { |
| /// ^bb0(%b0: ... , %b1: ... , %b2: ... , %b3: ..., %b4: ..., %b5: ...): |
| /// %1 = <s1> %b3, %b2 : ... |
| /// linalg.yield %..., %1 : ... |
| /// } -> (..., ...) |
| /// return %op0#0, %op1#1 |
| /// ``` |
| /// |
| /// After canonicalization this is expected to be |
| /// |
| /// ```mlir |
| /// %init = tensor.empty ... |
| /// %op0 = linalg.generic ... ins(%arg0, %arg1, : ...) |
| /// outs(%init : ...) |
| /// ^bb0(%b0: ... , %b1: ... , %b2: ...): |
| /// %0 = <s0> %b0, %b1 : ... |
| /// linalg.yield %0 : ... |
| /// } -> ... |
| /// %op1 = linalg.generic ... ins(%arg2, %op0#2 : ...) |
| /// outs(%init1 : ...) { |
| /// ^bb0(%b0: ... , %b1: ... , %b2: ...): |
| /// %1 = <s1> %b1, %b0 : ... |
| /// linalg.yield %..., %1 : ... |
| /// } -> ... |
| /// return %op0, %op1 |
| /// ``` |
| struct DecomposeLinalgOp : public OpRewritePattern<GenericOp> { |
| using OpRewritePattern<GenericOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GenericOp genericOp, |
| PatternRewriter &rewriter) const override; |
| |
| private: |
| /// Helper method to create a generic op for the peeled scalar operation. The |
| /// created op has an empty region. |
| GenericOp createPeeledGenericOp(GenericOp genericOp, |
| PatternRewriter &rewriter) const; |
| |
| /// Helper method to create a generic op for the residual scalar operation. |
| /// The created op has the same region as the original op. |
| GenericOp createResidualGenericOp(GenericOp genericOp, |
| GenericOp peeledGenericOp, |
| PatternRewriter &rewriter) const; |
| }; |
| } // namespace |
| |
| /// Helper method to compute the range of a generic op. |
| static SmallVector<OpFoldResult> getGenericOpLoopRange(OpBuilder &b, |
| GenericOp op) { |
| OpBuilder::InsertionGuard g(b); |
| b.setInsertionPoint(op); |
| Location loc = op.getLoc(); |
| auto allShapesSizes = |
| cast<LinalgOp>(op.getOperation()).createFlatListOfOperandDims(b, loc); |
| AffineMap map = op.getShapesToLoopsMap(); |
| IRRewriter rewriter(b); |
| return affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, |
| allShapesSizes); |
| } |
| |
| /// Helper method to permute the list of `values` based on the `map`. |
| SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values, |
| AffineMap map) { |
| assert(map.isPermutation()); |
| SmallVector<OpFoldResult> permutedValues(values.size()); |
| for (const auto &position : |
| llvm::enumerate(llvm::map_range(map.getResults(), [](AffineExpr expr) { |
| return cast<AffineDimExpr>(expr).getPosition(); |
| }))) |
| permutedValues[position.value()] = values[position.index()]; |
| return permutedValues; |
| } |
| |
| /// Get zero value for an element type. |
| static Value getZero(OpBuilder &b, Location loc, Type elementType) { |
| assert(elementType.isIntOrIndexOrFloat() && |
| "expected scalar type while computing zero value"); |
| if (isa<IntegerType>(elementType)) |
| return b.create<arith::ConstantIntOp>(loc, 0, elementType); |
| if (elementType.isIndex()) |
| return b.create<arith::ConstantIndexOp>(loc, 0); |
| // Assume float. |
| auto floatType = cast<FloatType>(elementType); |
| return b.create<arith::ConstantFloatOp>( |
| loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); |
| } |
| |
| GenericOp |
| DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, |
| PatternRewriter &rewriter) const { |
| Block *body = genericOp.getBody(); |
| Operation *peeledScalarOperation = &(*body->begin()); |
| SmallVector<AffineMap> peeledGenericOpIndexingMaps = |
| genericOp.getIndexingMapsArray(); |
| |
| /// Compute the loop ranges for operation. This is the shape of the result of |
| /// the generic op for the peeled operation. |
| Location loc = genericOp.getLoc(); |
| SmallVector<OpFoldResult> domain = getGenericOpLoopRange(rewriter, genericOp); |
| SmallVector<Value> newInitValues; |
| SmallVector<Type> newResultTypes; |
| |
| // Add as many new results as the number of results of the peeled scalar op. |
| for (auto scalarOpResult : peeledScalarOperation->getResults()) { |
| // If the result is yielded by the original op, use the operand, indexing |
| // map and result type that correspond to the yielded value. |
| |
| std::optional<unsigned> resultNumber; |
| for (auto *user : scalarOpResult.getUsers()) { |
| if (auto yieldOp = dyn_cast<YieldOp>(user)) { |
| // Find the first use of the `scalarOpResult` in the yield op. |
| for (OpOperand &yieldOperand : yieldOp->getOpOperands()) { |
| if (yieldOperand.get() == scalarOpResult) { |
| resultNumber = yieldOperand.getOperandNumber(); |
| break; |
| } |
| } |
| assert(resultNumber && "unable to find use of a value in its user"); |
| break; |
| } |
| } |
| if (resultNumber) { |
| newInitValues.push_back( |
| genericOp.getDpsInitOperand(*resultNumber)->get()); |
| OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber)); |
| newResultTypes.push_back(result.getType()); |
| peeledGenericOpIndexingMaps.push_back( |
| genericOp.getIndexingMapMatchingResult(result)); |
| continue; |
| } |
| |
| // Fall back path, use an `init_tensor` and identity indexing map. |
| AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); |
| Value emptyTensor = |
| rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType()); |
| newInitValues.push_back(emptyTensor); |
| newResultTypes.push_back(emptyTensor.getType()); |
| peeledGenericOpIndexingMaps.push_back(indexingMap); |
| } |
| |
| /// Create the peeled generic op with an empty body. |
| SmallVector<Value> outsOperands = genericOp.getOutputs(); |
| outsOperands.append(newInitValues.begin(), newInitValues.end()); |
| SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes()); |
| resultTypes.append(newResultTypes.begin(), newResultTypes.end()); |
| auto indexingMapAttr = |
| rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); |
| return rewriter.create<GenericOp>( |
| loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr, |
| genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, |
| [](OpBuilder, Location, ValueRange) {}); |
| } |
| |
| GenericOp |
| DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, |
| GenericOp peeledGenericOp, |
| PatternRewriter &rewriter) const { |
| /// Append all results from the peeledGenericOps as `ins` operand for the |
| /// residual generic op. |
| SmallVector<Value> residualGenericOpOperands = genericOp.getInputs(); |
| unsigned origNumResults = genericOp.getNumResults(); |
| unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); |
| SmallVector<Value> extraIns; |
| for (auto resultNum : |
| llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) |
| extraIns.push_back(peeledGenericOp->getResult(resultNum)); |
| residualGenericOpOperands.append(extraIns); |
| |
| /// Add indexing maps for the newly added operands. Use the same map |
| /// as those used for the new results of the peeledGenericOp. |
| auto indexingMaps = llvm::to_vector( |
| llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) { |
| return genericOp.getMatchingIndexingMap(operand); |
| })); |
| for (auto resultNum : |
| llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) { |
| OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum)); |
| indexingMaps.push_back( |
| peeledGenericOp.getIndexingMapMatchingResult(result)); |
| } |
| for (OpOperand &outOperand : genericOp.getDpsInitsMutable()) |
| indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); |
| |
| auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); |
| return rewriter.create<GenericOp>( |
| genericOp->getLoc(), genericOp->getResultTypes(), |
| residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr, |
| genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, |
| [](OpBuilder, Location, ValueRange) {}); |
| } |
| |
| LogicalResult |
| DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, |
| PatternRewriter &rewriter) const { |
| /// For now only match on operations where the iterator types are all parallel |
| if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) { |
| return rewriter.notifyMatchFailure(genericOp, |
| "unhandled decomposition of operation " |
| "with non-parallel iterator types"); |
| } |
| // TODO: this could be generalized to handle `linalg.generic` with buffer |
| // operands too but requires allocation for intermediates. Punt on this for |
| // now. |
| if (!genericOp.hasPureTensorSemantics()) { |
| return rewriter.notifyMatchFailure( |
| genericOp, "only operations with tensor semantics are handled"); |
| } |
| |
| if (llvm::any_of(genericOp.getDpsInitsMutable(), [&](OpOperand &outOperand) { |
| return !genericOp.getMatchingIndexingMap(&outOperand).isPermutation(); |
| })) { |
| return rewriter.notifyMatchFailure( |
| genericOp, "unhandled decomposition of generic op with out operand not " |
| "accessed using a permutation"); |
| } |
| |
| /// If the op has only a single statement (apart from the yield), do nothing. |
| Block *body = genericOp.getBody(); |
| if (body->getOperations().size() <= 2) { |
| return rewriter.notifyMatchFailure(genericOp, |
| "operation has less than 3 statements"); |
| } |
| |
| /// Check that the peeled statement has a scalar element type. |
| if (llvm::any_of(body->getOperations().begin()->getResultTypes(), |
| [](Type t) { return !t.isIntOrIndexOrFloat(); })) { |
| return rewriter.notifyMatchFailure( |
| &(*body->getOperations().begin()), |
| "expected return type to be only int, index or float"); |
| } |
| |
| GenericOp peeledGenericOp = createPeeledGenericOp(genericOp, rewriter); |
| GenericOp residualGenericOp = |
| createResidualGenericOp(genericOp, peeledGenericOp, rewriter); |
| |
| /// Move the first statement of the original operation into the body of the |
| /// generic op for the peeled operation. |
| Block *peeledGenericOpBody = peeledGenericOp.getBody(); |
| Block *residualGenericOpBody = residualGenericOp.getBody(); |
| assert(peeledGenericOpBody->empty() && residualGenericOpBody->empty() && |
| "expected split generic ops to have empty region"); |
| peeledGenericOpBody->getOperations().splice( |
| peeledGenericOpBody->begin(), body->getOperations(), body->begin()); |
| residualGenericOpBody->getOperations().splice(residualGenericOpBody->begin(), |
| body->getOperations()); |
| |
| Operation *peeledScalarOperation = &(*peeledGenericOpBody->begin()); |
| auto *yieldOp = residualGenericOpBody->getTerminator(); |
| { |
| // Yield all the result of the peeled scalar operation. |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPointToEnd(peeledGenericOpBody); |
| SmallVector<Value> yieldedVals; |
| for (auto origYield : yieldOp->getOperands()) { |
| if (origYield.getDefiningOp() == peeledScalarOperation) { |
| yieldedVals.push_back(origYield); |
| } else { |
| // Do not materialize any new ops inside of the decomposed LinalgOp, |
| // as that would trigger another application of the rewrite pattern |
| // (infinite loop). |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(peeledGenericOp); |
| yieldedVals.push_back( |
| getZero(rewriter, genericOp.getLoc(), origYield.getType())); |
| } |
| } |
| yieldedVals.append(llvm::to_vector( |
| llvm::map_range(peeledScalarOperation->getResults(), |
| [](OpResult opr) -> Value { return opr; }))); |
| rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals); |
| } |
| |
| /// In the split operations, replace block arguments uses that refer to |
| /// original operation to the block arguments of the newly created operation. |
| unsigned origNumInputs = genericOp.getNumDpsInputs(); |
| for (const auto &inputBlockArg : |
| llvm::enumerate(genericOp.getBody()->getArguments())) { |
| Value residualOpReplacementArg = |
| residualGenericOpBody->getArgument(inputBlockArg.index()); |
| rewriter.replaceUsesWithIf( |
| inputBlockArg.value(), residualOpReplacementArg, [&](OpOperand &use) { |
| return use.getOwner()->getBlock() == residualGenericOpBody; |
| }); |
| |
| Value peeledOpReplacementArg = |
| peeledGenericOpBody->getArgument(inputBlockArg.index()); |
| rewriter.replaceUsesWithIf( |
| inputBlockArg.value(), peeledOpReplacementArg, [&](OpOperand &use) { |
| return use.getOwner()->getBlock() == peeledGenericOpBody; |
| }); |
| } |
| |
| /// Before fixing up the residual operation, track what values are yielded. If |
| /// any of those are from the peeled scalar operation, the uses of the |
| /// corresponding result have to be remapped to result of the generic op for |
| /// the peeled operation. |
| SmallVector<Value> replacements; |
| for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) { |
| OpResult opr = dyn_cast<OpResult>(yieldValue.value()); |
| if (!opr || opr.getOwner() != peeledScalarOperation) |
| replacements.push_back(residualGenericOp.getResult(yieldValue.index())); |
| else |
| replacements.push_back(peeledGenericOp->getResult(yieldValue.index())); |
| } |
| |
| /// Update all uses of the peeled scalar operation results in the residual op |
| /// to the newly added arguments. |
| { |
| SmallVector<Value> scalarReplacements; |
| unsigned peeledScalarOpNumResults = peeledScalarOperation->getNumResults(); |
| scalarReplacements.reserve(peeledScalarOpNumResults); |
| for (auto num : llvm::seq<unsigned>(0, peeledScalarOpNumResults)) |
| scalarReplacements.push_back( |
| residualGenericOpBody->getArgument(num + origNumInputs)); |
| bool allUsesReplaced = false; |
| rewriter.replaceOpUsesWithinBlock(peeledScalarOperation, scalarReplacements, |
| residualGenericOpBody, &allUsesReplaced); |
| assert(!allUsesReplaced && |
| "peeled scalar operation is erased when it wasnt expected to be"); |
| } |
| |
| // Replace the original operation |
| rewriter.replaceOp(genericOp, replacements); |
| return success(); |
| } |
| |
| void mlir::linalg::populateDecomposeLinalgOpsPattern( |
| RewritePatternSet &patterns, bool removeDeadArgsAndResults) { |
| patterns.insert<DecomposeLinalgOp>(patterns.getContext()); |
| // Add the patterns to clean up the dead operands and results. |
| if (removeDeadArgsAndResults) |
| populateEraseUnusedOperandsAndResultsPatterns(patterns); |
| } |