blob: 45295e874f3bdb7c49e49a193039b0fa55e291c0 [file] [log] [blame]
//===- BuiltinAttributeInterfaces.td - Attr interfaces -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definition of the ElementsAttr interface.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
#define MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// ElementsAttrInterface
//===----------------------------------------------------------------------===//
def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
let cppNamespace = "::mlir";
let description = [{
This interface is used for attributes that contain the constant elements of
a tensor or vector type. It allows for opaquely interacting with the
elements of the underlying attribute, and most importantly allows for
accessing the element values (including iteration) in any of the C++ data
types supported by the underlying attribute.
An attribute implementing this interface can expose the supported data types
in two steps:
* Define the set of iterable C++ data types:
An attribute may define the set of iterable types by providing a definition
of tuples `ContiguousIterableTypesT` and/or `NonContiguousIterableTypesT`.
- `ContiguousIterableTypesT` should contain types which can be iterated
contiguously. A contiguous range is an array-like range, such as
ArrayRef, where all of the elements are layed out sequentially in memory.
- `NonContiguousIterableTypesT` should contain types which can not be
iterated contiguously. A non-contiguous range implies no contiguity,
whose elements may even be materialized when indexing, such as the case
for a mapped_range.
As an example, consider an attribute that only contains i64 elements, with
the elements being stored within an ArrayRef. This attribute could
potentially define the iterable types as so:
```c++
using ContiguousIterableTypesT = std::tuple<uint64_t>;
using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>;
```
* Provide a `iterator value_begin_impl(OverloadToken<T>) const` overload for
each iterable type
These overloads should return an iterator to the start of the range for the
respective iterable type. Consider the example i64 elements attribute
described in the previous section. This attribute may define the
value_begin_impl overloads like so:
```c++
/// Provide begin iterators for the various iterable types.
/// * uint64_t
auto value_begin_impl(OverloadToken<uint64_t>) const {
return getElements().begin();
}
/// * APInt
auto value_begin_impl(OverloadToken<llvm::APInt>) const {
return llvm::map_range(getElements(), [=](uint64_t value) {
return llvm::APInt(/*numBits=*/64, value);
}).begin();
}
/// * Attribute
auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
mlir::Type elementType = getType().getElementType();
return llvm::map_range(getElements(), [=](uint64_t value) {
return mlir::IntegerAttr::get(elementType,
llvm::APInt(/*numBits=*/64, value));
}).begin();
}
```
After the above, ElementsAttr will now be able to iterate over elements
using each of the registered iterable data types:
```c++
ElementsAttr attr = myI64ElementsAttr;
// We can access value ranges for the data types via `getValues<T>`.
for (uint64_t value : attr.getValues<uint64_t>())
...;
for (llvm::APInt value : attr.getValues<llvm::APInt>())
...;
for (mlir::IntegerAttr value : attr.getValues<mlir::IntegerAttr>())
...;
// We can also access the value iterators directly.
auto it = attr.value_begin<uint64_t>(), e = attr.value_end<uint64_t>();
for (; it != e; ++it) {
uint64_t value = *it;
...
}
```
ElementsAttr also supports failable access to iterators and ranges. This
allows for safely checking if the attribute supports the data type, and can
also allow for code to have fast paths for native data types.
```c++
// Using `tryGetValues<T>`, we can also safely handle when the attribute
// doesn't support the data type.
if (auto range = attr.tryGetValues<uint64_t>()) {
for (uint64_t value : *range)
...;
return;
}
// We can also access the begin iterator safely, by using `try_value_begin`.
if (auto safeIt = attr.try_value_begin<uint64_t>()) {
auto it = *safeIt, e = attr.value_end<uint64_t>();
for (; it != e; ++it) {
uint64_t value = *it;
...
}
return;
}
```
}];
let methods = [
InterfaceMethod<[{
This method returns an opaque range indexer for the given elementID, which
corresponds to a desired C++ element data type. Returns the indexer if the
attribute supports the given data type, failure otherwise.
}],
"::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>", "getValuesImpl",
(ins "::mlir::TypeID":$elementID), [{}], /*defaultImplementation=*/[{
auto result = getValueImpl(
(typename ConcreteAttr::ContiguousIterableTypesT *)nullptr, elementID,
/*isContiguous=*/std::true_type());
if (succeeded(result))
return std::move(result);
return getValueImpl(
(typename ConcreteAttr::NonContiguousIterableTypesT *)nullptr,
elementID, /*isContiguous=*/std::false_type());
}]>,
InterfaceMethod<[{
Returns true if the attribute elements correspond to a splat, i.e. that
all elements of the attribute are the same value.
}], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{
// By default, only check for a single element splat.
return $_attr.getNumElements() == 1;
}]>
];
string ElementsAttrInterfaceAccessors = [{
/// Return the number of elements held by this attribute.
int64_t size() const { return getNumElements(); }
/// Return if the attribute holds no elements.
bool empty() const { return size() == 0; }
}];
let extraTraitClassDeclaration = [{
// By default, no types are iterable.
using ContiguousIterableTypesT = std::tuple<>;
using NonContiguousIterableTypesT = std::tuple<>;
//===------------------------------------------------------------------===//
// Accessors
//===------------------------------------------------------------------===//
/// Return the element type of this ElementsAttr.
Type getElementType() const {
return ::mlir::ElementsAttr::getElementType($_attr);
}
/// Returns the number of elements held by this attribute.
int64_t getNumElements() const {
return ::mlir::ElementsAttr::getNumElements($_attr);
}
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const {
return ::mlir::ElementsAttr::isValidIndex($_attr, index);
}
protected:
/// Returns the 1-dimensional flattened row-major index from the given
/// multi-dimensional index.
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
return ::mlir::ElementsAttr::getFlattenedIndex($_attr, index);
}
//===------------------------------------------------------------------===//
// Value Iteration Internals
//===------------------------------------------------------------------===//
protected:
/// This class is used to allow specifying function overloads for different
/// types, without actually taking the types as parameters. This avoids the
/// need to build complicated SFINAE to select specific overloads.
template <typename T>
struct OverloadToken {};
private:
/// This function unpacks the types within a given tuple and then forwards
/// on to the unwrapped variant.
template <typename... Ts, typename IsContiguousT>
auto getValueImpl(std::tuple<Ts...> *, ::mlir::TypeID elementID,
IsContiguousT isContiguous) const {
return getValueImpl<Ts...>(elementID, isContiguous);
}
/// Check to see if the given `elementID` matches the current type `T`. If
/// it does, build a value result using the current type. If it doesn't,
/// keep looking for the desired type.
template <typename T, typename... Ts, typename IsContiguousT>
auto getValueImpl(::mlir::TypeID elementID,
IsContiguousT isContiguous) const {
if (::mlir::TypeID::get<T>() == elementID)
return buildValueResult<T>(isContiguous);
return getValueImpl<Ts...>(elementID, isContiguous);
}
/// Bottom out case for no matching type.
template <typename IsContiguousT>
::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer>
getValueImpl(::mlir::TypeID, IsContiguousT) const {
return failure();
}
/// Build an indexer for the given type `T`, which is represented via a
/// contiguous range.
template <typename T>
::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
/*isContiguous*/std::true_type) const {
if ($_attr.empty()) {
return ::mlir::detail::ElementsAttrIndexer::contiguous<T>(
/*isSplat=*/false, nullptr);
}
auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
return ::mlir::detail::ElementsAttrIndexer::contiguous(
$_attr.isSplat(), &*valueIt);
}
/// Build an indexer for the given type `T`, which is represented via a
/// non-contiguous range.
template <typename T>
::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
/*isContiguous*/std::false_type) const {
auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
return ::mlir::detail::ElementsAttrIndexer::nonContiguous(
$_attr.isSplat(), valueIt);
}
public:
//===------------------------------------------------------------------===//
// Value Iteration
//===------------------------------------------------------------------===//
/// The iterator for the given element type T.
template <typename T, typename AttrT = ConcreteAttr>
using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
/// The iterator range over the given element T.
template <typename T, typename AttrT = ConcreteAttr>
using iterator_range =
decltype(std::declval<AttrT>().template getValues<T>());
/// Return an iterator to the first element of this attribute as a value of
/// type `T`.
template <typename T>
auto value_begin() const {
return $_attr.value_begin_impl(OverloadToken<T>());
}
/// Return the elements of this attribute as a value of type 'T'.
template <typename T>
auto getValues() const {
auto beginIt = $_attr.template value_begin<T>();
return detail::ElementsAttrRange<decltype(beginIt)>(
Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
}
}] # ElementsAttrInterfaceAccessors;
let extraClassDeclaration = [{
template <typename T>
using iterator = detail::ElementsAttrIterator<T>;
template <typename T>
using iterator_range = detail::ElementsAttrRange<iterator<T>>;
//===------------------------------------------------------------------===//
// Accessors
//===------------------------------------------------------------------===//
/// Return the type of this attribute.
ShapedType getType() const;
/// Return the element type of this ElementsAttr.
Type getElementType() const { return getElementType(*this); }
static Type getElementType(Attribute elementsAttr);
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const {
return isValidIndex(*this, index);
}
static bool isValidIndex(ShapedType type, ArrayRef<uint64_t> index);
static bool isValidIndex(Attribute elementsAttr, ArrayRef<uint64_t> index);
/// Return the 1 dimensional flattened row-major index from the given
/// multi-dimensional index.
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
return getFlattenedIndex(*this, index);
}
static uint64_t getFlattenedIndex(Type type,
ArrayRef<uint64_t> index);
static uint64_t getFlattenedIndex(Attribute elementsAttr,
ArrayRef<uint64_t> index) {
return getFlattenedIndex(elementsAttr.getType(), index);
}
/// Returns the number of elements held by this attribute.
int64_t getNumElements() const { return getNumElements(*this); }
static int64_t getNumElements(Attribute elementsAttr);
//===------------------------------------------------------------------===//
// Value Iteration
//===------------------------------------------------------------------===//
template <typename T>
using DerivedAttrValueCheckT =
typename std::enable_if_t<std::is_base_of<Attribute, T>::value &&
!std::is_same<Attribute, T>::value>;
template <typename T, typename ResultT>
using DefaultValueCheckT =
typename std::enable_if_t<std::is_same<Attribute, T>::value ||
!std::is_base_of<Attribute, T>::value,
ResultT>;
/// Return the splat value for this attribute. This asserts that the
/// attribute corresponds to a splat.
template <typename T>
T getSplatValue() const {
assert(isSplat() && "expected splat attribute");
return *value_begin<T>();
}
/// Return the elements of this attribute as a value of type 'T'.
template <typename T>
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
return {Attribute::getType(), value_begin<T>(), value_end<T>()};
}
template <typename T>
DefaultValueCheckT<T, iterator<T>> value_begin() const;
template <typename T>
DefaultValueCheckT<T, iterator<T>> value_end() const {
return iterator<T>({}, size());
}
/// Return the held element values a range of T, where T is a derived
/// attribute type.
template <typename T>
using DerivedAttrValueIterator =
llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
template <typename T>
using DerivedAttrValueIteratorRange =
detail::ElementsAttrRange<DerivedAttrValueIterator<T>>;
template <typename T, typename = DerivedAttrValueCheckT<T>>
DerivedAttrValueIteratorRange<T> getValues() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
static_cast<T (*)(Attribute)>(castFn))};
}
template <typename T, typename = DerivedAttrValueCheckT<T>>
DerivedAttrValueIterator<T> value_begin() const {
return getValues<T>().begin();
}
template <typename T, typename = DerivedAttrValueCheckT<T>>
DerivedAttrValueIterator<T> value_end() const {
return {value_end<Attribute>(), nullptr};
}
//===------------------------------------------------------------------===//
// Failable Value Iteration
/// If this attribute supports iterating over element values of type `T`,
/// return the iterable range. Otherwise, return llvm::None.
template <typename T>
DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
return iterator_range<T>(Attribute::getType(), *beginIt,
value_end<T>());
}
return llvm::None;
}
template <typename T>
DefaultValueCheckT<T, Optional<iterator<T>>> try_value_begin() const;
/// If this attribute supports iterating over element values of type `T`,
/// return the iterable range. Otherwise, return llvm::None.
template <typename T, typename = DerivedAttrValueCheckT<T>>
Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
auto values = tryGetValues<Attribute>();
if (!values)
return llvm::None;
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
return DerivedAttrValueIteratorRange<T>(
Attribute::getType(),
llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
);
}
template <typename T, typename = DerivedAttrValueCheckT<T>>
Optional<DerivedAttrValueIterator<T>> try_value_begin() const {
if (auto values = tryGetValues<T>())
return values->begin();
return llvm::None;
}
}] # ElementsAttrInterfaceAccessors;
}
//===----------------------------------------------------------------------===//
// MemRefLayoutAttrInterface
//===----------------------------------------------------------------------===//
def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
let cppNamespace = "::mlir";
let description = [{
This interface is used for attributes that can represent the MemRef type's
layout semantics, such as dimension order in the memory, strides and offsets.
Such a layout attribute should be representable as a
[semi-affine map](Affine.md/#semi-affine-maps).
Note: the MemRef type's layout is assumed to represent simple strided buffer
layout. For more complicated case, like sparse storage buffers,
it is preferable to use separate type with more specic layout, rather then
introducing extra complexity to the builin MemRef type.
}];
let methods = [
InterfaceMethod<
"Get the MemRef layout as an AffineMap, the method must not return NULL",
"::mlir::AffineMap", "getAffineMap", (ins)
>,
InterfaceMethod<
"Return true if this attribute represents the identity layout",
"bool", "isIdentity", (ins),
[{}],
[{
return $_attr.getAffineMap().isIdentity();
}]
>,
InterfaceMethod<
"Check if the current layout is applicable to the provided shape",
"::mlir::LogicalResult", "verifyLayout",
(ins "::llvm::ArrayRef<int64_t>":$shape,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError),
[{}],
[{
return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
shape, emitError);
}]
>
];
}
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_