[mlir][SPIR-V] Add support for SPV_INTEL_long_composites extension (#195685)
Add serialization and deserialization support for the
SPV_INTEL_long_composites extension, which allows splitting
composite/struct instructions that exceed the SPIR-V 16-bit word count
limit (65535 words) into a head instruction followed by one or more
continuation instructions
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a9a8947..fcc4be4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4686,6 +4686,10 @@
def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
def SPIRV_OC_OpExpectKHR : I32EnumAttrCase<"OpExpectKHR", 5631>;
def SPIRV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
+def SPIRV_OC_OpTypeStructContinuedINTEL : I32EnumAttrCase<"OpTypeStructContinuedINTEL", 6090>;
+def SPIRV_OC_OpConstantCompositeContinuedINTEL : I32EnumAttrCase<"OpConstantCompositeContinuedINTEL", 6091>;
+def SPIRV_OC_OpSpecConstantCompositeContinuedINTEL : I32EnumAttrCase<"OpSpecConstantCompositeContinuedINTEL", 6092>;
+def SPIRV_OC_OpCompositeConstructContinuedINTEL : I32EnumAttrCase<"OpCompositeConstructContinuedINTEL", 6096>;
def SPIRV_OC_OpConvertFToBF16INTEL : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
def SPIRV_OC_OpConvertBF16ToFINTEL : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrierArriveINTEL", 6142>;
@@ -4810,6 +4814,10 @@
SPIRV_OC_OpEmitMeshTasksEXT, SPIRV_OC_OpSetMeshOutputsEXT,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpExpectKHR, SPIRV_OC_OpAtomicFAddEXT,
+ SPIRV_OC_OpTypeStructContinuedINTEL,
+ SPIRV_OC_OpConstantCompositeContinuedINTEL,
+ SPIRV_OC_OpSpecConstantCompositeContinuedINTEL,
+ SPIRV_OC_OpCompositeConstructContinuedINTEL,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
diff --git a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
index 4a41163..34aa774 100644
--- a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
+++ b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h
@@ -17,6 +17,7 @@
#include "mlir/Support/LLVM.h"
#include <cstdint>
+#include <optional>
namespace mlir {
namespace spirv {
@@ -58,6 +59,25 @@
return str;
}
+/// Returns the SPV_INTEL_long_composites continuation opcode that may follow
+/// `parent`, or std::nullopt if `parent` is not a splittable composite/struct
+/// op.
+inline std::optional<spirv::Opcode>
+getContinuationOpcode(spirv::Opcode parent) {
+ switch (parent) {
+ case spirv::Opcode::OpTypeStruct:
+ return spirv::Opcode::OpTypeStructContinuedINTEL;
+ case spirv::Opcode::OpConstantComposite:
+ return spirv::Opcode::OpConstantCompositeContinuedINTEL;
+ case spirv::Opcode::OpSpecConstantComposite:
+ return spirv::Opcode::OpSpecConstantCompositeContinuedINTEL;
+ case spirv::Opcode::OpCompositeConstruct:
+ return spirv::Opcode::OpCompositeConstructContinuedINTEL;
+ default:
+ return std::nullopt;
+ }
+}
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 21a8400..f65b559 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -125,11 +125,45 @@
return success();
}
+void spirv::Deserializer::mergeLongCompositeContinuations(
+ spirv::Opcode opcode, ArrayRef<uint32_t> &operands,
+ SmallVectorImpl<uint32_t> &mergedStorage) {
+ std::optional<spirv::Opcode> continuationOp = getContinuationOpcode(opcode);
+ if (!continuationOp)
+ return;
+
+ size_t binarySize = binary.size();
+ auto isNextContinuation = [&]() {
+ if (curOffset >= binarySize)
+ return false;
+ uint32_t wordCount = binary[curOffset] >> 16;
+ if (wordCount == 0 || curOffset + wordCount > binarySize)
+ return false;
+ return extractOpcode(binary[curOffset]) == *continuationOp;
+ };
+
+ if (!isNextContinuation())
+ return;
+
+ mergedStorage.assign(operands);
+ do {
+ spirv::Opcode contOpcode;
+ ArrayRef<uint32_t> contOperands;
+ if (failed(sliceInstruction(contOpcode, contOperands, *continuationOp)))
+ return;
+ llvm::append_range(mergedStorage, contOperands);
+ } while (isNextContinuation());
+ operands = mergedStorage;
+}
+
LogicalResult spirv::Deserializer::processInstruction(
spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
<< spirv::stringifyOpcode(opcode) << "\n");
+ SmallVector<uint32_t, 0> mergedStorage;
+ mergeLongCompositeContinuations(opcode, operands, mergedStorage);
+
// First dispatch all the instructions whose opcode does not correspond to
// those that have a direct mirror in the SPIR-V dialect
switch (opcode) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index f3dab42..b2adbb5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -518,6 +518,15 @@
sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
std::optional<spirv::Opcode> expectedOpcode = std::nullopt);
+ /// If `opcode` is a SPV_INTEL_long_composites splittable opcode and the
+ /// next binary instruction(s) are matching `*ContinuedINTEL` ops, consumes
+ /// them and rebinds `operands` to a buffer (held in `mergedStorage`)
+ /// containing the parent + continuation operands concatenated.
+ void
+ mergeLongCompositeContinuations(spirv::Opcode opcode,
+ ArrayRef<uint32_t> &operands,
+ SmallVectorImpl<uint32_t> &mergedStorage);
+
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
/// This method is the main entrance for handling SPIR-V instruction; it
/// checks the instruction opcode and dispatches to the corresponding handler.
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index a2c942d..841fc55 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -121,8 +121,8 @@
operands.push_back(constituentID);
}
- encodeInstructionInto(typesGlobalValues,
- spirv::Opcode::OpSpecConstantComposite, operands);
+ encodeInstructionWithContinuationInto(
+ typesGlobalValues, spirv::Opcode::OpSpecConstantComposite, operands);
specConstIDMap[op.getSymName()] = resultID;
return processName(resultID, op.getSymName());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 5c9e378..7a2eaf3 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -213,6 +213,48 @@
{static_cast<uint32_t>(cap)});
}
+void Serializer::addLongCompositesCapability() {
+ if (longCompositesEmitted)
+ return;
+ longCompositesEmitted = true;
+ auto vceTriple = module.getVceTriple();
+ if (!llvm::is_contained(vceTriple->getCapabilities(),
+ spirv::Capability::LongCompositesINTEL))
+ encodeInstructionInto(
+ capabilities, spirv::Opcode::OpCapability,
+ {static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL)});
+ if (!llvm::is_contained(vceTriple->getExtensions(),
+ spirv::Extension::SPV_INTEL_long_composites)) {
+ SmallVector<uint32_t, 8> extName;
+ spirv::encodeStringLiteralInto(
+ extName,
+ spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
+ encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
+ }
+}
+
+void Serializer::encodeInstructionWithContinuationInto(
+ SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
+ ArrayRef<uint32_t> operands) {
+ if (1 + operands.size() <= spirv::kMaxWordCount) {
+ encodeInstructionInto(binary, op, operands);
+ return;
+ }
+
+ std::optional<spirv::Opcode> continuationOp =
+ spirv::getContinuationOpcode(op);
+ assert(continuationOp && "op is not a splittable composite/struct opcode");
+
+ const unsigned chunk = spirv::kMaxWordCount - 1;
+ encodeInstructionInto(binary, op, operands.take_front(chunk));
+ for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
+ rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
+ encodeInstructionInto(binary, *continuationOp, rest.take_front(chunk));
+ }
+
+ addLongCompositesCapability();
+}
+
void Serializer::processDebugInfo() {
if (!options.emitDebugInfo)
return;
@@ -560,7 +602,11 @@
typeIDMap[type] = typeID;
- encodeInstructionInto(typesGlobalValues, typeEnum, operands);
+ if (typeEnum == spirv::Opcode::OpTypeStruct)
+ encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
+ operands);
+ else
+ encodeInstructionInto(typesGlobalValues, typeEnum, operands);
if (recursiveStructInfos.count(type) != 0) {
// This recursive struct type is emitted already, now the OpTypePointer
@@ -1024,8 +1070,8 @@
return 0;
}
}
- spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
- encodeInstructionInto(typesGlobalValues, opcode, operands);
+ encodeInstructionWithContinuationInto(
+ typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
return resultID;
}
@@ -1104,8 +1150,8 @@
}
}
}
- spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
- encodeInstructionInto(typesGlobalValues, opcode, operands);
+ encodeInstructionWithContinuationInto(
+ typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
return resultID;
}
@@ -1605,6 +1651,9 @@
return processBranchConditionalOp(op);
})
.Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
+ .Case([&](spirv::CompositeConstructOp op) {
+ return processCompositeConstructOp(op);
+ })
.Case([&](spirv::EXTConstantCompositeReplicateOp op) {
return processConstantCompositeReplicateOp(op);
})
@@ -1645,6 +1694,41 @@
[&](Operation *op) { return dispatchToAutogenSerialization(op); });
}
+LogicalResult
+Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
+ Location loc = op.getLoc();
+
+ uint32_t resultTypeID = 0;
+ if (failed(processType(loc, op.getType(), resultTypeID)))
+ return failure();
+
+ uint32_t resultID = getNextID();
+ valueIDMap[op.getResult()] = resultID;
+
+ SmallVector<uint32_t, 8> operands;
+ operands.reserve(2 + op.getConstituents().size());
+ operands.push_back(resultTypeID);
+ operands.push_back(resultID);
+ for (Value constituent : op.getConstituents()) {
+ uint32_t id = getValueID(constituent);
+ assert(id && "use before def!");
+ operands.push_back(id);
+ }
+
+ if (failed(emitDebugLine(functionBody, loc)))
+ return failure();
+
+ encodeInstructionWithContinuationInto(
+ functionBody, spirv::Opcode::OpCompositeConstruct, operands);
+
+ for (auto attr : op->getAttrs()) {
+ if (failed(processDecoration(loc, resultID, attr)))
+ return failure();
+ }
+
+ return success();
+}
+
LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
StringRef extInstSet,
uint32_t opcode) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index eb5ac0d..e43556f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -104,10 +104,23 @@
LogicalResult processExtension();
+ /// Encodes `op` + `operands` into `binary`, splitting via the
+ /// SPV_INTEL_long_composites continuation opcode when the total word count
+ /// would exceed kMaxWordCount. `op` must be a splittable composite/struct
+ /// opcode (see getContinuationOpcode). The capability and extension are
+ /// emitted lazily on first split.
+ void encodeInstructionWithContinuationInto(SmallVectorImpl<uint32_t> &binary,
+ spirv::Opcode op,
+ ArrayRef<uint32_t> operands);
+
+ void addLongCompositesCapability();
+
void processMemoryModel();
LogicalResult processConstantOp(spirv::ConstantOp op);
+ LogicalResult processCompositeConstructOp(spirv::CompositeConstructOp op);
+
LogicalResult processConstantCompositeReplicateOp(
spirv::EXTConstantCompositeReplicateOp op);
@@ -387,6 +400,8 @@
/// The next available result <id>.
uint32_t nextID = 1;
+ bool longCompositesEmitted = false;
+
// The following are for different SPIR-V instruction sections. They follow
// the logical layout of a SPIR-V module.
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index af55296..b0413a8 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/Target/SPIRV/Deserialization.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@@ -229,3 +230,323 @@
};
EXPECT_FALSE(scanInstruction(hasVarName));
}
+
+//===----------------------------------------------------------------------===//
+// SPV_INTEL_long_composites: composites whose binary form would exceed the
+// SPIR-V 16-bit word-count limit are split into a parent + *ContinuedINTEL ops
+// on serialization, and merged back on deserialization. These tests build the
+// large composites programmatically so that the IR doesn't have to expand
+// thousands of operands literally.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Picked to comfortably exceed kMaxWordCount = 65535 for any of the splittable
+// composite/struct opcodes -- each one packs at most kMaxWordCount - {1,2,3}
+// operands into the parent word, so 65540 always triggers a split.
+constexpr unsigned kLongCompositeSize = 65540;
+
+bool hasOpcode(SmallVectorImpl<uint32_t> &binary, spirv::Opcode target) {
+ size_t offset = spirv::kHeaderWordCount;
+ while (offset < binary.size()) {
+ uint32_t wordCount = binary[offset] >> 16;
+ if (!wordCount || offset + wordCount > binary.size())
+ return false;
+ auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+ if (op == target)
+ return true;
+ offset += wordCount;
+ }
+ return false;
+}
+
+bool hasLongCompositesCapabilityAndExtension(
+ SmallVectorImpl<uint32_t> &binary) {
+ bool foundCap = false;
+ bool foundExt = false;
+ size_t offset = spirv::kHeaderWordCount;
+ size_t binarySize = binary.size();
+ while (offset < binarySize) {
+ uint32_t wordCount = binary[offset] >> 16;
+ if (!wordCount || offset + wordCount > binarySize)
+ break;
+ auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+ ArrayRef<uint32_t> operands(binary.data() + offset + 1, wordCount - 1);
+ if (op == spirv::Opcode::OpCapability && !operands.empty() &&
+ operands[0] ==
+ static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL))
+ foundCap = true;
+ if (op == spirv::Opcode::OpExtension) {
+ unsigned idx = 0;
+ if (spirv::decodeStringLiteral(operands, idx) ==
+ spirv::stringifyExtension(
+ spirv::Extension::SPV_INTEL_long_composites))
+ foundExt = true;
+ }
+ offset += wordCount;
+ }
+ return foundCap && foundExt;
+}
+
+// Verifies that no instruction in the binary has a word count exceeding the
+// SPIR-V 16-bit limit (which would mean the splitting logic failed).
+bool allInstructionsWithinWordLimit(SmallVectorImpl<uint32_t> &binary) {
+ size_t offset = spirv::kHeaderWordCount;
+ size_t binarySize = binary.size();
+ while (offset < binarySize) {
+ uint32_t wordCount = binary[offset] >> 16;
+ if (!wordCount || wordCount > spirv::kMaxWordCount)
+ return false;
+ offset += wordCount;
+ }
+ return true;
+}
+
+} // namespace
+
+TEST_F(SerializationTest, LongTypeStructIsSplit) {
+ OpBuilder builder(module->getRegion());
+ Type i32Type = builder.getIntegerType(32);
+ Type f32Type = builder.getF32Type();
+ SmallVector<Type> memberTypes;
+ memberTypes.reserve(kLongCompositeSize);
+ for (unsigned i = 0; i < kLongCompositeSize; ++i)
+ memberTypes.push_back((i & 1) ? f32Type : i32Type);
+ SmallVector<spirv::StructType::OffsetInfo> offsets(kLongCompositeSize, 0);
+ auto structType = spirv::StructType::get(memberTypes, offsets);
+ addGlobalVar(structType, "var0");
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+ EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStruct));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStructContinuedINTEL));
+ EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+ MLIRContext freshContext;
+ freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+ OwningOpRef<spirv::ModuleOp> roundTripped =
+ spirv::deserialize(binary, &freshContext);
+ ASSERT_TRUE(roundTripped);
+ bool foundStruct = false;
+ roundTripped->walk([&](spirv::GlobalVariableOp gv) {
+ auto ptrType = dyn_cast<spirv::PointerType>(gv.getType());
+ if (!ptrType)
+ return;
+ auto rtStruct = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!rtStruct)
+ return;
+ ASSERT_EQ(rtStruct.getNumElements(), kLongCompositeSize);
+ bool typesMatch = true;
+ for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+ Type expected = (i & 1) ? Type(Float32Type::get(&freshContext))
+ : Type(IntegerType::get(&freshContext, 32));
+ if (rtStruct.getElementType(i) != expected) {
+ typesMatch = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(typesMatch);
+ foundStruct = true;
+ });
+ EXPECT_TRUE(foundStruct);
+}
+
+TEST_F(SerializationTest, LongConstantCompositeIsSplit) {
+ OpBuilder builder(module->getRegion());
+ Location loc = UnknownLoc::get(&context);
+ Type i32Type = builder.getIntegerType(32);
+ auto arrayType = spirv::ArrayType::get(i32Type, kLongCompositeSize);
+ auto funcType = builder.getFunctionType({}, {arrayType});
+
+ auto funcOp = spirv::FuncOp::create(builder, loc, "long_array_const",
+ funcType, spirv::FunctionControl::None);
+ Block *entry = funcOp.addEntryBlock();
+ OpBuilder bodyBuilder = OpBuilder::atBlockBegin(entry);
+ SmallVector<Attribute> elements;
+ elements.reserve(kLongCompositeSize);
+ for (unsigned i = 0; i < kLongCompositeSize; ++i)
+ elements.push_back(bodyBuilder.getI32IntegerAttr(i & 0xff));
+ auto arrayAttr = bodyBuilder.getArrayAttr(elements);
+ auto cst = spirv::ConstantOp::create(bodyBuilder, loc, arrayType, arrayAttr);
+ spirv::ReturnValueOp::create(bodyBuilder, loc, cst.getResult());
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+ EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpConstantComposite));
+ EXPECT_TRUE(
+ hasOpcode(binary, spirv::Opcode::OpConstantCompositeContinuedINTEL));
+ EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+ MLIRContext freshContext;
+ freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+ OwningOpRef<spirv::ModuleOp> roundTripped =
+ spirv::deserialize(binary, &freshContext);
+ ASSERT_TRUE(roundTripped);
+ bool foundConst = false;
+ roundTripped->walk([&](spirv::ConstantOp op) {
+ auto arr = dyn_cast<ArrayAttr>(op.getValue());
+ if (!arr)
+ return;
+ ASSERT_EQ(arr.size(), kLongCompositeSize);
+ bool valuesMatch = true;
+ for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+ auto intAttr = dyn_cast<IntegerAttr>(arr[i]);
+ if (!intAttr || intAttr.getInt() != static_cast<int64_t>(i & 0xff)) {
+ valuesMatch = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(valuesMatch);
+ foundConst = true;
+ });
+ EXPECT_TRUE(foundConst);
+}
+
+TEST_F(SerializationTest, LongSpecConstantCompositeIsSplit) {
+ OpBuilder builder(module->getRegion());
+ Location loc = UnknownLoc::get(&context);
+ Type i32Type = builder.getIntegerType(32);
+ auto arrayType = spirv::ArrayType::get(i32Type, kLongCompositeSize);
+
+ SmallVector<Attribute> constituents;
+ constituents.reserve(kLongCompositeSize);
+ for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+ std::string name = ("sc" + Twine(i)).str();
+ auto sc =
+ spirv::SpecConstantOp::create(builder, loc, builder.getStringAttr(name),
+ builder.getI32IntegerAttr(0));
+ constituents.push_back(SymbolRefAttr::get(sc));
+ }
+ spirv::SpecConstantCompositeOp::create(builder, loc, TypeAttr::get(arrayType),
+ builder.getStringAttr("long_scc"),
+ builder.getArrayAttr(constituents));
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+ EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpSpecConstantComposite));
+ EXPECT_TRUE(
+ hasOpcode(binary, spirv::Opcode::OpSpecConstantCompositeContinuedINTEL));
+ EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+ MLIRContext freshContext;
+ freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+ OwningOpRef<spirv::ModuleOp> roundTripped =
+ spirv::deserialize(binary, &freshContext);
+ ASSERT_TRUE(roundTripped);
+ bool foundSCC = false;
+ roundTripped->walk([&](spirv::SpecConstantCompositeOp op) {
+ ArrayAttr rtConstituents = op.getConstituents();
+ ASSERT_EQ(rtConstituents.size(), kLongCompositeSize);
+ bool namesMatch = true;
+ for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+ auto symRef = dyn_cast<SymbolRefAttr>(rtConstituents[i]);
+ if (!symRef ||
+ symRef.getLeafReference().getValue() != ("sc" + Twine(i)).str()) {
+ namesMatch = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(namesMatch);
+ foundSCC = true;
+ });
+ EXPECT_TRUE(foundSCC);
+}
+
+TEST_F(SerializationTest, LongCompositeConstructIsSplit) {
+ OpBuilder builder(module->getRegion());
+ Location loc = UnknownLoc::get(&context);
+ Type i32Type = builder.getIntegerType(32);
+ auto arrayType = spirv::ArrayType::get(i32Type, kLongCompositeSize);
+ auto funcType = builder.getFunctionType({}, {arrayType});
+
+ auto funcOp = spirv::FuncOp::create(builder, loc, "long_composite_construct",
+ funcType, spirv::FunctionControl::None);
+ Block *entry = funcOp.addEntryBlock();
+ OpBuilder bodyBuilder = OpBuilder::atBlockBegin(entry);
+ SmallVector<Value> constituents;
+ constituents.reserve(kLongCompositeSize);
+ for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+ auto cst = spirv::ConstantOp::create(
+ bodyBuilder, loc, i32Type, bodyBuilder.getI32IntegerAttr(i & 0xff));
+ constituents.push_back(cst.getResult());
+ }
+ auto cc = spirv::CompositeConstructOp::create(bodyBuilder, loc, arrayType,
+ constituents);
+ spirv::ReturnValueOp::create(bodyBuilder, loc, cc.getResult());
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+ EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpCompositeConstruct));
+ EXPECT_TRUE(
+ hasOpcode(binary, spirv::Opcode::OpCompositeConstructContinuedINTEL));
+ EXPECT_TRUE(hasLongCompositesCapabilityAndExtension(binary));
+
+ MLIRContext freshContext;
+ freshContext.getOrLoadDialect<spirv::SPIRVDialect>();
+ OwningOpRef<spirv::ModuleOp> roundTripped =
+ spirv::deserialize(binary, &freshContext);
+ ASSERT_TRUE(roundTripped);
+ bool foundCC = false;
+ roundTripped->walk([&](spirv::CompositeConstructOp op) {
+ auto rtConstituents = op.getConstituents();
+ ASSERT_EQ(rtConstituents.size(), kLongCompositeSize);
+ bool valuesMatch = true;
+ for (unsigned i = 0; i < kLongCompositeSize; ++i) {
+ auto definingCst = rtConstituents[i].getDefiningOp<spirv::ConstantOp>();
+ if (!definingCst) {
+ valuesMatch = false;
+ break;
+ }
+ auto intAttr = dyn_cast<IntegerAttr>(definingCst.getValue());
+ if (!intAttr || intAttr.getInt() != static_cast<int64_t>(i & 0xff)) {
+ valuesMatch = false;
+ break;
+ }
+ }
+ EXPECT_TRUE(valuesMatch);
+ foundCC = true;
+ });
+ EXPECT_TRUE(foundCC);
+}
+
+namespace {
+unsigned countOpcode(SmallVectorImpl<uint32_t> &binary, spirv::Opcode target) {
+ unsigned count = 0;
+ size_t offset = spirv::kHeaderWordCount;
+ size_t binarySize = binary.size();
+ while (offset < binarySize) {
+ uint32_t wordCount = binary[offset] >> 16;
+ if (!wordCount || offset + wordCount > binarySize)
+ break;
+ auto op = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+ if (op == target)
+ ++count;
+ offset += wordCount;
+ }
+ return count;
+}
+} // namespace
+
+TEST_F(SerializationTest, LongCompositeDoesNotDuplicateDeclaredCapability) {
+ // Pre-declare LongCompositesINTEL / SPV_INTEL_long_composites in the VCE
+ // triple. The serializer must not emit a second OpCapability/OpExtension
+ // when a long composite triggers `addLongCompositesCapability()`.
+ module->getOperation()->setAttr(
+ spirv::ModuleOp::getVCETripleAttrName(),
+ spirv::VerCapExtAttr::get(
+ spirv::Version::V_1_0, {spirv::Capability::LongCompositesINTEL},
+ {spirv::Extension::SPV_INTEL_long_composites}, &context));
+
+ OpBuilder builder(module->getRegion());
+ Type i32Type = builder.getIntegerType(32);
+ SmallVector<Type> memberTypes(kLongCompositeSize, i32Type);
+ SmallVector<spirv::StructType::OffsetInfo> offsets(kLongCompositeSize, 0);
+ auto structType = spirv::StructType::get(memberTypes, offsets);
+ addGlobalVar(structType, "var0");
+
+ ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
+ EXPECT_TRUE(allInstructionsWithinWordLimit(binary));
+ EXPECT_TRUE(hasOpcode(binary, spirv::Opcode::OpTypeStructContinuedINTEL));
+ EXPECT_EQ(countOpcode(binary, spirv::Opcode::OpCapability), 1u);
+ EXPECT_EQ(countOpcode(binary, spirv::Opcode::OpExtension), 1u);
+}