[mlir][SPIR-V] Add support for SPV_INTEL_long_composites extension (#195685)

Add serialization and deserialization support for the
SPV_INTEL_long_composites extension, which allows splitting
composite/struct instructions that exceed the SPIR-V 16-bit word count
limit (65535 words) into a head instruction followed by one or more
continuation instructions
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a9a8947..fcc4be4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4686,6 +4686,10 @@
 def SPIRV_OC_OpAssumeTrueKHR                  : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
 def SPIRV_OC_OpExpectKHR                      : I32EnumAttrCase<"OpExpectKHR", 5631>;
 def SPIRV_OC_OpAtomicFAddEXT                  : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
+def SPIRV_OC_OpTypeStructContinuedINTEL       : I32EnumAttrCase<"OpTypeStructContinuedINTEL", 6090>;
+def SPIRV_OC_OpConstantCompositeContinuedINTEL : I32EnumAttrCase<"OpConstantCompositeContinuedINTEL", 6091>;
+def SPIRV_OC_OpSpecConstantCompositeContinuedINTEL : I32EnumAttrCase<"OpSpecConstantCompositeContinuedINTEL", 6092>;
+def SPIRV_OC_OpCompositeConstructContinuedINTEL : I32EnumAttrCase<"OpCompositeConstructContinuedINTEL", 6096>;
 def SPIRV_OC_OpConvertFToBF16INTEL            : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
 def SPIRV_OC_OpConvertBF16ToFINTEL            : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
 def SPIRV_OC_OpControlBarrierArriveINTEL      : I32EnumAttrCase<"OpControlBarrierArriveINTEL", 6142>;
@@ -4810,6 +4814,10 @@
       SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpExpectKHR, SPIRV_OC_OpAtomicFAddEXT,
+      SPIRV_OC_OpTypeStructContinuedINTEL,
+      SPIRV_OC_OpConstantCompositeContinuedINTEL,
+      SPIRV_OC_OpSpecConstantCompositeContinuedINTEL,
+      SPIRV_OC_OpCompositeConstructContinuedINTEL,
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
       SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
diff --git a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
index 4a41163..34aa774 100644
--- a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
+++ b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
@@ -17,6 +17,7 @@
 #include "mlir/Support/LLVM.h"
 
 #include <cstdint>
+#include <optional>
 
 namespace mlir {
 namespace spirv {
@@ -58,6 +59,25 @@
   return str;
 }
 
+/// Returns the SPV_INTEL_long_composites continuation opcode that may follow
+/// `parent`, or std::nullopt if `parent` is not a splittable composite/struct
+/// op.
+inline std::optional<spirv::Opcode>
+getContinuationOpcode(spirv::Opcode parent) {
+  switch (parent) {
+  case spirv::Opcode::OpTypeStruct:
+    return spirv::Opcode::OpTypeStructContinuedINTEL;
+  case spirv::Opcode::OpConstantComposite:
+    return spirv::Opcode::OpConstantCompositeContinuedINTEL;
+  case spirv::Opcode::OpSpecConstantComposite:
+    return spirv::Opcode::OpSpecConstantCompositeContinuedINTEL;
+  case spirv::Opcode::OpCompositeConstruct:
+    return spirv::Opcode::OpCompositeConstructContinuedINTEL;
+  default:
+    return std::nullopt;
+  }
+}
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 21a8400..f65b559 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -125,11 +125,45 @@
   return success();
 }
 
+void spirv::Deserializer::mergeLongCompositeContinuations(
+    spirv::Opcode opcode, ArrayRef<uint32_t> &operands,
+    SmallVectorImpl<uint32_t> &mergedStorage) {
+  std::optional<spirv::Opcode> continuationOp = getContinuationOpcode(opcode);
+  if (!continuationOp)
+    return;
+
+  size_t binarySize = binary.size();
+  auto isNextContinuation = [&]() {
+    if (curOffset >= binarySize)
+      return false;
+    uint32_t wordCount = binary[curOffset] >> 16;
+    if (wordCount == 0 || curOffset + wordCount > binarySize)
+      return false;
+    return extractOpcode(binary[curOffset]) == *continuationOp;
+  };
+
+  if (!isNextContinuation())
+    return;
+
+  mergedStorage.assign(operands);
+  do {
+    spirv::Opcode contOpcode;
+    ArrayRef<uint32_t> contOperands;
+    if (failed(sliceInstruction(contOpcode, contOperands, *continuationOp)))
+      return;
+    llvm::append_range(mergedStorage, contOperands);
+  } while (isNextContinuation());
+  operands = mergedStorage;
+}
+
 LogicalResult spirv::Deserializer::processInstruction(
     spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
   LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
                                 << spirv::stringifyOpcode(opcode) << "\n");
 
+  SmallVector<uint32_t, 0> mergedStorage;
+  mergeLongCompositeContinuations(opcode, operands, mergedStorage);
+
   // First dispatch all the instructions whose opcode does not correspond to
   // those that have a direct mirror in the SPIR-V dialect
   switch (opcode) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index f3dab42..b2adbb5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -518,6 +518,15 @@
   sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
                    std::optional<spirv::Opcode> expectedOpcode = std::nullopt);
 
