| //===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "TestOpsSyntax.h" |
| #include "TestDialect.h" |
| #include "TestOps.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "llvm/Support/Base64.h" |
| |
| using namespace mlir; |
| using namespace test; |
| |
| //===----------------------------------------------------------------------===// |
| // Test Format* operations |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // Parsing |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseCustomOptionalOperand( |
| OpAsmParser &parser, |
| std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { |
| if (succeeded(parser.parseOptionalLParen())) { |
| optOperand.emplace(); |
| if (parser.parseOperand(*optOperand) || parser.parseRParen()) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| static ParseResult parseCustomDirectiveOperands( |
| OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, |
| std::optional<OpAsmParser::UnresolvedOperand> &optOperand, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { |
| if (parser.parseOperand(operand)) |
| return failure(); |
| if (succeeded(parser.parseOptionalComma())) { |
| optOperand.emplace(); |
| if (parser.parseOperand(*optOperand)) |
| return failure(); |
| } |
| if (parser.parseArrow() || parser.parseLParen() || |
| parser.parseOperandList(varOperands) || parser.parseRParen()) |
| return failure(); |
| return success(); |
| } |
| static ParseResult |
| parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, |
| Type &optOperandType, |
| SmallVectorImpl<Type> &varOperandTypes) { |
| if (parser.parseColon()) |
| return failure(); |
| |
| if (parser.parseType(operandType)) |
| return failure(); |
| if (succeeded(parser.parseOptionalComma())) { |
| if (parser.parseType(optOperandType)) |
| return failure(); |
| } |
| if (parser.parseArrow() || parser.parseLParen() || |
| parser.parseTypeList(varOperandTypes) || parser.parseRParen()) |
| return failure(); |
| return success(); |
| } |
| static ParseResult |
| parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, |
| Type optOperandType, |
| const SmallVectorImpl<Type> &varOperandTypes) { |
| if (parser.parseKeyword("type_refs_capture")) |
| return failure(); |
| |
| Type operandType2, optOperandType2; |
| SmallVector<Type, 1> varOperandTypes2; |
| if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, |
| varOperandTypes2)) |
| return failure(); |
| |
| if (operandType != operandType2 || optOperandType != optOperandType2 || |
| varOperandTypes != varOperandTypes2) |
| return failure(); |
| |
| return success(); |
| } |
| static ParseResult parseCustomDirectiveOperandsAndTypes( |
| OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, |
| std::optional<OpAsmParser::UnresolvedOperand> &optOperand, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, |
| Type &operandType, Type &optOperandType, |
| SmallVectorImpl<Type> &varOperandTypes) { |
| if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || |
| parseCustomDirectiveResults(parser, operandType, optOperandType, |
| varOperandTypes)) |
| return failure(); |
| return success(); |
| } |
| static ParseResult parseCustomDirectiveRegions( |
| OpAsmParser &parser, Region ®ion, |
| SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { |
| if (parser.parseRegion(region)) |
| return failure(); |
| if (failed(parser.parseOptionalComma())) |
| return success(); |
| std::unique_ptr<Region> varRegion = std::make_unique<Region>(); |
| if (parser.parseRegion(*varRegion)) |
| return failure(); |
| varRegions.emplace_back(std::move(varRegion)); |
| return success(); |
| } |
| static ParseResult |
| parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, |
| SmallVectorImpl<Block *> &varSuccessors) { |
| if (parser.parseSuccessor(successor)) |
| return failure(); |
| if (failed(parser.parseOptionalComma())) |
| return success(); |
| Block *varSuccessor; |
| if (parser.parseSuccessor(varSuccessor)) |
| return failure(); |
| varSuccessors.append(2, varSuccessor); |
| return success(); |
| } |
| static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, |
| IntegerAttr &attr, |
| IntegerAttr &optAttr) { |
| if (parser.parseAttribute(attr)) |
| return failure(); |
| if (succeeded(parser.parseOptionalComma())) { |
| if (parser.parseAttribute(optAttr)) |
| return failure(); |
| } |
| return success(); |
| } |
| static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser, |
| mlir::StringAttr &attr) { |
| return parser.parseAttribute(attr); |
| } |
| static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, |
| NamedAttrList &attrs) { |
| return parser.parseOptionalAttrDict(attrs); |
| } |
| static ParseResult parseCustomDirectiveOptionalOperandRef( |
| OpAsmParser &parser, |
| std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { |
| int64_t operandCount = 0; |
| if (parser.parseInteger(operandCount)) |
| return failure(); |
| bool expectedOptionalOperand = operandCount == 0; |
| return success(expectedOptionalOperand != optOperand.has_value()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing |
| //===----------------------------------------------------------------------===// |
| |
| static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, |
| Value optOperand) { |
| if (optOperand) |
| printer << "(" << optOperand << ") "; |
| } |
| |
| static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, |
| Value operand, Value optOperand, |
| OperandRange varOperands) { |
| printer << operand; |
| if (optOperand) |
| printer << ", " << optOperand; |
| printer << " -> (" << varOperands << ")"; |
| } |
| static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, |
| Type operandType, Type optOperandType, |
| TypeRange varOperandTypes) { |
| printer << " : " << operandType; |
| if (optOperandType) |
| printer << ", " << optOperandType; |
| printer << " -> (" << varOperandTypes << ")"; |
| } |
| static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, |
| Operation *op, Type operandType, |
| Type optOperandType, |
| TypeRange varOperandTypes) { |
| printer << " type_refs_capture "; |
| printCustomDirectiveResults(printer, op, operandType, optOperandType, |
| varOperandTypes); |
| } |
| static void printCustomDirectiveOperandsAndTypes( |
| OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, |
| OperandRange varOperands, Type operandType, Type optOperandType, |
| TypeRange varOperandTypes) { |
| printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); |
| printCustomDirectiveResults(printer, op, operandType, optOperandType, |
| varOperandTypes); |
| } |
| static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, |
| Region ®ion, |
| MutableArrayRef<Region> varRegions) { |
| printer.printRegion(region); |
| if (!varRegions.empty()) { |
| printer << ", "; |
| for (Region ®ion : varRegions) |
| printer.printRegion(region); |
| } |
| } |
| static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, |
| Block *successor, |
| SuccessorRange varSuccessors) { |
| printer << successor; |
| if (!varSuccessors.empty()) |
| printer << ", " << varSuccessors.front(); |
| } |
| static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, |
| Attribute attribute, |
| Attribute optAttribute) { |
| printer << attribute; |
| if (optAttribute) |
| printer << ", " << optAttribute; |
| } |
| static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op, |
| Attribute attribute) { |
| printer << attribute; |
| } |
| static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, |
| DictionaryAttr attrs) { |
| printer.printOptionalAttrDict(attrs.getValue()); |
| } |
| |
| static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, |
| Operation *op, |
| Value optOperand) { |
| printer << (optOperand ? "1" : "0"); |
| } |
| //===----------------------------------------------------------------------===// |
| // Test parser. |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| if (parser.parseOptionalColon()) |
| return success(); |
| uint64_t numResults; |
| if (parser.parseInteger(numResults)) |
| return failure(); |
| |
| IndexType type = parser.getBuilder().getIndexType(); |
| for (unsigned i = 0; i < numResults; ++i) |
| result.addTypes(type); |
| return success(); |
| } |
| |
| void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { |
| if (unsigned numResults = getNumResults()) |
| p << " : " << numResults; |
| } |
| |
| ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| StringRef keyword; |
| if (parser.parseKeyword(&keyword)) |
| return failure(); |
| result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); |
| return success(); |
| } |
| |
| void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } |
| |
| ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| std::vector<char> bytes; |
| if (parser.parseBase64Bytes(&bytes)) |
| return failure(); |
| result.addAttribute("b64", parser.getBuilder().getStringAttr( |
| StringRef(&bytes.front(), bytes.size()))); |
| return success(); |
| } |
| |
| void ParseB64BytesOp::print(OpAsmPrinter &p) { |
| p << " \"" << llvm::encodeBase64(getB64()) << "\""; |
| } |
| |
| ::llvm::LogicalResult FormatInferType2Op::inferReturnTypes( |
| ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, |
| ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, |
| OpaqueProperties properties, ::mlir::RegionRange regions, |
| ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { |
| inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); |
| return ::mlir::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult WrappingRegionOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| if (parser.parseKeyword("wraps")) |
| return failure(); |
| |
| // Parse the wrapped op in a region |
| Region &body = *result.addRegion(); |
| body.push_back(new Block); |
| Block &block = body.back(); |
| Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); |
| if (!wrappedOp) |
| return failure(); |
| |
| // Create a return terminator in the inner region, pass as operand to the |
| // terminator the returned values from the wrapped operation. |
| SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); |
| OpBuilder builder(parser.getContext()); |
| builder.setInsertionPointToEnd(&block); |
| builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); |
| |
| // Get the results type for the wrapping op from the terminator operands. |
| Operation &returnOp = body.back().back(); |
| result.types.append(returnOp.operand_type_begin(), |
| returnOp.operand_type_end()); |
| |
| // Use the location of the wrapped op for the "test.wrapping_region" op. |
| result.location = wrappedOp->getLoc(); |
| |
| return success(); |
| } |
| |
| void WrappingRegionOp::print(OpAsmPrinter &p) { |
| p << " wraps "; |
| p.printGenericOp(&getRegion().front().front()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Test PrettyPrintedRegionOp - exercising the following parser APIs |
| // parseGenericOperationAfterOpName |
| // parseCustomOperationName |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| |
| SMLoc loc = parser.getCurrentLocation(); |
| Location currLocation = parser.getEncodedSourceLoc(loc); |
| |
| // Parse the operands. |
| SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; |
| if (parser.parseOperandList(operands)) |
| return failure(); |
| |
| // Check if we are parsing the pretty-printed version |
| // test.pretty_printed_region start <inner-op> end : <functional-type> |
| // Else fallback to parsing the "non pretty-printed" version. |
| if (!succeeded(parser.parseOptionalKeyword("start"))) |
| return parser.parseGenericOperationAfterOpName(result, |
| llvm::ArrayRef(operands)); |
| |
| FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); |
| if (failed(parseOpNameInfo)) |
| return failure(); |
| |
| StringAttr innerOpName = parseOpNameInfo->getIdentifier(); |
| |
| FunctionType opFntype; |
| std::optional<Location> explicitLoc; |
| if (parser.parseKeyword("end") || parser.parseColon() || |
| parser.parseType(opFntype) || |
| parser.parseOptionalLocationSpecifier(explicitLoc)) |
| return failure(); |
| |
| // If location of the op is explicitly provided, then use it; Else use |
| // the parser's current location. |
| Location opLoc = explicitLoc.value_or(currLocation); |
| |
| // Derive the SSA-values for op's operands. |
| if (parser.resolveOperands(operands, opFntype.getInputs(), loc, |
| result.operands)) |
| return failure(); |
| |
| // Add a region for op. |
| Region ®ion = *result.addRegion(); |
| |
| // Create a basic-block inside op's region. |
| Block &block = region.emplaceBlock(); |
| |
| // Create and insert an "inner-op" operation in the block. |
| // Just for testing purposes, we can assume that inner op is a binary op with |
| // result and operand types all same as the test-op's first operand. |
| Type innerOpType = opFntype.getInput(0); |
| Value lhs = block.addArgument(innerOpType, opLoc); |
| Value rhs = block.addArgument(innerOpType, opLoc); |
| |
| OpBuilder builder(parser.getBuilder().getContext()); |
| builder.setInsertionPointToStart(&block); |
| |
| Operation *innerOp = |
| builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); |
| |
| // Insert a return statement in the block returning the inner-op's result. |
| builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); |
| |
| // Populate the op operation-state with result-type and location. |
| result.addTypes(opFntype.getResults()); |
| result.location = innerOp->getLoc(); |
| |
| return success(); |
| } |
| |
| void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { |
| p << ' '; |
| p.printOperands(getOperands()); |
| |
| Operation &innerOp = getRegion().front().front(); |
| // Assuming that region has a single non-terminator inner-op, if the inner-op |
| // meets some criteria (which in this case is a simple one based on the name |
| // of inner-op), then we can print the entire region in a succinct way. |
| // Here we assume that the prototype of "test.special.op" can be trivially |
| // derived while parsing it back. |
| if (innerOp.getName().getStringRef() == "test.special.op") { |
| p << " start test.special.op end"; |
| } else { |
| p << " ("; |
| p.printRegion(getRegion()); |
| p << ")"; |
| } |
| |
| p << " : "; |
| p.printFunctionalType(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Test PolyForOp - parse list of region arguments. |
| //===----------------------------------------------------------------------===// |
| |
| ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::Argument, 4> ivsInfo; |
| // Parse list of region arguments without a delimiter. |
| if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) |
| return failure(); |
| |
| // Parse the body region. |
| Region *body = result.addRegion(); |
| for (auto &iv : ivsInfo) |
| iv.type = parser.getBuilder().getIndexType(); |
| return parser.parseRegion(*body, ivsInfo); |
| } |
| |
| void PolyForOp::print(OpAsmPrinter &p) { |
| p << " "; |
| llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) { |
| p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true); |
| }); |
| p << " "; |
| p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| } |
| |
| void PolyForOp::getAsmBlockArgumentNames(Region ®ion, |
| OpAsmSetValueNameFn setNameFn) { |
| auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); |
| if (!arrayAttr) |
| return; |
| auto args = getRegion().front().getArguments(); |
| auto e = std::min(arrayAttr.size(), args.size()); |
| for (unsigned i = 0; i < e; ++i) { |
| if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i])) |
| setNameFn(args[i], strAttr.getValue()); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestAttrWithLoc - parse/printOptionalLocationSpecifier |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) { |
| std::optional<Location> result; |
| SMLoc sourceLoc = p.getCurrentLocation(); |
| if (p.parseOptionalLocationSpecifier(result)) |
| return failure(); |
| if (result) |
| loc = *result; |
| else |
| loc = p.getEncodedSourceLoc(sourceLoc); |
| return success(); |
| } |
| |
| static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { |
| p.printOptionalLocationSpecifier(cast<LocationAttr>(loc)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ParseCustomOperationNameAPI |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult parseCustomOperationNameEntry(OpAsmParser &p, |
| Attribute &name) { |
| FailureOr<OperationName> opName = p.parseCustomOperationName(); |
| if (failed(opName)) |
| return ParseResult::failure(); |
| |
| name = p.getBuilder().getStringAttr(opName->getStringRef()); |
| return ParseResult::success(); |
| } |
| |
| static void printCustomOperationNameEntry(OpAsmPrinter &p, Operation *op, |
| Attribute name) { |
| p << cast<StringAttr>(name).getValue(); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "TestOpsSyntax.cpp.inc" |
| |
| void TestDialect::registerOpsSyntax() { |
| addOperations< |
| #define GET_OP_LIST |
| #include "TestOpsSyntax.cpp.inc" |
| >(); |
| } |