[mlir][spirv] Add support for SPV_ARM_tensors (#144667)

This patch introduces a new custom type `!spirv.arm.tensor<>` to the
MLIR SPIR-V dialect to represent
`OpTypeTensorARM` as defined in the `SPV_ARM_tensors` extension.

The type models a shaped tensor with element type and optional shape,
and implements the
`ShapedType` interface to enable reuse of MLIR's existing shape-aware
infrastructure.

The type supports serialization to and from SPIR-V binary as
`OpTypeTensorARM`, and emits the
required capability (`TensorsARM`) and extension (`SPV_ARM_tensors`)
declarations automatically.

This addition lays the foundation for supporting structured tensor
values natively in SPIR-V and
will enable future support for operations defined in the
`SPV_ARM_tensors` extension, such as
`OpTensorReadARM`, `OpTensorWriteARM`, and `OpTensorQuerySizeARM`.

Reference: https://github.com/KhronosGroup/SPIRV-Registry/pull/342

---------

Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian@arm.com>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76c..d874817 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@
 
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
+def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
       SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader,
+      SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@
 def SPIRV_C_MultiViewport                               : I32EnumAttrCase<"MultiViewport", 57> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
 }
+def SPIRV_C_TensorsARM                                  : I32EnumAttrCase<"TensorsARM", 4174> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT        : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1523,6 +1544,8 @@
       SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+      SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
 
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@
                                 "any SPIR-V struct type">;
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+                                 "any SPIR-V tensorArm type">;
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage
+    SPIRV_AnyImage, SPIRV_AnyTensorArm
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformLogicalXor,
+      SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d..212cba6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@
 namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
 struct ImageTypeStorage;
 struct MatrixTypeStorage;
 struct PointerTypeStorage;
@@ -96,7 +97,8 @@
   std::optional<int64_t> getSizeInBytes();
 };
 
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
+// StructType, or SPIR-V TensorArmType.
 class CompositeType : public SPIRVType {
 public:
   using SPIRVType::SPIRVType;
@@ -477,6 +479,46 @@
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
+/// SPIR-V TensorARM Type
+class TensorArmType
+    : public Type::TypeBase<TensorArmType, CompositeType,
+                            detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+  using Base::Base;
+
+  using ShapedTypeTraits = ShapedType::Trait<TensorArmType>;
+  using ShapedTypeTraits::getDimSize;
+  using ShapedTypeTraits::getDynamicDimIndex;
+  using ShapedTypeTraits::getElementTypeBitWidth;
+  using ShapedTypeTraits::getNumDynamicDims;
+  using ShapedTypeTraits::getNumElements;
+  using ShapedTypeTraits::getRank;
+  using ShapedTypeTraits::hasStaticShape;
+  using ShapedTypeTraits::isDynamicDim;
+
+  static constexpr StringLiteral name = "spirv.arm.tensor";
+
+  // TensorArm supports minimum rank of 1, hence an empty shape here means
+  // unranked.
+  static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+  TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                          Type elementType) const;
+
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType);
+
+  Type getElementType() const;
+  ArrayRef<int64_t> getShape() const;
+  bool hasRank() const { return !getShape().empty(); }
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     std::optional<StorageClass> storage = std::nullopt);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       std::optional<StorageClass> storage = std::nullopt);
+};
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef..88c7adf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@
           << t.getNumElements();
       return Type();
     }
+  } else if (auto t = dyn_cast<TensorArmType>(type)) {
+    if (!isa<ScalarType>(t.getElementType())) {
+      parser.emitError(
+          typeLoc, "only scalar element type allowed in tensor type but found ")
+          << t.getElementType();
+      return Type();
+    }
   } else {
     parser.emitError(typeLoc, "cannot use ")
         << type << " to compose SPIR-V types";
@@ -363,6 +370,52 @@
   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
 }
 
+// tensor-arm-type ::=
+//   `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+                               DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return {};
+
+  bool unranked = false;
+  SmallVector<int64_t, 4> dims;
+  SMLoc countLoc = parser.getCurrentLocation();
+
+  if (parser.parseOptionalStar().succeeded()) {
+    unranked = true;
+    if (parser.parseXInDimensionList())
+      return {};
+  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
+    return {};
+  }
+
+  if (!unranked && dims.empty()) {
+    parser.emitError(countLoc, "arm.tensors do not support rank zero");
+    return {};
+  }
+
+  if (llvm::is_contained(dims, 0)) {
+    parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+    return {};
+  }
+
+  if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
+      llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
+    parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+                               "fully dynamic or completed shaped");
+    return {};
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return {};
+
+  if (parser.parseGreater())
+    return {};
+
+  return TensorArmType::get(dims, elementTy);
+}
+
 // TODO: Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -759,6 +812,8 @@
     return parseStructType(*this, parser);
   if (keyword == "matrix")
     return parseMatrixType(*this, parser);
+  if (keyword == "arm.tensor")
+    return parseTensorArmType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -855,10 +910,28 @@
   os << ">";
 }
 
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+  os << "arm.tensor<";
+
+  llvm::interleave(
+      type.getShape(), os,
+      [&](int64_t dim) {
+        if (ShapedType::isDynamic(dim))
+          os << '?';
+        else
+          os << dim;
+      },
+      "x");
+  if (!type.hasRank()) {
+    os << "*";
+  }
+  os << "x" << type.getElementType() << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType>(
+            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
           [&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027..eb2974d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@
       return failure();
   }
 
+  if (llvm::isa<TensorArmType>(type)) {
+    if (parser.parseOptionalColon().succeeded())
+      if (parser.parseType(type))
+        return failure();
+  }
+
   return parser.addTypeToList(type, result.types);
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b..2b90df4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,9 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <algorithm>
 #include <cstdint>
-#include <iterator>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -96,7 +97,7 @@
     return isValid(vectorType);
   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
                    spirv::MatrixType, spirv::RuntimeArrayType,
-                   spirv::StructType>(type);
+                   spirv::StructType, spirv::TensorArmType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -107,8 +108,8 @@
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
-          [](auto type) { return type.getElementType(); })
+      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+            TensorArmType>([](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
           [index](StructType type) { return type.getElementType(index); })
@@ -125,6 +126,8 @@
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
+  if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
+    return tensorArmType.getNumElements();
   if (llvm::isa<CooperativeMatrixType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +154,13 @@
         return llvm::cast<ScalarType>(type.getElementType())
             .getExtensions(extensions, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static constexpr Extension ext{Extension::SPV_ARM_tensors};
+        extensions.push_back(ext);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getExtensions(extensions, storage);
+      })
+
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -171,6 +181,12 @@
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static constexpr Capability cap{Capability::TensorsARM};
+        capabilities.push_back(cap);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getCapabilities(capabilities, storage);
+      })
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -186,6 +202,13 @@
       return std::nullopt;
     return *elementSize * vectorType.getNumElements();
   }
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    std::optional<int64_t> elementSize =
+        llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+    if (!elementSize)
+      return std::nullopt;
+    return *elementSize * tensorArmType.getNumElements();
+  }
   return std::nullopt;
 }
 
@@ -691,6 +714,8 @@
     return true;
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return CompositeType::isValid(vectorType);
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
+    return llvm::isa<ScalarType>(tensorArmType.getElementType());
   return false;
 }
 
@@ -712,6 +737,8 @@
     matrixType.getExtensions(extensions, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getExtensions(extensions, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getExtensions(extensions, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getExtensions");
   }
@@ -732,6 +759,8 @@
     matrixType.getCapabilities(capabilities, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getCapabilities(capabilities, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getCapabilities(capabilities, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
   }
@@ -1204,10 +1233,84 @@
 }
 
 //===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+  static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    auto [shape, elementType] = key;
+    shape = allocator.copyInto(shape);
+    return new (allocator.allocate<TensorArmTypeStorage>())
+        TensorArmTypeStorage(shape, elementType);
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    auto [shape, elementType] = key;
+    return llvm::hash_combine(shape, elementType);
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(shape, elementType);
+  }
+
+  TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+      : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+  ArrayRef<int64_t> shape;
+  Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+  return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                       Type elementType) const {
+  return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+void TensorArmType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    std::optional<StorageClass> storage) {
+
+  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  static constexpr Extension ext{Extension::SPV_ARM_tensors};
+  extensions.push_back(ext);
+}
+
+void TensorArmType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType())
+      .getCapabilities(capabilities, storage);
+  static constexpr Capability cap{Capability::TensorsARM};
+  capabilities.push_back(cap);
+}
+
+LogicalResult
+TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType) {
+  if (llvm::is_contained(shape, 0))
+    return emitError() << "arm.tensor do not support dimensions = 0";
+  if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
+      llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
+    return emitError()
+           << "arm.tensor shape dimensions must be either fully dynamic or "
+              "completed shaped";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
   addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
-           RuntimeArrayType, SampledImageType, StructType>();
+           RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index b30da77..55d6a38 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -164,6 +164,7 @@
   case spirv::Opcode::OpTypeRuntimeArray:
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
+  case spirv::Opcode::OpTypeTensorARM:
   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b9d9a90..b1abd8b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -935,6 +935,8 @@
     return processStructType(operands);
   case spirv::Opcode::OpTypeMatrix:
     return processMatrixType(operands);
+  case spirv::Opcode::OpTypeTensorARM:
+    return processTensorARMType(operands);
   default:
     return emitError(unknownLoc, "unhandled type instruction");
   }
@@ -1239,6 +1241,55 @@
 }
 
 LogicalResult
+spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
+  unsigned size = operands.size();
+  if (size < 2 || size > 4)
+    return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
+                                 "(result_id, element_type, (rank), (shape)) ")
+           << size;
+
+  Type elementTy = getType(operands[1]);
+  if (!elementTy)
+    return emitError(unknownLoc,
+                     "OpTypeTensorARM references undefined element type ")
+           << operands[1];
+
+  if (size == 2) {
+    typeMap[operands[0]] = TensorArmType::get({}, elementTy);
+    return success();
+  }
+
+  IntegerAttr rankAttr = getConstantInt(operands[2]);
+  if (!rankAttr)
+    return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
+                                 "scalar integer constant instruction");
+  unsigned rank = rankAttr.getValue().getZExtValue();
+  if (size == 3) {
+    SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
+    typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+    return success();
+  }
+
+  std::optional<std::pair<Attribute, Type>> shapeInfo =
+      getConstant(operands[3]);
+  if (!shapeInfo)
+    return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
+                                 "constant instruction of type OpTypeArray");
+
+  ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
+  SmallVector<int64_t, 1> shape;
+  for (auto dimAttr : shapeArrayAttr.getValue()) {
+    auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
+    if (!dimIntAttr)
+      return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
+                                   "dimension size");
+    shape.push_back(dimIntAttr.getValue().getSExtValue());
+  }
+  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+  return success();
+}
+
+LogicalResult
 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2)
     return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e4556e7..1bc9e4a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -291,6 +291,8 @@
 
   LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+
   LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index d258bfd..ebebd2d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -729,6 +729,54 @@
     return success();
   }
 
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+    uint32_t elementTypeID = 0;
+    uint32_t rank = 0;
+    uint32_t shapeID = 0;
+    uint32_t rankID = 0;
+    if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
+                               elementTypeID, serializationCtx))) {
+      return failure();
+    }
+    if (tensorArmType.hasRank()) {
+      ArrayRef<int64_t> dims = tensorArmType.getShape();
+      rank = dims.size();
+      rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
+      if (rankID == 0) {
+        return failure();
+      }
+
+      bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
+      if (rank > 0 && shaped) {
+        auto I32Type = IntegerType::get(type.getContext(), 32);
+        auto shapeType = ArrayType::get(I32Type, rank);
+        if (rank == 1) {
+          SmallVector<uint64_t, 1> index(rank);
+          shapeID = prepareDenseElementsConstant(
+              loc, shapeType,
+              mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
+              index);
+        } else {
+          shapeID = prepareArrayConstant(
+              loc, shapeType,
+              mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
+        }
+        if (shapeID == 0) {
+          return failure();
+        }
+      }
+    }
+    typeEnum = spirv::Opcode::OpTypeTensorARM;
+    operands.push_back(elementTypeID);
+    if (rankID == 0)
+      return success();
+    operands.push_back(rankID);
+    if (shapeID == 0)
+      return success();
+    operands.push_back(shapeID);
+    return success();
+  }
+
   // TODO: Handle other types.
   return emitError(loc, "unhandled type in serialization: ") << type;
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index c23894c..7d45b5e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -564,3 +564,54 @@
 func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> ()
 
 // -----
+
+//===----------------------------------------------------------------------===//
+// TensorArm
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>)
+func.func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>)
+func.func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>)
+func.func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>)
+func.func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>)
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>)
+func.func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>) -> ()
+// -----
+
+// expected-error @+1 {{arm.tensor shape dimensions must be either fully dynamic or completed shaped}}
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<1x?xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support rank zero}}
+func.func private @arm_tensor_rank_zero(!spirv.arm.tensor<i32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>)
+func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support zero dimensions}}
+func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> ()
diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir
new file mode 100644
index 0000000..75b648e
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tensorARM.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, TensorsARM], [SPV_ARM_tensors]> {
+  // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" {
+  spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @shaped_rank2_int_arm_tensor(%arg0: !spirv.arm.tensor<2x3xi32>) "None" {
+  spirv.func @shaped_rank2_int_arm_tensor(%arg0 : !spirv.arm.tensor<2x3xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xi64> "None" {
+  spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xui64> "None" {
+    // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xi64>
+    %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xui64>
+
+    spirv.ReturnValue %0: !spirv.arm.tensor<3xui64>
+  }
+
+// -----
+
+  // CHECK: spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+  spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+    // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+    %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xsi32>
+  }
+
+// -----
+
+  // CHECK: spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+  spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+    // CHECK: spirv.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : !spirv.arm.tensor<3xf32>
+    %0 = spirv.Constant dense<[3., 4., 5.]> : !spirv.arm.tensor<3xf32>
+
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+  }
+
+// -----
+
+  // CHECK: spirv.func @unranked_int_arm_tensor(%arg0: !spirv.arm.tensor<*xi32>) "None" {
+  spirv.func @unranked_int_arm_tensor(%arg0 : !spirv.arm.tensor<*xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @unshaped_int_arm_tensor(%arg0: !spirv.arm.tensor<?xi32>) "None" {
+  spirv.func @unshaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<?xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @unshaped_int_arm_tensor_2(%arg0: !spirv.arm.tensor<?x?xi32>) "None" {
+  spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor<?x?xi32>) "None" {
+    spirv.Return
+  }
+}