[Flang][OpenMP] Support depend clause for task construct, excluding array sections

This patch adds support for the OpenMP 4.0 depend clause for the task
construct, excluding array sections, to Flang lowering from parse-tree
to MLIR.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D146766

GitOrigin-RevId: 3373c8405c5e144567c2b25c28b54b4ac63b09b6
diff --git a/lib/Lower/OpenMP.cpp b/lib/Lower/OpenMP.cpp
index 13b7c58..94bd8fb 100644
--- a/lib/Lower/OpenMP.cpp
+++ b/lib/Lower/OpenMP.cpp
@@ -1006,6 +1006,31 @@
   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
 }
 
+static omp::ClauseTaskDependAttr
+genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
+                  const Fortran::parser::OmpClause::Depend *dependClause) {
+  omp::ClauseTaskDepend pbKind;
+  switch (
+      std::get<Fortran::parser::OmpDependenceType>(
+          std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
+              .t)
+          .v) {
+  case Fortran::parser::OmpDependenceType::Type::In:
+    pbKind = omp::ClauseTaskDepend::taskdependin;
+    break;
+  case Fortran::parser::OmpDependenceType::Type::Out:
+    pbKind = omp::ClauseTaskDepend::taskdependout;
+    break;
+  case Fortran::parser::OmpDependenceType::Type::Inout:
+    pbKind = omp::ClauseTaskDepend::taskdependinout;
+    break;
+  default:
+    llvm_unreachable("unknown parser task dependence type");
+    break;
+  }
+  return omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), pbKind);
+}
+
 /* When parallel is used in a combined construct, then use this function to
  * create the parallel operation. It handles the parallel specific clauses
  * and leaves the rest for handling at the inner operations.
@@ -1072,7 +1097,8 @@
   mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
       priorityClauseOperand;
   mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
-  SmallVector<Value> allocateOperands, allocatorOperands;
+  SmallVector<Value> allocateOperands, allocatorOperands, dependOperands;
+  SmallVector<Attribute> dependTypeOperands;
   mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
 
   const auto &opClauseList =
@@ -1148,6 +1174,42 @@
            "Reduction in OpenMP " +
                llvm::omp::getOpenMPDirectiveName(blockDirective.v) +
                " construct");
+    } else if (const auto &dependClause =
+                   std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
+      const std::list<Fortran::parser::Designator> &depVal =
+          std::get<std::list<Fortran::parser::Designator>>(
+              std::get<Fortran::parser::OmpDependClause::InOut>(
+                  dependClause->v.u)
+                  .t);
+      omp::ClauseTaskDependAttr dependTypeOperand =
+          genDependKindAttr(firOpBuilder, dependClause);
+      dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
+                                dependTypeOperand);
+      for (const Fortran::parser::Designator &ompObject : depVal) {
+        Fortran::semantics::Symbol *sym = nullptr;
+        std::visit(
+            Fortran::common::visitors{
+                [&](const Fortran::parser::DataRef &designator) {
+                  if (const Fortran::parser::Name *name =
+                          std::get_if<Fortran::parser::Name>(&designator.u)) {
+                    sym = name->symbol;
+                  } else if (const Fortran::common::Indirection<
+                                 Fortran::parser::ArrayElement> *a =
+                                 std::get_if<Fortran::common::Indirection<
+                                     Fortran::parser::ArrayElement>>(
+                                     &designator.u)) {
+                    TODO(converter.getCurrentLocation(),
+                         "array sections not supported for task depend");
+                  }
+                },
+                [&](const Fortran::parser::Substring &designator) {
+                  TODO(converter.getCurrentLocation(),
+                       "substring not supported for task depend");
+                }},
+            (ompObject).u);
+        const mlir::Value variable = converter.getSymbolAddress(*sym);
+        dependOperands.push_back(((variable)));
+      }
     } else {
       TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
     }
@@ -1185,8 +1247,12 @@
     auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
         currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
         mergeableAttr, /*in_reduction_vars=*/ValueRange(),
-        /*in_reductions=*/nullptr, priorityClauseOperand, /*depends=*/nullptr,
-        /*depend_vars=*/ValueRange(), allocateOperands, allocatorOperands);
+        /*in_reductions=*/nullptr, priorityClauseOperand,
+        dependTypeOperands.empty()
+            ? nullptr
+            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
+                                   dependTypeOperands),
+        dependOperands, allocateOperands, allocatorOperands);
     createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
   } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
     // TODO: Add task_reduction support
diff --git a/test/Lower/OpenMP/task.f90 b/test/Lower/OpenMP/task.f90
index 810e3b5..d7419bd 100644
--- a/test/Lower/OpenMP/task.f90
+++ b/test/Lower/OpenMP/task.f90
@@ -100,6 +100,79 @@
 end subroutine task_allocate
 
 !===============================================================================
+! `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func @_QPtask_depend
+subroutine task_depend()
+  integer :: x
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<i32>) {
+  !$omp task depend(in : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend
+
+!CHECK-LABEL: func @_QPtask_depend_non_int
+subroutine task_depend_non_int()
+  character(len = 15) :: x
+  integer, allocatable :: y
+  complex :: z
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<!fir.char<1,15>>, taskdependin -> %{{.+}} : !fir.ref<!fir.box<!fir.heap<i32>>>, taskdependin ->  %{{.+}} : !fir.ref<!fir.complex<4>>) {
+  !$omp task depend(in : x, y, z)
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_non_int
+
+!CHECK-LABEL: func @_QPtask_depend_all_kinds_one_task
+subroutine task_depend_all_kinds_one_task()
+  integer :: x
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<i32>, taskdependout -> %{{.+}} : !fir.ref<i32>, taskdependinout -> %{{.+}} : !fir.ref<i32>) {
+  !$omp task depend(in : x) depend(out : x) depend(inout : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_all_kinds_one_task
+
+!CHECK-LABEL: func @_QPtask_depend_multi_var
+subroutine task_depend_multi_var()
+  integer :: x
+  integer :: y
+  !CHECK: omp.task depend(taskdependin -> %{{.*}} : !fir.ref<i32>, taskdependin -> %{{.+}} : !fir.ref<i32>) {
+  !$omp task depend(in :x,y)
+  !CHECK: arith.addi
+  x = x + 12
+  y = y + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_multi_var
+
+!CHECK-LABEL: func @_QPtask_depend_multi_task
+subroutine task_depend_multi_task()
+  integer :: x
+  !CHECK: omp.task depend(taskdependout -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(out : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+  !CHECK: omp.task depend(taskdependinout -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(inout : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(in : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_multi_task
+
+!===============================================================================
 ! `private` clause
 !===============================================================================
 !CHECK-LABEL: func @_QPtask_private