| //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the types and operation details for the LLVM IR dialect in |
| // MLIR, and the LLVM IR dialect. It also registers the dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "TypeDetail.h" |
| #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/Interfaces/FunctionImplementation.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| |
| #include "llvm/ADT/SCCIterator.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/AsmParser/Parser.h" |
| #include "llvm/Bitcode/BitcodeReader.h" |
| #include "llvm/Bitcode/BitcodeWriter.h" |
| #include "llvm/IR/Attributes.h" |
| #include "llvm/IR/Function.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/Support/Error.h" |
| #include "llvm/Support/Mutex.h" |
| #include "llvm/Support/SourceMgr.h" |
| |
| #include <numeric> |
| #include <optional> |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| using mlir::LLVM::cconv::getMaxEnumValForCConv; |
| using mlir::LLVM::linkage::getMaxEnumValForLinkage; |
| using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind; |
| |
| #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // Attribute Helpers |
| //===----------------------------------------------------------------------===// |
| |
| static constexpr const char kElemTypeAttrName[] = "elem_type"; |
| |
| static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { |
| SmallVector<NamedAttribute, 8> filteredAttrs( |
| llvm::make_filter_range(attrs, [&](NamedAttribute attr) { |
| if (attr.getName() == "fastmathFlags") { |
| auto defAttr = |
| FastmathFlagsAttr::get(attr.getValue().getContext(), {}); |
| return defAttr != attr.getValue(); |
| } |
| return true; |
| })); |
| return filteredAttrs; |
| } |
| |
| /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and |
| /// fully defined llvm.func. |
| static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, |
| Operation *op, |
| SymbolTableCollection &symbolTable) { |
| StringRef name = symbol.getValue(); |
| auto func = |
| symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr()); |
| if (!func) |
| return op->emitOpError("'") |
| << name << "' does not reference a valid LLVM function"; |
| if (func.isExternal()) |
| return op->emitOpError("'") << name << "' does not have a definition"; |
| return success(); |
| } |
| |
| /// Returns a boolean type that has the same shape as `type`. It supports both |
| /// fixed size vectors as well as scalable vectors. |
| static Type getI1SameShape(Type type) { |
| Type i1Type = IntegerType::get(type.getContext(), 1); |
| if (LLVM::isCompatibleVectorType(type)) |
| return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type)); |
| return i1Type; |
| } |
| |
| // Parses one of the keywords provided in the list `keywords` and returns the |
| // position of the parsed keyword in the list. If none of the keywords from the |
| // list is parsed, returns -1. |
| static int parseOptionalKeywordAlternative(OpAsmParser &parser, |
| ArrayRef<StringRef> keywords) { |
| for (const auto &en : llvm::enumerate(keywords)) { |
| if (succeeded(parser.parseOptionalKeyword(en.value()))) |
| return en.index(); |
| } |
| return -1; |
| } |
| |
| namespace { |
| template <typename Ty> |
| struct EnumTraits {}; |
| |
| #define REGISTER_ENUM_TYPE(Ty) \ |
| template <> \ |
| struct EnumTraits<Ty> { \ |
| static StringRef stringify(Ty value) { return stringify##Ty(value); } \ |
| static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ |
| } |
| |
| REGISTER_ENUM_TYPE(Linkage); |
| REGISTER_ENUM_TYPE(UnnamedAddr); |
| REGISTER_ENUM_TYPE(CConv); |
| REGISTER_ENUM_TYPE(TailCallKind); |
| REGISTER_ENUM_TYPE(Visibility); |
| } // namespace |
| |
| /// Parse an enum from the keyword, or default to the provided default value. |
| /// The return type is the enum type by default, unless overridden with the |
| /// second template argument. |
| template <typename EnumTy, typename RetTy = EnumTy> |
| static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, |
| OperationState &result, |
| EnumTy defaultValue) { |
| SmallVector<StringRef, 10> names; |
| for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i) |
| names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i))); |
| |
| int index = parseOptionalKeywordAlternative(parser, names); |
| if (index == -1) |
| return static_cast<RetTy>(defaultValue); |
| return static_cast<RetTy>(index); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Operand bundle helpers. |
| //===----------------------------------------------------------------------===// |
| |
| static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, |
| TypeRange operandTypes, StringRef tag) { |
| p.printString(tag); |
| p << "("; |
| |
| if (!operands.empty()) { |
| p.printOperands(operands); |
| p << " : "; |
| llvm::interleaveComma(operandTypes, p); |
| } |
| |
| p << ")"; |
| } |
| |
| static void printOpBundles(OpAsmPrinter &p, Operation *op, |
| OperandRangeRange opBundleOperands, |
| TypeRangeRange opBundleOperandTypes, |
| std::optional<ArrayAttr> opBundleTags) { |
| if (opBundleOperands.empty()) |
| return; |
| assert(opBundleTags && "expect operand bundle tags"); |
| |
| p << "["; |
| llvm::interleaveComma( |
| llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p, |
| [&p](auto bundle) { |
| auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue(); |
| printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle), |
| bundleTag); |
| }); |
| p << "]"; |
| } |
| |
| static ParseResult parseOneOpBundle( |
| OpAsmParser &p, |
| SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands, |
| SmallVector<SmallVector<Type>> &opBundleOperandTypes, |
| SmallVector<Attribute> &opBundleTags) { |
| SMLoc currentParserLoc = p.getCurrentLocation(); |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| SmallVector<Type> types; |
| std::string tag; |
| |
| if (p.parseString(&tag)) |
| return p.emitError(currentParserLoc, "expect operand bundle tag"); |
| |
| if (p.parseLParen()) |
| return failure(); |
| |
| if (p.parseOptionalRParen()) { |
| if (p.parseOperandList(operands) || p.parseColon() || |
| p.parseTypeList(types) || p.parseRParen()) |
| return failure(); |
| } |
| |
| opBundleOperands.push_back(std::move(operands)); |
| opBundleOperandTypes.push_back(std::move(types)); |
| opBundleTags.push_back(StringAttr::get(p.getContext(), tag)); |
| |
| return success(); |
| } |
| |
| static std::optional<ParseResult> parseOpBundles( |
| OpAsmParser &p, |
| SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> &opBundleOperands, |
| SmallVector<SmallVector<Type>> &opBundleOperandTypes, |
| ArrayAttr &opBundleTags) { |
| if (p.parseOptionalLSquare()) |
| return std::nullopt; |
| |
| if (succeeded(p.parseOptionalRSquare())) |
| return success(); |
| |
| SmallVector<Attribute> opBundleTagAttrs; |
| auto bundleParser = [&] { |
| return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes, |
| opBundleTagAttrs); |
| }; |
| if (p.parseCommaSeparatedList(bundleParser)) |
| return failure(); |
| |
| if (p.parseRSquare()) |
| return failure(); |
| |
| opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing, parsing, folding and builder for LLVM::CmpOp. |
| //===----------------------------------------------------------------------===// |
| |
| void ICmpOp::print(OpAsmPrinter &p) { |
| p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) |
| << ", " << getOperand(1); |
| p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"}); |
| p << " : " << getLhs().getType(); |
| } |
| |
| void FCmpOp::print(OpAsmPrinter &p) { |
| p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0) |
| << ", " << getOperand(1); |
| p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"}); |
| p << " : " << getLhs().getType(); |
| } |
| |
| // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use |
| // attribute-dict? `:` type |
| // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use |
| // attribute-dict? `:` type |
| template <typename CmpPredicateType> |
| static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { |
| StringAttr predicateAttr; |
| OpAsmParser::UnresolvedOperand lhs, rhs; |
| Type type; |
| SMLoc predicateLoc, trailingTypeLoc; |
| if (parser.getCurrentLocation(&predicateLoc) || |
| parser.parseAttribute(predicateAttr, "predicate", result.attributes) || |
| parser.parseOperand(lhs) || parser.parseComma() || |
| parser.parseOperand(rhs) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) || |
| parser.resolveOperand(lhs, type, result.operands) || |
| parser.resolveOperand(rhs, type, result.operands)) |
| return failure(); |
| |
| // Replace the string attribute `predicate` with an integer attribute. |
| int64_t predicateValue = 0; |
| if (std::is_same<CmpPredicateType, ICmpPredicate>()) { |
| std::optional<ICmpPredicate> predicate = |
| symbolizeICmpPredicate(predicateAttr.getValue()); |
| if (!predicate) |
| return parser.emitError(predicateLoc) |
| << "'" << predicateAttr.getValue() |
| << "' is an incorrect value of the 'predicate' attribute"; |
| predicateValue = static_cast<int64_t>(*predicate); |
| } else { |
| std::optional<FCmpPredicate> predicate = |
| symbolizeFCmpPredicate(predicateAttr.getValue()); |
| if (!predicate) |
| return parser.emitError(predicateLoc) |
| << "'" << predicateAttr.getValue() |
| << "' is an incorrect value of the 'predicate' attribute"; |
| predicateValue = static_cast<int64_t>(*predicate); |
| } |
| |
| result.attributes.set("predicate", |
| parser.getBuilder().getI64IntegerAttr(predicateValue)); |
| |
| // The result type is either i1 or a vector type <? x i1> if the inputs are |
| // vectors. |
| if (!isCompatibleType(type)) |
| return parser.emitError(trailingTypeLoc, |
| "expected LLVM dialect-compatible type"); |
| result.addTypes(getI1SameShape(type)); |
| return success(); |
| } |
| |
| ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) { |
| return parseCmpOp<ICmpPredicate>(parser, result); |
| } |
| |
| ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) { |
| return parseCmpOp<FCmpPredicate>(parser, result); |
| } |
| |
| /// Returns a scalar or vector boolean attribute of the given type. |
| static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) { |
| auto boolAttr = BoolAttr::get(ctx, value); |
| ShapedType shapedType = dyn_cast<ShapedType>(type); |
| if (!shapedType) |
| return boolAttr; |
| return DenseElementsAttr::get(shapedType, boolAttr); |
| } |
| |
| OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) { |
| if (getPredicate() != ICmpPredicate::eq && |
| getPredicate() != ICmpPredicate::ne) |
| return {}; |
| |
| // cmpi(eq/ne, x, x) -> true/false |
| if (getLhs() == getRhs()) |
| return getBoolAttribute(getType(), getContext(), |
| getPredicate() == ICmpPredicate::eq); |
| |
| // cmpi(eq/ne, alloca, null) -> false/true |
| if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>()) |
| return getBoolAttribute(getType(), getContext(), |
| getPredicate() == ICmpPredicate::ne); |
| |
| // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null) |
| if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) { |
| Value lhs = getLhs(); |
| Value rhs = getRhs(); |
| getLhsMutable().assign(rhs); |
| getRhsMutable().assign(lhs); |
| return getResult(); |
| } |
| |
| return {}; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Printing, parsing and verification for LLVM::AllocaOp. |
| //===----------------------------------------------------------------------===// |
| |
| void AllocaOp::print(OpAsmPrinter &p) { |
| auto funcTy = |
| FunctionType::get(getContext(), {getArraySize().getType()}, {getType()}); |
| |
| if (getInalloca()) |
| p << " inalloca"; |
| |
| p << ' ' << getArraySize() << " x " << getElemType(); |
| if (getAlignment() && *getAlignment() != 0) |
| p.printOptionalAttrDict((*this)->getAttrs(), |
| {kElemTypeAttrName, getInallocaAttrName()}); |
| else |
| p.printOptionalAttrDict( |
| (*this)->getAttrs(), |
| {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()}); |
| p << " : " << funcTy; |
| } |
| |
| // <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type |
| // attribute-dict? `:` type `,` type |
| ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) { |
| OpAsmParser::UnresolvedOperand arraySize; |
| Type type, elemType; |
| SMLoc trailingTypeLoc; |
| |
| if (succeeded(parser.parseOptionalKeyword("inalloca"))) |
| result.addAttribute(getInallocaAttrName(result.name), |
| UnitAttr::get(parser.getContext())); |
| |
| if (parser.parseOperand(arraySize) || parser.parseKeyword("x") || |
| parser.parseType(elemType) || |
| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type)) |
| return failure(); |
| |
| std::optional<NamedAttribute> alignmentAttr = |
| result.attributes.getNamed("alignment"); |
| if (alignmentAttr.has_value()) { |
| auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue()); |
| if (!alignmentInt) |
| return parser.emitError(parser.getNameLoc(), |
| "expected integer alignment"); |
| if (alignmentInt.getValue().isZero()) |
| result.attributes.erase("alignment"); |
| } |
| |
| // Extract the result type from the trailing function type. |
| auto funcType = llvm::dyn_cast<FunctionType>(type); |
| if (!funcType || funcType.getNumInputs() != 1 || |
| funcType.getNumResults() != 1) |
| return parser.emitError( |
| trailingTypeLoc, |
| "expected trailing function type with one argument and one result"); |
| |
| if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands)) |
| return failure(); |
| |
| Type resultType = funcType.getResult(0); |
| if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType)) |
| result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType)); |
| |
| result.addTypes({funcType.getResult(0)}); |
| return success(); |
| } |
| |
| LogicalResult AllocaOp::verify() { |
| // Only certain target extension types can be used in 'alloca'. |
| if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType()); |
| targetExtType && !targetExtType.supportsMemOps()) |
| return emitOpError() |
| << "this target extension type cannot be used in alloca"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::BrOp |
| //===----------------------------------------------------------------------===// |
| |
| SuccessorOperands BrOp::getSuccessorOperands(unsigned index) { |
| assert(index == 0 && "invalid successor index"); |
| return SuccessorOperands(getDestOperandsMutable()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::CondBrOp |
| //===----------------------------------------------------------------------===// |
| |
| SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) { |
| assert(index < getNumSuccessors() && "invalid successor index"); |
| return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable() |
| : getFalseDestOperandsMutable()); |
| } |
| |
| void CondBrOp::build(OpBuilder &builder, OperationState &result, |
| Value condition, Block *trueDest, ValueRange trueOperands, |
| Block *falseDest, ValueRange falseOperands, |
| std::optional<std::pair<uint32_t, uint32_t>> weights) { |
| DenseI32ArrayAttr weightsAttr; |
| if (weights) |
| weightsAttr = |
| builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first), |
| static_cast<int32_t>(weights->second)}); |
| |
| build(builder, result, condition, trueOperands, falseOperands, weightsAttr, |
| /*loop_annotation=*/{}, trueDest, falseDest); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::SwitchOp |
| //===----------------------------------------------------------------------===// |
| |
| void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
| Block *defaultDestination, ValueRange defaultOperands, |
| DenseIntElementsAttr caseValues, |
| BlockRange caseDestinations, |
| ArrayRef<ValueRange> caseOperands, |
| ArrayRef<int32_t> branchWeights) { |
| DenseI32ArrayAttr weightsAttr; |
| if (!branchWeights.empty()) |
| weightsAttr = builder.getDenseI32ArrayAttr(branchWeights); |
| |
| build(builder, result, value, defaultOperands, caseOperands, caseValues, |
| weightsAttr, defaultDestination, caseDestinations); |
| } |
| |
| void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
| Block *defaultDestination, ValueRange defaultOperands, |
| ArrayRef<APInt> caseValues, BlockRange caseDestinations, |
| ArrayRef<ValueRange> caseOperands, |
| ArrayRef<int32_t> branchWeights) { |
| DenseIntElementsAttr caseValuesAttr; |
| if (!caseValues.empty()) { |
| ShapedType caseValueType = VectorType::get( |
| static_cast<int64_t>(caseValues.size()), value.getType()); |
| caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
| } |
| |
| build(builder, result, value, defaultDestination, defaultOperands, |
| caseValuesAttr, caseDestinations, caseOperands, branchWeights); |
| } |
| |
| void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
| Block *defaultDestination, ValueRange defaultOperands, |
| ArrayRef<int32_t> caseValues, BlockRange caseDestinations, |
| ArrayRef<ValueRange> caseOperands, |
| ArrayRef<int32_t> branchWeights) { |
| DenseIntElementsAttr caseValuesAttr; |
| if (!caseValues.empty()) { |
| ShapedType caseValueType = VectorType::get( |
| static_cast<int64_t>(caseValues.size()), value.getType()); |
| caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
| } |
| |
| build(builder, result, value, defaultDestination, defaultOperands, |
| caseValuesAttr, caseDestinations, caseOperands, branchWeights); |
| } |
| |
| /// <cases> ::= `[` (case (`,` case )* )? `]` |
| /// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? |
| static ParseResult parseSwitchOpCases( |
| OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, |
| SmallVectorImpl<Block *> &caseDestinations, |
| SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, |
| SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { |
| if (failed(parser.parseLSquare())) |
| return failure(); |
| if (succeeded(parser.parseOptionalRSquare())) |
| return success(); |
| SmallVector<APInt> values; |
| unsigned bitWidth = flagType.getIntOrFloatBitWidth(); |
| auto parseCase = [&]() { |
| int64_t value = 0; |
| if (failed(parser.parseInteger(value))) |
| return failure(); |
| values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); |
| |
| Block *destination; |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| SmallVector<Type> operandTypes; |
| if (parser.parseColon() || parser.parseSuccessor(destination)) |
| return failure(); |
| if (!parser.parseOptionalLParen()) { |
| if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, |
| /*allowResultNumber=*/false) || |
| parser.parseColonTypeList(operandTypes) || parser.parseRParen()) |
| return failure(); |
| } |
| caseDestinations.push_back(destination); |
| caseOperands.emplace_back(operands); |
| caseOperandTypes.emplace_back(operandTypes); |
| return success(); |
| }; |
| if (failed(parser.parseCommaSeparatedList(parseCase))) |
| return failure(); |
| |
| ShapedType caseValueType = |
| VectorType::get(static_cast<int64_t>(values.size()), flagType); |
| caseValues = DenseIntElementsAttr::get(caseValueType, values); |
| return parser.parseRSquare(); |
| } |
| |
| static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, |
| DenseIntElementsAttr caseValues, |
| SuccessorRange caseDestinations, |
| OperandRangeRange caseOperands, |
| const TypeRangeRange &caseOperandTypes) { |
| p << '['; |
| p.printNewline(); |
| if (!caseValues) { |
| p << ']'; |
| return; |
| } |
| |
| size_t index = 0; |
| llvm::interleave( |
| llvm::zip(caseValues, caseDestinations), |
| [&](auto i) { |
| p << " "; |
| p << std::get<0>(i); |
| p << ": "; |
| p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); |
| }, |
| [&] { |
| p << ','; |
| p.printNewline(); |
| }); |
| p.printNewline(); |
| p << ']'; |
| } |
| |
| LogicalResult SwitchOp::verify() { |
| if ((!getCaseValues() && !getCaseDestinations().empty()) || |
| (getCaseValues() && |
| getCaseValues()->size() != |
| static_cast<int64_t>(getCaseDestinations().size()))) |
| return emitOpError("expects number of case values to match number of " |
| "case destinations"); |
| if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) |
| return emitError("expects number of branch weights to match number of " |
| "successors: ") |
| << getBranchWeights()->size() << " vs " << getNumSuccessors(); |
| if (getCaseValues() && |
| getValue().getType() != getCaseValues()->getElementType()) |
| return emitError("expects case value type to match condition value type"); |
| return success(); |
| } |
| |
| SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { |
| assert(index < getNumSuccessors() && "invalid successor index"); |
| return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() |
| : getCaseOperandsMutable(index - 1)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Code for LLVM::GEPOp. |
| //===----------------------------------------------------------------------===// |
| |
| constexpr int32_t GEPOp::kDynamicIndex; |
| |
| GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() { |
| return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(), |
| getDynamicIndices()); |
| } |
| |
| /// Returns the elemental type of any LLVM-compatible vector type or self. |
| static Type extractVectorElementType(Type type) { |
| if (auto vectorType = llvm::dyn_cast<VectorType>(type)) |
| return vectorType.getElementType(); |
| return type; |
| } |
| |
| /// Destructures the 'indices' parameter into 'rawConstantIndices' and |
| /// 'dynamicIndices', encoding the former in the process. In the process, |
| /// dynamic indices which are used to index into a structure type are converted |
| /// to constant indices when possible. To do this, the GEPs element type should |
| /// be passed as first parameter. |
| static void destructureIndices(Type currType, ArrayRef<GEPArg> indices, |
| SmallVectorImpl<int32_t> &rawConstantIndices, |
| SmallVectorImpl<Value> &dynamicIndices) { |
| for (const GEPArg &iter : indices) { |
| // If the thing we are currently indexing into is a struct we must turn |
| // any integer constants into constant indices. If this is not possible |
| // we don't do anything here. The verifier will catch it and emit a proper |
| // error. All other canonicalization is done in the fold method. |
| bool requiresConst = !rawConstantIndices.empty() && |
| isa_and_nonnull<LLVMStructType>(currType); |
| if (Value val = llvm::dyn_cast_if_present<Value>(iter)) { |
| APInt intC; |
| if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) && |
| intC.isSignedIntN(kGEPConstantBitWidth)) { |
| rawConstantIndices.push_back(intC.getSExtValue()); |
| } else { |
| rawConstantIndices.push_back(GEPOp::kDynamicIndex); |
| dynamicIndices.push_back(val); |
| } |
| } else { |
| rawConstantIndices.push_back(cast<GEPConstantIndex>(iter)); |
| } |
| |
| // Skip for very first iteration of this loop. First index does not index |
| // within the aggregates, but is just a pointer offset. |
| if (rawConstantIndices.size() == 1 || !currType) |
| continue; |
| |
| currType = TypeSwitch<Type, Type>(currType) |
| .Case<VectorType, LLVMArrayType>([](auto containerType) { |
| return containerType.getElementType(); |
| }) |
| .Case([&](LLVMStructType structType) -> Type { |
| int64_t memberIndex = rawConstantIndices.back(); |
| if (memberIndex >= 0 && static_cast<size_t>(memberIndex) < |
| structType.getBody().size()) |
| return structType.getBody()[memberIndex]; |
| return nullptr; |
| }) |
| .Default(Type(nullptr)); |
| } |
| } |
| |
| void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, |
| Type elementType, Value basePtr, ArrayRef<GEPArg> indices, |
| bool inbounds, ArrayRef<NamedAttribute> attributes) { |
| SmallVector<int32_t> rawConstantIndices; |
| SmallVector<Value> dynamicIndices; |
| destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices); |
| |
| result.addTypes(resultType); |
| result.addAttributes(attributes); |
| result.addAttribute(getRawConstantIndicesAttrName(result.name), |
| builder.getDenseI32ArrayAttr(rawConstantIndices)); |
| if (inbounds) { |
| result.addAttribute(getInboundsAttrName(result.name), |
| builder.getUnitAttr()); |
| } |
| result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); |
| result.addOperands(basePtr); |
| result.addOperands(dynamicIndices); |
| } |
| |
| void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, |
| Type elementType, Value basePtr, ValueRange indices, |
| bool inbounds, ArrayRef<NamedAttribute> attributes) { |
| build(builder, result, resultType, elementType, basePtr, |
| SmallVector<GEPArg>(indices), inbounds, attributes); |
| } |
| |
| static ParseResult |
| parseGEPIndices(OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices, |
| DenseI32ArrayAttr &rawConstantIndices) { |
| SmallVector<int32_t> constantIndices; |
| |
| auto idxParser = [&]() -> ParseResult { |
| int32_t constantIndex; |
| OptionalParseResult parsedInteger = |
| parser.parseOptionalInteger(constantIndex); |
| if (parsedInteger.has_value()) { |
| if (failed(parsedInteger.value())) |
| return failure(); |
| constantIndices.push_back(constantIndex); |
| return success(); |
| } |
| |
| constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); |
| return parser.parseOperand(indices.emplace_back()); |
| }; |
| if (parser.parseCommaSeparatedList(idxParser)) |
| return failure(); |
| |
| rawConstantIndices = |
| DenseI32ArrayAttr::get(parser.getContext(), constantIndices); |
| return success(); |
| } |
| |
| static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, |
| OperandRange indices, |
| DenseI32ArrayAttr rawConstantIndices) { |
| llvm::interleaveComma( |
| GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer, |
| [&](PointerUnion<IntegerAttr, Value> cst) { |
| if (Value val = llvm::dyn_cast_if_present<Value>(cst)) |
| printer.printOperand(val); |
| else |
| printer << cast<IntegerAttr>(cst).getInt(); |
| }); |
| } |
| |
| /// For the given `indices`, check if they comply with `baseGEPType`, |
| /// especially check against LLVMStructTypes nested within. |
| static LogicalResult |
| verifyStructIndices(Type baseGEPType, unsigned indexPos, |
| GEPIndicesAdaptor<ValueRange> indices, |
| function_ref<InFlightDiagnostic()> emitOpError) { |
| if (indexPos >= indices.size()) |
| // Stop searching |
| return success(); |
| |
| return TypeSwitch<Type, LogicalResult>(baseGEPType) |
| .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult { |
| auto attr = dyn_cast<IntegerAttr>(indices[indexPos]); |
| if (!attr) |
| return emitOpError() << "expected index " << indexPos |
| << " indexing a struct to be constant"; |
| |
| int32_t gepIndex = attr.getInt(); |
| ArrayRef<Type> elementTypes = structType.getBody(); |
| if (gepIndex < 0 || |
| static_cast<size_t>(gepIndex) >= elementTypes.size()) |
| return emitOpError() << "index " << indexPos |
| << " indexing a struct is out of bounds"; |
| |
| // Instead of recursively going into every children types, we only |
| // dive into the one indexed by gepIndex. |
| return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, |
| indices, emitOpError); |
| }) |
| .Case<VectorType, LLVMArrayType>( |
| [&](auto containerType) -> LogicalResult { |
| return verifyStructIndices(containerType.getElementType(), |
| indexPos + 1, indices, emitOpError); |
| }) |
| .Default([&](auto otherType) -> LogicalResult { |
| return emitOpError() |
| << "type " << otherType << " cannot be indexed (index #" |
| << indexPos << ")"; |
| }); |
| } |
| |
| /// Driver function around `verifyStructIndices`. |
| static LogicalResult |
| verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices, |
| function_ref<InFlightDiagnostic()> emitOpError) { |
| return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError); |
| } |
| |
| LogicalResult LLVM::GEPOp::verify() { |
| if (static_cast<size_t>( |
| llvm::count(getRawConstantIndices(), kDynamicIndex)) != |
| getDynamicIndices().size()) |
| return emitOpError("expected as many dynamic indices as specified in '") |
| << getRawConstantIndicesAttrName().getValue() << "'"; |
| |
| return verifyStructIndices(getElemType(), getIndices(), |
| [&] { return emitOpError(); }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LoadOp |
| //===----------------------------------------------------------------------===// |
| |
| void LoadOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable()); |
| // Volatile operations can have target-specific read-write effects on |
| // memory besides the one referred to by the pointer operand. |
| // Similarly, atomic operations that are monotonic or stricter cause |
| // synchronization that from a language point-of-view, are arbitrary |
| // read-writes into memory. |
| if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && |
| getOrdering() != AtomicOrdering::unordered)) { |
| effects.emplace_back(MemoryEffects::Write::get()); |
| effects.emplace_back(MemoryEffects::Read::get()); |
| } |
| } |
| |
| /// Returns true if the given type is supported by atomic operations. All |
| /// integer, float, and pointer types with a power-of-two bitsize and a minimal |
| /// size of 8 bits are supported. |
| static bool isTypeCompatibleWithAtomicOp(Type type, |
| const DataLayout &dataLayout) { |
| if (!isa<IntegerType, LLVMPointerType>(type)) |
| if (!isCompatibleFloatingPointType(type)) |
| return false; |
| |
| llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type); |
| if (bitWidth.isScalable()) |
| return false; |
| // Needs to be at least 8 bits and a power of two. |
| return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0; |
| } |
| |
| /// Verifies the attributes and the type of atomic memory access operations. |
| template <typename OpTy> |
| LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, |
| ArrayRef<AtomicOrdering> unsupportedOrderings) { |
| if (memOp.getOrdering() != AtomicOrdering::not_atomic) { |
| DataLayout dataLayout = DataLayout::closest(memOp); |
| if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout)) |
| return memOp.emitOpError("unsupported type ") |
| << valueType << " for atomic access"; |
| if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering())) |
| return memOp.emitOpError("unsupported ordering '") |
| << stringifyAtomicOrdering(memOp.getOrdering()) << "'"; |
| if (!memOp.getAlignment()) |
| return memOp.emitOpError("expected alignment for atomic access"); |
| return success(); |
| } |
| if (memOp.getSyncscope()) |
| return memOp.emitOpError( |
| "expected syncscope to be null for non-atomic access"); |
| return success(); |
| } |
| |
| LogicalResult LoadOp::verify() { |
| Type valueType = getResult().getType(); |
| return verifyAtomicMemOp(*this, valueType, |
| {AtomicOrdering::release, AtomicOrdering::acq_rel}); |
| } |
| |
| void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, |
| Value addr, unsigned alignment, bool isVolatile, |
| bool isNonTemporal, bool isInvariant, bool isInvariantGroup, |
| AtomicOrdering ordering, StringRef syncscope) { |
| build(builder, state, type, addr, |
| alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, |
| isNonTemporal, isInvariant, isInvariantGroup, ordering, |
| syncscope.empty() ? nullptr : builder.getStringAttr(syncscope), |
| /*dereferenceable=*/nullptr, |
| /*access_groups=*/nullptr, |
| /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, |
| /*tbaa=*/nullptr); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // StoreOp |
| //===----------------------------------------------------------------------===// |
| |
| void StoreOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable()); |
| // Volatile operations can have target-specific read-write effects on |
| // memory besides the one referred to by the pointer operand. |
| // Similarly, atomic operations that are monotonic or stricter cause |
| // synchronization that from a language point-of-view, are arbitrary |
| // read-writes into memory. |
| if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && |
| getOrdering() != AtomicOrdering::unordered)) { |
| effects.emplace_back(MemoryEffects::Write::get()); |
| effects.emplace_back(MemoryEffects::Read::get()); |
| } |
| } |
| |
| LogicalResult StoreOp::verify() { |
| Type valueType = getValue().getType(); |
| return verifyAtomicMemOp(*this, valueType, |
| {AtomicOrdering::acquire, AtomicOrdering::acq_rel}); |
| } |
| |
| void StoreOp::build(OpBuilder &builder, OperationState &state, Value value, |
| Value addr, unsigned alignment, bool isVolatile, |
| bool isNonTemporal, bool isInvariantGroup, |
| AtomicOrdering ordering, StringRef syncscope) { |
| build(builder, state, value, addr, |
| alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, |
| isNonTemporal, isInvariantGroup, ordering, |
| syncscope.empty() ? nullptr : builder.getStringAttr(syncscope), |
| /*access_groups=*/nullptr, |
| /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Gets the MLIR Op-like result types of a LLVMFunctionType. |
| static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) { |
| SmallVector<Type, 1> results; |
| Type resultType = calleeType.getReturnType(); |
| if (!isa<LLVM::LLVMVoidType>(resultType)) |
| results.push_back(resultType); |
| return results; |
| } |
| |
| /// Gets the variadic callee type for a LLVMFunctionType. |
| static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) { |
| return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr; |
| } |
| |
| /// Constructs a LLVMFunctionType from MLIR `results` and `args`. |
| static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, |
| ValueRange args) { |
| Type resultType; |
| if (results.empty()) |
| resultType = LLVMVoidType::get(context); |
| else |
| resultType = results.front(); |
| return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()), |
| /*isVarArg=*/false); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, |
| StringRef callee, ValueRange args) { |
| build(builder, state, results, builder.getStringAttr(callee), args); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, |
| StringAttr callee, ValueRange args) { |
| build(builder, state, results, SymbolRefAttr::get(callee), args); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, |
| FlatSymbolRefAttr callee, ValueRange args) { |
| assert(callee && "expected non-null callee in direct call builder"); |
| build(builder, state, results, |
| /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr, |
| /*branch_weights=*/nullptr, |
| /*CConv=*/nullptr, /*TailCallKind=*/nullptr, |
| /*memory_effects=*/nullptr, |
| /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, |
| /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, |
| /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, |
| /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, |
| /*no_inline=*/nullptr, /*always_inline=*/nullptr, |
| /*inline_hint=*/nullptr); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, |
| LLVMFunctionType calleeType, StringRef callee, |
| ValueRange args) { |
| build(builder, state, calleeType, builder.getStringAttr(callee), args); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, |
| LLVMFunctionType calleeType, StringAttr callee, |
| ValueRange args) { |
| build(builder, state, calleeType, SymbolRefAttr::get(callee), args); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, |
| LLVMFunctionType calleeType, FlatSymbolRefAttr callee, |
| ValueRange args) { |
| build(builder, state, getCallOpResultTypes(calleeType), |
| getCallOpVarCalleeType(calleeType), callee, args, |
| /*fastmathFlags=*/nullptr, |
| /*branch_weights=*/nullptr, /*CConv=*/nullptr, |
| /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, |
| /*convergent=*/nullptr, |
| /*no_unwind=*/nullptr, /*will_return=*/nullptr, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, |
| /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, |
| /*access_groups=*/nullptr, |
| /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, |
| /*no_inline=*/nullptr, /*always_inline=*/nullptr, |
| /*inline_hint=*/nullptr); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, |
| LLVMFunctionType calleeType, ValueRange args) { |
| build(builder, state, getCallOpResultTypes(calleeType), |
| getCallOpVarCalleeType(calleeType), |
| /*callee=*/nullptr, args, |
| /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, |
| /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, |
| /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, |
| /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, |
| /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, |
| /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, |
| /*no_inline=*/nullptr, /*always_inline=*/nullptr, |
| /*inline_hint=*/nullptr); |
| } |
| |
| void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, |
| ValueRange args) { |
| auto calleeType = func.getFunctionType(); |
| build(builder, state, getCallOpResultTypes(calleeType), |
| getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args, |
| /*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr, |
| /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, |
| /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, |
| /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, |
| /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, |
| /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr, |
| /*no_inline=*/nullptr, /*always_inline=*/nullptr, |
| /*inline_hint=*/nullptr); |
| } |
| |
| CallInterfaceCallable CallOp::getCallableForCallee() { |
| // Direct call. |
| if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) |
| return calleeAttr; |
| // Indirect call, callee Value is the first operand. |
| return getOperand(0); |
| } |
| |
| void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| // Direct call. |
| if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { |
| auto symRef = cast<SymbolRefAttr>(callee); |
| return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef)); |
| } |
| // Indirect call, callee Value is the first operand. |
| return setOperand(0, cast<Value>(callee)); |
| } |
| |
| Operation::operand_range CallOp::getArgOperands() { |
| return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1); |
| } |
| |
| MutableOperandRange CallOp::getArgOperandsMutable() { |
| return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, |
| getCalleeOperands().size()); |
| } |
| |
| /// Verify that an inlinable callsite of a debug-info-bearing function in a |
| /// debug-info-bearing function has a debug location attached to it. This |
| /// mirrors an LLVM IR verifier. |
| static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) { |
| if (callee.isExternal()) |
| return success(); |
| auto parentFunc = callOp->getParentOfType<FunctionOpInterface>(); |
| if (!parentFunc) |
| return success(); |
| |
| auto hasSubprogram = [](Operation *op) { |
| return op->getLoc() |
| ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() != |
| nullptr; |
| }; |
| if (!hasSubprogram(parentFunc) || !hasSubprogram(callee)) |
| return success(); |
| bool containsLoc = !isa<UnknownLoc>(callOp->getLoc()); |
| if (!containsLoc) |
| return callOp.emitError() |
| << "inlinable function call in a function with a DISubprogram " |
| "location must have a debug location"; |
| return success(); |
| } |
| |
| /// Verify that the parameter and return types of the variadic callee type match |
| /// the `callOp` argument and result types. |
| template <typename OpTy> |
| LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { |
| std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType(); |
| if (!varCalleeType) |
| return success(); |
| |
| // Verify the variadic callee type is a variadic function type. |
| if (!varCalleeType->isVarArg()) |
| return callOp.emitOpError( |
| "expected var_callee_type to be a variadic function type"); |
| |
| // Verify the variadic callee type has at most as many parameters as the call |
| // has argument operands. |
| if (varCalleeType->getNumParams() > callOp.getArgOperands().size()) |
| return callOp.emitOpError("expected var_callee_type to have at most ") |
| << callOp.getArgOperands().size() << " parameters"; |
| |
| // Verify the variadic callee type matches the call argument types. |
| for (auto [paramType, operand] : |
| llvm::zip(varCalleeType->getParams(), callOp.getArgOperands())) |
| if (paramType != operand.getType()) |
| return callOp.emitOpError() |
| << "var_callee_type parameter type mismatch: " << paramType |
| << " != " << operand.getType(); |
| |
| // Verify the variadic callee type matches the call result type. |
| if (!callOp.getNumResults()) { |
| if (!isa<LLVMVoidType>(varCalleeType->getReturnType())) |
| return callOp.emitOpError("expected var_callee_type to return void"); |
| } else { |
| if (callOp.getResult().getType() != varCalleeType->getReturnType()) |
| return callOp.emitOpError("var_callee_type return type mismatch: ") |
| << varCalleeType->getReturnType() |
| << " != " << callOp.getResult().getType(); |
| } |
| return success(); |
| } |
| |
| template <typename OpType> |
| static LogicalResult verifyOperandBundles(OpType &op) { |
| OperandRangeRange opBundleOperands = op.getOpBundleOperands(); |
| std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags(); |
| |
| auto isStringAttr = [](Attribute tagAttr) { |
| return isa<StringAttr>(tagAttr); |
| }; |
| if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr)) |
| return op.emitError("operand bundle tag must be a StringAttr"); |
| |
| size_t numOpBundles = opBundleOperands.size(); |
| size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0; |
| if (numOpBundles != numOpBundleTags) |
| return op.emitError("expected ") |
| << numOpBundles << " operand bundle tags, but actually got " |
| << numOpBundleTags; |
| |
| return success(); |
| } |
| |
| LogicalResult CallOp::verify() { return verifyOperandBundles(*this); } |
| |
| LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| if (failed(verifyCallOpVarCalleeType(*this))) |
| return failure(); |
| |
| // Type for the callee, we'll get it differently depending if it is a direct |
| // or indirect call. |
| Type fnType; |
| |
| bool isIndirect = false; |
| |
| // If this is an indirect call, the callee attribute is missing. |
| FlatSymbolRefAttr calleeName = getCalleeAttr(); |
| if (!calleeName) { |
| isIndirect = true; |
| if (!getNumOperands()) |
| return emitOpError( |
| "must have either a `callee` attribute or at least an operand"); |
| auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType()); |
| if (!ptrType) |
| return emitOpError("indirect call expects a pointer as callee: ") |
| << getOperand(0).getType(); |
| |
| return success(); |
| } else { |
| Operation *callee = |
| symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr()); |
| if (!callee) |
| return emitOpError() |
| << "'" << calleeName.getValue() |
| << "' does not reference a symbol in the current scope"; |
| auto fn = dyn_cast<LLVMFuncOp>(callee); |
| if (!fn) |
| return emitOpError() << "'" << calleeName.getValue() |
| << "' does not reference a valid LLVM function"; |
| |
| if (failed(verifyCallOpDebugInfo(*this, fn))) |
| return failure(); |
| fnType = fn.getFunctionType(); |
| } |
| |
| LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType); |
| if (!funcType) |
| return emitOpError("callee does not have a functional type: ") << fnType; |
| |
| if (funcType.isVarArg() && !getVarCalleeType()) |
| return emitOpError() << "missing var_callee_type attribute for vararg call"; |
| |
| // Verify that the operand and result types match the callee. |
| |
| if (!funcType.isVarArg() && |
| funcType.getNumParams() != (getCalleeOperands().size() - isIndirect)) |
| return emitOpError() << "incorrect number of operands (" |
| << (getCalleeOperands().size() - isIndirect) |
| << ") for callee (expecting: " |
| << funcType.getNumParams() << ")"; |
| |
| if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect)) |
| return emitOpError() << "incorrect number of operands (" |
| << (getCalleeOperands().size() - isIndirect) |
| << ") for varargs callee (expecting at least: " |
| << funcType.getNumParams() << ")"; |
| |
| for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) |
| if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) |
| return emitOpError() << "operand type mismatch for operand " << i << ": " |
| << getOperand(i + isIndirect).getType() |
| << " != " << funcType.getParamType(i); |
| |
| if (getNumResults() == 0 && |
| !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType())) |
| return emitOpError() << "expected function call to produce a value"; |
| |
| if (getNumResults() != 0 && |
| llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType())) |
| return emitOpError() |
| << "calling function with void result must not produce values"; |
| |
| if (getNumResults() > 1) |
| return emitOpError() |
| << "expected LLVM function call to produce 0 or 1 result"; |
| |
| if (getNumResults() && getResult().getType() != funcType.getReturnType()) |
| return emitOpError() << "result type mismatch: " << getResult().getType() |
| << " != " << funcType.getReturnType(); |
| |
| return success(); |
| } |
| |
| void CallOp::print(OpAsmPrinter &p) { |
| auto callee = getCallee(); |
| bool isDirect = callee.has_value(); |
| |
| p << ' '; |
| |
| // Print calling convention. |
| if (getCConv() != LLVM::CConv::C) |
| p << stringifyCConv(getCConv()) << ' '; |
| |
| if (getTailCallKind() != LLVM::TailCallKind::None) |
| p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' '; |
| |
| // Print the direct callee if present as a function attribute, or an indirect |
| // callee (first operand) otherwise. |
| if (isDirect) |
| p.printSymbolName(callee.value()); |
| else |
| p << getOperand(0); |
| |
| auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1); |
| p << '(' << args << ')'; |
| |
| // Print the variadic callee type if the call is variadic. |
| if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) |
| p << " vararg(" << *varCalleeType << ")"; |
| |
| if (!getOpBundleOperands().empty()) { |
| p << " "; |
| printOpBundles(p, *this, getOpBundleOperands(), |
| getOpBundleOperands().getTypes(), getOpBundleTags()); |
| } |
| |
| p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), |
| {getCalleeAttrName(), getTailCallKindAttrName(), |
| getVarCalleeTypeAttrName(), getCConvAttrName(), |
| getOperandSegmentSizesAttrName(), |
| getOpBundleSizesAttrName(), |
| getOpBundleTagsAttrName(), getArgAttrsAttrName(), |
| getResAttrsAttrName()}); |
| |
| p << " : "; |
| if (!isDirect) |
| p << getOperand(0).getType() << ", "; |
| |
| // Reconstruct the MLIR function type from operand and result types. |
| call_interface_impl::printFunctionSignature( |
| p, args.getTypes(), getArgAttrsAttr(), |
| /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); |
| } |
| |
| /// Parses the type of a call operation and resolves the operands if the parsing |
| /// succeeds. Returns failure otherwise. |
| static ParseResult parseCallTypeAndResolveOperands( |
| OpAsmParser &parser, OperationState &result, bool isDirect, |
| ArrayRef<OpAsmParser::UnresolvedOperand> operands, |
| SmallVectorImpl<DictionaryAttr> &argAttrs, |
| SmallVectorImpl<DictionaryAttr> &resultAttrs) { |
| SMLoc trailingTypesLoc = parser.getCurrentLocation(); |
| SmallVector<Type> types; |
| if (parser.parseColon()) |
| return failure(); |
| if (!isDirect) { |
| types.emplace_back(); |
| if (parser.parseType(types.back())) |
| return failure(); |
| if (parser.parseOptionalComma()) |
| return parser.emitError( |
| trailingTypesLoc, "expected indirect call to have 2 trailing types"); |
| } |
| SmallVector<Type> argTypes; |
| SmallVector<Type> resTypes; |
| if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs, |
| resTypes, resultAttrs)) { |
| if (isDirect) |
| return parser.emitError(trailingTypesLoc, |
| "expected direct call to have 1 trailing types"); |
| return parser.emitError(trailingTypesLoc, |
| "expected trailing function type"); |
| } |
| |
| if (resTypes.size() > 1) |
| return parser.emitError(trailingTypesLoc, |
| "expected function with 0 or 1 result"); |
| if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0])) |
| return parser.emitError(trailingTypesLoc, |
| "expected a non-void result type"); |
| |
| // The head element of the types list matches the callee type for |
| // indirect calls, while the types list is emtpy for direct calls. |
| // Append the function input types to resolve the call operation |
| // operands. |
| llvm::append_range(types, argTypes); |
| if (parser.resolveOperands(operands, types, parser.getNameLoc(), |
| result.operands)) |
| return failure(); |
| if (resTypes.size() != 0) |
| result.addTypes(resTypes); |
| |
| return success(); |
| } |
| |
| /// Parses an optional function pointer operand before the call argument list |
| /// for indirect calls, or stops parsing at the function identifier otherwise. |
| static ParseResult parseOptionalCallFuncPtr( |
| OpAsmParser &parser, |
| SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands) { |
| OpAsmParser::UnresolvedOperand funcPtrOperand; |
| OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand); |
| if (parseResult.has_value()) { |
| if (failed(*parseResult)) |
| return *parseResult; |
| operands.push_back(funcPtrOperand); |
| } |
| return success(); |
| } |
| |
| static ParseResult resolveOpBundleOperands( |
| OpAsmParser &parser, SMLoc loc, OperationState &state, |
| ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands, |
| ArrayRef<SmallVector<Type>> opBundleOperandTypes, |
| StringAttr opBundleSizesAttrName) { |
| unsigned opBundleIndex = 0; |
| for (const auto &[operands, types] : |
| llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) { |
| if (operands.size() != types.size()) |
| return parser.emitError(loc, "expected ") |
| << operands.size() |
| << " types for operand bundle operands for operand bundle #" |
| << opBundleIndex << ", but actually got " << types.size(); |
| if (parser.resolveOperands(operands, types, loc, state.operands)) |
| return failure(); |
| } |
| |
| SmallVector<int32_t> opBundleSizes; |
| opBundleSizes.reserve(opBundleOperands.size()); |
| for (const auto &operands : opBundleOperands) |
| opBundleSizes.push_back(operands.size()); |
| |
| state.addAttribute( |
| opBundleSizesAttrName, |
| DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes)); |
| |
| return success(); |
| } |
| |
| // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use) |
| // `(` ssa-use-list `)` |
| // ( `vararg(` var-callee-type `)` )? |
| // ( `[` op-bundles-list `]` )? |
| // attribute-dict? `:` (type `,`)? function-type |
| ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) { |
| SymbolRefAttr funcAttr; |
| TypeAttr varCalleeType; |
| SmallVector<OpAsmParser::UnresolvedOperand> operands; |
| SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands; |
| SmallVector<SmallVector<Type>> opBundleOperandTypes; |
| ArrayAttr opBundleTags; |
| |
| // Default to C Calling Convention if no keyword is provided. |
| result.addAttribute( |
| getCConvAttrName(result.name), |
| CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( |
| parser, result, LLVM::CConv::C))); |
| |
| result.addAttribute( |
| getTailCallKindAttrName(result.name), |
| TailCallKindAttr::get(parser.getContext(), |
| parseOptionalLLVMKeyword<TailCallKind>( |
| parser, result, LLVM::TailCallKind::None))); |
| |
| // Parse a function pointer for indirect calls. |
| if (parseOptionalCallFuncPtr(parser, operands)) |
| return failure(); |
| bool isDirect = operands.empty(); |
| |
| // Parse a function identifier for direct calls. |
| if (isDirect) |
| if (parser.parseAttribute(funcAttr, "callee", result.attributes)) |
| return failure(); |
| |
| // Parse the function arguments. |
| if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) |
| return failure(); |
| |
| bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded(); |
| if (isVarArg) { |
| StringAttr varCalleeTypeAttrName = |
| CallOp::getVarCalleeTypeAttrName(result.name); |
| if (parser.parseLParen().failed() || |
| parser |
| .parseAttribute(varCalleeType, varCalleeTypeAttrName, |
| result.attributes) |
| .failed() || |
| parser.parseRParen().failed()) |
| return failure(); |
| } |
| |
| SMLoc opBundlesLoc = parser.getCurrentLocation(); |
| if (std::optional<ParseResult> result = parseOpBundles( |
| parser, opBundleOperands, opBundleOperandTypes, opBundleTags); |
| result && failed(*result)) |
| return failure(); |
| if (opBundleTags && !opBundleTags.empty()) |
| result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(), |
| opBundleTags); |
| |
| if (parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| |
| // Parse the trailing type list and resolve the operands. |
| SmallVector<DictionaryAttr> argAttrs; |
| SmallVector<DictionaryAttr> resultAttrs; |
| if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, |
| argAttrs, resultAttrs)) |
| return failure(); |
| call_interface_impl::addArgAndResultAttrs( |
| parser.getBuilder(), result, argAttrs, resultAttrs, |
| getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, |
| opBundleOperandTypes, |
| getOpBundleSizesAttrName(result.name))) |
| return failure(); |
| |
| int32_t numOpBundleOperands = 0; |
| for (const auto &operands : opBundleOperands) |
| numOpBundleOperands += operands.size(); |
| |
| result.addAttribute( |
| CallOp::getOperandSegmentSizeAttr(), |
| parser.getBuilder().getDenseI32ArrayAttr( |
| {static_cast<int32_t>(operands.size()), numOpBundleOperands})); |
| return success(); |
| } |
| |
| LLVMFunctionType CallOp::getCalleeFunctionType() { |
| if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) |
| return *varCalleeType; |
| return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands()); |
| } |
| |
| ///===---------------------------------------------------------------------===// |
| /// LLVM::InvokeOp |
| ///===---------------------------------------------------------------------===// |
| |
| void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, |
| ValueRange ops, Block *normal, ValueRange normalOps, |
| Block *unwind, ValueRange unwindOps) { |
| auto calleeType = func.getFunctionType(); |
| build(builder, state, getCallOpResultTypes(calleeType), |
| getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops, |
| /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps, |
| nullptr, nullptr, {}, {}, normal, unwind); |
| } |
| |
| void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, |
| FlatSymbolRefAttr callee, ValueRange ops, Block *normal, |
| ValueRange normalOps, Block *unwind, |
| ValueRange unwindOps) { |
| build(builder, state, tys, |
| /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr, |
| /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {}, |
| normal, unwind); |
| } |
| |
| void InvokeOp::build(OpBuilder &builder, OperationState &state, |
| LLVMFunctionType calleeType, FlatSymbolRefAttr callee, |
| ValueRange ops, Block *normal, ValueRange normalOps, |
| Block *unwind, ValueRange unwindOps) { |
| build(builder, state, getCallOpResultTypes(calleeType), |
| getCallOpVarCalleeType(calleeType), callee, ops, |
| /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps, |
| nullptr, nullptr, {}, {}, normal, unwind); |
| } |
| |
| SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) { |
| assert(index < getNumSuccessors() && "invalid successor index"); |
| return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable() |
| : getUnwindDestOperandsMutable()); |
| } |
| |
| CallInterfaceCallable InvokeOp::getCallableForCallee() { |
| // Direct call. |
| if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) |
| return calleeAttr; |
| // Indirect call, callee Value is the first operand. |
| return getOperand(0); |
| } |
| |
| void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) { |
| // Direct call. |
| if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { |
| auto symRef = cast<SymbolRefAttr>(callee); |
| return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef)); |
| } |
| // Indirect call, callee Value is the first operand. |
| return setOperand(0, cast<Value>(callee)); |
| } |
| |
| Operation::operand_range InvokeOp::getArgOperands() { |
| return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1); |
| } |
| |
| MutableOperandRange InvokeOp::getArgOperandsMutable() { |
| return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, |
| getCalleeOperands().size()); |
| } |
| |
| LogicalResult InvokeOp::verify() { |
| if (failed(verifyCallOpVarCalleeType(*this))) |
| return failure(); |
| |
| Block *unwindDest = getUnwindDest(); |
| if (unwindDest->empty()) |
| return emitError("must have at least one operation in unwind destination"); |
| |
| // In unwind destination, first operation must be LandingpadOp |
| if (!isa<LandingpadOp>(unwindDest->front())) |
| return emitError("first operation in unwind destination should be a " |
| "llvm.landingpad operation"); |
| |
| if (failed(verifyOperandBundles(*this))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| void InvokeOp::print(OpAsmPrinter &p) { |
| auto callee = getCallee(); |
| bool isDirect = callee.has_value(); |
| |
| p << ' '; |
| |
| // Print calling convention. |
| if (getCConv() != LLVM::CConv::C) |
| p << stringifyCConv(getCConv()) << ' '; |
| |
| // Either function name or pointer |
| if (isDirect) |
| p.printSymbolName(callee.value()); |
| else |
| p << getOperand(0); |
| |
| p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')'; |
| p << " to "; |
| p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands()); |
| p << " unwind "; |
| p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands()); |
| |
| // Print the variadic callee type if the invoke is variadic. |
| if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) |
| p << " vararg(" << *varCalleeType << ")"; |
| |
| if (!getOpBundleOperands().empty()) { |
| p << " "; |
| printOpBundles(p, *this, getOpBundleOperands(), |
| getOpBundleOperands().getTypes(), getOpBundleTags()); |
| } |
| |
| p.printOptionalAttrDict((*this)->getAttrs(), |
| {getCalleeAttrName(), getOperandSegmentSizeAttr(), |
| getCConvAttrName(), getVarCalleeTypeAttrName(), |
| getOpBundleSizesAttrName(), |
| getOpBundleTagsAttrName(), getArgAttrsAttrName(), |
| getResAttrsAttrName()}); |
| |
| p << " : "; |
| if (!isDirect) |
| p << getOperand(0).getType() << ", "; |
| call_interface_impl::printFunctionSignature( |
| p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(), |
| getArgAttrsAttr(), |
| /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); |
| } |
| |
| // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use) |
| // `(` ssa-use-list `)` |
| // `to` bb-id (`[` ssa-use-and-type-list `]`)? |
| // `unwind` bb-id (`[` ssa-use-and-type-list `]`)? |
| // ( `vararg(` var-callee-type `)` )? |
| // ( `[` op-bundles-list `]` )? |
| // attribute-dict? `:` (type `,`)? |
| // function-type-with-argument-attributes |
| ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::UnresolvedOperand, 8> operands; |
| SymbolRefAttr funcAttr; |
| TypeAttr varCalleeType; |
| SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands; |
| SmallVector<SmallVector<Type>> opBundleOperandTypes; |
| ArrayAttr opBundleTags; |
| Block *normalDest, *unwindDest; |
| SmallVector<Value, 4> normalOperands, unwindOperands; |
| Builder &builder = parser.getBuilder(); |
| |
| // Default to C Calling Convention if no keyword is provided. |
| result.addAttribute( |
| getCConvAttrName(result.name), |
| CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( |
| parser, result, LLVM::CConv::C))); |
| |
| // Parse a function pointer for indirect calls. |
| if (parseOptionalCallFuncPtr(parser, operands)) |
| return failure(); |
| bool isDirect = operands.empty(); |
| |
| // Parse a function identifier for direct calls. |
| if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes)) |
| return failure(); |
| |
| // Parse the function arguments. |
| if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || |
| parser.parseKeyword("to") || |
| parser.parseSuccessorAndUseList(normalDest, normalOperands) || |
| parser.parseKeyword("unwind") || |
| parser.parseSuccessorAndUseList(unwindDest, unwindOperands)) |
| return failure(); |
| |
| bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded(); |
| if (isVarArg) { |
| StringAttr varCalleeTypeAttrName = |
| InvokeOp::getVarCalleeTypeAttrName(result.name); |
| if (parser.parseLParen().failed() || |
| parser |
| .parseAttribute(varCalleeType, varCalleeTypeAttrName, |
| result.attributes) |
| .failed() || |
| parser.parseRParen().failed()) |
| return failure(); |
| } |
| |
| SMLoc opBundlesLoc = parser.getCurrentLocation(); |
| if (std::optional<ParseResult> result = parseOpBundles( |
| parser, opBundleOperands, opBundleOperandTypes, opBundleTags); |
| result && failed(*result)) |
| return failure(); |
| if (opBundleTags && !opBundleTags.empty()) |
| result.addAttribute( |
| InvokeOp::getOpBundleTagsAttrName(result.name).getValue(), |
| opBundleTags); |
| |
| if (parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| |
| // Parse the trailing type list and resolve the function operands. |
| SmallVector<DictionaryAttr> argAttrs; |
| SmallVector<DictionaryAttr> resultAttrs; |
| if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands, |
| argAttrs, resultAttrs)) |
| return failure(); |
| call_interface_impl::addArgAndResultAttrs( |
| parser.getBuilder(), result, argAttrs, resultAttrs, |
| getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| |
| if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, |
| opBundleOperandTypes, |
| getOpBundleSizesAttrName(result.name))) |
| return failure(); |
| |
| result.addSuccessors({normalDest, unwindDest}); |
| result.addOperands(normalOperands); |
| result.addOperands(unwindOperands); |
| |
| int32_t numOpBundleOperands = 0; |
| for (const auto &operands : opBundleOperands) |
| numOpBundleOperands += operands.size(); |
| |
| result.addAttribute( |
| InvokeOp::getOperandSegmentSizeAttr(), |
| builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()), |
| static_cast<int32_t>(normalOperands.size()), |
| static_cast<int32_t>(unwindOperands.size()), |
| numOpBundleOperands})); |
| return success(); |
| } |
| |
| LLVMFunctionType InvokeOp::getCalleeFunctionType() { |
| if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType()) |
| return *varCalleeType; |
| return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands()); |
| } |
| |
| ///===----------------------------------------------------------------------===// |
| /// Verifying/Printing/Parsing for LLVM::LandingpadOp. |
| ///===----------------------------------------------------------------------===// |
| |
| LogicalResult LandingpadOp::verify() { |
| Value value; |
| if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) { |
| if (!func.getPersonality()) |
| return emitError( |
| "llvm.landingpad needs to be in a function with a personality"); |
| } |
| |
| // Consistency of llvm.landingpad result types is checked in |
| // LLVMFuncOp::verify(). |
| |
| if (!getCleanup() && getOperands().empty()) |
| return emitError("landingpad instruction expects at least one clause or " |
| "cleanup attribute"); |
| |
| for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { |
| value = getOperand(idx); |
| bool isFilter = llvm::isa<LLVMArrayType>(value.getType()); |
| if (isFilter) { |
| // FIXME: Verify filter clauses when arrays are appropriately handled |
| } else { |
| // catch - global addresses only. |
| // Bitcast ops should have global addresses as their args. |
| if (auto bcOp = value.getDefiningOp<BitcastOp>()) { |
| if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>()) |
| continue; |
| return emitError("constant clauses expected").attachNote(bcOp.getLoc()) |
| << "global addresses expected as operand to " |
| "bitcast used in clauses for landingpad"; |
| } |
| // ZeroOp and AddressOfOp allowed |
| if (value.getDefiningOp<ZeroOp>()) |
| continue; |
| if (value.getDefiningOp<AddressOfOp>()) |
| continue; |
| return emitError("clause #") |
| << idx << " is not a known constant - null, addressof, bitcast"; |
| } |
| } |
| return success(); |
| } |
| |
| void LandingpadOp::print(OpAsmPrinter &p) { |
| p << (getCleanup() ? " cleanup " : " "); |
| |
| // Clauses |
| for (auto value : getOperands()) { |
| // Similar to llvm - if clause is an array type then it is filter |
| // clause else catch clause |
| bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType()); |
| p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " |
| << value.getType() << ") "; |
| } |
| |
| p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"}); |
| |
| p << ": " << getType(); |
| } |
| |
| // <operation> ::= `llvm.landingpad` `cleanup`? |
| // ((`catch` | `filter`) operand-type ssa-use)* attribute-dict? |
| ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) { |
| // Check for cleanup |
| if (succeeded(parser.parseOptionalKeyword("cleanup"))) |
| result.addAttribute("cleanup", parser.getBuilder().getUnitAttr()); |
| |
| // Parse clauses with types |
| while (succeeded(parser.parseOptionalLParen()) && |
| (succeeded(parser.parseOptionalKeyword("filter")) || |
| succeeded(parser.parseOptionalKeyword("catch")))) { |
| OpAsmParser::UnresolvedOperand operand; |
| Type ty; |
| if (parser.parseOperand(operand) || parser.parseColon() || |
| parser.parseType(ty) || |
| parser.resolveOperand(operand, ty, result.operands) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| |
| Type type; |
| if (parser.parseColon() || parser.parseType(type)) |
| return failure(); |
| |
| result.addTypes(type); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ExtractValueOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Extract the type at `position` in the LLVM IR aggregate type |
| /// `containerType`. Each element of `position` is an index into a nested |
| /// aggregate type. Return the resulting type or emit an error. |
| static Type getInsertExtractValueElementType( |
| function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType, |
| ArrayRef<int64_t> position) { |
| Type llvmType = containerType; |
| if (!isCompatibleType(containerType)) { |
| emitError("expected LLVM IR Dialect type, got ") << containerType; |
| return {}; |
| } |
| |
| // Infer the element type from the structure type: iteratively step inside the |
| // type by taking the element type, indexed by the position attribute for |
| // structures. Check the position index before accessing, it is supposed to |
| // be in bounds. |
| for (int64_t idx : position) { |
| if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) { |
| if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) { |
| emitError("position out of bounds: ") << idx; |
| return {}; |
| } |
| llvmType = arrayType.getElementType(); |
| } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) { |
| if (idx < 0 || |
| static_cast<unsigned>(idx) >= structType.getBody().size()) { |
| emitError("position out of bounds: ") << idx; |
| return {}; |
| } |
| llvmType = structType.getBody()[idx]; |
| } else { |
| emitError("expected LLVM IR structure/array type, got: ") << llvmType; |
| return {}; |
| } |
| } |
| return llvmType; |
| } |
| |
| /// Extract the type at `position` in the wrapped LLVM IR aggregate type |
| /// `containerType`. |
| static Type getInsertExtractValueElementType(Type llvmType, |
| ArrayRef<int64_t> position) { |
| for (int64_t idx : position) { |
| if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) |
| llvmType = structType.getBody()[idx]; |
| else |
| llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType(); |
| } |
| return llvmType; |
| } |
| |
| OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { |
| if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) { |
| SmallVector<int64_t, 4> newPos(extractValueOp.getPosition()); |
| newPos.append(getPosition().begin(), getPosition().end()); |
| setPosition(newPos); |
| getContainerMutable().set(extractValueOp.getContainer()); |
| return getResult(); |
| } |
| |
| { |
| DenseElementsAttr constval; |
| matchPattern(getContainer(), m_Constant(&constval)); |
| if (constval && constval.getElementType() == getType()) { |
| if (isa<SplatElementsAttr>(constval)) |
| return constval.getSplatValue<Attribute>(); |
| if (getPosition().size() == 1) |
| return constval.getValues<Attribute>()[getPosition()[0]]; |
| } |
| } |
| |
| auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>(); |
| OpFoldResult result = {}; |
| while (insertValueOp) { |
| if (getPosition() == insertValueOp.getPosition()) |
| return insertValueOp.getValue(); |
| unsigned min = |
| std::min(getPosition().size(), insertValueOp.getPosition().size()); |
| // If one is fully prefix of the other, stop propagating back as it will |
| // miss dependencies. For instance, %3 should not fold to %f0 in the |
| // following example: |
| // ``` |
| // %1 = llvm.insertvalue %f0, %0[0, 0] : |
| // !llvm.array<4 x !llvm.array<4 x f32>> |
| // %2 = llvm.insertvalue %arr, %1[0] : |
| // !llvm.array<4 x !llvm.array<4 x f32>> |
| // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>> |
| // ``` |
| if (getPosition().take_front(min) == |
| insertValueOp.getPosition().take_front(min)) |
| return result; |
| |
| // If neither a prefix, nor the exact position, we can extract out of the |
| // value being inserted into. Moreover, we can try again if that operand |
| // is itself an insertvalue expression. |
| getContainerMutable().assign(insertValueOp.getContainer()); |
| result = getResult(); |
| insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>(); |
| } |
| return result; |
| } |
| |
| LogicalResult ExtractValueOp::verify() { |
| auto emitError = [this](StringRef msg) { return emitOpError(msg); }; |
| Type valueType = getInsertExtractValueElementType( |
| emitError, getContainer().getType(), getPosition()); |
| if (!valueType) |
| return failure(); |
| |
| if (getRes().getType() != valueType) |
| return emitOpError() << "Type mismatch: extracting from " |
| << getContainer().getType() << " should produce " |
| << valueType << " but this op returns " |
| << getRes().getType(); |
| return success(); |
| } |
| |
| void ExtractValueOp::build(OpBuilder &builder, OperationState &state, |
| Value container, ArrayRef<int64_t> position) { |
| build(builder, state, |
| getInsertExtractValueElementType(container.getType(), position), |
| container, builder.getAttr<DenseI64ArrayAttr>(position)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InsertValueOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Infer the value type from the container type and position. |
| static ParseResult |
| parseInsertExtractValueElementType(AsmParser &parser, Type &valueType, |
| Type containerType, |
| DenseI64ArrayAttr position) { |
| valueType = getInsertExtractValueElementType( |
| [&](StringRef msg) { |
| return parser.emitError(parser.getCurrentLocation(), msg); |
| }, |
| containerType, position.asArrayRef()); |
| return success(!!valueType); |
| } |
| |
| /// Nothing to print for an inferred type. |
| static void printInsertExtractValueElementType(AsmPrinter &printer, |
| Operation *op, Type valueType, |
| Type containerType, |
| DenseI64ArrayAttr position) {} |
| |
| LogicalResult InsertValueOp::verify() { |
| auto emitError = [this](StringRef msg) { return emitOpError(msg); }; |
| Type valueType = getInsertExtractValueElementType( |
| emitError, getContainer().getType(), getPosition()); |
| if (!valueType) |
| return failure(); |
| |
| if (getValue().getType() != valueType) |
| return emitOpError() << "Type mismatch: cannot insert " |
| << getValue().getType() << " into " |
| << getContainer().getType(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReturnOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReturnOp::verify() { |
| auto parent = (*this)->getParentOfType<LLVMFuncOp>(); |
| if (!parent) |
| return success(); |
| |
| Type expectedType = parent.getFunctionType().getReturnType(); |
| if (llvm::isa<LLVMVoidType>(expectedType)) { |
| if (!getArg()) |
| return success(); |
| InFlightDiagnostic diag = emitOpError("expected no operands"); |
| diag.attachNote(parent->getLoc()) << "when returning from function"; |
| return diag; |
| } |
| if (!getArg()) { |
| if (llvm::isa<LLVMVoidType>(expectedType)) |
| return success(); |
| InFlightDiagnostic diag = emitOpError("expected 1 operand"); |
| diag.attachNote(parent->getLoc()) << "when returning from function"; |
| return diag; |
| } |
| if (expectedType != getArg().getType()) { |
| InFlightDiagnostic diag = emitOpError("mismatching result types"); |
| diag.attachNote(parent->getLoc()) << "when returning from function"; |
| return diag; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::AddressOfOp. |
| //===----------------------------------------------------------------------===// |
| |
| static Operation *parentLLVMModule(Operation *op) { |
| Operation *module = op->getParentOp(); |
| while (module && !satisfiesLLVMModule(module)) |
| module = module->getParentOp(); |
| assert(module && "unexpected operation outside of a module"); |
| return module; |
| } |
| |
| GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) { |
| return dyn_cast_or_null<GlobalOp>( |
| symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); |
| } |
| |
| LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) { |
| return dyn_cast_or_null<LLVMFuncOp>( |
| symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); |
| } |
| |
| AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) { |
| return dyn_cast_or_null<AliasOp>( |
| symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr())); |
| } |
| |
| LogicalResult |
| AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| Operation *symbol = |
| symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()); |
| |
| auto global = dyn_cast_or_null<GlobalOp>(symbol); |
| auto function = dyn_cast_or_null<LLVMFuncOp>(symbol); |
| auto alias = dyn_cast_or_null<AliasOp>(symbol); |
| |
| if (!global && !function && !alias) |
| return emitOpError("must reference a global defined by 'llvm.mlir.global', " |
| "'llvm.mlir.alias' or 'llvm.func'"); |
| |
| LLVMPointerType type = getType(); |
| if ((global && global.getAddrSpace() != type.getAddressSpace()) || |
| (alias && alias.getAddrSpace() != type.getAddressSpace())) |
| return emitOpError("pointer address space must match address space of the " |
| "referenced global or alias"); |
| |
| return success(); |
| } |
| |
| // AddressOfOp constant-folds to the global symbol name. |
| OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) { |
| return getGlobalNameAttr(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::DSOLocalEquivalentOp |
| //===----------------------------------------------------------------------===// |
| |
| LLVMFuncOp |
| DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) { |
| return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn( |
| parentLLVMModule(*this), getFunctionNameAttr())); |
| } |
| |
| AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) { |
| return dyn_cast_or_null<AliasOp>(symbolTable.lookupSymbolIn( |
| parentLLVMModule(*this), getFunctionNameAttr())); |
| } |
| |
| LogicalResult |
| DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this), |
| getFunctionNameAttr()); |
| auto function = dyn_cast_or_null<LLVMFuncOp>(symbol); |
| auto alias = dyn_cast_or_null<AliasOp>(symbol); |
| |
| if (!function && !alias) |
| return emitOpError( |
| "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'"); |
| |
| if (alias) { |
| if (alias.getInitializer() |
| .walk([&](AddressOfOp addrOp) { |
| if (addrOp.getGlobal(symbolTable)) |
| return WalkResult::interrupt(); |
| return WalkResult::advance(); |
| }) |
| .wasInterrupted()) |
| return emitOpError("must reference an alias to a function"); |
| } |
| |
| if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) || |
| (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak)) |
| return emitOpError( |
| "target function with 'extern_weak' linkage not allowed"); |
| |
| return success(); |
| } |
| |
| /// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent |
| /// attribute. |
| OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) { |
| return DSOLocalEquivalentAttr::get(getContext(), getFunctionNameAttr()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Verifier for LLVM::ComdatOp. |
| //===----------------------------------------------------------------------===// |
| |
| void ComdatOp::build(OpBuilder &builder, OperationState &result, |
| StringRef symName) { |
| result.addAttribute(getSymNameAttrName(result.name), |
| builder.getStringAttr(symName)); |
| Region *body = result.addRegion(); |
| body->emplaceBlock(); |
| } |
| |
| LogicalResult ComdatOp::verifyRegions() { |
| Region &body = getBody(); |
| for (Operation &op : body.getOps()) |
| if (!isa<ComdatSelectorOp>(op)) |
| return op.emitError( |
| "only comdat selector symbols can appear in a comdat region"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Builder, printer and verifier for LLVM::GlobalOp. |
| //===----------------------------------------------------------------------===// |
| |
| void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, |
| bool isConstant, Linkage linkage, StringRef name, |
| Attribute value, uint64_t alignment, unsigned addrSpace, |
| bool dsoLocal, bool threadLocal, SymbolRefAttr comdat, |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<Attribute> dbgExprs) { |
| result.addAttribute(getSymNameAttrName(result.name), |
| builder.getStringAttr(name)); |
| result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type)); |
| if (isConstant) |
| result.addAttribute(getConstantAttrName(result.name), |
| builder.getUnitAttr()); |
| if (value) |
| result.addAttribute(getValueAttrName(result.name), value); |
| if (dsoLocal) |
| result.addAttribute(getDsoLocalAttrName(result.name), |
| builder.getUnitAttr()); |
| if (threadLocal) |
| result.addAttribute(getThreadLocal_AttrName(result.name), |
| builder.getUnitAttr()); |
| if (comdat) |
| result.addAttribute(getComdatAttrName(result.name), comdat); |
| |
| // Only add an alignment attribute if the "alignment" input |
| // is different from 0. The value must also be a power of two, but |
| // this is tested in GlobalOp::verify, not here. |
| if (alignment != 0) |
| result.addAttribute(getAlignmentAttrName(result.name), |
| builder.getI64IntegerAttr(alignment)); |
| |
| result.addAttribute(getLinkageAttrName(result.name), |
| LinkageAttr::get(builder.getContext(), linkage)); |
| if (addrSpace != 0) |
| result.addAttribute(getAddrSpaceAttrName(result.name), |
| builder.getI32IntegerAttr(addrSpace)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| |
| if (!dbgExprs.empty()) |
| result.addAttribute(getDbgExprsAttrName(result.name), |
| ArrayAttr::get(builder.getContext(), dbgExprs)); |
| |
| result.addRegion(); |
| } |
| |
| void GlobalOp::print(OpAsmPrinter &p) { |
| p << ' ' << stringifyLinkage(getLinkage()) << ' '; |
| StringRef visibility = stringifyVisibility(getVisibility_()); |
| if (!visibility.empty()) |
| p << visibility << ' '; |
| if (getThreadLocal_()) |
| p << "thread_local "; |
| if (auto unnamedAddr = getUnnamedAddr()) { |
| StringRef str = stringifyUnnamedAddr(*unnamedAddr); |
| if (!str.empty()) |
| p << str << ' '; |
| } |
| if (getConstant()) |
| p << "constant "; |
| p.printSymbolName(getSymName()); |
| p << '('; |
| if (auto value = getValueOrNull()) |
| p.printAttribute(value); |
| p << ')'; |
| if (auto comdat = getComdat()) |
| p << " comdat(" << *comdat << ')'; |
| |
| // Note that the alignment attribute is printed using the |
| // default syntax here, even though it is an inherent attribute |
| // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes) |
| p.printOptionalAttrDict((*this)->getAttrs(), |
| {SymbolTable::getSymbolAttrName(), |
| getGlobalTypeAttrName(), getConstantAttrName(), |
| getValueAttrName(), getLinkageAttrName(), |
| getUnnamedAddrAttrName(), getThreadLocal_AttrName(), |
| getVisibility_AttrName(), getComdatAttrName(), |
| getUnnamedAddrAttrName()}); |
| |
| // Print the trailing type unless it's a string global. |
| if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) |
| return; |
| p << " : " << getType(); |
| |
| Region &initializer = getInitializerRegion(); |
| if (!initializer.empty()) { |
| p << ' '; |
| p.printRegion(initializer, /*printEntryBlockArgs=*/false); |
| } |
| } |
| |
| static LogicalResult verifyComdat(Operation *op, |
| std::optional<SymbolRefAttr> attr) { |
| if (!attr) |
| return success(); |
| |
| auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr); |
| if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector)) |
| return op->emitError() << "expected comdat symbol"; |
| |
| return success(); |
| } |
| |
| static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) { |
| llvm::DenseSet<BlockTagAttr> blockTags; |
| BlockTagOp badBlockTagOp; |
| if (funcOp |
| .walk([&](BlockTagOp blockTagOp) { |
| if (blockTags.contains(blockTagOp.getTag())) { |
| badBlockTagOp = blockTagOp; |
| return WalkResult::interrupt(); |
| } |
| blockTags.insert(blockTagOp.getTag()); |
| return WalkResult::advance(); |
| }) |
| .wasInterrupted()) { |
| badBlockTagOp.emitError() |
| << "duplicate block tag '" << badBlockTagOp.getTag().getId() |
| << "' in the same function: "; |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Parse common attributes that might show up in the same order in both |
| /// GlobalOp and AliasOp. |
| template <typename OpType> |
| static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser, |
| OperationState &result) { |
| MLIRContext *ctx = parser.getContext(); |
| // Parse optional linkage, default to External. |
| result.addAttribute(OpType::getLinkageAttrName(result.name), |
| LLVM::LinkageAttr::get( |
| ctx, parseOptionalLLVMKeyword<Linkage>( |
| parser, result, LLVM::Linkage::External))); |
| |
| // Parse optional visibility, default to Default. |
| result.addAttribute(OpType::getVisibility_AttrName(result.name), |
| parser.getBuilder().getI64IntegerAttr( |
| parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>( |
| parser, result, LLVM::Visibility::Default))); |
| |
| // Parse optional UnnamedAddr, default to None. |
| result.addAttribute(OpType::getUnnamedAddrAttrName(result.name), |
| parser.getBuilder().getI64IntegerAttr( |
| parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( |
| parser, result, LLVM::UnnamedAddr::None))); |
| |
| if (succeeded(parser.parseOptionalKeyword("thread_local"))) |
| result.addAttribute(OpType::getThreadLocal_AttrName(result.name), |
| parser.getBuilder().getUnitAttr()); |
| |
| return success(); |
| } |
| |
| // operation ::= `llvm.mlir.global` linkage? visibility? |
| // (`unnamed_addr` | `local_unnamed_addr`)? |
| // `thread_local`? `constant`? `@` identifier |
| // `(` attribute? `)` (`comdat(` symbol-ref-id `)`)? |
| // attribute-list? (`:` type)? region? |
| // |
| // The type can be omitted for string attributes, in which case it will be |
| // inferred from the value of the string as [strlen(value) x i8]. |
| ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { |
| // Call into common parsing between GlobalOp and AliasOp. |
| if (parseCommonGlobalAndAlias<GlobalOp>(parser, result).failed()) |
| return failure(); |
| |
| if (succeeded(parser.parseOptionalKeyword("constant"))) |
| result.addAttribute(getConstantAttrName(result.name), |
| parser.getBuilder().getUnitAttr()); |
| |
| StringAttr name; |
| if (parser.parseSymbolName(name, getSymNameAttrName(result.name), |
| result.attributes) || |
| parser.parseLParen()) |
| return failure(); |
| |
| Attribute value; |
| if (parser.parseOptionalRParen()) { |
| if (parser.parseAttribute(value, getValueAttrName(result.name), |
| result.attributes) || |
| parser.parseRParen()) |
| return failure(); |
| } |
| |
| if (succeeded(parser.parseOptionalKeyword("comdat"))) { |
| SymbolRefAttr comdat; |
| if (parser.parseLParen() || parser.parseAttribute(comdat) || |
| parser.parseRParen()) |
| return failure(); |
| |
| result.addAttribute(getComdatAttrName(result.name), comdat); |
| } |
| |
| SmallVector<Type, 1> types; |
| if (parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseOptionalColonTypeList(types)) |
| return failure(); |
| |
| if (types.size() > 1) |
| return parser.emitError(parser.getNameLoc(), "expected zero or one type"); |
| |
| Region &initRegion = *result.addRegion(); |
| if (types.empty()) { |
| if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) { |
| MLIRContext *context = parser.getContext(); |
| auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), |
| strAttr.getValue().size()); |
| types.push_back(arrayType); |
| } else { |
| return parser.emitError(parser.getNameLoc(), |
| "type can only be omitted for string globals"); |
| } |
| } else { |
| OptionalParseResult parseResult = |
| parser.parseOptionalRegion(initRegion, /*arguments=*/{}, |
| /*argTypes=*/{}); |
| if (parseResult.has_value() && failed(*parseResult)) |
| return failure(); |
| } |
| |
| result.addAttribute(getGlobalTypeAttrName(result.name), |
| TypeAttr::get(types[0])); |
| return success(); |
| } |
| |
| static bool isZeroAttribute(Attribute value) { |
| if (auto intValue = llvm::dyn_cast<IntegerAttr>(value)) |
| return intValue.getValue().isZero(); |
| if (auto fpValue = llvm::dyn_cast<FloatAttr>(value)) |
| return fpValue.getValue().isZero(); |
| if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value)) |
| return isZeroAttribute(splatValue.getSplatValue<Attribute>()); |
| if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value)) |
| return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute); |
| if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value)) |
| return llvm::all_of(arrayValue.getValue(), isZeroAttribute); |
| return false; |
| } |
| |
| LogicalResult GlobalOp::verify() { |
| bool validType = isCompatibleOuterType(getType()) |
| ? !llvm::isa<LLVMVoidType, LLVMTokenType, |
| LLVMMetadataType, LLVMLabelType>(getType()) |
| : llvm::isa<PointerElementTypeInterface>(getType()); |
| if (!validType) |
| return emitOpError( |
| "expects type to be a valid element type for an LLVM global"); |
| if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) |
| return emitOpError("must appear at the module level"); |
| |
| if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) { |
| auto type = llvm::dyn_cast<LLVMArrayType>(getType()); |
| IntegerType elementType = |
| type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr; |
| if (!elementType || elementType.getWidth() != 8 || |
| type.getNumElements() != strAttr.getValue().size()) |
| return emitOpError( |
| "requires an i8 array type of the length equal to that of the string " |
| "attribute"); |
| } |
| |
| if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { |
| if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal)) |
| return emitOpError() |
| << "this target extension type cannot be used in a global"; |
| |
| if (Attribute value = getValueOrNull()) |
| return emitOpError() << "global with target extension type can only be " |
| "initialized with zero-initializer"; |
| } |
| |
| if (getLinkage() == Linkage::Common) { |
| if (Attribute value = getValueOrNull()) { |
| if (!isZeroAttribute(value)) { |
| return emitOpError() |
| << "expected zero value for '" |
| << stringifyLinkage(Linkage::Common) << "' linkage"; |
| } |
| } |
| } |
| |
| if (getLinkage() == Linkage::Appending) { |
| if (!llvm::isa<LLVMArrayType>(getType())) { |
| return emitOpError() << "expected array type for '" |
| << stringifyLinkage(Linkage::Appending) |
| << "' linkage"; |
| } |
| } |
| |
| if (failed(verifyComdat(*this, getComdat()))) |
| return failure(); |
| |
| std::optional<uint64_t> alignAttr = getAlignment(); |
| if (alignAttr.has_value()) { |
| uint64_t value = alignAttr.value(); |
| if (!llvm::isPowerOf2_64(value)) |
| return emitError() << "alignment attribute is not a power of 2"; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult GlobalOp::verifyRegions() { |
| if (Block *b = getInitializerBlock()) { |
| ReturnOp ret = cast<ReturnOp>(b->getTerminator()); |
| if (ret.operand_type_begin() == ret.operand_type_end()) |
| return emitOpError("initializer region cannot return void"); |
| if (*ret.operand_type_begin() != getType()) |
| return emitOpError("initializer region type ") |
| << *ret.operand_type_begin() << " does not match global type " |
| << getType(); |
| |
| for (Operation &op : *b) { |
| auto iface = dyn_cast<MemoryEffectOpInterface>(op); |
| if (!iface || !iface.hasNoEffect()) |
| return op.emitError() |
| << "ops with side effects not allowed in global initializers"; |
| } |
| |
| if (getValueOrNull()) |
| return emitOpError("cannot have both initializer value and region"); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::GlobalCtorsOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) { |
| if (data.empty()) |
| return success(); |
| |
| if (llvm::all_of(data.getAsRange<Attribute>(), [](Attribute v) { |
| return isa<FlatSymbolRefAttr, ZeroAttr>(v); |
| })) |
| return success(); |
| return op->emitError("data element must be symbol or #llvm.zero"); |
| } |
| |
| LogicalResult |
| GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| for (Attribute ctor : getCtors()) { |
| if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this, |
| symbolTable))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| LogicalResult GlobalCtorsOp::verify() { |
| if (checkGlobalXtorData(*this, getData()).failed()) |
| return failure(); |
| |
| if (getCtors().size() == getPriorities().size() && |
| getCtors().size() == getData().size()) |
| return success(); |
| return emitError( |
| "ctors, priorities, and data must have the same number of elements"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVM::GlobalDtorsOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| for (Attribute dtor : getDtors()) { |
| if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this, |
| symbolTable))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| LogicalResult GlobalDtorsOp::verify() { |
| if (checkGlobalXtorData(*this, getData()).failed()) |
| return failure(); |
| |
| if (getDtors().size() == getPriorities().size() && |
| getDtors().size() == getData().size()) |
| return success(); |
| return emitError( |
| "dtors, priorities, and data must have the same number of elements"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Builder, printer and verifier for LLVM::AliasOp. |
| //===----------------------------------------------------------------------===// |
| |
| void AliasOp::build(OpBuilder &builder, OperationState &result, Type type, |
| Linkage linkage, StringRef name, bool dsoLocal, |
| bool threadLocal, ArrayRef<NamedAttribute> attrs) { |
| result.addAttribute(getSymNameAttrName(result.name), |
| builder.getStringAttr(name)); |
| result.addAttribute(getAliasTypeAttrName(result.name), TypeAttr::get(type)); |
| if (dsoLocal) |
| result.addAttribute(getDsoLocalAttrName(result.name), |
| builder.getUnitAttr()); |
| if (threadLocal) |
| result.addAttribute(getThreadLocal_AttrName(result.name), |
| builder.getUnitAttr()); |
| |
| result.addAttribute(getLinkageAttrName(result.name), |
| LinkageAttr::get(builder.getContext(), linkage)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| |
| result.addRegion(); |
| } |
| |
| void AliasOp::print(OpAsmPrinter &p) { |
| p << ' ' << stringifyLinkage(getLinkage()) << ' '; |
| StringRef visibility = stringifyVisibility(getVisibility_()); |
| if (!visibility.empty()) |
| p << visibility << ' '; |
| |
| if (std::optional<mlir::LLVM::UnnamedAddr> unnamedAddr = getUnnamedAddr()) { |
| StringRef str = stringifyUnnamedAddr(*unnamedAddr); |
| if (!str.empty()) |
| p << str << ' '; |
| } |
| |
| if (getThreadLocal_()) |
| p << "thread_local "; |
| |
| p.printSymbolName(getSymName()); |
| p.printOptionalAttrDict((*this)->getAttrs(), |
| {SymbolTable::getSymbolAttrName(), |
| getAliasTypeAttrName(), getLinkageAttrName(), |
| getUnnamedAddrAttrName(), getThreadLocal_AttrName(), |
| getVisibility_AttrName(), getUnnamedAddrAttrName()}); |
| |
| // Print the trailing type. |
| p << " : " << getType() << ' '; |
| // Print the initializer region. |
| p.printRegion(getInitializerRegion(), /*printEntryBlockArgs=*/false); |
| } |
| |
| // operation ::= `llvm.mlir.alias` linkage? visibility? |
| // (`unnamed_addr` | `local_unnamed_addr`)? |
| // `thread_local`? `@` identifier |
| // `(` attribute? `)` |
| // attribute-list? `:` type region |
| // |
| ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) { |
| // Call into common parsing between GlobalOp and AliasOp. |
| if (parseCommonGlobalAndAlias<AliasOp>(parser, result).failed()) |
| return failure(); |
| |
| StringAttr name; |
| if (parser.parseSymbolName(name, getSymNameAttrName(result.name), |
| result.attributes)) |
| return failure(); |
| |
| SmallVector<Type, 1> types; |
| if (parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseOptionalColonTypeList(types)) |
| return failure(); |
| |
| if (types.size() > 1) |
| return parser.emitError(parser.getNameLoc(), "expected zero or one type"); |
| |
| Region &initRegion = *result.addRegion(); |
| if (parser.parseRegion(initRegion).failed()) |
| return failure(); |
| |
| result.addAttribute(getAliasTypeAttrName(result.name), |
| TypeAttr::get(types[0])); |
| return success(); |
| } |
| |
| LogicalResult AliasOp::verify() { |
| bool validType = isCompatibleOuterType(getType()) |
| ? !llvm::isa<LLVMVoidType, LLVMTokenType, |
| LLVMMetadataType, LLVMLabelType>(getType()) |
| : llvm::isa<PointerElementTypeInterface>(getType()); |
| if (!validType) |
| return emitOpError( |
| "expects type to be a valid element type for an LLVM global alias"); |
| |
| // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp |
| switch (getLinkage()) { |
| case Linkage::External: |
| case Linkage::Internal: |
| case Linkage::Private: |
| case Linkage::Weak: |
| case Linkage::WeakODR: |
| case Linkage::Linkonce: |
| case Linkage::LinkonceODR: |
| case Linkage::AvailableExternally: |
| break; |
| default: |
| return emitOpError() |
| << "'" << stringifyLinkage(getLinkage()) |
| << "' linkage not supported in aliases, available options: private, " |
| "internal, linkonce, weak, linkonce_odr, weak_odr, external or " |
| "available_externally"; |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult AliasOp::verifyRegions() { |
| Block &b = getInitializerBlock(); |
| auto ret = cast<ReturnOp>(b.getTerminator()); |
| if (ret.getNumOperands() == 0 || |
| !isa<LLVM::LLVMPointerType>(ret.getOperand(0).getType())) |
| return emitOpError("initializer region must always return a pointer"); |
| |
| for (Operation &op : b) { |
| auto iface = dyn_cast<MemoryEffectOpInterface>(op); |
| if (!iface || !iface.hasNoEffect()) |
| return op.emitError() |
| << "ops with side effects are not allowed in alias initializers"; |
| } |
| |
| return success(); |
| } |
| |
| unsigned AliasOp::getAddrSpace() { |
| Block &initializer = getInitializerBlock(); |
| auto ret = cast<ReturnOp>(initializer.getTerminator()); |
| auto ptrTy = cast<LLVMPointerType>(ret.getOperand(0).getType()); |
| return ptrTy.getAddressSpace(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ShuffleVectorOp |
| //===----------------------------------------------------------------------===// |
| |
| void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, |
| Value v2, DenseI32ArrayAttr mask, |
| ArrayRef<NamedAttribute> attrs) { |
| auto containerType = v1.getType(); |
| auto vType = LLVM::getVectorType( |
| cast<VectorType>(containerType).getElementType(), mask.size(), |
| LLVM::isScalableVectorType(containerType)); |
| build(builder, state, vType, v1, v2, mask); |
| state.addAttributes(attrs); |
| } |
| |
| void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1, |
| Value v2, ArrayRef<int32_t> mask) { |
| build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask)); |
| } |
| |
| /// Build the result type of a shuffle vector operation. |
| static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, |
| Type &resType, DenseI32ArrayAttr mask) { |
| if (!LLVM::isCompatibleVectorType(v1Type)) |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected an LLVM compatible vector type"); |
| resType = |
| LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(), |
| mask.size(), LLVM::isScalableVectorType(v1Type)); |
| return success(); |
| } |
| |
| /// Nothing to do when the result type is inferred. |
| static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, |
| Type resType, DenseI32ArrayAttr mask) {} |
| |
| LogicalResult ShuffleVectorOp::verify() { |
| if (LLVM::isScalableVectorType(getV1().getType()) && |
| llvm::any_of(getMask(), [](int32_t v) { return v != 0; })) |
| return emitOpError("expected a splat operation for scalable vectors"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Implementations for LLVM::LLVMFuncOp. |
| //===----------------------------------------------------------------------===// |
| |
| // Add the entry block to the function. |
| Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) { |
| assert(empty() && "function already has an entry block"); |
| OpBuilder::InsertionGuard g(builder); |
| Block *entry = builder.createBlock(&getBody()); |
| |
| // FIXME: Allow passing in proper locations for the entry arguments. |
| LLVMFunctionType type = getFunctionType(); |
| for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) |
| entry->addArgument(type.getParamType(i), getLoc()); |
| return entry; |
| } |
| |
| void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, |
| StringRef name, Type type, LLVM::Linkage linkage, |
| bool dsoLocal, CConv cconv, SymbolRefAttr comdat, |
| ArrayRef<NamedAttribute> attrs, |
| ArrayRef<DictionaryAttr> argAttrs, |
| std::optional<uint64_t> functionEntryCount) { |
| result.addRegion(); |
| result.addAttribute(SymbolTable::getSymbolAttrName(), |
| builder.getStringAttr(name)); |
| result.addAttribute(getFunctionTypeAttrName(result.name), |
| TypeAttr::get(type)); |
| result.addAttribute(getLinkageAttrName(result.name), |
| LinkageAttr::get(builder.getContext(), linkage)); |
| result.addAttribute(getCConvAttrName(result.name), |
| CConvAttr::get(builder.getContext(), cconv)); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| if (dsoLocal) |
| result.addAttribute(getDsoLocalAttrName(result.name), |
| builder.getUnitAttr()); |
| if (comdat) |
| result.addAttribute(getComdatAttrName(result.name), comdat); |
| if (functionEntryCount) |
| result.addAttribute(getFunctionEntryCountAttrName(result.name), |
| builder.getI64IntegerAttr(functionEntryCount.value())); |
| if (argAttrs.empty()) |
| return; |
| |
| assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() && |
| "expected as many argument attribute lists as arguments"); |
| call_interface_impl::addArgAndResultAttrs( |
| builder, result, argAttrs, /*resultAttrs=*/std::nullopt, |
| getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| } |
| |
| // Builds an LLVM function type from the given lists of input and output types. |
| // Returns a null type if any of the types provided are non-LLVM types, or if |
| // there is more than one output type. |
| static Type |
| buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef<Type> inputs, |
| ArrayRef<Type> outputs, |
| function_interface_impl::VariadicFlag variadicFlag) { |
| Builder &b = parser.getBuilder(); |
| if (outputs.size() > 1) { |
| parser.emitError(loc, "failed to construct function type: expected zero or " |
| "one function result"); |
| return {}; |
| } |
| |
| // Convert inputs to LLVM types, exit early on error. |
| SmallVector<Type, 4> llvmInputs; |
| for (auto t : inputs) { |
| if (!isCompatibleType(t)) { |
| parser.emitError(loc, "failed to construct function type: expected LLVM " |
| "type for function arguments"); |
| return {}; |
| } |
| llvmInputs.push_back(t); |
| } |
| |
| // No output is denoted as "void" in LLVM type system. |
| Type llvmOutput = |
| outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); |
| if (!isCompatibleType(llvmOutput)) { |
| parser.emitError(loc, "failed to construct function type: expected LLVM " |
| "type for function results") |
| << llvmOutput; |
| return {}; |
| } |
| return LLVMFunctionType::get(llvmOutput, llvmInputs, |
| variadicFlag.isVariadic()); |
| } |
| |
| // Parses an LLVM function. |
| // |
| // operation ::= `llvm.func` linkage? cconv? function-signature |
| // (`comdat(` symbol-ref-id `)`)? |
| // function-attributes? |
| // function-body |
| // |
| ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { |
| // Default to external linkage if no keyword is provided. |
| result.addAttribute( |
| getLinkageAttrName(result.name), |
| LinkageAttr::get(parser.getContext(), |
| parseOptionalLLVMKeyword<Linkage>( |
| parser, result, LLVM::Linkage::External))); |
| |
| // Parse optional visibility, default to Default. |
| result.addAttribute(getVisibility_AttrName(result.name), |
| parser.getBuilder().getI64IntegerAttr( |
| parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>( |
| parser, result, LLVM::Visibility::Default))); |
| |
| // Parse optional UnnamedAddr, default to None. |
| result.addAttribute(getUnnamedAddrAttrName(result.name), |
| parser.getBuilder().getI64IntegerAttr( |
| parseOptionalLLVMKeyword<UnnamedAddr, int64_t>( |
| parser, result, LLVM::UnnamedAddr::None))); |
| |
| // Default to C Calling Convention if no keyword is provided. |
| result.addAttribute( |
| getCConvAttrName(result.name), |
| CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>( |
| parser, result, LLVM::CConv::C))); |
| |
| StringAttr nameAttr; |
| SmallVector<OpAsmParser::Argument> entryArgs; |
| SmallVector<DictionaryAttr> resultAttrs; |
| SmallVector<Type> resultTypes; |
| bool isVariadic; |
| |
| auto signatureLocation = parser.getCurrentLocation(); |
| if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), |
| result.attributes) || |
| function_interface_impl::parseFunctionSignatureWithArguments( |
| parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, |
| resultAttrs)) |
| return failure(); |
| |
| SmallVector<Type> argTypes; |
| for (auto &arg : entryArgs) |
| argTypes.push_back(arg.type); |
| auto type = |
| buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, |
| function_interface_impl::VariadicFlag(isVariadic)); |
| if (!type) |
| return failure(); |
| result.addAttribute(getFunctionTypeAttrName(result.name), |
| TypeAttr::get(type)); |
| |
| if (succeeded(parser.parseOptionalKeyword("vscale_range"))) { |
| int64_t minRange, maxRange; |
| if (parser.parseLParen() || parser.parseInteger(minRange) || |
| parser.parseComma() || parser.parseInteger(maxRange) || |
| parser.parseRParen()) |
| return failure(); |
| auto intTy = IntegerType::get(parser.getContext(), 32); |
| result.addAttribute( |
| getVscaleRangeAttrName(result.name), |
| LLVM::VScaleRangeAttr::get(parser.getContext(), |
| IntegerAttr::get(intTy, minRange), |
| IntegerAttr::get(intTy, maxRange))); |
| } |
| // Parse the optional comdat selector. |
| if (succeeded(parser.parseOptionalKeyword("comdat"))) { |
| SymbolRefAttr comdat; |
| if (parser.parseLParen() || parser.parseAttribute(comdat) || |
| parser.parseRParen()) |
| return failure(); |
| |
| result.addAttribute(getComdatAttrName(result.name), comdat); |
| } |
| |
| if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
| return failure(); |
| call_interface_impl::addArgAndResultAttrs( |
| parser.getBuilder(), result, entryArgs, resultAttrs, |
| getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| |
| auto *body = result.addRegion(); |
| OptionalParseResult parseResult = |
| parser.parseOptionalRegion(*body, entryArgs); |
| return failure(parseResult.has_value() && failed(*parseResult)); |
| } |
| |
| // Print the LLVMFuncOp. Collects argument and result types and passes them to |
| // helper functions. Drops "void" result since it cannot be parsed back. Skips |
| // the external linkage since it is the default value. |
| void LLVMFuncOp::print(OpAsmPrinter &p) { |
| p << ' '; |
| if (getLinkage() != LLVM::Linkage::External) |
| p << stringifyLinkage(getLinkage()) << ' '; |
| StringRef visibility = stringifyVisibility(getVisibility_()); |
| if (!visibility.empty()) |
| p << visibility << ' '; |
| if (auto unnamedAddr = getUnnamedAddr()) { |
| StringRef str = stringifyUnnamedAddr(*unnamedAddr); |
| if (!str.empty()) |
| p << str << ' '; |
| } |
| if (getCConv() != LLVM::CConv::C) |
| p << stringifyCConv(getCConv()) << ' '; |
| |
| p.printSymbolName(getName()); |
| |
| LLVMFunctionType fnType = getFunctionType(); |
| SmallVector<Type, 8> argTypes; |
| SmallVector<Type, 1> resTypes; |
| argTypes.reserve(fnType.getNumParams()); |
| for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) |
| argTypes.push_back(fnType.getParamType(i)); |
| |
| Type returnType = fnType.getReturnType(); |
| if (!llvm::isa<LLVMVoidType>(returnType)) |
| resTypes.push_back(returnType); |
| |
| function_interface_impl::printFunctionSignature(p, *this, argTypes, |
| isVarArg(), resTypes); |
| |
| // Print vscale range if present |
| if (std::optional<VScaleRangeAttr> vscale = getVscaleRange()) |
| p << " vscale_range(" << vscale->getMinRange().getInt() << ", " |
| << vscale->getMaxRange().getInt() << ')'; |
| |
| // Print the optional comdat selector. |
| if (auto comdat = getComdat()) |
| p << " comdat(" << *comdat << ')'; |
| |
| function_interface_impl::printFunctionAttributes( |
| p, *this, |
| {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), |
| getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(), |
| getComdatAttrName(), getUnnamedAddrAttrName(), |
| getVscaleRangeAttrName()}); |
| |
| // Print the body if this is not an external function. |
| Region &body = getBody(); |
| if (!body.empty()) { |
| p << ' '; |
| p.printRegion(body, /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/true); |
| } |
| } |
| |
| // Verifies LLVM- and implementation-specific properties of the LLVM func Op: |
| // - functions don't have 'common' linkage |
| // - external functions have 'external' or 'extern_weak' linkage; |
| // - vararg is (currently) only supported for external functions; |
| LogicalResult LLVMFuncOp::verify() { |
| if (getLinkage() == LLVM::Linkage::Common) |
| return emitOpError() << "functions cannot have '" |
| << stringifyLinkage(LLVM::Linkage::Common) |
| << "' linkage"; |
| |
| if (failed(verifyComdat(*this, getComdat()))) |
| return failure(); |
| |
| if (isExternal()) { |
| if (getLinkage() != LLVM::Linkage::External && |
| getLinkage() != LLVM::Linkage::ExternWeak) |
| return emitOpError() << "external functions must have '" |
| << stringifyLinkage(LLVM::Linkage::External) |
| << "' or '" |
| << stringifyLinkage(LLVM::Linkage::ExternWeak) |
| << "' linkage"; |
| return success(); |
| } |
| |
| // In LLVM IR, these attributes are composed by convention, not by design. |
| if (isNoInline() && isAlwaysInline()) |
| return emitError("no_inline and always_inline attributes are incompatible"); |
| |
| if (isOptimizeNone() && !isNoInline()) |
| return emitOpError("with optimize_none must also be no_inline"); |
| |
| Type landingpadResultTy; |
| StringRef diagnosticMessage; |
| bool isLandingpadTypeConsistent = |
| !walk([&](Operation *op) { |
| const auto checkType = [&](Type type, StringRef errorMessage) { |
| if (!landingpadResultTy) { |
| landingpadResultTy = type; |
| return WalkResult::advance(); |
| } |
| if (landingpadResultTy != type) { |
| diagnosticMessage = errorMessage; |
| return WalkResult::interrupt(); |
| } |
| return WalkResult::advance(); |
| }; |
| return TypeSwitch<Operation *, WalkResult>(op) |
| .Case<LandingpadOp>([&](auto landingpad) { |
| constexpr StringLiteral errorMessage = |
| "'llvm.landingpad' should have a consistent result type " |
| "inside a function"; |
| return checkType(landingpad.getType(), errorMessage); |
| }) |
| .Case<ResumeOp>([&](auto resume) { |
| constexpr StringLiteral errorMessage = |
| "'llvm.resume' should have a consistent input type inside a " |
| "function"; |
| return checkType(resume.getValue().getType(), errorMessage); |
| }) |
| .Default([](auto) { return WalkResult::skip(); }); |
| }).wasInterrupted(); |
| if (!isLandingpadTypeConsistent) { |
| assert(!diagnosticMessage.empty() && |
| "Expecting a non-empty diagnostic message"); |
| return emitError(diagnosticMessage); |
| } |
| |
| if (failed(verifyBlockTags(*this))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| /// Verifies LLVM- and implementation-specific properties of the LLVM func Op: |
| /// - entry block arguments are of LLVM types. |
| LogicalResult LLVMFuncOp::verifyRegions() { |
| if (isExternal()) |
| return success(); |
| |
| unsigned numArguments = getFunctionType().getNumParams(); |
| Block &entryBlock = front(); |
| for (unsigned i = 0; i < numArguments; ++i) { |
| Type argType = entryBlock.getArgument(i).getType(); |
| if (!isCompatibleType(argType)) |
| return emitOpError("entry block argument #") |
| << i << " is not of LLVM type"; |
| } |
| |
| return success(); |
| } |
| |
| Region *LLVMFuncOp::getCallableRegion() { |
| if (isExternal()) |
| return nullptr; |
| return &getBody(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UndefOp. |
| //===----------------------------------------------------------------------===// |
| |
| /// Fold an undef operation to a dedicated undef attribute. |
| OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) { |
| return LLVM::UndefAttr::get(getContext()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PoisonOp. |
| //===----------------------------------------------------------------------===// |
| |
| /// Fold a poison operation to a dedicated poison attribute. |
| OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) { |
| return LLVM::PoisonAttr::get(getContext()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ZeroOp. |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult LLVM::ZeroOp::verify() { |
| if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) |
| if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit)) |
| return emitOpError() |
| << "target extension type does not support zero-initializer"; |
| |
| return success(); |
| } |
| |
| /// Fold a zero operation to a builtin zero attribute when possible and fall |
| /// back to a dedicated zero attribute. |
| OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) { |
| OpFoldResult result = Builder(getContext()).getZeroAttr(getType()); |
| if (result) |
| return result; |
| return LLVM::ZeroAttr::get(getContext()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp. |
| //===----------------------------------------------------------------------===// |
| |
| /// Compute the total number of elements in the given type, also taking into |
| /// account nested types. Supported types are `VectorType` and `LLVMArrayType`. |
| /// Everything else is treated as a scalar. |
| static int64_t getNumElements(Type t) { |
| if (auto vecType = dyn_cast<VectorType>(t)) { |
| assert(!vecType.isScalable() && |
| "number of elements of a scalable vector type is unknown"); |
| return vecType.getNumElements() * getNumElements(vecType.getElementType()); |
| } |
| if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) |
| return arrayType.getNumElements() * |
| getNumElements(arrayType.getElementType()); |
| return 1; |
| } |
| |
| /// Check if the given type is a scalable vector type or a vector/array type |
| /// that contains a nested scalable vector type. |
| static bool hasScalableVectorType(Type t) { |
| if (auto vecType = dyn_cast<VectorType>(t)) { |
| if (vecType.isScalable()) |
| return true; |
| return hasScalableVectorType(vecType.getElementType()); |
| } |
| if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) |
| return hasScalableVectorType(arrayType.getElementType()); |
| return false; |
| } |
| |
| LogicalResult LLVM::ConstantOp::verify() { |
| if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) { |
| auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType()); |
| if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || |
| !arrayType.getElementType().isInteger(8)) { |
| return emitOpError() << "expected array type of " |
| << sAttr.getValue().size() |
| << " i8 elements for the string constant"; |
| } |
| return success(); |
| } |
| if (auto structType = dyn_cast<LLVMStructType>(getType())) { |
| auto arrayAttr = dyn_cast<ArrayAttr>(getValue()); |
| if (!arrayAttr) { |
| return emitOpError() << "expected array attribute for a struct constant"; |
| } |
| |
| ArrayRef<Type> elementTypes = structType.getBody(); |
| if (arrayAttr.size() != elementTypes.size()) { |
| return emitOpError() << "expected array attribute of size " |
| << elementTypes.size(); |
| } |
| for (auto elementTy : elementTypes) { |
| if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) { |
| return emitOpError() << "expected struct element types to be floating " |
| "point type or integer type"; |
| } |
| } |
| |
| for (size_t i = 0; i < elementTypes.size(); ++i) { |
| Attribute element = arrayAttr[i]; |
| if (!isa<IntegerAttr, FloatAttr>(element)) { |
| return emitOpError() |
| << "expected struct element attribute types to be floating " |
| "point type or integer type"; |
| } |
| auto elementType = cast<TypedAttr>(element).getType(); |
| if (elementType != elementTypes[i]) { |
| return emitOpError() |
| << "struct element at index " << i << " is of wrong type"; |
| } |
| } |
| |
| return success(); |
| } |
| if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) { |
| return emitOpError() << "does not support target extension type."; |
| } |
| |
| // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr. |
| if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) { |
| if (!llvm::isa<IntegerType>(getType())) |
| return emitOpError() << "expected integer type"; |
| } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) { |
| const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); |
| unsigned floatWidth = APFloat::getSizeInBits(sem); |
| if (auto floatTy = dyn_cast<FloatType>(getType())) { |
| if (floatTy.getWidth() != floatWidth) { |
| return emitOpError() << "expected float type of width " << floatWidth; |
| } |
| } |
| // See the comment for getLLVMConstant for more details about why 8-bit |
| // floats can be represented by integers. |
| if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) { |
| return emitOpError() << "expected integer type of width " << floatWidth; |
| } |
| } else if (isa<ElementsAttr, ArrayAttr>(getValue())) { |
| if (hasScalableVectorType(getType())) { |
| // The exact number of elements of a scalable vector is unknown, so we |
| // allow only splat attributes. |
| auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue()); |
| if (!splatElementsAttr) |
| return emitOpError() |
| << "scalable vector type requires a splat attribute"; |
| return success(); |
| } |
| if (!isa<VectorType, LLVM::LLVMArrayType>(getType())) |
| return emitOpError() << "expected vector or array type"; |
| // The number of elements of the attribute and the type must match. |
| int64_t attrNumElements; |
| if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) |
| attrNumElements = elementsAttr.getNumElements(); |
| else |
| attrNumElements = cast<ArrayAttr>(getValue()).size(); |
| if (getNumElements(getType()) != attrNumElements) |
| return emitOpError() |
| << "type and attribute have a different number of elements: " |
| << getNumElements(getType()) << " vs. " << attrNumElements; |
| } else { |
| return emitOpError() |
| << "only supports integer, float, string or elements attributes"; |
| } |
| |
| return success(); |
| } |
| |
| bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) { |
| // The value's type must be the same as the provided type. |
| auto typedAttr = dyn_cast<TypedAttr>(value); |
| if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type)) |
| return false; |
| // The value's type must be an LLVM compatible type. |
| if (!isCompatibleType(type)) |
| return false; |
| // TODO: Add support for additional attributes kinds once needed. |
| return isa<IntegerAttr, FloatAttr, ElementsAttr>(value); |
| } |
| |
| ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value, |
| Type type, Location loc) { |
| if (isBuildableWith(value, type)) |
| return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value)); |
| return nullptr; |
| } |
| |
| // Constant op constant-folds to its value. |
| OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); } |
| |
| //===----------------------------------------------------------------------===// |
| // AtomicRMWOp |
| //===----------------------------------------------------------------------===// |
| |
| void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, |
| AtomicBinOp binOp, Value ptr, Value val, |
| AtomicOrdering ordering, StringRef syncscope, |
| unsigned alignment, bool isVolatile) { |
| build(builder, state, val.getType(), binOp, ptr, val, ordering, |
| !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr, |
| alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile, |
| /*access_groups=*/nullptr, |
| /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
| } |
| |
| LogicalResult AtomicRMWOp::verify() { |
| auto valType = getVal().getType(); |
| if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub || |
| getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) { |
| if (isCompatibleVectorType(valType)) { |
| if (isScalableVectorType(valType)) |
| return emitOpError("expected LLVM IR fixed vector type"); |
| Type elemType = llvm::cast<VectorType>(valType).getElementType(); |
| if (!isCompatibleFloatingPointType(elemType)) |
| return emitOpError( |
| "expected LLVM IR floating point type for vector element"); |
| } else if (!isCompatibleFloatingPointType(valType)) { |
| return emitOpError("expected LLVM IR floating point type"); |
| } |
| } else if (getBinOp() == AtomicBinOp::xchg) { |
| DataLayout dataLayout = DataLayout::closest(*this); |
| if (!isTypeCompatibleWithAtomicOp(valType, dataLayout)) |
| return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); |
| } else { |
| auto intType = llvm::dyn_cast<IntegerType>(valType); |
| unsigned intBitWidth = intType ? intType.getWidth() : 0; |
| if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && |
| intBitWidth != 64) |
| return emitOpError("expected LLVM IR integer type"); |
| } |
| |
| if (static_cast<unsigned>(getOrdering()) < |
| static_cast<unsigned>(AtomicOrdering::monotonic)) |
| return emitOpError() << "expected at least '" |
| << stringifyAtomicOrdering(AtomicOrdering::monotonic) |
| << "' ordering"; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AtomicCmpXchgOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns an LLVM struct type that contains a value type and a boolean type. |
| static LLVMStructType getValAndBoolStructType(Type valType) { |
| auto boolType = IntegerType::get(valType.getContext(), 1); |
| return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType}); |
| } |
| |
| void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state, |
| Value ptr, Value cmp, Value val, |
| AtomicOrdering successOrdering, |
| AtomicOrdering failureOrdering, StringRef syncscope, |
| unsigned alignment, bool isWeak, bool isVolatile) { |
| build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val, |
| successOrdering, failureOrdering, |
| !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr, |
| alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak, |
| isVolatile, /*access_groups=*/nullptr, |
| /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); |
| } |
| |
| LogicalResult AtomicCmpXchgOp::verify() { |
| auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()); |
| if (!ptrType) |
| return emitOpError("expected LLVM IR pointer type for operand #0"); |
| auto valType = getVal().getType(); |
| DataLayout dataLayout = DataLayout::closest(*this); |
| if (!isTypeCompatibleWithAtomicOp(valType, dataLayout)) |
| return emitOpError("unexpected LLVM IR type"); |
| if (getSuccessOrdering() < AtomicOrdering::monotonic || |
| getFailureOrdering() < AtomicOrdering::monotonic) |
| return emitOpError("ordering must be at least 'monotonic'"); |
| if (getFailureOrdering() == AtomicOrdering::release || |
| getFailureOrdering() == AtomicOrdering::acq_rel) |
| return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FenceOp |
| //===----------------------------------------------------------------------===// |
| |
| void FenceOp::build(OpBuilder &builder, OperationState &state, |
| AtomicOrdering ordering, StringRef syncscope) { |
| build(builder, state, ordering, |
| syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); |
| } |
| |
| LogicalResult FenceOp::verify() { |
| if (getOrdering() == AtomicOrdering::not_atomic || |
| getOrdering() == AtomicOrdering::unordered || |
| getOrdering() == AtomicOrdering::monotonic) |
| return emitOpError("can be given only acquire, release, acq_rel, " |
| "and seq_cst orderings"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Verifier for extension ops |
| //===----------------------------------------------------------------------===// |
| |
| /// Verifies that the given extension operation operates on consistent scalars |
| /// or vectors, and that the target width is larger than the input width. |
| template <class ExtOp> |
| static LogicalResult verifyExtOp(ExtOp op) { |
| IntegerType inputType, outputType; |
| if (isCompatibleVectorType(op.getArg().getType())) { |
| if (!isCompatibleVectorType(op.getResult().getType())) |
| return op.emitError( |
| "input type is a vector but output type is an integer"); |
| if (getVectorNumElements(op.getArg().getType()) != |
| getVectorNumElements(op.getResult().getType())) |
| return op.emitError("input and output vectors are of incompatible shape"); |
| // Because this is a CastOp, the element of vectors is guaranteed to be an |
| // integer. |
| inputType = cast<IntegerType>( |
| cast<VectorType>(op.getArg().getType()).getElementType()); |
| outputType = cast<IntegerType>( |
| cast<VectorType>(op.getResult().getType()).getElementType()); |
| } else { |
| // Because this is a CastOp and arg is not a vector, arg is guaranteed to be |
| // an integer. |
| inputType = cast<IntegerType>(op.getArg().getType()); |
| outputType = dyn_cast<IntegerType>(op.getResult().getType()); |
| if (!outputType) |
| return op.emitError( |
| "input type is an integer but output type is a vector"); |
| } |
| |
| if (outputType.getWidth() <= inputType.getWidth()) |
| return op.emitError("integer width of the output type is smaller or " |
| "equal to the integer width of the input type"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ZExtOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); } |
| |
| OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) { |
| auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg()); |
| if (!arg) |
| return {}; |
| |
| size_t targetSize = cast<IntegerType>(getType()).getWidth(); |
| return IntegerAttr::get(getType(), arg.getValue().zext(targetSize)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SExtOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); } |
| |
| //===----------------------------------------------------------------------===// |
| // Folder and verifier for LLVM::BitcastOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Folds a cast op that can be chained. |
| template <typename T> |
| static OpFoldResult foldChainableCast(T castOp, |
| typename T::FoldAdaptor adaptor) { |
| // cast(x : T0, T0) -> x |
| if (castOp.getArg().getType() == castOp.getType()) |
| return castOp.getArg(); |
| if (auto prev = castOp.getArg().template getDefiningOp<T>()) { |
| // cast(cast(x : T0, T1), T0) -> x |
| if (prev.getArg().getType() == castOp.getType()) |
| return prev.getArg(); |
| // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2) |
| castOp.getArgMutable().set(prev.getArg()); |
| return Value{castOp}; |
| } |
| return {}; |
| } |
| |
| OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) { |
| return foldChainableCast(*this, adaptor); |
| } |
| |
| LogicalResult LLVM::BitcastOp::verify() { |
| auto resultType = llvm::dyn_cast<LLVMPointerType>( |
| extractVectorElementType(getResult().getType())); |
| auto sourceType = llvm::dyn_cast<LLVMPointerType>( |
| extractVectorElementType(getArg().getType())); |
| |
| // If one of the types is a pointer (or vector of pointers), then |
| // both source and result type have to be pointers. |
| if (static_cast<bool>(resultType) != static_cast<bool>(sourceType)) |
| return emitOpError("can only cast pointers from and to pointers"); |
| |
| if (!resultType) |
| return success(); |
| |
| auto isVector = llvm::IsaPred<VectorType>; |
| |
| // Due to bitcast requiring both operands to be of the same size, it is not |
| // possible for only one of the two to be a pointer of vectors. |
| if (isVector(getResult().getType()) && !isVector(getArg().getType())) |
| return emitOpError("cannot cast pointer to vector of pointers"); |
| |
| if (!isVector(getResult().getType()) && isVector(getArg().getType())) |
| return emitOpError("cannot cast vector of pointers to pointer"); |
| |
| // Bitcast cannot cast between pointers of different address spaces. |
| // 'llvm.addrspacecast' must be used for this purpose instead. |
| if (resultType.getAddressSpace() != sourceType.getAddressSpace()) |
| return emitOpError("cannot cast pointers of different address spaces, " |
| "use 'llvm.addrspacecast' instead"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Folder for LLVM::AddrSpaceCastOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) { |
| return foldChainableCast(*this, adaptor); |
| } |
| |
| Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); } |
| |
| //===----------------------------------------------------------------------===// |
| // Folder for LLVM::GEPOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) { |
| GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(), |
| adaptor.getDynamicIndices()); |
| |
| // gep %x:T, 0 -> %x |
| if (getBase().getType() == getType() && indices.size() == 1) |
| if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0])) |
| if (integer.getValue().isZero()) |
| return getBase(); |
| |
| // Canonicalize any dynamic indices of constant value to constant indices. |
| bool changed = false; |
| SmallVector<GEPArg> gepArgs; |
| for (auto iter : llvm::enumerate(indices)) { |
| auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value()); |
| // Constant indices can only be int32_t, so if integer does not fit we |
| // are forced to keep it dynamic, despite being a constant. |
| if (!indices.isDynamicIndex(iter.index()) || !integer || |
| !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) { |
| |
| PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()]; |
| if (Value val = llvm::dyn_cast_if_present<Value>(existing)) |
| gepArgs.emplace_back(val); |
| else |
| gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt()); |
| |
| continue; |
| } |
| |
| changed = true; |
| gepArgs.emplace_back(integer.getInt()); |
| } |
| if (changed) { |
| SmallVector<int32_t> rawConstantIndices; |
| SmallVector<Value> dynamicIndices; |
| destructureIndices(getElemType(), gepArgs, rawConstantIndices, |
| dynamicIndices); |
| |
| getDynamicIndicesMutable().assign(dynamicIndices); |
| setRawConstantIndices(rawConstantIndices); |
| return Value{*this}; |
| } |
| |
| return {}; |
| } |
| |
| Value LLVM::GEPOp::getViewSource() { return getBase(); } |
| |
| //===----------------------------------------------------------------------===// |
| // ShlOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) { |
| auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()); |
| if (!rhs) |
| return {}; |
| |
| if (rhs.getValue().getZExtValue() >= |
| getLhs().getType().getIntOrFloatBitWidth()) |
| return {}; // TODO: Fold into poison. |
| |
| auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs()); |
| if (!lhs) |
| return {}; |
| |
| return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue())); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OrOp |
| //===----------------------------------------------------------------------===// |
| |
| OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) { |
| auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs()); |
| if (!lhs) |
| return {}; |
| |
| auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()); |
| if (!rhs) |
| return {}; |
| |
| return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CallIntrinsicOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CallIntrinsicOp::verify() { |
| if (!getIntrin().starts_with("llvm.")) |
| return emitOpError() << "intrinsic name must start with 'llvm.'"; |
| if (failed(verifyOperandBundles(*this))) |
| return failure(); |
| return success(); |
| } |
| |
| void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, |
| mlir::StringAttr intrin, mlir::ValueRange args) { |
| build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args, |
| FastmathFlagsAttr{}, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{}, |
| /*res_attrs=*/{}); |
| } |
| |
| void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, |
| mlir::StringAttr intrin, mlir::ValueRange args, |
| mlir::LLVM::FastmathFlagsAttr fastMathFlags) { |
| build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args, |
| fastMathFlags, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{}, |
| /*res_attrs=*/{}); |
| } |
| |
| void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, |
| mlir::Type resultType, mlir::StringAttr intrin, |
| mlir::ValueRange args) { |
| build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{}, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{}, |
| /*res_attrs=*/{}); |
| } |
| |
| void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state, |
| mlir::TypeRange resultTypes, |
| mlir::StringAttr intrin, mlir::ValueRange args, |
| mlir::LLVM::FastmathFlagsAttr fastMathFlags) { |
| build(builder, state, resultTypes, intrin, args, fastMathFlags, |
| /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{}, |
| /*res_attrs=*/{}); |
| } |
| |
| ParseResult CallIntrinsicOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| StringAttr intrinAttr; |
| SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; |
| SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> opBundleOperands; |
| SmallVector<SmallVector<Type>> opBundleOperandTypes; |
| ArrayAttr opBundleTags; |
| |
| // Parse intrinsic name. |
| if (parser.parseCustomAttributeWithFallback( |
| intrinAttr, parser.getBuilder().getType<NoneType>())) |
| return failure(); |
| result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name), |
| intrinAttr); |
| |
| if (parser.parseLParen()) |
| return failure(); |
| |
| // Parse the function arguments. |
| if (parser.parseOperandList(operands)) |
| return mlir::failure(); |
| |
| if (parser.parseRParen()) |
| return mlir::failure(); |
| |
| // Handle bundles. |
| SMLoc opBundlesLoc = parser.getCurrentLocation(); |
| if (std::optional<ParseResult> result = parseOpBundles( |
| parser, opBundleOperands, opBundleOperandTypes, opBundleTags); |
| result && failed(*result)) |
| return failure(); |
| if (opBundleTags && !opBundleTags.empty()) |
| result.addAttribute( |
| CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(), |
| opBundleTags); |
| |
| if (parser.parseOptionalAttrDict(result.attributes)) |
| return mlir::failure(); |
| |
| SmallVector<DictionaryAttr> argAttrs; |
| SmallVector<DictionaryAttr> resultAttrs; |
| if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true, |
| operands, argAttrs, resultAttrs)) |
| return failure(); |
| call_interface_impl::addArgAndResultAttrs( |
| parser.getBuilder(), result, argAttrs, resultAttrs, |
| getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
| |
| if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands, |
| opBundleOperandTypes, |
| getOpBundleSizesAttrName(result.name))) |
| return failure(); |
| |
| int32_t numOpBundleOperands = 0; |
| for (const auto &operands : opBundleOperands) |
| numOpBundleOperands += operands.size(); |
| |
| result.addAttribute( |
| CallIntrinsicOp::getOperandSegmentSizeAttr(), |
| parser.getBuilder().getDenseI32ArrayAttr( |
| {static_cast<int32_t>(operands.size()), numOpBundleOperands})); |
| |
| return mlir::success(); |
| } |
| |
| void CallIntrinsicOp::print(OpAsmPrinter &p) { |
| p << ' '; |
| p.printAttributeWithoutType(getIntrinAttr()); |
| |
| OperandRange args = getArgs(); |
| p << "(" << args << ")"; |
| |
| // Operand bundles. |
| if (!getOpBundleOperands().empty()) { |
| p << ' '; |
| printOpBundles(p, *this, getOpBundleOperands(), |
| getOpBundleOperands().getTypes(), getOpBundleTagsAttr()); |
| } |
| |
| p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), |
| {getOperandSegmentSizesAttrName(), |
| getOpBundleSizesAttrName(), getIntrinAttrName(), |
| getOpBundleTagsAttrName(), getArgAttrsAttrName(), |
| getResAttrsAttrName()}); |
| |
| p << " : "; |
| |
| // Reconstruct the MLIR function type from operand and result types. |
| call_interface_impl::printFunctionSignature( |
| p, args.getTypes(), getArgAttrsAttr(), |
| /*isVariadic=*/false, getResultTypes(), getResAttrsAttr()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpAsmDialectInterface |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface { |
| using OpAsmDialectInterface::OpAsmDialectInterface; |
| |
| AliasResult getAlias(Attribute attr, raw_ostream &os) const override { |
| return TypeSwitch<Attribute, AliasResult>(attr) |
| .Case<AccessGroupAttr, AliasScopeAttr, AliasScopeDomainAttr, |
| DIBasicTypeAttr, DICommonBlockAttr, DICompileUnitAttr, |
| DICompositeTypeAttr, DIDerivedTypeAttr, DIFileAttr, |
| DIGlobalVariableAttr, DIGlobalVariableExpressionAttr, |
| DIImportedEntityAttr, DILabelAttr, DILexicalBlockAttr, |
| DILexicalBlockFileAttr, DILocalVariableAttr, DIModuleAttr, |
| DINamespaceAttr, DINullTypeAttr, DIStringTypeAttr, |
| DISubprogramAttr, DISubroutineTypeAttr, LoopAnnotationAttr, |
| LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr, |
| LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr, |
| LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr, |
| TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) { |
| os << decltype(attr)::getMnemonic(); |
| return AliasResult::OverridableAlias; |
| }) |
| .Default([](Attribute) { return AliasResult::NoAlias; }); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // LinkerOptionsOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult LinkerOptionsOp::verify() { |
| if (mlir::Operation *parentOp = (*this)->getParentOp(); |
| parentOp && !satisfiesLLVMModule(parentOp)) |
| return emitOpError("must appear at the module level"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ModuleFlagsOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ModuleFlagsOp::verify() { |
| if (Operation *parentOp = (*this)->getParentOp(); |
| parentOp && !satisfiesLLVMModule(parentOp)) |
| return emitOpError("must appear at the module level"); |
| for (Attribute flag : getFlags()) |
| if (!isa<ModuleFlagAttr>(flag)) |
| return emitOpError("expected a module flag attribute"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InlineAsmOp |
| //===----------------------------------------------------------------------===// |
| |
| void InlineAsmOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| if (getHasSideEffects()) { |
| effects.emplace_back(MemoryEffects::Write::get()); |
| effects.emplace_back(MemoryEffects::Read::get()); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BlockAddressOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
| Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this), |
| getBlockAddr().getFunction()); |
| auto function = dyn_cast_or_null<LLVMFuncOp>(symbol); |
| |
| if (!function) |
| return emitOpError("must reference a function defined by 'llvm.func'"); |
| |
| return success(); |
| } |
| |
| LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) { |
| return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn( |
| parentLLVMModule(*this), getBlockAddr().getFunction())); |
| } |
| |
| BlockTagOp BlockAddressOp::getBlockTagOp() { |
| auto funcOp = dyn_cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom( |
| parentLLVMModule(*this), getBlockAddr().getFunction())); |
| if (!funcOp) |
| return nullptr; |
| |
| BlockTagOp blockTagOp = nullptr; |
| funcOp.walk([&](LLVM::BlockTagOp labelOp) { |
| if (labelOp.getTag() == getBlockAddr().getTag()) { |
| blockTagOp = labelOp; |
| return WalkResult::interrupt(); |
| } |
| return WalkResult::advance(); |
| }); |
| return blockTagOp; |
| } |
| |
| LogicalResult BlockAddressOp::verify() { |
| if (!getBlockTagOp()) |
| return emitOpError( |
| "expects an existing block label target in the referenced function"); |
| |
| return success(); |
| } |
| |
| /// Fold a blockaddress operation to a dedicated blockaddress |
| /// attribute. |
| OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); } |
| |
| //===----------------------------------------------------------------------===// |
| // AssumeOp (intrinsic) |
| //===----------------------------------------------------------------------===// |
| |
| void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, |
| mlir::Value cond) { |
| return build(builder, state, cond, /*op_bundle_operands=*/{}, |
| /*op_bundle_tags=*/ArrayAttr{}); |
| } |
| |
| void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, |
| Value cond, |
| ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) { |
| SmallVector<ValueRange> opBundleOperands; |
| SmallVector<Attribute> opBundleTags; |
| opBundleOperands.reserve(opBundles.size()); |
| opBundleTags.reserve(opBundles.size()); |
| |
| for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) { |
| opBundleOperands.emplace_back(bundle.inputs()); |
| opBundleTags.push_back( |
| StringAttr::get(builder.getContext(), bundle.getTag())); |
| } |
| |
| auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags); |
| return build(builder, state, cond, opBundleOperands, opBundleTagsAttr); |
| } |
| |
| void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, |
| Value cond, llvm::StringRef tag, ValueRange args) { |
| llvm::OperandBundleDefT<Value> opBundle( |
| tag.str(), SmallVector<Value>(args.begin(), args.end())); |
| return build(builder, state, cond, opBundle); |
| } |
| |
| void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, |
| Value cond, AssumeAlignTag, Value ptr, Value align) { |
| return build(builder, state, cond, "align", ValueRange{ptr, align}); |
| } |
| |
| void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, |
| Value cond, AssumeSeparateStorageTag, Value ptr1, |
| Value ptr2) { |
| return build(builder, state, cond, "separate_storage", |
| ValueRange{ptr1, ptr2}); |
| } |
| |
| LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); } |
| |
| //===----------------------------------------------------------------------===// |
| // masked_gather (intrinsic) |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult LLVM::masked_gather::verify() { |
| auto ptrsVectorType = getPtrs().getType(); |
| Type expectedPtrsVectorType = |
| LLVM::getVectorType(extractVectorElementType(ptrsVectorType), |
| LLVM::getVectorNumElements(getRes().getType())); |
| // Vector of pointers type should match result vector type, other than the |
| // element type. |
| if (ptrsVectorType != expectedPtrsVectorType) |
| return emitOpError("expected operand #1 type to be ") |
| << expectedPtrsVectorType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // masked_scatter (intrinsic) |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult LLVM::masked_scatter::verify() { |
| auto ptrsVectorType = getPtrs().getType(); |
| Type expectedPtrsVectorType = |
| LLVM::getVectorType(extractVectorElementType(ptrsVectorType), |
| LLVM::getVectorNumElements(getValue().getType())); |
| // Vector of pointers type should match value vector type, other than the |
| // element type. |
| if (ptrsVectorType != expectedPtrsVectorType) |
| return emitOpError("expected operand #2 type to be ") |
| << expectedPtrsVectorType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // LLVMDialect initialization, type parsing, and registration. |
| //===----------------------------------------------------------------------===// |
| |
| void LLVMDialect::initialize() { |
| registerAttributes(); |
| |
| // clang-format off |
| addTypes<LLVMVoidType, |
| LLVMTokenType, |
| LLVMLabelType, |
| LLVMMetadataType>(); |
| // clang-format on |
| registerTypes(); |
| |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" |
| , |
| #define GET_OP_LIST |
| #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" |
| >(); |
| |
| // Support unknown operations because not all LLVM operations are registered. |
| allowUnknownOperations(); |
| // clang-format off |
| addInterfaces<LLVMOpAsmDialectInterface>(); |
| // clang-format on |
| declarePromisedInterface<DialectInlinerInterface, LLVMDialect>(); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc" |
| |
| LogicalResult LLVMDialect::verifyDataLayoutString( |
| StringRef descr, llvm::function_ref<void(const Twine &)> reportError) { |
| llvm::Expected<llvm::DataLayout> maybeDataLayout = |
| llvm::DataLayout::parse(descr); |
| if (maybeDataLayout) |
| return success(); |
| |
| std::string message; |
| llvm::raw_string_ostream messageStream(message); |
| llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); |
| reportError("invalid data layout descriptor: " + message); |
| return failure(); |
| } |
| |
| /// Verify LLVM dialect attributes. |
| LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, |
| NamedAttribute attr) { |
| // If the data layout attribute is present, it must use the LLVM data layout |
| // syntax. Try parsing it and report errors in case of failure. Users of this |
| // attribute may assume it is well-formed and can pass it to the (asserting) |
| // llvm::DataLayout constructor. |
| if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) |
| return success(); |
| if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue())) |
| return verifyDataLayoutString( |
| stringAttr.getValue(), |
| [op](const Twine &message) { op->emitOpError() << message.str(); }); |
| |
| return op->emitOpError() << "expected '" |
| << LLVM::LLVMDialect::getDataLayoutAttrName() |
| << "' to be a string attributes"; |
| } |
| |
| LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op, |
| Type paramType, |
| NamedAttribute paramAttr) { |
| // LLVM attribute may be attached to a result of operation that has not been |
| // converted to LLVM dialect yet, so the result may have a type with unknown |
| // representation in LLVM dialect type space. In this case we cannot verify |
| // whether the attribute may be |
| bool verifyValueType = isCompatibleType(paramType); |
| StringAttr name = paramAttr.getName(); |
| |
| auto checkUnitAttrType = [&]() -> LogicalResult { |
| if (!llvm::isa<UnitAttr>(paramAttr.getValue())) |
| return op->emitError() << name << " should be a unit attribute"; |
| return success(); |
| }; |
| auto checkTypeAttrType = [&]() -> LogicalResult { |
| if (!llvm::isa<TypeAttr>(paramAttr.getValue())) |
| return op->emitError() << name << " should be a type attribute"; |
| return success(); |
| }; |
| auto checkIntegerAttrType = [&]() -> LogicalResult { |
| if (!llvm::isa<IntegerAttr>(paramAttr.getValue())) |
| return op->emitError() << name << " should be an integer attribute"; |
| return success(); |
| }; |
| auto checkPointerType = [&]() -> LogicalResult { |
| if (!llvm::isa<LLVMPointerType>(paramType)) |
| return op->emitError() |
| << name << " attribute attached to non-pointer LLVM type"; |
| return success(); |
| }; |
| auto checkIntegerType = [&]() -> LogicalResult { |
| if (!llvm::isa<IntegerType>(paramType)) |
| return op->emitError() |
| << name << " attribute attached to non-integer LLVM type"; |
| return success(); |
| }; |
| auto checkPointerTypeMatches = [&]() -> LogicalResult { |
| if (failed(checkPointerType())) |
| return failure(); |
| |
| return success(); |
| }; |
| |
| // Check a unit attribute that is attached to a pointer value. |
| if (name == LLVMDialect::getNoAliasAttrName() || |
| name == LLVMDialect::getReadonlyAttrName() || |
| name == LLVMDialect::getReadnoneAttrName() || |
| name == LLVMDialect::getWriteOnlyAttrName() || |
| name == LLVMDialect::getNestAttrName() || |
| name == LLVMDialect::getNoCaptureAttrName() || |
| name == LLVMDialect::getNoFreeAttrName() || |
| name == LLVMDialect::getNonNullAttrName()) { |
| if (failed(checkUnitAttrType())) |
| return failure(); |
| if (verifyValueType && failed(checkPointerType())) |
| return failure(); |
| return success(); |
| } |
| |
| // Check a type attribute that is attached to a pointer value. |
| if (name == LLVMDialect::getStructRetAttrName() || |
| name == LLVMDialect::getByValAttrName() || |
| name == LLVMDialect::getByRefAttrName() || |
| name == LLVMDialect::getElementTypeAttrName() || |
| name == LLVMDialect::getInAllocaAttrName() || |
| name == LLVMDialect::getPreallocatedAttrName()) { |
| if (failed(checkTypeAttrType())) |
| return failure(); |
| if (verifyValueType && failed(checkPointerTypeMatches())) |
| return failure(); |
| return success(); |
| } |
| |
| // Check a unit attribute that is attached to an integer value. |
| if (name == LLVMDialect::getSExtAttrName() || |
| name == LLVMDialect::getZExtAttrName()) { |
| if (failed(checkUnitAttrType())) |
| return failure(); |
| if (verifyValueType && failed(checkIntegerType())) |
| return failure(); |
| return success(); |
| } |
| |
| // Check an integer attribute that is attached to a pointer value. |
| if (name == LLVMDialect::getAlignAttrName() || |
| name == LLVMDialect::getDereferenceableAttrName() || |
| name == LLVMDialect::getDereferenceableOrNullAttrName()) { |
| if (failed(checkIntegerAttrType())) |
| return failure(); |
| if (verifyValueType && failed(checkPointerType())) |
| return failure(); |
| return success(); |
| } |
| |
| // Check an integer attribute that is attached to a pointer value. |
| if (name == LLVMDialect::getStackAlignmentAttrName()) { |
| if (failed(checkIntegerAttrType())) |
| return failure(); |
| return success(); |
| } |
| |
| // Check a unit attribute that can be attached to arbitrary types. |
| if (name == LLVMDialect::getNoUndefAttrName() || |
| name == LLVMDialect::getInRegAttrName() || |
| name == LLVMDialect::getReturnedAttrName()) |
| return checkUnitAttrType(); |
| |
| return success(); |
| } |
| |
| /// Verify LLVMIR function argument attributes. |
| LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, |
| unsigned regionIdx, |
| unsigned argIdx, |
| NamedAttribute argAttr) { |
| auto funcOp = dyn_cast<FunctionOpInterface>(op); |
| if (!funcOp) |
| return success(); |
| Type argType = funcOp.getArgumentTypes()[argIdx]; |
| |
| return verifyParameterAttribute(op, argType, argAttr); |
| } |
| |
| LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op, |
| unsigned regionIdx, |
| unsigned resIdx, |
| NamedAttribute resAttr) { |
| auto funcOp = dyn_cast<FunctionOpInterface>(op); |
| if (!funcOp) |
| return success(); |
| Type resType = funcOp.getResultTypes()[resIdx]; |
| |
| // Check to see if this function has a void return with a result attribute |
| // to it. It isn't clear what semantics we would assign to that. |
| if (llvm::isa<LLVMVoidType>(resType)) |
| return op->emitError() << "cannot attach result attributes to functions " |
| "with a void return"; |
| |
| // Check to see if this attribute is allowed as a result attribute. Only |
| // explicitly forbidden LLVM attributes will cause an error. |
| auto name = resAttr.getName(); |
| if (name == LLVMDialect::getAllocAlignAttrName() || |
| name == LLVMDialect::getAllocatedPointerAttrName() || |
| name == LLVMDialect::getByValAttrName() || |
| name == LLVMDialect::getByRefAttrName() || |
| name == LLVMDialect::getInAllocaAttrName() || |
| name == LLVMDialect::getNestAttrName() || |
| name == LLVMDialect::getNoCaptureAttrName() || |
| name == LLVMDialect::getNoFreeAttrName() || |
| name == LLVMDialect::getPreallocatedAttrName() || |
| name == LLVMDialect::getReadnoneAttrName() || |
| name == LLVMDialect::getReadonlyAttrName() || |
| name == LLVMDialect::getReturnedAttrName() || |
| name == LLVMDialect::getStackAlignmentAttrName() || |
| name == LLVMDialect::getStructRetAttrName() || |
| name == LLVMDialect::getWriteOnlyAttrName()) |
| return op->emitError() << name << " is not a valid result attribute"; |
| return verifyParameterAttribute(op, resType, resAttr); |
| } |
| |
| Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value, |
| Type type, Location loc) { |
| // If this was folded from an operation other than llvm.mlir.constant, it |
| // should be materialized as such. Note that an llvm.mlir.zero may fold into |
| // a builtin zero attribute and thus will materialize as a llvm.mlir.constant. |
| if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value)) |
| if (isa<LLVM::LLVMPointerType>(type)) |
| return builder.create<LLVM::AddressOfOp>(loc, type, symbol); |
| if (isa<LLVM::UndefAttr>(value)) |
| return builder.create<LLVM::UndefOp>(loc, type); |
| if (isa<LLVM::PoisonAttr>(value)) |
| return builder.create<LLVM::PoisonOp>(loc, type); |
| if (isa<LLVM::ZeroAttr>(value)) |
| return builder.create<LLVM::ZeroOp>(loc, type); |
| // Otherwise try materializing it as a regular llvm.mlir.constant op. |
| return LLVM::ConstantOp::materialize(builder, value, type, loc); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions. |
| //===----------------------------------------------------------------------===// |
| |
| Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, |
| StringRef name, StringRef value, |
| LLVM::Linkage linkage) { |
| assert(builder.getInsertionBlock() && |
| builder.getInsertionBlock()->getParentOp() && |
| "expected builder to point to a block constrained in an op"); |
| auto module = |
| builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); |
| assert(module && "builder points to an op outside of a module"); |
| |
| // Create the global at the entry of the module. |
| OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener()); |
| MLIRContext *ctx = builder.getContext(); |
| auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size()); |
| auto global = moduleBuilder.create<LLVM::GlobalOp>( |
| loc, type, /*isConstant=*/true, linkage, name, |
| builder.getStringAttr(value), /*alignment=*/0); |
| |
| LLVMPointerType ptrType = LLVMPointerType::get(ctx); |
| // Get the pointer to the first character in the global string. |
| Value globalPtr = |
| builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr()); |
| return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr, |
| ArrayRef<GEPArg>{0, 0}); |
| } |
| |
| bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { |
| return op->hasTrait<OpTrait::SymbolTable>() && |
| op->hasTrait<OpTrait::IsIsolatedFromAbove>(); |
| } |