blob: 665290552d73b3d43778df847ee15cb7ab285cfe [file] [log] [blame] [edit]
//===-- lib/Semantics/openmp-utils.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Common utilities used in OpenMP semantic checks.
//
//===----------------------------------------------------------------------===//
#include "flang/Semantics/openmp-utils.h"
#include "flang/Common/Fortran-consts.h"
#include "flang/Common/idioms.h"
#include "flang/Common/indirection.h"
#include "flang/Common/reference.h"
#include "flang/Common/visit.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/match.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/traverse.h"
#include "flang/Evaluate/type.h"
#include "flang/Evaluate/variable.h"
#include "flang/Parser/openmp-utils.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/semantics.h"
#include "flang/Semantics/symbol.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include <cinttypes>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
namespace Fortran::semantics::omp {
using namespace Fortran::parser::omp;
const Scope &GetScopingUnit(const Scope &scope) {
const Scope *iter{&scope};
for (; !iter->IsTopLevel(); iter = &iter->parent()) {
switch (iter->kind()) {
case Scope::Kind::BlockConstruct:
case Scope::Kind::BlockData:
case Scope::Kind::DerivedType:
case Scope::Kind::MainProgram:
case Scope::Kind::Module:
case Scope::Kind::Subprogram:
return *iter;
default:
break;
}
}
return *iter;
}
const Scope &GetProgramUnit(const Scope &scope) {
const Scope *unit{nullptr};
for (const Scope *iter{&scope}; !iter->IsTopLevel(); iter = &iter->parent()) {
switch (iter->kind()) {
case Scope::Kind::BlockData:
case Scope::Kind::MainProgram:
case Scope::Kind::Module:
return *iter;
case Scope::Kind::Subprogram:
// Ignore subprograms that are nested.
unit = iter;
break;
default:
break;
}
}
assert(unit && "Scope not in a program unit");
return *unit;
}
SourcedActionStmt GetActionStmt(const parser::ExecutionPartConstruct *x) {
if (x == nullptr) {
return SourcedActionStmt{};
}
if (auto *exec{std::get_if<parser::ExecutableConstruct>(&x->u)}) {
using ActionStmt = parser::Statement<parser::ActionStmt>;
if (auto *stmt{std::get_if<ActionStmt>(&exec->u)}) {
return SourcedActionStmt{&stmt->statement, stmt->source};
}
}
return SourcedActionStmt{};
}
SourcedActionStmt GetActionStmt(const parser::Block &block) {
if (block.size() == 1) {
return GetActionStmt(&block.front());
}
return SourcedActionStmt{};
}
std::string ThisVersion(unsigned version) {
std::string tv{
std::to_string(version / 10) + "." + std::to_string(version % 10)};
return "OpenMP v" + tv;
}
std::string TryVersion(unsigned version) {
return "try -fopenmp-version=" + std::to_string(version);
}
const parser::Designator *GetDesignatorFromObj(
const parser::OmpObject &object) {
return std::get_if<parser::Designator>(&object.u);
}
const parser::DataRef *GetDataRefFromObj(const parser::OmpObject &object) {
if (auto *desg{GetDesignatorFromObj(object)}) {
return std::get_if<parser::DataRef>(&desg->u);
}
return nullptr;
}
const parser::ArrayElement *GetArrayElementFromObj(
const parser::OmpObject &object) {
if (auto *dataRef{GetDataRefFromObj(object)}) {
using ElementIndirection = common::Indirection<parser::ArrayElement>;
if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
return &ind->value();
}
}
return nullptr;
}
const Symbol *GetObjectSymbol(const parser::OmpObject &object) {
// Some symbols may be missing if the resolution failed, e.g. when an
// undeclared name is used with implicit none.
if (auto *name{std::get_if<parser::Name>(&object.u)}) {
return name->symbol ? &name->symbol->GetUltimate() : nullptr;
} else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
auto &last{GetLastName(*desg)};
return last.symbol ? &GetLastName(*desg).symbol->GetUltimate() : nullptr;
}
return nullptr;
}
std::optional<parser::CharBlock> GetObjectSource(
const parser::OmpObject &object) {
if (auto *name{std::get_if<parser::Name>(&object.u)}) {
return name->source;
} else if (auto *desg{std::get_if<parser::Designator>(&object.u)}) {
return GetLastName(*desg).source;
}
return std::nullopt;
}
const Symbol *GetArgumentSymbol(const parser::OmpArgument &argument) {
if (auto *locator{std::get_if<parser::OmpLocator>(&argument.u)}) {
if (auto *object{std::get_if<parser::OmpObject>(&locator->u)}) {
return GetObjectSymbol(*object);
}
}
return nullptr;
}
const parser::OmpObject *GetArgumentObject(
const parser::OmpArgument &argument) {
if (auto *locator{std::get_if<parser::OmpLocator>(&argument.u)}) {
return std::get_if<parser::OmpObject>(&locator->u);
}
return nullptr;
}
bool IsCommonBlock(const Symbol &sym) {
return sym.detailsIf<CommonBlockDetails>() != nullptr;
}
bool IsVariableListItem(const Symbol &sym) {
return evaluate::IsVariable(sym) || sym.attrs().test(Attr::POINTER);
}
bool IsExtendedListItem(const Symbol &sym) {
return IsVariableListItem(sym) || sym.IsSubprogram();
}
bool IsTypeParamInquiry(const Symbol &sym) {
return common::visit( //
common::visitors{
[&](const MiscDetails &d) {
return d.kind() == MiscDetails::Kind::KindParamInquiry ||
d.kind() == MiscDetails::Kind::LenParamInquiry;
},
[&](const TypeParamDetails &s) { return true; },
[&](auto &&) { return false; },
},
sym.details());
}
bool IsStructureComponent(const Symbol &sym) {
return sym.owner().kind() == Scope::Kind::DerivedType;
}
bool IsPrivatizable(const Symbol &sym) {
auto *misc{sym.detailsIf<MiscDetails>()};
return IsVariableName(sym) && !IsProcedure(sym) && !IsStmtFunction(sym) &&
!IsNamedConstant(sym) &&
( // OpenMP 5.2, 5.1.1: Assumed-size arrays are shared
!semantics::IsAssumedSizeArray(sym) ||
// If CrayPointer is among the DSA list then the
// CrayPointee is Privatizable
sym.test(Symbol::Flag::CrayPointee)) &&
!sym.owner().IsDerivedType() &&
sym.owner().kind() != Scope::Kind::ImpliedDos &&
sym.owner().kind() != Scope::Kind::Forall &&
!sym.detailsIf<semantics::AssocEntityDetails>() &&
!sym.detailsIf<semantics::NamelistDetails>() &&
(!misc ||
(misc->kind() != MiscDetails::Kind::ComplexPartRe &&
misc->kind() != MiscDetails::Kind::ComplexPartIm &&
misc->kind() != MiscDetails::Kind::KindParamInquiry &&
misc->kind() != MiscDetails::Kind::LenParamInquiry &&
misc->kind() != MiscDetails::Kind::ConstructName));
}
bool IsVarOrFunctionRef(const MaybeExpr &expr) {
if (expr) {
return evaluate::UnwrapProcedureRef(*expr) != nullptr ||
evaluate::IsVariable(*expr);
} else {
return false;
}
}
bool IsWholeAssumedSizeArray(const parser::OmpObject &object) {
if (auto *sym{GetObjectSymbol(object)}; sym && IsAssumedSizeArray(*sym)) {
return !GetArrayElementFromObj(object);
}
return false;
}
bool IsMapEnteringType(parser::OmpMapType::Value type) {
switch (type) {
case parser::OmpMapType::Value::Alloc:
case parser::OmpMapType::Value::Storage:
case parser::OmpMapType::Value::To:
case parser::OmpMapType::Value::Tofrom:
return true;
default:
return false;
}
}
bool IsMapExitingType(parser::OmpMapType::Value type) {
switch (type) {
case parser::OmpMapType::Value::Delete:
case parser::OmpMapType::Value::From:
case parser::OmpMapType::Value::Release:
case parser::OmpMapType::Value::Storage:
case parser::OmpMapType::Value::Tofrom:
return true;
default:
return false;
}
}
static MaybeExpr GetEvaluateExprFromTyped(const parser::TypedExpr &typedExpr) {
// ForwardOwningPointer typedExpr
// `- GenericExprWrapper ^.get()
// `- std::optional<Expr> ^->v
if (auto *wrapper{typedExpr.get()}) {
return wrapper->v;
}
return std::nullopt;
}
MaybeExpr GetEvaluateExpr(const parser::Expr &parserExpr) {
return GetEvaluateExprFromTyped(parserExpr.typedExpr);
}
std::optional<evaluate::DynamicType> GetDynamicType(
const parser::Expr &parserExpr) {
if (auto maybeExpr{GetEvaluateExpr(parserExpr)}) {
return maybeExpr->GetType();
} else {
return std::nullopt;
}
}
namespace {
struct LogicalConstantVistor : public evaluate::Traverse<LogicalConstantVistor,
std::optional<bool>, false> {
using Result = std::optional<bool>;
using Base = evaluate::Traverse<LogicalConstantVistor, Result, false>;
LogicalConstantVistor() : Base(*this) {}
Result Default() const { return std::nullopt; }
using Base::operator();
template <typename T> //
Result operator()(const evaluate::Constant<T> &x) const {
if constexpr (T::category == common::TypeCategory::Logical) {
return llvm::transformOptional(
x.GetScalarValue(), [](auto &&v) { return v.IsTrue(); });
} else {
return std::nullopt;
}
}
template <typename... Rs> //
Result Combine(Result &&result, Rs &&...results) const {
if constexpr (sizeof...(results) == 0) {
return result;
} else {
if (result.has_value()) {
return result;
} else {
return Combine(std::move(results)...);
}
}
}
};
} // namespace
std::optional<bool> GetLogicalValue(const SomeExpr &expr) {
return LogicalConstantVistor{}(expr);
}
namespace {
struct ContiguousHelper {
ContiguousHelper(SemanticsContext &context)
: fctx_(context.foldingContext()) {}
template <typename Contained>
std::optional<bool> Visit(const common::Indirection<Contained> &x) {
return Visit(x.value());
}
template <typename Contained>
std::optional<bool> Visit(const common::Reference<Contained> &x) {
return Visit(x.get());
}
template <typename T> std::optional<bool> Visit(const evaluate::Expr<T> &x) {
return common::visit([&](auto &&s) { return Visit(s); }, x.u);
}
template <typename T>
std::optional<bool> Visit(const evaluate::Designator<T> &x) {
return common::visit(
[this](auto &&s) { return evaluate::IsContiguous(s, fctx_); }, x.u);
}
template <typename T> std::optional<bool> Visit(const T &) {
// Everything else.
return std::nullopt;
}
private:
evaluate::FoldingContext &fctx_;
};
} // namespace
// Return values:
// - std::optional<bool>{true} if the object is known to be contiguous
// - std::optional<bool>{false} if the object is known not to be contiguous
// - std::nullopt if the object contiguity cannot be determined
std::optional<bool> IsContiguous(
SemanticsContext &semaCtx, const parser::OmpObject &object) {
return common::visit( //
common::visitors{//
[&](const parser::Name &x) {
// Any member of a common block must be contiguous.
return std::optional<bool>{true};
},
[&](const parser::Designator &x) {
evaluate::ExpressionAnalyzer ea{semaCtx};
if (MaybeExpr maybeExpr{ea.Analyze(x)}) {
return ContiguousHelper{semaCtx}.Visit(*maybeExpr);
}
return std::optional<bool>{};
},
[&](const parser::OmpObject::Invalid &) {
return std::optional<bool>{};
}},
object.u);
}
struct DesignatorCollector : public evaluate::Traverse<DesignatorCollector,
std::vector<SomeExpr>, false> {
using Result = std::vector<SomeExpr>;
using Base = evaluate::Traverse<DesignatorCollector, Result, false>;
DesignatorCollector() : Base(*this) {}
Result Default() const { return {}; }
using Base::operator();
template <typename T> //
Result operator()(const evaluate::Designator<T> &x) const {
// Once in a designator, don't traverse it any further (i.e. only
// collect top-level designators).
auto copy{x};
return Result{AsGenericExpr(std::move(copy))};
}
template <typename... Rs> //
Result Combine(Result &&result, Rs &&...results) const {
Result v(std::move(result));
auto moveAppend{[](auto &accum, auto &&other) {
for (auto &&s : other) {
accum.push_back(std::move(s));
}
}};
(moveAppend(v, std::move(results)), ...);
return v;
}
};
std::vector<SomeExpr> GetAllDesignators(const SomeExpr &expr) {
return DesignatorCollector{}(expr);
}
static bool HasCommonDesignatorSymbols(
const evaluate::SymbolVector &baseSyms, const SomeExpr &other) {
// Compare the designators used in "other" with the designators whose
// symbols are given in baseSyms.
// This is a part of the check if these two expressions can access the same
// storage: if the designators used in them are different enough, then they
// will be assumed not to access the same memory.
//
// Consider an (array element) expression x%y(w%z), the corresponding symbol
// vector will be {x, y, w, z} (i.e. the symbols for these names).
// Check whether this exact sequence appears anywhere in any the symbol
// vector for "other". This will be true for x(y) and x(y+1), so this is
// not a sufficient condition, but can be used to eliminate candidates
// before doing more exhaustive checks.
//
// If any of the symbols in this sequence are function names, assume that
// there is no storage overlap, mostly because it would be impossible in
// general to determine what storage the function will access.
// Note: if f is pure, then two calls to f will access the same storage
// when called with the same arguments. This check is not done yet.
if (llvm::any_of(
baseSyms, [](const SymbolRef &s) { return s->IsSubprogram(); })) {
// If there is a function symbol in the chain then we can't infer much
// about the accessed storage.
return false;
}
auto isSubsequence{// Is u a subsequence of v.
[](const evaluate::SymbolVector &u, const evaluate::SymbolVector &v) {
size_t us{u.size()}, vs{v.size()};
if (us > vs) {
return false;
}
for (size_t off{0}; off != vs - us + 1; ++off) {
bool same{true};
for (size_t i{0}; i != us; ++i) {
if (u[i] != v[off + i]) {
same = false;
break;
}
}
if (same) {
return true;
}
}
return false;
}};
evaluate::SymbolVector otherSyms{evaluate::GetSymbolVector(other)};
return isSubsequence(baseSyms, otherSyms);
}
static bool HasCommonTopLevelDesignators(
const std::vector<SomeExpr> &baseDsgs, const SomeExpr &other) {
// Compare designators directly as expressions. This will ensure
// that x(y) and x(y+1) are not flagged as overlapping, whereas
// the symbol vectors for both of these would be identical.
std::vector<SomeExpr> otherDsgs{GetAllDesignators(other)};
for (auto &s : baseDsgs) {
if (llvm::any_of(otherDsgs, [&](auto &&t) { return s == t; })) {
return true;
}
}
return false;
}
const SomeExpr *HasStorageOverlap(
const SomeExpr &base, llvm::ArrayRef<SomeExpr> exprs) {
evaluate::SymbolVector baseSyms{evaluate::GetSymbolVector(base)};
std::vector<SomeExpr> baseDsgs{GetAllDesignators(base)};
for (const SomeExpr &expr : exprs) {
if (!HasCommonDesignatorSymbols(baseSyms, expr)) {
continue;
}
if (HasCommonTopLevelDesignators(baseDsgs, expr)) {
return &expr;
}
}
return nullptr;
}
// Check if the ActionStmt is actually a [Pointer]AssignmentStmt. This is
// to separate cases where the source has something that looks like an
// assignment, but is semantically wrong (diagnosed by general semantic
// checks), and where the source has some other statement (which we want
// to report as "should be an assignment").
bool IsAssignment(const parser::ActionStmt *x) {
if (x == nullptr) {
return false;
}
using AssignmentStmt = common::Indirection<parser::AssignmentStmt>;
using PointerAssignmentStmt =
common::Indirection<parser::PointerAssignmentStmt>;
return common::visit(
[](auto &&s) -> bool {
using BareS = llvm::remove_cvref_t<decltype(s)>;
return std::is_same_v<BareS, AssignmentStmt> ||
std::is_same_v<BareS, PointerAssignmentStmt>;
},
x->u);
}
bool IsPointerAssignment(const evaluate::Assignment &x) {
return std::holds_alternative<evaluate::Assignment::BoundsSpec>(x.u) ||
std::holds_alternative<evaluate::Assignment::BoundsRemapping>(x.u);
}
MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp) {
auto &instance = std::get<parser::OmpStylizedInstance::Instance>(inp.t);
return common::visit( //
common::visitors{
[&](const parser::AssignmentStmt &s) -> MaybeExpr {
return GetEvaluateExpr(std::get<parser::Expr>(s.t));
},
[&](const parser::CallStmt &s) -> MaybeExpr {
assert(s.typedCall && "Expecting typedCall");
const auto &procRef = *s.typedCall;
return SomeExpr(procRef);
},
[&](const common::Indirection<parser::Expr> &s) -> MaybeExpr {
return GetEvaluateExpr(s.value());
},
},
instance.u);
}
Reason::Reason(const Reason &other) { //
CopyFrom(other);
}
Reason &Reason::operator=(const Reason &other) {
if (this != &other) {
msgs.clear();
CopyFrom(other);
}
return *this;
}
void Reason::CopyFrom(const Reason &other) {
for (auto &msg : other.msgs.messages()) {
msgs.Say(parser::Message(msg));
}
}
parser::Message &Reason::AttachTo(parser::Message &msg) {
msgs.AttachTo(msg);
return msg;
}
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
unsigned version) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
if (auto value{GetIntValue(*expr)}) {
std::string name{GetUpperName(clauseId, version)};
Reason reason;
reason.Say(clause->source,
"%s clause was specified with argument %" PRId64 ""_because_en_US,
name.c_str(), *value);
return {*value, std::move(reason)};
}
}
}
return {};
}
template <typename T>
static WithReason<int64_t> GetNumArgumentsWithReasonForType(
const parser::OmpClause &clause, const std::string &name) {
if (auto *args{parser::Unwrap<std::list<T>>(clause.u)}) {
auto num{static_cast<int64_t>(args->size())};
Reason reason;
reason.Say(clause.source,
"%s clause was specified with %" PRId64 " arguments"_because_en_US,
name.c_str(), num);
return {num, std::move(reason)};
}
return {};
}
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
unsigned version) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
std::string name{GetUpperName(clauseId, version)};
// Try the types used for list items.
{
using Ty = parser::ScalarIntExpr;
if (auto n{GetNumArgumentsWithReasonForType<Ty>(*clause, name)}) {
return n;
}
}
{
using Ty = parser::ScalarIntConstantExpr;
if (auto n{GetNumArgumentsWithReasonForType<Ty>(*clause, name)}) {
return n;
}
}
}
return {};
}
bool IsLoopTransforming(llvm::omp::Directive dir) {
switch (dir) {
// TODO case llvm::omp::Directive::OMPD_flatten:
case llvm::omp::Directive::OMPD_fuse:
case llvm::omp::Directive::OMPD_interchange:
case llvm::omp::Directive::OMPD_nothing:
case llvm::omp::Directive::OMPD_reverse:
// TODO case llvm::omp::Directive::OMPD_split:
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile:
case llvm::omp::Directive::OMPD_unroll:
return true;
default:
return false;
}
}
bool IsFullUnroll(const parser::OpenMPLoopConstruct &x) {
const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
if (beginSpec.DirName().v == llvm::omp::Directive::OMPD_unroll) {
return parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_partial) == nullptr;
}
return false;
}
namespace {
// Helper class to check if a given evaluate::Expr is an array expression.
// This does not check any proper subexpressions of the expression (except
// parentheses).
struct ArrayExpressionRecognizer {
template <TypeCategory C>
static bool isArrayExpression(
const evaluate::Expr<evaluate::SomeKind<C>> &x) {
return common::visit([](auto &&s) { return isArrayExpression(s); }, x.u);
}
template <TypeCategory C, int K>
static bool isArrayExpression(const evaluate::Expr<evaluate::Type<C, K>> &x) {
return common::visit([](auto &&s) { return isArrayExpression(s); },
evaluate::match::deparen(x).u);
}
template <typename T>
static bool isArrayExpression(const evaluate::Designator<T> &x) {
if (auto *sym{std::get_if<SymbolRef>(&x.u)}) {
return (*sym)->Rank() != 0;
}
if (auto *array{std::get_if<evaluate::ArrayRef>(&x.u)}) {
return llvm::any_of(array->subscript(), [](const evaluate::Subscript &s) {
// A vector subscript will not be a Triplet, but will have rank > 0.
return std::holds_alternative<evaluate::Triplet>(s.u) || s.Rank() > 0;
});
}
return false;
}
template <typename T> static bool isArrayExpression(const T &x) {
return false;
}
static bool isArrayExpression(const evaluate::Expr<evaluate::SomeType> &x) {
return common::visit([](auto &&s) { return isArrayExpression(s); }, x.u);
}
};
/// Helper class to check if a given evaluate::Expr contains a subexpression
/// (not necessarily proper) that is an array expression.
struct ArrayExpressionFinder
: public evaluate::AnyTraverse<ArrayExpressionFinder> {
using Base = evaluate::AnyTraverse<ArrayExpressionFinder>;
using Base::operator();
ArrayExpressionFinder() : Base(*this) {}
template <typename T>
bool operator()(const evaluate::Designator<T> &x) const {
return ArrayExpressionRecognizer::isArrayExpression(x);
}
};
/// Helper class to check if any array expressions contained in the given
/// evaluate::Expr satisfy the criteria for being in "intervening code".
struct ArrayExpressionChecker {
template <typename T> bool Pre(const T &) { return true; }
template <typename T> void Post(const T &) {}
bool Pre(const parser::Expr &parserExpr) {
// If we have found a prohibited expression, skip the rest of the
// traversal.
if (!rejected) {
if (auto expr{GetEvaluateExpr(parserExpr)}) {
rejected = ArrayExpressionFinder{}(*expr);
}
}
return !rejected;
}
bool rejected{false};
};
} // namespace
static bool ContainsInvalidArrayExpression(
const parser::ExecutionPartConstruct &x) {
ArrayExpressionChecker checker;
parser::Walk(x, checker);
return checker.rejected;
}
/// Checks if the given construct `x` satisfied OpenMP requirements for
/// intervening-code. Excludes CYCLE/EXIT statements as well as constructs
/// likely to result in a runtime loop, e.g. FORALL, WHERE, etc.
bool IsValidInterveningCode(const parser::ExecutionPartConstruct &x) {
static auto isScalar = [](const parser::Variable &variable) {
if (auto expr{GetEvaluateExprFromTyped(variable.typedExpr)}) {
return expr->Rank() == 0;
}
return false;
};
auto *exec{parser::Unwrap<parser::ExecutableConstruct>(x)};
if (!exec) {
// DATA, ENTRY, FORMAT, NAMELIST are not explicitly prohibited in a CLN
// although they are likely disallowed due to other requirements.
// Return true, they should be rejected elsewhere if necessary.
return true;
}
if (auto *action{parser::Unwrap<parser::ActionStmt>(exec->u)}) {
if (parser::Unwrap<parser::CycleStmt>(action->u) ||
parser::Unwrap<parser::ExitStmt>(action->u) ||
parser::Unwrap<parser::ForallStmt>(action->u) ||
parser::Unwrap<parser::WhereStmt>(action->u)) {
return false;
}
if (auto *assign{parser::Unwrap<parser::AssignmentStmt>(&action->u)}) {
if (!isScalar(std::get<parser::Variable>(assign->t))) {
return false;
}
}
} else { // Not ActionStmt
if (parser::Unwrap<parser::LabelDoStmt>(exec->u) ||
parser::Unwrap<parser::DoConstruct>(exec->u) ||
parser::Unwrap<parser::ForallConstruct>(exec->u) ||
parser::Unwrap<parser::WhereConstruct>(exec->u)) {
return false;
}
if (auto *omp{parser::Unwrap<parser::OpenMPConstruct>(exec->u)}) {
auto dirName{GetOmpDirectiveName(*omp)};
if (llvm::omp::getDirectiveCategory(dirName.v) ==
llvm::omp::Category::Executable) {
return false;
}
}
}
if (ContainsInvalidArrayExpression(x)) {
return false;
}
return true;
}
/// Checks if the given construct `x` preserves perfect nesting of a loop,
/// when placed adjacent to the loop in the enclosing (parent) loop.
/// CONTINUE statements are no-ops, and thus are considered transparent.
/// Non-OpenMP compiler directives are also considered transparent to
/// allow legacy applications to pass the semantic checks.
bool IsTransparentInterveningCode(const parser::ExecutionPartConstruct &x) {
// Tolerate compiler directives in perfect nests.
return parser::Unwrap<parser::CompilerDirective>(x) ||
parser::Unwrap<parser::ContinueStmt>(x);
}
bool IsTransformableLoop(const parser::DoConstruct &loop) {
return loop.IsDoNormal();
}
bool IsTransformableLoop(const parser::OpenMPLoopConstruct &omp) {
if (IsFullUnroll(omp)) {
return false;
}
return IsLoopTransforming(omp.BeginDir().DirId());
}
bool IsTransformableLoop(const parser::ExecutionPartConstruct &epc) {
if (auto *loop{parser::Unwrap<parser::DoConstruct>(epc)}) {
return IsTransformableLoop(*loop);
}
if (auto *omp{parser::Unwrap<parser::OpenMPLoopConstruct>(epc)}) {
return IsTransformableLoop(*omp);
}
return false;
}
template <typename T,
typename = std::enable_if_t<std::is_arithmetic_v<llvm::remove_cvref_t<T>>>>
WithReason<T> operator+(const WithReason<T> &a, const WithReason<T> &b) {
if (a.value && b.value) {
return WithReason<T>{
*a.value + *b.value, Reason().Append(a.reason).Append(b.reason)};
}
return WithReason<T>();
}
template <typename T,
typename = std::enable_if_t<std::is_arithmetic_v<llvm::remove_cvref_t<T>>>>
WithReason<T> operator+(T a, const WithReason<T> &b) {
return WithReason<T>{a, Reason()} + b;
}
// Return the depth of the affected nests:
// {affected-depth, must-be-perfect-nest}.
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version) {
llvm::omp::Directive dir{spec.DirId()};
bool allowsCollapse{llvm::omp::isAllowedClauseForDirective(
dir, llvm::omp::Clause::OMPC_collapse, version)};
bool allowsOrdered{llvm::omp::isAllowedClauseForDirective(
dir, llvm::omp::Clause::OMPC_ordered, version)};
if (allowsCollapse || allowsOrdered) {
auto [count, reason]{GetArgumentValueWithReason(
spec, llvm::omp::Clause::OMPC_collapse, version)};
auto [vo, ro]{GetArgumentValueWithReason(
spec, llvm::omp::Clause::OMPC_ordered, version)};
if (vo) {
if (!count || *count < *vo) {
count = vo;
reason = std::move(ro);
}
}
return {{count, std::move(reason)}, true};
}
if (IsLoopTransforming(dir)) {
switch (dir) {
case llvm::omp::Directive::OMPD_interchange: {
// Get the length of the argument list to PERMUTATION.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
auto [num, reason]{GetNumArgumentsWithReason(
spec, llvm::omp::Clause::OMPC_permutation, version)};
return {{num, std::move(reason)}, true};
}
// PERMUTATION not specified, assume PERMUTATION(2, 1).
std::string name{parser::omp::GetUpperName(
llvm::omp::Clause::OMPC_permutation, version)};
Reason reason;
reason.Say(spec.source,
"%s clause was not specified, %s(2, 1) was assumed"_because_en_US,
name.c_str(), name.c_str());
return {{2, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile: {
// Get the length of the argument list to SIZES.
auto [num, reason]{GetNumArgumentsWithReason(
spec, llvm::omp::Clause::OMPC_sizes, version)};
return {{num, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_fuse: {
// Get the value from the argument to DEPTH.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_depth)) {
auto [count, reason]{GetArgumentValueWithReason(
spec, llvm::omp::Clause::OMPC_depth, version)};
return {{count, std::move(reason)}, true};
}
std::string name{
parser::omp::GetUpperName(llvm::omp::Clause::OMPC_depth, version)};
Reason reason;
reason.Say(spec.source,
"%s clause was not specified, a value of 1 was assumed"_because_en_US,
name.c_str());
return {{1, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_reverse:
case llvm::omp::Directive::OMPD_unroll:
return {WithReason<int64_t>(1), false};
// TODO: case llvm::omp::Directive::OMPD_flatten:
// TODO: case llvm::omp::Directive::OMPD_split:
default:
break;
}
}
return {{}, false};
}
// Return the range of the affected nests in the sequence:
// {first, count, std::move(reason)}.
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version) {
llvm::omp::Directive dir{spec.DirId()};
if (dir == llvm::omp::Directive::OMPD_fuse) {
std::string name{GetUpperName(llvm::omp::Clause::OMPC_looprange, version)};
if (auto *clause{
parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_looprange)}) {
auto &range{DEREF(parser::Unwrap<parser::OmpLooprangeClause>(clause->u))};
std::optional<int64_t> first{GetIntValue(std::get<0>(range.t))};
std::optional<int64_t> count{GetIntValue(std::get<1>(range.t))};
if (!first || !count || *first <= 0 || *count <= 0) {
return {};
}
std::string name{parser::omp::GetUpperName(
llvm::omp::Clause::OMPC_looprange, version)};
Reason reason;
reason.Say(clause->source,
"%s clause was specified with a count of %" PRId64
" starting at loop %" PRId64 ""_because_en_US,
name.c_str(), *count, *first);
return {std::make_pair(*first, *count), std::move(reason)};
}
// If LOOPRANGE was not found, return {1, -1}, where -1 means "the whole
// associated sequence".
Reason reason;
reason.Say(spec.source,
"%s clause was not specified, a value of 1 was assumed"_because_en_US,
name.c_str());
return {std::make_pair(1, -1), std::move(reason)};
}
assert(llvm::omp::getDirectiveAssociation(dir) ==
llvm::omp::Association::LoopNest &&
"Expecting loop-nest-associated construct");
// For loop-nest constructs, a single loop-nest is affected.
return {std::make_pair(1, 1), Reason()};
}
std::optional<int64_t> GetRequiredCount(
std::optional<int64_t> first, std::optional<int64_t> count) {
if (first && count && *first > 0) {
if (*count > 0) {
return *first + *count - 1;
} else if (*count == -1) {
return -1;
}
}
return std::nullopt;
}
std::optional<int64_t> GetRequiredCount(
std::optional<std::pair<int64_t, int64_t>> range) {
if (range) {
return GetRequiredCount(range->first, range->second);
}
return GetRequiredCount(std::nullopt, std::nullopt);
}
#ifdef EXPENSIVE_CHECKS
namespace {
/// Check that for every value x of type T, there will be a "source" member
/// somewhere in x. This is to specifically make sure that parser::GetSource
/// will return something for any parser::ExecutionPartConstruct.
template <typename...> struct HasSourceT {
static constexpr bool value{false};
};
template <typename T> struct HasSourceT<T> {
private:
using U = llvm::remove_cvref_t<T>;
static constexpr bool check() {
if constexpr (parser::HasSource<U>::value) {
return true;
} else if constexpr (ConstraintTrait<U>) {
return HasSourceT<decltype(U::thing)>::value;
} else if constexpr (WrapperTrait<U>) {
return HasSourceT<decltype(U::v)>::value;
} else if constexpr (TupleTrait<U>) {
return HasSourceT<decltype(U::t)>::value;
} else if constexpr (UnionTrait<U>) {
return HasSourceT<decltype(U::u)>::value;
} else {
return false;
}
}
public:
static constexpr bool value{check()};
};
template <> struct HasSourceT<parser::ErrorRecovery> {
static constexpr bool value{true};
};
template <typename T> struct HasSourceT<common::Indirection<T>> {
static constexpr bool value{HasSourceT<T>::value};
};
template <typename... Ts> struct HasSourceT<std::tuple<Ts...>> {
static constexpr bool value{(HasSourceT<Ts>::value || ...)};
};
template <typename... Ts> struct HasSourceT<std::variant<Ts...>> {
static constexpr bool value{(HasSourceT<Ts>::value && ...)};
};
static_assert(HasSourceT<parser::ExecutionPartConstruct>::value);
} // namespace
#endif // EXPENSIVE_CHECKS
LoopSequence::LoopSequence(const parser::ExecutionPartConstruct &root,
unsigned version, bool allowAllLoops)
: version_(version), allowAllLoops_(allowAllLoops) {
entry_ = createConstructEntry(root);
assert(entry_ && "Expecting loop like code");
createChildrenFromRange(entry_->location);
precalculate();
}
LoopSequence::LoopSequence(
std::unique_ptr<Construct> entry, unsigned version, bool allowAllLoops)
: version_(version), allowAllLoops_(allowAllLoops),
entry_(std::move(entry)) {
createChildrenFromRange(entry_->location);
precalculate();
}
std::unique_ptr<LoopSequence::Construct> LoopSequence::createConstructEntry(
const parser::ExecutionPartConstruct &code) {
if (auto *loop{parser::Unwrap<parser::DoConstruct>(code)}) {
if (allowAllLoops_ || IsTransformableLoop(*loop)) {
auto &body{std::get<parser::Block>(loop->t)};
return std::make_unique<Construct>(body, &code);
}
} else if (auto *omp{parser::Unwrap<parser::OpenMPLoopConstruct>(code)}) {
// Allow all loop constructs. This helps with better diagnostics, e.g.
// "this is not a loop-transforming construct", insted of just "this is
// not a valid intervening code".
auto &body{std::get<parser::Block>(omp->t)};
return std::make_unique<Construct>(body, &code);
}
return nullptr;
}
void LoopSequence::createChildrenFromRange(
ExecutionPartIterator::IteratorType begin,
ExecutionPartIterator::IteratorType end) {
bool invalidWithEntry{false};
// Create children. If there is zero or one, this LoopSequence could be
// a nest. If there are more, it could be a proper sequence. In the latter
// case any code between consecutive children must be "transparent".
for (auto &code : BlockRange(begin, end, BlockRange::Step::Over)) {
if (auto entry{createConstructEntry(code)}) {
children_.push_back(
LoopSequence(std::move(entry), version_, allowAllLoops_));
// Even when DO WHILE et al are allowed to have entries, still treat
// them as invalid intervening code.
// Give it priority over other kinds of invalid interveninig code.
if (!invalidWithEntry && !IsTransformableLoop(code)) {
invalidIC_ = &code;
invalidWithEntry = true;
}
} else {
if (!invalidIC_ && !IsValidInterveningCode(code)) {
invalidIC_ = &code;
}
if (!opaqueIC_ && !IsTransparentInterveningCode(code)) {
opaqueIC_ = &code;
}
}
}
}
void LoopSequence::precalculate() {
// Calculate length before depths.
length_ = calculateLength();
depth_ = calculateDepths();
}
std::optional<int64_t> LoopSequence::calculateLength() const {
if (!entry_->owner) {
return getNestedLength();
}
if (parser::Unwrap<parser::DoConstruct>(entry_->owner)) {
return 1;
}
auto &omp{DEREF(parser::Unwrap<parser::OpenMPLoopConstruct>(*entry_->owner))};
const parser::OmpDirectiveSpecification &beginSpec{omp.BeginDir()};
llvm::omp::Directive dir{beginSpec.DirId()};
if (!IsLoopTransforming(dir)) {
return 0;
}
// TODO: Handle split, apply.
if (IsFullUnroll(omp)) {
return std::nullopt;
}
auto nestedLength{getNestedLength()};
if (dir == llvm::omp::Directive::OMPD_fuse) {
// If there are no loops nested inside of FUSE, then the construct is
// invalid. This case will be diagnosed when analyzing the body of the FUSE
// construct itself, not when checking a construct in which the FUSE is
// nested.
// Returning std::nullopt prevents error messages caused by the same
// problem from being emitted for every enclosing loop construct, for
// example:
// !$omp do ! error: this should contain a loop (superfluous)
// !$omp fuse ! error: this should contain a loop
// !$omp end fuse
if (!nestedLength || *nestedLength == 0) {
return std::nullopt;
}
auto *clause{
parser::omp::FindClause(beginSpec, llvm::omp::Clause::OMPC_looprange)};
if (!clause) {
return 1;
}
auto *loopRange{parser::Unwrap<parser::OmpLooprangeClause>(*clause)};
std::optional<int64_t> count{GetIntValue(std::get<1>(loopRange->t))};
if (!count || *count <= 0) {
return std::nullopt;
}
if (*count <= *nestedLength) {
return 1 + *nestedLength - *count;
}
return std::nullopt;
}
if (dir == llvm::omp::Directive::OMPD_nothing) {
return nestedLength;
}
// For every other loop construct return 1.
return 1;
}
std::optional<int64_t> LoopSequence::getNestedLength() const {
int64_t sum{0};
for (auto &seq : children_) {
if (auto len{seq.length()}) {
sum += *len;
} else {
return std::nullopt;
}
}
return sum;
}
LoopSequence::Depth LoopSequence::calculateDepths() const {
auto plus{[](std::optional<int64_t> a,
std::optional<int64_t> b) -> std::optional<int64_t> {
if (a && b) {
return *a + *b;
}
return std::nullopt;
}};
// The sequence length is calculated first, so we already know if this
// sequence is a nest or not.
if (!isNest()) {
return Depth{0, 0};
}
// Get the length of the nested sequence. The invalidIC_ and opaqueIC_
// members do not count canonical loop nests, but there can only be one
// for depth to make sense.
std::optional<int64_t> length{getNestedLength()};
// Get the depths of the code nested in this sequence (e.g. contained in
// entry_), and use it as the basis for the depths of entry_->owner.
auto [semaDepth, perfDepth]{getNestedDepths()};
if (invalidIC_ || length.value_or(0) != 1) {
semaDepth = perfDepth = 0;
} else if (opaqueIC_ || length.value_or(0) != 1) {
perfDepth = 0;
}
if (!entry_->owner) {
return Depth{semaDepth, perfDepth};
}
if (parser::Unwrap<parser::DoConstruct>(entry_->owner)) {
return Depth{plus(1, semaDepth), plus(1, perfDepth)};
}
auto &omp{DEREF(parser::Unwrap<parser::OpenMPLoopConstruct>(*entry_->owner))};
const parser::OmpDirectiveSpecification &beginSpec{omp.BeginDir()};
llvm::omp::Directive dir{beginSpec.DirId()};
if (!IsTransformableLoop(omp)) {
return Depth{0, 0};
}
switch (dir) {
// TODO: case llvm::omp::Directive::OMPD_split:
// TODO: case llvm::omp::Directive::OMPD_flatten:
case llvm::omp::Directive::OMPD_fuse:
if (auto *clause{parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_depth)}) {
auto &expr{parser::UnwrapRef<parser::Expr>(clause->u)};
auto value{GetIntValue(expr)};
auto nestedLength{getNestedLength()};
// The result is a perfect nest only if all loop in the sequence
// are fused.
if (value && nestedLength) {
auto range{GetAffectedLoopRangeWithReason(beginSpec, version_)};
if (auto required{GetRequiredCount(range.value)}) {
if (*required == -1 || *required == *nestedLength) {
return Depth{value, value};
}
return Depth{1, 1};
}
}
return Depth{std::nullopt, std::nullopt};
}
// FUSE cannot create a nest of depth > 1 without DEPTH clause.
return Depth{1, 1};
case llvm::omp::Directive::OMPD_interchange:
case llvm::omp::Directive::OMPD_nothing:
case llvm::omp::Directive::OMPD_reverse:
return {semaDepth, perfDepth};
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile:
// Look for SIZES clause.
if (auto *clause{parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_sizes)}) {
// Return the number of arguments in the SIZES clause
size_t num{
parser::UnwrapRef<parser::OmpClause::Sizes>(clause->u).v.size()};
return Depth{plus(num, semaDepth), plus(num, perfDepth)};
}
// The SIZES clause is mandatory, if it's missing the result is unknown.
return {};
case llvm::omp::Directive::OMPD_unroll:
if (IsFullUnroll(omp)) {
return Depth{0, 0};
}
// If this is not a full unroll then look for a PARTIAL clause.
if (auto *clause{parser::omp::FindClause(
beginSpec, llvm::omp::Clause::OMPC_partial)}) {
std::optional<int64_t> factor;
if (auto *expr{parser::Unwrap<parser::Expr>(clause->u)}) {
factor = GetIntValue(*expr);
}
// If it's a partial unroll, and the unroll count is 1, then this
// construct is a no-op.
if (factor && *factor == 1) {
return Depth{semaDepth, perfDepth};
}
// If it's a proper partial unroll, then the resulting loop cannot
// have either depth greater than 1: if it had a loop nested in it,
// then after unroll it will have at least two copies it it, making
// it a final loop.
return {1, 1};
}
return Depth{std::nullopt, std::nullopt};
default:
llvm_unreachable("Expecting loop-transforming construct");
}
}
LoopSequence::Depth LoopSequence::getNestedDepths() const {
if (!isNest()) {
return {std::nullopt, std::nullopt};
} else if (children_.empty()) {
// No children, but length == 1.
assert(entry_->owner &&
parser::Unwrap<parser::DoConstruct>(entry_->owner) &&
"Expecting DO construct");
return Depth{0, 0};
}
return children_.front().depth_;
}
} // namespace Fortran::semantics::omp