[flang][hlfir] Lower user defined assignment

Lower user defined assignment inside the hlfir.region_assign
"userDefinedAssignment" mlir region.

This is done by adding an entry point to ConvertCall.h in order
to call genUserCall with the region block arguments as arguments.

The codegen for hlfir.region_assign with user defined assignment
will be added in a later patch.

Differential Revision: https://reviews.llvm.org/D153404
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index f3efbfa..23a58eb 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3189,6 +3189,21 @@
     builder.restoreInsertionPoint(insertPt);
   }
 
+  bool firstDummyIsPointerOrAllocatable(
+      const Fortran::evaluate::ProcedureRef &userDefinedAssignment) {
+    using DummyAttr = Fortran::evaluate::characteristics::DummyDataObject::Attr;
+    if (auto procedure =
+            Fortran::evaluate::characteristics::Procedure::Characterize(
+                userDefinedAssignment.proc(), getFoldingContext()))
+      if (!procedure->dummyArguments.empty())
+        if (const auto *dataArg = std::get_if<
+                Fortran::evaluate::characteristics::DummyDataObject>(
+                &procedure->dummyArguments[0].u))
+          return dataArg->attrs.test(DummyAttr::Pointer) ||
+                 dataArg->attrs.test(DummyAttr::Allocatable);
+    return false;
+  }
+
   void genDataAssignment(
       const Fortran::evaluate::Assignment &assign,
       const Fortran::evaluate::ProcedureRef *userDefinedAssignment) {
@@ -3199,6 +3214,9 @@
     const bool isWholeAllocatableAssignment =
         !userDefinedAssignment && !isInsideHlfirWhere() &&
         Fortran::lower::isWholeAllocatable(assign.lhs);
+    const bool isUserDefAssignToPointerOrAllocatable =
+        userDefinedAssignment &&
+        firstDummyIsPointerOrAllocatable(*userDefinedAssignment);
     std::optional<Fortran::evaluate::DynamicType> lhsType =
         assign.lhs.GetType();
     const bool keepLhsLengthInAllocatableAssignment =
@@ -3233,7 +3251,8 @@
           loc, *this, assign.lhs, localSymbols, stmtCtx);
       // Dereference pointer LHS: the target is being assigned to.
       // Same for allocatables outside of whole allocatable assignments.
-      if (!isWholeAllocatableAssignment)
+      if (!isWholeAllocatableAssignment &&
+          !isUserDefAssignToPointerOrAllocatable)
         lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs);
       return lhs;
     };
@@ -3263,26 +3282,46 @@
     // Lower LHS in its own region.
     builder.createBlock(&regionAssignOp.getLhsRegion());
     Fortran::lower::StatementContext lhsContext;
+    mlir::Value lhsYield = nullptr;
     if (!lhsHasVectorSubscripts) {
       hlfir::Entity lhs = evaluateLhs(lhsContext);
       auto lhsYieldOp = builder.create<hlfir::YieldOp>(loc, lhs);
       genCleanUpInRegionIfAny(loc, builder, lhsYieldOp.getCleanup(),
                               lhsContext);
+      lhsYield = lhs;
     } else {
       hlfir::ElementalAddrOp elementalAddr =
           Fortran::lower::convertVectorSubscriptedExprToElementalAddr(
               loc, *this, assign.lhs, localSymbols, lhsContext);
       genCleanUpInRegionIfAny(loc, builder, elementalAddr.getCleanup(),
                               lhsContext);
+      lhsYield = elementalAddr.getYieldOp().getEntity();
     }
+    assert(lhsYield && "must have been set");
 
     // Add "realloc" flag to hlfir.region_assign.
     if (isWholeAllocatableAssignment)
       TODO(loc, "assignment to a whole allocatable inside FORALL");
-    // Generate the hlfir.region_assign userDefinedAssignment region.
-    if (userDefinedAssignment)
-      TODO(loc, "HLFIR user defined assignment");
 
+    // Generate the hlfir.region_assign userDefinedAssignment region.
+    if (userDefinedAssignment) {
+      mlir::Type rhsType = rhs.getType();
+      mlir::Type lhsType = lhsYield.getType();
+      if (userDefinedAssignment->IsElemental()) {
+        rhsType = hlfir::getEntityElementType(rhs);
+        lhsType = hlfir::getEntityElementType(hlfir::Entity{lhsYield});
+      }
+      builder.createBlock(&regionAssignOp.getUserDefinedAssignment(),
+                          mlir::Region::iterator{}, {rhsType, lhsType},
+                          {loc, loc});
+      auto end = builder.create<fir::FirEndOp>(loc);
+      builder.setInsertionPoint(end);
+      hlfir::Entity lhsBlockArg{regionAssignOp.getUserAssignmentLhs()};
+      hlfir::Entity rhsBlockArg{regionAssignOp.getUserAssignmentRhs()};
+      Fortran::lower::convertUserDefinedAssignmentToHLFIR(
+          loc, *this, *userDefinedAssignment, lhsBlockArg, rhsBlockArg,
+          localSymbols);
+    }
     builder.setInsertionPointAfter(regionAssignOp);
   }
 
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 4f4505c..0564468 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -1919,3 +1919,19 @@
   CallContext callContext(procRef, resultType, loc, converter, symMap, stmtCtx);
   return genProcedureRef(callContext);
 }
+
+void Fortran::lower::convertUserDefinedAssignmentToHLFIR(
+    mlir::Location loc, Fortran::lower::AbstractConverter &converter,
+    const evaluate::ProcedureRef &procRef, hlfir::Entity lhs, hlfir::Entity rhs,
+    Fortran::lower::SymMap &symMap) {
+  Fortran::lower::StatementContext definedAssignmentContext;
+  CallContext callContext(procRef, /*resultType=*/std::nullopt, loc, converter,
+                          symMap, definedAssignmentContext);
+  Fortran::lower::CallerInterface caller(procRef, converter);
+  mlir::FunctionType callSiteType = caller.genFunctionType();
+  PreparedActualArgument preparedLhs{lhs, /*isPresent=*/std::nullopt};
+  PreparedActualArgument preparedRhs{rhs, /*isPresent=*/std::nullopt};
+  PreparedActualArguments loweredActuals{preparedLhs, preparedRhs};
+  genUserCall(loweredActuals, caller, callSiteType, callContext);
+  return;
+}
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 99a699b..b801058 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -710,6 +710,16 @@
   return fir::ReferenceType::get(eleTy);
 }
 
+mlir::Type hlfir::getEntityElementType(hlfir::Entity entity) {
+  if (entity.isVariable())
+    return getVariableElementType(entity);
+  if (entity.isScalar())
+    return entity.getType();
+  auto exprType = mlir::dyn_cast<hlfir::ExprType>(entity.getType());
+  assert(exprType && "array value must be an hlfir.expr");
+  return exprType.getElementExprType();
+}
+
 static hlfir::ExprType getArrayExprType(mlir::Type elementType,
                                         mlir::Value shape, bool isPolymorphic) {
   unsigned rank = shape.getType().cast<fir::ShapeType>().getRank();
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 21a44c0..e34123f 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -1287,6 +1287,13 @@
   return mlir::success();
 }
 
+hlfir::YieldOp hlfir::ElementalAddrOp::getYieldOp() {
+  hlfir::YieldOp yieldOp =
+      mlir::dyn_cast_or_null<hlfir::YieldOp>(getTerminator(getBody()));
+  assert(yieldOp && "element_addr is ill-formed");
+  return yieldOp;
+}
+
 //===----------------------------------------------------------------------===//
 // OrderedAssignmentTreeOpInterface
 //===----------------------------------------------------------------------===//