| //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "ReductionProcessor.h" |
| |
| #include "PrivateReductionUtils.h" |
| #include "flang/Lower/AbstractConverter.h" |
| #include "flang/Lower/ConvertType.h" |
| #include "flang/Lower/SymbolMap.h" |
| #include "flang/Optimizer/Builder/Complex.h" |
| #include "flang/Optimizer/Builder/HLFIRTools.h" |
| #include "flang/Optimizer/Builder/Todo.h" |
| #include "flang/Optimizer/Dialect/FIRType.h" |
| #include "flang/Optimizer/HLFIR/HLFIROps.h" |
| #include "flang/Optimizer/Support/FatalError.h" |
| #include "flang/Parser/tools.h" |
| #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| #include "llvm/Support/CommandLine.h" |
| |
| static llvm::cl::opt<bool> forceByrefReduction( |
| "force-byref-reduction", |
| llvm::cl::desc("Pass all reduction arguments by reference"), |
| llvm::cl::Hidden); |
| |
| using ReductionModifier = |
| Fortran::lower::omp::clause::Reduction::ReductionModifier; |
| |
| namespace Fortran { |
| namespace lower { |
| namespace omp { |
| |
| ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
| const omp::clause::ProcedureDesignator &pd) { |
| auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( |
| getRealName(pd.v.sym()).ToString()) |
| .Case("max", ReductionIdentifier::MAX) |
| .Case("min", ReductionIdentifier::MIN) |
| .Case("iand", ReductionIdentifier::IAND) |
| .Case("ior", ReductionIdentifier::IOR) |
| .Case("ieor", ReductionIdentifier::IEOR) |
| .Default(std::nullopt); |
| assert(redType && "Invalid Reduction"); |
| return *redType; |
| } |
| |
| ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
| omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { |
| switch (intrinsicOp) { |
| case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
| return ReductionIdentifier::ADD; |
| case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: |
| return ReductionIdentifier::SUBTRACT; |
| case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
| return ReductionIdentifier::MULTIPLY; |
| case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
| return ReductionIdentifier::AND; |
| case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
| return ReductionIdentifier::EQV; |
| case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
| return ReductionIdentifier::OR; |
| case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
| return ReductionIdentifier::NEQV; |
| default: |
| llvm_unreachable("unexpected intrinsic operator in reduction"); |
| } |
| } |
| |
| bool ReductionProcessor::supportedIntrinsicProcReduction( |
| const omp::clause::ProcedureDesignator &pd) { |
| semantics::Symbol *sym = pd.v.sym(); |
| if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC)) |
| return false; |
| auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) |
| .Case("max", true) |
| .Case("min", true) |
| .Case("iand", true) |
| .Case("ior", true) |
| .Case("ieor", true) |
| .Default(false); |
| return redType; |
| } |
| |
| std::string |
| ReductionProcessor::getReductionName(llvm::StringRef name, |
| const fir::KindMapping &kindMap, |
| mlir::Type ty, bool isByRef) { |
| ty = fir::unwrapRefType(ty); |
| |
| // extra string to distinguish reduction functions for variables passed by |
| // reference |
| llvm::StringRef byrefAddition{""}; |
| if (isByRef) |
| byrefAddition = "_byref"; |
| |
| return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); |
| } |
| |
| std::string ReductionProcessor::getReductionName( |
| omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, |
| const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { |
| std::string reductionName; |
| |
| switch (intrinsicOp) { |
| case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
| reductionName = "add_reduction"; |
| break; |
| case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
| reductionName = "multiply_reduction"; |
| break; |
| case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
| return "and_reduction"; |
| case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
| return "eqv_reduction"; |
| case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
| return "or_reduction"; |
| case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
| return "neqv_reduction"; |
| default: |
| reductionName = "other_reduction"; |
| break; |
| } |
| |
| return getReductionName(reductionName, kindMap, ty, isByRef); |
| } |
| |
| mlir::Value |
| ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, |
| ReductionIdentifier redId, |
| fir::FirOpBuilder &builder) { |
| type = fir::unwrapRefType(type); |
| if (!fir::isa_integer(type) && !fir::isa_real(type) && |
| !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) |
| TODO(loc, "Reduction of some types is not supported"); |
| switch (redId) { |
| case ReductionIdentifier::MAX: { |
| if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { |
| const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
| return builder.createRealConstant( |
| loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); |
| } |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, minInt); |
| } |
| case ReductionIdentifier::MIN: { |
| if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { |
| const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
| return builder.createRealConstant( |
| loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); |
| } |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, maxInt); |
| } |
| case ReductionIdentifier::IOR: { |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, zeroInt); |
| } |
| case ReductionIdentifier::IEOR: { |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, zeroInt); |
| } |
| case ReductionIdentifier::IAND: { |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, allOnInt); |
| } |
| case ReductionIdentifier::ADD: |
| case ReductionIdentifier::MULTIPLY: |
| case ReductionIdentifier::AND: |
| case ReductionIdentifier::OR: |
| case ReductionIdentifier::EQV: |
| case ReductionIdentifier::NEQV: |
| if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) { |
| mlir::Type realTy = cplxTy.getElementType(); |
| mlir::Value initRe = builder.createRealConstant( |
| loc, realTy, getOperationIdentity(redId, loc)); |
| mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); |
| |
| return fir::factory::Complex{builder, loc}.createComplex(type, initRe, |
| initIm); |
| } |
| if (mlir::isa<mlir::FloatType>(type)) |
| return builder.create<mlir::arith::ConstantOp>( |
| loc, type, |
| builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); |
| |
| if (mlir::isa<fir::LogicalType>(type)) { |
| mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( |
| loc, builder.getI1Type(), |
| builder.getIntegerAttr(builder.getI1Type(), |
| getOperationIdentity(redId, loc))); |
| return builder.createConvert(loc, type, intConst); |
| } |
| |
| return builder.create<mlir::arith::ConstantOp>( |
| loc, type, |
| builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); |
| case ReductionIdentifier::ID: |
| case ReductionIdentifier::USER_DEF_OP: |
| case ReductionIdentifier::SUBTRACT: |
| TODO(loc, "Reduction of some identifier types is not supported"); |
| } |
| llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); |
| } |
| |
| mlir::Value ReductionProcessor::createScalarCombiner( |
| fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, |
| mlir::Type type, mlir::Value op1, mlir::Value op2) { |
| mlir::Value reductionOp; |
| type = fir::unwrapRefType(type); |
| switch (redId) { |
| case ReductionIdentifier::MAX: |
| reductionOp = |
| getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>( |
| builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::MIN: |
| reductionOp = |
| getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>( |
| builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::IOR: |
| assert((type.isIntOrIndex()) && "only integer is expected"); |
| reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); |
| break; |
| case ReductionIdentifier::IEOR: |
| assert((type.isIntOrIndex()) && "only integer is expected"); |
| reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); |
| break; |
| case ReductionIdentifier::IAND: |
| assert((type.isIntOrIndex()) && "only integer is expected"); |
| reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); |
| break; |
| case ReductionIdentifier::ADD: |
| reductionOp = |
| getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, |
| fir::AddcOp>(builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::MULTIPLY: |
| reductionOp = |
| getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, |
| fir::MulcOp>(builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::AND: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, andiOp); |
| break; |
| } |
| case ReductionIdentifier::OR: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, oriOp); |
| break; |
| } |
| case ReductionIdentifier::EQV: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
| loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, cmpiOp); |
| break; |
| } |
| case ReductionIdentifier::NEQV: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
| loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, cmpiOp); |
| break; |
| } |
| default: |
| TODO(loc, "Reduction of some intrinsic operators is not supported"); |
| } |
| |
| return reductionOp; |
| } |
| |
| /// Create reduction combiner region for reduction variables which are boxed |
| /// arrays |
| static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
| ReductionProcessor::ReductionIdentifier redId, |
| fir::BaseBoxType boxTy, mlir::Value lhs, |
| mlir::Value rhs) { |
| fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( |
| fir::unwrapRefType(boxTy.getEleTy())); |
| fir::HeapType heapTy = |
| mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy()); |
| fir::PointerType ptrTy = |
| mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy()); |
| if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy) |
| TODO(loc, "Unsupported boxed type in OpenMP reduction"); |
| |
| // load fir.ref<fir.box<...>> |
| mlir::Value lhsAddr = lhs; |
| lhs = builder.create<fir::LoadOp>(loc, lhs); |
| rhs = builder.create<fir::LoadOp>(loc, rhs); |
| |
| if ((heapTy || ptrTy) && !seqTy) { |
| // get box contents (heap pointers) |
| lhs = builder.create<fir::BoxAddrOp>(loc, lhs); |
| rhs = builder.create<fir::BoxAddrOp>(loc, rhs); |
| mlir::Value lhsValAddr = lhs; |
| |
| // load heap pointers |
| lhs = builder.create<fir::LoadOp>(loc, lhs); |
| rhs = builder.create<fir::LoadOp>(loc, rhs); |
| |
| mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy(); |
| |
| mlir::Value result = ReductionProcessor::createScalarCombiner( |
| builder, loc, redId, eleTy, lhs, rhs); |
| builder.create<fir::StoreOp>(loc, result, lhsValAddr); |
| builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
| return; |
| } |
| |
| // Get ShapeShift with default lower bounds. This makes it possible to use |
| // unmodified LoopNest's indices with ArrayCoorOp. |
| fir::ShapeShiftOp shapeShift = |
| getShapeShift(builder, loc, lhs, |
| /*cannotHaveNonDefaultLowerBounds=*/false, |
| /*useDefaultLowerBounds=*/true); |
| |
| // Iterate over array elements, applying the equivalent scalar reduction: |
| |
| // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced |
| // and so no null check is needed here before indexing into the (possibly |
| // allocatable) arrays. |
| |
| // A hlfir::elemental here gets inlined with a temporary so create the |
| // loop nest directly. |
| // This function already controls all of the code in this region so we |
| // know this won't miss any opportuinties for clever elemental inlining |
| hlfir::LoopNest nest = hlfir::genLoopNest( |
| loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); |
| builder.setInsertionPointToStart(nest.body); |
| mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); |
| auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( |
| loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, |
| nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
| auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( |
| loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, |
| nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
| auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); |
| auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); |
| mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( |
| builder, loc, redId, refTy, lhsEle, rhsEle); |
| builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); |
| |
| builder.setInsertionPointAfter(nest.outerOp); |
| builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
| } |
| |
| // generate combiner region for reduction operations |
| static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
| ReductionProcessor::ReductionIdentifier redId, |
| mlir::Type ty, mlir::Value lhs, mlir::Value rhs, |
| bool isByRef) { |
| ty = fir::unwrapRefType(ty); |
| |
| if (fir::isa_trivial(ty)) { |
| mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); |
| mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); |
| |
| mlir::Value result = ReductionProcessor::createScalarCombiner( |
| builder, loc, redId, ty, lhsLoaded, rhsLoaded); |
| if (isByRef) { |
| builder.create<fir::StoreOp>(loc, result, lhs); |
| builder.create<mlir::omp::YieldOp>(loc, lhs); |
| } else { |
| builder.create<mlir::omp::YieldOp>(loc, result); |
| } |
| return; |
| } |
| // all arrays should have been boxed |
| if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { |
| genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); |
| return; |
| } |
| |
| TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); |
| } |
| |
| // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes |
| static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { |
| if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) |
| return seqTy.getEleTy(); |
| if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { |
| auto eleTy = fir::unwrapRefType(boxTy.getEleTy()); |
| if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) |
| return seqTy.getEleTy(); |
| return eleTy; |
| } |
| return ty; |
| } |
| |
| static void createReductionAllocAndInitRegions( |
| AbstractConverter &converter, mlir::Location loc, |
| mlir::omp::DeclareReductionOp &reductionDecl, |
| const ReductionProcessor::ReductionIdentifier redId, mlir::Type type, |
| bool isByRef) { |
| fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| auto yield = [&](mlir::Value ret) { |
| builder.create<mlir::omp::YieldOp>(loc, ret); |
| }; |
| |
| mlir::Block *allocBlock = nullptr; |
| mlir::Block *initBlock = nullptr; |
| if (isByRef) { |
| allocBlock = |
| builder.createBlock(&reductionDecl.getAllocRegion(), |
| reductionDecl.getAllocRegion().end(), {}, {}); |
| initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), |
| reductionDecl.getInitializerRegion().end(), |
| {type, type}, {loc, loc}); |
| } else { |
| initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), |
| reductionDecl.getInitializerRegion().end(), |
| {type}, {loc}); |
| } |
| |
| mlir::Type ty = fir::unwrapRefType(type); |
| builder.setInsertionPointToEnd(initBlock); |
| mlir::Value initValue = ReductionProcessor::getReductionInitValue( |
| loc, unwrapSeqOrBoxedType(ty), redId, builder); |
| |
| if (isByRef) { |
| populateByRefInitAndCleanupRegions( |
| converter, loc, type, initValue, initBlock, |
| reductionDecl.getInitializerAllocArg(), |
| reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(), |
| DeclOperationKind::Reduction); |
| } |
| |
| if (fir::isa_trivial(ty)) { |
| if (isByRef) { |
| // alloc region |
| builder.setInsertionPointToEnd(allocBlock); |
| mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); |
| yield(alloca); |
| return; |
| } |
| // by val |
| yield(initValue); |
| return; |
| } |
| assert(isByRef && "passing non-trivial types by val is unsupported"); |
| |
| // alloc region |
| builder.setInsertionPointToEnd(allocBlock); |
| mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); |
| yield(boxAlloca); |
| } |
| |
| mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( |
| AbstractConverter &converter, llvm::StringRef reductionOpName, |
| const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, |
| bool isByRef) { |
| fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| mlir::OpBuilder::InsertionGuard guard(builder); |
| mlir::ModuleOp module = builder.getModule(); |
| |
| assert(!reductionOpName.empty()); |
| |
| auto decl = |
| module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); |
| if (decl) |
| return decl; |
| |
| mlir::OpBuilder modBuilder(module.getBodyRegion()); |
| mlir::Type valTy = fir::unwrapRefType(type); |
| if (!isByRef) |
| type = valTy; |
| |
| decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, |
| type); |
| createReductionAllocAndInitRegions(converter, loc, decl, redId, type, |
| isByRef); |
| |
| builder.createBlock(&decl.getReductionRegion(), |
| decl.getReductionRegion().end(), {type, type}, |
| {loc, loc}); |
| |
| builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); |
| mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); |
| mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); |
| genCombiner(builder, loc, redId, type, op1, op2, isByRef); |
| |
| return decl; |
| } |
| |
| static bool doReductionByRef(mlir::Value reductionVar) { |
| if (forceByrefReduction) |
| return true; |
| |
| if (auto declare = |
| mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) |
| reductionVar = declare.getMemref(); |
| |
| if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) |
| return true; |
| |
| return false; |
| } |
| |
| mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) { |
| switch (mod) { |
| case ReductionModifier::Default: |
| return mlir::omp::ReductionModifier::defaultmod; |
| case ReductionModifier::Inscan: |
| return mlir::omp::ReductionModifier::inscan; |
| case ReductionModifier::Task: |
| return mlir::omp::ReductionModifier::task; |
| } |
| return mlir::omp::ReductionModifier::defaultmod; |
| } |
| |
| void ReductionProcessor::processReductionArguments( |
| mlir::Location currentLocation, lower::AbstractConverter &converter, |
| const omp::clause::Reduction &reduction, |
| llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
| llvm::SmallVectorImpl<bool> &reduceVarByRef, |
| llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
| llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols, |
| mlir::omp::ReductionModifierAttr &reductionMod) { |
| fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| |
| auto mod = std::get<std::optional<ReductionModifier>>(reduction.t); |
| if (mod.has_value()) { |
| if (mod.value() == ReductionModifier::Task) |
| TODO(currentLocation, "Reduction modifier `task` is not supported"); |
| else |
| reductionMod = mlir::omp::ReductionModifierAttr::get( |
| firOpBuilder.getContext(), translateReductionModifier(mod.value())); |
| } |
| |
| mlir::omp::DeclareReductionOp decl; |
| const auto &redOperatorList{ |
| std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; |
| assert(redOperatorList.size() == 1 && "Expecting single operator"); |
| const auto &redOperator = redOperatorList.front(); |
| const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; |
| |
| if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { |
| if (const auto *reductionIntrinsic = |
| std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { |
| if (!ReductionProcessor::supportedIntrinsicProcReduction( |
| *reductionIntrinsic)) { |
| return; |
| } |
| } else { |
| return; |
| } |
| } |
| |
| // Reduction variable processing common to both intrinsic operators and |
| // procedure designators |
| fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| for (const Object &object : objectList) { |
| const semantics::Symbol *symbol = object.sym(); |
| reductionSymbols.push_back(symbol); |
| mlir::Value symVal = converter.getSymbolAddress(*symbol); |
| mlir::Type eleType; |
| auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); |
| if (refType) |
| eleType = refType.getEleTy(); |
| else |
| eleType = symVal.getType(); |
| |
| // all arrays must be boxed so that we have convenient access to all the |
| // information needed to iterate over the array |
| if (mlir::isa<fir::SequenceType>(eleType)) { |
| // For Host associated symbols, use `SymbolBox` instead |
| lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol); |
| hlfir::Entity entity{symBox.getAddr()}; |
| entity = genVariableBox(currentLocation, builder, entity); |
| mlir::Value box = entity.getBase(); |
| |
| // Always pass the box by reference so that the OpenMP dialect |
| // verifiers don't need to know anything about fir.box |
| auto alloca = |
| builder.create<fir::AllocaOp>(currentLocation, box.getType()); |
| builder.create<fir::StoreOp>(currentLocation, box, alloca); |
| |
| symVal = alloca; |
| } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { |
| // boxed arrays are passed as values not by reference. Unfortunately, |
| // we can't pass a box by value to omp.redution_declare, so turn it |
| // into a reference |
| |
| auto alloca = |
| builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); |
| builder.create<fir::StoreOp>(currentLocation, symVal, alloca); |
| symVal = alloca; |
| } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { |
| symVal = declOp.getBase(); |
| } |
| |
| // this isn't the same as the by-val and by-ref passing later in the |
| // pipeline. Both styles assume that the variable is a reference at |
| // this point |
| assert(fir::isa_ref_type(symVal.getType()) && |
| "reduction input var is passed by reference"); |
| mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType()); |
| mlir::Type refTy = fir::ReferenceType::get(elementType); |
| |
| reductionVars.push_back( |
| builder.createConvert(currentLocation, refTy, symVal)); |
| reduceVarByRef.push_back(doReductionByRef(symVal)); |
| } |
| |
| for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) { |
| auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); |
| const auto &kindMap = firOpBuilder.getKindMap(); |
| std::string reductionName; |
| ReductionIdentifier redId; |
| mlir::Type redNameTy = redType; |
| if (mlir::isa<fir::LogicalType>(redType.getEleTy())) |
| redNameTy = builder.getI1Type(); |
| |
| if (const auto &redDefinedOp = |
| std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { |
| const auto &intrinsicOp{ |
| std::get<omp::clause::DefinedOperator::IntrinsicOperator>( |
| redDefinedOp->u)}; |
| redId = getReductionType(intrinsicOp); |
| switch (redId) { |
| case ReductionIdentifier::ADD: |
| case ReductionIdentifier::MULTIPLY: |
| case ReductionIdentifier::AND: |
| case ReductionIdentifier::EQV: |
| case ReductionIdentifier::OR: |
| case ReductionIdentifier::NEQV: |
| break; |
| default: |
| TODO(currentLocation, |
| "Reduction of some intrinsic operators is not supported"); |
| break; |
| } |
| |
| reductionName = |
| getReductionName(intrinsicOp, kindMap, redNameTy, isByRef); |
| } else if (const auto *reductionIntrinsic = |
| std::get_if<omp::clause::ProcedureDesignator>( |
| &redOperator.u)) { |
| if (!ReductionProcessor::supportedIntrinsicProcReduction( |
| *reductionIntrinsic)) { |
| TODO(currentLocation, "Unsupported intrinsic proc reduction"); |
| } |
| redId = getReductionType(*reductionIntrinsic); |
| reductionName = |
| getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap, |
| redNameTy, isByRef); |
| } else { |
| TODO(currentLocation, "Unexpected reduction type"); |
| } |
| |
| decl = createDeclareReduction(converter, reductionName, redId, redType, |
| currentLocation, isByRef); |
| reductionDeclSymbols.push_back( |
| mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName())); |
| } |
| } |
| |
| const semantics::SourceName |
| ReductionProcessor::getRealName(const semantics::Symbol *symbol) { |
| return symbol->GetUltimate().name(); |
| } |
| |
| const semantics::SourceName |
| ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { |
| return getRealName(pd.v.sym()); |
| } |
| |
| int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, |
| mlir::Location loc) { |
| switch (redId) { |
| case ReductionIdentifier::ADD: |
| case ReductionIdentifier::OR: |
| case ReductionIdentifier::NEQV: |
| return 0; |
| case ReductionIdentifier::MULTIPLY: |
| case ReductionIdentifier::AND: |
| case ReductionIdentifier::EQV: |
| return 1; |
| default: |
| TODO(loc, "Reduction of some intrinsic operators is not supported"); |
| } |
| } |
| |
| } // namespace omp |
| } // namespace lower |
| } // namespace Fortran |