Upstream support for POINTER assignment in FORALL.

Reviewed By: vdonaldson, PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D125140
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index cb3b8ce..a5d456d 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -373,16 +373,12 @@
     return Fortran::lower::createSomeExtendedExpression(
         loc ? *loc : toLocation(), *this, expr, localSymbols, context);
   }
-  fir::MutableBoxValue
-  genExprMutableBox(mlir::Location loc,
-                    const Fortran::lower::SomeExpr &expr) override final {
-    return Fortran::lower::createMutableBox(loc, *this, expr, localSymbols);
-  }
-  fir::ExtendedValue genExprBox(const Fortran::lower::SomeExpr &expr,
-                                Fortran::lower::StatementContext &context,
-                                mlir::Location loc) override final {
+
+  fir::ExtendedValue
+  genExprBox(mlir::Location loc, const Fortran::lower::SomeExpr &expr,
+             Fortran::lower::StatementContext &stmtCtx) override final {
     return Fortran::lower::createBoxValue(loc, *this, expr, localSymbols,
-                                          context);
+                                          stmtCtx);
   }
 
   Fortran::evaluate::FoldingContext &getFoldingContext() override final {
@@ -441,8 +437,8 @@
           // Create a contiguous temp with the same shape and length as
           // the original variable described by a fir.box.
           llvm::SmallVector<mlir::Value> extents =
-              fir::factory::getExtents(*builder, loc, hexv);
-          if (box.isDerivedWithLengthParameters())
+              fir::factory::getExtents(loc, *builder, hexv);
+          if (box.isDerivedWithLenParameters())
             TODO(loc, "get length parameters from derived type BoxValue");
           if (box.isCharacter()) {
             mlir::Value len = fir::factory::readCharLen(*builder, loc, box);
@@ -459,7 +455,7 @@
         },
         [&](const auto &) -> fir::ExtendedValue {
           mlir::Value temp =
-              allocate(fir::factory::getExtents(*builder, loc, hexv),
+              allocate(fir::factory::getExtents(loc, *builder, hexv),
                        fir::getTypeParams(hexv));
           return fir::substBase(hexv, temp);
         });
@@ -1598,7 +1594,7 @@
   fir::ExtendedValue
   genAssociateSelector(const Fortran::lower::SomeExpr &selector,
                        Fortran::lower::StatementContext &stmtCtx) {
-    return isArraySectionWithoutVectorSubscript(selector)
+    return Fortran::lower::isArraySectionWithoutVectorSubscript(selector)
                ? Fortran::lower::createSomeArrayBox(*this, selector,
                                                     localSymbols, stmtCtx)
                : genExprAddr(selector, stmtCtx);
@@ -1850,9 +1846,16 @@
   /// Generate an array assignment.
   /// This is an assignment expression with rank > 0. The assignment may or may
   /// not be in a WHERE and/or FORALL context.
-  void genArrayAssignment(const Fortran::evaluate::Assignment &assign,
-                          Fortran::lower::StatementContext &stmtCtx) {
-    if (isWholeAllocatable(assign.lhs)) {
+  /// In a FORALL context, the assignment may be a pointer assignment and the \p
+  /// lbounds and \p ubounds parameters should only be used in such a pointer
+  /// assignment case. (If both are None then the array assignment cannot be a
+  /// pointer assignment.)
+  void genArrayAssignment(
+      const Fortran::evaluate::Assignment &assign,
+      Fortran::lower::StatementContext &stmtCtx,
+      llvm::Optional<llvm::SmallVector<mlir::Value>> lbounds = llvm::None,
+      llvm::Optional<llvm::SmallVector<mlir::Value>> ubounds = llvm::None) {
+    if (Fortran::lower::isWholeAllocatable(assign.lhs)) {
       // Assignment to allocatables may require the lhs to be
       // deallocated/reallocated. See Fortran 2018 10.2.1.3 p3
       Fortran::lower::createAllocatableArrayAssignment(
@@ -1861,6 +1864,17 @@
       return;
     }
 
+    if (lbounds.hasValue()) {
+      // Array of POINTER entities, with elemental assignment.
+      if (!Fortran::lower::isWholePointer(assign.lhs))
+        fir::emitFatalError(toLocation(), "pointer assignment to non-pointer");
+
+      Fortran::lower::createArrayOfPointerAssignment(
+          *this, assign.lhs, assign.rhs, explicitIterSpace, implicitIterSpace,
+          lbounds.getValue(), ubounds, localSymbols, stmtCtx);
+      return;
+    }
+
     if (!implicitIterationSpace() && !explicitIterationSpace()) {
       // No masks and the iteration space is implied by the array, so create a
       // simple array assignment.
@@ -1885,13 +1899,6 @@
                                  : implicitIterSpace.stmtContext());
   }
 
-  static bool
-  isArraySectionWithoutVectorSubscript(const Fortran::lower::SomeExpr &expr) {
-    return expr.Rank() > 0 && Fortran::evaluate::IsVariable(expr) &&
-           !Fortran::evaluate::UnwrapWholeSymbolDataRef(expr) &&
-           !Fortran::evaluate::HasVectorSubscript(expr);
-  }
-
 #if !defined(NDEBUG)
   static bool isFuncResultDesignator(const Fortran::lower::SomeExpr &expr) {
     const Fortran::semantics::Symbol *sym =
@@ -1900,10 +1907,10 @@
   }
 #endif
 
-  static bool isWholeAllocatable(const Fortran::lower::SomeExpr &expr) {
-    const Fortran::semantics::Symbol *sym =
-        Fortran::evaluate::UnwrapWholeSymbolOrComponentDataRef(expr);
-    return sym && Fortran::semantics::IsAllocatable(*sym);
+  inline fir::MutableBoxValue
+  genExprMutableBox(mlir::Location loc,
+                    const Fortran::lower::SomeExpr &expr) override final {
+    return Fortran::lower::createMutableBox(loc, *this, expr, localSymbols);
   }
 
   /// Shared for both assignments and pointer assignments.
@@ -1929,7 +1936,8 @@
               assert(lhsType && "lhs cannot be typeless");
               // Assignment to polymorphic allocatables may require changing the
               // variable dynamic type (See Fortran 2018 10.2.1.3 p3).
-              if (lhsType->IsPolymorphic() && isWholeAllocatable(assign.lhs))
+              if (lhsType->IsPolymorphic() &&
+                  Fortran::lower::isWholeAllocatable(assign.lhs))
                 TODO(loc, "assignment to polymorphic allocatable");
 
               // Note: No ad-hoc handling for pointers is required here. The
@@ -1950,7 +1958,8 @@
               fir::ExtendedValue rhs = isNumericScalar
                                            ? genExprValue(assign.rhs, stmtCtx)
                                            : genExprAddr(assign.rhs, stmtCtx);
-              bool lhsIsWholeAllocatable = isWholeAllocatable(assign.lhs);
+              const bool lhsIsWholeAllocatable =
+                  Fortran::lower::isWholeAllocatable(assign.lhs);
               llvm::Optional<fir::factory::MutableBoxReallocation> lhsRealloc;
               llvm::Optional<fir::MutableBoxValue> lhsMutableBox;
               auto lhs = [&]() -> fir::ExtendedValue {
@@ -1959,7 +1968,7 @@
                   llvm::SmallVector<mlir::Value> lengthParams;
                   if (const fir::CharBoxValue *charBox = rhs.getCharBox())
                     lengthParams.push_back(charBox->getLen());
-                  else if (fir::isDerivedWithLengthParameters(rhs))
+                  else if (fir::isDerivedWithLenParameters(rhs))
                     TODO(loc, "assignment to derived type allocatable with "
                               "length parameters");
                   lhsRealloc = fir::factory::genReallocIfNeeded(
@@ -2023,7 +2032,7 @@
             // [3] Pointer assignment with possibly empty bounds-spec. R1035: a
             // bounds-spec is a lower bound value.
             [&](const Fortran::evaluate::Assignment::BoundsSpec &lbExprs) {
-              if (IsProcedure(assign.rhs))
+              if (Fortran::evaluate::IsProcedure(assign.rhs))
                 TODO(loc, "procedure pointer assignment");
               std::optional<Fortran::evaluate::DynamicType> lhsType =
                   assign.lhs.GetType();
@@ -2034,23 +2043,19 @@
                   (rhsType && rhsType->IsPolymorphic()))
                 TODO(loc, "pointer assignment involving polymorphic entity");
 
-              // FIXME: in the explicit space context, we want to use
-              // ScalarArrayExprLowering here.
-              fir::MutableBoxValue lhs = genExprMutableBox(loc, assign.lhs);
               llvm::SmallVector<mlir::Value> lbounds;
               for (const Fortran::evaluate::ExtentExpr &lbExpr : lbExprs)
                 lbounds.push_back(
                     fir::getBase(genExprValue(toEvExpr(lbExpr), stmtCtx)));
+              if (explicitIterationSpace()) {
+                // Pointer assignment in FORALL context. Copy the rhs box value
+                // into the lhs box variable.
+                genArrayAssignment(assign, stmtCtx, lbounds);
+                return;
+              }
+              fir::MutableBoxValue lhs = genExprMutableBox(loc, assign.lhs);
               Fortran::lower::associateMutableBox(*this, loc, lhs, assign.rhs,
                                                   lbounds, stmtCtx);
-              if (explicitIterationSpace()) {
-                mlir::ValueRange inners = explicitIterSpace.getInnerArgs();
-                if (!inners.empty()) {
-                  // TODO: should force a copy-in/copy-out here.
-                  // e.g., obj%ptr(i+1) => obj%ptr(i)
-                  builder->create<fir::ResultOp>(loc, inners);
-                }
-              }
             },
 
             // [4] Pointer assignment with bounds-remapping. R1036: a
@@ -2066,14 +2071,6 @@
                   (rhsType && rhsType->IsPolymorphic()))
                 TODO(loc, "pointer assignment involving polymorphic entity");
 
-              // FIXME: in the explicit space context, we want to use
-              // ScalarArrayExprLowering here.
-              fir::MutableBoxValue lhs = genExprMutableBox(loc, assign.lhs);
-              if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
-                      assign.rhs)) {
-                fir::factory::disassociateMutableBox(*builder, loc, lhs);
-                return;
-              }
               llvm::SmallVector<mlir::Value> lbounds;
               llvm::SmallVector<mlir::Value> ubounds;
               for (const std::pair<Fortran::evaluate::ExtentExpr,
@@ -2086,9 +2083,22 @@
                 ubounds.push_back(
                     fir::getBase(genExprValue(toEvExpr(ubExpr), stmtCtx)));
               }
+              if (explicitIterationSpace()) {
+                // Pointer assignment in FORALL context. Copy the rhs box value
+                // into the lhs box variable.
+                genArrayAssignment(assign, stmtCtx, lbounds, ubounds);
+                return;
+              }
+              fir::MutableBoxValue lhs = genExprMutableBox(loc, assign.lhs);
+              if (Fortran::evaluate::UnwrapExpr<Fortran::evaluate::NullPointer>(
+                      assign.rhs)) {
+                fir::factory::disassociateMutableBox(*builder, loc, lhs);
+                return;
+              }
               // Do not generate a temp in case rhs is an array section.
               fir::ExtendedValue rhs =
-                  isArraySectionWithoutVectorSubscript(assign.rhs)
+                  Fortran::lower::isArraySectionWithoutVectorSubscript(
+                      assign.rhs)
                       ? Fortran::lower::createSomeArrayBox(
                             *this, assign.rhs, localSymbols, stmtCtx)
                       : genExprAddr(assign.rhs, stmtCtx);
@@ -2096,11 +2106,8 @@
                                                          rhs, lbounds, ubounds);
               if (explicitIterationSpace()) {
                 mlir::ValueRange inners = explicitIterSpace.getInnerArgs();
-                if (!inners.empty()) {
-                  // TODO: should force a copy-in/copy-out here.
-                  // e.g., obj%ptr(i+1) => obj%ptr(i)
+                if (!inners.empty())
                   builder->create<fir::ResultOp>(loc, inners);
-                }
               }
             },
         },
@@ -2349,7 +2356,7 @@
                             const Fortran::lower::CalleeInterface &callee) {
     assert(builder && "require a builder object at this point");
     using PassBy = Fortran::lower::CalleeInterface::PassEntityBy;
-    auto mapPassedEntity = [&](const auto arg) -> void {
+    auto mapPassedEntity = [&](const auto arg) {
       if (arg.passBy == PassBy::AddressAndLength) {
         // TODO: now that fir call has some attributes regarding character
         // return, PassBy::AddressAndLength should be retired.