| //===-- lib/Utisl/OpenMP.cpp ------------------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "flang/Utils/OpenMP.h" |
| |
| #include "flang/Lower/ConvertExprToHLFIR.h" |
| #include "flang/Optimizer/Builder/DirectivesCommon.h" |
| #include "flang/Optimizer/Builder/FIRBuilder.h" |
| #include "flang/Optimizer/Dialect/FIROps.h" |
| #include "flang/Optimizer/Dialect/FIRType.h" |
| |
| #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| #include "mlir/Transforms/RegionUtils.h" |
| |
| namespace Fortran::utils::openmp { |
| mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder, |
| mlir::Location loc, mlir::Value baseAddr, mlir::Value varPtrPtr, |
| llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds, |
| llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex, |
| mlir::omp::ClauseMapFlags mapType, |
| mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, |
| bool partialMap, mlir::FlatSymbolRefAttr mapperId) { |
| |
| auto getPtrVarType = [](mlir::Type ptrType) { |
| mlir::TypeAttr varType = mlir::TypeAttr::get( |
| llvm::cast<mlir::omp::PointerLikeType>(ptrType).getElementType()); |
| |
| // For types with unknown extents such as <2x?xi32> we discard the |
| // incomplete type info and only retain the base type. The correct |
| // dimensions are later recovered through the bounds info. |
| if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue())) |
| if (seqType.hasDynamicExtents()) |
| varType = mlir::TypeAttr::get(seqType.getEleTy()); |
| return varType; |
| }; |
| |
| if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) { |
| baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr); |
| retTy = baseAddr.getType(); |
| } |
| |
| auto varPtrType = getPtrVarType(retTy); |
| auto varPtrPtrTy = |
| varPtrPtr ? getPtrVarType(varPtrPtr.getType()) : mlir::TypeAttr{}; |
| |
| mlir::omp::MapInfoOp op = |
| mlir::omp::MapInfoOp::create(builder, loc, retTy, baseAddr, varPtrType, |
| builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType), |
| builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType), |
| varPtrPtr, varPtrPtrTy, members, membersIndex, bounds, mapperId, |
| builder.getStringAttr(name), builder.getBoolAttr(partialMap)); |
| return op; |
| } |
| |
| mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder, |
| mlir::omp::TargetOp targetOp, mlir::Value val, llvm::StringRef name) { |
| mlir::OpBuilder::InsertionGuard guard(firOpBuilder); |
| mlir::Operation *valOp = val.getDefiningOp(); |
| |
| if (valOp) |
| firOpBuilder.setInsertionPointAfter(valOp); |
| else |
| // This means val is a block argument |
| firOpBuilder.setInsertionPoint(targetOp); |
| |
| auto copyVal = firOpBuilder.createTemporary(val.getLoc(), val.getType()); |
| firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal); |
| |
| fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( |
| firOpBuilder, val, /*isOptional=*/false, val.getLoc()); |
| llvm::SmallVector<mlir::Value> bounds = |
| fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, |
| mlir::omp::MapBoundsType>(firOpBuilder, info, |
| hlfir::translateToExtendedValue( |
| val.getLoc(), firOpBuilder, hlfir::Entity{val}) |
| .first, |
| /*dataExvIsAssumedSize=*/false, val.getLoc()); |
| |
| firOpBuilder.setInsertionPoint(targetOp); |
| |
| mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit; |
| mlir::omp::VariableCaptureKind captureKind = |
| mlir::omp::VariableCaptureKind::ByRef; |
| |
| mlir::Type eleType = copyVal.getType(); |
| if (auto refType = mlir::dyn_cast<fir::ReferenceType>(copyVal.getType())) { |
| eleType = refType.getElementType(); |
| } |
| |
| if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) { |
| captureKind = mlir::omp::VariableCaptureKind::ByCopy; |
| } else if (!fir::isa_builtin_cptr_type(eleType)) { |
| mapFlag |= mlir::omp::ClauseMapFlags::to; |
| } |
| |
| mlir::Value mapOp = createMapInfoOp(firOpBuilder, copyVal.getLoc(), copyVal, |
| /*varPtrPtr=*/mlir::Value{}, name.str(), bounds, |
| /*members=*/llvm::SmallVector<mlir::Value>{}, |
| /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, |
| copyVal.getType()); |
| |
| auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp); |
| mlir::Region ®ion = targetOp.getRegion(); |
| |
| // Get the index of the first non-map argument before modifying mapVars, |
| // then append an element to mapVars and an associated entry block |
| // argument at that index. |
| unsigned insertIndex = |
| argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs(); |
| targetOp.getMapVarsMutable().append(mapOp); |
| mlir::Value clonedValArg = |
| region.insertArgument(insertIndex, copyVal.getType(), copyVal.getLoc()); |
| |
| mlir::Block *entryBlock = ®ion.getBlocks().front(); |
| firOpBuilder.setInsertionPointToStart(entryBlock); |
| auto loadOp = |
| fir::LoadOp::create(firOpBuilder, clonedValArg.getLoc(), clonedValArg); |
| return loadOp.getResult(); |
| } |
| |
| void cloneOrMapRegionOutsiders( |
| fir::FirOpBuilder &firOpBuilder, mlir::omp::TargetOp targetOp) { |
| mlir::Region ®ion = targetOp.getRegion(); |
| mlir::Block *entryBlock = ®ion.getBlocks().front(); |
| |
| llvm::SetVector<mlir::Value> valuesDefinedAbove; |
| mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); |
| while (!valuesDefinedAbove.empty()) { |
| for (mlir::Value val : valuesDefinedAbove) { |
| mlir::Operation *valOp = val.getDefiningOp(); |
| |
| // NOTE: We skip BoxDimsOp's as the lesser of two evils is to map the |
| // indices separately, as the alternative is to eventually map the Box, |
| // which comes with a fairly large overhead comparatively. We could be |
| // more robust about this and check using a BackwardsSlice to see if we |
| // run the risk of mapping a box. |
| if (valOp && mlir::isMemoryEffectFree(valOp) && |
| !mlir::isa<fir::BoxDimsOp>(valOp)) { |
| mlir::Operation *clonedOp = valOp->clone(); |
| entryBlock->push_front(clonedOp); |
| |
| auto replace = [entryBlock](mlir::OpOperand &use) { |
| return use.getOwner()->getBlock() == entryBlock; |
| }; |
| |
| valOp->getResults().replaceUsesWithIf(clonedOp->getResults(), replace); |
| valOp->replaceUsesWithIf(clonedOp, replace); |
| } else { |
| mlir::Value mappedTemp = mapTemporaryValue(firOpBuilder, targetOp, val, |
| /*name=*/{}); |
| val.replaceUsesWithIf(mappedTemp, [entryBlock](mlir::OpOperand &use) { |
| return use.getOwner()->getBlock() == entryBlock; |
| }); |
| } |
| } |
| valuesDefinedAbove.clear(); |
| mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove); |
| } |
| } |
| |
| /// Gets or generates a default declare mapper for a given record type. |
| /// |
| /// \param firOpBuilder The builder to use for generating the mapper. |
| /// \param loc The location to use for the generated operations. |
| /// \param recordType The record type to generate the mapper for. |
| /// \param mapperNameStr The name of the mapper to generate. |
| /// \param mangler A function to mangle the mapper name for nested types. |
| mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper( |
| fir::FirOpBuilder &firOpBuilder, mlir::Location loc, |
| fir::RecordType recordType, llvm::StringRef mapperNameStr, |
| RecordMemberMapperMangler mangler) { |
| if (mapperNameStr.empty()) |
| return {}; |
| |
| mlir::ModuleOp moduleOp = firOpBuilder.getModule(); |
| if (moduleOp.lookupSymbol(mapperNameStr)) |
| return mlir::FlatSymbolRefAttr::get( |
| firOpBuilder.getContext(), mapperNameStr); |
| |
| mlir::OpBuilder::InsertionGuard guard(firOpBuilder); |
| |
| firOpBuilder.setInsertionPointToStart(moduleOp.getBody()); |
| auto declMapperOp = mlir::omp::DeclareMapperOp::create( |
| firOpBuilder, loc, mapperNameStr, recordType); |
| auto ®ion = declMapperOp.getRegion(); |
| firOpBuilder.createBlock(®ion); |
| auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc); |
| |
| auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg, |
| /*uniq_name=*/""); |
| |
| const auto genBoundsOps = [&](mlir::Value mapVal, |
| llvm::SmallVectorImpl<mlir::Value> &bounds) { |
| fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(), |
| firOpBuilder, hlfir::Entity{mapVal}, |
| /*contiguousHint=*/true) |
| .first; |
| fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr( |
| firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc()); |
| bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp, |
| mlir::omp::MapBoundsType>(firOpBuilder, info, extVal, |
| /*dataExvIsAssumedSize=*/false, mapVal.getLoc()); |
| }; |
| |
| const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName, |
| mlir::Type fieldTy, mlir::Type recType) { |
| mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc, |
| fir::FieldType::get(recType.getContext()), fieldName, recType, |
| fir::getTypeParams(rec)); |
| return fir::CoordinateOp::create( |
| firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field); |
| }; |
| |
| llvm::SmallVector<mlir::Value> clauseMapVars; |
| llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices; |
| llvm::SmallVector<mlir::Value> memberMapOps; |
| |
| mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to | |
| mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit; |
| mlir::omp::VariableCaptureKind captureKind = |
| mlir::omp::VariableCaptureKind::ByRef; |
| |
| for (const auto &entry : llvm::enumerate(recordType.getTypeList())) { |
| const auto &memberName = entry.value().first; |
| const auto &memberType = entry.value().second; |
| mlir::FlatSymbolRefAttr mapperId; |
| if (auto recType = mlir::dyn_cast<fir::RecordType>( |
| fir::getFortranElementType(memberType))) { |
| std::string mapperIdName = |
| recType.getName().str() + llvm::omp::OmpDefaultMapperName; |
| mangler(mapperIdName, memberName); |
| mapperId = getOrGenImplicitDefaultDeclareMapper( |
| firOpBuilder, loc, recType, mapperIdName, mangler); |
| } |
| |
| auto ref = |
| getFieldRef(declareOp.getBase(), memberName, memberType, recordType); |
| llvm::SmallVector<mlir::Value> bounds; |
| genBoundsOps(ref, bounds); |
| mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder, |
| loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds, |
| /*members=*/{}, |
| /*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(), |
| /*partialMap=*/false, mapperId); |
| memberMapOps.emplace_back(mapOp); |
| memberPlacementIndices.emplace_back( |
| llvm::SmallVector<int64_t>{(int64_t)entry.index()}); |
| } |
| |
| llvm::SmallVector<mlir::Value> bounds; |
| genBoundsOps(declareOp.getOriginalBase(), bounds); |
| mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit; |
| mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp( |
| firOpBuilder, loc, declareOp.getOriginalBase(), |
| /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps, |
| firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag, |
| captureKind, declareOp.getType(0), |
| /*partialMap=*/true); |
| |
| clauseMapVars.emplace_back(mapOp); |
| mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars); |
| return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr); |
| } |
| } // namespace Fortran::utils::openmp |