| //===-- Lower/OpenMP/ClauseProcessor.h --------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| // |
| //===----------------------------------------------------------------------===// |
| #ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H |
| #define FORTRAN_LOWER_CLAUASEPROCESSOR_H |
| |
| #include "Clauses.h" |
| #include "DirectivesCommon.h" |
| #include "ReductionProcessor.h" |
| #include "Utils.h" |
| #include "flang/Lower/AbstractConverter.h" |
| #include "flang/Lower/Bridge.h" |
| #include "flang/Optimizer/Builder/Todo.h" |
| #include "flang/Parser/dump-parse-tree.h" |
| #include "flang/Parser/parse-tree.h" |
| #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
| |
| namespace fir { |
| class FirOpBuilder; |
| } // namespace fir |
| |
| namespace Fortran { |
| namespace lower { |
| namespace omp { |
| |
| /// Class that handles the processing of OpenMP clauses. |
| /// |
| /// Its `process<ClauseName>()` methods perform MLIR code generation for their |
| /// corresponding clause if it is present in the clause list. Otherwise, they |
| /// will return `false` to signal that the clause was not found. |
| /// |
| /// The intended use is of this class is to move clause processing outside of |
| /// construct processing, since the same clauses can appear attached to |
| /// different constructs and constructs can be combined, so that code |
| /// duplication is minimized. |
| /// |
| /// Each construct-lowering function only calls the `process<ClauseName>()` |
| /// methods that relate to clauses that can impact the lowering of that |
| /// construct. |
| class ClauseProcessor { |
| using ClauseTy = Fortran::parser::OmpClause; |
| |
| public: |
| ClauseProcessor(Fortran::lower::AbstractConverter &converter, |
| Fortran::semantics::SemanticsContext &semaCtx, |
| const Fortran::parser::OmpClauseList &clauses) |
| : converter(converter), semaCtx(semaCtx), clauses2(clauses), |
| clauses(makeList(clauses, semaCtx)) {} |
| |
| // 'Unique' clauses: They can appear at most once in the clause list. |
| bool processCollapse( |
| mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval, |
| llvm::SmallVectorImpl<mlir::Value> &lowerBound, |
| llvm::SmallVectorImpl<mlir::Value> &upperBound, |
| llvm::SmallVectorImpl<mlir::Value> &step, |
| llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const; |
| bool processDefault() const; |
| bool processDevice(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processDeviceType(mlir::omp::DeclareTargetDeviceType &result) const; |
| bool processFinal(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processHint(mlir::IntegerAttr &result) const; |
| bool processMergeable(mlir::UnitAttr &result) const; |
| bool processNowait(mlir::UnitAttr &result) const; |
| bool processNumTeams(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processNumThreads(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processOrdered(mlir::IntegerAttr &result) const; |
| bool processPriority(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processProcBind(mlir::omp::ClauseProcBindKindAttr &result) const; |
| bool processSafelen(mlir::IntegerAttr &result) const; |
| bool processSchedule(mlir::omp::ClauseScheduleKindAttr &valAttr, |
| mlir::omp::ScheduleModifierAttr &modifierAttr, |
| mlir::UnitAttr &simdModifierAttr) const; |
| bool processScheduleChunk(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processSimdlen(mlir::IntegerAttr &result) const; |
| bool processThreadLimit(Fortran::lower::StatementContext &stmtCtx, |
| mlir::Value &result) const; |
| bool processUntied(mlir::UnitAttr &result) const; |
| |
| // 'Repeatable' clauses: They can appear multiple times in the clause list. |
| bool |
| processAllocate(llvm::SmallVectorImpl<mlir::Value> &allocatorOperands, |
| llvm::SmallVectorImpl<mlir::Value> &allocateOperands) const; |
| bool processCopyin() const; |
| bool processCopyPrivate( |
| mlir::Location currentLocation, |
| llvm::SmallVectorImpl<mlir::Value> ©PrivateVars, |
| llvm::SmallVectorImpl<mlir::Attribute> ©PrivateFuncs) const; |
| bool processDepend(llvm::SmallVectorImpl<mlir::Attribute> &dependTypeOperands, |
| llvm::SmallVectorImpl<mlir::Value> &dependOperands) const; |
| bool |
| processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| bool processIf(omp::clause::If::DirectiveNameModifier directiveName, |
| mlir::Value &result) const; |
| bool |
| processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| |
| // This method is used to process a map clause. |
| // The optional parameters - mapSymTypes, mapSymLocs & mapSymbols are used to |
| // store the original type, location and Fortran symbol for the map operands. |
| // They may be used later on to create the block_arguments for some of the |
| // target directives that require it. |
| bool processMap(mlir::Location currentLocation, |
| const llvm::omp::Directive &directive, |
| Fortran::lower::StatementContext &stmtCtx, |
| llvm::SmallVectorImpl<mlir::Value> &mapOperands, |
| llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr, |
| llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr, |
| llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| *mapSymbols = nullptr) const; |
| bool |
| processReduction(mlir::Location currentLocation, |
| llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
| llvm::SmallVectorImpl<mlir::Type> &reductionTypes, |
| llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
| llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| *reductionSymbols = nullptr) const; |
| bool processSectionsReduction(mlir::Location currentLocation) const; |
| bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const; |
| bool |
| processUseDeviceAddr(llvm::SmallVectorImpl<mlir::Value> &operands, |
| llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, |
| llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, |
| llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| &useDeviceSymbols) const; |
| bool |
| processUseDevicePtr(llvm::SmallVectorImpl<mlir::Value> &operands, |
| llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, |
| llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, |
| llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
| &useDeviceSymbols) const; |
| |
| template <typename T> |
| bool processMotionClauses(Fortran::lower::StatementContext &stmtCtx, |
| llvm::SmallVectorImpl<mlir::Value> &mapOperands); |
| |
| // Call this method for these clauses that should be supported but are not |
| // implemented yet. It triggers a compilation error if any of the given |
| // clauses is found. |
| template <typename... Ts> |
| void processTODO(mlir::Location currentLocation, |
| llvm::omp::Directive directive) const; |
| |
| private: |
| using ClauseIterator = List<Clause>::const_iterator; |
| using ClauseIterator2 = std::list<ClauseTy>::const_iterator; |
| |
| /// Utility to find a clause within a range in the clause list. |
| template <typename T> |
| static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end); |
| template <typename T> |
| static ClauseIterator2 findClause2(ClauseIterator2 begin, |
| ClauseIterator2 end); |
| |
| /// Return the first instance of the given clause found in the clause list or |
| /// `nullptr` if not present. If more than one instance is expected, use |
| /// `findRepeatableClause` instead. |
| template <typename T> |
| const T * |
| findUniqueClause(const Fortran::parser::CharBlock **source = nullptr) const; |
| |
| /// Call `callbackFn` for each occurrence of the given clause. Return `true` |
| /// if at least one instance was found. |
| template <typename T> |
| bool findRepeatableClause( |
| std::function<void(const T &, const Fortran::parser::CharBlock &source)> |
| callbackFn) const; |
| template <typename T> |
| bool findRepeatableClause2( |
| std::function<void(const T *, const Fortran::parser::CharBlock &source)> |
| callbackFn) const; |
| |
| /// Set the `result` to a new `mlir::UnitAttr` if the clause is present. |
| template <typename T> |
| bool markClauseOccurrence(mlir::UnitAttr &result) const; |
| |
| Fortran::lower::AbstractConverter &converter; |
| Fortran::semantics::SemanticsContext &semaCtx; |
| const Fortran::parser::OmpClauseList &clauses2; |
| List<Clause> clauses; |
| }; |
| |
| template <typename T> |
| bool ClauseProcessor::processMotionClauses( |
| Fortran::lower::StatementContext &stmtCtx, |
| llvm::SmallVectorImpl<mlir::Value> &mapOperands) { |
| return findRepeatableClause2<T>( |
| [&](const T *motionClause, const Fortran::parser::CharBlock &source) { |
| mlir::Location clauseLocation = converter.genLocation(source); |
| fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
| |
| static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> || |
| std::is_same_v<T, ClauseProcessor::ClauseTy::From>); |
| |
| // TODO Support motion modifiers: present, mapper, iterator. |
| constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = |
| std::is_same_v<T, ClauseProcessor::ClauseTy::To> |
| ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
| : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
| |
| for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) { |
| llvm::SmallVector<mlir::Value> bounds; |
| std::stringstream asFortran; |
| Fortran::lower::AddrAndBoundsInfo info = |
| Fortran::lower::gatherDataOperandAddrAndBounds< |
| Fortran::parser::OmpObject, mlir::omp::DataBoundsOp, |
| mlir::omp::DataBoundsType>( |
| converter, firOpBuilder, semaCtx, stmtCtx, ompObject, |
| clauseLocation, asFortran, bounds, treatIndexAsSection); |
| |
| auto origSymbol = |
| converter.getSymbolAddress(*getOmpObjectSymbol(ompObject)); |
| mlir::Value symAddr = info.addr; |
| if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) |
| symAddr = origSymbol; |
| |
| // Explicit map captures are captured ByRef by default, |
| // optimisation passes may alter this to ByCopy or other capture |
| // types to optimise |
| mlir::Value mapOp = createMapInfoOp( |
| firOpBuilder, clauseLocation, symAddr, mlir::Value{}, |
| asFortran.str(), bounds, {}, |
| static_cast< |
| std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>( |
| mapTypeBits), |
| mlir::omp::VariableCaptureKind::ByRef, symAddr.getType()); |
| |
| mapOperands.push_back(mapOp); |
| } |
| }); |
| } |
| |
| template <typename... Ts> |
| void ClauseProcessor::processTODO(mlir::Location currentLocation, |
| llvm::omp::Directive directive) const { |
| auto checkUnhandledClause = [&](const auto *x) { |
| if (!x) |
| return; |
| TODO(currentLocation, |
| "Unhandled clause " + |
| llvm::StringRef(Fortran::parser::ParseTreeDumper::GetNodeName(*x)) |
| .upper() + |
| " in " + llvm::omp::getOpenMPDirectiveName(directive).upper() + |
| " construct"); |
| }; |
| |
| for (ClauseIterator2 it = clauses2.v.begin(); it != clauses2.v.end(); ++it) |
| (checkUnhandledClause(std::get_if<Ts>(&it->u)), ...); |
| } |
| |
| template <typename T> |
| ClauseProcessor::ClauseIterator |
| ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) { |
| for (ClauseIterator it = begin; it != end; ++it) { |
| if (std::get_if<T>(&it->u)) |
| return it; |
| } |
| |
| return end; |
| } |
| |
| template <typename T> |
| ClauseProcessor::ClauseIterator2 |
| ClauseProcessor::findClause2(ClauseIterator2 begin, ClauseIterator2 end) { |
| for (ClauseIterator2 it = begin; it != end; ++it) { |
| if (std::get_if<T>(&it->u)) |
| return it; |
| } |
| |
| return end; |
| } |
| |
| template <typename T> |
| const T *ClauseProcessor::findUniqueClause( |
| const Fortran::parser::CharBlock **source) const { |
| ClauseIterator it = findClause<T>(clauses.begin(), clauses.end()); |
| if (it != clauses.end()) { |
| if (source) |
| *source = &it->source; |
| return &std::get<T>(it->u); |
| } |
| return nullptr; |
| } |
| |
| template <typename T> |
| bool ClauseProcessor::findRepeatableClause( |
| std::function<void(const T &, const Fortran::parser::CharBlock &source)> |
| callbackFn) const { |
| bool found = false; |
| ClauseIterator nextIt, endIt = clauses.end(); |
| for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) { |
| nextIt = findClause<T>(it, endIt); |
| |
| if (nextIt != endIt) { |
| callbackFn(std::get<T>(nextIt->u), nextIt->source); |
| found = true; |
| ++nextIt; |
| } |
| } |
| return found; |
| } |
| |
| template <typename T> |
| bool ClauseProcessor::findRepeatableClause2( |
| std::function<void(const T *, const Fortran::parser::CharBlock &source)> |
| callbackFn) const { |
| bool found = false; |
| ClauseIterator2 nextIt, endIt = clauses2.v.end(); |
| for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) { |
| nextIt = findClause2<T>(it, endIt); |
| |
| if (nextIt != endIt) { |
| callbackFn(&std::get<T>(nextIt->u), nextIt->source); |
| found = true; |
| ++nextIt; |
| } |
| } |
| return found; |
| } |
| |
| template <typename T> |
| bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const { |
| if (findUniqueClause<T>()) { |
| result = converter.getFirOpBuilder().getUnitAttr(); |
| return true; |
| } |
| return false; |
| } |
| |
| } // namespace omp |
| } // namespace lower |
| } // namespace Fortran |
| |
| #endif // FORTRAN_LOWER_CLAUASEPROCESSOR_H |