| //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the Deserializer methods for SPIR-V binary instructions. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Deserializer.h" |
| |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Location.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "spirv-deserialization" |
| |
| //===----------------------------------------------------------------------===// |
| // Utility Functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Extracts the opcode from the given first word of a SPIR-V instruction. |
| static inline spirv::Opcode extractOpcode(uint32_t word) { |
| return static_cast<spirv::Opcode>(word & 0xffff); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Instruction |
| //===----------------------------------------------------------------------===// |
| |
| Value spirv::Deserializer::getValue(uint32_t id) { |
| if (auto constInfo = getConstant(id)) { |
| // Materialize a `spv.Constant` op at every use site. |
| return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second, |
| constInfo->first); |
| } |
| if (auto varOp = getGlobalVariable(id)) { |
| auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( |
| unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation())); |
| return addressOfOp.pointer(); |
| } |
| if (auto constOp = getSpecConstant(id)) { |
| auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( |
| unknownLoc, constOp.default_value().getType(), |
| SymbolRefAttr::get(constOp.getOperation())); |
| return referenceOfOp.reference(); |
| } |
| if (auto constCompositeOp = getSpecConstantComposite(id)) { |
| auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( |
| unknownLoc, constCompositeOp.type(), |
| SymbolRefAttr::get(constCompositeOp.getOperation())); |
| return referenceOfOp.reference(); |
| } |
| if (auto specConstOperationInfo = getSpecConstantOperation(id)) { |
| return materializeSpecConstantOperation( |
| id, specConstOperationInfo->enclodesOpcode, |
| specConstOperationInfo->resultTypeID, |
| specConstOperationInfo->enclosedOpOperands); |
| } |
| if (auto undef = getUndefType(id)) { |
| return opBuilder.create<spirv::UndefOp>(unknownLoc, undef); |
| } |
| return valueMap.lookup(id); |
| } |
| |
| LogicalResult |
| spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode, |
| ArrayRef<uint32_t> &operands, |
| Optional<spirv::Opcode> expectedOpcode) { |
| auto binarySize = binary.size(); |
| if (curOffset >= binarySize) { |
| return emitError(unknownLoc, "expected ") |
| << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) |
| : "more") |
| << " instruction"; |
| } |
| |
| // For each instruction, get its word count from the first word to slice it |
| // from the stream properly, and then dispatch to the instruction handler. |
| |
| uint32_t wordCount = binary[curOffset] >> 16; |
| |
| if (wordCount == 0) |
| return emitError(unknownLoc, "word count cannot be zero"); |
| |
| uint32_t nextOffset = curOffset + wordCount; |
| if (nextOffset > binarySize) |
| return emitError(unknownLoc, "insufficient words for the last instruction"); |
| |
| opcode = extractOpcode(binary[curOffset]); |
| operands = binary.slice(curOffset + 1, wordCount - 1); |
| curOffset = nextOffset; |
| return success(); |
| } |
| |
| LogicalResult spirv::Deserializer::processInstruction( |
| spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) { |
| LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " |
| << spirv::stringifyOpcode(opcode) << "\n"); |
| |
| // First dispatch all the instructions whose opcode does not correspond to |
| // those that have a direct mirror in the SPIR-V dialect |
| switch (opcode) { |
| case spirv::Opcode::OpCapability: |
| return processCapability(operands); |
| case spirv::Opcode::OpExtension: |
| return processExtension(operands); |
| case spirv::Opcode::OpExtInst: |
| return processExtInst(operands); |
| case spirv::Opcode::OpExtInstImport: |
| return processExtInstImport(operands); |
| case spirv::Opcode::OpMemberName: |
| return processMemberName(operands); |
| case spirv::Opcode::OpMemoryModel: |
| return processMemoryModel(operands); |
| case spirv::Opcode::OpEntryPoint: |
| case spirv::Opcode::OpExecutionMode: |
| if (deferInstructions) { |
| deferredInstructions.emplace_back(opcode, operands); |
| return success(); |
| } |
| break; |
| case spirv::Opcode::OpVariable: |
| if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { |
| return processGlobalVariable(operands); |
| } |
| break; |
| case spirv::Opcode::OpLine: |
| return processDebugLine(operands); |
| case spirv::Opcode::OpNoLine: |
| return clearDebugLine(); |
| case spirv::Opcode::OpName: |
| return processName(operands); |
| case spirv::Opcode::OpString: |
| return processDebugString(operands); |
| case spirv::Opcode::OpModuleProcessed: |
| case spirv::Opcode::OpSource: |
| case spirv::Opcode::OpSourceContinued: |
| case spirv::Opcode::OpSourceExtension: |
| // TODO: This is debug information embedded in the binary which should be |
| // translated into the spv.module. |
| return success(); |
| case spirv::Opcode::OpTypeVoid: |
| case spirv::Opcode::OpTypeBool: |
| case spirv::Opcode::OpTypeInt: |
| case spirv::Opcode::OpTypeFloat: |
| case spirv::Opcode::OpTypeVector: |
| case spirv::Opcode::OpTypeMatrix: |
| case spirv::Opcode::OpTypeArray: |
| case spirv::Opcode::OpTypeFunction: |
| case spirv::Opcode::OpTypeImage: |
| case spirv::Opcode::OpTypeSampledImage: |
| case spirv::Opcode::OpTypeRuntimeArray: |
| case spirv::Opcode::OpTypeStruct: |
| case spirv::Opcode::OpTypePointer: |
| case spirv::Opcode::OpTypeCooperativeMatrixNV: |
| return processType(opcode, operands); |
| case spirv::Opcode::OpTypeForwardPointer: |
| return processTypeForwardPointer(operands); |
| case spirv::Opcode::OpConstant: |
| return processConstant(operands, /*isSpec=*/false); |
| case spirv::Opcode::OpSpecConstant: |
| return processConstant(operands, /*isSpec=*/true); |
| case spirv::Opcode::OpConstantComposite: |
| return processConstantComposite(operands); |
| case spirv::Opcode::OpSpecConstantComposite: |
| return processSpecConstantComposite(operands); |
| case spirv::Opcode::OpSpecConstantOp: |
| return processSpecConstantOperation(operands); |
| case spirv::Opcode::OpConstantTrue: |
| return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); |
| case spirv::Opcode::OpSpecConstantTrue: |
| return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); |
| case spirv::Opcode::OpConstantFalse: |
| return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); |
| case spirv::Opcode::OpSpecConstantFalse: |
| return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); |
| case spirv::Opcode::OpConstantNull: |
| return processConstantNull(operands); |
| case spirv::Opcode::OpDecorate: |
| return processDecoration(operands); |
| case spirv::Opcode::OpMemberDecorate: |
| return processMemberDecoration(operands); |
| case spirv::Opcode::OpFunction: |
| return processFunction(operands); |
| case spirv::Opcode::OpLabel: |
| return processLabel(operands); |
| case spirv::Opcode::OpBranch: |
| return processBranch(operands); |
| case spirv::Opcode::OpBranchConditional: |
| return processBranchConditional(operands); |
| case spirv::Opcode::OpSelectionMerge: |
| return processSelectionMerge(operands); |
| case spirv::Opcode::OpLoopMerge: |
| return processLoopMerge(operands); |
| case spirv::Opcode::OpPhi: |
| return processPhi(operands); |
| case spirv::Opcode::OpUndef: |
| return processUndef(operands); |
| default: |
| break; |
| } |
| return dispatchToAutogenDeserialization(opcode, operands); |
| } |
| |
| LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( |
| ArrayRef<uint32_t> words, StringRef opName, bool hasResult, |
| unsigned numOperands) { |
| SmallVector<Type, 1> resultTypes; |
| uint32_t valueID = 0; |
| |
| size_t wordIndex = 0; |
| if (hasResult) { |
| if (wordIndex >= words.size()) |
| return emitError(unknownLoc, |
| "expected result type <id> while deserializing for ") |
| << opName; |
| |
| // Decode the type <id> |
| auto type = getType(words[wordIndex]); |
| if (!type) |
| return emitError(unknownLoc, "unknown type result <id>: ") |
| << words[wordIndex]; |
| resultTypes.push_back(type); |
| ++wordIndex; |
| |
| // Decode the result <id> |
| if (wordIndex >= words.size()) |
| return emitError(unknownLoc, |
| "expected result <id> while deserializing for ") |
| << opName; |
| valueID = words[wordIndex]; |
| ++wordIndex; |
| } |
| |
| SmallVector<Value, 4> operands; |
| SmallVector<NamedAttribute, 4> attributes; |
| |
| // Decode operands |
| size_t operandIndex = 0; |
| for (; operandIndex < numOperands && wordIndex < words.size(); |
| ++operandIndex, ++wordIndex) { |
| auto arg = getValue(words[wordIndex]); |
| if (!arg) |
| return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex]; |
| operands.push_back(arg); |
| } |
| if (operandIndex != numOperands) { |
| return emitError( |
| unknownLoc, |
| "found less operands than expected when deserializing for ") |
| << opName << "; only " << operandIndex << " of " << numOperands |
| << " processed"; |
| } |
| if (wordIndex != words.size()) { |
| return emitError( |
| unknownLoc, |
| "found more operands than expected when deserializing for ") |
| << opName << "; only " << wordIndex << " of " << words.size() |
| << " processed"; |
| } |
| |
| // Attach attributes from decorations |
| if (decorations.count(valueID)) { |
| auto attrs = decorations[valueID].getAttrs(); |
| attributes.append(attrs.begin(), attrs.end()); |
| } |
| |
| // Create the op and update bookkeeping maps |
| Location loc = createFileLineColLoc(opBuilder); |
| OperationState opState(loc, opName); |
| opState.addOperands(operands); |
| if (hasResult) |
| opState.addTypes(resultTypes); |
| opState.addAttributes(attributes); |
| Operation *op = opBuilder.createOperation(opState); |
| if (hasResult) |
| valueMap[valueID] = op->getResult(0); |
| |
| if (op->hasTrait<OpTrait::IsTerminator>()) |
| (void)clearDebugLine(); |
| |
| return success(); |
| } |
| |
| LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) { |
| if (operands.size() != 2) { |
| return emitError(unknownLoc, "OpUndef instruction must have two operands"); |
| } |
| auto type = getType(operands[0]); |
| if (!type) { |
| return emitError(unknownLoc, "unknown type <id> with OpUndef instruction"); |
| } |
| undefMap[operands[1]] = type; |
| return success(); |
| } |
| |
| LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) { |
| if (operands.size() < 4) { |
| return emitError(unknownLoc, |
| "OpExtInst must have at least 4 operands, result type " |
| "<id>, result <id>, set <id> and instruction opcode"); |
| } |
| if (!extendedInstSets.count(operands[2])) { |
| return emitError(unknownLoc, "undefined set <id> in OpExtInst"); |
| } |
| SmallVector<uint32_t, 4> slicedOperands; |
| slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); |
| slicedOperands.append(std::next(operands.begin(), 4), operands.end()); |
| return dispatchToExtensionSetAutogenDeserialization( |
| extendedInstSets[operands[2]], operands[3], slicedOperands); |
| } |
| |
| namespace mlir { |
| namespace spirv { |
| |
| template <> |
| LogicalResult |
| Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { |
| unsigned wordIndex = 0; |
| if (wordIndex >= words.size()) { |
| return emitError(unknownLoc, |
| "missing Execution Model specification in OpEntryPoint"); |
| } |
| auto execModel = spirv::ExecutionModelAttr::get( |
| context, static_cast<spirv::ExecutionModel>(words[wordIndex++])); |
| if (wordIndex >= words.size()) { |
| return emitError(unknownLoc, "missing <id> in OpEntryPoint"); |
| } |
| // Get the function <id> |
| auto fnID = words[wordIndex++]; |
| // Get the function name |
| auto fnName = decodeStringLiteral(words, wordIndex); |
| // Verify that the function <id> matches the fnName |
| auto parsedFunc = getFunction(fnID); |
| if (!parsedFunc) { |
| return emitError(unknownLoc, "no function matching <id> ") << fnID; |
| } |
| if (parsedFunc.getName() != fnName) { |
| return emitError(unknownLoc, "function name mismatch between OpEntryPoint " |
| "and OpFunction with <id> ") |
| << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); |
| } |
| SmallVector<Attribute, 4> interface; |
| while (wordIndex < words.size()) { |
| auto arg = getGlobalVariable(words[wordIndex]); |
| if (!arg) { |
| return emitError(unknownLoc, "undefined result <id> ") |
| << words[wordIndex] << " while decoding OpEntryPoint"; |
| } |
| interface.push_back(SymbolRefAttr::get(arg.getOperation())); |
| wordIndex++; |
| } |
| opBuilder.create<spirv::EntryPointOp>( |
| unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), |
| opBuilder.getArrayAttr(interface)); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult |
| Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { |
| unsigned wordIndex = 0; |
| if (wordIndex >= words.size()) { |
| return emitError(unknownLoc, |
| "missing function result <id> in OpExecutionMode"); |
| } |
| // Get the function <id> to get the name of the function |
| auto fnID = words[wordIndex++]; |
| auto fn = getFunction(fnID); |
| if (!fn) { |
| return emitError(unknownLoc, "no function matching <id> ") << fnID; |
| } |
| // Get the Execution mode |
| if (wordIndex >= words.size()) { |
| return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); |
| } |
| auto execMode = spirv::ExecutionModeAttr::get( |
| context, static_cast<spirv::ExecutionMode>(words[wordIndex++])); |
| |
| // Get the values |
| SmallVector<Attribute, 4> attrListElems; |
| while (wordIndex < words.size()) { |
| attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); |
| } |
| auto values = opBuilder.getArrayAttr(attrListElems); |
| opBuilder.create<spirv::ExecutionModeOp>( |
| unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), |
| execMode, values); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult |
| Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) { |
| if (operands.size() != 3) { |
| return emitError( |
| unknownLoc, |
| "OpControlBarrier must have execution scope <id>, memory scope <id> " |
| "and memory semantics <id>"); |
| } |
| |
| SmallVector<IntegerAttr, 3> argAttrs; |
| for (auto operand : operands) { |
| auto argAttr = getConstantInt(operand); |
| if (!argAttr) { |
| return emitError(unknownLoc, |
| "expected 32-bit integer constant from <id> ") |
| << operand << " for OpControlBarrier"; |
| } |
| argAttrs.push_back(argAttr); |
| } |
| |
| opBuilder.create<spirv::ControlBarrierOp>( |
| unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), |
| argAttrs[1].cast<spirv::ScopeAttr>(), |
| argAttrs[2].cast<spirv::MemorySemanticsAttr>()); |
| |
| return success(); |
| } |
| |
| template <> |
| LogicalResult |
| Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) { |
| if (operands.size() < 3) { |
| return emitError(unknownLoc, |
| "OpFunctionCall must have at least 3 operands"); |
| } |
| |
| Type resultType = getType(operands[0]); |
| if (!resultType) { |
| return emitError(unknownLoc, "undefined result type from <id> ") |
| << operands[0]; |
| } |
| |
| // Use null type to mean no result type. |
| if (isVoidType(resultType)) |
| resultType = nullptr; |
| |
| auto resultID = operands[1]; |
| auto functionID = operands[2]; |
| |
| auto functionName = getFunctionSymbol(functionID); |
| |
| SmallVector<Value, 4> arguments; |
| for (auto operand : llvm::drop_begin(operands, 3)) { |
| auto value = getValue(operand); |
| if (!value) { |
| return emitError(unknownLoc, "unknown <id> ") |
| << operand << " used by OpFunctionCall"; |
| } |
| arguments.push_back(value); |
| } |
| |
| auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>( |
| unknownLoc, resultType, |
| SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); |
| |
| if (resultType) |
| valueMap[resultID] = opFunctionCall.getResult(0); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult |
| Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) { |
| if (operands.size() != 2) { |
| return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> " |
| "and memory semantics <id>"); |
| } |
| |
| SmallVector<IntegerAttr, 2> argAttrs; |
| for (auto operand : operands) { |
| auto argAttr = getConstantInt(operand); |
| if (!argAttr) { |
| return emitError(unknownLoc, |
| "expected 32-bit integer constant from <id> ") |
| << operand << " for OpMemoryBarrier"; |
| } |
| argAttrs.push_back(argAttr); |
| } |
| |
| opBuilder.create<spirv::MemoryBarrierOp>( |
| unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), |
| argAttrs[1].cast<spirv::MemorySemanticsAttr>()); |
| return success(); |
| } |
| |
| template <> |
| LogicalResult |
| Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) { |
| SmallVector<Type, 1> resultTypes; |
| size_t wordIndex = 0; |
| SmallVector<Value, 4> operands; |
| SmallVector<NamedAttribute, 4> attributes; |
| |
| if (wordIndex < words.size()) { |
| auto arg = getValue(words[wordIndex]); |
| |
| if (!arg) { |
| return emitError(unknownLoc, "unknown result <id> : ") |
| << words[wordIndex]; |
| } |
| |
| operands.push_back(arg); |
| wordIndex++; |
| } |
| |
| if (wordIndex < words.size()) { |
| auto arg = getValue(words[wordIndex]); |
| |
| if (!arg) { |
| return emitError(unknownLoc, "unknown result <id> : ") |
| << words[wordIndex]; |
| } |
| |
| operands.push_back(arg); |
| wordIndex++; |
| } |
| |
| bool isAlignedAttr = false; |
| |
| if (wordIndex < words.size()) { |
| auto attrValue = words[wordIndex++]; |
| attributes.push_back(opBuilder.getNamedAttr( |
| "memory_access", opBuilder.getI32IntegerAttr(attrValue))); |
| isAlignedAttr = (attrValue == 2); |
| } |
| |
| if (isAlignedAttr && wordIndex < words.size()) { |
| attributes.push_back(opBuilder.getNamedAttr( |
| "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); |
| } |
| |
| if (wordIndex < words.size()) { |
| attributes.push_back(opBuilder.getNamedAttr( |
| "source_memory_access", |
| opBuilder.getI32IntegerAttr(words[wordIndex++]))); |
| } |
| |
| if (wordIndex < words.size()) { |
| attributes.push_back(opBuilder.getNamedAttr( |
| "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); |
| } |
| |
| if (wordIndex != words.size()) { |
| return emitError(unknownLoc, |
| "found more operands than expected when deserializing " |
| "spirv::CopyMemoryOp, only ") |
| << wordIndex << " of " << words.size() << " processed"; |
| } |
| |
| Location loc = createFileLineColLoc(opBuilder); |
| opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes); |
| |
| return success(); |
| } |
| |
| // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and |
| // various Deserializer::processOp<...>() specializations. |
| #define GET_DESERIALIZATION_FNS |
| #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" |
| |
| } // namespace spirv |
| } // namespace mlir |