+  /// If `opcode` is a SPV_INTEL_long_composites splittable opcode and the
+  /// next binary instruction(s) are matching `*ContinuedINTEL` ops, consumes
+  /// them and rebinds `operands` to a buffer (held in `mergedStorage`)
+  /// containing the parent + continuation operands concatenated.
+  void
+  mergeLongCompositeContinuations(spirv::Opcode opcode,
+                                  ArrayRef<uint32_t> &operands,
+                                  SmallVectorImpl<uint32_t> &mergedStorage);
+
   /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
   /// This method is the main entrance for handling SPIR-V instruction; it
   /// checks the instruction opcode and dispatches to the corresponding handler.
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index a2c942d..841fc55 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -121,8 +121,8 @@
     operands.push_back(constituentID);
   }
 
-  encodeInstructionInto(typesGlobalValues,
-                        spirv::Opcode::OpSpecConstantComposite, operands);
+  encodeInstructionWithContinuationInto(
+      typesGlobalValues, spirv::Opcode::OpSpecConstantComposite, operands);
   specConstIDMap[op.getSymName()] = resultID;
 
   return processName(resultID, op.getSymName());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 5c9e378..7a2eaf3 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -213,6 +213,48 @@
                           {static_cast<uint32_t>(cap)});
 }
 
+void Serializer::addLongCompositesCapability() {
+  if (longCompositesEmitted)
+    return;
+  longCompositesEmitted = true;
+  auto vceTriple = module.getVceTriple();
+  if (!llvm::is_contained(vceTriple->getCapabilities(),
+                          spirv::Capability::LongCompositesINTEL))
+    encodeInstructionInto(
+        capabilities, spirv::Opcode::OpCapability,
+        {static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL)});
+  if (!llvm::is_contained(vceTriple->getExtensions(),
+                          spirv::Extension::SPV_INTEL_long_composites)) {
+    SmallVector<uint32_t, 8> extName;
+    spirv::encodeStringLiteralInto(
+        extName,
+        spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
+    encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
+  }
+}
+
+void Serializer::encodeInstructionWithContinuationInto(
+    SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
+    ArrayRef<uint32_t> operands) {
+  if (1 + operands.size() <= spirv::kMaxWordCount) {
+    encodeInstructionInto(binary, op, operands);
+    return;
+  }
+
+  std::optional<spirv::Opcode> continuationOp =
+      spirv::getContinuationOpcode(op);
+  assert(continuationOp && "op is not a splittable composite/struct opcode");
+
+  const unsigned chunk = spirv::kMaxWordCount - 1;
+  encodeInstructionInto(binary, op, operands.take_front(chunk));
+  for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
+       rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
+    encodeInstructionInto(binary, *continuationOp, rest.take_front(chunk));
+  }
+
+  addLongCompositesCapability();
+}
+
 void Serializer::processDebugInfo() {
   if (!options.emitDebugInfo)
     return;
@@ -560,7 +602,11 @@
 
     typeIDMap[type] = typeID;
 
-    encodeInstructionInto(typesGlobalValues, typeEnum, operands);
+    if (typeEnum == spirv::Opcode::OpTypeStruct)
+      encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
+                                            operands);
+    else
+      encodeInstructionInto(typesGlobalValues, typeEnum, operands);
 
     if (recursiveStructInfos.count(type) != 0) {
       // This recursive struct type is emitted already, now the OpTypePointer
@@ -1024,8 +1070,8 @@
       return 0;
     }
   }
