[Flang][OpenMP] Properly bind arguments of composite operations (#113682)
When composite constructs are lowered, clauses for each leaf construct
are lowered before creating the set of loop wrapper operations, using
these outside values to populate their operand lists. Then, when the
loop nest associated to that composite construct is lowered, the binding
of Fortran symbols to the entry block arguments defined by these loop
wrappers is performed, resulting in the creation of `hlfir.declare`
operations in the entry block of the `omp.loop_nest`.
This approach prevents `hlfir.declare` operations related to the binding
and other operations resulting from the evaluation of the clauses from
being inserted between loop wrapper operations, which would be an
illegal MLIR representation. However, this introduces the problem of
entry block arguments defined by a wrapper that then should be used by
one of its nested wrappers, because the corresponding Fortran symbol
would still be mapped to an outside value at the time of gathering the
list of operands for the nested wrapper.
This patch adds operand re-mapping logic to update wrappers without
changing when clauses are evaluated or where the `hlfir.declare`
creation is performed.
GitOrigin-RevId: 6c28530ed082204a1b6d20b45482e81d4cd5ead4
diff --git a/lib/Lower/OpenMP/OpenMP.cpp b/lib/Lower/OpenMP/OpenMP.cpp
index 84985b8..329cbf3 100644
--- a/lib/Lower/OpenMP/OpenMP.cpp
+++ b/lib/Lower/OpenMP/OpenMP.cpp
@@ -589,10 +589,27 @@
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
firOpBuilder.createBlock(®ion, {}, tiv, locs);
+ // Update nested wrapper operands if parent wrappers have mapped these values
+ // to block arguments.
+ //
+ // Binding these values earlier would take care of this, but we cannot rely on
+ // that approach because binding in between the creation of a wrapper and the
+ // next one would result in 'hlfir.declare' operations being introduced inside
+ // of a wrapper, which is illegal.
+ mlir::IRMapping mapper;
+ for (auto [argGeneratingOp, blockArgs] : wrapperArgs) {
+ for (mlir::OpOperand &operand : argGeneratingOp->getOpOperands())
+ operand.set(mapper.lookupOrDefault(operand.get()));
+
+ for (const auto [arg, var] : llvm::zip_equal(
+ argGeneratingOp->getRegion(0).getArguments(), blockArgs.getVars()))
+ mapper.map(var, arg);
+ }
+
// Bind the entry block arguments of parent wrappers to the corresponding
// symbols.
- for (auto [argGeneratingOp, args] : wrapperArgs)
- bindEntryBlockArgs(converter, argGeneratingOp, args);
+ for (auto [argGeneratingOp, blockArgs] : wrapperArgs)
+ bindEntryBlockArgs(converter, argGeneratingOp, blockArgs);
// The argument is not currently in memory, so make a temporary for the
// argument, and store it there, then bind that location to the argument.