| //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" |
| #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
| |
| using namespace mlir; |
| using namespace linalg; |
| using namespace mlir::bufferization; |
| |
| namespace { |
| |
| /// Generic conversion for any DestinationStyleOpInterface on tensors. |
| static LogicalResult |
| bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, |
| DestinationStyleOpInterface op, |
| const BufferizationOptions &options) { |
| // Take a guard before anything else. |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(op); |
| |
| // Nothing to do. This op is already bufferized. |
| if (op.hasBufferSemantics()) |
| return success(); |
| |
| // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need |
| // basis. |
| if (!op.hasTensorSemantics()) |
| return op->emitError() << "op does not have tensor semantics"; |
| |
| // New input operands for the cloned op. |
| SmallVector<Value> newInputBuffers; |
| newInputBuffers.reserve(op.getNumDpsInputs()); |
| for (OpOperand *opOperand : op.getDpsInputOperands()) { |
| if (op.isScalar(opOperand)) { |
| newInputBuffers.push_back(opOperand->get()); |
| continue; |
| } |
| FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options); |
| if (failed(buffer)) |
| return failure(); |
| newInputBuffers.push_back(*buffer); |
| } |
| |
| // New output operands for the cloned op. |
| SmallVector<Value> newOutputBuffers; |
| for (OpResult opResult : op->getOpResults()) { |
| OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); |
| FailureOr<Value> resultBuffer = |
| getBuffer(rewriter, opOperand->get(), options); |
| if (failed(resultBuffer)) |
| return failure(); |
| newOutputBuffers.push_back(*resultBuffer); |
| } |
| |
| // Merge input/output operands. |
| SmallVector<Value> newOperands = newInputBuffers; |
| newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); |
| |
| // Set insertion point now that potential alloc/dealloc are introduced. |
| rewriter.setInsertionPoint(op); |
| // Clone the op, but use the new operands. Move the existing block into the |
| // new op. Since the new op does not have any tensor results, it does not |
| // return anything. |
| assert(op->getNumRegions() == 1 && "expected that op has 1 region"); |
| auto newOp = cast<DestinationStyleOpInterface>(op.cloneWithoutRegions( |
| rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); |
| rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0), |
| newOp->getRegion(0).begin()); |
| |
| // Replace the results of the old op with the new output buffers. |
| replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); |
| |
| return success(); |
| } |
| |
| /// Bufferization of linalg.generic. Replace with a new linalg.generic that |
| /// operates entirely on memrefs. |
| template <typename OpTy> |
| struct LinalgOpInterface |
| : public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>, |
| OpTy> { |
| bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Operand is read if it is used in the computation. |
| auto genericOp = cast<linalg::LinalgOp>(op); |
| return genericOp.payloadUsesValueFromOperand(&opOperand); |
| } |
| |
| bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Operand is written to if it has an aliasing OpResult. |
| auto bufferizableOp = cast<BufferizableOpInterface>(op); |
| return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); |
| } |
| |
| SmallVector<OpOperand *> |
| getAliasingOpOperand(Operation *op, OpResult opResult, |
| const AnalysisState &state) const { |
| auto genericOp = cast<DestinationStyleOpInterface>(op); |
| |
| // The i-th OpResult may alias with the i-th "out" tensor. |
| return {genericOp.getDpsInitOperand(opResult.getResultNumber())}; |
| } |
| |
| SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| auto genericOp = cast<DestinationStyleOpInterface>(op); |
| |
| // The i-th "out" tensor may alias with the i-th OpResult. |
| if (genericOp.isDpsInit(&opOperand)) |
| return {genericOp.getTiedOpResult(&opOperand)}; |
| return {}; |
| } |
| |
| BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| const AnalysisState &state) const { |
| return BufferRelation::Equivalent; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| return bufferizeDestinationStyleOpInterface( |
| rewriter, cast<DestinationStyleOpInterface>(op), options); |
| } |
| }; |
| |
| /// Helper structure that iterates over all LinalgOps in `OpTys` and registers |
| /// the `BufferizableOpInterface` with each of them. |
| template <typename... Ops> |
| struct LinalgOpInterfaceHelper { |
| static void registerOpInterface(MLIRContext *ctx) { |
| (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...); |
| } |
| }; |
| } // namespace |
| |
| void mlir::linalg::registerBufferizableOpInterfaceExternalModels( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
| // Register all Linalg structured ops. `LinalgOp` is an interface and it is |
| // not possible to attach an external interface to an existing interface. |
| // Therefore, attach the `BufferizableOpInterface` to all ops one-by-one. |
| LinalgOpInterfaceHelper< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
| >::registerOpInterface(ctx); |
| }); |
| } |