| //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Tablegen Interface Definitions |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/FunctionInterfaces.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // Function Arguments and Results. |
| //===----------------------------------------------------------------------===// |
| |
| static bool isEmptyAttrDict(Attribute attr) { |
| return llvm::cast<DictionaryAttr>(attr).empty(); |
| } |
| |
| DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, |
| unsigned index) { |
| ArrayAttr attrs = op.getArgAttrsAttr(); |
| DictionaryAttr argAttrs = |
| attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); |
| return argAttrs; |
| } |
| |
| DictionaryAttr |
| function_interface_impl::getResultAttrDict(FunctionOpInterface op, |
| unsigned index) { |
| ArrayAttr attrs = op.getResAttrsAttr(); |
| DictionaryAttr resAttrs = |
| attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); |
| return resAttrs; |
| } |
| |
| ArrayRef<NamedAttribute> |
| function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { |
| auto argDict = getArgAttrDict(op, index); |
| return argDict ? argDict.getValue() : std::nullopt; |
| } |
| |
| ArrayRef<NamedAttribute> |
| function_interface_impl::getResultAttrs(FunctionOpInterface op, |
| unsigned index) { |
| auto resultDict = getResultAttrDict(op, index); |
| return resultDict ? resultDict.getValue() : std::nullopt; |
| } |
| |
| /// Get either the argument or result attributes array. |
| template <bool isArg> |
| static ArrayAttr getArgResAttrs(FunctionOpInterface op) { |
| if constexpr (isArg) |
| return op.getArgAttrsAttr(); |
| else |
| return op.getResAttrsAttr(); |
| } |
| |
| /// Set either the argument or result attributes array. |
| template <bool isArg> |
| static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { |
| if constexpr (isArg) |
| op.setArgAttrsAttr(attrs); |
| else |
| op.setResAttrsAttr(attrs); |
| } |
| |
| /// Erase either the argument or result attributes array. |
| template <bool isArg> |
| static void removeArgResAttrs(FunctionOpInterface op) { |
| if constexpr (isArg) |
| op.removeArgAttrsAttr(); |
| else |
| op.removeResAttrsAttr(); |
| } |
| |
| /// Set all of the argument or result attribute dictionaries for a function. |
| template <bool isArg> |
| static void setAllArgResAttrDicts(FunctionOpInterface op, |
| ArrayRef<Attribute> attrs) { |
| if (llvm::all_of(attrs, isEmptyAttrDict)) |
| removeArgResAttrs<isArg>(op); |
| else |
| setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs)); |
| } |
| |
| void function_interface_impl::setAllArgAttrDicts( |
| FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { |
| setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); |
| } |
| |
| void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, |
| ArrayRef<Attribute> attrs) { |
| auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { |
| return !attr ? DictionaryAttr::get(op->getContext()) : attr; |
| }); |
| setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs)); |
| } |
| |
| void function_interface_impl::setAllResultAttrDicts( |
| FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { |
| setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); |
| } |
| |
| void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, |
| ArrayRef<Attribute> attrs) { |
| auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { |
| return !attr ? DictionaryAttr::get(op->getContext()) : attr; |
| }); |
| setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs)); |
| } |
| |
| /// Update the given index into an argument or result attribute dictionary. |
| template <bool isArg> |
| static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, |
| unsigned index, DictionaryAttr attrs) { |
| ArrayAttr allAttrs = getArgResAttrs<isArg>(op); |
| if (!allAttrs) { |
| if (attrs.empty()) |
| return; |
| |
| // If this attribute is not empty, we need to create a new attribute array. |
| SmallVector<Attribute, 8> newAttrs(numTotalIndices, |
| DictionaryAttr::get(op->getContext())); |
| newAttrs[index] = attrs; |
| setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); |
| return; |
| } |
| // Check to see if the attribute is different from what we already have. |
| if (allAttrs[index] == attrs) |
| return; |
| |
| // If it is, check to see if the attribute array would now contain only empty |
| // dictionaries. |
| ArrayRef<Attribute> rawAttrArray = allAttrs.getValue(); |
| if (attrs.empty() && |
| llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && |
| llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) |
| return removeArgResAttrs<isArg>(op); |
| |
| // Otherwise, create a new attribute array with the updated dictionary. |
| SmallVector<Attribute, 8> newAttrs(rawAttrArray); |
| newAttrs[index] = attrs; |
| setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); |
| } |
| |
| void function_interface_impl::setArgAttrs(FunctionOpInterface op, |
| unsigned index, |
| ArrayRef<NamedAttribute> attributes) { |
| assert(index < op.getNumArguments() && "invalid argument number"); |
| return setArgResAttrDict</*isArg=*/true>( |
| op, op.getNumArguments(), index, |
| DictionaryAttr::get(op->getContext(), attributes)); |
| } |
| |
| void function_interface_impl::setArgAttrs(FunctionOpInterface op, |
| unsigned index, |
| DictionaryAttr attributes) { |
| return setArgResAttrDict</*isArg=*/true>( |
| op, op.getNumArguments(), index, |
| attributes ? attributes : DictionaryAttr::get(op->getContext())); |
| } |
| |
| void function_interface_impl::setResultAttrs( |
| FunctionOpInterface op, unsigned index, |
| ArrayRef<NamedAttribute> attributes) { |
| assert(index < op.getNumResults() && "invalid result number"); |
| return setArgResAttrDict</*isArg=*/false>( |
| op, op.getNumResults(), index, |
| DictionaryAttr::get(op->getContext(), attributes)); |
| } |
| |
| void function_interface_impl::setResultAttrs(FunctionOpInterface op, |
| unsigned index, |
| DictionaryAttr attributes) { |
| assert(index < op.getNumResults() && "invalid result number"); |
| return setArgResAttrDict</*isArg=*/false>( |
| op, op.getNumResults(), index, |
| attributes ? attributes : DictionaryAttr::get(op->getContext())); |
| } |
| |
| void function_interface_impl::insertFunctionArguments( |
| FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes, |
| ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs, |
| unsigned originalNumArgs, Type newType) { |
| assert(argIndices.size() == argTypes.size()); |
| assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); |
| assert(argIndices.size() == argLocs.size()); |
| if (argIndices.empty()) |
| return; |
| |
| // There are 3 things that need to be updated: |
| // - Function type. |
| // - Arg attrs. |
| // - Block arguments of entry block, if not empty. |
| |
| // Update the argument attributes of the function. |
| ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); |
| if (oldArgAttrs || !argAttrs.empty()) { |
| SmallVector<DictionaryAttr, 4> newArgAttrs; |
| newArgAttrs.reserve(originalNumArgs + argIndices.size()); |
| unsigned oldIdx = 0; |
| auto migrate = [&](unsigned untilIdx) { |
| if (!oldArgAttrs) { |
| newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); |
| } else { |
| auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>(); |
| newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, |
| oldArgAttrRange.begin() + untilIdx); |
| } |
| oldIdx = untilIdx; |
| }; |
| for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { |
| migrate(argIndices[i]); |
| newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); |
| } |
| migrate(originalNumArgs); |
| setAllArgAttrDicts(op, newArgAttrs); |
| } |
| |
| // Update the function type. |
| op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| |
| // Update entry block arguments, if not empty. |
| if (!op.isExternal()) { |
| Block &entry = op->getRegion(0).front(); |
| for (unsigned i = 0, e = argIndices.size(); i < e; ++i) |
| entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); |
| } |
| } |
| |
| void function_interface_impl::insertFunctionResults( |
| FunctionOpInterface op, ArrayRef<unsigned> resultIndices, |
| TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs, |
| unsigned originalNumResults, Type newType) { |
| assert(resultIndices.size() == resultTypes.size()); |
| assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); |
| if (resultIndices.empty()) |
| return; |
| |
| // There are 2 things that need to be updated: |
| // - Function type. |
| // - Result attrs. |
| |
| // Update the result attributes of the function. |
| ArrayAttr oldResultAttrs = op.getResAttrsAttr(); |
| if (oldResultAttrs || !resultAttrs.empty()) { |
| SmallVector<DictionaryAttr, 4> newResultAttrs; |
| newResultAttrs.reserve(originalNumResults + resultIndices.size()); |
| unsigned oldIdx = 0; |
| auto migrate = [&](unsigned untilIdx) { |
| if (!oldResultAttrs) { |
| newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); |
| } else { |
| auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>(); |
| newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, |
| oldResultAttrsRange.begin() + untilIdx); |
| } |
| oldIdx = untilIdx; |
| }; |
| for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { |
| migrate(resultIndices[i]); |
| newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} |
| : resultAttrs[i]); |
| } |
| migrate(originalNumResults); |
| setAllResultAttrDicts(op, newResultAttrs); |
| } |
| |
| // Update the function type. |
| op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| } |
| |
| void function_interface_impl::eraseFunctionArguments( |
| FunctionOpInterface op, const BitVector &argIndices, Type newType) { |
| // There are 3 things that need to be updated: |
| // - Function type. |
| // - Arg attrs. |
| // - Block arguments of entry block, if not empty. |
| |
| // Update the argument attributes of the function. |
| if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { |
| SmallVector<DictionaryAttr, 4> newArgAttrs; |
| newArgAttrs.reserve(argAttrs.size()); |
| for (unsigned i = 0, e = argIndices.size(); i < e; ++i) |
| if (!argIndices[i]) |
| newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i])); |
| setAllArgAttrDicts(op, newArgAttrs); |
| } |
| |
| // Update the function type. |
| op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| |
| // Update entry block arguments, if not empty. |
| if (!op.isExternal()) { |
| Block &entry = op->getRegion(0).front(); |
| entry.eraseArguments(argIndices); |
| } |
| } |
| |
| void function_interface_impl::eraseFunctionResults( |
| FunctionOpInterface op, const BitVector &resultIndices, Type newType) { |
| // There are 2 things that need to be updated: |
| // - Function type. |
| // - Result attrs. |
| |
| // Update the result attributes of the function. |
| if (ArrayAttr resAttrs = op.getResAttrsAttr()) { |
| SmallVector<DictionaryAttr, 4> newResultAttrs; |
| newResultAttrs.reserve(resAttrs.size()); |
| for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) |
| if (!resultIndices[i]) |
| newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i])); |
| setAllResultAttrDicts(op, newResultAttrs); |
| } |
| |
| // Update the function type. |
| op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Function type signature. |
| //===----------------------------------------------------------------------===// |
| |
| void function_interface_impl::setFunctionType(FunctionOpInterface op, |
| Type newType) { |
| unsigned oldNumArgs = op.getNumArguments(); |
| unsigned oldNumResults = op.getNumResults(); |
| op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| unsigned newNumArgs = op.getNumArguments(); |
| unsigned newNumResults = op.getNumResults(); |
| |
| // Functor used to update the argument and result attributes of the function. |
| auto emptyDict = DictionaryAttr::get(op.getContext()); |
| auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { |
| constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>; |
| |
| if (oldCount == newCount) |
| return; |
| // The new type has no arguments/results, just drop the attribute. |
| if (newCount == 0) |
| return removeArgResAttrs<isArgVal>(op); |
| ArrayAttr attrs = getArgResAttrs<isArgVal>(op); |
| if (!attrs) |
| return; |
| |
| // The new type has less arguments/results, take the first N attributes. |
| if (newCount < oldCount) |
| return setAllArgResAttrDicts<isArgVal>( |
| op, attrs.getValue().take_front(newCount)); |
| |
| // Otherwise, the new type has more arguments/results. Initialize the new |
| // arguments/results with empty dictionary attributes. |
| SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end()); |
| newAttrs.resize(newCount, emptyDict); |
| setAllArgResAttrDicts<isArgVal>(op, newAttrs); |
| }; |
| |
| // Update the argument and result attributes. |
| updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); |
| updateAttrFn(std::false_type{}, oldNumResults, newNumResults); |
| } |