[flang][fir] Lower `do concurrent` loop nests to `fir.do_concurrent` (#132904)
Adds support for lowering `do concurrent` nests from PFT to the new
`fir.do_concurrent` MLIR op as well as its special terminator
`fir.do_concurrent.loop` which models the actual loop nest.
To that end, this PR emits the allocations for the iteration variables
within the block of the `fir.do_concurrent` op and creates a region for
the `fir.do_concurrent.loop` op that accepts arguments equal in number
to the number of the input `do concurrent` iteration ranges.
For example, given the following input:
```fortran
do concurrent(i=1:10, j=11:20)
end do
```
the changes in this PR emit the following MLIR:
```mlir
fir.do_concurrent {
%22 = fir.alloca i32 {bindc_name = "i"}
%23:2 = hlfir.declare %22 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%24 = fir.alloca i32 {bindc_name = "j"}
%25:2 = hlfir.declare %24 {uniq_name = "_QFsub1Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
fir.do_concurrent.loop (%arg1, %arg2) = (%18, %20) to (%19, %21) step (%c1, %c1_0) {
%26 = fir.convert %arg1 : (index) -> i32
fir.store %26 to %23#0 : !fir.ref<i32>
%27 = fir.convert %arg2 : (index) -> i32
fir.store %27 to %25#0 : !fir.ref<i32>
}
}
```
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index b4d1197..625dd11 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -94,10 +94,11 @@
template <typename T>
explicit IncrementLoopInfo(Fortran::semantics::Symbol &sym, const T &lower,
const T &upper, const std::optional<T> &step,
- bool isUnordered = false)
+ bool isConcurrent = false)
: loopVariableSym{&sym}, lowerExpr{Fortran::semantics::GetExpr(lower)},
upperExpr{Fortran::semantics::GetExpr(upper)},
- stepExpr{Fortran::semantics::GetExpr(step)}, isUnordered{isUnordered} {}
+ stepExpr{Fortran::semantics::GetExpr(step)},
+ isConcurrent{isConcurrent} {}
IncrementLoopInfo(IncrementLoopInfo &&) = default;
IncrementLoopInfo &operator=(IncrementLoopInfo &&x) = default;
@@ -120,7 +121,7 @@
const Fortran::lower::SomeExpr *upperExpr;
const Fortran::lower::SomeExpr *stepExpr;
const Fortran::lower::SomeExpr *maskExpr = nullptr;
- bool isUnordered; // do concurrent, forall
+ bool isConcurrent;
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
llvm::SmallVector<
@@ -130,7 +131,7 @@
mlir::Value loopVariable = nullptr;
// Data members for structured loops.
- fir::DoLoopOp doLoop = nullptr;
+ mlir::Operation *loopOp = nullptr;
// Data members for unstructured loops.
bool hasRealControl = false;
@@ -1980,7 +1981,7 @@
llvm_unreachable("illegal reduction operator");
}
- /// Collect DO CONCURRENT or FORALL loop control information.
+ /// Collect DO CONCURRENT loop control information.
IncrementLoopNestInfo getConcurrentControl(
const Fortran::parser::ConcurrentHeader &header,
const std::list<Fortran::parser::LocalitySpec> &localityList = {}) {
@@ -2291,8 +2292,14 @@
mlir::LLVM::LoopAnnotationAttr la = mlir::LLVM::LoopAnnotationAttr::get(
builder->getContext(), {}, /*vectorize=*/va, {}, /*unroll*/ ua,
/*unroll_and_jam*/ uja, {}, {}, {}, {}, {}, {}, {}, {}, {}, {});
- if (has_attrs)
- info.doLoop.setLoopAnnotationAttr(la);
+ if (has_attrs) {
+ if (auto loopOp = mlir::dyn_cast<fir::DoLoopOp>(info.loopOp))
+ loopOp.setLoopAnnotationAttr(la);
+
+ if (auto doConcurrentOp =
+ mlir::dyn_cast<fir::DoConcurrentLoopOp>(info.loopOp))
+ doConcurrentOp.setLoopAnnotationAttr(la);
+ }
}
/// Generate FIR to begin a structured or unstructured increment loop nest.
@@ -2301,96 +2308,77 @@
llvm::SmallVectorImpl<const Fortran::parser::CompilerDirective *> &dirs) {
assert(!incrementLoopNestInfo.empty() && "empty loop nest");
mlir::Location loc = toLocation();
- mlir::Operation *boundsAndStepIP = nullptr;
mlir::arith::IntegerOverflowFlags iofBackup{};
+ llvm::SmallVector<mlir::Value> nestLBs;
+ llvm::SmallVector<mlir::Value> nestUBs;
+ llvm::SmallVector<mlir::Value> nestSts;
+ llvm::SmallVector<mlir::Value> nestReduceOperands;
+ llvm::SmallVector<mlir::Attribute> nestReduceAttrs;
+ bool genDoConcurrent = false;
+
for (IncrementLoopInfo &info : incrementLoopNestInfo) {
- mlir::Value lowerValue;
- mlir::Value upperValue;
- mlir::Value stepValue;
+ genDoConcurrent = info.isStructured() && info.isConcurrent;
- {
- mlir::OpBuilder::InsertionGuard guard(*builder);
-
- // Set the IP before the first loop in the nest so that all nest bounds
- // and step values are created outside the nest.
- if (boundsAndStepIP)
- builder->setInsertionPointAfter(boundsAndStepIP);
-
+ if (!genDoConcurrent)
info.loopVariable = genLoopVariableAddress(loc, *info.loopVariableSym,
- info.isUnordered);
- if (!getLoweringOptions().getIntegerWrapAround()) {
- iofBackup = builder->getIntegerOverflowFlags();
- builder->setIntegerOverflowFlags(
- mlir::arith::IntegerOverflowFlags::nsw);
- }
- lowerValue = genControlValue(info.lowerExpr, info);
- upperValue = genControlValue(info.upperExpr, info);
- bool isConst = true;
- stepValue = genControlValue(info.stepExpr, info,
- info.isStructured() ? nullptr : &isConst);
- if (!getLoweringOptions().getIntegerWrapAround())
- builder->setIntegerOverflowFlags(iofBackup);
- boundsAndStepIP = stepValue.getDefiningOp();
+ info.isConcurrent);
- // Use a temp variable for unstructured loops with non-const step.
- if (!isConst) {
- info.stepVariable =
- builder->createTemporary(loc, stepValue.getType());
- boundsAndStepIP =
- builder->create<fir::StoreOp>(loc, stepValue, info.stepVariable);
- }
+ if (!getLoweringOptions().getIntegerWrapAround()) {
+ iofBackup = builder->getIntegerOverflowFlags();
+ builder->setIntegerOverflowFlags(
+ mlir::arith::IntegerOverflowFlags::nsw);
}
+ nestLBs.push_back(genControlValue(info.lowerExpr, info));
+ nestUBs.push_back(genControlValue(info.upperExpr, info));
+ bool isConst = true;
+ nestSts.push_back(genControlValue(
+ info.stepExpr, info, info.isStructured() ? nullptr : &isConst));
+
+ if (!getLoweringOptions().getIntegerWrapAround())
+ builder->setIntegerOverflowFlags(iofBackup);
+
+ // Use a temp variable for unstructured loops with non-const step.
+ if (!isConst) {
+ mlir::Value stepValue = nestSts.back();
+ info.stepVariable = builder->createTemporary(loc, stepValue.getType());
+ builder->create<fir::StoreOp>(loc, stepValue, info.stepVariable);
+ }
+
+ if (genDoConcurrent && nestReduceOperands.empty()) {
+ // Create DO CONCURRENT reduce operands and attributes
+ for (const auto &reduceSym : info.reduceSymList) {
+ const fir::ReduceOperationEnum reduceOperation = reduceSym.first;
+ const Fortran::semantics::Symbol *sym = reduceSym.second;
+ fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
+ nestReduceOperands.push_back(fir::getBase(exv));
+ auto reduceAttr =
+ fir::ReduceAttr::get(builder->getContext(), reduceOperation);
+ nestReduceAttrs.push_back(reduceAttr);
+ }
+ }
+ }
+
+ for (auto [info, lowerValue, upperValue, stepValue] :
+ llvm::zip_equal(incrementLoopNestInfo, nestLBs, nestUBs, nestSts)) {
// Structured loop - generate fir.do_loop.
if (info.isStructured()) {
+ if (genDoConcurrent)
+ continue;
+
+ // The loop variable is a doLoop op argument.
mlir::Type loopVarType = info.getLoopVariableType();
- mlir::Value loopValue;
- if (info.isUnordered) {
- llvm::SmallVector<mlir::Value> reduceOperands;
- llvm::SmallVector<mlir::Attribute> reduceAttrs;
- // Create DO CONCURRENT reduce operands and attributes
- for (const auto &reduceSym : info.reduceSymList) {
- const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
- const Fortran::semantics::Symbol *sym = reduceSym.second;
- fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
- reduceOperands.push_back(fir::getBase(exv));
- auto reduce_attr =
- fir::ReduceAttr::get(builder->getContext(), reduce_operation);
- reduceAttrs.push_back(reduce_attr);
- }
- // The loop variable value is explicitly updated.
- info.doLoop = builder->create<fir::DoLoopOp>(
- loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
- /*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
- llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
- builder->setInsertionPointToStart(info.doLoop.getBody());
- loopValue = builder->createConvert(loc, loopVarType,
- info.doLoop.getInductionVar());
- } else {
- // The loop variable is a doLoop op argument.
- info.doLoop = builder->create<fir::DoLoopOp>(
- loc, lowerValue, upperValue, stepValue, /*unordered=*/false,
- /*finalCountValue=*/true,
- builder->createConvert(loc, loopVarType, lowerValue));
- builder->setInsertionPointToStart(info.doLoop.getBody());
- loopValue = info.doLoop.getRegionIterArgs()[0];
- }
+ auto loopOp = builder->create<fir::DoLoopOp>(
+ loc, lowerValue, upperValue, stepValue, /*unordered=*/false,
+ /*finalCountValue=*/true,
+ builder->createConvert(loc, loopVarType, lowerValue));
+ info.loopOp = loopOp;
+ builder->setInsertionPointToStart(loopOp.getBody());
+ mlir::Value loopValue = loopOp.getRegionIterArgs()[0];
+
// Update the loop variable value in case it has non-index references.
builder->create<fir::StoreOp>(loc, loopValue, info.loopVariable);
- if (info.maskExpr) {
- Fortran::lower::StatementContext stmtCtx;
- mlir::Value maskCond = createFIRExpr(loc, info.maskExpr, stmtCtx);
- stmtCtx.finalizeAndReset();
- mlir::Value maskCondCast =
- builder->createConvert(loc, builder->getI1Type(), maskCond);
- auto ifOp = builder->create<fir::IfOp>(loc, maskCondCast,
- /*withElseRegion=*/false);
- builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
- }
- if (info.hasLocalitySpecs())
- handleLocalitySpecs(info);
-
addLoopAnnotationAttr(info, dirs);
continue;
}
@@ -2454,6 +2442,60 @@
builder->restoreInsertionPoint(insertPt);
}
}
+
+ if (genDoConcurrent) {
+ auto loopWrapperOp = builder->create<fir::DoConcurrentOp>(loc);
+ builder->setInsertionPointToStart(
+ builder->createBlock(&loopWrapperOp.getRegion()));
+
+ for (IncrementLoopInfo &info : llvm::reverse(incrementLoopNestInfo)) {
+ info.loopVariable = genLoopVariableAddress(loc, *info.loopVariableSym,
+ info.isConcurrent);
+ }
+
+ builder->setInsertionPointToEnd(loopWrapperOp.getBody());
+ auto loopOp = builder->create<fir::DoConcurrentLoopOp>(
+ loc, nestLBs, nestUBs, nestSts, nestReduceOperands,
+ nestReduceAttrs.empty()
+ ? nullptr
+ : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs),
+ nullptr);
+
+ llvm::SmallVector<mlir::Type> loopBlockArgTypes(
+ incrementLoopNestInfo.size(), builder->getIndexType());
+ llvm::SmallVector<mlir::Location> loopBlockArgLocs(
+ incrementLoopNestInfo.size(), loc);
+ mlir::Region &loopRegion = loopOp.getRegion();
+ mlir::Block *loopBlock = builder->createBlock(
+ &loopRegion, loopRegion.begin(), loopBlockArgTypes, loopBlockArgLocs);
+ builder->setInsertionPointToStart(loopBlock);
+
+ for (auto [info, blockArg] :
+ llvm::zip_equal(incrementLoopNestInfo, loopBlock->getArguments())) {
+ info.loopOp = loopOp;
+ mlir::Value loopValue =
+ builder->createConvert(loc, info.getLoopVariableType(), blockArg);
+ builder->create<fir::StoreOp>(loc, loopValue, info.loopVariable);
+
+ if (info.maskExpr) {
+ Fortran::lower::StatementContext stmtCtx;
+ mlir::Value maskCond = createFIRExpr(loc, info.maskExpr, stmtCtx);
+ stmtCtx.finalizeAndReset();
+ mlir::Value maskCondCast =
+ builder->createConvert(loc, builder->getI1Type(), maskCond);
+ auto ifOp = builder->create<fir::IfOp>(loc, maskCondCast,
+ /*withElseRegion=*/false);
+ builder->setInsertionPointToStart(&ifOp.getThenRegion().front());
+ }
+ }
+
+ IncrementLoopInfo &innermostInfo = incrementLoopNestInfo.back();
+
+ if (innermostInfo.hasLocalitySpecs())
+ handleLocalitySpecs(innermostInfo);
+
+ addLoopAnnotationAttr(innermostInfo, dirs);
+ }
}
/// Generate FIR to end a structured or unstructured increment loop nest.
@@ -2470,29 +2512,31 @@
it != rend; ++it) {
IncrementLoopInfo &info = *it;
if (info.isStructured()) {
- // End fir.do_loop.
- if (info.isUnordered) {
- builder->setInsertionPointAfter(info.doLoop);
+ // End fir.do_concurent.loop.
+ if (info.isConcurrent) {
+ builder->setInsertionPointAfter(info.loopOp->getParentOp());
continue;
}
+
+ // End fir.do_loop.
// Decrement tripVariable.
- builder->setInsertionPointToEnd(info.doLoop.getBody());
+ auto doLoopOp = mlir::cast<fir::DoLoopOp>(info.loopOp);
+ builder->setInsertionPointToEnd(doLoopOp.getBody());
llvm::SmallVector<mlir::Value, 2> results;
results.push_back(builder->create<mlir::arith::AddIOp>(
- loc, info.doLoop.getInductionVar(), info.doLoop.getStep(),
- iofAttr));
+ loc, doLoopOp.getInductionVar(), doLoopOp.getStep(), iofAttr));
// Step loopVariable to help optimizations such as vectorization.
// Induction variable elimination will clean up as necessary.
mlir::Value step = builder->createConvert(
- loc, info.getLoopVariableType(), info.doLoop.getStep());
+ loc, info.getLoopVariableType(), doLoopOp.getStep());
mlir::Value loopVar =
builder->create<fir::LoadOp>(loc, info.loopVariable);
results.push_back(
builder->create<mlir::arith::AddIOp>(loc, loopVar, step, iofAttr));
builder->create<fir::ResultOp>(loc, results);
- builder->setInsertionPointAfter(info.doLoop);
+ builder->setInsertionPointAfter(doLoopOp);
// The loop control variable may be used after the loop.
- builder->create<fir::StoreOp>(loc, info.doLoop.getResult(1),
+ builder->create<fir::StoreOp>(loc, doLoopOp.getResult(1),
info.loopVariable);
continue;
}