blob: 9adbe6a497214c2640f047966483029cf8d59c8a [file] [log] [blame]
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Emit OpenACC clause nodes as CIR code.
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
namespace clang {
// Simple type-trait to see if the first template arg is one of the list, so we
// can tell whether to `if-constexpr` a bunch of stuff.
template <typename ToTest, typename T, typename... Tys>
constexpr bool isOneOfTypes =
std::is_same_v<ToTest, T> || isOneOfTypes<ToTest, Tys...>;
template <typename ToTest, typename T>
constexpr bool isOneOfTypes<ToTest, T> = std::is_same_v<ToTest, T>;
// Holds information for emitting clauses for a combined construct. We
// instantiate the clause emitter with this type so that it can use
// if-constexpr to specially handle these.
template <typename CompOpTy> struct CombinedConstructClauseInfo {
using ComputeOpTy = CompOpTy;
ComputeOpTy computeOp;
mlir::acc::LoopOp loopOp;
};
template <typename ToTest> constexpr bool isCombinedType = false;
template <typename T>
constexpr bool isCombinedType<CombinedConstructClauseInfo<T>> = true;
template <typename OpTy>
class OpenACCClauseCIREmitter final
: public OpenACCClauseVisitor<OpenACCClauseCIREmitter<OpTy>> {
// Necessary for combined constructs.
template <typename FriendOpTy> friend class OpenACCClauseCIREmitter;
OpTy &operation;
CIRGen::CIRGenFunction &cgf;
CIRGen::CIRGenBuilderTy &builder;
// This is necessary since a few of the clauses emit differently based on the
// directive kind they are attached to.
OpenACCDirectiveKind dirKind;
// TODO(cir): This source location should be able to go away once the NYI
// diagnostics are gone.
SourceLocation dirLoc;
llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
lastDeviceTypeValues.clear();
llvm::for_each(clause.getArchitectures(),
[this](const DeviceTypeArgument &arg) {
lastDeviceTypeValues.push_back(
decodeDeviceType(arg.getIdentifierInfo()));
});
}
void clauseNotImplemented(const OpenACCClause &c) {
cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
}
mlir::Value createIntExpr(const Expr *intExpr) {
mlir::Value expr = cgf.emitScalarExpr(intExpr);
mlir::Location exprLoc = cgf.cgm.getLoc(intExpr->getBeginLoc());
mlir::IntegerType targetType = mlir::IntegerType::get(
&cgf.getMLIRContext(), cgf.getContext().getIntWidth(intExpr->getType()),
intExpr->getType()->isSignedIntegerOrEnumerationType()
? mlir::IntegerType::SignednessSemantics::Signed
: mlir::IntegerType::SignednessSemantics::Unsigned);
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
exprLoc, targetType, expr);
return conversionOp.getResult(0);
}
// 'condition' as an OpenACC grammar production is used for 'if' and (some
// variants of) 'self'. It needs to be emitted as a signless-1-bit value, so
// this function emits the expression, then sets the unrealized conversion
// cast correctly, and returns the completed value.
mlir::Value createCondition(const Expr *condExpr) {
mlir::Value condition = cgf.evaluateExprAsBool(condExpr);
mlir::Location exprLoc = cgf.cgm.getLoc(condExpr->getBeginLoc());
mlir::IntegerType targetType = mlir::IntegerType::get(
&cgf.getMLIRContext(), /*width=*/1,
mlir::IntegerType::SignednessSemantics::Signless);
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
exprLoc, targetType, condition);
return conversionOp.getResult(0);
}
mlir::Value createConstantInt(mlir::Location loc, unsigned width,
int64_t value) {
mlir::IntegerType ty = mlir::IntegerType::get(
&cgf.getMLIRContext(), width,
mlir::IntegerType::SignednessSemantics::Signless);
auto constOp = builder.create<mlir::arith::ConstantOp>(
loc, builder.getIntegerAttr(ty, value));
return constOp.getResult();
}
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
// '*' case leaves no identifier-info, just a nullptr.
if (!ii)
return mlir::acc::DeviceType::Star;
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
.CaseLower("default", mlir::acc::DeviceType::Default)
.CaseLower("host", mlir::acc::DeviceType::Host)
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
.CasesLower("nvidia", "acc_device_nvidia",
mlir::acc::DeviceType::Nvidia)
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
}
mlir::acc::GangArgType decodeGangType(OpenACCGangKind gk) {
switch (gk) {
case OpenACCGangKind::Num:
return mlir::acc::GangArgType::Num;
case OpenACCGangKind::Dim:
return mlir::acc::GangArgType::Dim;
case OpenACCGangKind::Static:
return mlir::acc::GangArgType::Static;
}
llvm_unreachable("unknown gang kind");
}
template <typename U = void,
typename = std::enable_if_t<isCombinedType<OpTy>, U>>
void applyToLoopOp(const OpenACCClause &c) {
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPoint(operation.loopOp);
OpenACCClauseCIREmitter<mlir::acc::LoopOp> loopEmitter{
operation.loopOp, cgf, builder, dirKind, dirLoc};
loopEmitter.lastDeviceTypeValues = lastDeviceTypeValues;
loopEmitter.Visit(&c);
}
template <typename U = void,
typename = std::enable_if_t<isCombinedType<OpTy>, U>>
void applyToComputeOp(const OpenACCClause &c) {
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPoint(operation.computeOp);
OpenACCClauseCIREmitter<typename OpTy::ComputeOpTy> computeEmitter{
operation.computeOp, cgf, builder, dirKind, dirLoc};
computeEmitter.lastDeviceTypeValues = lastDeviceTypeValues;
computeEmitter.Visit(&c);
}
public:
OpenACCClauseCIREmitter(OpTy &operation, CIRGen::CIRGenFunction &cgf,
CIRGen::CIRGenBuilderTy &builder,
OpenACCDirectiveKind dirKind, SourceLocation dirLoc)
: operation(operation), cgf(cgf), builder(builder), dirKind(dirKind),
dirLoc(dirLoc) {}
void VisitClause(const OpenACCClause &clause) {
clauseNotImplemented(clause);
}
void VisitDefaultClause(const OpenACCDefaultClause &clause) {
// This type-trait checks if 'op'(the first arg) is one of the mlir::acc
// operations listed in the rest of the arguments.
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
switch (clause.getDefaultClauseKind()) {
case OpenACCDefaultClauseKind::None:
operation.setDefaultAttr(mlir::acc::ClauseDefaultValue::None);
break;
case OpenACCDefaultClauseKind::Present:
operation.setDefaultAttr(mlir::acc::ClauseDefaultValue::Present);
break;
case OpenACCDefaultClauseKind::Invalid:
break;
}
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitDefaultClause");
}
}
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
setLastDeviceTypeClause(clause);
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp,
mlir::acc::ShutdownOp>) {
llvm::for_each(
clause.getArchitectures(), [this](const DeviceTypeArgument &arg) {
operation.addDeviceType(builder.getContext(),
decodeDeviceType(arg.getIdentifierInfo()));
});
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
assert(!operation.getDeviceTypeAttr() && "already have device-type?");
assert(clause.getArchitectures().size() <= 1);
if (!clause.getArchitectures().empty())
operation.setDeviceType(
decodeDeviceType(clause.getArchitectures()[0].getIdentifierInfo()));
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
mlir::acc::SerialOp, mlir::acc::KernelsOp,
mlir::acc::DataOp, mlir::acc::LoopOp>) {
// Nothing to do here, these constructs don't have any IR for these, as
// they just modify the other clauses IR. So setting of
// `lastDeviceTypeValues` (done above) is all we need.
} else if constexpr (isCombinedType<OpTy>) {
// Nothing to do here either, combined constructs are just going to use
// 'lastDeviceTypeValues' to set the value for the child visitor.
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. update, data, routine constructs remain.
return clauseNotImplemented(clause);
}
}
void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
mlir::acc::KernelsOp>) {
operation.addNumWorkersOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitNumGangsClause");
}
}
void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
mlir::acc::KernelsOp>) {
operation.addVectorLengthOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitVectorLengthClause");
}
}
void VisitAsyncClause(const OpenACCAsyncClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
if (!clause.hasIntExpr())
operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
else
operation.addAsyncOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
} else if constexpr (isOneOfTypes<OpTy, mlir::acc::WaitOp>) {
// Wait doesn't have a device_type, so its handling here is slightly
// different.
if (!clause.hasIntExpr())
operation.setAsync(true);
else
operation.getAsyncOperandMutable().append(
createIntExpr(clause.getIntExpr()));
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain. Data, enter data, exit data,
// update constructs remain.
return clauseNotImplemented(clause);
}
}
void VisitSelfClause(const OpenACCSelfClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
mlir::acc::KernelsOp>) {
if (clause.isEmptySelfClause()) {
operation.setSelfAttr(true);
} else if (clause.isConditionExprClause()) {
assert(clause.hasConditionExpr());
operation.getSelfCondMutable().append(
createCondition(clause.getConditionExpr()));
} else {
llvm_unreachable("var-list version of self shouldn't get here");
}
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. update construct remains.
return clauseNotImplemented(clause);
}
}
void VisitIfClause(const OpenACCIfClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
mlir::acc::KernelsOp, mlir::acc::InitOp,
mlir::acc::ShutdownOp, mlir::acc::SetOp,
mlir::acc::DataOp, mlir::acc::WaitOp>) {
operation.getIfCondMutable().append(
createCondition(clause.getConditionExpr()));
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
// 'if' applies to most of the constructs, but hold off on lowering them
// until we can write tests/know what we're doing with codegen to make
// sure we get it right.
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Enter data, exit data, host_data, update constructs
// remain.
return clauseNotImplemented(clause);
}
}
void VisitDeviceNumClause(const OpenACCDeviceNumClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::InitOp, mlir::acc::ShutdownOp,
mlir::acc::SetOp>) {
operation.getDeviceNumMutable().append(
createIntExpr(clause.getIntExpr()));
} else {
llvm_unreachable(
"init, shutdown, set, are only valid device_num constructs");
}
}
void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp,
mlir::acc::KernelsOp>) {
llvm::SmallVector<mlir::Value> values;
for (const Expr *E : clause.getIntExprs())
values.push_back(createIntExpr(E));
operation.addNumGangsOperands(builder.getContext(), values,
lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitNumGangsClause");
}
}
void VisitWaitClause(const OpenACCWaitClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
mlir::acc::KernelsOp, mlir::acc::DataOp>) {
if (!clause.hasExprs()) {
operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues);
} else {
llvm::SmallVector<mlir::Value> values;
if (clause.hasDevNumExpr())
values.push_back(createIntExpr(clause.getDevNumExpr()));
for (const Expr *E : clause.getQueueIdExprs())
values.push_back(createIntExpr(E));
operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
values, lastDeviceTypeValues);
}
} else if constexpr (isCombinedType<OpTy>) {
applyToComputeOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Enter data, exit data, update constructs remain.
return clauseNotImplemented(clause);
}
}
void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::SetOp>) {
operation.getDefaultAsyncMutable().append(
createIntExpr(clause.getIntExpr()));
} else {
llvm_unreachable("set, is only valid device_num constructs");
}
}
void VisitSeqClause(const OpenACCSeqClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
operation.addSeq(builder.getContext(), lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Routine construct remains.
return clauseNotImplemented(clause);
}
}
void VisitAutoClause(const OpenACCAutoClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
operation.addAuto(builder.getContext(), lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Routine, construct remains.
return clauseNotImplemented(clause);
}
}
void VisitIndependentClause(const OpenACCIndependentClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
operation.addIndependent(builder.getContext(), lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Routine construct remains.
return clauseNotImplemented(clause);
}
}
void VisitCollapseClause(const OpenACCCollapseClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
llvm::APInt value =
clause.getIntExpr()->EvaluateKnownConstInt(cgf.cgm.getASTContext());
value = value.sextOrTrunc(64);
operation.setCollapseForDeviceTypes(builder.getContext(),
lastDeviceTypeValues, value);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitCollapseClause");
}
}
void VisitTileClause(const OpenACCTileClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
llvm::SmallVector<mlir::Value> values;
for (const Expr *e : clause.getSizeExprs()) {
mlir::Location exprLoc = cgf.cgm.getLoc(e->getBeginLoc());
// We represent the * as -1. Additionally, this is a constant, so we
// can always just emit it as 64 bits to avoid having to do any more
// work to determine signedness or size.
if (isa<OpenACCAsteriskSizeExpr>(e)) {
values.push_back(createConstantInt(exprLoc, 64, -1));
} else {
llvm::APInt curValue =
e->EvaluateKnownConstInt(cgf.cgm.getASTContext());
values.push_back(createConstantInt(
exprLoc, 64, curValue.sextOrTrunc(64).getSExtValue()));
}
}
operation.setTileForDeviceTypes(builder.getContext(),
lastDeviceTypeValues, values);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitTileClause");
}
}
void VisitWorkerClause(const OpenACCWorkerClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
if (clause.hasIntExpr())
operation.addWorkerNumOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
else
operation.addEmptyWorker(builder.getContext(), lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
return clauseNotImplemented(clause);
}
}
void VisitVectorClause(const OpenACCVectorClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
if (clause.hasIntExpr())
operation.addVectorOperand(builder.getContext(),
createIntExpr(clause.getIntExpr()),
lastDeviceTypeValues);
else
operation.addEmptyVector(builder.getContext(), lastDeviceTypeValues);
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
// TODO: When we've implemented this for everything, switch this to an
// unreachable. Combined constructs remain.
return clauseNotImplemented(clause);
}
}
void VisitGangClause(const OpenACCGangClause &clause) {
if constexpr (isOneOfTypes<OpTy, mlir::acc::LoopOp>) {
if (clause.getNumExprs() == 0) {
operation.addEmptyGang(builder.getContext(), lastDeviceTypeValues);
} else {
llvm::SmallVector<mlir::Value> values;
llvm::SmallVector<mlir::acc::GangArgType> argTypes;
for (unsigned i : llvm::index_range(0u, clause.getNumExprs())) {
auto [kind, expr] = clause.getExpr(i);
mlir::Location exprLoc = cgf.cgm.getLoc(expr->getBeginLoc());
argTypes.push_back(decodeGangType(kind));
if (kind == OpenACCGangKind::Dim) {
llvm::APInt curValue =
expr->EvaluateKnownConstInt(cgf.cgm.getASTContext());
// The value is 1, 2, or 3, but the type isn't necessarily smaller
// than 64.
curValue = curValue.sextOrTrunc(64);
values.push_back(
createConstantInt(exprLoc, 64, curValue.getSExtValue()));
} else if (isa<OpenACCAsteriskSizeExpr>(expr)) {
values.push_back(createConstantInt(exprLoc, 64, -1));
} else {
values.push_back(createIntExpr(expr));
}
}
operation.addGangOperands(builder.getContext(), lastDeviceTypeValues,
argTypes, values);
}
} else if constexpr (isCombinedType<OpTy>) {
applyToLoopOp(clause);
} else {
llvm_unreachable("Unknown construct kind in VisitGangClause");
}
}
};
template <typename OpTy>
auto makeClauseEmitter(OpTy &op, CIRGen::CIRGenFunction &cgf,
CIRGen::CIRGenBuilderTy &builder,
OpenACCDirectiveKind dirKind, SourceLocation dirLoc) {
return OpenACCClauseCIREmitter<OpTy>(op, cgf, builder, dirKind, dirLoc);
}
} // namespace clang