[mlir][spirv] Add support for SPV_ARM_graph extension - part 2 (#156665)
This is the second patch to add support for the `SPV_ARM_graph` SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
`Graph` abstraction for expressing dataflow computations over full
resources.
The part 2 implementation includes:
- Serialization and deserialization support for:
- `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`,
`OpGraphEndARM`
- `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM`
- Tests covering binary round-tripping.
Graphs currently support only `SPV_ARM_tensors`, but are designed to
generalize to other resource types, such as images.
Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC:
https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947
---------
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
GitOrigin-RevId: 1a746b6ca3862165360c48fff5d807d5b400b541
diff --git a/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index ee18cf8..c27f9aa 100644
--- a/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -86,6 +86,13 @@
if (auto undef = getUndefType(id)) {
return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
}
+ if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantARMInfo = getGraphConstantARM(id)) {
+ IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
+ Type resultType = graphConstantARMInfo->resultType;
+ return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType,
+ graphConstantID);
+ }
return valueMap.lookup(id);
}
@@ -180,6 +187,7 @@
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeTensorARM:
+ case spirv::Opcode::OpTypeGraphARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
@@ -208,12 +216,26 @@
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
+ case spirv::Opcode::OpGraphConstantARM:
+ return processGraphConstantARM(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
+ case spirv::Opcode::OpGraphEntryPointARM:
+ if (deferInstructions) {
+ deferredInstructions.emplace_back(opcode, operands);
+ return success();
+ }
+ return processGraphEntryPointARM(operands);
+ case spirv::Opcode::OpGraphARM:
+ return processGraphARM(operands);
+ case spirv::Opcode::OpGraphSetOutputARM:
+ return processOpGraphSetOutputARM(operands);
+ case spirv::Opcode::OpGraphEndARM:
+ return processGraphEndARM(operands);
case spirv::Opcode::OpLabel:
return processLabel(operands);
case spirv::Opcode::OpBranch:
diff --git a/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3625dd2..0c3e87a 100644
--- a/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -669,6 +669,200 @@
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc,
+ "missing graph defintion in OpGraphEntryPointARM");
+ }
+
+ unsigned wordIndex = 0;
+ uint32_t graphID = operands[wordIndex++];
+ if (!graphMap.contains(graphID)) {
+ return emitError(unknownLoc,
+ "missing graph definition/declaration with id ")
+ << graphID;
+ }
+
+ spirv::GraphARMOp graphARM = graphMap[graphID];
+ StringRef name = decodeStringLiteral(operands, wordIndex);
+ graphARM.setSymName(name);
+ graphARM.setEntryPoint(true);
+
+ SmallVector<Attribute, 4> interface;
+ for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
+ if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+ interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+ } else {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+ }
+ }
+
+ // RAII guard to reset the insertion point to previous value when done.
+ OpBuilder::InsertionGuard insertionGuard(opBuilder);
+ opBuilder.setInsertionPoint(graphARM);
+ opBuilder.create<spirv::GraphEntryPointARMOp>(
+ unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ opBuilder.getArrayAttr(interface));
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+ if (curGraph) {
+ return emitError(unknownLoc, "found graph inside graph");
+ }
+ // Get the result type.
+ if (operands.size() < 2) {
+ return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+ }
+
+ Type type = getType(operands[0]);
+ if (!type || !isa<GraphType>(type)) {
+ return emitError(unknownLoc, "unknown graph type from <id> ")
+ << operands[0];
+ }
+ auto graphType = cast<GraphType>(type);
+ if (graphType.getNumResults() <= 0) {
+ return emitError(unknownLoc, "expected at least one result");
+ }
+
+ uint32_t graphID = operands[1];
+ if (graphMap.count(graphID)) {
+ return emitError(unknownLoc, "duplicate graph definition/declaration");
+ }
+
+ std::string graphName = getGraphSymbol(graphID);
+ auto graphOp =
+ opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType);
+ curGraph = graphMap[graphID] = graphOp;
+ Block *entryBlock = graphOp.addEntryBlock();
+ LLVM_DEBUG({
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ logger.startLine() << "[graph] name: " << graphName << "\n";
+ logger.startLine() << "[graph] type: " << graphType << "\n";
+ logger.startLine() << "[graph] ID: " << graphID << "\n";
+ logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+ logger.indent();
+ });
+
+ // Parse the op argument instructions.
+ for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> operands;
+ if (failed(sliceInstruction(opcode, operands,
+ spirv::Opcode::OpGraphInputARM))) {
+ return failure();
+ }
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "expected result type, result <id> and "
+ "input index for OpGraphInputARM");
+ }
+
+ Type argDefinedType = getType(operands[0]);
+ if (!argDefinedType) {
+ return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+ }
+
+ if (argDefinedType != argType) {
+ return emitError(unknownLoc,
+ "mismatch in argument type between graph type "
+ "definition ")
+ << graphType << " and argument type definition " << argDefinedType
+ << " at argument " << index;
+ }
+ if (getValue(operands[1])) {
+ return emitError(unknownLoc, "duplicate definition of result <id> ")
+ << operands[1];
+ }
+
+ IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+ if (!inputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read inputIndex value from constant op ")
+ << operands[2];
+ }
+ BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+ valueMap[operands[1]] = argValue;
+ }
+
+ graphOutputs.resize(graphType.getNumResults());
+
+ // RAII guard to reset the insertion point to the module's region after
+ // deserializing the body of this function.
+ OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+ blockMap[graphID] = entryBlock;
+ if (failed(createGraphBlock(graphID))) {
+ return failure();
+ }
+
+ // Process all the instructions in the graph until and including
+ // OpGraphEndARM.
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> instOperands;
+ do {
+ if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+ return failure();
+ }
+
+ if (failed(processInstruction(opcode, instOperands))) {
+ return failure();
+ }
+ } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(
+ unknownLoc,
+ "expected value id and output index for OpGraphSetOutputARM");
+ }
+
+ uint32_t id = operands[0];
+ Value value = getValue(id);
+ if (!value) {
+ return emitError(unknownLoc, "could not find result <id> ") << id;
+ }
+
+ IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+ if (!outputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read outputIndex value from constant op ")
+ << operands[1];
+ }
+ graphOutputs[outputIndexAttr.getInt()] = value;
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+ // Create GraphOutputsARM instruction.
+ opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+
+ // Process OpGraphEndARM.
+ if (!operands.empty()) {
+ return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
+ }
+
+ curBlock = nullptr;
+ curGraph = std::nullopt;
+ graphOutputs.clear();
+
+ LLVM_DEBUG({
+ logger.unindent();
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ });
+ return success();
+}
+
std::optional<std::pair<Attribute, Type>>
spirv::Deserializer::getConstant(uint32_t id) {
auto constIt = constantMap.find(id);
@@ -701,6 +895,14 @@
return funcName;
}
+std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
+ std::string graphName = nameMap.lookup(id).str();
+ if (graphName.empty()) {
+ graphName = "spirv_graph_" + std::to_string(id);
+ }
+ return graphName;
+}
+
std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
@@ -723,6 +925,14 @@
return op;
}
+std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+spirv::Deserializer::getGraphConstantARM(uint32_t id) {
+ auto graphConstIt = graphConstantMap.find(id);
+ if (graphConstIt == graphConstantMap.end())
+ return std::nullopt;
+ return graphConstIt->getSecond();
+}
+
LogicalResult
spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
@@ -944,6 +1154,8 @@
return processMatrixType(operands);
case spirv::Opcode::OpTypeTensorARM:
return processTensorARMType(operands);
+ case spirv::Opcode::OpTypeGraphARM:
+ return processGraphTypeARM(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1312,6 +1524,35 @@
}
LogicalResult
+spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
+ unsigned size = operands.size();
+ if (size < 2) {
+ return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
+ "(result_id, num_inputs, (inout0_type, "
+ "inout1_type, ...))")
+ << size;
+ }
+ uint32_t numInputs = operands[1];
+ SmallVector<Type, 1> argTypes;
+ SmallVector<Type, 1> returnTypes;
+ for (unsigned i = 2; i < size; ++i) {
+ Type inOutTy = getType(operands[i]);
+ if (!inOutTy) {
+ return emitError(unknownLoc,
+ "OpTypeGraphARM references undefined element type.")
+ << operands[i];
+ }
+ if (i - 2 >= numInputs) {
+ returnTypes.push_back(inOutTy);
+ } else {
+ argTypes.push_back(inOutTy);
+ }
+ }
+ typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
+ return success();
+}
+
+LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc,
@@ -1823,6 +2064,34 @@
<< resultType;
}
+LogicalResult
+spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 3) {
+ return emitError(unknownLoc)
+ << "OpGraphConstantARM must have at least 2 operands";
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ uint32_t resultID = operands[1];
+
+ if (!dyn_cast<spirv::TensorArmType>(resultType)) {
+ return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
+ }
+
+ APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
+ Type i32Ty = opBuilder.getIntegerType(32);
+ IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
+ graphConstantMap.try_emplace(
+ resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@
return success();
}
+LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
+ if (!curGraph) {
+ return emitError(unknownLoc, "a graph block must appear inside a graph");
+ }
+
+ // We may have forward declared this block.
+ Block *block = getOrCreateBlock(graphID);
+ LLVM_DEBUG(logger.startLine()
+ << "[block] populating block " << block << "\n");
+ // If we have seen this block, make sure it was just a forward declaration.
+ assert(block->empty() && "re-deserialize the same block!");
+
+ opBuilder.setInsertionPointToStart(block);
+ blockMap[graphID] = curBlock = block;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
diff --git a/lib/Target/SPIRV/Deserialization/Deserializer.h b/lib/Target/SPIRV/Deserialization/Deserializer.h
index db1cc3f..6027f1a 100644
--- a/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -106,6 +106,13 @@
SmallVector<uint32_t> enclosedOpOperands;
};
+/// A struct that collects the info needed to materialize/emit a
+/// GraphConstantARMOp.
+struct GraphConstantARMOpMaterializationInfo {
+ Type resultType;
+ IntegerAttr graphConstantID;
+};
+
//===----------------------------------------------------------------------===//
// Deserializer Declaration
//===----------------------------------------------------------------------===//
@@ -211,9 +218,14 @@
/// exists; otherwise creates one based on the <id>.
std::string getFunctionSymbol(uint32_t id);
- /// Returns a symbol to be used for the specialization constant with the given
- /// result <id>. This tries to use the specialization constant's OpName if
+ /// Returns a symbol to be used for the graph name with the given
+ /// result <id>. This tries to use the graph's OpName if
/// exists; otherwise creates one based on the <id>.
+ std::string getGraphSymbol(uint32_t id);
+
+ /// Returns a symbol to be used for the specialization constant with the
+ /// given result <id>. This tries to use the specialization constant's
+ /// OpName if exists; otherwise creates one based on the <id>.
std::string getSpecConstantSymbol(uint32_t id);
/// Gets the specialization constant with the given result <id>.
@@ -237,6 +249,11 @@
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
TypedAttr defaultValue);
+ /// Gets the GraphConstantARM ID attribute and result type with the given
+ /// result <id>.
+ std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ getGraphConstantARM(uint32_t id);
+
/// Processes the OpVariable instructions at current `offset` into `binary`.
/// It is expected that this method is used for variables that are to be
/// defined at module scope and will be deserialized into a
@@ -306,6 +323,16 @@
LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+ LogicalResult processGraphTypeARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEntryPointARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processOpGraphSetOutputARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEndARM(ArrayRef<uint32_t> operands);
+
LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
@@ -353,6 +380,10 @@
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpGraphConstantARM instruction with the given
+ /// `operands`.
+ LogicalResult processGraphConstantARM(ArrayRef<uint32_t> operands);
+
//===--------------------------------------------------------------------===//
// Debug
//===--------------------------------------------------------------------===//
@@ -450,6 +481,9 @@
/// blocks declared as selection/loop headers are handled.
LogicalResult structurizeControlFlow();
+ /// Creates a block for graph with the given graphID.
+ LogicalResult createGraphBlock(uint32_t graphID);
+
//===--------------------------------------------------------------------===//
// Instruction
//===--------------------------------------------------------------------===//
@@ -546,6 +580,9 @@
/// The current function under construction.
std::optional<spirv::FuncOp> curFunction;
+ /// The current graph under construction.
+ std::optional<spirv::GraphARMOp> curGraph;
+
/// The current block under construction.
Block *curBlock = nullptr;
@@ -599,12 +636,19 @@
DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
specConstOperationMap;
+ // Result <id> to GraphConstantARM ID attribute and result type.
+ DenseMap<uint32_t, spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantMap;
+
// Result <id> to variable mapping.
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
// Result <id> to function mapping.
DenseMap<uint32_t, spirv::FuncOp> funcMap;
+ // Result <id> to function mapping.
+ DenseMap<uint32_t, spirv::GraphARMOp> graphMap;
+
// Result <id> to block mapping.
DenseMap<uint32_t, Block *> blockMap;
@@ -668,6 +712,9 @@
/// Deserialization options.
DeserializationOptions options;
+ /// List of IDs assigned to graph outputs.
+ SmallVector<Value> graphOutputs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d62529b..e9b180a 100644
--- a/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -203,6 +203,16 @@
return success();
}
+LogicalResult
+Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
+ if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
+ op.getGraphConstantIdAttr())) {
+ valueIDMap[op.getResult()] = resultID;
+ return success();
+ }
+ return failure();
+}
+
LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
auto undefType = op.getType();
auto &id = undefValIDMap[undefType];
@@ -368,6 +378,118 @@
return success();
}
+LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
+ if (op.getNumResults() < 1) {
+ return op.emitError("cannot serialize graph with no return types");
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n");
+ assert(functionHeader.empty() && functionBody.empty());
+
+ uint32_t funcID = getOrCreateFunctionID(op.getName());
+ uint32_t fnTypeID = 0;
+ // Generate type of the function.
+ if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
+ return failure();
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM,
+ {fnTypeID, funcID});
+
+ // Declare the parameters.
+ for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+ uint32_t argTypeID = 0;
+ SmallVector<uint32_t, 3> inputOperands;
+
+ if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+ return failure();
+ }
+
+ uint32_t argValueID = getNextID();
+ valueIDMap[arg] = argValueID;
+
+ auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
+ uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
+
+ inputOperands.push_back(argTypeID);
+ inputOperands.push_back(argValueID);
+ inputOperands.push_back(indexID);
+
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
+ inputOperands);
+ }
+
+ if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
+ return failure();
+ if (failed(visitInPrettyBlockOrder(
+ &op.front(), [&](Block *block) { return processBlock(block); },
+ /*skipHeader=*/true))) {
+ return failure();
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName()
+ << "' --\n");
+ // Insert OpGraphEndARM.
+ encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {});
+
+ llvm::append_range(graphs, functionHeader);
+ llvm::append_range(graphs, functionBody);
+ functionHeader.clear();
+ functionBody.clear();
+
+ return success();
+}
+
+LogicalResult
+Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
+ SmallVector<uint32_t, 4> operands;
+ StringRef graph = op.getFn();
+ // Add the graph <id>.
+ uint32_t graphID = getOrCreateFunctionID(graph);
+ operands.push_back(graphID);
+ // Add the name of the graph.
+ spirv::encodeStringLiteralInto(operands, graph);
+
+ // Add the interface values.
+ if (ArrayAttr interface = op.getInterface()) {
+ for (Attribute var : interface.getValue()) {
+ StringRef value = cast<FlatSymbolRefAttr>(var).getValue();
+ if (uint32_t id = getVariableID(value)) {
+ operands.push_back(id);
+ } else {
+ return op.emitError(
+ "referencing undefined global variable."
+ "spirv.GraphEntryPointARM is at the end of spirv.module. All "
+ "referenced variables should already be defined");
+ }
+ }
+ }
+ encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands);
+ return success();
+}
+
+LogicalResult
+Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
+ for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
+ SmallVector<uint32_t, 2> outputOperands;
+
+ Type resType = value.getType();
+ uint32_t resTypeID = 0;
+ if (failed(processType(op.getLoc(), resType, resTypeID))) {
+ return failure();
+ }
+
+ uint32_t outputID = getValueID(value);
+ auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
+ uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
+
+ outputOperands.push_back(outputID);
+ outputOperands.push_back(indexID);
+
+ encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
+ outputOperands);
+ }
+ return success();
+}
+
LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
SmallVector<uint32_t, 4> operands;
SmallVector<StringRef, 2> elidedAttrs;
diff --git a/lib/Target/SPIRV/Serialization/Serializer.cpp b/lib/Target/SPIRV/Serialization/Serializer.cpp
index 7fc7795..b56e778 100644
--- a/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -136,7 +136,7 @@
extensions.size() + extendedSets.size() +
memoryModel.size() + entryPoints.size() +
executionModes.size() + decorations.size() +
- typesGlobalValues.size() + functions.size();
+ typesGlobalValues.size() + functions.size() + graphs.size();
binary.clear();
binary.reserve(moduleSize);
@@ -154,6 +154,7 @@
binary.append(decorations.begin(), decorations.end());
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functions.begin(), functions.end());
+ binary.append(graphs.begin(), graphs.end());
}
#ifndef NDEBUG
@@ -509,6 +510,9 @@
if ((isa<FunctionType>(type) &&
succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
operands))) ||
+ (isa<GraphType>(type) &&
+ succeeded(
+ prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
deferSerialization, serializationCtx))) {
if (deferSerialization)
@@ -539,7 +543,7 @@
return success();
}
- return failure();
+ return emitError(loc, "failed to process type: ") << type;
}
LogicalResult Serializer::prepareBasicType(
@@ -875,6 +879,33 @@
return success();
}
+LogicalResult
+Serializer::prepareGraphType(Location loc, GraphType type,
+ spirv::Opcode &typeEnum,
+ SmallVectorImpl<uint32_t> &operands) {
+ typeEnum = spirv::Opcode::OpTypeGraphARM;
+ assert(type.getNumResults() >= 1 &&
+ "serialization requires at least a return value");
+
+ operands.push_back(type.getNumInputs());
+
+ for (Type argType : type.getInputs()) {
+ uint32_t argTypeID = 0;
+ if (failed(processType(loc, argType, argTypeID)))
+ return failure();
+ operands.push_back(argTypeID);
+ }
+
+ for (Type resType : type.getResults()) {
+ uint32_t resTypeID = 0;
+ if (failed(processType(loc, resType, resTypeID)))
+ return failure();
+ operands.push_back(resTypeID);
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
@@ -1135,6 +1166,41 @@
return resultID;
}
+uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
+ IntegerAttr intAttr) {
+ // De-duplicate graph constants.
+ if (uint32_t id = getGraphConstantARMId(intAttr)) {
+ return id;
+ }
+
+ // Process the type for this graph constant.
+ uint32_t typeID = 0;
+ if (failed(processType(loc, graphConstType, typeID))) {
+ return 0;
+ }
+
+ uint32_t resultID = getNextID();
+ APInt value = intAttr.getValue();
+ unsigned bitwidth = value.getBitWidth();
+ if (bitwidth > 32) {
+ emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
+ << bitwidth << " bits";
+ return 0;
+ }
+ bool isSigned = value.isSignedIntN(bitwidth);
+
+ uint32_t word = 0;
+ if (isSigned) {
+ word = static_cast<int32_t>(value.getSExtValue());
+ } else {
+ word = static_cast<uint32_t>(value.getZExtValue());
+ }
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
+ {typeID, resultID, word});
+ graphConstIDMap[intAttr] = resultID;
+ return resultID;
+}
+
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec) {
if (!isSpec) {
@@ -1469,9 +1535,19 @@
return processConstantCompositeReplicateOp(op);
})
.Case([&](spirv::FuncOp op) { return processFuncOp(op); })
+ .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
+ .Case([&](spirv::GraphEntryPointARMOp op) {
+ return processGraphEntryPointARMOp(op);
+ })
+ .Case([&](spirv::GraphOutputsARMOp op) {
+ return processGraphOutputsARMOp(op);
+ })
.Case([&](spirv::GlobalVariableOp op) {
return processGlobalVariableOp(op);
})
+ .Case([&](spirv::GraphConstantARMOp op) {
+ return processGraphConstantARMOp(op);
+ })
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
diff --git a/lib/Target/SPIRV/Serialization/Serializer.h b/lib/Target/SPIRV/Serialization/Serializer.h
index fb2cecd..add372b 100644
--- a/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/lib/Target/SPIRV/Serialization/Serializer.h
@@ -122,6 +122,8 @@
LogicalResult
processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
+ LogicalResult processGraphConstantARMOp(spirv::GraphConstantARMOp op);
+
/// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA
/// value to use with other operations. The SPIR-V spec recommends that
/// OpUndef be generated at module level. The serialization generates an
@@ -135,6 +137,15 @@
LogicalResult processFuncOp(spirv::FuncOp op);
LogicalResult processFuncParameter(spirv::FuncOp op);
+ /// Processes a SPIR-V GraphARM op.
+ LogicalResult processGraphARMOp(spirv::GraphARMOp op);
+
+ /// Processes a SPIR-V GraphEntryPointARM op.
+ LogicalResult processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op);
+
+ /// Processes a SPIR-V GraphOutputsARMOp op.
+ LogicalResult processGraphOutputsARMOp(spirv::GraphOutputsARMOp op);
+
LogicalResult processVariableOp(spirv::VariableOp op);
/// Process a SPIR-V GlobalVariableOp
@@ -189,6 +200,10 @@
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands);
+ LogicalResult prepareGraphType(Location loc, GraphType type,
+ spirv::Opcode &typeEnum,
+ SmallVectorImpl<uint32_t> &operands);
+
//===--------------------------------------------------------------------===//
// Constant
//===--------------------------------------------------------------------===//
@@ -238,6 +253,13 @@
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
bool isSpec = false);
+ uint32_t getGraphConstantARMId(Attribute value) const {
+ return graphConstIDMap.lookup(value);
+ }
+
+ uint32_t prepareGraphConstantId(Location loc, Type graphConstType,
+ IntegerAttr intAttr);
+
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec = false);
@@ -372,6 +394,7 @@
SmallVector<uint32_t, 0> decorations;
SmallVector<uint32_t, 0> typesGlobalValues;
SmallVector<uint32_t, 0> functions;
+ SmallVector<uint32_t, 0> graphs;
/// Recursive struct references are serialized as OpTypePointer instructions
/// to the recursive struct type. However, the OpTypePointer instruction
@@ -388,15 +411,22 @@
recursiveStructInfos;
/// `functionHeader` contains all the instructions that must be in the first
- /// block in the function, and `functionBody` contains the rest. After
- /// processing FuncOp, the encoded instructions of a function are appended to
- /// `functions`. An example of instructions in `functionHeader` in order:
+ /// block in the function or graph, and `functionBody` contains the rest.
+ /// After processing FuncOp/GraphARMOp, the encoded instructions of a function
+ /// or graph are appended to `functions` or `graphs` respectively. Examples of
+ /// instructions in `functionHeader` in order:
+ ///
+ /// For a FuncOp:
/// OpFunction ...
/// OpFunctionParameter ...
/// OpFunctionParameter ...
/// OpLabel ...
/// OpVariable ...
/// OpVariable ...
+ ///
+ /// For a GraphARMOp
+ /// OpGraphARM ...
+ /// OpGraphInputARM ...
SmallVector<uint32_t, 0> functionHeader;
SmallVector<uint32_t, 0> functionBody;
@@ -412,6 +442,9 @@
/// Map from specialization constant names to their <id>s.
llvm::StringMap<uint32_t> specConstIDMap;
+ /// Map from graph constant ID value to their <id>s.
+ DenseMap<Attribute, uint32_t> graphConstIDMap;
+
/// Map from GlobalVariableOps name to <id>s.
llvm::StringMap<uint32_t> globalVarIDMap;
diff --git a/test/Target/SPIRV/graph-ops.mlir b/test/Target/SPIRV/graph-ops.mlir
new file mode 100644
index 0000000..c956157
--- /dev/null
+++ b/test/Target/SPIRV/graph-ops.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+ // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+ spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+ // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+ spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+ // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+ spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
+ // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+ spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+ // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+ %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+ // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
+ }
+
+ // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = false} {
+ spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+ }
+}