| //===- BuiltinAttributeInterfaces.td - Attr interfaces -----*- tablegen -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file contains the definition of the ElementsAttr interface. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ |
| #define MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ |
| |
| include "mlir/IR/OpBase.td" |
| |
| //===----------------------------------------------------------------------===// |
| // ElementsAttrInterface |
| //===----------------------------------------------------------------------===// |
| |
| def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { |
| let cppNamespace = "::mlir"; |
| let description = [{ |
| This interface is used for attributes that contain the constant elements of |
| a tensor or vector type. It allows for opaquely interacting with the |
| elements of the underlying attribute, and most importantly allows for |
| accessing the element values (including iteration) in any of the C++ data |
| types supported by the underlying attribute. |
| |
| An attribute implementing this interface can expose the supported data types |
| in two steps: |
| |
| * Define the set of iterable C++ data types: |
| |
| An attribute may define the set of iterable types by providing a definition |
| of tuples `ContiguousIterableTypesT` and/or `NonContiguousIterableTypesT`. |
| |
| - `ContiguousIterableTypesT` should contain types which can be iterated |
| contiguously. A contiguous range is an array-like range, such as |
| ArrayRef, where all of the elements are layed out sequentially in memory. |
| |
| - `NonContiguousIterableTypesT` should contain types which can not be |
| iterated contiguously. A non-contiguous range implies no contiguity, |
| whose elements may even be materialized when indexing, such as the case |
| for a mapped_range. |
| |
| As an example, consider an attribute that only contains i64 elements, with |
| the elements being stored within an ArrayRef. This attribute could |
| potentially define the iterable types as so: |
| |
| ```c++ |
| using ContiguousIterableTypesT = std::tuple<uint64_t>; |
| using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>; |
| ``` |
| |
| * Provide a `iterator value_begin_impl(OverloadToken<T>) const` overload for |
| each iterable type |
| |
| These overloads should return an iterator to the start of the range for the |
| respective iterable type. Consider the example i64 elements attribute |
| described in the previous section. This attribute may define the |
| value_begin_impl overloads like so: |
| |
| ```c++ |
| /// Provide begin iterators for the various iterable types. |
| /// * uint64_t |
| auto value_begin_impl(OverloadToken<uint64_t>) const { |
| return getElements().begin(); |
| } |
| /// * APInt |
| auto value_begin_impl(OverloadToken<llvm::APInt>) const { |
| return llvm::map_range(getElements(), [=](uint64_t value) { |
| return llvm::APInt(/*numBits=*/64, value); |
| }).begin(); |
| } |
| /// * Attribute |
| auto value_begin_impl(OverloadToken<mlir::Attribute>) const { |
| mlir::Type elementType = getType().getElementType(); |
| return llvm::map_range(getElements(), [=](uint64_t value) { |
| return mlir::IntegerAttr::get(elementType, |
| llvm::APInt(/*numBits=*/64, value)); |
| }).begin(); |
| } |
| ``` |
| |
| After the above, ElementsAttr will now be able to iterate over elements |
| using each of the registered iterable data types: |
| |
| ```c++ |
| ElementsAttr attr = myI64ElementsAttr; |
| |
| // We can access value ranges for the data types via `getValues<T>`. |
| for (uint64_t value : attr.getValues<uint64_t>()) |
| ...; |
| for (llvm::APInt value : attr.getValues<llvm::APInt>()) |
| ...; |
| for (mlir::IntegerAttr value : attr.getValues<mlir::IntegerAttr>()) |
| ...; |
| |
| // We can also access the value iterators directly. |
| auto it = attr.value_begin<uint64_t>(), e = attr.value_end<uint64_t>(); |
| for (; it != e; ++it) { |
| uint64_t value = *it; |
| ... |
| } |
| ``` |
| |
| ElementsAttr also supports failable access to iterators and ranges. This |
| allows for safely checking if the attribute supports the data type, and can |
| also allow for code to have fast paths for native data types. |
| |
| ```c++ |
| // Using `tryGetValues<T>`, we can also safely handle when the attribute |
| // doesn't support the data type. |
| if (auto range = attr.tryGetValues<uint64_t>()) { |
| for (uint64_t value : *range) |
| ...; |
| return; |
| } |
| |
| // We can also access the begin iterator safely, by using `try_value_begin`. |
| if (auto safeIt = attr.try_value_begin<uint64_t>()) { |
| auto it = *safeIt, e = attr.value_end<uint64_t>(); |
| for (; it != e; ++it) { |
| uint64_t value = *it; |
| ... |
| } |
| return; |
| } |
| ``` |
| }]; |
| let methods = [ |
| InterfaceMethod<[{ |
| This method returns an opaque range indexer for the given elementID, which |
| corresponds to a desired C++ element data type. Returns the indexer if the |
| attribute supports the given data type, failure otherwise. |
| }], |
| "::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>", "getValuesImpl", |
| (ins "::mlir::TypeID":$elementID), [{}], /*defaultImplementation=*/[{ |
| auto result = getValueImpl( |
| (typename ConcreteAttr::ContiguousIterableTypesT *)nullptr, elementID, |
| /*isContiguous=*/std::true_type()); |
| if (succeeded(result)) |
| return std::move(result); |
| |
| return getValueImpl( |
| (typename ConcreteAttr::NonContiguousIterableTypesT *)nullptr, |
| elementID, /*isContiguous=*/std::false_type()); |
| }]>, |
| InterfaceMethod<[{ |
| Returns true if the attribute elements correspond to a splat, i.e. that |
| all elements of the attribute are the same value. |
| }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{ |
| // By default, only check for a single element splat. |
| return $_attr.getNumElements() == 1; |
| }]> |
| ]; |
| |
| string ElementsAttrInterfaceAccessors = [{ |
| /// Return the number of elements held by this attribute. |
| int64_t size() const { return getNumElements(); } |
| |
| /// Return if the attribute holds no elements. |
| bool empty() const { return size() == 0; } |
| }]; |
| |
| let extraTraitClassDeclaration = [{ |
| // By default, no types are iterable. |
| using ContiguousIterableTypesT = std::tuple<>; |
| using NonContiguousIterableTypesT = std::tuple<>; |
| |
| //===------------------------------------------------------------------===// |
| // Accessors |
| //===------------------------------------------------------------------===// |
| |
| /// Return the element type of this ElementsAttr. |
| Type getElementType() const { |
| return ::mlir::ElementsAttr::getElementType($_attr); |
| } |
| |
| /// Returns the number of elements held by this attribute. |
| int64_t getNumElements() const { |
| return ::mlir::ElementsAttr::getNumElements($_attr); |
| } |
| |
| /// Return if the given 'index' refers to a valid element in this attribute. |
| bool isValidIndex(ArrayRef<uint64_t> index) const { |
| return ::mlir::ElementsAttr::isValidIndex($_attr, index); |
| } |
| |
| protected: |
| /// Returns the 1-dimensional flattened row-major index from the given |
| /// multi-dimensional index. |
| uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const { |
| return ::mlir::ElementsAttr::getFlattenedIndex($_attr, index); |
| } |
| |
| //===------------------------------------------------------------------===// |
| // Value Iteration Internals |
| //===------------------------------------------------------------------===// |
| protected: |
| /// This class is used to allow specifying function overloads for different |
| /// types, without actually taking the types as parameters. This avoids the |
| /// need to build complicated SFINAE to select specific overloads. |
| template <typename T> |
| struct OverloadToken {}; |
| |
| private: |
| /// This function unpacks the types within a given tuple and then forwards |
| /// on to the unwrapped variant. |
| template <typename... Ts, typename IsContiguousT> |
| auto getValueImpl(std::tuple<Ts...> *, ::mlir::TypeID elementID, |
| IsContiguousT isContiguous) const { |
| return getValueImpl<Ts...>(elementID, isContiguous); |
| } |
| /// Check to see if the given `elementID` matches the current type `T`. If |
| /// it does, build a value result using the current type. If it doesn't, |
| /// keep looking for the desired type. |
| template <typename T, typename... Ts, typename IsContiguousT> |
| auto getValueImpl(::mlir::TypeID elementID, |
| IsContiguousT isContiguous) const { |
| if (::mlir::TypeID::get<T>() == elementID) |
| return buildValueResult<T>(isContiguous); |
| return getValueImpl<Ts...>(elementID, isContiguous); |
| } |
| /// Bottom out case for no matching type. |
| template <typename IsContiguousT> |
| ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> |
| getValueImpl(::mlir::TypeID, IsContiguousT) const { |
| return failure(); |
| } |
| |
| /// Build an indexer for the given type `T`, which is represented via a |
| /// contiguous range. |
| template <typename T> |
| ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( |
| /*isContiguous*/std::true_type) const { |
| if ($_attr.empty()) { |
| return ::mlir::detail::ElementsAttrIndexer::contiguous<T>( |
| /*isSplat=*/false, nullptr); |
| } |
| |
| auto valueIt = $_attr.value_begin_impl(OverloadToken<T>()); |
| return ::mlir::detail::ElementsAttrIndexer::contiguous( |
| $_attr.isSplat(), &*valueIt); |
| } |
| /// Build an indexer for the given type `T`, which is represented via a |
| /// non-contiguous range. |
| template <typename T> |
| ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult( |
| /*isContiguous*/std::false_type) const { |
| auto valueIt = $_attr.value_begin_impl(OverloadToken<T>()); |
| return ::mlir::detail::ElementsAttrIndexer::nonContiguous( |
| $_attr.isSplat(), valueIt); |
| } |
| |
| public: |
| //===------------------------------------------------------------------===// |
| // Value Iteration |
| //===------------------------------------------------------------------===// |
| |
| /// The iterator for the given element type T. |
| template <typename T, typename AttrT = ConcreteAttr> |
| using iterator = decltype(std::declval<AttrT>().template value_begin<T>()); |
| /// The iterator range over the given element T. |
| template <typename T, typename AttrT = ConcreteAttr> |
| using iterator_range = |
| decltype(std::declval<AttrT>().template getValues<T>()); |
| |
| /// Return an iterator to the first element of this attribute as a value of |
| /// type `T`. |
| template <typename T> |
| auto value_begin() const { |
| return $_attr.value_begin_impl(OverloadToken<T>()); |
| } |
| |
| /// Return the elements of this attribute as a value of type 'T'. |
| template <typename T> |
| auto getValues() const { |
| auto beginIt = $_attr.template value_begin<T>(); |
| return detail::ElementsAttrRange<decltype(beginIt)>( |
| Attribute($_attr).getType(), beginIt, std::next(beginIt, size())); |
| } |
| }] # ElementsAttrInterfaceAccessors; |
| |
| let extraClassDeclaration = [{ |
| template <typename T> |
| using iterator = detail::ElementsAttrIterator<T>; |
| template <typename T> |
| using iterator_range = detail::ElementsAttrRange<iterator<T>>; |
| |
| //===------------------------------------------------------------------===// |
| // Accessors |
| //===------------------------------------------------------------------===// |
| |
| /// Return the type of this attribute. |
| ShapedType getType() const; |
| |
| /// Return the element type of this ElementsAttr. |
| Type getElementType() const { return getElementType(*this); } |
| static Type getElementType(Attribute elementsAttr); |
| |
| /// Return if the given 'index' refers to a valid element in this attribute. |
| bool isValidIndex(ArrayRef<uint64_t> index) const { |
| return isValidIndex(*this, index); |
| } |
| static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index); |
| static bool isValidIndex(Attribute elementsAttr, ArrayRef<uint64_t> index); |
| |
| /// Return the 1 dimensional flattened row-major index from the given |
| /// multi-dimensional index. |
| uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const { |
| return getFlattenedIndex(*this, index); |
| } |
| static uint64_t getFlattenedIndex(Type type, |
| ArrayRef<uint64_t> index); |
| static uint64_t getFlattenedIndex(Attribute elementsAttr, |
| ArrayRef<uint64_t> index) { |
| return getFlattenedIndex(elementsAttr.getType(), index); |
| } |
| |
| /// Returns the number of elements held by this attribute. |
| int64_t getNumElements() const { return getNumElements(*this); } |
| static int64_t getNumElements(Attribute elementsAttr); |
| |
| //===------------------------------------------------------------------===// |
| // Value Iteration |
| //===------------------------------------------------------------------===// |
| |
| template <typename T> |
| using DerivedAttrValueCheckT = |
| typename std::enable_if_t<std::is_base_of<Attribute, T>::value && |
| !std::is_same<Attribute, T>::value>; |
| template <typename T, typename ResultT> |
| using DefaultValueCheckT = |
| typename std::enable_if_t<std::is_same<Attribute, T>::value || |
| !std::is_base_of<Attribute, T>::value, |
| ResultT>; |
| |
| /// Return the splat value for this attribute. This asserts that the |
| /// attribute corresponds to a splat. |
| template <typename T> |
| T getSplatValue() const { |
| assert(isSplat() && "expected splat attribute"); |
| return *value_begin<T>(); |
| } |
| |
| /// Return the elements of this attribute as a value of type 'T'. |
| template <typename T> |
| DefaultValueCheckT<T, iterator_range<T>> getValues() const { |
| return {Attribute::getType(), value_begin<T>(), value_end<T>()}; |
| } |
| template <typename T> |
| DefaultValueCheckT<T, iterator<T>> value_begin() const; |
| template <typename T> |
| DefaultValueCheckT<T, iterator<T>> value_end() const { |
| return iterator<T>({}, size()); |
| } |
| |
| /// Return the held element values a range of T, where T is a derived |
| /// attribute type. |
| template <typename T> |
| using DerivedAttrValueIterator = |
| llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>; |
| template <typename T> |
| using DerivedAttrValueIteratorRange = |
| detail::ElementsAttrRange<DerivedAttrValueIterator<T>>; |
| template <typename T, typename = DerivedAttrValueCheckT<T>> |
| DerivedAttrValueIteratorRange<T> getValues() const { |
| auto castFn = [](Attribute attr) { return attr.template cast<T>(); }; |
| return {Attribute::getType(), llvm::map_range(getValues<Attribute>(), |
| static_cast<T (*)(Attribute)>(castFn))}; |
| } |
| template <typename T, typename = DerivedAttrValueCheckT<T>> |
| DerivedAttrValueIterator<T> value_begin() const { |
| return getValues<T>().begin(); |
| } |
| template <typename T, typename = DerivedAttrValueCheckT<T>> |
| DerivedAttrValueIterator<T> value_end() const { |
| return {value_end<Attribute>(), nullptr}; |
| } |
| |
| //===------------------------------------------------------------------===// |
| // Failable Value Iteration |
| |
| /// If this attribute supports iterating over element values of type `T`, |
| /// return the iterable range. Otherwise, return llvm::None. |
| template <typename T> |
| DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const { |
| if (Optional<iterator<T>> beginIt = try_value_begin<T>()) { |
| return iterator_range<T>(Attribute::getType(), *beginIt, |
| value_end<T>()); |
| } |
| return llvm::None; |
| } |
| template <typename T> |
| DefaultValueCheckT<T, Optional<iterator<T>>> try_value_begin() const; |
| |
| /// If this attribute supports iterating over element values of type `T`, |
| /// return the iterable range. Otherwise, return llvm::None. |
| template <typename T, typename = DerivedAttrValueCheckT<T>> |
| Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const { |
| auto values = tryGetValues<Attribute>(); |
| if (!values) |
| return llvm::None; |
| |
| auto castFn = [](Attribute attr) { return attr.template cast<T>(); }; |
| return DerivedAttrValueIteratorRange<T>( |
| Attribute::getType(), |
| llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn)) |
| ); |
| } |
| template <typename T, typename = DerivedAttrValueCheckT<T>> |
| Optional<DerivedAttrValueIterator<T>> try_value_begin() const { |
| if (auto values = tryGetValues<T>()) |
| return values->begin(); |
| return llvm::None; |
| } |
| }] # ElementsAttrInterfaceAccessors; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MemRefLayoutAttrInterface |
| //===----------------------------------------------------------------------===// |
| |
| def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> { |
| let cppNamespace = "::mlir"; |
| |
| let description = [{ |
| This interface is used for attributes that can represent the MemRef type's |
| layout semantics, such as dimension order in the memory, strides and offsets. |
| Such a layout attribute should be representable as a |
| [semi-affine map](Affine.md/#semi-affine-maps). |
| |
| Note: the MemRef type's layout is assumed to represent simple strided buffer |
| layout. For more complicated case, like sparse storage buffers, |
| it is preferable to use separate type with more specic layout, rather then |
| introducing extra complexity to the builin MemRef type. |
| }]; |
| |
| let methods = [ |
| InterfaceMethod< |
| "Get the MemRef layout as an AffineMap, the method must not return NULL", |
| "::mlir::AffineMap", "getAffineMap", (ins) |
| >, |
| |
| InterfaceMethod< |
| "Return true if this attribute represents the identity layout", |
| "bool", "isIdentity", (ins), |
| [{}], |
| [{ |
| return $_attr.getAffineMap().isIdentity(); |
| }] |
| >, |
| |
| InterfaceMethod< |
| "Check if the current layout is applicable to the provided shape", |
| "::mlir::LogicalResult", "verifyLayout", |
| (ins "::llvm::ArrayRef<int64_t>":$shape, |
| "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError), |
| [{}], |
| [{ |
| return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(), |
| shape, emitError); |
| }] |
| > |
| ]; |
| } |
| |
| #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ |