| //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// |
| // |
| // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| // ============================================================================= |
| |
| #include "mlir/Dialect/OpenACC/OpenACC.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| using namespace mlir; |
| using namespace acc; |
| |
| #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // OpenACC operations |
| //===----------------------------------------------------------------------===// |
| |
| void OpenACCDialect::initialize() { |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" |
| >(); |
| } |
| |
| template <typename StructureOp> |
| static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, |
| unsigned nRegions = 1) { |
| |
| SmallVector<Region *, 2> regions; |
| for (unsigned i = 0; i < nRegions; ++i) |
| regions.push_back(state.addRegion()); |
| |
| for (Region *region : regions) { |
| if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| static ParseResult |
| parseOperandList(OpAsmParser &parser, StringRef keyword, |
| SmallVectorImpl<OpAsmParser::OperandType> &args, |
| SmallVectorImpl<Type> &argTypes, OperationState &result) { |
| if (failed(parser.parseOptionalKeyword(keyword))) |
| return success(); |
| |
| if (failed(parser.parseLParen())) |
| return failure(); |
| |
| // Exit early if the list is empty. |
| if (succeeded(parser.parseOptionalRParen())) |
| return success(); |
| |
| do { |
| OpAsmParser::OperandType arg; |
| Type type; |
| |
| if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) |
| return failure(); |
| |
| args.push_back(arg); |
| argTypes.push_back(type); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (failed(parser.parseRParen())) |
| return failure(); |
| |
| return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), |
| result.operands); |
| } |
| |
| static void printOperandList(Operation::operand_range operands, |
| StringRef listName, OpAsmPrinter &printer) { |
| |
| if (!operands.empty()) { |
| printer << " " << listName << "("; |
| llvm::interleaveComma(operands, printer, [&](Value op) { |
| printer << op << ": " << op.getType(); |
| }); |
| printer << ")"; |
| } |
| } |
| |
| static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, |
| OpAsmParser::OperandType &operand, |
| Type type, bool &hasOptional, |
| OperationState &result) { |
| hasOptional = false; |
| if (succeeded(parser.parseOptionalKeyword(keyword))) { |
| hasOptional = true; |
| if (parser.parseLParen() || parser.parseOperand(operand) || |
| parser.resolveOperand(operand, type, result.operands) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| static ParseResult parseOperandAndType(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::OperandType operand; |
| Type type; |
| if (parser.parseOperand(operand) || parser.parseColonType(type) || |
| parser.resolveOperand(operand, type, result.operands)) |
| return failure(); |
| return success(); |
| } |
| |
| /// Parse optional operand and its type wrapped in parenthesis prefixed with |
| /// a keyword. |
| /// Example: |
| /// keyword `(` %vectorLength: i64 `)` |
| static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, |
| StringRef keyword, |
| OperationState &result) { |
| OpAsmParser::OperandType operand; |
| if (succeeded(parser.parseOptionalKeyword(keyword))) { |
| return failure(parser.parseLParen() || |
| parseOperandAndType(parser, result) || parser.parseRParen()); |
| } |
| return llvm::None; |
| } |
| |
| /// Parse optional operand and its type wrapped in parenthesis. |
| /// Example: |
| /// `(` %vectorLength: i64 `)` |
| static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, |
| OperationState &result) { |
| if (succeeded(parser.parseOptionalLParen())) { |
| return failure(parseOperandAndType(parser, result) || parser.parseRParen()); |
| } |
| return llvm::None; |
| } |
| |
| /// Parse optional operand with its type prefixed with prefixKeyword `=`. |
| /// Example: |
| /// num=%gangNum: i32 |
| static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( |
| OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { |
| if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { |
| parser.parseEqual(); |
| return parseOperandAndType(parser, result); |
| } |
| return llvm::None; |
| } |
| |
| static bool isComputeOperation(Operation *op) { |
| return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); |
| } |
| |
| namespace { |
| /// Pattern to remove operation without region that have constant false `ifCond` |
| /// and remove the condition from the operation if the `ifCond` is a true |
| /// constant. |
| template <typename OpTy> |
| struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { |
| using OpRewritePattern<OpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(OpTy op, |
| PatternRewriter &rewriter) const override { |
| // Early return if there is no condition. |
| if (!op.ifCond()) |
| return success(); |
| |
| auto constOp = op.ifCond().template getDefiningOp<arith::ConstantOp>(); |
| if (constOp && constOp.getValue().template cast<IntegerAttr>().getInt()) |
| rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); |
| else if (constOp) |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // ParallelOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse acc.parallel operation |
| /// operation := `acc.parallel` `async` `(` index `)`? |
| /// `wait` `(` index-list `)`? |
| /// `num_gangs` `(` value `)`? |
| /// `num_workers` `(` value `)`? |
| /// `vector_length` `(` value `)`? |
| /// `if` `(` value `)`? |
| /// `self` `(` value `)`? |
| /// `reduction` `(` value-list `)`? |
| /// `copy` `(` value-list `)`? |
| /// `copyin` `(` value-list `)`? |
| /// `copyin_readonly` `(` value-list `)`? |
| /// `copyout` `(` value-list `)`? |
| /// `copyout_zero` `(` value-list `)`? |
| /// `create` `(` value-list `)`? |
| /// `create_zero` `(` value-list `)`? |
| /// `no_create` `(` value-list `)`? |
| /// `present` `(` value-list `)`? |
| /// `deviceptr` `(` value-list `)`? |
| /// `attach` `(` value-list `)`? |
| /// `private` `(` value-list `)`? |
| /// `firstprivate` `(` value-list `)`? |
| /// region attr-dict? |
| static ParseResult parseParallelOp(OpAsmParser &parser, |
| OperationState &result) { |
| Builder &builder = parser.getBuilder(); |
| SmallVector<OpAsmParser::OperandType, 8> privateOperands, |
| firstprivateOperands, copyOperands, copyinOperands, |
| copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, |
| createOperands, createZeroOperands, noCreateOperands, presentOperands, |
| devicePtrOperands, attachOperands, waitOperands, reductionOperands; |
| SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes, |
| copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes, |
| copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes, |
| createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes, |
| deviceptrOperandTypes, attachOperandTypes, privateOperandTypes, |
| firstprivateOperandTypes; |
| |
| SmallVector<Type, 8> operandTypes; |
| OpAsmParser::OperandType ifCond, selfCond; |
| bool hasIfCond = false, hasSelfCond = false; |
| OptionalParseResult async, numGangs, numWorkers, vectorLength; |
| Type i1Type = builder.getI1Type(); |
| |
| // async()? |
| async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), |
| result); |
| if (async.hasValue() && failed(*async)) |
| return failure(); |
| |
| // wait()? |
| if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), |
| waitOperands, waitOperandTypes, result))) |
| return failure(); |
| |
| // num_gangs(value)? |
| numGangs = parseOptionalOperandAndType( |
| parser, ParallelOp::getNumGangsKeyword(), result); |
| if (numGangs.hasValue() && failed(*numGangs)) |
| return failure(); |
| |
| // num_workers(value)? |
| numWorkers = parseOptionalOperandAndType( |
| parser, ParallelOp::getNumWorkersKeyword(), result); |
| if (numWorkers.hasValue() && failed(*numWorkers)) |
| return failure(); |
| |
| // vector_length(value)? |
| vectorLength = parseOptionalOperandAndType( |
| parser, ParallelOp::getVectorLengthKeyword(), result); |
| if (vectorLength.hasValue() && failed(*vectorLength)) |
| return failure(); |
| |
| // if()? |
| if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, |
| i1Type, hasIfCond, result))) |
| return failure(); |
| |
| // self()? |
| if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), |
| selfCond, i1Type, hasSelfCond, result))) |
| return failure(); |
| |
| // reduction()? |
| if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), |
| reductionOperands, reductionOperandTypes, |
| result))) |
| return failure(); |
| |
| // copy()? |
| if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), |
| copyOperands, copyOperandTypes, result))) |
| return failure(); |
| |
| // copyin()? |
| if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), |
| copyinOperands, copyinOperandTypes, result))) |
| return failure(); |
| |
| // copyin_readonly()? |
| if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(), |
| copyinReadonlyOperands, |
| copyinReadonlyOperandTypes, result))) |
| return failure(); |
| |
| // copyout()? |
| if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), |
| copyoutOperands, copyoutOperandTypes, result))) |
| return failure(); |
| |
| // copyout_zero()? |
| if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(), |
| copyoutZeroOperands, copyoutZeroOperandTypes, |
| result))) |
| return failure(); |
| |
| // create()? |
| if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), |
| createOperands, createOperandTypes, result))) |
| return failure(); |
| |
| // create_zero()? |
| if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(), |
| createZeroOperands, createZeroOperandTypes, |
| result))) |
| return failure(); |
| |
| // no_create()? |
| if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), |
| noCreateOperands, noCreateOperandTypes, result))) |
| return failure(); |
| |
| // present()? |
| if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), |
| presentOperands, presentOperandTypes, result))) |
| return failure(); |
| |
| // deviceptr()? |
| if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), |
| devicePtrOperands, deviceptrOperandTypes, |
| result))) |
| return failure(); |
| |
| // attach()? |
| if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), |
| attachOperands, attachOperandTypes, result))) |
| return failure(); |
| |
| // private()? |
| if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), |
| privateOperands, privateOperandTypes, result))) |
| return failure(); |
| |
| // firstprivate()? |
| if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), |
| firstprivateOperands, firstprivateOperandTypes, |
| result))) |
| return failure(); |
| |
| // Parallel op region |
| if (failed(parseRegions<ParallelOp>(parser, result))) |
| return failure(); |
| |
| result.addAttribute( |
| ParallelOp::getOperandSegmentSizeAttr(), |
| builder.getI32VectorAttr( |
| {static_cast<int32_t>(async.hasValue() ? 1 : 0), |
| static_cast<int32_t>(waitOperands.size()), |
| static_cast<int32_t>(numGangs.hasValue() ? 1 : 0), |
| static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0), |
| static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0), |
| static_cast<int32_t>(hasIfCond ? 1 : 0), |
| static_cast<int32_t>(hasSelfCond ? 1 : 0), |
| static_cast<int32_t>(reductionOperands.size()), |
| static_cast<int32_t>(copyOperands.size()), |
| static_cast<int32_t>(copyinOperands.size()), |
| static_cast<int32_t>(copyinReadonlyOperands.size()), |
| static_cast<int32_t>(copyoutOperands.size()), |
| static_cast<int32_t>(copyoutZeroOperands.size()), |
| static_cast<int32_t>(createOperands.size()), |
| static_cast<int32_t>(createZeroOperands.size()), |
| static_cast<int32_t>(noCreateOperands.size()), |
| static_cast<int32_t>(presentOperands.size()), |
| static_cast<int32_t>(devicePtrOperands.size()), |
| static_cast<int32_t>(attachOperands.size()), |
| static_cast<int32_t>(privateOperands.size()), |
| static_cast<int32_t>(firstprivateOperands.size())})); |
| |
| // Additional attributes |
| if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter &printer, ParallelOp &op) { |
| // async()? |
| if (Value async = op.async()) |
| printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " |
| << async.getType() << ")"; |
| |
| // wait()? |
| printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer); |
| |
| // num_gangs()? |
| if (Value numGangs = op.numGangs()) |
| printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs |
| << ": " << numGangs.getType() << ")"; |
| |
| // num_workers()? |
| if (Value numWorkers = op.numWorkers()) |
| printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers |
| << ": " << numWorkers.getType() << ")"; |
| |
| // vector_length()? |
| if (Value vectorLength = op.vectorLength()) |
| printer << " " << ParallelOp::getVectorLengthKeyword() << "(" |
| << vectorLength << ": " << vectorLength.getType() << ")"; |
| |
| // if()? |
| if (Value ifCond = op.ifCond()) |
| printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; |
| |
| // self()? |
| if (Value selfCond = op.selfCond()) |
| printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; |
| |
| // reduction()? |
| printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(), |
| printer); |
| |
| // copy()? |
| printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer); |
| |
| // copyin()? |
| printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(), |
| printer); |
| |
| // copyin_readonly()? |
| printOperandList(op.copyinReadonlyOperands(), |
| ParallelOp::getCopyinReadonlyKeyword(), printer); |
| |
| // copyout()? |
| printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(), |
| printer); |
| |
| // copyout_zero()? |
| printOperandList(op.copyoutZeroOperands(), |
| ParallelOp::getCopyoutZeroKeyword(), printer); |
| |
| // create()? |
| printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(), |
| printer); |
| |
| // create_zero()? |
| printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(), |
| printer); |
| |
| // no_create()? |
| printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(), |
| printer); |
| |
| // present()? |
| printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(), |
| printer); |
| |
| // deviceptr()? |
| printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), |
| printer); |
| |
| // attach()? |
| printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(), |
| printer); |
| |
| // private()? |
| printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(), |
| printer); |
| |
| // firstprivate()? |
| printOperandList(op.gangFirstPrivateOperands(), |
| ParallelOp::getFirstPrivateKeyword(), printer); |
| |
| printer.printRegion(op.region(), |
| /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| printer.printOptionalAttrDictWithKeyword( |
| op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); |
| } |
| |
| unsigned ParallelOp::getNumDataOperands() { |
| return reductionOperands().size() + copyOperands().size() + |
| copyinOperands().size() + copyinReadonlyOperands().size() + |
| copyoutOperands().size() + copyoutZeroOperands().size() + |
| createOperands().size() + createZeroOperands().size() + |
| noCreateOperands().size() + presentOperands().size() + |
| devicePtrOperands().size() + attachOperands().size() + |
| gangPrivateOperands().size() + gangFirstPrivateOperands().size(); |
| } |
| |
| Value ParallelOp::getDataOperand(unsigned i) { |
| unsigned numOptional = async() ? 1 : 0; |
| numOptional += numGangs() ? 1 : 0; |
| numOptional += numWorkers() ? 1 : 0; |
| numOptional += vectorLength() ? 1 : 0; |
| numOptional += ifCond() ? 1 : 0; |
| numOptional += selfCond() ? 1 : 0; |
| return getOperand(waitOperands().size() + numOptional + i); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LoopOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Parse acc.loop operation |
| /// operation := `acc.loop` |
| /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? |
| /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? |
| /// (`vector_length` `(` value `)`)? |
| /// (`tile` `(` value-list `)`)? |
| /// (`private` `(` value-list `)`)? |
| /// (`reduction` `(` value-list `)`)? |
| /// region attr-dict? |
| static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { |
| Builder &builder = parser.getBuilder(); |
| unsigned executionMapping = OpenACCExecMapping::NONE; |
| SmallVector<Type, 8> operandTypes; |
| SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands; |
| SmallVector<OpAsmParser::OperandType, 8> tileOperands; |
| OptionalParseResult gangNum, gangStatic, worker, vector; |
| |
| // gang? |
| if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) |
| executionMapping |= OpenACCExecMapping::GANG; |
| |
| // optional gang operand |
| if (succeeded(parser.parseOptionalLParen())) { |
| gangNum = parserOptionalOperandAndTypeWithPrefix( |
| parser, result, LoopOp::getGangNumKeyword()); |
| if (gangNum.hasValue() && failed(*gangNum)) |
| return failure(); |
| parser.parseOptionalComma(); |
| gangStatic = parserOptionalOperandAndTypeWithPrefix( |
| parser, result, LoopOp::getGangStaticKeyword()); |
| if (gangStatic.hasValue() && failed(*gangStatic)) |
| return failure(); |
| parser.parseOptionalComma(); |
| if (failed(parser.parseRParen())) |
| return failure(); |
| } |
| |
| // worker? |
| if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) |
| executionMapping |= OpenACCExecMapping::WORKER; |
| |
| // optional worker operand |
| worker = parseOptionalOperandAndType(parser, result); |
| if (worker.hasValue() && failed(*worker)) |
| return failure(); |
| |
| // vector? |
| if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) |
| executionMapping |= OpenACCExecMapping::VECTOR; |
| |
| // optional vector operand |
| vector = parseOptionalOperandAndType(parser, result); |
| if (vector.hasValue() && failed(*vector)) |
| return failure(); |
| |
| // tile()? |
| if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, |
| operandTypes, result))) |
| return failure(); |
| |
| // private()? |
| if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), |
| privateOperands, operandTypes, result))) |
| return failure(); |
| |
| // reduction()? |
| if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), |
| reductionOperands, operandTypes, result))) |
| return failure(); |
| |
| if (executionMapping != acc::OpenACCExecMapping::NONE) |
| result.addAttribute(LoopOp::getExecutionMappingAttrName(), |
| builder.getI64IntegerAttr(executionMapping)); |
| |
| // Parse optional results in case there is a reduce. |
| if (parser.parseOptionalArrowTypeList(result.types)) |
| return failure(); |
| |
| if (failed(parseRegions<LoopOp>(parser, result))) |
| return failure(); |
| |
| result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), |
| builder.getI32VectorAttr( |
| {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0), |
| static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0), |
| static_cast<int32_t>(worker.hasValue() ? 1 : 0), |
| static_cast<int32_t>(vector.hasValue() ? 1 : 0), |
| static_cast<int32_t>(tileOperands.size()), |
| static_cast<int32_t>(privateOperands.size()), |
| static_cast<int32_t>(reductionOperands.size())})); |
| |
| if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter &printer, LoopOp &op) { |
| unsigned execMapping = op.exec_mapping(); |
| if (execMapping & OpenACCExecMapping::GANG) { |
| printer << " " << LoopOp::getGangKeyword(); |
| Value gangNum = op.gangNum(); |
| Value gangStatic = op.gangStatic(); |
| |
| // Print optional gang operands |
| if (gangNum || gangStatic) { |
| printer << "("; |
| if (gangNum) { |
| printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " |
| << gangNum.getType(); |
| if (gangStatic) |
| printer << ", "; |
| } |
| if (gangStatic) |
| printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " |
| << gangStatic.getType(); |
| printer << ")"; |
| } |
| } |
| |
| if (execMapping & OpenACCExecMapping::WORKER) { |
| printer << " " << LoopOp::getWorkerKeyword(); |
| |
| // Print optional worker operand if present |
| if (Value workerNum = op.workerNum()) |
| printer << "(" << workerNum << ": " << workerNum.getType() << ")"; |
| } |
| |
| if (execMapping & OpenACCExecMapping::VECTOR) { |
| printer << " " << LoopOp::getVectorKeyword(); |
| |
| // Print optional vector operand if present |
| if (Value vectorLength = op.vectorLength()) |
| printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; |
| } |
| |
| // tile()? |
| printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer); |
| |
| // private()? |
| printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer); |
| |
| // reduction()? |
| printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(), |
| printer); |
| |
| if (op.getNumResults() > 0) |
| printer << " -> (" << op.getResultTypes() << ")"; |
| |
| printer.printRegion(op.region(), |
| /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| |
| printer.printOptionalAttrDictWithKeyword( |
| op->getAttrs(), {LoopOp::getExecutionMappingAttrName(), |
| LoopOp::getOperandSegmentSizeAttr()}); |
| } |
| |
| static LogicalResult verifyLoopOp(acc::LoopOp loopOp) { |
| // auto, independent and seq attribute are mutually exclusive. |
| if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) || |
| (loopOp.independent() && loopOp.seq())) { |
| loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + |
| acc::LoopOp::getIndependentAttrName() + ", " + |
| acc::LoopOp::getSeqAttrName() + |
| " can be present at the same time"); |
| return failure(); |
| } |
| |
| // Gang, worker and vector are incompatible with seq. |
| if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) { |
| loopOp.emitError("gang, worker or vector cannot appear with the seq attr"); |
| return failure(); |
| } |
| |
| // Check non-empty body(). |
| if (loopOp.region().empty()) { |
| loopOp.emitError("expected non-empty body."); |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DataOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::DataOp dataOp) { |
| // 2.6.5. Data Construct restriction |
| // At least one copy, copyin, copyout, create, no_create, present, deviceptr, |
| // attach, or default clause must appear on a data construct. |
| if (dataOp.getOperands().empty() && !dataOp.defaultAttr()) |
| return dataOp.emitError("at least one operand or the default attribute " |
| "must appear on the data operation"); |
| return success(); |
| } |
| |
| unsigned DataOp::getNumDataOperands() { |
| return copyOperands().size() + copyinOperands().size() + |
| copyinReadonlyOperands().size() + copyoutOperands().size() + |
| copyoutZeroOperands().size() + createOperands().size() + |
| createZeroOperands().size() + noCreateOperands().size() + |
| presentOperands().size() + deviceptrOperands().size() + |
| attachOperands().size(); |
| } |
| |
| Value DataOp::getDataOperand(unsigned i) { |
| unsigned numOptional = ifCond() ? 1 : 0; |
| return getOperand(numOptional + i); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ExitDataOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::ExitDataOp op) { |
| // 2.6.6. Data Exit Directive restriction |
| // At least one copyout, delete, or detach clause must appear on an exit data |
| // directive. |
| if (op.copyoutOperands().empty() && op.deleteOperands().empty() && |
| op.detachOperands().empty()) |
| return op.emitError( |
| "at least one operand in copyout, delete or detach must appear on the " |
| "exit data operation"); |
| |
| // The async attribute represent the async clause without value. Therefore the |
| // attribute and operand cannot appear at the same time. |
| if (op.asyncOperand() && op.async()) |
| return op.emitError("async attribute cannot appear with asyncOperand"); |
| |
| // The wait attribute represent the wait clause without values. Therefore the |
| // attribute and operands cannot appear at the same time. |
| if (!op.waitOperands().empty() && op.wait()) |
| return op.emitError("wait attribute cannot appear with waitOperands"); |
| |
| if (op.waitDevnum() && op.waitOperands().empty()) |
| return op.emitError("wait_devnum cannot appear without waitOperands"); |
| |
| return success(); |
| } |
| |
| unsigned ExitDataOp::getNumDataOperands() { |
| return copyoutOperands().size() + deleteOperands().size() + |
| detachOperands().size(); |
| } |
| |
| Value ExitDataOp::getDataOperand(unsigned i) { |
| unsigned numOptional = ifCond() ? 1 : 0; |
| numOptional += asyncOperand() ? 1 : 0; |
| numOptional += waitDevnum() ? 1 : 0; |
| return getOperand(waitOperands().size() + numOptional + i); |
| } |
| |
| void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<RemoveConstantIfCondition<ExitDataOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // EnterDataOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::EnterDataOp op) { |
| // 2.6.6. Data Enter Directive restriction |
| // At least one copyin, create, or attach clause must appear on an enter data |
| // directive. |
| if (op.copyinOperands().empty() && op.createOperands().empty() && |
| op.createZeroOperands().empty() && op.attachOperands().empty()) |
| return op.emitError( |
| "at least one operand in copyin, create, " |
| "create_zero or attach must appear on the enter data operation"); |
| |
| // The async attribute represent the async clause without value. Therefore the |
| // attribute and operand cannot appear at the same time. |
| if (op.asyncOperand() && op.async()) |
| return op.emitError("async attribute cannot appear with asyncOperand"); |
| |
| // The wait attribute represent the wait clause without values. Therefore the |
| // attribute and operands cannot appear at the same time. |
| if (!op.waitOperands().empty() && op.wait()) |
| return op.emitError("wait attribute cannot appear with waitOperands"); |
| |
| if (op.waitDevnum() && op.waitOperands().empty()) |
| return op.emitError("wait_devnum cannot appear without waitOperands"); |
| |
| return success(); |
| } |
| |
| unsigned EnterDataOp::getNumDataOperands() { |
| return copyinOperands().size() + createOperands().size() + |
| createZeroOperands().size() + attachOperands().size(); |
| } |
| |
| Value EnterDataOp::getDataOperand(unsigned i) { |
| unsigned numOptional = ifCond() ? 1 : 0; |
| numOptional += asyncOperand() ? 1 : 0; |
| numOptional += waitDevnum() ? 1 : 0; |
| return getOperand(waitOperands().size() + numOptional + i); |
| } |
| |
| void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<RemoveConstantIfCondition<EnterDataOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InitOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::InitOp initOp) { |
| Operation *currOp = initOp; |
| while ((currOp = currOp->getParentOp())) { |
| if (isComputeOperation(currOp)) |
| return initOp.emitOpError("cannot be nested in a compute operation"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ShutdownOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::ShutdownOp op) { |
| Operation *currOp = op; |
| while ((currOp = currOp->getParentOp())) { |
| if (isComputeOperation(currOp)) |
| return op.emitOpError("cannot be nested in a compute operation"); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UpdateOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::UpdateOp updateOp) { |
| // At least one of host or device should have a value. |
| if (updateOp.hostOperands().empty() && updateOp.deviceOperands().empty()) |
| return updateOp.emitError("at least one value must be present in" |
| " hostOperands or deviceOperands"); |
| |
| // The async attribute represent the async clause without value. Therefore the |
| // attribute and operand cannot appear at the same time. |
| if (updateOp.asyncOperand() && updateOp.async()) |
| return updateOp.emitError("async attribute cannot appear with " |
| " asyncOperand"); |
| |
| // The wait attribute represent the wait clause without values. Therefore the |
| // attribute and operands cannot appear at the same time. |
| if (!updateOp.waitOperands().empty() && updateOp.wait()) |
| return updateOp.emitError("wait attribute cannot appear with waitOperands"); |
| |
| if (updateOp.waitDevnum() && updateOp.waitOperands().empty()) |
| return updateOp.emitError("wait_devnum cannot appear without waitOperands"); |
| |
| return success(); |
| } |
| |
| unsigned UpdateOp::getNumDataOperands() { |
| return hostOperands().size() + deviceOperands().size(); |
| } |
| |
| Value UpdateOp::getDataOperand(unsigned i) { |
| unsigned numOptional = asyncOperand() ? 1 : 0; |
| numOptional += waitDevnum() ? 1 : 0; |
| numOptional += ifCond() ? 1 : 0; |
| return getOperand(waitOperands().size() + deviceTypeOperands().size() + |
| numOptional + i); |
| } |
| |
| void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<RemoveConstantIfCondition<UpdateOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // WaitOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(acc::WaitOp waitOp) { |
| // The async attribute represent the async clause without value. Therefore the |
| // attribute and operand cannot appear at the same time. |
| if (waitOp.asyncOperand() && waitOp.async()) |
| return waitOp.emitError("async attribute cannot appear with asyncOperand"); |
| |
| if (waitOp.waitDevnum() && waitOp.waitOperands().empty()) |
| return waitOp.emitError("wait_devnum cannot appear without waitOperands"); |
| |
| return success(); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" |