blob: 4ec6cc52db7aab3dc0bceca486a73ce59e9d3b19 [file] [log] [blame]
//===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the types in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
#define MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include <tuple>
namespace mlir {
namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct SampledImageTypeStorage;
struct StructTypeStorage;
} // namespace detail
// Base SPIR-V type for providing availability queries.
class SPIRVType : public Type {
public:
using Type::Type;
static bool classof(Type type);
bool isScalarOrVector();
/// The extension requirements for each type are following the
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>;
/// Appends to `extensions` the extensions needed for this type to appear in
/// the given `storage` class. This method does not guarantee the uniqueness
/// of extensions; the same extension may be appended multiple times.
void getExtensions(ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
/// The capability requirements for each type are following the
/// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
/// convention.
using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>;
/// Appends to `capabilities` the capabilities needed for this type to appear
/// in the given `storage` class. This method does not guarantee the
/// uniqueness of capabilities; the same capability may be appended multiple
/// times.
void getCapabilities(CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
/// Returns the size in bytes for each type. If no size can be calculated,
/// returns `llvm::None`. Note that if the type has explicit layout, it is
/// also taken into account in calculation.
Optional<int64_t> getSizeInBytes();
};
// SPIR-V scalar type: bool type, integer type, floating point type.
class ScalarType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
static bool classof(Type type);
/// Returns true if the given integer type is valid for the SPIR-V dialect.
static bool isValid(FloatType);
/// Returns true if the given float type is valid for the SPIR-V dialect.
static bool isValid(IntegerType);
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
Optional<int64_t> getSizeInBytes();
};
// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
static bool classof(Type type);
/// Returns true if the given vector type is valid for the SPIR-V dialect.
static bool isValid(VectorType);
/// Return the number of elements of the type. This should only be called if
/// hasCompileTimeKnownNumElements is true.
unsigned getNumElements() const;
Type getElementType(unsigned) const;
/// Return true if the number of elements is known at compile time and is not
/// implementation dependent.
bool hasCompileTimeKnownNumElements() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
Optional<int64_t> getSizeInBytes();
};
// SPIR-V array type
class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
detail::ArrayTypeStorage> {
public:
using Base::Base;
static ArrayType get(Type elementType, unsigned elementCount);
/// Returns an array type with the given stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount,
unsigned stride);
unsigned getNumElements() const;
Type getElementType() const;
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
/// Returns the array size in bytes. Since array type may have an explicit
/// stride declaration (in bytes), we also include it in the calculation.
Optional<int64_t> getSizeInBytes();
};
// SPIR-V image type
class ImageType
: public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
public:
using Base::Base;
static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
ImageFormat format = ImageFormat::Unknown) {
return ImageType::get(
std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
elementType, dim, depth, arrayed, samplingInfo, samplerUse,
format));
}
static ImageType
get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
Type getElementType() const;
Dim getDim() const;
ImageDepthInfo getDepthInfo() const;
ImageArrayedInfo getArrayedInfo() const;
ImageSamplingInfo getSamplingInfo() const;
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO: Add support for Access qualifier
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V pointer type
class PointerType : public Type::TypeBase<PointerType, SPIRVType,
detail::PointerTypeStorage> {
public:
using Base::Base;
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType() const;
StorageClass getStorageClass() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V run-time array type
class RuntimeArrayType
: public Type::TypeBase<RuntimeArrayType, SPIRVType,
detail::RuntimeArrayTypeStorage> {
public:
using Base::Base;
static RuntimeArrayType get(Type elementType);
/// Returns a runtime array type with the given stride in bytes.
static RuntimeArrayType get(Type elementType, unsigned stride);
Type getElementType() const;
/// Returns the array stride in bytes. 0 means no stride decorated on this
/// type.
unsigned getArrayStride() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V sampled image type
class SampledImageType
: public Type::TypeBase<SampledImageType, SPIRVType,
detail::SampledImageTypeStorage> {
public:
using Base::Base;
static SampledImageType get(Type imageType);
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType);
Type getImageType() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<spirv::StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<spirv::StorageClass> storage = llvm::None);
};
/// SPIR-V struct type. Two kinds of struct types are supported:
/// - Literal: a literal struct type is uniqued by its fields (types + offset
/// info + decoration info).
/// - Identified: an indentified struct type is uniqued by its string identifier
/// (name). This is useful in representing recursive structs. For example, the
/// following C struct:
///
/// struct A {
/// A* next;
/// };
///
/// would be represented in MLIR as:
///
/// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)>
///
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
class StructType : public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage> {
public:
using Base::Base;
// Type for specifying the offset of the struct members
using OffsetInfo = uint32_t;
// Type for specifying the decoration(s) on struct members
struct MemberDecorationInfo {
uint32_t memberIndex : 31;
uint32_t hasValue : 1;
Decoration decoration;
uint32_t decorationValue;
MemberDecorationInfo(uint32_t index, uint32_t hasValue,
Decoration decoration, uint32_t decorationValue)
: memberIndex(index), hasValue(hasValue), decoration(decoration),
decorationValue(decorationValue) {}
bool operator==(const MemberDecorationInfo &other) const {
return (this->memberIndex == other.memberIndex) &&
(this->decoration == other.decoration) &&
(this->decorationValue == other.decorationValue);
}
bool operator<(const MemberDecorationInfo &other) const {
return this->memberIndex < other.memberIndex ||
(this->memberIndex == other.memberIndex &&
static_cast<uint32_t>(this->decoration) <
static_cast<uint32_t>(other.decoration));
}
};
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
/// StructType::trySetBody(...) must follow when the StructType contents are
/// available (e.g. parsed or deserialized).
///
/// Note: If another thread creates (or had already created) a struct with the
/// same identifier, that struct will be returned as a result.
static StructType getIdentified(MLIRContext *context, StringRef identifier);
/// Construct a (possibly identified) StructType with no members.
///
/// Note: this method might fail in a multi-threaded setup if another thread
/// created an identified struct with the same identifier but with different
/// contents before returning. In which case, an empty (default-constructed)
/// StructType is returned.
static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
/// For literal structs, return an empty string.
/// For identified structs, return the struct's identifier.
StringRef getIdentifier() const;
/// Returns true if the StructType is identified.
bool isIdentified() const;
unsigned getNumElements() const;
Type getElementType(unsigned) const;
/// Range class for element types.
class ElementTypeRange
: public ::llvm::detail::indexed_accessor_range_base<
ElementTypeRange, const Type *, Type, Type, Type> {
private:
using RangeBaseT::RangeBaseT;
/// See `llvm::detail::indexed_accessor_range_base` for details.
static const Type *offset_base(const Type *object, ptrdiff_t index) {
return object + index;
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
static Type dereference_iterator(const Type *object, ptrdiff_t index) {
return object[index];
}
/// Allow base class access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
ElementTypeRange getElementTypes() const;
bool hasOffset() const;
uint64_t getMemberOffset(unsigned) const;
// Returns in `memberDecorations` the Decorations (apart from Offset)
// associated with all members of the StructType.
void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
&memberDecorations) const;
// Returns in `decorationsInfo` all the Decorations (apart from Offset)
// associated with the `i`-th member of the StructType.
void getMemberDecorations(
unsigned i,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
llvm::hash_code
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
// SPIR-V cooperative matrix type
class CooperativeMatrixNVType
: public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
detail::CooperativeMatrixTypeStorage> {
public:
using Base::Base;
static CooperativeMatrixNVType get(Type elementType, Scope scope,
unsigned rows, unsigned columns);
Type getElementType() const;
/// Return the scope of the cooperative matrix.
Scope getScope() const;
/// return the number of rows of the matrix.
unsigned getRows() const;
/// return the number of columns of the matrix.
unsigned getColumns() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
// SPIR-V matrix type
class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
detail::MatrixTypeStorage> {
public:
using Base::Base;
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
Type getColumnType() const;
/// Returns the number of rows.
unsigned getNumRows() const;
/// Returns the number of columns.
unsigned getNumColumns() const;
/// Returns total number of elements (rows*columns).
unsigned getNumElements() const;
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage = llvm::None);
};
} // end namespace spirv
} // end namespace mlir
#endif // MLIR_DIALECT_SPIRV_IR_SPIRVTYPES_H_