| //===- Bufferize.cpp - Bufferization utilities ----------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| #include "mlir/Dialect/Bufferization/Transforms/Passes.h" |
| #include "mlir/IR/Operation.h" |
| |
| using namespace mlir; |
| using namespace mlir::bufferization; |
| |
| //===----------------------------------------------------------------------===// |
| // BufferizeTypeConverter |
| //===----------------------------------------------------------------------===// |
| |
| static Value materializeToTensor(OpBuilder &builder, TensorType type, |
| ValueRange inputs, Location loc) { |
| assert(inputs.size() == 1); |
| assert(inputs[0].getType().isa<BaseMemRefType>()); |
| return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); |
| } |
| |
| /// Registers conversions into BufferizeTypeConverter |
| BufferizeTypeConverter::BufferizeTypeConverter() { |
| // Keep all types unchanged. |
| addConversion([](Type type) { return type; }); |
| // Convert RankedTensorType to MemRefType. |
| addConversion([](RankedTensorType type) -> Type { |
| return MemRefType::get(type.getShape(), type.getElementType()); |
| }); |
| // Convert UnrankedTensorType to UnrankedMemRefType. |
| addConversion([](UnrankedTensorType type) -> Type { |
| return UnrankedMemRefType::get(type.getElementType(), 0); |
| }); |
| addArgumentMaterialization(materializeToTensor); |
| addSourceMaterialization(materializeToTensor); |
| addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, |
| ValueRange inputs, Location loc) -> Value { |
| assert(inputs.size() == 1); |
| assert(inputs[0].getType().isa<TensorType>()); |
| return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); |
| }); |
| } |
| |
| void mlir::bufferization::populateBufferizeMaterializationLegality( |
| ConversionTarget &target) { |
| target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); |
| } |
| |
| namespace { |
| // In a finalizing bufferize conversion, we know that all tensors have been |
| // converted to memrefs, thus, this op becomes an identity. |
| class BufferizeToTensorOp |
| : public OpConversionPattern<bufferization::ToTensorOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp(op, adaptor.memref()); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| namespace { |
| // In a finalizing bufferize conversion, we know that all tensors have been |
| // converted to memrefs, thus, this op becomes an identity. |
| class BufferizeToMemrefOp |
| : public OpConversionPattern<bufferization::ToMemrefOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp(op, adaptor.tensor()); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns( |
| BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter, |
| patterns.getContext()); |
| } |
| |
| namespace { |
| struct FinalizingBufferizePass |
| : public FinalizingBufferizeBase<FinalizingBufferizePass> { |
| using FinalizingBufferizeBase< |
| FinalizingBufferizePass>::FinalizingBufferizeBase; |
| |
| void runOnFunction() override { |
| auto func = getFunction(); |
| auto *context = &getContext(); |
| |
| BufferizeTypeConverter typeConverter; |
| RewritePatternSet patterns(context); |
| ConversionTarget target(*context); |
| |
| populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns); |
| |
| // If all result types are legal, and all block arguments are legal (ensured |
| // by func conversion above), then all types in the program are legal. |
| // |
| // We also check that the operand types are legal to avoid creating invalid |
| // IR. For example, this prevents |
| // populateEliminateBufferizeMaterializationsPatterns from updating the |
| // types of the operands to a return op without updating the enclosing |
| // function. |
| target.markUnknownOpDynamicallyLegal( |
| [&](Operation *op) { return typeConverter.isLegal(op); }); |
| |
| if (failed(applyFullConversion(func, target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| std::unique_ptr<FunctionPass> |
| mlir::bufferization::createFinalizingBufferizePass() { |
| return std::make_unique<FinalizingBufferizePass>(); |
| } |