blob: 05b5dafb8a68219d2c399327934f046b05c94d8c [file] [log] [blame]
//===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the types in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::spirv;
//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
struct spirv::detail::ArrayTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, unsigned, unsigned>;
static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, elementCount, stride);
}
ArrayTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
stride(std::get<2>(key)) {}
Type elementType;
unsigned elementCount;
unsigned stride;
};
ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
assert(elementCount && "ArrayType needs at least one element");
return Base::get(elementType.getContext(), elementType, elementCount,
/*stride=*/0);
}
ArrayType ArrayType::get(Type elementType, unsigned elementCount,
unsigned stride) {
assert(elementCount && "ArrayType needs at least one element");
return Base::get(elementType.getContext(), elementType, elementCount, stride);
}
unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
Type ArrayType::getElementType() const { return getImpl()->elementType; }
unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
}
void ArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}
Optional<int64_t> ArrayType::getSizeInBytes() {
auto elementType = getElementType().cast<SPIRVType>();
Optional<int64_t> size = elementType.getSizeInBytes();
if (!size)
return llvm::None;
return (*size + getArrayStride()) * getNumElements();
}
//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//
bool CompositeType::classof(Type type) {
if (auto vectorType = type.dyn_cast<VectorType>())
return isValid(vectorType);
return type
.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
spirv::RuntimeArrayType, spirv::StructType>();
}
bool CompositeType::isValid(VectorType type) {
switch (type.getNumElements()) {
case 2:
case 3:
case 4:
case 8:
case 16:
break;
default:
return false;
}
return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
}
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
.Default(
[](Type) -> Type { llvm_unreachable("invalid composite type"); });
}
unsigned CompositeType::getNumElements() const {
if (auto arrayType = dyn_cast<ArrayType>())
return arrayType.getNumElements();
if (auto matrixType = dyn_cast<MatrixType>())
return matrixType.getNumColumns();
if (auto structType = dyn_cast<StructType>())
return structType.getNumElements();
if (auto vectorType = dyn_cast<VectorType>())
return vectorType.getNumElements();
if (isa<CooperativeMatrixNVType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
}
if (isa<RuntimeArrayType>()) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
}
llvm_unreachable("invalid composite type");
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
}
void CompositeType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
StructType>(
[&](auto type) { type.getExtensions(extensions, storage); })
.Case<VectorType>([&](VectorType type) {
return type.getElementType().cast<ScalarType>().getExtensions(
extensions, storage);
})
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
void CompositeType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
TypeSwitch<Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
StructType>(
[&](auto type) { type.getCapabilities(capabilities, storage); })
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
if (vecSize == 8 || vecSize == 16) {
static const Capability caps[] = {Capability::Vector16};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
return type.getElementType().cast<ScalarType>().getCapabilities(
capabilities, storage);
})
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
Optional<int64_t> CompositeType::getSizeInBytes() {
if (auto arrayType = dyn_cast<ArrayType>())
return arrayType.getSizeInBytes();
if (auto structType = dyn_cast<StructType>())
return structType.getSizeInBytes();
if (auto vectorType = dyn_cast<VectorType>()) {
Optional<int64_t> elementSize =
vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
if (!elementSize)
return llvm::None;
return *elementSize * vectorType.getNumElements();
}
return llvm::None;
}
//===----------------------------------------------------------------------===//
// CooperativeMatrixType
//===----------------------------------------------------------------------===//
struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
static CooperativeMatrixTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<CooperativeMatrixTypeStorage>())
CooperativeMatrixTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, scope, rows, columns);
}
CooperativeMatrixTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), rows(std::get<2>(key)),
columns(std::get<3>(key)), scope(std::get<1>(key)) {}
Type elementType;
unsigned rows;
unsigned columns;
Scope scope;
};
CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
Scope scope, unsigned rows,
unsigned columns) {
return Base::get(elementType.getContext(), elementType, scope, rows, columns);
}
Type CooperativeMatrixNVType::getElementType() const {
return getImpl()->elementType;
}
Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
unsigned CooperativeMatrixNVType::getColumns() const {
return getImpl()->columns;
}
void CooperativeMatrixNVType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
extensions.push_back(ref);
}
void CooperativeMatrixNVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
static const Capability caps[] = {Capability::CooperativeMatrixNV};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
template <typename T> static constexpr unsigned getNumBits() { return 0; }
template <> constexpr unsigned getNumBits<Dim>() {
static_assert((1 << 3) > getMaxEnumValForDim(),
"Not enough bits to encode Dim value");
return 3;
}
template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
"Not enough bits to encode ImageDepthInfo value");
return 2;
}
template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
"Not enough bits to encode ImageArrayedInfo value");
return 1;
}
template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
"Not enough bits to encode ImageSamplingInfo value");
return 1;
}
template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
"Not enough bits to encode ImageSamplerUseInfo value");
return 2;
}
template <> constexpr unsigned getNumBits<ImageFormat>() {
static_assert((1 << 6) > getMaxEnumValForImageFormat(),
"Not enough bits to encode ImageFormat value");
return 6;
}
struct spirv::detail::ImageTypeStorage : public TypeStorage {
public:
using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
samplerUseInfo, format);
}
ImageTypeStorage(const KeyTy &key)
: elementType(std::get<0>(key)), dim(std::get<1>(key)),
depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
format(std::get<6>(key)) {}
Type elementType;
Dim dim : getNumBits<Dim>();
ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
ImageFormat format : getNumBits<ImageFormat>();
};
ImageType
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
value) {
return Base::get(std::get<0>(value).getContext(), value);
}
Type ImageType::getElementType() const { return getImpl()->elementType; }
Dim ImageType::getDim() const { return getImpl()->dim; }
ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
ImageArrayedInfo ImageType::getArrayedInfo() const {
return getImpl()->arrayedInfo;
}
ImageSamplingInfo ImageType::getSamplingInfo() const {
return getImpl()->samplingInfo;
}
ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
return getImpl()->samplerUseInfo;
}
ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
Optional<StorageClass>) {
// Image types do not require extra extensions thus far.
}
void ImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
if (auto dimCaps = spirv::getCapabilities(getDim()))
capabilities.push_back(*dimCaps);
if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
capabilities.push_back(*fmtCaps);
}
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
struct spirv::detail::PointerTypeStorage : public TypeStorage {
// (Type, StorageClass) as the key: Type stored in this struct, and
// StorageClass stored as TypeStorage's subclass data.
using KeyTy = std::pair<Type, StorageClass>;
static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<PointerTypeStorage>())
PointerTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(pointeeType, storageClass);
}
PointerTypeStorage(const KeyTy &key)
: pointeeType(key.first), storageClass(key.second) {}
Type pointeeType;
StorageClass storageClass;
};
PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
return Base::get(pointeeType.getContext(), pointeeType, storageClass);
}
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
StorageClass PointerType::getStorageClass() const {
return getImpl()->storageClass;
}
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
getPointeeType().cast<SPIRVType>().getExtensions(extensions,
getStorageClass());
if (auto scExts = spirv::getExtensions(getStorageClass()))
extensions.push_back(*scExts);
}
void PointerType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
// Use this pointer type's storage class because this pointer indicates we are
// using the pointee type in that specific storage class.
getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
getStorageClass());
if (auto scCaps = spirv::getCapabilities(getStorageClass()))
capabilities.push_back(*scCaps);
}
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
using KeyTy = std::pair<Type, unsigned>;
static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<RuntimeArrayTypeStorage>())
RuntimeArrayTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(elementType, stride);
}
RuntimeArrayTypeStorage(const KeyTy &key)
: elementType(key.first), stride(key.second) {}
Type elementType;
unsigned stride;
};
RuntimeArrayType RuntimeArrayType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType, /*stride=*/0);
}
RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
return Base::get(elementType.getContext(), elementType, stride);
}
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
void RuntimeArrayType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
}
void RuntimeArrayType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
{
static const Capability caps[] = {Capability::Shader};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
// ScalarType
//===----------------------------------------------------------------------===//
bool ScalarType::classof(Type type) {
if (auto floatType = type.dyn_cast<FloatType>()) {
return isValid(floatType);
}
if (auto intType = type.dyn_cast<IntegerType>()) {
return isValid(intType);
}
return false;
}
bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
bool ScalarType::isValid(IntegerType type) {
switch (type.getWidth()) {
case 1:
case 8:
case 16:
case 32:
case 64:
return true;
default:
return false;
}
}
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
// 8- or 16-bit integer/floating-point numbers will require extra extensions
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
// SPV_KHR_8bit_storage for more details.
if (!storage)
return;
switch (*storage) {
case StorageClass::PushConstant:
case StorageClass::StorageBuffer:
case StorageClass::Uniform:
if (getIntOrFloatBitWidth() == 8) {
static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
extensions.push_back(ref);
}
LLVM_FALLTHROUGH;
case StorageClass::Input:
case StorageClass::Output:
if (getIntOrFloatBitWidth() == 16) {
static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
extensions.push_back(ref);
}
break;
default:
break;
}
}
void ScalarType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
unsigned bitwidth = getIntOrFloatBitWidth();
// 8- or 16-bit integer/floating-point numbers will require extra capabilities
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
// SPV_KHR_8bit_storage for more details.
#define STORAGE_CASE(storage, cap8, cap16) \
case StorageClass::storage: { \
if (bitwidth == 8) { \
static const Capability caps[] = {Capability::cap8}; \
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
capabilities.push_back(ref); \
} else if (bitwidth == 16) { \
static const Capability caps[] = {Capability::cap16}; \
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
capabilities.push_back(ref); \
} \
/* No requirements for other bitwidths */ \
return; \
}
// This part only handles the cases where special bitwidths appearing in
// interface storage classes.
if (storage) {
switch (*storage) {
STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
StorageBuffer16BitAccess);
STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
StorageUniform16);
case StorageClass::Input:
case StorageClass::Output: {
if (bitwidth == 16) {
static const Capability caps[] = {Capability::StorageInputOutput16};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
return;
}
default:
break;
}
}
#undef STORAGE_CASE
// For other non-interface storage classes, require a different set of
// capabilities for special bitwidths.
#define WIDTH_CASE(type, width) \
case width: { \
static const Capability caps[] = {Capability::type##width}; \
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
capabilities.push_back(ref); \
} break
if (auto intType = dyn_cast<IntegerType>()) {
switch (bitwidth) {
case 32:
case 1:
break;
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
WIDTH_CASE(Int, 64);
default:
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
assert(isa<FloatType>());
switch (bitwidth) {
case 32:
break;
WIDTH_CASE(Float, 16);
WIDTH_CASE(Float, 64);
default:
llvm_unreachable("invalid bitwidth to getCapabilities");
}
}
#undef WIDTH_CASE
}
Optional<int64_t> ScalarType::getSizeInBytes() {
auto bitWidth = getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
// type. If they are stored (in conjunction with OpVariable), they can only
// be used with logical addressing operations, not physical, and only with
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
// Private, Function, Input, and Output."
if (bitWidth == 1)
return llvm::None;
return bitWidth / 8;
}
//===----------------------------------------------------------------------===//
// SPIRVType
//===----------------------------------------------------------------------===//
bool SPIRVType::classof(Type type) {
// Allow SPIR-V dialect types
if (llvm::isa<SPIRVDialect>(type.getDialect()))
return true;
if (type.isa<ScalarType>())
return true;
if (auto vectorType = type.dyn_cast<VectorType>())
return CompositeType::isValid(vectorType);
return false;
}
bool SPIRVType::isScalarOrVector() {
return isIntOrFloat() || isa<VectorType>();
}
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
if (auto scalarType = dyn_cast<ScalarType>()) {
scalarType.getExtensions(extensions, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
compositeType.getExtensions(extensions, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getExtensions(extensions, storage);
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
sampledImageType.getExtensions(extensions, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
ptrType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
}
void SPIRVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
if (auto scalarType = dyn_cast<ScalarType>()) {
scalarType.getCapabilities(capabilities, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
compositeType.getCapabilities(capabilities, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
imageType.getCapabilities(capabilities, storage);
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
sampledImageType.getCapabilities(capabilities, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
ptrType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
}
Optional<int64_t> SPIRVType::getSizeInBytes() {
if (auto scalarType = dyn_cast<ScalarType>())
return scalarType.getSizeInBytes();
if (auto compositeType = dyn_cast<CompositeType>())
return compositeType.getSizeInBytes();
return llvm::None;
}
//===----------------------------------------------------------------------===//
// SampledImageType
//===----------------------------------------------------------------------===//
struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
using KeyTy = Type;
SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<SampledImageTypeStorage>())
SampledImageTypeStorage(key);
}
Type imageType;
};
SampledImageType SampledImageType::get(Type imageType) {
return Base::get(imageType.getContext(), imageType);
}
SampledImageType
SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type imageType) {
return Base::getChecked(emitError, imageType.getContext(), imageType);
}
Type SampledImageType::getImageType() const { return getImpl()->imageType; }
LogicalResult
SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType) {
if (!imageType.isa<ImageType>())
return emitError() << "expected image type";
return success();
}
void SampledImageType::getExtensions(
SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getImageType().cast<ImageType>().getExtensions(extensions, storage);
}
void SampledImageType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
/// Type storage for SPIR-V structure types:
///
/// Structures are uniqued using:
/// - for identified structs:
/// - a string identifier;
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
/// - a list of member decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
/// - a list of member decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
/// in order to mutate the storage object providing the actual content.
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMemberDecorations(0), memberDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
/// associated with such storage is immutable.
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
StructType::MemberDecorationInfo const *memberDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
/// - a StringRef representing the struct identifier;
/// - for literal structs:
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
/// A literal struct type is uniqued only by the second part (fields 1, 2, and
/// 3) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
ArrayRef<StructType::MemberDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
///
/// For literal structs, return true if the given key contains a matching list
/// of member types + offset info + decoration info.
bool operator==(const KeyTy &key) const {
if (isIdentified()) {
// Identified types are uniqued by their identifier.
return getIdentifier() == std::get<0>(key);
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
getMemberDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
/// an identified struct and leaves the rest of the struct type data to be set
/// through a later call to StructType::trySetBody(...).
///
/// If, on the other hand, the key contains an empty identifier, a literal
/// struct is constructed using the other fields of the key.
static StructTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
StringRef keyIdentifier = std::get<0>(key);
if (!keyIdentifier.empty()) {
StringRef identifier = allocator.copyInto(keyIdentifier);
// Identified StructType body/members will be set through trySetBody(...)
// later.
return new (allocator.allocate<StructTypeStorage>())
StructTypeStorage(identifier);
}
ArrayRef<Type> keyTypes = std::get<1>(key);
// Copy the member type and layout information into the bump pointer
const Type *typesList = nullptr;
if (!keyTypes.empty()) {
typesList = allocator.copyInto(keyTypes).data();
}
const StructType::OffsetInfo *offsetInfoList = nullptr;
if (!std::get<2>(key).empty()) {
ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
assert(keyOffsetInfo.size() == keyTypes.size() &&
"size of offset information must be same as the size of number of "
"elements");
offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
}
const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
unsigned numMemberDecorations = 0;
if (!std::get<3>(key).empty()) {
auto keyMemberDecorations = std::get<3>(key);
numMemberDecorations = keyMemberDecorations.size();
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
return new (allocator.allocate<StructTypeStorage>())
StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
numMemberDecorations, memberDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
}
ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
if (offsetInfo) {
return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
}
return {};
}
ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
if (memberDecorationsInfo) {
return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
numMemberDecorations);
}
return {};
}
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
/// Sets the struct type content for identified structs. Calling this method
/// is only valid for identified structs.
///
/// Fails under the following conditions:
/// - If called for a literal struct;
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
LogicalResult mutate(
TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
ArrayRef<StructType::OffsetInfo> structOffsetInfo,
ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
getMemberDecorationsInfo() != structMemberDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
numMembers = structMemberTypes.size();
// Copy the member type and layout information into the bump pointer.
if (!structMemberTypes.empty())
memberTypesAndIsBodySet.setPointer(
allocator.copyInto(structMemberTypes).data());
if (!structOffsetInfo.empty()) {
assert(structOffsetInfo.size() == structMemberTypes.size() &&
"size of offset information must be same as the size of number of "
"elements");
offsetInfo = allocator.copyInto(structOffsetInfo).data();
}
if (!structMemberDecorationInfo.empty()) {
numMemberDecorations = structMemberDecorationInfo.size();
memberDecorationsInfo =
allocator.copyInto(structMemberDecorationInfo).data();
}
return success();
}
llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
StructType::OffsetInfo const *offsetInfo;
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
memberDecorations.begin(), memberDecorations.end());
llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
sortedDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
StringRef identifier) {
assert(!identifier.empty() &&
"StructType identifier must be non-empty string");
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
ArrayRef<StructType::MemberDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
ArrayRef<StructType::MemberDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
ArrayRef<StructType::MemberDecorationInfo>())))
return StructType();
return newStructType;
}
StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
Type StructType::getElementType(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
}
StructType::ElementTypeRange StructType::getElementTypes() const {
return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
getNumElements());
}
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
}
void StructType::getMemberDecorations(
SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
const {
memberDecorations.clear();
auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
memberDecorations.append(implMemberDecorations.begin(),
implMemberDecorations.end());
}
void StructType::getMemberDecorations(
unsigned index,
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
assert(getNumElements() > index && "member index out of range");
auto memberDecorations = getImpl()->getMemberDecorationsInfo();
decorationsInfo.clear();
for (const auto &memberDecoration : memberDecorations) {
if (memberDecoration.memberIndex == index) {
decorationsInfo.push_back(memberDecoration);
}
if (memberDecoration.memberIndex > index) {
// Early exit since the decorations are stored sorted.
return;
}
}
}
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
ArrayRef<MemberDecorationInfo> memberDecorations) {
return Base::mutate(memberTypes, offsetInfo, memberDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
for (Type elementType : getElementTypes())
elementType.cast<SPIRVType>().getExtensions(extensions, storage);
}
void StructType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
for (Type elementType : getElementTypes())
elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
}
llvm::hash_code spirv::hash_value(
const StructType::MemberDecorationInfo &memberDecorationInfo) {
return llvm::hash_combine(memberDecorationInfo.memberIndex,
memberDecorationInfo.decoration);
}
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStorage(Type columnType, uint32_t columnCount)
: TypeStorage(), columnType(columnType), columnCount(columnCount) {}
using KeyTy = std::tuple<Type, uint32_t>;
static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
// Initialize the memory using placement new.
return new (allocator.allocate<MatrixTypeStorage>())
MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
}
bool operator==(const KeyTy &key) const {
return key == KeyTy(columnType, columnCount);
}
Type columnType;
const uint32_t columnCount;
};
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
return Base::get(columnType.getContext(), columnType, columnCount);
}
MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount) {
return Base::getChecked(emitError, columnType.getContext(), columnType,
columnCount);
}
LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount) {
if (columnCount < 2 || columnCount > 4)
return emitError() << "matrix can have 2, 3, or 4 columns only";
if (!isValidColumnType(columnType))
return emitError() << "matrix columns must be vectors of floats";
/// The underlying vectors (columns) must be of size 2, 3, or 4
ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
if (columnShape.size() != 1)
return emitError() << "matrix columns must be 1D vectors";
if (columnShape[0] < 2 || columnShape[0] > 4)
return emitError() << "matrix columns must be of size 2, 3, or 4";
return success();
}
/// Returns true if the matrix elements are vectors of float elements
bool MatrixType::isValidColumnType(Type columnType) {
if (auto vectorType = columnType.dyn_cast<VectorType>()) {
if (vectorType.getElementType().isa<FloatType>())
return true;
}
return false;
}
Type MatrixType::getColumnType() const { return getImpl()->columnType; }
Type MatrixType::getElementType() const {
return getImpl()->columnType.cast<VectorType>().getElementType();
}
unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
unsigned MatrixType::getNumRows() const {
return getImpl()->columnType.cast<VectorType>().getShape()[0];
}
unsigned MatrixType::getNumElements() const {
return (getImpl()->columnCount) * getNumRows();
}
void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage) {
getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
}
void MatrixType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
Optional<StorageClass> storage) {
{
static const Capability caps[] = {Capability::Matrix};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
}
// Add any capabilities associated with the underlying vectors (i.e., columns)
getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
}
//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
PointerType, RuntimeArrayType, SampledImageType, StructType>();
}