| //===- Bufferize.cpp - Bufferization for std ops --------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements bufferization of tensor-valued arith.constant ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/StandardOps/Transforms/Passes.h" |
| #include "mlir/IR/BlockAndValueMapping.h" |
| #include "mlir/Transforms/BufferUtils.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| using namespace mlir; |
| |
| memref::GlobalOp GlobalCreator::getGlobalFor(arith::ConstantOp constantOp) { |
| auto type = constantOp.getType().cast<RankedTensorType>(); |
| |
| bufferization::BufferizeTypeConverter typeConverter; |
| |
| // If we already have a global for this constant value, no need to do |
| // anything else. |
| auto it = globals.find(constantOp.getValue()); |
| if (it != globals.end()) |
| return cast<memref::GlobalOp>(it->second); |
| |
| // Create a builder without an insertion point. We will insert using the |
| // symbol table to guarantee unique names. |
| OpBuilder globalBuilder(moduleOp.getContext()); |
| SymbolTable symbolTable(moduleOp); |
| |
| // Create a pretty name. |
| SmallString<64> buf; |
| llvm::raw_svector_ostream os(buf); |
| interleave(type.getShape(), os, "x"); |
| os << "x" << type.getElementType(); |
| |
| // Add an optional alignment to the global memref. |
| IntegerAttr memrefAlignment = |
| alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) |
| : IntegerAttr(); |
| |
| auto global = globalBuilder.create<memref::GlobalOp>( |
| constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), |
| /*sym_visibility=*/globalBuilder.getStringAttr("private"), |
| /*type=*/typeConverter.convertType(type).cast<MemRefType>(), |
| /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), |
| /*constant=*/true, |
| /*alignment=*/memrefAlignment); |
| symbolTable.insert(global); |
| // The symbol table inserts at the end of the module, but globals are a bit |
| // nicer if they are at the beginning. |
| global->moveBefore(&moduleOp.front()); |
| globals[constantOp.getValue()] = global; |
| return global; |
| } |
| |
| namespace { |
| class BufferizeTensorConstantOp |
| : public OpConversionPattern<arith::ConstantOp> { |
| public: |
| BufferizeTensorConstantOp(GlobalCreator &globals, |
| TypeConverter &typeConverter, MLIRContext *context) |
| : OpConversionPattern<arith::ConstantOp>(typeConverter, context, |
| /*benefit=*/1), |
| globals(globals) {} |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto type = op.getType().dyn_cast<RankedTensorType>(); |
| if (!type) |
| return failure(); |
| |
| auto globalMemref = globals.getGlobalFor(op); |
| rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, globalMemref.type(), |
| globalMemref.getName()); |
| return success(); |
| } |
| GlobalCreator &globals; |
| }; |
| } // namespace |
| |
| void mlir::populateTensorConstantBufferizePatterns( |
| GlobalCreator &globalCreator, |
| bufferization::BufferizeTypeConverter &typeConverter, |
| RewritePatternSet &patterns) { |
| patterns.add<BufferizeTensorConstantOp>(globalCreator, typeConverter, |
| patterns.getContext()); |
| } |
| |
| namespace { |
| class TensorConstantBufferizePass |
| : public TensorConstantBufferizeBase<TensorConstantBufferizePass> { |
| public: |
| explicit TensorConstantBufferizePass(unsigned alignment) { |
| if (alignment) |
| this->alignment = alignment; |
| } |
| |
| void runOnOperation() override { |
| auto module = getOperation(); |
| GlobalCreator globals(module, alignment); |
| |
| auto *context = &getContext(); |
| bufferization::BufferizeTypeConverter typeConverter; |
| RewritePatternSet patterns(context); |
| ConversionTarget target(*context); |
| |
| target.addLegalDialect<memref::MemRefDialect>(); |
| populateTensorConstantBufferizePatterns(globals, typeConverter, patterns); |
| target.addDynamicallyLegalOp<arith::ConstantOp>([&](arith::ConstantOp op) { |
| return typeConverter.isLegal(op.getType()); |
| }); |
| if (failed(applyPartialConversion(module, target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| std::unique_ptr<Pass> |
| mlir::createTensorConstantBufferizePass(unsigned alignment) { |
| return std::make_unique<TensorConstantBufferizePass>(alignment); |
| } |