| //===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "LoopAnnotationImporter.h" |
| #include "llvm/IR/Constants.h" |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| using namespace mlir::LLVM::detail; |
| |
| namespace { |
| /// Helper class that keeps the state of one metadata to attribute conversion. |
| struct LoopMetadataConversion { |
| LoopMetadataConversion(const llvm::MDNode *node, Location loc, |
| LoopAnnotationImporter &loopAnnotationImporter) |
| : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter), |
| ctx(loc->getContext()){}; |
| /// Converts this structs loop metadata node into a LoopAnnotationAttr. |
| LoopAnnotationAttr convert(); |
| |
| /// Initializes the shared state for the conversion member functions. |
| LogicalResult initConversionState(); |
| |
| /// Helper function to get and erase a property. |
| const llvm::MDNode *lookupAndEraseProperty(StringRef name); |
| |
| /// Helper functions to lookup and convert MDNodes into a specifc attribute |
| /// kind. These functions return null-attributes if there is no node with the |
| /// specified name, or failure, if the node is ill-formatted. |
| FailureOr<BoolAttr> lookupUnitNode(StringRef name); |
| FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false); |
| FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name); |
| FailureOr<IntegerAttr> lookupIntNode(StringRef name); |
| FailureOr<llvm::MDNode *> lookupMDNode(StringRef name); |
| FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name); |
| FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name); |
| FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName, |
| StringRef disableName, |
| bool negated = false); |
| |
| /// Conversion functions for sub-attributes. |
| FailureOr<LoopVectorizeAttr> convertVectorizeAttr(); |
| FailureOr<LoopInterleaveAttr> convertInterleaveAttr(); |
| FailureOr<LoopUnrollAttr> convertUnrollAttr(); |
| FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr(); |
| FailureOr<LoopLICMAttr> convertLICMAttr(); |
| FailureOr<LoopDistributeAttr> convertDistributeAttr(); |
| FailureOr<LoopPipelineAttr> convertPipelineAttr(); |
| FailureOr<LoopPeeledAttr> convertPeeledAttr(); |
| FailureOr<LoopUnswitchAttr> convertUnswitchAttr(); |
| FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses(); |
| FusedLoc convertStartLoc(); |
| FailureOr<FusedLoc> convertEndLoc(); |
| |
| llvm::SmallVector<llvm::DILocation *, 2> locations; |
| llvm::StringMap<const llvm::MDNode *> propertyMap; |
| const llvm::MDNode *node; |
| Location loc; |
| LoopAnnotationImporter &loopAnnotationImporter; |
| MLIRContext *ctx; |
| }; |
| } // namespace |
| |
| LogicalResult LoopMetadataConversion::initConversionState() { |
| // Check if it's a valid node. |
| if (node->getNumOperands() == 0 || |
| dyn_cast<llvm::MDNode>(node->getOperand(0)) != node) |
| return emitWarning(loc) << "invalid loop node"; |
| |
| for (const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) { |
| if (auto *diLoc = dyn_cast<llvm::DILocation>(operand)) { |
| locations.push_back(diLoc); |
| continue; |
| } |
| |
| auto *property = dyn_cast<llvm::MDNode>(operand); |
| if (!property) |
| return emitWarning(loc) << "expected all loop properties to be either " |
| "debug locations or metadata nodes"; |
| |
| if (property->getNumOperands() == 0) |
| return emitWarning(loc) << "cannot import empty loop property"; |
| |
| auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0)); |
| if (!nameNode) |
| return emitWarning(loc) << "cannot import loop property without a name"; |
| StringRef name = nameNode->getString(); |
| |
| bool succ = propertyMap.try_emplace(name, property).second; |
| if (!succ) |
| return emitWarning(loc) |
| << "cannot import loop properties with duplicated names " << name; |
| } |
| |
| return success(); |
| } |
| |
| const llvm::MDNode * |
| LoopMetadataConversion::lookupAndEraseProperty(StringRef name) { |
| auto it = propertyMap.find(name); |
| if (it == propertyMap.end()) |
| return nullptr; |
| const llvm::MDNode *property = it->getValue(); |
| propertyMap.erase(it); |
| return property; |
| } |
| |
| FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) { |
| const llvm::MDNode *property = lookupAndEraseProperty(name); |
| if (!property) |
| return BoolAttr(nullptr); |
| |
| if (property->getNumOperands() != 1) |
| return emitWarning(loc) |
| << "expected metadata node " << name << " to hold no value"; |
| |
| return BoolAttr::get(ctx, true); |
| } |
| |
| FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode( |
| StringRef enableName, StringRef disableName, bool negated) { |
| auto enable = lookupUnitNode(enableName); |
| auto disable = lookupUnitNode(disableName); |
| if (failed(enable) || failed(disable)) |
| return failure(); |
| |
| if (*enable && *disable) |
| return emitWarning(loc) |
| << "expected metadata nodes " << enableName << " and " << disableName |
| << " to be mutually exclusive."; |
| |
| if (*enable) |
| return BoolAttr::get(ctx, !negated); |
| |
| if (*disable) |
| return BoolAttr::get(ctx, negated); |
| return BoolAttr(nullptr); |
| } |
| |
| FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name, |
| bool negated) { |
| const llvm::MDNode *property = lookupAndEraseProperty(name); |
| if (!property) |
| return BoolAttr(nullptr); |
| |
| auto emitNodeWarning = [&]() { |
| return emitWarning(loc) |
| << "expected metadata node " << name << " to hold a boolean value"; |
| }; |
| |
| if (property->getNumOperands() != 2) |
| return emitNodeWarning(); |
| llvm::ConstantInt *val = |
| llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1)); |
| if (!val || val->getBitWidth() != 1) |
| return emitNodeWarning(); |
| |
| return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated); |
| } |
| |
| FailureOr<BoolAttr> |
| LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) { |
| const llvm::MDNode *property = lookupAndEraseProperty(name); |
| if (!property) |
| return BoolAttr(nullptr); |
| |
| auto emitNodeWarning = [&]() { |
| return emitWarning(loc) |
| << "expected metadata node " << name << " to hold an integer value"; |
| }; |
| |
| if (property->getNumOperands() != 2) |
| return emitNodeWarning(); |
| llvm::ConstantInt *val = |
| llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1)); |
| if (!val || val->getBitWidth() != 32) |
| return emitNodeWarning(); |
| |
| return BoolAttr::get(ctx, val->getValue().getLimitedValue(1)); |
| } |
| |
| FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) { |
| const llvm::MDNode *property = lookupAndEraseProperty(name); |
| if (!property) |
| return IntegerAttr(nullptr); |
| |
| auto emitNodeWarning = [&]() { |
| return emitWarning(loc) |
| << "expected metadata node " << name << " to hold an i32 value"; |
| }; |
| |
| if (property->getNumOperands() != 2) |
| return emitNodeWarning(); |
| |
| llvm::ConstantInt *val = |
| llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1)); |
| if (!val || val->getBitWidth() != 32) |
| return emitNodeWarning(); |
| |
| return IntegerAttr::get(IntegerType::get(ctx, 32), |
| val->getValue().getLimitedValue()); |
| } |
| |
| FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) { |
| const llvm::MDNode *property = lookupAndEraseProperty(name); |
| if (!property) |
| return nullptr; |
| |
| auto emitNodeWarning = [&]() { |
| return emitWarning(loc) |
| << "expected metadata node " << name << " to hold an MDNode"; |
| }; |
| |
| if (property->getNumOperands() != 2) |
| return emitNodeWarning(); |
| |
| auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1)); |
| if (!node) |
| return emitNodeWarning(); |
| |
| return node; |
| } |
| |
| FailureOr<SmallVector<llvm::MDNode *>> |
| LoopMetadataConversion::lookupMDNodes(StringRef name) { |
| const llvm::MDNode *property = lookupAndEraseProperty(name); |
| SmallVector<llvm::MDNode *> res; |
| if (!property) |
| return res; |
| |
| auto emitNodeWarning = [&]() { |
| return emitWarning(loc) << "expected metadata node " << name |
| << " to hold one or multiple MDNodes"; |
| }; |
| |
| if (property->getNumOperands() < 2) |
| return emitNodeWarning(); |
| |
| for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) { |
| auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i)); |
| if (!node) |
| return emitNodeWarning(); |
| res.push_back(node); |
| } |
| |
| return res; |
| } |
| |
| FailureOr<LoopAnnotationAttr> |
| LoopMetadataConversion::lookupFollowupNode(StringRef name) { |
| auto node = lookupMDNode(name); |
| if (failed(node)) |
| return failure(); |
| if (*node == nullptr) |
| return LoopAnnotationAttr(nullptr); |
| |
| return loopAnnotationImporter.translateLoopAnnotation(*node, loc); |
| } |
| |
| static bool isEmptyOrNull(const Attribute attr) { return !attr; } |
| |
| template <typename T> |
| static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) { |
| return vec.empty(); |
| } |
| |
| /// Helper function that only creates and attribute of type T if all argument |
| /// conversion were successfull and at least one of them holds a non-null value. |
| template <typename T, typename... P> |
| static T createIfNonNull(MLIRContext *ctx, const P &...args) { |
| bool anyFailed = (failed(args) || ...); |
| if (anyFailed) |
| return {}; |
| |
| bool allEmpty = (isEmptyOrNull(*args) && ...); |
| if (allEmpty) |
| return {}; |
| |
| return T::get(ctx, *args...); |
| } |
| |
| FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() { |
| FailureOr<BoolAttr> enable = |
| lookupBoolNode("llvm.loop.vectorize.enable", true); |
| FailureOr<BoolAttr> predicateEnable = |
| lookupBoolNode("llvm.loop.vectorize.predicate.enable"); |
| FailureOr<BoolAttr> scalableEnable = |
| lookupBoolNode("llvm.loop.vectorize.scalable.enable"); |
| FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width"); |
| FailureOr<LoopAnnotationAttr> followupVec = |
| lookupFollowupNode("llvm.loop.vectorize.followup_vectorized"); |
| FailureOr<LoopAnnotationAttr> followupEpi = |
| lookupFollowupNode("llvm.loop.vectorize.followup_epilogue"); |
| FailureOr<LoopAnnotationAttr> followupAll = |
| lookupFollowupNode("llvm.loop.vectorize.followup_all"); |
| |
| return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable, |
| scalableEnable, width, followupVec, |
| followupEpi, followupAll); |
| } |
| |
| FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() { |
| FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count"); |
| return createIfNonNull<LoopInterleaveAttr>(ctx, count); |
| } |
| |
| FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() { |
| FailureOr<BoolAttr> disable = lookupBooleanUnitNode( |
| "llvm.loop.unroll.enable", "llvm.loop.unroll.disable", /*negated=*/true); |
| FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count"); |
| FailureOr<BoolAttr> runtimeDisable = |
| lookupUnitNode("llvm.loop.unroll.runtime.disable"); |
| FailureOr<BoolAttr> full = lookupUnitNode("llvm.loop.unroll.full"); |
| FailureOr<LoopAnnotationAttr> followupUnrolled = |
| lookupFollowupNode("llvm.loop.unroll.followup_unrolled"); |
| FailureOr<LoopAnnotationAttr> followupRemainder = |
| lookupFollowupNode("llvm.loop.unroll.followup_remainder"); |
| FailureOr<LoopAnnotationAttr> followupAll = |
| lookupFollowupNode("llvm.loop.unroll.followup_all"); |
| |
| return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable, |
| full, followupUnrolled, |
| followupRemainder, followupAll); |
| } |
| |
| FailureOr<LoopUnrollAndJamAttr> |
| LoopMetadataConversion::convertUnrollAndJamAttr() { |
| FailureOr<BoolAttr> disable = lookupBooleanUnitNode( |
| "llvm.loop.unroll_and_jam.enable", "llvm.loop.unroll_and_jam.disable", |
| /*negated=*/true); |
| FailureOr<IntegerAttr> count = |
| lookupIntNode("llvm.loop.unroll_and_jam.count"); |
| FailureOr<LoopAnnotationAttr> followupOuter = |
| lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer"); |
| FailureOr<LoopAnnotationAttr> followupInner = |
| lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner"); |
| FailureOr<LoopAnnotationAttr> followupRemainderOuter = |
| lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer"); |
| FailureOr<LoopAnnotationAttr> followupRemainderInner = |
| lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner"); |
| FailureOr<LoopAnnotationAttr> followupAll = |
| lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all"); |
| return createIfNonNull<LoopUnrollAndJamAttr>( |
| ctx, disable, count, followupOuter, followupInner, followupRemainderOuter, |
| followupRemainderInner, followupAll); |
| } |
| |
| FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() { |
| FailureOr<BoolAttr> disable = lookupUnitNode("llvm.licm.disable"); |
| FailureOr<BoolAttr> versioningDisable = |
| lookupUnitNode("llvm.loop.licm_versioning.disable"); |
| return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable); |
| } |
| |
| FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() { |
| FailureOr<BoolAttr> disable = |
| lookupBoolNode("llvm.loop.distribute.enable", true); |
| FailureOr<LoopAnnotationAttr> followupCoincident = |
| lookupFollowupNode("llvm.loop.distribute.followup_coincident"); |
| FailureOr<LoopAnnotationAttr> followupSequential = |
| lookupFollowupNode("llvm.loop.distribute.followup_sequential"); |
| FailureOr<LoopAnnotationAttr> followupFallback = |
| lookupFollowupNode("llvm.loop.distribute.followup_fallback"); |
| FailureOr<LoopAnnotationAttr> followupAll = |
| lookupFollowupNode("llvm.loop.distribute.followup_all"); |
| return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident, |
| followupSequential, |
| followupFallback, followupAll); |
| } |
| |
| FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() { |
| FailureOr<BoolAttr> disable = lookupBoolNode("llvm.loop.pipeline.disable"); |
| FailureOr<IntegerAttr> initiationinterval = |
| lookupIntNode("llvm.loop.pipeline.initiationinterval"); |
| return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval); |
| } |
| |
| FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() { |
| FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count"); |
| return createIfNonNull<LoopPeeledAttr>(ctx, count); |
| } |
| |
| FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() { |
| FailureOr<BoolAttr> partialDisable = |
| lookupUnitNode("llvm.loop.unswitch.partial.disable"); |
| return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable); |
| } |
| |
| FailureOr<SmallVector<AccessGroupAttr>> |
| LoopMetadataConversion::convertParallelAccesses() { |
| FailureOr<SmallVector<llvm::MDNode *>> nodes = |
| lookupMDNodes("llvm.loop.parallel_accesses"); |
| if (failed(nodes)) |
| return failure(); |
| SmallVector<AccessGroupAttr> refs; |
| for (llvm::MDNode *node : *nodes) { |
| FailureOr<SmallVector<AccessGroupAttr>> accessGroups = |
| loopAnnotationImporter.lookupAccessGroupAttrs(node); |
| if (failed(accessGroups)) { |
| emitWarning(loc) << "could not lookup access group"; |
| continue; |
| } |
| llvm::append_range(refs, *accessGroups); |
| } |
| return refs; |
| } |
| |
| FusedLoc LoopMetadataConversion::convertStartLoc() { |
| if (locations.empty()) |
| return {}; |
| return dyn_cast<FusedLoc>( |
| loopAnnotationImporter.moduleImport.translateLoc(locations[0])); |
| } |
| |
| FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() { |
| if (locations.size() < 2) |
| return FusedLoc(); |
| if (locations.size() > 2) |
| return emitError(loc) |
| << "expected loop metadata to have at most two DILocations"; |
| return dyn_cast<FusedLoc>( |
| loopAnnotationImporter.moduleImport.translateLoc(locations[1])); |
| } |
| |
| LoopAnnotationAttr LoopMetadataConversion::convert() { |
| if (failed(initConversionState())) |
| return {}; |
| |
| FailureOr<BoolAttr> disableNonForced = |
| lookupUnitNode("llvm.loop.disable_nonforced"); |
| FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr(); |
| FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr(); |
| FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr(); |
| FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr(); |
| FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr(); |
| FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr(); |
| FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr(); |
| FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr(); |
| FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr(); |
| FailureOr<BoolAttr> mustProgress = lookupUnitNode("llvm.loop.mustprogress"); |
| FailureOr<BoolAttr> isVectorized = |
| lookupIntNodeAsBoolAttr("llvm.loop.isvectorized"); |
| FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses = |
| convertParallelAccesses(); |
| |
| // Drop the metadata if there are parts that cannot be imported. |
| if (!propertyMap.empty()) { |
| for (auto name : propertyMap.keys()) |
| emitWarning(loc) << "unknown loop annotation " << name; |
| return {}; |
| } |
| |
| FailureOr<FusedLoc> startLoc = convertStartLoc(); |
| FailureOr<FusedLoc> endLoc = convertEndLoc(); |
| |
| return createIfNonNull<LoopAnnotationAttr>( |
| ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr, |
| unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr, |
| unswitchAttr, mustProgress, isVectorized, startLoc, endLoc, |
| parallelAccesses); |
| } |
| |
| LoopAnnotationAttr |
| LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node, |
| Location loc) { |
| if (!node) |
| return {}; |
| |
| // Note: This check is necessary to distinguish between failed translations |
| // and not yet attempted translations. |
| auto it = loopMetadataMapping.find(node); |
| if (it != loopMetadataMapping.end()) |
| return it->getSecond(); |
| |
| LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert(); |
| |
| mapLoopMetadata(node, attr); |
| return attr; |
| } |
| |
| LogicalResult |
| LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node, |
| Location loc) { |
| SmallVector<const llvm::MDNode *> accessGroups; |
| if (!node->getNumOperands()) |
| accessGroups.push_back(node); |
| for (const llvm::MDOperand &operand : node->operands()) { |
| auto *childNode = dyn_cast<llvm::MDNode>(operand); |
| if (!childNode) |
| return failure(); |
| accessGroups.push_back(cast<llvm::MDNode>(operand.get())); |
| } |
| |
| // Convert all entries of the access group list to access group operations. |
| for (const llvm::MDNode *accessGroup : accessGroups) { |
| if (accessGroupMapping.count(accessGroup)) |
| continue; |
| // Verify the access group node is distinct and empty. |
| if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct()) |
| return emitWarning(loc) |
| << "expected an access group node to be empty and distinct"; |
| |
| // Add a mapping from the access group node to the newly created attribute. |
| accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>(); |
| } |
| return success(); |
| } |
| |
| FailureOr<SmallVector<AccessGroupAttr>> |
| LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const { |
| // An access group node is either a single access group or an access group |
| // list. |
| SmallVector<AccessGroupAttr> accessGroups; |
| if (!node->getNumOperands()) |
| accessGroups.push_back(accessGroupMapping.lookup(node)); |
| for (const llvm::MDOperand &operand : node->operands()) { |
| auto *node = cast<llvm::MDNode>(operand.get()); |
| accessGroups.push_back(accessGroupMapping.lookup(node)); |
| } |
| // Exit if one of the access group node lookups failed. |
| if (llvm::is_contained(accessGroups, nullptr)) |
| return failure(); |
| return accessGroups; |
| } |