blob: 31475ec7b4f1638b0051499883a3faa7b7bd1ec0 [file] [log] [blame]
//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the OpenMP dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include <cstddef>
#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
using namespace mlir;
using namespace mlir::omp;
namespace {
/// Model for pointer-like types that already provide a `getElementType` method.
template <typename T>
struct PointerLikeModel
: public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
Type getElementType(Type pointer) const {
return pointer.cast<T>().getElementType();
}
};
} // end namespace
void OpenMPDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
>();
LLVM::LLVMPointerType::attachInterface<
PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
}
//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*default_val=*/nullptr, /*private_vars=*/ValueRange(),
/*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
/*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
/*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
state.addAttributes(attributes);
}
//===----------------------------------------------------------------------===//
// Parser and printer for Operand and type list
//===----------------------------------------------------------------------===//
/// Parse a list of operands with types.
///
/// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
/// ssa-id-and-type-list ::= ssa-id-and-type |
/// ssa-id-and-type `,` ssa-id-and-type-list
/// ssa-id-and-type ::= ssa-id `:` type
static ParseResult
parseOperandAndTypeList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &operands,
SmallVectorImpl<Type> &types) {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
OpAsmParser::OperandType operand;
Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operands.push_back(operand);
types.push_back(type);
return success();
});
}
/// Print an operand and type list with parentheses
static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
p << "(";
llvm::interleaveComma(
operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
p << ") ";
}
/// Print data variables corresponding to a data-sharing clause `name`
static void printDataVars(OpAsmPrinter &p, OperandRange operands,
StringRef name) {
if (operands.size()) {
p << name;
printOperandAndTypeList(p, operands);
}
}
//===----------------------------------------------------------------------===//
// Parser and printer for Allocate Clause
//===----------------------------------------------------------------------===//
/// Parse an allocate clause with allocators and a list of operands with types.
///
/// allocate ::= `allocate` `(` allocate-operand-list `)`
/// allocate-operand-list :: = allocate-operand |
/// allocator-operand `,` allocate-operand-list
/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
/// ssa-id-and-type ::= ssa-id `:` type
static ParseResult parseAllocateAndAllocator(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
SmallVectorImpl<Type> &typesAllocate,
SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
SmallVectorImpl<Type> &typesAllocator) {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
OpAsmParser::OperandType operand;
Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operandsAllocator.push_back(operand);
typesAllocator.push_back(type);
if (parser.parseArrow())
return failure();
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operandsAllocate.push_back(operand);
typesAllocate.push_back(type);
return success();
});
}
/// Print allocate clause
static void printAllocateAndAllocator(OpAsmPrinter &p,
OperandRange varsAllocate,
OperandRange varsAllocator) {
p << "allocate(";
for (unsigned i = 0; i < varsAllocate.size(); ++i) {
std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
}
}
static LogicalResult verifyParallelOp(ParallelOp op) {
if (op.allocate_vars().size() != op.allocators_vars().size())
return op.emitError(
"expected equal sizes for allocate and allocator variables");
return success();
}
static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
p << " ";
if (auto ifCond = op.if_expr_var())
p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
if (auto threads = op.num_threads_var())
p << "num_threads(" << threads << " : " << threads.getType() << ") ";
printDataVars(p, op.private_vars(), "private");
printDataVars(p, op.firstprivate_vars(), "firstprivate");
printDataVars(p, op.shared_vars(), "shared");
printDataVars(p, op.copyin_vars(), "copyin");
if (!op.allocate_vars().empty())
printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
if (auto def = op.default_val())
p << "default(" << def->drop_front(3) << ") ";
if (auto bind = op.proc_bind_val())
p << "proc_bind(" << bind << ") ";
p.printRegion(op.getRegion());
}
//===----------------------------------------------------------------------===//
// Parser and printer for Linear Clause
//===----------------------------------------------------------------------===//
/// linear ::= `linear` `(` linear-list `)`
/// linear-list := linear-val | linear-val linear-list
/// linear-val := ssa-id-and-type `=` ssa-id-and-type
static ParseResult
parseLinearClause(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &vars,
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
if (parser.parseLParen())
return failure();
do {
OpAsmParser::OperandType var;
Type type;
OpAsmParser::OperandType stepVar;
if (parser.parseOperand(var) || parser.parseEqual() ||
parser.parseOperand(stepVar) || parser.parseColonType(type))
return failure();
vars.push_back(var);
types.push_back(type);
stepVars.push_back(stepVar);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen())
return failure();
return success();
}
/// Print Linear Clause
static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
OperandRange linearStepVars) {
size_t linearVarsSize = linearVars.size();
p << "linear(";
for (unsigned i = 0; i < linearVarsSize; ++i) {
std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
p << linearVars[i];
if (linearStepVars.size() > i)
p << " = " << linearStepVars[i];
p << " : " << linearVars[i].getType() << separator;
}
}
//===----------------------------------------------------------------------===//
// Parser and printer for Schedule Clause
//===----------------------------------------------------------------------===//
static ParseResult
verifyScheduleModifiers(OpAsmParser &parser,
SmallVectorImpl<SmallString<12>> &modifiers) {
if (modifiers.size() > 2)
return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
for (auto mod : modifiers) {
// Translate the string. If it has no value, then it was not a valid
// modifier!
auto symbol = symbolizeScheduleModifier(mod);
if (!symbol.hasValue())
return parser.emitError(parser.getNameLoc())
<< " unknown modifier type: " << mod;
}
// If we have one modifier that is "simd", then stick a "none" modiifer in
// index 0.
if (modifiers.size() == 1) {
if (symbolizeScheduleModifier(modifiers[0]) ==
mlir::omp::ScheduleModifier::simd) {
modifiers.push_back(modifiers[0]);
modifiers[0] =
stringifyScheduleModifier(mlir::omp::ScheduleModifier::none);
}
} else if (modifiers.size() == 2) {
// If there are two modifier:
// First modifier should not be simd, second one should be simd
if (symbolizeScheduleModifier(modifiers[0]) ==
mlir::omp::ScheduleModifier::simd ||
symbolizeScheduleModifier(modifiers[1]) !=
mlir::omp::ScheduleModifier::simd)
return parser.emitError(parser.getNameLoc())
<< " incorrect modifier order";
}
return success();
}
/// schedule ::= `schedule` `(` sched-list `)`
/// sched-list ::= sched-val | sched-val sched-list |
/// sched-val `,` sched-modifier
/// sched-val ::= sched-with-chunk | sched-wo-chunk
/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
/// sched-wo-chunk ::= `auto` | `runtime`
/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
static ParseResult
parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
SmallVectorImpl<SmallString<12>> &modifiers,
Optional<OpAsmParser::OperandType> &chunkSize) {
if (parser.parseLParen())
return failure();
StringRef keyword;
if (parser.parseKeyword(&keyword))
return failure();
schedule = keyword;
if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
if (succeeded(parser.parseOptionalEqual())) {
chunkSize = OpAsmParser::OperandType{};
if (parser.parseOperand(*chunkSize))
return failure();
} else {
chunkSize = llvm::NoneType::None;
}
} else if (keyword == "auto" || keyword == "runtime") {
chunkSize = llvm::NoneType::None;
} else {
return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
}
// If there is a comma, we have one or more modifiers..
while (succeeded(parser.parseOptionalComma())) {
StringRef mod;
if (parser.parseKeyword(&mod))
return failure();
modifiers.push_back(mod);
}
if (parser.parseRParen())
return failure();
if (verifyScheduleModifiers(parser, modifiers))
return failure();
return success();
}
/// Print schedule clause
static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
llvm::Optional<StringRef> modifier, bool simd,
Value scheduleChunkVar) {
std::string schedLower = sched.lower();
p << "schedule(" << schedLower;
if (scheduleChunkVar)
p << " = " << scheduleChunkVar;
if (modifier && modifier.hasValue())
p << ", " << modifier;
if (simd)
p << ", simd";
p << ") ";
}
//===----------------------------------------------------------------------===//
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
/// reduction ::= `reduction` `(` reduction-entry-list `)`
/// reduction-entry-list ::= reduction-entry
/// | reduction-entry-list `,` reduction-entry
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
static ParseResult
parseReductionVarList(OpAsmParser &parser,
SmallVectorImpl<SymbolRefAttr> &symbols,
SmallVectorImpl<OpAsmParser::OperandType> &operands,
SmallVectorImpl<Type> &types) {
if (failed(parser.parseLParen()))
return failure();
do {
if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
} while (succeeded(parser.parseOptionalComma()));
return parser.parseRParen();
}
/// Print Reduction clause
static void printReductionVarList(OpAsmPrinter &p,
Optional<ArrayAttr> reductions,
OperandRange reduction_vars) {
p << "reduction(";
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
<< reduction_vars[i].getType();
}
p << ") ";
}
/// Verifies Reduction Clause
static LogicalResult verifyReductionVarList(Operation *op,
Optional<ArrayAttr> reductions,
OperandRange reduction_vars) {
if (reduction_vars.size() != 0) {
if (!reductions || reductions->size() != reduction_vars.size())
return op->emitOpError()
<< "expected as many reduction symbol references "
"as reduction variables";
} else {
if (reductions)
return op->emitOpError() << "unexpected reduction symbol references";
return success();
}
DenseSet<Value> accumulators;
for (auto args : llvm::zip(reduction_vars, *reductions)) {
Value accum = std::get<0>(args);
if (!accumulators.insert(accum).second)
return op->emitOpError() << "accumulator variable used more than once";
Type varType = accum.getType().cast<PointerLikeType>();
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
auto decl =
SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
if (!decl)
return op->emitOpError() << "expected symbol reference " << symbolRef
<< " to point to a reduction declaration";
if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
return op->emitOpError()
<< "expected accumulator (" << varType
<< ") to be the same type as reduction declaration ("
<< decl.getAccumulatorType() << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// Parser, printer and verifier for Synchronization Hint (2.17.12)
//===----------------------------------------------------------------------===//
/// Parses a Synchronization Hint clause. The value of hint is an integer
/// which is a combination of different hints from `omp_sync_hint_t`.
///
/// hint-clause = `hint` `(` hint-value `)`
static ParseResult parseSynchronizationHint(OpAsmParser &parser,
IntegerAttr &hintAttr,
bool parseKeyword = true) {
if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
return success();
}
if (failed(parser.parseLParen()))
return failure();
StringRef hintKeyword;
int64_t hint = 0;
do {
if (failed(parser.parseKeyword(&hintKeyword)))
return failure();
if (hintKeyword == "uncontended")
hint |= 1;
else if (hintKeyword == "contended")
hint |= 2;
else if (hintKeyword == "nonspeculative")
hint |= 4;
else if (hintKeyword == "speculative")
hint |= 8;
else
return parser.emitError(parser.getCurrentLocation())
<< hintKeyword << " is not a valid hint";
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()))
return failure();
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
return success();
}
/// Prints a Synchronization Hint clause
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
IntegerAttr hintAttr) {
int64_t hint = hintAttr.getInt();
if (hint == 0)
return;
// Helper function to get n-th bit from the right end of `value`
auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
bool uncontended = bitn(hint, 0);
bool contended = bitn(hint, 1);
bool nonspeculative = bitn(hint, 2);
bool speculative = bitn(hint, 3);
SmallVector<StringRef> hints;
if (uncontended)
hints.push_back("uncontended");
if (contended)
hints.push_back("contended");
if (nonspeculative)
hints.push_back("nonspeculative");
if (speculative)
hints.push_back("speculative");
p << "hint(";
llvm::interleaveComma(hints, p);
p << ") ";
}
/// Verifies a synchronization hint clause
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
// Helper function to get n-th bit from the right end of `value`
auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
bool uncontended = bitn(hint, 0);
bool contended = bitn(hint, 1);
bool nonspeculative = bitn(hint, 2);
bool speculative = bitn(hint, 3);
if (uncontended && contended)
return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
"omp_sync_hint_contended cannot be combined";
if (nonspeculative && speculative)
return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
"omp_sync_hint_speculative cannot be combined.";
return success();
}
enum ClauseType {
ifClause,
numThreadsClause,
privateClause,
firstprivateClause,
lastprivateClause,
sharedClause,
copyinClause,
allocateClause,
defaultClause,
procBindClause,
reductionClause,
nowaitClause,
linearClause,
scheduleClause,
collapseClause,
orderClause,
orderedClause,
memoryOrderClause,
hintClause,
COUNT
};
//===----------------------------------------------------------------------===//
// Parser for Clause List
//===----------------------------------------------------------------------===//
/// Parse a list of clauses. The clauses can appear in any order, but their
/// operand segment indices are in the same order that they are passed in the
/// `clauses` list. The operand segments are added over the prevSegments
/// clause-list ::= clause clause-list | empty
/// clause ::= if | num-threads | private | firstprivate | lastprivate |
/// shared | copyin | allocate | default | proc-bind | reduction |
/// nowait | linear | schedule | collapse | order | ordered |
/// inclusive
/// if ::= `if` `(` ssa-id-and-type `)`
/// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
/// private ::= `private` operand-and-type-list
/// firstprivate ::= `firstprivate` operand-and-type-list
/// lastprivate ::= `lastprivate` operand-and-type-list
/// shared ::= `shared` operand-and-type-list
/// copyin ::= `copyin` operand-and-type-list
/// allocate ::= `allocate` `(` allocate-operand-list `)`
/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
/// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
/// reduction ::= `reduction` `(` reduction-entry-list `)`
/// nowait ::= `nowait`
/// linear ::= `linear` `(` linear-list `)`
/// schedule ::= `schedule` `(` sched-list `)`
/// collapse ::= `collapse` `(` ssa-id-and-type `)`
/// order ::= `order` `(` `concurrent` `)`
/// ordered ::= `ordered` `(` ssa-id-and-type `)`
/// inclusive ::= `inclusive`
///
/// Note that each clause can only appear once in the clase-list.
static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<ClauseType> &clauses,
SmallVectorImpl<int> &segments) {
// Check done[clause] to see if it has been parsed already
llvm::BitVector done(ClauseType::COUNT, false);
// See pos[clause] to get position of clause in operand segments
SmallVector<int> pos(ClauseType::COUNT, -1);
// Stores the last parsed clause keyword
StringRef clauseKeyword;
StringRef opName = result.name.getStringRef();
// Containers for storing operands, types and attributes for various clauses
std::pair<OpAsmParser::OperandType, Type> ifCond;
std::pair<OpAsmParser::OperandType, Type> numThreads;
SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
shareds, copyins;
SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
sharedTypes, copyinTypes;
SmallVector<OpAsmParser::OperandType> allocates, allocators;
SmallVector<Type> allocateTypes, allocatorTypes;
SmallVector<SymbolRefAttr> reductionSymbols;
SmallVector<OpAsmParser::OperandType> reductionVars;
SmallVector<Type> reductionVarTypes;
SmallVector<OpAsmParser::OperandType> linears;
SmallVector<Type> linearTypes;
SmallVector<OpAsmParser::OperandType> linearSteps;
SmallString<8> schedule;
SmallVector<SmallString<12>> modifiers;
Optional<OpAsmParser::OperandType> scheduleChunkSize;
// Compute the position of clauses in operand segments
int currPos = 0;
for (ClauseType clause : clauses) {
// Skip the following clauses - they do not take any position in operand
// segments
if (clause == defaultClause || clause == procBindClause ||
clause == nowaitClause || clause == collapseClause ||
clause == orderClause || clause == orderedClause)
continue;
pos[clause] = currPos++;
// For the following clauses, two positions are reserved in the operand
// segments
if (clause == allocateClause || clause == linearClause)
currPos++;
}
SmallVector<int> clauseSegments(currPos);
// Helper function to check if a clause is allowed/repeated or not
auto checkAllowed = [&](ClauseType clause,
bool allowRepeat = false) -> ParseResult {
if (!llvm::is_contained(clauses, clause))
return parser.emitError(parser.getCurrentLocation())
<< clauseKeyword << " is not a valid clause for the " << opName
<< " operation";
if (done[clause] && !allowRepeat)
return parser.emitError(parser.getCurrentLocation())
<< "at most one " << clauseKeyword << " clause can appear on the "
<< opName << " operation";
done[clause] = true;
return success();
};
while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
if (clauseKeyword == "if") {
if (checkAllowed(ifClause) || parser.parseLParen() ||
parser.parseOperand(ifCond.first) ||
parser.parseColonType(ifCond.second) || parser.parseRParen())
return failure();
clauseSegments[pos[ifClause]] = 1;
} else if (clauseKeyword == "num_threads") {
if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
parser.parseOperand(numThreads.first) ||
parser.parseColonType(numThreads.second) || parser.parseRParen())
return failure();
clauseSegments[pos[numThreadsClause]] = 1;
} else if (clauseKeyword == "private") {
if (checkAllowed(privateClause) ||
parseOperandAndTypeList(parser, privates, privateTypes))
return failure();
clauseSegments[pos[privateClause]] = privates.size();
} else if (clauseKeyword == "firstprivate") {
if (checkAllowed(firstprivateClause) ||
parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
return failure();
clauseSegments[pos[firstprivateClause]] = firstprivates.size();
} else if (clauseKeyword == "lastprivate") {
if (checkAllowed(lastprivateClause) ||
parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
return failure();
clauseSegments[pos[lastprivateClause]] = lastprivates.size();
} else if (clauseKeyword == "shared") {
if (checkAllowed(sharedClause) ||
parseOperandAndTypeList(parser, shareds, sharedTypes))
return failure();
clauseSegments[pos[sharedClause]] = shareds.size();
} else if (clauseKeyword == "copyin") {
if (checkAllowed(copyinClause) ||
parseOperandAndTypeList(parser, copyins, copyinTypes))
return failure();
clauseSegments[pos[copyinClause]] = copyins.size();
} else if (clauseKeyword == "allocate") {
if (checkAllowed(allocateClause) ||
parseAllocateAndAllocator(parser, allocates, allocateTypes,
allocators, allocatorTypes))
return failure();
clauseSegments[pos[allocateClause]] = allocates.size();
clauseSegments[pos[allocateClause] + 1] = allocators.size();
} else if (clauseKeyword == "default") {
StringRef defval;
if (checkAllowed(defaultClause) || parser.parseLParen() ||
parser.parseKeyword(&defval) || parser.parseRParen())
return failure();
// The def prefix is required for the attribute as "private" is a keyword
// in C++.
auto attr = parser.getBuilder().getStringAttr("def" + defval);
result.addAttribute("default_val", attr);
} else if (clauseKeyword == "proc_bind") {
StringRef bind;
if (checkAllowed(procBindClause) || parser.parseLParen() ||
parser.parseKeyword(&bind) || parser.parseRParen())
return failure();
auto attr = parser.getBuilder().getStringAttr(bind);
result.addAttribute("proc_bind_val", attr);
} else if (clauseKeyword == "reduction") {
if (checkAllowed(reductionClause) ||
parseReductionVarList(parser, reductionSymbols, reductionVars,
reductionVarTypes))
return failure();
clauseSegments[pos[reductionClause]] = reductionVars.size();
} else if (clauseKeyword == "nowait") {
if (checkAllowed(nowaitClause))
return failure();
auto attr = UnitAttr::get(parser.getBuilder().getContext());
result.addAttribute("nowait", attr);
} else if (clauseKeyword == "linear") {
if (checkAllowed(linearClause) ||
parseLinearClause(parser, linears, linearTypes, linearSteps))
return failure();
clauseSegments[pos[linearClause]] = linears.size();
clauseSegments[pos[linearClause] + 1] = linearSteps.size();
} else if (clauseKeyword == "schedule") {
if (checkAllowed(scheduleClause) ||
parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
return failure();
if (scheduleChunkSize) {
clauseSegments[pos[scheduleClause]] = 1;
}
} else if (clauseKeyword == "collapse") {
auto type = parser.getBuilder().getI64Type();
mlir::IntegerAttr attr;
if (checkAllowed(collapseClause) || parser.parseLParen() ||
parser.parseAttribute(attr, type) || parser.parseRParen())
return failure();
result.addAttribute("collapse_val", attr);
} else if (clauseKeyword == "ordered") {
mlir::IntegerAttr attr;
if (checkAllowed(orderedClause))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
auto type = parser.getBuilder().getI64Type();
if (parser.parseAttribute(attr, type) || parser.parseRParen())
return failure();
} else {
// Use 0 to represent no ordered parameter was specified
attr = parser.getBuilder().getI64IntegerAttr(0);
}
result.addAttribute("ordered_val", attr);
} else if (clauseKeyword == "order") {
StringRef order;
if (checkAllowed(orderClause) || parser.parseLParen() ||
parser.parseKeyword(&order) || parser.parseRParen())
return failure();
auto attr = parser.getBuilder().getStringAttr(order);
result.addAttribute("order_val", attr);
} else if (clauseKeyword == "memory_order") {
StringRef memoryOrder;
if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
parser.parseKeyword(&memoryOrder) || parser.parseRParen())
return failure();
result.addAttribute("memory_order",
parser.getBuilder().getStringAttr(memoryOrder));
} else if (clauseKeyword == "hint") {
IntegerAttr hint;
if (checkAllowed(hintClause) ||
parseSynchronizationHint(parser, hint, false))
return failure();
result.addAttribute("hint", hint);
} else {
return parser.emitError(parser.getNameLoc())
<< clauseKeyword << " is not a valid clause";
}
}
// Add if parameter.
if (done[ifClause] && clauseSegments[pos[ifClause]] &&
failed(
parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
return failure();
// Add num_threads parameter.
if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
failed(parser.resolveOperand(numThreads.first, numThreads.second,
result.operands)))
return failure();
// Add private parameters.
if (done[privateClause] && clauseSegments[pos[privateClause]] &&
failed(parser.resolveOperands(privates, privateTypes,
privates[0].location, result.operands)))
return failure();
// Add firstprivate parameters.
if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
failed(parser.resolveOperands(firstprivates, firstprivateTypes,
firstprivates[0].location,
result.operands)))
return failure();
// Add lastprivate parameters.
if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
failed(parser.resolveOperands(lastprivates, lastprivateTypes,
lastprivates[0].location, result.operands)))
return failure();
// Add shared parameters.
if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
result.operands)))
return failure();
// Add copyin parameters.
if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
result.operands)))
return failure();
// Add allocate parameters.
if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
failed(parser.resolveOperands(allocates, allocateTypes,
allocates[0].location, result.operands)))
return failure();
// Add allocator parameters.
if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
failed(parser.resolveOperands(allocators, allocatorTypes,
allocators[0].location, result.operands)))
return failure();
// Add reduction parameters and symbols
if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
parser.getNameLoc(), result.operands)))
return failure();
SmallVector<Attribute> reductions(reductionSymbols.begin(),
reductionSymbols.end());
result.addAttribute("reductions",
parser.getBuilder().getArrayAttr(reductions));
}
// Add linear parameters
if (done[linearClause] && clauseSegments[pos[linearClause]]) {
auto linearStepType = parser.getBuilder().getI32Type();
SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
result.operands)) ||
failed(parser.resolveOperands(linearSteps, linearStepTypes,
linearSteps[0].location,
result.operands)))
return failure();
}
// Add schedule parameters
if (done[scheduleClause] && !schedule.empty()) {
schedule[0] = llvm::toUpper(schedule[0]);
auto attr = parser.getBuilder().getStringAttr(schedule);
result.addAttribute("schedule_val", attr);
if (modifiers.size() > 0) {
auto mod = parser.getBuilder().getStringAttr(modifiers[0]);
result.addAttribute("schedule_modifier", mod);
// Only SIMD attribute is allowed here!
if (modifiers.size() > 1) {
assert(symbolizeScheduleModifier(modifiers[1]) ==
mlir::omp::ScheduleModifier::simd);
auto attr = UnitAttr::get(parser.getBuilder().getContext());
result.addAttribute("simd_modifier", attr);
}
}
if (scheduleChunkSize) {
auto chunkSizeType = parser.getBuilder().getI32Type();
parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
}
}
segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
return success();
}
/// Parses a parallel operation.
///
/// operation ::= `omp.parallel` clause-list
/// clause-list ::= clause | clause clause-list
/// clause ::= if | num-threads | private | firstprivate | shared | copyin |
/// allocate | default | proc-bind
///
static ParseResult parseParallelOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<ClauseType> clauses = {
ifClause, numThreadsClause, privateClause,
firstprivateClause, sharedClause, copyinClause,
allocateClause, defaultClause, procBindClause};
SmallVector<int> segments;
if (failed(parseClauses(parser, result, clauses, segments)))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(segments));
Region *body = result.addRegion();
SmallVector<OpAsmParser::OperandType> regionArgs;
SmallVector<Type> regionArgTypes;
if (parser.parseRegion(*body, regionArgs, regionArgTypes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// Parser, printer and verifier for SectionsOp
//===----------------------------------------------------------------------===//
/// Parses an OpenMP Sections operation
///
/// sections ::= `omp.sections` clause-list
/// clause-list ::= clause clause-list | empty
/// clause ::= private | firstprivate | lastprivate | reduction | allocate |
/// nowait
static ParseResult parseSectionsOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<ClauseType> clauses = {privateClause, firstprivateClause,
lastprivateClause, reductionClause,
allocateClause, nowaitClause};
SmallVector<int> segments;
if (failed(parseClauses(parser, result, clauses, segments)))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(segments));
// Now parse the body.
Region *body = result.addRegion();
if (parser.parseRegion(*body))
return failure();
return success();
}
static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
p << " ";
printDataVars(p, op.private_vars(), "private");
printDataVars(p, op.firstprivate_vars(), "firstprivate");
printDataVars(p, op.lastprivate_vars(), "lastprivate");
if (!op.reduction_vars().empty())
printReductionVarList(p, op.reductions(), op.reduction_vars());
if (!op.allocate_vars().empty())
printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
if (op.nowait())
p << "nowait ";
p.printRegion(op.region());
}
static LogicalResult verifySectionsOp(SectionsOp op) {
// A list item may not appear in more than one clause on the same directive,
// except that it may be specified in both firstprivate and lastprivate
// clauses.
for (auto var : op.private_vars()) {
if (llvm::is_contained(op.firstprivate_vars(), var))
return op.emitOpError()
<< "operand used in both private and firstprivate clauses";
if (llvm::is_contained(op.lastprivate_vars(), var))
return op.emitOpError()
<< "operand used in both private and lastprivate clauses";
}
if (op.allocate_vars().size() != op.allocators_vars().size())
return op.emitError(
"expected equal sizes for allocate and allocator variables");
for (auto &inst : *op.region().begin()) {
if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
op.emitOpError()
<< "expected omp.section op or terminator op inside region";
}
return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
}
/// Parses an OpenMP Workshare Loop operation
///
/// wsloop ::= `omp.wsloop` loop-control clause-list
/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
/// steps := `step` `(`ssa-id-list`)`
/// clause-list ::= clause clause-list | empty
/// clause ::= private | firstprivate | lastprivate | linear | schedule |
// collapse | nowait | ordered | order | reduction
static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::OperandType> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
return failure();
int numIVs = static_cast<int>(ivs.size());
Type loopVarType;
if (parser.parseColonType(loopVarType))
return failure();
// Parse loop bounds.
SmallVector<OpAsmParser::OperandType> lower;
if (parser.parseEqual() ||
parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(lower, loopVarType, result.operands))
return failure();
SmallVector<OpAsmParser::OperandType> upper;
if (parser.parseKeyword("to") ||
parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(upper, loopVarType, result.operands))
return failure();
if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
auto attr = UnitAttr::get(parser.getBuilder().getContext());
result.addAttribute("inclusive", attr);
}
// Parse step values.
SmallVector<OpAsmParser::OperandType> steps;
if (parser.parseKeyword("step") ||
parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(steps, loopVarType, result.operands))
return failure();
SmallVector<ClauseType> clauses = {
privateClause, firstprivateClause, lastprivateClause, linearClause,
reductionClause, collapseClause, orderClause, orderedClause,
nowaitClause, scheduleClause};
SmallVector<int> segments{numIVs, numIVs, numIVs};
if (failed(parseClauses(parser, result, clauses, segments)))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(segments));
// Now parse the body.
Region *body = result.addRegion();
SmallVector<Type> ivTypes(numIVs, loopVarType);
SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
if (parser.parseRegion(*body, blockArgs, ivTypes))
return failure();
return success();
}
static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
auto args = op.getRegion().front().getArguments();
p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
<< ") to (" << op.upperBound() << ") ";
if (op.inclusive()) {
p << "inclusive ";
}
p << "step (" << op.step() << ") ";
printDataVars(p, op.private_vars(), "private");
printDataVars(p, op.firstprivate_vars(), "firstprivate");
printDataVars(p, op.lastprivate_vars(), "lastprivate");
if (op.linear_vars().size())
printLinearClause(p, op.linear_vars(), op.linear_step_vars());
if (auto sched = op.schedule_val())
printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
op.simd_modifier(), op.schedule_chunk_var());
if (auto collapse = op.collapse_val())
p << "collapse(" << collapse << ") ";
if (op.nowait())
p << "nowait ";
if (auto ordered = op.ordered_val())
p << "ordered(" << ordered << ") ";
if (auto order = op.order_val())
p << "order(" << order << ") ";
if (!op.reduction_vars().empty())
printReductionVarList(p, op.reductions(), op.reduction_vars());
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
Region &region) {
if (parser.parseOptionalKeyword("atomic"))
return success();
return parser.parseRegion(region);
}
static void printAtomicReductionRegion(OpAsmPrinter &printer,
ReductionDeclareOp op, Region &region) {
if (region.empty())
return;
printer << "atomic ";
printer.printRegion(region);
}
static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
if (op.initializerRegion().empty())
return op.emitOpError() << "expects non-empty initializer region";
Block &initializerEntryBlock = op.initializerRegion().front();
if (initializerEntryBlock.getNumArguments() != 1 ||
initializerEntryBlock.getArgument(0).getType() != op.type()) {
return op.emitOpError() << "expects initializer region with one argument "
"of the reduction type";
}
for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
if (yieldOp.results().size() != 1 ||
yieldOp.results().getTypes()[0] != op.type())
return op.emitOpError() << "expects initializer region to yield a value "
"of the reduction type";
}
if (op.reductionRegion().empty())
return op.emitOpError() << "expects non-empty reduction region";
Block &reductionEntryBlock = op.reductionRegion().front();
if (reductionEntryBlock.getNumArguments() != 2 ||
reductionEntryBlock.getArgumentTypes()[0] !=
reductionEntryBlock.getArgumentTypes()[1] ||
reductionEntryBlock.getArgumentTypes()[0] != op.type())
return op.emitOpError() << "expects reduction region with two arguments of "
"the reduction type";
for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
if (yieldOp.results().size() != 1 ||
yieldOp.results().getTypes()[0] != op.type())
return op.emitOpError() << "expects reduction region to yield a value "
"of the reduction type";
}
if (op.atomicReductionRegion().empty())
return success();
Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
if (atomicReductionEntryBlock.getNumArguments() != 2 ||
atomicReductionEntryBlock.getArgumentTypes()[0] !=
atomicReductionEntryBlock.getArgumentTypes()[1])
return op.emitOpError() << "expects atomic reduction region with two "
"arguments of the same type";
auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
.dyn_cast<PointerLikeType>();
if (!ptrType || ptrType.getElementType() != op.type())
return op.emitOpError() << "expects atomic reduction region arguments to "
"be accumulators containing the reduction type";
return success();
}
static LogicalResult verifyReductionOp(ReductionOp op) {
// TODO: generalize this to an op interface when there is more than one op
// that supports reductions.
auto container = op->getParentOfType<WsLoopOp>();
for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
if (container.reduction_vars()[i] == op.accumulator())
return success();
return op.emitOpError() << "the accumulator is not used by the parent";
}
//===----------------------------------------------------------------------===//
// WsLoopOp
//===----------------------------------------------------------------------===//
void WsLoopOp::build(OpBuilder &builder, OperationState &state,
ValueRange lowerBound, ValueRange upperBound,
ValueRange step, ArrayRef<NamedAttribute> attributes) {
build(builder, state, TypeRange(), lowerBound, upperBound, step,
/*private_vars=*/ValueRange(),
/*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
/*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
/*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
/*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
/*inclusive=*/nullptr, /*buildBody=*/false);
state.addAttributes(attributes);
}
void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
ValueRange operands, ArrayRef<NamedAttribute> attributes) {
state.addOperands(operands);
state.addAttributes(attributes);
(void)state.addRegion();
assert(resultTypes.empty() && "mismatched number of return types");
state.addTypes(resultTypes);
}
void WsLoopOp::build(OpBuilder &builder, OperationState &result,
TypeRange typeRange, ValueRange lowerBounds,
ValueRange upperBounds, ValueRange steps,
ValueRange privateVars, ValueRange firstprivateVars,
ValueRange lastprivateVars, ValueRange linearVars,
ValueRange linearStepVars, ValueRange reductionVars,
StringAttr scheduleVal, Value scheduleChunkVar,
IntegerAttr collapseVal, UnitAttr nowait,
IntegerAttr orderedVal, StringAttr orderVal,
UnitAttr inclusive, bool buildBody) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
result.addOperands(steps);
result.addOperands(privateVars);
result.addOperands(firstprivateVars);
result.addOperands(linearVars);
result.addOperands(linearStepVars);
if (scheduleChunkVar)
result.addOperands(scheduleChunkVar);
if (scheduleVal)
result.addAttribute("schedule_val", scheduleVal);
if (collapseVal)
result.addAttribute("collapse_val", collapseVal);
if (nowait)
result.addAttribute("nowait", nowait);
if (orderedVal)
result.addAttribute("ordered_val", orderedVal);
if (orderVal)
result.addAttribute("order", orderVal);
if (inclusive)
result.addAttribute("inclusive", inclusive);
result.addAttribute(
WsLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
{static_cast<int32_t>(lowerBounds.size()),
static_cast<int32_t>(upperBounds.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(privateVars.size()),
static_cast<int32_t>(firstprivateVars.size()),
static_cast<int32_t>(lastprivateVars.size()),
static_cast<int32_t>(linearVars.size()),
static_cast<int32_t>(linearStepVars.size()),
static_cast<int32_t>(reductionVars.size()),
static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
Region *bodyRegion = result.addRegion();
if (buildBody) {
OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size();
SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
builder.createBlock(bodyRegion, {}, argTypes);
}
}
static LogicalResult verifyWsLoopOp(WsLoopOp op) {
return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
}
//===----------------------------------------------------------------------===//
// Verifier for critical construct (2.17.1)
//===----------------------------------------------------------------------===//
static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
return verifySynchronizationHint(op, op.hint());
}
static LogicalResult verifyCriticalOp(CriticalOp op) {
if (op.nameAttr()) {
auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
auto decl =
SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
if (!decl) {
return op.emitOpError() << "expected symbol reference " << symbolRef
<< " to point to a critical declaration";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// Verifier for ordered construct
//===----------------------------------------------------------------------===//
static LogicalResult verifyOrderedOp(OrderedOp op) {
auto container = op->getParentOfType<WsLoopOp>();
if (!container || !container.ordered_valAttr() ||
container.ordered_valAttr().getInt() == 0)
return op.emitOpError() << "ordered depend directive must be closely "
<< "nested inside a worksharing-loop with ordered "
<< "clause with parameter present";
if (container.ordered_valAttr().getInt() !=
(int64_t)op.num_loops_val().getValue())
return op.emitOpError() << "number of variables in depend clause does not "
<< "match number of iteration variables in the "
<< "doacross loop";
return success();
}
static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
// TODO: The code generation for ordered simd directive is not supported yet.
if (op.simd())
return failure();
if (auto container = op->getParentOfType<WsLoopOp>()) {
if (!container.ordered_valAttr() ||
container.ordered_valAttr().getInt() != 0)
return op.emitOpError() << "ordered region must be closely nested inside "
<< "a worksharing-loop region with an ordered "
<< "clause without parameter present";
}
return success();
}
//===----------------------------------------------------------------------===//
// AtomicReadOp
//===----------------------------------------------------------------------===//
/// Parser for AtomicReadOp
///
/// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
/// address ::= operand `:` type
static ParseResult parseAtomicReadOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType address;
Type addressType;
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;
if (parser.parseOperand(address) ||
parseClauses(parser, result, clauses, segments) ||
parser.parseColonType(addressType) ||
parser.resolveOperand(address, addressType, result.operands))
return failure();
SmallVector<Type> resultType;
if (parser.parseArrowTypeList(resultType))
return failure();
result.addTypes(resultType);
return success();
}
/// Printer for AtomicReadOp
static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
p << " " << op.address() << " ";
if (op.memory_order())
p << "memory_order(" << op.memory_order().getValue() << ") ";
if (op.hintAttr())
printSynchronizationHint(p << " ", op, op.hintAttr());
p << ": " << op.address().getType() << " -> " << op.getType();
return;
}
/// Verifier for AtomicReadOp
static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
if (op.memory_order()) {
StringRef memOrder = op.memory_order().getValue();
if (memOrder.equals("acq_rel") || memOrder.equals("release"))
return op.emitError(
"memory-order must not be acq_rel or release for atomic reads");
}
return verifySynchronizationHint(op, op.hint());
}
//===----------------------------------------------------------------------===//
// AtomicWriteOp
//===----------------------------------------------------------------------===//
/// Parser for AtomicWriteOp
///
/// operation ::= `omp.atomic.write` atomic-clause-list operands
/// operands ::= address `,` value
/// address ::= operand `:` type
/// value ::= operand `:` type
static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType address, value;
Type addrType, valueType;
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
SmallVector<int> segments;
if (parser.parseOperand(address) || parser.parseComma() ||
parser.parseOperand(value) ||
parseClauses(parser, result, clauses, segments) ||
parser.parseColonType(addrType) || parser.parseComma() ||
parser.parseType(valueType) ||
parser.resolveOperand(address, addrType, result.operands) ||
parser.resolveOperand(value, valueType, result.operands))
return failure();
return success();
}
/// Printer for AtomicWriteOp
static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
p << " " << op.address() << ", " << op.value() << " ";
if (op.memory_order())
p << "memory_order(" << op.memory_order() << ") ";
if (op.hintAttr())
printSynchronizationHint(p, op, op.hintAttr());
p << ": " << op.address().getType() << ", " << op.value().getType();
return;
}
/// Verifier for AtomicWriteOp
static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
if (op.memory_order()) {
StringRef memoryOrder = op.memory_order().getValue();
if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
return op.emitError(
"memory-order must not be acq_rel or acquire for atomic writes");
}
return verifySynchronizationHint(op, op.hint());
}
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"