| //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "ReductionProcessor.h" |
| |
| #include "flang/Lower/AbstractConverter.h" |
| #include "flang/Lower/ConvertType.h" |
| #include "flang/Lower/SymbolMap.h" |
| #include "flang/Optimizer/Builder/Complex.h" |
| #include "flang/Optimizer/Builder/HLFIRTools.h" |
| #include "flang/Optimizer/Builder/Todo.h" |
| #include "flang/Optimizer/Dialect/FIRType.h" |
| #include "flang/Optimizer/HLFIR/HLFIROps.h" |
| #include "flang/Parser/tools.h" |
| #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| #include "llvm/Support/CommandLine.h" |
| |
| static llvm::cl::opt<bool> forceByrefReduction( |
| "force-byref-reduction", |
| llvm::cl::desc("Pass all reduction arguments by reference"), |
| llvm::cl::Hidden); |
| |
| namespace Fortran { |
| namespace lower { |
| namespace omp { |
| |
| ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
| const omp::clause::ProcedureDesignator &pd) { |
| auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( |
| getRealName(pd.v.id()).ToString()) |
| .Case("max", ReductionIdentifier::MAX) |
| .Case("min", ReductionIdentifier::MIN) |
| .Case("iand", ReductionIdentifier::IAND) |
| .Case("ior", ReductionIdentifier::IOR) |
| .Case("ieor", ReductionIdentifier::IEOR) |
| .Default(std::nullopt); |
| assert(redType && "Invalid Reduction"); |
| return *redType; |
| } |
| |
| ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
| omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { |
| switch (intrinsicOp) { |
| case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
| return ReductionIdentifier::ADD; |
| case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: |
| return ReductionIdentifier::SUBTRACT; |
| case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
| return ReductionIdentifier::MULTIPLY; |
| case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
| return ReductionIdentifier::AND; |
| case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
| return ReductionIdentifier::EQV; |
| case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
| return ReductionIdentifier::OR; |
| case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
| return ReductionIdentifier::NEQV; |
| default: |
| llvm_unreachable("unexpected intrinsic operator in reduction"); |
| } |
| } |
| |
| bool ReductionProcessor::supportedIntrinsicProcReduction( |
| const omp::clause::ProcedureDesignator &pd) { |
| Fortran::semantics::Symbol *sym = pd.v.id(); |
| if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) |
| return false; |
| auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) |
| .Case("max", true) |
| .Case("min", true) |
| .Case("iand", true) |
| .Case("ior", true) |
| .Case("ieor", true) |
| .Default(false); |
| return redType; |
| } |
| |
| std::string |
| ReductionProcessor::getReductionName(llvm::StringRef name, |
| const fir::KindMapping &kindMap, |
| mlir::Type ty, bool isByRef) { |
| ty = fir::unwrapRefType(ty); |
| |
| // extra string to distinguish reduction functions for variables passed by |
| // reference |
| llvm::StringRef byrefAddition{""}; |
| if (isByRef) |
| byrefAddition = "_byref"; |
| |
| return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); |
| } |
| |
| std::string ReductionProcessor::getReductionName( |
| omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, |
| const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { |
| std::string reductionName; |
| |
| switch (intrinsicOp) { |
| case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
| reductionName = "add_reduction"; |
| break; |
| case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
| reductionName = "multiply_reduction"; |
| break; |
| case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
| return "and_reduction"; |
| case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
| return "eqv_reduction"; |
| case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
| return "or_reduction"; |
| case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
| return "neqv_reduction"; |
| default: |
| reductionName = "other_reduction"; |
| break; |
| } |
| |
| return getReductionName(reductionName, kindMap, ty, isByRef); |
| } |
| |
| mlir::Value |
| ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, |
| ReductionIdentifier redId, |
| fir::FirOpBuilder &builder) { |
| type = fir::unwrapRefType(type); |
| if (!fir::isa_integer(type) && !fir::isa_real(type) && |
| !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) |
| TODO(loc, "Reduction of some types is not supported"); |
| switch (redId) { |
| case ReductionIdentifier::MAX: { |
| if (auto ty = type.dyn_cast<mlir::FloatType>()) { |
| const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
| return builder.createRealConstant( |
| loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); |
| } |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, minInt); |
| } |
| case ReductionIdentifier::MIN: { |
| if (auto ty = type.dyn_cast<mlir::FloatType>()) { |
| const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
| return builder.createRealConstant( |
| loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); |
| } |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, maxInt); |
| } |
| case ReductionIdentifier::IOR: { |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, zeroInt); |
| } |
| case ReductionIdentifier::IEOR: { |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, zeroInt); |
| } |
| case ReductionIdentifier::IAND: { |
| unsigned bits = type.getIntOrFloatBitWidth(); |
| int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); |
| return builder.createIntegerConstant(loc, type, allOnInt); |
| } |
| case ReductionIdentifier::ADD: |
| case ReductionIdentifier::MULTIPLY: |
| case ReductionIdentifier::AND: |
| case ReductionIdentifier::OR: |
| case ReductionIdentifier::EQV: |
| case ReductionIdentifier::NEQV: |
| if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) { |
| mlir::Type realTy = |
| Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind()); |
| mlir::Value initRe = builder.createRealConstant( |
| loc, realTy, getOperationIdentity(redId, loc)); |
| mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); |
| |
| return fir::factory::Complex{builder, loc}.createComplex(type, initRe, |
| initIm); |
| } |
| if (type.isa<mlir::FloatType>()) |
| return builder.create<mlir::arith::ConstantOp>( |
| loc, type, |
| builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); |
| |
| if (type.isa<fir::LogicalType>()) { |
| mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( |
| loc, builder.getI1Type(), |
| builder.getIntegerAttr(builder.getI1Type(), |
| getOperationIdentity(redId, loc))); |
| return builder.createConvert(loc, type, intConst); |
| } |
| |
| return builder.create<mlir::arith::ConstantOp>( |
| loc, type, |
| builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); |
| case ReductionIdentifier::ID: |
| case ReductionIdentifier::USER_DEF_OP: |
| case ReductionIdentifier::SUBTRACT: |
| TODO(loc, "Reduction of some identifier types is not supported"); |
| } |
| llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue"); |
| } |
| |
| mlir::Value ReductionProcessor::createScalarCombiner( |
| fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, |
| mlir::Type type, mlir::Value op1, mlir::Value op2) { |
| mlir::Value reductionOp; |
| type = fir::unwrapRefType(type); |
| switch (redId) { |
| case ReductionIdentifier::MAX: |
| reductionOp = |
| getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( |
| builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::MIN: |
| reductionOp = |
| getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( |
| builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::IOR: |
| assert((type.isIntOrIndex()) && "only integer is expected"); |
| reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); |
| break; |
| case ReductionIdentifier::IEOR: |
| assert((type.isIntOrIndex()) && "only integer is expected"); |
| reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); |
| break; |
| case ReductionIdentifier::IAND: |
| assert((type.isIntOrIndex()) && "only integer is expected"); |
| reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); |
| break; |
| case ReductionIdentifier::ADD: |
| reductionOp = |
| getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, |
| fir::AddcOp>(builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::MULTIPLY: |
| reductionOp = |
| getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, |
| fir::MulcOp>(builder, type, loc, op1, op2); |
| break; |
| case ReductionIdentifier::AND: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, andiOp); |
| break; |
| } |
| case ReductionIdentifier::OR: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, oriOp); |
| break; |
| } |
| case ReductionIdentifier::EQV: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
| loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, cmpiOp); |
| break; |
| } |
| case ReductionIdentifier::NEQV: { |
| mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
| mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
| |
| mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
| loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); |
| |
| reductionOp = builder.createConvert(loc, type, cmpiOp); |
| break; |
| } |
| default: |
| TODO(loc, "Reduction of some intrinsic operators is not supported"); |
| } |
| |
| return reductionOp; |
| } |
| |
| /// Create reduction combiner region for reduction variables which are boxed |
| /// arrays |
| static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
| ReductionProcessor::ReductionIdentifier redId, |
| fir::BaseBoxType boxTy, mlir::Value lhs, |
| mlir::Value rhs) { |
| fir::SequenceType seqTy = |
| mlir::dyn_cast_or_null<fir::SequenceType>(boxTy.getEleTy()); |
| // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> |
| if (!seqTy || seqTy.hasUnknownShape()) |
| TODO(loc, "Unsupported boxed type in OpenMP reduction"); |
| |
| // load fir.ref<fir.box<...>> |
| mlir::Value lhsAddr = lhs; |
| lhs = builder.create<fir::LoadOp>(loc, lhs); |
| rhs = builder.create<fir::LoadOp>(loc, rhs); |
| |
| const unsigned rank = seqTy.getDimension(); |
| llvm::SmallVector<mlir::Value> extents; |
| extents.reserve(rank); |
| llvm::SmallVector<mlir::Value> lbAndExtents; |
| lbAndExtents.reserve(rank * 2); |
| |
| // Get box lowerbounds and extents: |
| mlir::Type idxTy = builder.getIndexType(); |
| for (unsigned i = 0; i < rank; ++i) { |
| // TODO: ideally we want to hoist box reads out of the critical section. |
| // We could do this by having box dimensions in block arguments like |
| // OpenACC does |
| mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); |
| auto dimInfo = |
| builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, lhs, dim); |
| extents.push_back(dimInfo.getExtent()); |
| lbAndExtents.push_back(dimInfo.getLowerBound()); |
| lbAndExtents.push_back(dimInfo.getExtent()); |
| } |
| |
| auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); |
| auto shapeShift = |
| builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); |
| |
| // Iterate over array elements, applying the equivalent scalar reduction: |
| |
| // A hlfir::elemental here gets inlined with a temporary so create the |
| // loop nest directly. |
| // This function already controls all of the code in this region so we |
| // know this won't miss any opportuinties for clever elemental inlining |
| hlfir::LoopNest nest = |
| hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true); |
| builder.setInsertionPointToStart(nest.innerLoop.getBody()); |
| mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); |
| auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( |
| loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, |
| nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
| auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( |
| loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, |
| nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
| auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); |
| auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); |
| mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( |
| builder, loc, redId, refTy, lhsEle, rhsEle); |
| builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); |
| |
| builder.setInsertionPointAfter(nest.outerLoop); |
| builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
| } |
| |
| // generate combiner region for reduction operations |
| static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
| ReductionProcessor::ReductionIdentifier redId, |
| mlir::Type ty, mlir::Value lhs, mlir::Value rhs, |
| bool isByRef) { |
| ty = fir::unwrapRefType(ty); |
| |
| if (fir::isa_trivial(ty)) { |
| mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); |
| mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); |
| |
| mlir::Value result = ReductionProcessor::createScalarCombiner( |
| builder, loc, redId, ty, lhsLoaded, rhsLoaded); |
| if (isByRef) { |
| builder.create<fir::StoreOp>(loc, result, lhs); |
| builder.create<mlir::omp::YieldOp>(loc, lhs); |
| } else { |
| builder.create<mlir::omp::YieldOp>(loc, result); |
| } |
| return; |
| } |
| // all arrays should have been boxed |
| if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { |
| genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); |
| return; |
| } |
| |
| TODO(loc, "OpenMP genCombiner for unsupported reduction variable type"); |
| } |
| |
| static mlir::Value |
| createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, |
| const ReductionProcessor::ReductionIdentifier redId, |
| mlir::Type type, bool isByRef) { |
| mlir::Type ty = fir::unwrapRefType(type); |
| mlir::Value initValue = ReductionProcessor::getReductionInitValue( |
| loc, fir::unwrapSeqOrBoxedSeqType(ty), redId, builder); |
| |
| if (fir::isa_trivial(ty)) { |
| if (isByRef) { |
| mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); |
| builder.createStoreWithConvert(loc, initValue, alloca); |
| return alloca; |
| } |
| // by val |
| return initValue; |
| } |
| |
| // all arrays are boxed |
| if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { |
| assert(isByRef && "passing arrays by value is unsupported"); |
| // TODO: support allocatable arrays: !fir.box<!fir.heap<!fir.array<...>>> |
| mlir::Type innerTy = fir::extractSequenceType(boxTy); |
| if (!mlir::isa<fir::SequenceType>(innerTy)) |
| TODO(loc, "Unsupported boxed type for reduction"); |
| // Create the private copy from the initial fir.box: |
| hlfir::Entity source = hlfir::Entity{builder.getBlock()->getArgument(0)}; |
| |
| // TODO: if the whole reduction is nested inside of a loop, this alloca |
| // could lead to a stack overflow (the memory is only freed at the end of |
| // the stack frame). The reduction declare operation needs a deallocation |
| // region to undo the init region. |
| hlfir::Entity temp = createStackTempFromMold(loc, builder, source); |
| |
| // Put the temporary inside of a box: |
| hlfir::Entity box = hlfir::genVariableBox(loc, builder, temp); |
| builder.create<hlfir::AssignOp>(loc, initValue, box); |
| mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); |
| builder.create<fir::StoreOp>(loc, box, boxAlloca); |
| return boxAlloca; |
| } |
| |
| TODO(loc, "createReductionInitRegion for unsupported type"); |
| } |
| |
| mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( |
| fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, |
| const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, |
| bool isByRef) { |
| mlir::OpBuilder::InsertionGuard guard(builder); |
| mlir::ModuleOp module = builder.getModule(); |
| |
| assert(!reductionOpName.empty()); |
| |
| auto decl = |
| module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); |
| if (decl) |
| return decl; |
| |
| mlir::OpBuilder modBuilder(module.getBodyRegion()); |
| mlir::Type valTy = fir::unwrapRefType(type); |
| if (!isByRef) |
| type = valTy; |
| |
| decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, |
| type); |
| builder.createBlock(&decl.getInitializerRegion(), |
| decl.getInitializerRegion().end(), {type}, {loc}); |
| builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); |
| |
| mlir::Value init = |
| createReductionInitRegion(builder, loc, redId, type, isByRef); |
| builder.create<mlir::omp::YieldOp>(loc, init); |
| |
| builder.createBlock(&decl.getReductionRegion(), |
| decl.getReductionRegion().end(), {type, type}, |
| {loc, loc}); |
| |
| builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); |
| mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); |
| mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); |
| genCombiner(builder, loc, redId, type, op1, op2, isByRef); |
| |
| return decl; |
| } |
| |
| // TODO: By-ref vs by-val reductions are currently toggled for the whole |
| // operation (possibly effecting multiple reduction variables). |
| // This could cause a problem with openmp target reductions because |
| // by-ref trivial types may not be supported. |
| bool ReductionProcessor::doReductionByRef( |
| const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { |
| if (reductionVars.empty()) |
| return false; |
| if (forceByrefReduction) |
| return true; |
| |
| for (mlir::Value reductionVar : reductionVars) { |
| if (auto declare = |
| mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) |
| reductionVar = declare.getMemref(); |
| |
| if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) |
| return true; |
| } |
| return false; |
| } |
| |
| void ReductionProcessor::addDeclareReduction( |
| mlir::Location currentLocation, |
| Fortran::lower::AbstractConverter &converter, |
| const omp::clause::Reduction &reduction, |
| llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
| llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
| llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| *reductionSymbols) { |
| fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| mlir::omp::DeclareReductionOp decl; |
| const auto &redOperatorList{ |
| std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; |
| assert(redOperatorList.size() == 1 && "Expecting single operator"); |
| const auto &redOperator = redOperatorList.front(); |
| const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; |
| |
| if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { |
| if (const auto *reductionIntrinsic = |
| std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { |
| if (!ReductionProcessor::supportedIntrinsicProcReduction( |
| *reductionIntrinsic)) { |
| return; |
| } |
| } else { |
| return; |
| } |
| } |
| |
| // initial pass to collect all reduction vars so we can figure out if this |
| // should happen byref |
| fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| for (const Object &object : objectList) { |
| const Fortran::semantics::Symbol *symbol = object.id(); |
| if (reductionSymbols) |
| reductionSymbols->push_back(symbol); |
| mlir::Value symVal = converter.getSymbolAddress(*symbol); |
| mlir::Type eleType; |
| auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); |
| if (refType) |
| eleType = refType.getEleTy(); |
| else |
| eleType = symVal.getType(); |
| |
| // all arrays must be boxed so that we have convenient access to all the |
| // information needed to iterate over the array |
| if (mlir::isa<fir::SequenceType>(eleType)) { |
| // For Host associated symbols, use `SymbolBox` instead |
| Fortran::lower::SymbolBox symBox = |
| converter.lookupOneLevelUpSymbol(*symbol); |
| hlfir::Entity entity{symBox.getAddr()}; |
| entity = genVariableBox(currentLocation, builder, entity); |
| mlir::Value box = entity.getBase(); |
| |
| // Always pass the box by reference so that the OpenMP dialect |
| // verifiers don't need to know anything about fir.box |
| auto alloca = |
| builder.create<fir::AllocaOp>(currentLocation, box.getType()); |
| builder.create<fir::StoreOp>(currentLocation, box, alloca); |
| |
| symVal = alloca; |
| } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { |
| // boxed arrays are passed as values not by reference. Unfortunately, |
| // we can't pass a box by value to omp.redution_declare, so turn it |
| // into a reference |
| |
| auto alloca = |
| builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); |
| builder.create<fir::StoreOp>(currentLocation, symVal, alloca); |
| symVal = alloca; |
| } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { |
| symVal = declOp.getBase(); |
| } |
| |
| // this isn't the same as the by-val and by-ref passing later in the |
| // pipeline. Both styles assume that the variable is a reference at |
| // this point |
| assert(mlir::isa<fir::ReferenceType>(symVal.getType()) && |
| "reduction input var is a reference"); |
| |
| reductionVars.push_back(symVal); |
| } |
| const bool isByRef = doReductionByRef(reductionVars); |
| |
| if (const auto &redDefinedOp = |
| std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { |
| const auto &intrinsicOp{ |
| std::get<omp::clause::DefinedOperator::IntrinsicOperator>( |
| redDefinedOp->u)}; |
| ReductionIdentifier redId = getReductionType(intrinsicOp); |
| switch (redId) { |
| case ReductionIdentifier::ADD: |
| case ReductionIdentifier::MULTIPLY: |
| case ReductionIdentifier::AND: |
| case ReductionIdentifier::EQV: |
| case ReductionIdentifier::OR: |
| case ReductionIdentifier::NEQV: |
| break; |
| default: |
| TODO(currentLocation, |
| "Reduction of some intrinsic operators is not supported"); |
| break; |
| } |
| |
| for (mlir::Value symVal : reductionVars) { |
| auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); |
| const auto &kindMap = firOpBuilder.getKindMap(); |
| if (redType.getEleTy().isa<fir::LogicalType>()) |
| decl = createDeclareReduction(firOpBuilder, |
| getReductionName(intrinsicOp, kindMap, |
| firOpBuilder.getI1Type(), |
| isByRef), |
| redId, redType, currentLocation, isByRef); |
| else |
| decl = createDeclareReduction( |
| firOpBuilder, |
| getReductionName(intrinsicOp, kindMap, redType, isByRef), redId, |
| redType, currentLocation, isByRef); |
| reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( |
| firOpBuilder.getContext(), decl.getSymName())); |
| } |
| } else if (const auto *reductionIntrinsic = |
| std::get_if<omp::clause::ProcedureDesignator>( |
| &redOperator.u)) { |
| if (ReductionProcessor::supportedIntrinsicProcReduction( |
| *reductionIntrinsic)) { |
| ReductionProcessor::ReductionIdentifier redId = |
| ReductionProcessor::getReductionType(*reductionIntrinsic); |
| for (const Object &object : objectList) { |
| const Fortran::semantics::Symbol *symbol = object.id(); |
| mlir::Value symVal = converter.getSymbolAddress(*symbol); |
| if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) |
| symVal = declOp.getBase(); |
| auto redType = symVal.getType().cast<fir::ReferenceType>(); |
| if (!redType.getEleTy().isIntOrIndexOrFloat()) |
| TODO(currentLocation, "User Defined Reduction on non-trivial type"); |
| decl = createDeclareReduction( |
| firOpBuilder, |
| getReductionName(getRealName(*reductionIntrinsic).ToString(), |
| firOpBuilder.getKindMap(), redType, isByRef), |
| redId, redType, currentLocation, isByRef); |
| reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( |
| firOpBuilder.getContext(), decl.getSymName())); |
| } |
| } |
| } |
| } |
| |
| const Fortran::semantics::SourceName |
| ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { |
| return symbol->GetUltimate().name(); |
| } |
| |
| const Fortran::semantics::SourceName |
| ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { |
| return getRealName(pd.v.id()); |
| } |
| |
| int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, |
| mlir::Location loc) { |
| switch (redId) { |
| case ReductionIdentifier::ADD: |
| case ReductionIdentifier::OR: |
| case ReductionIdentifier::NEQV: |
| return 0; |
| case ReductionIdentifier::MULTIPLY: |
| case ReductionIdentifier::AND: |
| case ReductionIdentifier::EQV: |
| return 1; |
| default: |
| TODO(loc, "Reduction of some intrinsic operators is not supported"); |
| } |
| } |
| |
| } // namespace omp |
| } // namespace lower |
| } // namespace Fortran |