| //===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "../PassDetail.h" |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/OpenACC/OpenACC.h" |
| #include "mlir/IR/Builders.h" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // DataDescriptor implementation |
| //===----------------------------------------------------------------------===// |
| |
| constexpr StringRef getStructName() { return "openacc_data"; } |
| |
| /// Construct a helper for the given descriptor value. |
| DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) { |
| assert(value != nullptr && "value cannot be null"); |
| } |
| |
| /// Builds IR creating an `undef` value of the data descriptor. |
| DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc, |
| Type basePtrTy, Type ptrTy) { |
| Type descriptorType = LLVM::LLVMStructType::getNewIdentified( |
| builder.getContext(), getStructName(), |
| {basePtrTy, ptrTy, builder.getI64Type()}); |
| Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType); |
| return DataDescriptor(descriptor); |
| } |
| |
| /// Check whether the type is a valid data descriptor. |
| bool DataDescriptor::isValid(Value descriptor) { |
| if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) { |
| if (type.isIdentified() && type.getName().startswith(getStructName()) && |
| type.getBody().size() == 3 && |
| (type.getBody()[kPtrBasePosInDataDescriptor] |
| .isa<LLVM::LLVMPointerType>() || |
| type.getBody()[kPtrBasePosInDataDescriptor] |
| .isa<LLVM::LLVMStructType>()) && |
| type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() && |
| type.getBody()[kSizePosInDataDescriptor].isInteger(64)) |
| return true; |
| } |
| return false; |
| } |
| |
| /// Builds IR inserting the base pointer value into the descriptor. |
| void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc, |
| Value basePtr) { |
| setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr); |
| } |
| |
| /// Builds IR inserting the pointer value into the descriptor. |
| void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) { |
| setPtr(builder, loc, kPtrPosInDataDescriptor, ptr); |
| } |
| |
| /// Builds IR inserting the size value into the descriptor. |
| void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) { |
| setPtr(builder, loc, kSizePosInDataDescriptor, size); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion patterns |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| template <typename Op> |
| class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> { |
| using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &builder) const override { |
| Location loc = op.getLoc(); |
| TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); |
| |
| unsigned numDataOperand = op.getNumDataOperands(); |
| |
| // Keep the non data operands without modification. |
| auto nonDataOperands = adaptor.getOperands().take_front( |
| adaptor.getOperands().size() - numDataOperand); |
| SmallVector<Value> convertedOperands; |
| convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end()); |
| |
| // Go over the data operand and legalize them for translation. |
| for (unsigned idx = 0; idx < numDataOperand; ++idx) { |
| Value originalDataOperand = op.getDataOperand(idx); |
| |
| // Traverse operands that were converted to MemRefDescriptors. |
| if (auto memRefType = |
| originalDataOperand.getType().dyn_cast<MemRefType>()) { |
| Type structType = converter->convertType(memRefType); |
| Value memRefDescriptor = builder |
| .create<UnrealizedConversionCastOp>( |
| loc, structType, originalDataOperand) |
| .getResult(0); |
| |
| // Calculate the size of the memref and get the pointer to the allocated |
| // buffer. |
| SmallVector<Value> sizes; |
| SmallVector<Value> strides; |
| Value sizeBytes; |
| ConvertToLLVMPattern::getMemRefDescriptorSizes( |
| loc, memRefType, {}, builder, sizes, strides, sizeBytes); |
| MemRefDescriptor descriptor(memRefDescriptor); |
| Value dataPtr = descriptor.alignedPtr(builder, loc); |
| auto ptrType = descriptor.getElementPtrType(); |
| |
| auto descr = DataDescriptor::undef(builder, loc, structType, ptrType); |
| descr.setBasePointer(builder, loc, memRefDescriptor); |
| descr.setPointer(builder, loc, dataPtr); |
| descr.setSize(builder, loc, sizeBytes); |
| convertedOperands.push_back(descr); |
| } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) { |
| convertedOperands.push_back(originalDataOperand); |
| } else { |
| // Type not supported. |
| return builder.notifyMatchFailure(op, "unsupported type"); |
| } |
| } |
| |
| builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands, |
| op.getOperation()->getAttrs()); |
| |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void mlir::populateOpenACCToLLVMConversionPatterns( |
| LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { |
| patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter); |
| patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter); |
| patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter); |
| patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter); |
| patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter); |
| } |
| |
| namespace { |
| struct ConvertOpenACCToLLVMPass |
| : public ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> { |
| void runOnOperation() override; |
| }; |
| } // namespace |
| |
| void ConvertOpenACCToLLVMPass::runOnOperation() { |
| auto op = getOperation(); |
| auto *context = op.getContext(); |
| |
| // Convert to OpenACC operations with LLVM IR dialect |
| RewritePatternSet patterns(context); |
| LLVMTypeConverter converter(context); |
| populateOpenACCToLLVMConversionPatterns(converter, patterns); |
| |
| ConversionTarget target(*context); |
| target.addLegalDialect<LLVM::LLVMDialect>(); |
| target.addLegalOp<UnrealizedConversionCastOp>(); |
| |
| auto allDataOperandsAreConverted = [](ValueRange operands) { |
| for (Value operand : operands) { |
| if (!DataDescriptor::isValid(operand) && |
| !operand.getType().isa<LLVM::LLVMPointerType>()) |
| return false; |
| } |
| return true; |
| }; |
| |
| target.addDynamicallyLegalOp<acc::DataOp>( |
| [allDataOperandsAreConverted](acc::DataOp op) { |
| return allDataOperandsAreConverted(op.copyOperands()) && |
| allDataOperandsAreConverted(op.copyinOperands()) && |
| allDataOperandsAreConverted(op.copyinReadonlyOperands()) && |
| allDataOperandsAreConverted(op.copyoutOperands()) && |
| allDataOperandsAreConverted(op.copyoutZeroOperands()) && |
| allDataOperandsAreConverted(op.createOperands()) && |
| allDataOperandsAreConverted(op.createZeroOperands()) && |
| allDataOperandsAreConverted(op.noCreateOperands()) && |
| allDataOperandsAreConverted(op.presentOperands()) && |
| allDataOperandsAreConverted(op.deviceptrOperands()) && |
| allDataOperandsAreConverted(op.attachOperands()); |
| }); |
| |
| target.addDynamicallyLegalOp<acc::EnterDataOp>( |
| [allDataOperandsAreConverted](acc::EnterDataOp op) { |
| return allDataOperandsAreConverted(op.copyinOperands()) && |
| allDataOperandsAreConverted(op.createOperands()) && |
| allDataOperandsAreConverted(op.createZeroOperands()) && |
| allDataOperandsAreConverted(op.attachOperands()); |
| }); |
| |
| target.addDynamicallyLegalOp<acc::ExitDataOp>( |
| [allDataOperandsAreConverted](acc::ExitDataOp op) { |
| return allDataOperandsAreConverted(op.copyoutOperands()) && |
| allDataOperandsAreConverted(op.deleteOperands()) && |
| allDataOperandsAreConverted(op.detachOperands()); |
| }); |
| |
| target.addDynamicallyLegalOp<acc::ParallelOp>( |
| [allDataOperandsAreConverted](acc::ParallelOp op) { |
| return allDataOperandsAreConverted(op.reductionOperands()) && |
| allDataOperandsAreConverted(op.copyOperands()) && |
| allDataOperandsAreConverted(op.copyinOperands()) && |
| allDataOperandsAreConverted(op.copyinReadonlyOperands()) && |
| allDataOperandsAreConverted(op.copyoutOperands()) && |
| allDataOperandsAreConverted(op.copyoutZeroOperands()) && |
| allDataOperandsAreConverted(op.createOperands()) && |
| allDataOperandsAreConverted(op.createZeroOperands()) && |
| allDataOperandsAreConverted(op.noCreateOperands()) && |
| allDataOperandsAreConverted(op.presentOperands()) && |
| allDataOperandsAreConverted(op.devicePtrOperands()) && |
| allDataOperandsAreConverted(op.attachOperands()) && |
| allDataOperandsAreConverted(op.gangPrivateOperands()) && |
| allDataOperandsAreConverted(op.gangFirstPrivateOperands()); |
| }); |
| |
| target.addDynamicallyLegalOp<acc::UpdateOp>( |
| [allDataOperandsAreConverted](acc::UpdateOp op) { |
| return allDataOperandsAreConverted(op.hostOperands()) && |
| allDataOperandsAreConverted(op.deviceOperands()); |
| }); |
| |
| if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| |
| std::unique_ptr<OperationPass<ModuleOp>> |
| mlir::createConvertOpenACCToLLVMPass() { |
| return std::make_unique<ConvertOpenACCToLLVMPass>(); |
| } |