blob: a1a89bb5154fb4470c18455c160d6e034a735029 [file] [log] [blame]
//===- LowerHLFIROrderedAssignments.cpp - Lower HLFIR ordered assignments -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file defines a pass to lower HLFIR ordered assignments.
// Ordered assignments are all the operations with the
// OrderedAssignmentTreeOpInterface that implements user defined assignments,
// assignment to vector subscripted entities, and assignments inside forall and
// where.
// The pass lowers these operations to regular hlfir.assign, loops and, if
// needed, introduces temporary storage to fulfill Fortran semantics.
//
// For each rewrite, an analysis builds an evaluation schedule, and then the
// new code is generated by following the evaluation schedule.
//===----------------------------------------------------------------------===//
#include "ScheduleOrderedAssignments.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/TemporaryStorage.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
namespace hlfir {
#define GEN_PASS_DEF_LOWERHLFIRORDEREDASSIGNMENTS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
#define DEBUG_TYPE "flang-ordered-assignment"
// Test option only to test the scheduling part only (operations are erased
// without codegen). The only goal is to allow printing and testing the debug
// info.
static llvm::cl::opt<bool> dbgScheduleOnly(
"flang-dbg-order-assignment-schedule-only",
llvm::cl::desc("Only run ordered assignment scheduling with no codegen"),
llvm::cl::init(false));
namespace {
/// Structure that represents a masked expression being lowered. Masked
/// expressions are any expressions inside an hlfir.where. As described in
/// Fortran 2018 section 10.2.3.2, the evaluation of the elemental parts of such
/// expressions must be masked, while the evaluation of none elemental parts
/// must not be masked. This structure analyzes the region evaluating the
/// expression and allows splitting the generation of the none elemental part
/// from the elemental part.
struct MaskedArrayExpr {
MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
bool isOuterMaskExpr);
/// Generate the none elemental part. Must be called outside of the
/// loops created for the WHERE construct.
void generateNoneElementalPart(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper);
/// Methods below can only be called once generateNoneElementalPart has been
/// called.
/// Return the shape of the expression.
mlir::Value generateShape(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper);
/// Return the value of an element value for this expression given the current
/// where loop indices.
mlir::Value generateElementalParts(fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
mlir::IRMapping &mapper);
/// Generate the cleanup for the none elemental parts, if any. This must be
/// called after the loops created for the WHERE construct.
void generateNoneElementalCleanupIfAny(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper);
/// Helper to clone the clean-ups of the masked expr region terminator.
/// This is called outside of the loops for the initial mask, and inside
/// the loops for the other masked expressions.
mlir::Operation *generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper);
mlir::Location loc;
mlir::Region &region;
/// Set of operations that form the elemental parts of the
/// expression evaluation. These are the hlfir.elemental and
/// hlfir.elemental_addr that form the elemental tree producing
/// the expression value. hlfir.elemental that produce values
/// used inside transformational operations are not part of this set.
llvm::SmallSet<mlir::Operation *, 4> elementalParts{};
/// Was generateNoneElementalPart called?
bool noneElementalPartWasGenerated = false;
/// Is this expression the mask expression of the outer where statement?
/// It is special because its evaluation is not masked by anything yet.
bool isOuterMaskExpr = false;
};
} // namespace
namespace {
/// Structure that visits an ordered assignment tree and generates code for
/// it according to a schedule.
class OrderedAssignmentRewriter {
public:
OrderedAssignmentRewriter(fir::FirOpBuilder &builder,
hlfir::OrderedAssignmentTreeOpInterface root)
: builder{builder}, root{root} {}
/// Generate code for the current run of the schedule.
void lowerRun(hlfir::Run &run) {
currentRun = &run;
walk(root);
currentRun = nullptr;
assert(constructStack.empty() && "must exit constructs after a run");
mapper.clear();
savedInCurrentRunBeforeUse.clear();
}
/// After all run have been lowered, clean-up all the temporary
/// storage that were created (do not call final routines).
void cleanupSavedEntities() {
for (auto &temp : savedEntities)
temp.second.destroy(root.getLoc(), builder);
}
/// Lowered value for an expression, and the original hlfir.yield if any
/// clean-up needs to be cloned after usage.
using ValueAndCleanUp = std::pair<mlir::Value, std::optional<hlfir::YieldOp>>;
private:
/// Walk the part of an order assignment tree node that needs
/// to be evaluated in the current run.
void walk(hlfir::OrderedAssignmentTreeOpInterface node);
/// Generate code when entering a given ordered assignment node.
void pre(hlfir::ForallOp forallOp);
void pre(hlfir::ForallIndexOp);
void pre(hlfir::ForallMaskOp);
void pre(hlfir::WhereOp whereOp);
void pre(hlfir::ElseWhereOp elseWhereOp);
void pre(hlfir::RegionAssignOp);
/// Generate code when leaving a given ordered assignment node.
void post(hlfir::ForallOp);
void post(hlfir::ForallMaskOp);
void post(hlfir::WhereOp);
void post(hlfir::ElseWhereOp);
/// Enter (and maybe create) the fir.if else block of an ElseWhereOp,
/// but do not generate the elswhere mask or the new fir.if.
void enterElsewhere(hlfir::ElseWhereOp);
/// Are there any leaf region in the node that must be saved in the current
/// run?
bool mustSaveRegionIn(
hlfir::OrderedAssignmentTreeOpInterface node,
llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const;
/// Should this node be evaluated in the current run? Saving a region in a
/// node does not imply the node needs to be evaluated.
bool
isRequiredInCurrentRun(hlfir::OrderedAssignmentTreeOpInterface node) const;
/// Generate a scalar value yielded by an ordered assignment tree region.
/// If the value was not saved in a previous run, this clone the region
/// code, except the final yield, at the current execution point.
/// If the value was saved in a previous run, this fetches the saved value
/// from the temporary storage and returns the value.
/// Inside Forall, the value will be hoisted outside of the forall loops if
/// it does not depend on the forall indices.
/// An optional type can be provided to get a value from a specific type
/// (the cast will be hoisted if the computation is hoisted).
mlir::Value generateYieldedScalarValue(
mlir::Region &region,
std::optional<mlir::Type> castToType = std::nullopt);
/// Generate an entity yielded by an ordered assignment tree region, and
/// optionally return the (uncloned) yield if there is any clean-up that
/// should be done after using the entity. Like, generateYieldedScalarValue,
/// this will return the saved value if the region was saved in a previous
/// run.
ValueAndCleanUp
generateYieldedEntity(mlir::Region &region,
std::optional<mlir::Type> castToType = std::nullopt);
struct LhsValueAndCleanUp {
mlir::Value lhs;
std::optional<hlfir::YieldOp> elementalCleanup;
mlir::Region *nonElementalCleanup = nullptr;
std::optional<hlfir::LoopNest> vectorSubscriptLoopNest;
std::optional<mlir::Value> vectorSubscriptShape;
};
/// Generate the left-hand side. If the left-hand side is vector
/// subscripted (hlfir.elemental_addr), this will create a loop nest
/// (unless it was already created by a WHERE mask) and return the
/// element address.
LhsValueAndCleanUp
generateYieldedLHS(mlir::Location loc, mlir::Region &lhsRegion,
std::optional<hlfir::Entity> loweredRhs = std::nullopt);
/// If \p maybeYield is present and has a clean-up, generate the clean-up
/// at the current insertion point (by cloning).
void generateCleanupIfAny(std::optional<hlfir::YieldOp> maybeYield);
void generateCleanupIfAny(mlir::Region *cleanupRegion);
/// Generate a masked entity. This can only be called when whereLoopNest was
/// set (When an hlfir.where is being visited).
/// This method returns the scalar element (that may have been previously
/// saved) for the current indices inside the where loop.
mlir::Value generateMaskedEntity(mlir::Location loc, mlir::Region &region) {
MaskedArrayExpr maskedExpr(loc, region, /*isOuterMaskExpr=*/!whereLoopNest);
return generateMaskedEntity(maskedExpr);
}
mlir::Value generateMaskedEntity(MaskedArrayExpr &maskedExpr);
/// Create a fir.if at the current position inside the where loop nest
/// given the element value of a mask.
void generateMaskIfOp(mlir::Value cdt);
/// Save a value for subsequent runs.
void generateSaveEntity(hlfir::SaveEntity savedEntity,
bool willUseSavedEntityInSameRun);
void saveLeftHandSide(hlfir::SaveEntity savedEntity,
hlfir::RegionAssignOp regionAssignOp);
/// Get a value if it was saved in this run or a previous run. Returns
/// nullopt if it has not been saved.
std::optional<ValueAndCleanUp> getIfSaved(mlir::Region &region);
/// Generate code before the loop nest for the current run, if any.
void doBeforeLoopNest(const std::function<void()> &callback) {
if (constructStack.empty()) {
callback();
return;
}
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPoint(constructStack[0]);
callback();
builder.restoreInsertionPoint(insertionPoint);
}
/// Can the current loop nest iteration number be computed? For simplicity,
/// this is true if and only if all the bounds and steps of the fir.do_loop
/// nest dominates the outer loop. The argument is filled with the current
/// loop nest on success.
bool currentLoopNestIterationNumberCanBeComputed(
llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest);
template <typename T>
fir::factory::TemporaryStorage *insertSavedEntity(mlir::Region &region,
T &&temp) {
auto inserted =
savedEntities.insert(std::make_pair(&region, std::forward<T>(temp)));
assert(inserted.second && "temp must have been emplaced");
return &inserted.first->second;
}
fir::FirOpBuilder &builder;
/// Map containing the mapping between the original order assignment tree
/// operations and the operations that have been cloned in the current run.
/// It is reset between two runs.
mlir::IRMapping mapper;
/// Dominance info is used to determine if inner loop bounds are all computed
/// before outer loop for the current loop. It does not need to be reset
/// between runs.
mlir::DominanceInfo dominanceInfo;
/// Construct stack in the current run. This allows setting back the insertion
/// point correctly when leaving a node that requires a fir.do_loop or fir.if
/// operation.
llvm::SmallVector<mlir::Operation *> constructStack;
/// Current where loop nest, if any.
std::optional<hlfir::LoopNest> whereLoopNest;
/// Map of temporary storage to keep track of saved entity once the run
/// that saves them has been lowered. It is kept in-between runs.
/// llvm::MapVector is used to guarantee deterministic order
/// of iterating through savedEntities (e.g. for generating
/// destruction code for the temporary storages).
llvm::MapVector<mlir::Region *, fir::factory::TemporaryStorage> savedEntities;
/// Map holding the values that were saved in the current run and that also
/// need to be used (because their construct will be visited). It is reset
/// after each run. It avoids having to store and fetch in the temporary
/// during the same run, which would require the temporary to have different
/// fetching and storing counters.
llvm::DenseMap<mlir::Region *, ValueAndCleanUp> savedInCurrentRunBeforeUse;
/// Root of the order assignment tree being lowered.
hlfir::OrderedAssignmentTreeOpInterface root;
/// Pointer to the current run of the schedule being lowered.
hlfir::Run *currentRun = nullptr;
/// When allocating temporary storage inlined, indicate if the storage should
/// be heap or stack allocated. Temporary allocated with the runtime are heap
/// allocated by the runtime.
bool allocateOnHeap = true;
};
} // namespace
void OrderedAssignmentRewriter::walk(
hlfir::OrderedAssignmentTreeOpInterface node) {
bool mustVisit =
isRequiredInCurrentRun(node) || mlir::isa<hlfir::ForallIndexOp>(node);
llvm::SmallVector<hlfir::SaveEntity> saveEntities;
mlir::Operation *nodeOp = node.getOperation();
if (mustSaveRegionIn(node, saveEntities)) {
mlir::IRRewriter::InsertPoint insertionPoint;
if (auto elseWhereOp = mlir::dyn_cast<hlfir::ElseWhereOp>(nodeOp)) {
// ElseWhere mask to save must be evaluated inside the fir.if else
// for the previous where/elsewehere (its evaluation must be
// masked by the "pending control mask").
insertionPoint = builder.saveInsertionPoint();
enterElsewhere(elseWhereOp);
}
for (hlfir::SaveEntity saveEntity : saveEntities)
generateSaveEntity(saveEntity, mustVisit);
if (insertionPoint.isSet())
builder.restoreInsertionPoint(insertionPoint);
}
if (mustVisit) {
llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
.Case<hlfir::ForallOp, hlfir::ForallIndexOp, hlfir::ForallMaskOp,
hlfir::RegionAssignOp, hlfir::WhereOp, hlfir::ElseWhereOp>(
[&](auto concreteOp) { pre(concreteOp); })
.Default([](auto) {});
if (auto *body = node.getSubTreeRegion()) {
for (mlir::Operation &op : body->getOps())
if (auto subNode =
mlir::dyn_cast<hlfir::OrderedAssignmentTreeOpInterface>(op))
walk(subNode);
llvm::TypeSwitch<mlir::Operation *, void>(nodeOp)
.Case<hlfir::ForallOp, hlfir::ForallMaskOp, hlfir::WhereOp,
hlfir::ElseWhereOp>([&](auto concreteOp) { post(concreteOp); })
.Default([](auto) {});
}
}
}
void OrderedAssignmentRewriter::pre(hlfir::ForallOp forallOp) {
/// Create a fir.do_loop given the hlfir.forall control values.
mlir::Type idxTy = builder.getIndexType();
mlir::Location loc = forallOp.getLoc();
mlir::Value lb = generateYieldedScalarValue(forallOp.getLbRegion(), idxTy);
mlir::Value ub = generateYieldedScalarValue(forallOp.getUbRegion(), idxTy);
mlir::Value step;
if (forallOp.getStepRegion().empty()) {
auto insertionPoint = builder.saveInsertionPoint();
if (!constructStack.empty())
builder.setInsertionPoint(constructStack[0]);
step = builder.createIntegerConstant(loc, idxTy, 1);
if (!constructStack.empty())
builder.restoreInsertionPoint(insertionPoint);
} else {
step = generateYieldedScalarValue(forallOp.getStepRegion(), idxTy);
}
auto doLoop = builder.create<fir::DoLoopOp>(loc, lb, ub, step);
builder.setInsertionPointToStart(doLoop.getBody());
mlir::Value oldIndex = forallOp.getForallIndexValue();
mlir::Value newIndex =
builder.createConvert(loc, oldIndex.getType(), doLoop.getInductionVar());
mapper.map(oldIndex, newIndex);
constructStack.push_back(doLoop);
}
void OrderedAssignmentRewriter::post(hlfir::ForallOp) {
assert(!constructStack.empty() && "must contain a loop");
builder.setInsertionPointAfter(constructStack.pop_back_val());
}
void OrderedAssignmentRewriter::pre(hlfir::ForallIndexOp forallIndexOp) {
mlir::Location loc = forallIndexOp.getLoc();
mlir::Type intTy = fir::unwrapRefType(forallIndexOp.getType());
mlir::Value indexVar =
builder.createTemporary(loc, intTy, forallIndexOp.getName());
mlir::Value newVal = mapper.lookupOrDefault(forallIndexOp.getIndex());
builder.createStoreWithConvert(loc, newVal, indexVar);
mapper.map(forallIndexOp, indexVar);
}
void OrderedAssignmentRewriter::pre(hlfir::ForallMaskOp forallMaskOp) {
mlir::Location loc = forallMaskOp.getLoc();
mlir::Value mask = generateYieldedScalarValue(forallMaskOp.getMaskRegion(),
builder.getI1Type());
auto ifOp = builder.create<fir::IfOp>(loc, std::nullopt, mask, false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
constructStack.push_back(ifOp);
}
void OrderedAssignmentRewriter::post(hlfir::ForallMaskOp forallMaskOp) {
assert(!constructStack.empty() && "must contain an ifop");
builder.setInsertionPointAfter(constructStack.pop_back_val());
}
/// Convert an entity to the type of a given mold.
/// This is intended to help with cases where hlfir entity is a value while
/// it must be used as a variable or vice-versa. These mismatches may occur
/// between the type of user defined assignment block arguments and the actual
/// argument that was lowered for them. The actual may be an in-memory copy
/// while the block argument expects an hlfir.expr.
static hlfir::Entity
convertToMoldType(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity input, hlfir::Entity mold,
llvm::SmallVectorImpl<hlfir::CleanupFunction> &cleanups) {
if (input.getType() == mold.getType())
return input;
fir::FirOpBuilder *b = &builder;
if (input.isVariable() && mold.isValue()) {
if (fir::isa_trivial(mold.getType())) {
// fir.ref<T> to T.
mlir::Value load = builder.create<fir::LoadOp>(loc, input);
return hlfir::Entity{builder.createConvert(loc, mold.getType(), load)};
}
// fir.ref<T> to hlfir.expr<T>.
mlir::Value asExpr = builder.create<hlfir::AsExprOp>(loc, input);
if (asExpr.getType() != mold.getType())
TODO(loc, "hlfir.expr conversion");
cleanups.emplace_back([=]() { b->create<hlfir::DestroyOp>(loc, asExpr); });
return hlfir::Entity{asExpr};
}
if (input.isValue() && mold.isVariable()) {
// T to fir.ref<T>, or hlfir.expr<T> to fir.ref<T>.
hlfir::AssociateOp associate = hlfir::genAssociateExpr(
loc, builder, input, mold.getFortranElementType(), ".tmp.val2ref");
cleanups.emplace_back(
[=]() { b->create<hlfir::EndAssociateOp>(loc, associate); });
return hlfir::Entity{associate.getBase()};
}
// Variable to Variable mismatch (e.g., fir.heap<T> vs fir.ref<T>), or value
// to Value mismatch (e.g. i1 vs fir.logical<4>).
if (mlir::isa<fir::BaseBoxType>(mold.getType()) &&
!mlir::isa<fir::BaseBoxType>(input.getType())) {
// An entity may have have been saved without descriptor while the original
// value had a descriptor (e.g., it was not contiguous).
auto emboxed = hlfir::convertToBox(loc, builder, input, mold.getType());
assert(!emboxed.second && "temp should already be in memory");
input = hlfir::Entity{fir::getBase(emboxed.first)};
}
return hlfir::Entity{builder.createConvert(loc, mold.getType(), input)};
}
void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
mlir::Location loc = regionAssignOp.getLoc();
std::optional<hlfir::LoopNest> elementalLoopNest;
auto [rhsValue, oldRhsYield] =
generateYieldedEntity(regionAssignOp.getRhsRegion());
hlfir::Entity rhsEntity{rhsValue};
LhsValueAndCleanUp loweredLhs =
generateYieldedLHS(loc, regionAssignOp.getLhsRegion(), rhsEntity);
hlfir::Entity lhsEntity{loweredLhs.lhs};
if (loweredLhs.vectorSubscriptLoopNest)
rhsEntity = hlfir::getElementAt(
loc, builder, rhsEntity,
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
if (!regionAssignOp.getUserDefinedAssignment().empty()) {
hlfir::Entity userAssignLhs{regionAssignOp.getUserAssignmentLhs()};
hlfir::Entity userAssignRhs{regionAssignOp.getUserAssignmentRhs()};
std::optional<hlfir::LoopNest> elementalLoopNest;
if (lhsEntity.isArray() && userAssignLhs.isScalar()) {
// Elemental assignment with array argument (the RHS cannot be an array
// if the LHS is not).
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
elementalLoopNest->oneBasedIndices);
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
elementalLoopNest->oneBasedIndices);
}
llvm::SmallVector<hlfir::CleanupFunction, 2> argConversionCleanups;
lhsEntity = convertToMoldType(loc, builder, lhsEntity, userAssignLhs,
argConversionCleanups);
rhsEntity = convertToMoldType(loc, builder, rhsEntity, userAssignRhs,
argConversionCleanups);
mapper.map(userAssignLhs, lhsEntity);
mapper.map(userAssignRhs, rhsEntity);
for (auto &op :
regionAssignOp.getUserDefinedAssignment().front().without_terminator())
(void)builder.clone(op, mapper);
for (auto &cleanupConversion : argConversionCleanups)
cleanupConversion();
if (elementalLoopNest)
builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
} else {
// TODO: preserve allocatable assignment aspects for forall once
// they are conveyed in hlfir.region_assign.
builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
}
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest)
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
generateCleanupIfAny(oldRhsYield);
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
}
void OrderedAssignmentRewriter::generateMaskIfOp(mlir::Value cdt) {
mlir::Location loc = cdt.getLoc();
cdt = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{cdt});
cdt = builder.createConvert(loc, builder.getI1Type(), cdt);
auto ifOp = builder.create<fir::IfOp>(cdt.getLoc(), std::nullopt, cdt,
/*withElseRegion=*/false);
constructStack.push_back(ifOp.getOperation());
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}
void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
mlir::Location loc = whereOp.getLoc();
if (!whereLoopNest) {
// This is the top-level WHERE. Start a loop nest iterating on the shape of
// the where mask.
if (auto maybeSaved = getIfSaved(whereOp.getMaskRegion())) {
// Use the saved value to get the shape and condition element.
hlfir::Entity savedMask{maybeSaved->first};
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
whereLoopNest->oneBasedIndices);
generateMaskIfOp(cdt);
if (maybeSaved->second) {
// If this is the same run as the one that saved the value, the clean-up
// was left-over to be done now.
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
generateCleanupIfAny(maybeSaved->second);
builder.restoreInsertionPoint(insertionPoint);
}
return;
}
// The mask was not evaluated yet or can be safely re-evaluated.
MaskedArrayExpr mask(loc, whereOp.getMaskRegion(),
/*isOuterMaskExpr=*/true);
mask.generateNoneElementalPart(builder, mapper);
mlir::Value shape = mask.generateShape(builder, mapper);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
mlir::Value cdt = generateMaskedEntity(mask);
generateMaskIfOp(cdt);
return;
}
// Where Loops have been already created by a parent WHERE.
// Generate a fir.if with the value of the current element of the mask
// inside the loops. The case where the mask was saved is handled in the
// generateYieldedScalarValue call.
mlir::Value cdt = generateYieldedScalarValue(whereOp.getMaskRegion());
generateMaskIfOp(cdt);
}
void OrderedAssignmentRewriter::post(hlfir::WhereOp whereOp) {
assert(!constructStack.empty() && "must contain a fir.if");
builder.setInsertionPointAfter(constructStack.pop_back_val());
// If all where/elsewhere fir.if have been popped, this is the outer whereOp,
// and the where loop must be exited.
assert(!constructStack.empty() && "must contain a fir.do_loop or fir.if");
if (mlir::isa<fir::DoLoopOp>(constructStack.back())) {
builder.setInsertionPointAfter(constructStack.pop_back_val());
whereLoopNest.reset();
}
}
void OrderedAssignmentRewriter::enterElsewhere(hlfir::ElseWhereOp elseWhereOp) {
// Create an "else" region for the current where/elsewhere fir.if.
auto ifOp = mlir::dyn_cast<fir::IfOp>(constructStack.back());
assert(ifOp && "must be an if");
if (ifOp.getElseRegion().empty()) {
mlir::Location loc = elseWhereOp.getLoc();
builder.createBlock(&ifOp.getElseRegion());
auto end = builder.create<fir::ResultOp>(loc);
builder.setInsertionPoint(end);
} else {
builder.setInsertionPoint(&ifOp.getElseRegion().back().back());
}
}
void OrderedAssignmentRewriter::pre(hlfir::ElseWhereOp elseWhereOp) {
enterElsewhere(elseWhereOp);
if (elseWhereOp.getMaskRegion().empty())
return;
// Create new nested fir.if with elsewhere mask if any.
mlir::Value cdt = generateYieldedScalarValue(elseWhereOp.getMaskRegion());
generateMaskIfOp(cdt);
}
void OrderedAssignmentRewriter::post(hlfir::ElseWhereOp elseWhereOp) {
// Exit ifOp that was created for the elseWhereOp mask, if any.
if (elseWhereOp.getMaskRegion().empty())
return;
assert(!constructStack.empty() && "must contain a fir.if");
builder.setInsertionPointAfter(constructStack.pop_back_val());
}
/// Is this value a Forall index?
/// Forall index are block arguments of hlfir.forall body, or the result
/// of hlfir.forall_index.
static bool isForallIndex(mlir::Value value) {
if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) {
if (mlir::Block *block = blockArg.getOwner())
return block->isEntryBlock() &&
mlir::isa_and_nonnull<hlfir::ForallOp>(block->getParentOp());
return false;
}
return value.getDefiningOp<hlfir::ForallIndexOp>();
}
static OrderedAssignmentRewriter::ValueAndCleanUp
castIfNeeded(mlir::Location loc, fir::FirOpBuilder &builder,
OrderedAssignmentRewriter::ValueAndCleanUp valueAndCleanUp,
std::optional<mlir::Type> castToType) {
if (!castToType.has_value())
return valueAndCleanUp;
mlir::Value cast =
builder.createConvert(loc, *castToType, valueAndCleanUp.first);
return {cast, valueAndCleanUp.second};
}
std::optional<OrderedAssignmentRewriter::ValueAndCleanUp>
OrderedAssignmentRewriter::getIfSaved(mlir::Region &region) {
mlir::Location loc = region.getParentOp()->getLoc();
// If the region was saved in the same run, use the value that was evaluated
// instead of fetching the temp, and do clean-up, if any, that were delayed.
// This is done to avoid requiring the temporary stack to have different
// fetching and storing counters, and also because it produces slightly better
// code.
if (auto savedInSameRun = savedInCurrentRunBeforeUse.find(&region);
savedInSameRun != savedInCurrentRunBeforeUse.end())
return savedInSameRun->second;
// If the region was saved in a previous run, fetch the saved value.
if (auto temp = savedEntities.find(&region); temp != savedEntities.end()) {
doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
return ValueAndCleanUp{temp->second.fetch(loc, builder), std::nullopt};
}
return std::nullopt;
}
static hlfir::YieldOp getYield(mlir::Region &region) {
auto yield = mlir::dyn_cast_or_null<hlfir::YieldOp>(
region.back().getOperations().back());
assert(yield && "region computing entities must end with a YieldOp");
return yield;
}
OrderedAssignmentRewriter::ValueAndCleanUp
OrderedAssignmentRewriter::generateYieldedEntity(
mlir::Region &region, std::optional<mlir::Type> castToType) {
mlir::Location loc = region.getParentOp()->getLoc();
if (auto maybeValueAndCleanUp = getIfSaved(region))
return castIfNeeded(loc, builder, *maybeValueAndCleanUp, castToType);
// Otherwise, evaluate the region now.
// Masked expression must not evaluate the elemental parts that are masked,
// they have custom code generation.
if (whereLoopNest.has_value()) {
mlir::Value maskedValue = generateMaskedEntity(loc, region);
return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
}
assert(region.hasOneBlock() && "region must contain one block");
auto oldYield = getYield(region);
mlir::Block::OpListType &ops = region.back().getOperations();
// Inside Forall, scalars that do not depend on forall indices can be hoisted
// here because their evaluation is required to only call pure procedures, and
// if they depend on a variable previously assigned to in a forall assignment,
// this assignment must have been scheduled in a previous run. Hoisting of
// scalars is done here to help creating simple temporary storage if needed.
// Inner forall bounds can often be hoisted, and this allows computing the
// total number of iterations to create temporary storages.
bool hoistComputation = false;
if (fir::isa_trivial(oldYield.getEntity().getType()) &&
!constructStack.empty()) {
hoistComputation = true;
for (mlir::Operation &op : ops)
if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
return isForallIndex(value);
})) {
hoistComputation = false;
break;
}
}
auto insertionPoint = builder.saveInsertionPoint();
if (hoistComputation)
builder.setInsertionPoint(constructStack[0]);
// Clone all operations except the final hlfir.yield.
assert(!ops.empty() && "yield block cannot be empty");
auto end = ops.end();
for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
(void)builder.clone(*opIt, mapper);
// Get the value for the yielded entity, it may be the result of an operation
// that was cloned, or it may be the same as the previous value if the yield
// operand was created before the ordered assignment tree.
mlir::Value newEntity = mapper.lookupOrDefault(oldYield.getEntity());
if (castToType.has_value())
newEntity =
builder.createConvert(newEntity.getLoc(), *castToType, newEntity);
if (hoistComputation) {
// Hoisted trivial scalars clean-up can be done right away, the value is
// in registers.
generateCleanupIfAny(oldYield);
builder.restoreInsertionPoint(insertionPoint);
return {newEntity, std::nullopt};
}
if (oldYield.getCleanup().empty())
return {newEntity, std::nullopt};
return {newEntity, oldYield};
}
mlir::Value OrderedAssignmentRewriter::generateYieldedScalarValue(
mlir::Region &region, std::optional<mlir::Type> castToType) {
mlir::Location loc = region.getParentOp()->getLoc();
auto [value, maybeYield] = generateYieldedEntity(region, castToType);
value = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{value});
assert(fir::isa_trivial(value.getType()) && "not a trivial scalar value");
generateCleanupIfAny(maybeYield);
return value;
}
OrderedAssignmentRewriter::LhsValueAndCleanUp
OrderedAssignmentRewriter::generateYieldedLHS(
mlir::Location loc, mlir::Region &lhsRegion,
std::optional<hlfir::Entity> loweredRhs) {
LhsValueAndCleanUp loweredLhs;
hlfir::ElementalAddrOp elementalAddrLhs =
mlir::dyn_cast<hlfir::ElementalAddrOp>(lhsRegion.back().back());
if (auto temp = savedEntities.find(&lhsRegion); temp != savedEntities.end()) {
// The LHS address was computed and saved in a previous run. Fetch it.
doBeforeLoopNest([&]() { temp->second.resetFetchPosition(loc, builder); });
if (elementalAddrLhs && !whereLoopNest) {
// Vector subscripted designator address are saved element by element.
// If no "elemental" loops have been created yet, the shape of the
// RHS, if it is an array can be used, or the shape of the vector
// subscripted designator must be retrieved to generate the "elemental"
// loop nest.
if (loweredRhs && loweredRhs->isArray()) {
// The RHS shape can be used to create the elemental loops and avoid
// saving the LHS shape.
loweredLhs.vectorSubscriptShape =
hlfir::genShape(loc, builder, *loweredRhs);
} else {
// If the shape cannot be retrieved from the RHS, it must have been
// saved. Get it from the temporary.
auto &vectorTmp =
temp->second.cast<fir::factory::AnyVectorSubscriptStack>();
loweredLhs.vectorSubscriptShape = vectorTmp.fetchShape(loc, builder);
}
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
loc, builder, loweredLhs.vectorSubscriptShape.value());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
}
loweredLhs.lhs = temp->second.fetch(loc, builder);
return loweredLhs;
}
// The LHS has not yet been evaluated and saved. Evaluate it now.
if (elementalAddrLhs && !whereLoopNest) {
// This is a vector subscripted entity. The address of elements must
// be returned. If no "elemental" loops have been created for a WHERE,
// create them now based on the vector subscripted designator shape.
for (auto &op : lhsRegion.front().without_terminator())
(void)builder.clone(op, mapper);
loweredLhs.vectorSubscriptShape =
mapper.lookupOrDefault(elementalAddrLhs.getShape());
loweredLhs.vectorSubscriptLoopNest =
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
!elementalAddrLhs.isOrdered());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
mapper.map(elementalAddrLhs.getIndices(),
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
(void)builder.clone(op, mapper);
loweredLhs.elementalCleanup = elementalAddrLhs.getYieldOp();
loweredLhs.lhs =
mapper.lookupOrDefault(loweredLhs.elementalCleanup->getEntity());
} else {
// This is a designator without vector subscripts. Generate it as
// it is done for other entities.
auto [lhs, yield] = generateYieldedEntity(lhsRegion);
loweredLhs.lhs = lhs;
if (yield && !yield->getCleanup().empty())
loweredLhs.nonElementalCleanup = &yield->getCleanup();
}
return loweredLhs;
}
mlir::Value
OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
assert(whereLoopNest.has_value() && "must be inside WHERE loop nest");
auto insertionPoint = builder.saveInsertionPoint();
if (!maskedExpr.noneElementalPartWasGenerated) {
// Generate none elemental part before the where loops (but inside the
// current forall loops if any).
builder.setInsertionPoint(whereLoopNest->outerLoop);
maskedExpr.generateNoneElementalPart(builder, mapper);
}
// Generate the none elemental part cleanup after the where loops.
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
// Generate the value of the current element for the masked expression
// at the current insertion point (inside the where loops, and any fir.if
// generated for previous masks).
builder.restoreInsertionPoint(insertionPoint);
mlir::Value scalar = maskedExpr.generateElementalParts(
builder, whereLoopNest->oneBasedIndices, mapper);
/// Generate cleanups for the elemental parts inside the loops (setting the
/// location so that the assignment will be generated before the cleanups).
if (!maskedExpr.isOuterMaskExpr)
if (mlir::Operation *firstCleanup =
maskedExpr.generateMaskedExprCleanUps(builder, mapper))
builder.setInsertionPoint(firstCleanup);
return scalar;
}
void OrderedAssignmentRewriter::generateCleanupIfAny(
std::optional<hlfir::YieldOp> maybeYield) {
if (maybeYield.has_value())
generateCleanupIfAny(&maybeYield->getCleanup());
}
void OrderedAssignmentRewriter::generateCleanupIfAny(
mlir::Region *cleanupRegion) {
if (cleanupRegion && !cleanupRegion->empty()) {
assert(cleanupRegion->hasOneBlock() && "region must contain one block");
for (auto &op : cleanupRegion->back().without_terminator())
builder.clone(op, mapper);
}
}
bool OrderedAssignmentRewriter::mustSaveRegionIn(
hlfir::OrderedAssignmentTreeOpInterface node,
llvm::SmallVectorImpl<hlfir::SaveEntity> &saveEntities) const {
for (auto &action : currentRun->actions)
if (hlfir::SaveEntity *savedEntity =
std::get_if<hlfir::SaveEntity>(&action))
if (node.getOperation() == savedEntity->yieldRegion->getParentOp())
saveEntities.push_back(*savedEntity);
return !saveEntities.empty();
}
bool OrderedAssignmentRewriter::isRequiredInCurrentRun(
hlfir::OrderedAssignmentTreeOpInterface node) const {
// hlfir.forall_index do not contain saved regions/assignments,
// but if their hlfir.forall parent was required, they are
// required (the forall indices needs to be mapped).
if (mlir::isa<hlfir::ForallIndexOp>(node))
return true;
for (auto &action : currentRun->actions)
if (hlfir::SaveEntity *savedEntity =
std::get_if<hlfir::SaveEntity>(&action)) {
// A SaveEntity action does not require evaluating the node that contains
// it, but it requires to evaluate all the parents of the nodes that
// contains it. For instance, an saving a bound in hlfir.forall B does not
// require creating the loops for B, but it requires creating the loops
// for any forall parent A of the forall B.
if (node->isProperAncestor(savedEntity->yieldRegion->getParentOp()))
return true;
} else {
auto assign = std::get<hlfir::RegionAssignOp>(action);
if (node->isAncestor(assign.getOperation()))
return true;
}
return false;
}
/// Is the apply using all the elemental indices in order?
static bool isInOrderApply(hlfir::ApplyOp apply,
hlfir::ElementalOpInterface elemental) {
mlir::Region::BlockArgListType elementalIndices = elemental.getIndices();
if (elementalIndices.size() != apply.getIndices().size())
return false;
for (auto [elementalIdx, applyIdx] :
llvm::zip(elementalIndices, apply.getIndices()))
if (elementalIdx != applyIdx)
return false;
return true;
}
/// Gather the tree of hlfir::ElementalOpInterface use-def, if any, starting
/// from \p elemental, which may be a nullptr.
static void
gatherElementalTree(hlfir::ElementalOpInterface elemental,
llvm::SmallPtrSetImpl<mlir::Operation *> &elementalOps,
bool isOutOfOrder) {
if (elemental) {
// Only inline an applied elemental that must be executed in order if the
// applying indices are in order. An hlfir::Elemental may have been created
// for a transformational like transpose, and Fortran 2018 standard
// section 10.2.3.2, point 10 imply that impure elemental sub-expression
// evaluations should not be masked if they are the arguments of
// transformational expressions.
if (isOutOfOrder && elemental.isOrdered())
return;
elementalOps.insert(elemental.getOperation());
for (mlir::Operation &op : elemental.getElementalRegion().getOps())
if (auto apply = mlir::dyn_cast<hlfir::ApplyOp>(op)) {
bool isUnorderedApply =
isOutOfOrder || !isInOrderApply(apply, elemental);
auto maybeElemental =
mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
apply.getExpr().getDefiningOp());
gatherElementalTree(maybeElemental, elementalOps, isUnorderedApply);
}
}
}
MaskedArrayExpr::MaskedArrayExpr(mlir::Location loc, mlir::Region &region,
bool isOuterMaskExpr)
: loc{loc}, region{region}, isOuterMaskExpr{isOuterMaskExpr} {
mlir::Operation &terminator = region.back().back();
if (auto elementalAddr =
mlir::dyn_cast<hlfir::ElementalOpInterface>(terminator)) {
// Vector subscripted designator (hlfir.elemental_addr terminator).
gatherElementalTree(elementalAddr, elementalParts, /*isOutOfOrder=*/false);
return;
}
// Try if elemental expression.
mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
auto maybeElemental = mlir::dyn_cast_or_null<hlfir::ElementalOpInterface>(
entity.getDefiningOp());
gatherElementalTree(maybeElemental, elementalParts, /*isOutOfOrder=*/false);
}
void MaskedArrayExpr::generateNoneElementalPart(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper) {
assert(!noneElementalPartWasGenerated &&
"none elemental parts already generated");
if (isOuterMaskExpr) {
// The outer mask expression is actually not masked, it is dealt as
// such so that its elemental part, if any, can be inlined in the WHERE
// loops. But all of the operations outside of hlfir.elemental/
// hlfir.elemental_addr must be emitted now because their value may be
// required to deduce the mask shape and the WHERE loop bounds.
for (mlir::Operation &op : region.back().without_terminator())
if (!elementalParts.contains(&op))
(void)builder.clone(op, mapper);
} else {
// For actual masked expressions, Fortran requires elemental expressions,
// even the scalar ones that are not encoded with hlfir.elemental, to be
// evaluated only when the mask is true. Blindly hoisting all scalar SSA
// tree could be wrong if the scalar computation has side effects and
// would never have been evaluated (e.g. division by zero) if the mask
// is fully false. See F'2023 10.2.3.2 point 10.
// Clone only the bodies of all hlfir.exactly_once operations, which contain
// the evaluation of sub-expression tree whose root was a non elemental
// function call at the Fortran level (the call itself may have been inlined
// since). These must be evaluated only once as per F'2023 10.2.3.2 point 9.
for (mlir::Operation &op : region.back().without_terminator())
if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
for (mlir::Operation &subOp :
exactlyOnce.getBody().back().without_terminator())
(void)builder.clone(subOp, mapper);
mlir::Value oldYield = getYield(exactlyOnce.getBody()).getEntity();
auto newYield = mapper.lookupOrDefault(oldYield);
mapper.map(exactlyOnce.getResult(), newYield);
}
}
noneElementalPartWasGenerated = true;
}
mlir::Value MaskedArrayExpr::generateShape(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper) {
assert(noneElementalPartWasGenerated &&
"non elemental part must have been generated");
mlir::Operation &terminator = region.back().back();
// If the operation that produced the yielded entity is elemental, it was not
// cloned, but it holds a shape argument that was cloned. Return the cloned
// shape.
if (auto elementalAddrOp = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator))
return mapper.lookupOrDefault(elementalAddrOp.getShape());
mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
if (auto elemental = entity.getDefiningOp<hlfir::ElementalOp>())
return mapper.lookupOrDefault(elemental.getShape());
// Otherwise, the whole entity was cloned, and the shape can be generated
// from it.
hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
return hlfir::genShape(loc, builder, hlfir::Entity{clonedEntity});
}
mlir::Value
MaskedArrayExpr::generateElementalParts(fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
mlir::IRMapping &mapper) {
assert(noneElementalPartWasGenerated &&
"non elemental part must have been generated");
if (!isOuterMaskExpr) {
// Clone all operations that are not hlfir.exactly_once and that are not
// hlfir.elemental/hlfir.elemental_addr.
for (mlir::Operation &op : region.back().without_terminator())
if (!mlir::isa<hlfir::ExactlyOnceOp>(op) && !elementalParts.contains(&op))
(void)builder.clone(op, mapper);
// For the outer mask, this was already done outside of the loop.
}
// Clone and "index" bodies of hlfir.elemental/hlfir.elemental_addr.
mlir::Operation &terminator = region.back().back();
hlfir::ElementalOpInterface elemental =
mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator);
if (!elemental) {
// If the terminator is not an hlfir.elemental_addr, try if the yielded
// entity was produced by an hlfir.elemental.
mlir::Value entity = mlir::cast<hlfir::YieldOp>(terminator).getEntity();
elemental = entity.getDefiningOp<hlfir::ElementalOp>();
if (!elemental) {
// The yielded entity was not produced by an elemental operation,
// get its clone in the non elemental part evaluation and address it.
hlfir::Entity clonedEntity{mapper.lookupOrDefault(entity)};
return hlfir::getElementAt(loc, builder, clonedEntity, oneBasedIndices);
}
}
auto mustRecursivelyInline =
[&](hlfir::ElementalOp appliedElemental) -> bool {
return elementalParts.contains(appliedElemental.getOperation());
};
return inlineElementalOp(loc, builder, elemental, oneBasedIndices, mapper,
mustRecursivelyInline);
}
mlir::Operation *
MaskedArrayExpr::generateMaskedExprCleanUps(fir::FirOpBuilder &builder,
mlir::IRMapping &mapper) {
// Clone the clean-ups from the region itself, except for the destroy
// of the hlfir.elemental that have been inlined.
mlir::Operation &terminator = region.back().back();
mlir::Region *cleanupRegion = nullptr;
if (auto elementalAddr = mlir::dyn_cast<hlfir::ElementalAddrOp>(terminator)) {
cleanupRegion = &elementalAddr.getCleanup();
} else {
auto yieldOp = mlir::cast<hlfir::YieldOp>(terminator);
cleanupRegion = &yieldOp.getCleanup();
}
if (cleanupRegion->empty())
return nullptr;
mlir::Operation *firstNewCleanup = nullptr;
for (mlir::Operation &op : cleanupRegion->front().without_terminator()) {
if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(op))
if (elementalParts.contains(destroy.getExpr().getDefiningOp()))
continue;
mlir::Operation *cleanup = builder.clone(op, mapper);
if (!firstNewCleanup)
firstNewCleanup = cleanup;
}
return firstNewCleanup;
}
void MaskedArrayExpr::generateNoneElementalCleanupIfAny(
fir::FirOpBuilder &builder, mlir::IRMapping &mapper) {
if (!isOuterMaskExpr) {
// Clone clean-ups of hlfir.exactly_once operations (in reverse order
// to properly deal with stack restores).
for (mlir::Operation &op :
llvm::reverse(region.back().without_terminator()))
if (auto exactlyOnce = mlir::dyn_cast<hlfir::ExactlyOnceOp>(op)) {
mlir::Region &cleanupRegion =
getYield(exactlyOnce.getBody()).getCleanup();
if (!cleanupRegion.empty())
for (mlir::Operation &cleanupOp :
cleanupRegion.front().without_terminator())
(void)builder.clone(cleanupOp, mapper);
}
} else {
// For the outer mask, the region clean-ups must be generated
// outside of the loops since the mask non hlfir.elemental part
// is generated before the loops.
generateMaskedExprCleanUps(builder, mapper);
}
}
static hlfir::RegionAssignOp
getAssignIfLeftHandSideRegion(mlir::Region &region) {
auto assign = mlir::dyn_cast<hlfir::RegionAssignOp>(region.getParentOp());
if (assign && (&assign.getLhsRegion() == &region))
return assign;
return nullptr;
}
bool OrderedAssignmentRewriter::currentLoopNestIterationNumberCanBeComputed(
llvm::SmallVectorImpl<fir::DoLoopOp> &loopNest) {
if (constructStack.empty())
return true;
mlir::Operation *outerLoop = constructStack[0];
mlir::Operation *currentConstruct = constructStack.back();
// Loop through the loops until the outer construct is met, and test if the
// loop operands dominate the outer construct.
while (currentConstruct) {
if (auto doLoop = mlir::dyn_cast<fir::DoLoopOp>(currentConstruct)) {
if (llvm::any_of(doLoop->getOperands(), [&](mlir::Value value) {
return !dominanceInfo.properlyDominates(value, outerLoop);
})) {
return false;
}
loopNest.push_back(doLoop);
}
if (currentConstruct == outerLoop)
currentConstruct = nullptr;
else
currentConstruct = currentConstruct->getParentOp();
}
return true;
}
static mlir::Value
computeLoopNestIterationNumber(mlir::Location loc, fir::FirOpBuilder &builder,
llvm::ArrayRef<fir::DoLoopOp> loopNest) {
mlir::Value loopExtent;
for (fir::DoLoopOp doLoop : loopNest) {
mlir::Value extent = builder.genExtentFromTriplet(
loc, doLoop.getLowerBound(), doLoop.getUpperBound(), doLoop.getStep(),
builder.getIndexType());
if (!loopExtent)
loopExtent = extent;
else
loopExtent = builder.create<mlir::arith::MulIOp>(loc, loopExtent, extent);
}
assert(loopExtent && "loopNest must not be empty");
return loopExtent;
}
/// Return a name for temporary storage that indicates in which context
/// the temporary storage was created.
static llvm::StringRef
getTempName(hlfir::OrderedAssignmentTreeOpInterface root) {
if (mlir::isa<hlfir::ForallOp>(root.getOperation()))
return ".tmp.forall";
if (mlir::isa<hlfir::WhereOp>(root.getOperation()))
return ".tmp.where";
return ".tmp.assign";
}
void OrderedAssignmentRewriter::generateSaveEntity(
hlfir::SaveEntity savedEntity, bool willUseSavedEntityInSameRun) {
mlir::Region &region = *savedEntity.yieldRegion;
if (hlfir::RegionAssignOp regionAssignOp =
getAssignIfLeftHandSideRegion(region)) {
// Need to save the address, not the values.
assert(!willUseSavedEntityInSameRun &&
"lhs cannot be used in the loop nest where it is saved");
return saveLeftHandSide(savedEntity, regionAssignOp);
}
mlir::Location loc = region.getParentOp()->getLoc();
// Evaluate the region inside the loop nest (if any).
auto [clonedValue, oldYield] = generateYieldedEntity(region);
hlfir::Entity entity{clonedValue};
entity = hlfir::loadTrivialScalar(loc, builder, entity);
mlir::Type entityType = entity.getType();
llvm::StringRef tempName = getTempName(root);
fir::factory::TemporaryStorage *temp = nullptr;
if (constructStack.empty()) {
// Value evaluated outside of any loops (this may be the first MASK of a
// WHERE construct, or an LHS/RHS temp of hlfir.region_assign outside of
// WHERE/FORALL).
temp = insertSavedEntity(
region, fir::factory::SimpleCopy(loc, builder, entity, tempName));
} else {
// Need to create a temporary for values computed inside loops.
// Create temporary storage outside of the loop nest given the entity
// type (and the loop context).
llvm::SmallVector<fir::DoLoopOp> loopNest;
bool loopShapeCanBePreComputed =
currentLoopNestIterationNumberCanBeComputed(loopNest);
doBeforeLoopNest([&] {
/// For simple scalars inside loops whose total iteration number can be
/// pre-computed, create a rank-1 array outside of the loops. It will be
/// assigned/fetched inside the loops like a normal Fortran array given
/// the iteration count.
if (loopShapeCanBePreComputed && fir::isa_trivial(entityType)) {
mlir::Value loopExtent =
computeLoopNestIterationNumber(loc, builder, loopNest);
auto sequenceType =
mlir::cast<fir::SequenceType>(builder.getVarLenSeqTy(entityType));
temp = insertSavedEntity(region,
fir::factory::HomogeneousScalarStack{
loc, builder, sequenceType, loopExtent,
/*lenParams=*/{}, allocateOnHeap,
/*stackThroughLoops=*/true, tempName});
} else {
// If the number of iteration is not known, or if the values at each
// iterations are values that may have different shape, type parameters
// or dynamic type, use the runtime to create and manage a stack-like
// temporary.
temp = insertSavedEntity(
region, fir::factory::AnyValueStack{loc, builder, entityType});
}
});
// Inside the loop nest (and any fir.if if there are active masks), copy
// the value to the temp and do clean-ups for the value if any.
temp->pushValue(loc, builder, entity);
}
// Delay the clean-up if the entity will be used in the same run (i.e., the
// parent construct will be visited and needs to be lowered). When possible,
// this is not done for hlfir.expr because this use would prevent the
// hlfir.expr storage from being moved when creating the temporary in
// bufferization, and that would lead to an extra copy.
if (willUseSavedEntityInSameRun &&
(!temp->canBeFetchedAfterPush() ||
!mlir::isa<hlfir::ExprType>(entity.getType()))) {
auto inserted =
savedInCurrentRunBeforeUse.try_emplace(&region, entity, oldYield);
assert(inserted.second && "entity must have been emplaced");
(void)inserted;
} else {
if (constructStack.empty() &&
mlir::isa<hlfir::RegionAssignOp>(region.getParentOp())) {
// Here the clean-up code is inserted after the original
// RegionAssignOp, so that the assignment code happens
// before the cleanup. We do this only for standalone
// operations, because the clean-up is handled specially
// during lowering of the parent constructs if any
// (e.g. see generateNoneElementalCleanupIfAny for
// WhereOp).
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(region.getParentOp());
generateCleanupIfAny(oldYield);
builder.restoreInsertionPoint(insertionPoint);
} else {
generateCleanupIfAny(oldYield);
}
}
}
static bool rhsIsArray(hlfir::RegionAssignOp regionAssignOp) {
auto yieldOp = mlir::dyn_cast<hlfir::YieldOp>(
regionAssignOp.getRhsRegion().back().back());
return yieldOp && hlfir::Entity{yieldOp.getEntity()}.isArray();
}
void OrderedAssignmentRewriter::saveLeftHandSide(
hlfir::SaveEntity savedEntity, hlfir::RegionAssignOp regionAssignOp) {
mlir::Region &region = *savedEntity.yieldRegion;
mlir::Location loc = region.getParentOp()->getLoc();
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
fir::factory::TemporaryStorage *temp = nullptr;
if (loweredLhs.vectorSubscriptLoopNest)
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
// Vector subscripted entity for which the shape must also be saved on top
// of the element addresses (e.g. the shape may change in each forall
// iteration and is needed to create the elemental loops).
mlir::Value shape = loweredLhs.vectorSubscriptShape.value();
int rank = mlir::cast<fir::ShapeType>(shape.getType()).getRank();
const bool shapeIsInvariant =
constructStack.empty() ||
dominanceInfo.properlyDominates(shape, constructStack[0]);
doBeforeLoopNest([&] {
// Outside of any forall/where/elemental loops, create a temporary that
// will both be able to save the vector subscripted designator shape(s)
// and element addresses.
temp =
insertSavedEntity(region, fir::factory::AnyVectorSubscriptStack{
loc, builder, loweredLhs.lhs.getType(),
shapeIsInvariant, rank});
});
// Save shape before the elemental loop nest created by the vector
// subscripted LHS.
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
vectorTmp.pushShape(loc, builder, shape);
builder.restoreInsertionPoint(insertionPoint);
} else {
// Otherwise, only save the LHS address.
// If the LHS address dominates the constructs, its SSA value can
// simply be tracked and there is no need to save the address in memory.
// Otherwise, the addresses are stored at each iteration in memory with
// a descriptor stack.
if (constructStack.empty() ||
dominanceInfo.properlyDominates(loweredLhs.lhs, constructStack[0]))
doBeforeLoopNest([&] {
temp = insertSavedEntity(region, fir::factory::SSARegister{});
});
else
doBeforeLoopNest([&] {
temp = insertSavedEntity(
region, fir::factory::AnyVariableStack{loc, builder,
loweredLhs.lhs.getType()});
});
}
temp->pushValue(loc, builder, loweredLhs.lhs);
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest) {
constructStack.pop_back();
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
}
}
/// Lower an ordered assignment tree to fir.do_loop and hlfir.assign given
/// a schedule.
static void lower(hlfir::OrderedAssignmentTreeOpInterface root,
mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) {
auto module = root->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, module);
OrderedAssignmentRewriter assignmentRewriter(builder, root);
for (auto &run : schedule)
assignmentRewriter.lowerRun(run);
assignmentRewriter.cleanupSavedEntities();
}
/// Shared rewrite entry point for all the ordered assignment tree root
/// operations. It calls the scheduler and then apply the schedule.
static mlir::LogicalResult rewrite(hlfir::OrderedAssignmentTreeOpInterface root,
bool tryFusingAssignments,
mlir::PatternRewriter &rewriter) {
hlfir::Schedule schedule =
hlfir::buildEvaluationSchedule(root, tryFusingAssignments);
LLVM_DEBUG(
/// Debug option to print the scheduling debug info without doing
/// any code generation. The operations are simply erased to avoid
/// failing and calling the rewrite patterns on nested operations.
/// The only purpose of this is to help testing scheduling without
/// having to test generated code.
if (dbgScheduleOnly) {
rewriter.eraseOp(root);
return mlir::success();
});
lower(root, rewriter, schedule);
rewriter.eraseOp(root);
return mlir::success();
}
namespace {
class ForallOpConversion : public mlir::OpRewritePattern<hlfir::ForallOp> {
public:
explicit ForallOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
: OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
mlir::LogicalResult
matchAndRewrite(hlfir::ForallOp forallOp,
mlir::PatternRewriter &rewriter) const override {
auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
forallOp.getOperation());
if (mlir::failed(::rewrite(root, tryFusingAssignments, rewriter)))
TODO(forallOp.getLoc(), "FORALL construct or statement in HLFIR");
return mlir::success();
}
const bool tryFusingAssignments;
};
class WhereOpConversion : public mlir::OpRewritePattern<hlfir::WhereOp> {
public:
explicit WhereOpConversion(mlir::MLIRContext *ctx, bool tryFusingAssignments)
: OpRewritePattern{ctx}, tryFusingAssignments{tryFusingAssignments} {}
mlir::LogicalResult
matchAndRewrite(hlfir::WhereOp whereOp,
mlir::PatternRewriter &rewriter) const override {
auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
whereOp.getOperation());
return ::rewrite(root, tryFusingAssignments, rewriter);
}
const bool tryFusingAssignments;
};
class RegionAssignConversion
: public mlir::OpRewritePattern<hlfir::RegionAssignOp> {
public:
explicit RegionAssignConversion(mlir::MLIRContext *ctx)
: OpRewritePattern{ctx} {}
mlir::LogicalResult
matchAndRewrite(hlfir::RegionAssignOp regionAssignOp,
mlir::PatternRewriter &rewriter) const override {
auto root = mlir::cast<hlfir::OrderedAssignmentTreeOpInterface>(
regionAssignOp.getOperation());
return ::rewrite(root, /*tryFusingAssignments=*/false, rewriter);
}
};
class LowerHLFIROrderedAssignments
: public hlfir::impl::LowerHLFIROrderedAssignmentsBase<
LowerHLFIROrderedAssignments> {
public:
using LowerHLFIROrderedAssignmentsBase<
LowerHLFIROrderedAssignments>::LowerHLFIROrderedAssignmentsBase;
void runOnOperation() override {
// Running on a ModuleOp because this pass may generate FuncOp declaration
// for runtime calls. This could be a FuncOp pass otherwise.
auto module = this->getOperation();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
// Patterns are only defined for the OrderedAssignmentTreeOpInterface
// operations that can be the root of ordered assignments. The other
// operations will be taken care of while rewriting these trees (they
// cannot exist outside of these operations given their verifiers/traits).
patterns.insert<ForallOpConversion, WhereOpConversion>(
context, this->tryFusingAssignments.getValue());
patterns.insert<RegionAssignConversion>(context);
mlir::ConversionTarget target(*context);
target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
return !mlir::isa<hlfir::OrderedAssignmentTreeOpInterface>(op);
});
if (mlir::failed(mlir::applyPartialConversion(module, target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(context),
"failure in HLFIR ordered assignments lowering pass");
signalPassFailure();
}
}
};
} // namespace