-  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
-  encodeInstructionInto(typesGlobalValues, opcode, operands);
+  encodeInstructionWithContinuationInto(
+      typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
 
   return resultID;
 }
@@ -1104,8 +1150,8 @@
       }
     }
   }
-  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
-  encodeInstructionInto(typesGlobalValues, opcode, operands);
+  encodeInstructionWithContinuationInto(
+      typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
 
   return resultID;
 }
@@ -1605,6 +1651,9 @@
         return processBranchConditionalOp(op);
       })
       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
+      .Case([&](spirv::CompositeConstructOp op) {
+        return processCompositeConstructOp(op);
+      })
       .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
         return processConstantCompositeReplicateOp(op);
       })
@@ -1645,6 +1694,41 @@
           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
 }
 
+LogicalResult
+Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
+  Location loc = op.getLoc();
+
+  uint32_t resultTypeID = 0;
+  if (failed(processType(loc, op.getType(), resultTypeID)))
+    return failure();
+
+  uint32_t resultID = getNextID();
+  valueIDMap[op.getResult()] = resultID;
+
+  SmallVector<uint32_t, 8> operands;
+  operands.reserve(2 + op.getConstituents().size());
+  operands.push_back(resultTypeID);
+  operands.push_back(resultID);
+  for (Value constituent : op.getConstituents()) {
+    uint32_t id = getValueID(constituent);
+    assert(id && "use before def!");
+    operands.push_back(id);
+  }
+
+  if (failed(emitDebugLine(functionBody, loc)))
+    return failure();
+
+  encodeInstructionWithContinuationInto(
+      functionBody, spirv::Opcode::OpCompositeConstruct, operands);
+
+  for (auto attr : op->getAttrs()) {
+    if (failed(processDecoration(loc, resultID, attr)))
+      return failure();
+  }
+
+  return success();
+}
+
 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
                                                       StringRef extInstSet,
                                                       uint32_t opcode) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index eb5ac0d..e43556f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -104,10 +104,23 @@
 
   LogicalResult processExtension();
 
+  /// Encodes `op` + `operands` into `binary`, splitting via the
+  /// SPV_INTEL_long_composites continuation opcode when the total word count
+  /// would exceed kMaxWordCount. `op` must be a splittable composite/struct
+  /// opcode (see getContinuationOpcode). The capability and extension are
+  /// emitted lazily on first split.
+  void encodeInstructionWithContinuationInto(SmallVectorImpl<uint32_t> &binary,
+                                             spirv::Opcode op,
+                                             ArrayRef<uint32_t> operands);
+
+  void addLongCompositesCapability();
+
   void processMemoryModel();
 
   LogicalResult processConstantOp(spirv::ConstantOp op);
 
+  LogicalResult processCompositeConstructOp(spirv::CompositeConstructOp op);
+
   LogicalResult processConstantCompositeReplicateOp(
       spirv::EXTConstantCompositeReplicateOp op);
 
@@ -387,6 +400,8 @@
   /// The next available result <id>.
   uint32_t nextID = 1;
 
+  bool longCompositesEmitted = false;
+
   // The following are for different SPIR-V instruction sections. They follow
   // the logical layout of a SPIR-V module.
 
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index af55296..b0413a8 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/Target/SPIRV/Deserialization.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -229,3 +230,323 @@
   };
   EXPECT_FALSE(scanInstruction(hasVarName));
 }
