[mlir][spirv] Add support for SPV_ARM_tensors (#144667)
This patch introduces a new custom type `!spirv.arm.tensor<>` to the
MLIR SPIR-V dialect to represent
`OpTypeTensorARM` as defined in the `SPV_ARM_tensors` extension.
The type models a shaped tensor with element type and optional shape,
and implements the
`ShapedType` interface to enable reuse of MLIR's existing shape-aware
infrastructure.
The type supports serialization to and from SPIR-V binary as
`OpTypeTensorARM`, and emits the
required capability (`TensorsARM`) and extension (`SPV_ARM_tensors`)
declarations automatically.
This addition lays the foundation for supporting structured tensor
values natively in SPIR-V and
will enable future support for operations defined in the
`SPV_ARM_tensors` extension, such as
`OpTensorReadARM`, `OpTensorWriteARM`, and `OpTensorQuerySizeARM`.
Reference: https://github.com/KhronosGroup/SPIRV-Registry/pull/342
---------
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian@arm.com>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76c..d874817 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
+def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
+ SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
}
+def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
@@ -1523,6 +1544,8 @@
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+ SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+ SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+ "any SPIR-V tensorArm type">;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage
+ SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+ SPIRV_OC_OpGroupNonUniformLogicalXor,
+ SPIRV_OC_OpTypeTensorARM,
+ SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d..212cba6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
@@ -96,7 +97,8 @@
std::optional<int64_t> getSizeInBytes();
};
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
+// StructType, or SPIR-V TensorArmType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
@@ -477,6 +479,46 @@
std::optional<StorageClass> storage = std::nullopt);
};
+/// SPIR-V TensorARM Type
+class TensorArmType
+ : public Type::TypeBase<TensorArmType, CompositeType,
+ detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+ using Base::Base;
+
+ using ShapedTypeTraits = ShapedType::Trait<TensorArmType>;
+ using ShapedTypeTraits::getDimSize;
+ using ShapedTypeTraits::getDynamicDimIndex;
+ using ShapedTypeTraits::getElementTypeBitWidth;
+ using ShapedTypeTraits::getNumDynamicDims;
+ using ShapedTypeTraits::getNumElements;
+ using ShapedTypeTraits::getRank;
+ using ShapedTypeTraits::hasStaticShape;
+ using ShapedTypeTraits::isDynamicDim;
+
+ static constexpr StringLiteral name = "spirv.arm.tensor";
+
+ // TensorArm supports minimum rank of 1, hence an empty shape here means
+ // unranked.
+ static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+ TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
+
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType);
+
+ Type getElementType() const;
+ ArrayRef<int64_t> getShape() const;
+ bool hasRank() const { return !getShape().empty(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage = std::nullopt);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage = std::nullopt);
+};
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef..88c7adf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@
<< t.getNumElements();
return Type();
}
+ } else if (auto t = dyn_cast<TensorArmType>(type)) {
+ if (!isa<ScalarType>(t.getElementType())) {
+ parser.emitError(
+ typeLoc, "only scalar element type allowed in tensor type but found ")
+ << t.getElementType();
+ return Type();
+ }
} else {
parser.emitError(typeLoc, "cannot use ")
<< type << " to compose SPIR-V types";
@@ -363,6 +370,52 @@
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}
+// tensor-arm-type ::=
+// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ bool unranked = false;
+ SmallVector<int64_t, 4> dims;
+ SMLoc countLoc = parser.getCurrentLocation();
+
+ if (parser.parseOptionalStar().succeeded()) {
+ unranked = true;
+ if (parser.parseXInDimensionList())
+ return {};
+ } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
+ return {};
+ }
+
+ if (!unranked && dims.empty()) {
+ parser.emitError(countLoc, "arm.tensors do not support rank zero");
+ return {};
+ }
+
+ if (llvm::is_contained(dims, 0)) {
+ parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+ return {};
+ }
+
+ if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
+ llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
+ parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+ "fully dynamic or completed shaped");
+ return {};
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return {};
+
+ if (parser.parseGreater())
+ return {};
+
+ return TensorArmType::get(dims, elementTy);
+}
+
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -759,6 +812,8 @@
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
+ if (keyword == "arm.tensor")
+ return parseTensorArmType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -855,10 +910,28 @@
os << ">";
}
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+ os << "arm.tensor<";
+
+ llvm::interleave(
+ type.getShape(), os,
+ [&](int64_t dim) {
+ if (ShapedType::isDynamic(dim))
+ os << '?';
+ else
+ os << dim;
+ },
+ "x");
+ if (!type.hasRank()) {
+ os << "*";
+ }
+ os << "x" << type.getElementType() << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType>(
+ ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027..eb2974d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@
return failure();
}
+ if (llvm::isa<TensorArmType>(type)) {
+ if (parser.parseOptionalColon().succeeded())
+ if (parser.parseType(type))
+ return failure();
+ }
+
return parser.addTypeToList(type, result.types);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b..2b90df4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,9 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
#include <cstdint>
-#include <iterator>
+#include <numeric>
using namespace mlir;
using namespace mlir::spirv;
@@ -96,7 +97,7 @@
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
spirv::MatrixType, spirv::RuntimeArrayType,
- spirv::StructType>(type);
+ spirv::StructType, spirv::TensorArmType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -107,8 +108,8 @@
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
- [](auto type) { return type.getElementType(); })
+ .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+ TensorArmType>([](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
@@ -125,6 +126,8 @@
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
+ if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
+ return tensorArmType.getNumElements();
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +154,13 @@
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static constexpr Extension ext{Extension::SPV_ARM_tensors};
+ extensions.push_back(ext);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getExtensions(extensions, storage);
+ })
+
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -171,6 +181,12 @@
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static constexpr Capability cap{Capability::TensorsARM};
+ capabilities.push_back(cap);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getCapabilities(capabilities, storage);
+ })
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -186,6 +202,13 @@
return std::nullopt;
return *elementSize * vectorType.getNumElements();
}
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ std::optional<int64_t> elementSize =
+ llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+ if (!elementSize)
+ return std::nullopt;
+ return *elementSize * tensorArmType.getNumElements();
+ }
return std::nullopt;
}
@@ -691,6 +714,8 @@
return true;
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return CompositeType::isValid(vectorType);
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
+ return llvm::isa<ScalarType>(tensorArmType.getElementType());
return false;
}
@@ -712,6 +737,8 @@
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
@@ -732,6 +759,8 @@
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
@@ -1204,10 +1233,84 @@
}
//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+ using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+ static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ auto [shape, elementType] = key;
+ shape = allocator.copyInto(shape);
+ return new (allocator.allocate<TensorArmTypeStorage>())
+ TensorArmTypeStorage(shape, elementType);
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ auto [shape, elementType] = key;
+ return llvm::hash_combine(shape, elementType);
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(shape, elementType);
+ }
+
+ TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+ : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+ ArrayRef<int64_t> shape;
+ Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+ return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+void TensorArmType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ static constexpr Extension ext{Extension::SPV_ARM_tensors};
+ extensions.push_back(ext);
+}
+
+void TensorArmType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage) {
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
+ static constexpr Capability cap{Capability::TensorsARM};
+ capabilities.push_back(cap);
+}
+
+LogicalResult
+TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType) {
+ if (llvm::is_contained(shape, 0))
+ return emitError() << "arm.tensor do not support dimensions = 0";
+ if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
+ llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
+ return emitError()
+ << "arm.tensor shape dimensions must be either fully dynamic or "
+ "completed shaped";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
- RuntimeArrayType, SampledImageType, StructType>();
+ RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index b30da77..55d6a38 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -164,6 +164,7 @@
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
+ case spirv::Opcode::OpTypeTensorARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b9d9a90..b1abd8b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -935,6 +935,8 @@
return processStructType(operands);
case spirv::Opcode::OpTypeMatrix:
return processMatrixType(operands);
+ case spirv::Opcode::OpTypeTensorARM:
+ return processTensorARMType(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1239,6 +1241,55 @@
}
LogicalResult
+spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
+ unsigned size = operands.size();
+ if (size < 2 || size > 4)
+ return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
+ "(result_id, element_type, (rank), (shape)) ")
+ << size;
+
+ Type elementTy = getType(operands[1]);
+ if (!elementTy)
+ return emitError(unknownLoc,
+ "OpTypeTensorARM references undefined element type ")
+ << operands[1];
+
+ if (size == 2) {
+ typeMap[operands[0]] = TensorArmType::get({}, elementTy);
+ return success();
+ }
+
+ IntegerAttr rankAttr = getConstantInt(operands[2]);
+ if (!rankAttr)
+ return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
+ "scalar integer constant instruction");
+ unsigned rank = rankAttr.getValue().getZExtValue();
+ if (size == 3) {
+ SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
+ typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+ return success();
+ }
+
+ std::optional<std::pair<Attribute, Type>> shapeInfo =
+ getConstant(operands[3]);
+ if (!shapeInfo)
+ return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
+ "constant instruction of type OpTypeArray");
+
+ ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
+ SmallVector<int64_t, 1> shape;
+ for (auto dimAttr : shapeArrayAttr.getValue()) {
+ auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
+ if (!dimIntAttr)
+ return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
+ "dimension size");
+ shape.push_back(dimIntAttr.getValue().getSExtValue());
+ }
+ typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+ return success();
+}
+
+LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e4556e7..1bc9e4a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -291,6 +291,8 @@
LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
+ LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+
LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index d258bfd..ebebd2d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -729,6 +729,54 @@
return success();
}
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+ uint32_t elementTypeID = 0;
+ uint32_t rank = 0;
+ uint32_t shapeID = 0;
+ uint32_t rankID = 0;
+ if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
+ elementTypeID, serializationCtx))) {
+ return failure();
+ }
+ if (tensorArmType.hasRank()) {
+ ArrayRef<int64_t> dims = tensorArmType.getShape();
+ rank = dims.size();
+ rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
+ if (rankID == 0) {
+ return failure();
+ }
+
+ bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
+ if (rank > 0 && shaped) {
+ auto I32Type = IntegerType::get(type.getContext(), 32);
+ auto shapeType = ArrayType::get(I32Type, rank);
+ if (rank == 1) {
+ SmallVector<uint64_t, 1> index(rank);
+ shapeID = prepareDenseElementsConstant(
+ loc, shapeType,
+ mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
+ index);
+ } else {
+ shapeID = prepareArrayConstant(
+ loc, shapeType,
+ mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
+ }
+ if (shapeID == 0) {
+ return failure();
+ }
+ }
+ }
+ typeEnum = spirv::Opcode::OpTypeTensorARM;
+ operands.push_back(elementTypeID);
+ if (rankID == 0)
+ return success();
+ operands.push_back(rankID);
+ if (shapeID == 0)
+ return success();
+ operands.push_back(shapeID);
+ return success();
+ }
+
// TODO: Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index c23894c..7d45b5e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -564,3 +564,54 @@
func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> ()
// -----
+
+//===----------------------------------------------------------------------===//
+// TensorArm
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>)
+func.func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>)
+func.func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>)
+func.func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>)
+func.func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>)
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>)
+func.func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>) -> ()
+// -----
+
+// expected-error @+1 {{arm.tensor shape dimensions must be either fully dynamic or completed shaped}}
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<1x?xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support rank zero}}
+func.func private @arm_tensor_rank_zero(!spirv.arm.tensor<i32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>)
+func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support zero dimensions}}
+func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> ()
diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir
new file mode 100644
index 0000000..75b648e
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tensorARM.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, TensorsARM], [SPV_ARM_tensors]> {
+ // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" {
+ spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @shaped_rank2_int_arm_tensor(%arg0: !spirv.arm.tensor<2x3xi32>) "None" {
+ spirv.func @shaped_rank2_int_arm_tensor(%arg0 : !spirv.arm.tensor<2x3xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xi64> "None" {
+ spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xui64> "None" {
+ // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xi64>
+ %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xui64>
+
+ spirv.ReturnValue %0: !spirv.arm.tensor<3xui64>
+ }
+
+// -----
+
+ // CHECK: spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+ spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+ // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+ %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xsi32>
+ }
+
+// -----
+
+ // CHECK: spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+ spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+ // CHECK: spirv.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : !spirv.arm.tensor<3xf32>
+ %0 = spirv.Constant dense<[3., 4., 5.]> : !spirv.arm.tensor<3xf32>
+
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+ }
+
+// -----
+
+ // CHECK: spirv.func @unranked_int_arm_tensor(%arg0: !spirv.arm.tensor<*xi32>) "None" {
+ spirv.func @unranked_int_arm_tensor(%arg0 : !spirv.arm.tensor<*xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @unshaped_int_arm_tensor(%arg0: !spirv.arm.tensor<?xi32>) "None" {
+ spirv.func @unshaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<?xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @unshaped_int_arm_tensor_2(%arg0: !spirv.arm.tensor<?x?xi32>) "None" {
+ spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor<?x?xi32>) "None" {
+ spirv.Return
+ }
+}