[flang][cuda] Avoid crash when the force modifier is used (#160176)
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index 8cb2cae5..1f8d928 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -315,6 +315,7 @@ private: std::int64_t GetAssociatedLoopLevelFromClauses(const parser::AccClauseList &); + bool HasForceCollapseModifier(const parser::AccClauseList &); Symbol::Flags dataSharingAttributeFlags{Symbol::Flag::AccShared, Symbol::Flag::AccPrivate, Symbol::Flag::AccFirstPrivate, @@ -333,7 +334,7 @@ Symbol::Flag::AccDevicePtr, Symbol::Flag::AccDeviceResident, Symbol::Flag::AccLink, Symbol::Flag::AccPresent}; - void CheckAssociatedLoop(const parser::DoConstruct &); + void CheckAssociatedLoop(const parser::DoConstruct &, bool forceCollapsed); void ResolveAccObjectList(const parser::AccObjectList &, Symbol::Flag); void ResolveAccObject(const parser::AccObject &, Symbol::Flag); Symbol *ResolveAcc(const parser::Name &, Symbol::Flag, Scope &); @@ -1168,7 +1169,7 @@ ClearDataSharingAttributeObjects(); SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList)); const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)}; - CheckAssociatedLoop(*outer); + CheckAssociatedLoop(*outer, HasForceCollapseModifier(clauseList)); return true; } @@ -1366,7 +1367,7 @@ const auto &clauseList{std::get<parser::AccClauseList>(beginBlockDir.t)}; SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList)); const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)}; - CheckAssociatedLoop(*outer); + CheckAssociatedLoop(*outer, HasForceCollapseModifier(clauseList)); ClearDataSharingAttributeObjects(); return true; } @@ -1478,6 +1479,18 @@ return true; } +bool AccAttributeVisitor::HasForceCollapseModifier( + const parser::AccClauseList &x) { + for (const auto &clause : x.v) { + if (const auto *collapseClause{ + std::get_if<parser::AccClause::Collapse>(&clause.u)}) { + const parser::AccCollapseArg &arg = collapseClause->v; + return std::get<bool>(arg.t); + } + } + return false; +} + std::int64_t AccAttributeVisitor::GetAssociatedLoopLevelFromClauses( const parser::AccClauseList &x) { std::int64_t collapseLevel{0}; @@ -1499,14 +1512,14 @@ } void AccAttributeVisitor::CheckAssociatedLoop( - const parser::DoConstruct &outerDoConstruct) { + const parser::DoConstruct &outerDoConstruct, bool forceCollapsed) { std::int64_t level{GetContext().associatedLoopLevel}; if (level <= 0) { // collapse value was negative or 0 return; } const auto getNextDoConstruct = - [this](const parser::Block &block, + [this, forceCollapsed](const parser::Block &block, std::int64_t &level) -> const parser::DoConstruct * { for (const auto &entry : block) { if (const auto *doConstruct = GetDoConstructIf(entry)) { @@ -1524,7 +1537,9 @@ "LOOP directive not expected in COLLAPSE loop nest"_err_en_US); level = 0; } else { - break; + if (!forceCollapsed) { + break; + } } } return nullptr;
diff --git a/flang/test/Semantics/OpenACC/acc-collapse-force.f90 b/flang/test/Semantics/OpenACC/acc-collapse-force.f90 new file mode 100644 index 0000000..80b1060 --- /dev/null +++ b/flang/test/Semantics/OpenACC/acc-collapse-force.f90
@@ -0,0 +1,19 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenacc -fsyntax-only + +! Check that loop with force collapse do not break in the semantic step. +subroutine sub3() + integer :: i, j + integer, parameter :: n = 100, m = 200 + real, dimension(n, m) :: a + real, dimension(n) :: bb + real :: r + a = 1 + r = 0 + !$acc parallel loop collapse(force:2) copy(a) + do i = 1, n + bb(i) = r + do j = 1, m + a(i,j) = r * a(i,j) + enddo + enddo +end subroutine