+
+//===----------------------------------------------------------------------===//
+// SPV_INTEL_long_composites: composites whose binary form would exceed the
+// SPIR-V 16-bit word-count limit are split into a parent + *ContinuedINTEL ops
+// on serialization, and merged back on deserialization. These tests build the
+// large composites programmatically so that the IR doesn't have to expand
+// thousands of operands literally.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Picked to comfortably exceed kMaxWordCount = 65535 for any of the splittable
+// composite/struct opcodes -- each one packs at most kMaxWordCount - {1,2,3}
+// operands into the parent word, so 65540 always triggers a split.
+constexpr unsigned kLongCompositeSize = 65540;
+
+bool hasOpcode(SmallVectorImpl<uint32_t> &binary, spirv::Opcode target) {
+  size_t offset = spirv::kHeaderWordCount;
+  while (offset < binary.size()) {
+    uint32_t wordCount = binary[offset] >> 16;
+    if (!wordCount || offset + wordCount > binary.size())
+      return false;
+    auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+    if (op == target)
+      return true;
+    offset += wordCount;
+  }
+  return false;
+}
+
+bool hasLongCompositesCapabilityAndExtension(
+    SmallVectorImpl<uint32_t> &binary) {
+  bool foundCap = false;
+  bool foundExt = false;
+  size_t offset = spirv::kHeaderWordCount;
+  size_t binarySize = binary.size();
+  while (offset < binarySize) {
+    uint32_t wordCount = binary[offset] >> 16;
+    if (!wordCount || offset + wordCount > binarySize)
+      break;
+    auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+    ArrayRef<uint32_t> operands(binary.data() + offset + 1, wordCount - 1);
+    if (op == spirv::Opcode::OpCapability && !operands.empty() &&
+        operands[0] ==
+            static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL))
+      foundCap = true;
+    if (op == spirv::Opcode::OpExtension) {
+      unsigned idx = 0;
+      if (spirv::decodeStringLiteral(operands, idx) ==
+          spirv::stringifyExtension(
+              spirv::Extension::SPV_INTEL_long_composites))
+        foundExt = true;
+    }
+    offset += wordCount;
+  }
+  return foundCap && foundExt;
+}
+
+// Verifies that no instruction in the binary has a word count exceeding the
+// SPIR-V 16-bit limit (which would mean the splitting logic failed).
+bool allInstructionsWithinWordLimit(SmallVectorImpl<uint32_t> &binary) {
+  size_t offset = spirv::kHeaderWordCount;
+  size_t binarySize = binary.size();
+  while (offset < binarySize) {
+    uint32_t wordCount = binary[offset] >> 16;
+    if (!wordCount || wordCount > spirv::kMaxWordCount)
+      return false;
+    offset += wordCount;
+  }
+  return true;
+}
+
+} // namespace
+
+TEST_F(SerializationTest, LongTypeStructIsSplit) {
+  OpBuilder builder(module->getRegion());
+  Type i32Type = builder.getIntegerType(32);
+  Type f32Type = builder.getF32Type();
+  SmallVector<Type> memberTypes;
+  memberTypes.reserve(kLongCompositeSize);
+  for (unsigned i = 0; i < kLongCompositeSize; ++i)
+    memberTypes.push_back((i & 1) ? f32Type : i32Type);
+  SmallVector<spirv::StructType::OffsetInfo> offsets(kLongCompositeSize, 0);
+  auto structType = spirv::StructType::get(memberTypes, offsets);
+  addGlobalVar(structType, "var0");
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+  EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStruct));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStructContinuedINTEL));
+  EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+  MLIRContext freshContext;
+  freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+  OwningOpRef<spirv::ModuleOp> roundTripped =
+      spirv::deserialize(binary, &freshContext);
+  ASSERT_TRUE(roundTripped);
+  bool foundStruct = false;
+  roundTripped->walk([&](spirv::GlobalVariableOp gv) {
+    auto ptrType = dyn_cast<spirv::PointerType>(gv.getType());
+    if (!ptrType)
+      return;
+    auto rtStruct = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
+    if (!rtStruct)
+      return;
+    ASSERT_EQ(rtStruct.getNumElements(), kLongCompositeSize);
+    bool typesMatch = true;
+    for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+      Type expected = (i & 1) ? Type(Float32Type::get(&freshContext))
+                              : Type(IntegerType::get(&freshContext, 32));
+      if (rtStruct.getElementType(i) != expected) {
+        typesMatch = false;
+        break;
+      }
+    }
+    EXPECT_TRUE(typesMatch);
+    foundStruct = true;
+  });
+  EXPECT_TRUE(foundStruct);
+}
+
+TEST_F(SerializationTest, LongConstantCompositeIsSplit) {
+  OpBuilder builder(module->getRegion());
+  Location loc = UnknownLoc::get(&context);
+  Type i32Type = builder.getIntegerType(32);
+  auto arrayType = spirv::ArrayType::get(i32Type, kLongCompositeSize);
+  auto funcType = builder.getFunctionType({}, {arrayType});
+
+  auto funcOp = spirv::FuncOp::create(builder, loc, "long_array_const",
+                                      funcType, spirv::FunctionControl::None);
+  Block *entry = funcOp.addEntryBlock();
+  OpBuilder bodyBuilder = OpBuilder::atBlockBegin(entry);
+  SmallVector<Attribute> elements;
+  elements.reserve(kLongCompositeSize);
+  for (unsigned i = 0; i < kLongCompositeSize; ++i)
+    elements.push_back(bodyBuilder.getI32IntegerAttr(i & 0xff));
+  auto arrayAttr = bodyBuilder.getArrayAttr(elements);
+  auto cst = spirv::ConstantOp::create(bodyBuilder, loc, arrayType, arrayAttr);
+  spirv::ReturnValueOp::create(bodyBuilder, loc, cst.getResult());
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+  EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpConstantComposite));
+  EXPECT_TRUE(
+      hasOpcode(binary, spirv::Opcode::OpConstantCompositeContinuedINTEL));
+  EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+  MLIRContext freshContext;
+  freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+  OwningOpRef<spirv::ModuleOp> roundTripped =
+      spirv::deserialize(binary, &freshContext);
+  ASSERT_TRUE(roundTripped);
+  bool foundConst = false;
+  roundTripped->walk([&](spirv::ConstantOp op) {
+    auto arr = dyn_cast<ArrayAttr>(op.getValue());
+    if (!arr)
+      return;
+    ASSERT_EQ(arr.size(), kLongCompositeSize);
+    bool valuesMatch = true;
+    for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+      auto intAttr = dyn_cast<IntegerAttr>(arr[i]);
+      if (!intAttr || intAttr.getInt() != static_cast<int64_t>(i & 0xff)) {
+        valuesMatch = false;
+        break;
+      }
+    }
+    EXPECT_TRUE(valuesMatch);
+    foundConst = true;
+  });
+  EXPECT_TRUE(foundConst);
+}
+
+TEST_F(SerializationTest, LongSpecConstantCompositeIsSplit) {
+  OpBuilder builder(module->getRegion());
+  Location loc = UnknownLoc::get(&context);
+  Type i32Type = builder.getIntegerType(32);
+  auto arrayType = spirv::ArrayType::get(i32Type, kLongCompositeSize);
+
+  SmallVector<Attribute> constituents;
+  constituents.reserve(kLongCompositeSize);
+  for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+    std::string name = ("sc" + Twine(i)).str();
+    auto sc =
+        spirv::SpecConstantOp::create(builder, loc, builder.getStringAttr(name),
+                                      builder.getI32IntegerAttr(0));
+    constituents.push_back(SymbolRefAttr::get(sc));
+  }
+  spirv::SpecConstantCompositeOp::create(builder, loc, TypeAttr::get(arrayType),
+                                         builder.getStringAttr("long_scc"),
+                                         builder.getArrayAttr(constituents));
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+  EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpSpecConstantComposite));
+  EXPECT_TRUE(
+      hasOpcode(binary, spirv::Opcode::OpSpecConstantCompositeContinuedINTEL));
+  EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+  MLIRContext freshContext;
+  freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+  OwningOpRef<spirv::ModuleOp> roundTripped =
+      spirv::deserialize(binary, &freshContext);
+  ASSERT_TRUE(roundTripped);
+  bool foundSCC = false;
+  roundTripped->walk([&](spirv::SpecConstantCompositeOp op) {
+    ArrayAttr rtConstituents = op.getConstituents();
+    ASSERT_EQ(rtConstituents.size(), kLongCompositeSize);
+    bool namesMatch = true;
+    for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+      auto symRef = dyn_cast<SymbolRefAttr>(rtConstituents[i]);
+      if (!symRef ||
+          symRef.getLeafReference().getValue() != ("sc" + Twine(i)).str()) {
+        namesMatch = false;
+        break;
+      }
+    }
+    EXPECT_TRUE(namesMatch);
+    foundSCC = true;
+  });
+  EXPECT_TRUE(foundSCC);
+}
+
+TEST_F(SerializationTest, LongCompositeConstructIsSplit) {
+  OpBuilder builder(module->getRegion());
+  Location loc = UnknownLoc::get(&context);
+  Type i32Type = builder.getIntegerType(32);
+  auto arrayType = spirv::ArrayType::get(i32Type, kLongCompositeSize);
+  auto funcType = builder.getFunctionType({}, {arrayType});
+
+  auto funcOp = spirv::FuncOp::create(builder, loc, "long_composite_construct",
+                                      funcType, spirv::FunctionControl::None);
+  Block *entry = funcOp.addEntryBlock();
+  OpBuilder bodyBuilder = OpBuilder::atBlockBegin(entry);
+  SmallVector<Value> constituents;
+  constituents.reserve(kLongCompositeSize);
+  for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+    auto cst = spirv::ConstantOp::create(
+        bodyBuilder, loc, i32Type, bodyBuilder.getI32IntegerAttr(i & 0xff));
+    constituents.push_back(cst.getResult());
+  }
+  auto cc = spirv::CompositeConstructOp::create(bodyBuilder, loc, arrayType,
+                                                constituents);
+  spirv::ReturnValueOp::create(bodyBuilder, loc, cc.getResult());
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+  EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpCompositeConstruct));
+  EXPECT_TRUE(
+      hasOpcode(binary, spirv::Opcode::OpCompositeConstructContinuedINTEL));
+  EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+  MLIRContext freshContext;
+  freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+  OwningOpRef<spirv::ModuleOp> roundTripped =
+      spirv::deserialize(binary, &freshContext);
+  ASSERT_TRUE(roundTripped);
+  bool foundCC = false;
+  roundTripped->walk([&](spirv::CompositeConstructOp op) {
+    auto rtConstituents = op.getConstituents();
+    ASSERT_EQ(rtConstituents.size(), kLongCompositeSize);
+    bool valuesMatch = true;
+    for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+      auto definingCst = rtConstituents[i].getDefiningOp<spirv::ConstantOp>();
+      if (!definingCst) {
+        valuesMatch = false;
+        break;
+      }
+      auto intAttr = dyn_cast<IntegerAttr>(definingCst.getValue());
+      if (!intAttr || intAttr.getInt() != static_cast<int64_t>(i & 0xff)) {
+        valuesMatch = false;
+        break;
+      }
+    }
+    EXPECT_TRUE(valuesMatch);
+    foundCC = true;
+  });
+  EXPECT_TRUE(foundCC);
+}
+
+namespace {
+unsigned countOpcode(SmallVectorImpl<uint32_t> &binary, spirv::Opcode target) {
+  unsigned count = 0;
+  size_t offset = spirv::kHeaderWordCount;
+  size_t binarySize = binary.size();
+  while (offset < binarySize) {
+    uint32_t wordCount = binary[offset] >> 16;
+    if (!wordCount || offset + wordCount > binarySize)
+      break;
+    auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+    if (op == target)
+      ++count;
+    offset += wordCount;
+  }
+  return count;
+}
+} // namespace
+
+TEST_F(SerializationTest, LongCompositeDoesNotDuplicateDeclaredCapability) {
+  // Pre-declare LongCompositesINTEL / SPV_INTEL_long_composites in the VCE
+  // triple. The serializer must not emit a second OpCapability/OpExtension
+  // when a long composite triggers `addLongCompositesCapability()`.
+  module->getOperation()->setAttr(
+      spirv::ModuleOp::getVCETripleAttrName(),
+      spirv::VerCapExtAttr::get(
+          spirv::Version::V_1_0, {spirv::Capability::LongCompositesINTEL},
+          {spirv::Extension::SPV_INTEL_long_composites}, &context));
+
+  OpBuilder builder(module->getRegion());
+  Type i32Type = builder.getIntegerType(32);
+  SmallVector<Type> memberTypes(kLongCompositeSize, i32Type);
+  SmallVector<spirv::StructType::OffsetInfo> offsets(kLongCompositeSize, 0);
+  auto structType = spirv::StructType::get(memberTypes, offsets);
+  addGlobalVar(structType, "var0");
+
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+  EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+  EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStructContinuedINTEL));
+  EXPECT_EQ(countOpcode(binary, spirv::Opcode::OpCapability), 1u);
+  EXPECT_EQ(countOpcode(binary, spirv::Opcode::OpExtension), 1u);
+}