blob: ef299c17baf767b22f242894a2410cce3aaf1478 [file] [log] [blame]
//===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains helper objects and APIs for working with DXIL
/// Root Signatures.
///
//===----------------------------------------------------------------------===//
#include "DXILRootSignature.h"
#include "DirectX.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <optional>
#include <utility>
using namespace llvm;
using namespace llvm::dxil;
static bool reportError(LLVMContext *Ctx, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
return true;
}
static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
uint32_t Value) {
Ctx->diagnose(DiagnosticInfoGeneric(
"Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
return true;
}
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI =
mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
return CI->getZExtValue();
return std::nullopt;
}
static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
return reportError(Ctx, "Invalid format for RootFlag Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
return reportError(Ctx, "Invalid value for RootFlag");
return false;
}
static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *RootConstantNode) {
if (RootConstantNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for RootConstants Element");
mcdxbc::RootParameter NewParameter;
NewParameter.Header.ParameterType =
llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
NewParameter.Header.ShaderVisibility = *Val;
else
return reportError(Ctx, "Invalid value for ShaderVisibility");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
NewParameter.Constants.ShaderRegister = *Val;
else
return reportError(Ctx, "Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
NewParameter.Constants.RegisterSpace = *Val;
else
return reportError(Ctx, "Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
NewParameter.Constants.Num32BitValues = *Val;
else
return reportError(Ctx, "Invalid value for Num32BitValues");
RSD.Parameters.push_back(NewParameter);
return false;
}
static bool parseRootSignatureElement(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD,
MDNode *Element) {
MDString *ElementText = cast<MDString>(Element->getOperand(0));
if (ElementText == nullptr)
return reportError(Ctx, "Invalid format for Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(ElementText->getString())
.Case("RootFlags", RootSignatureElementKind::RootFlags)
.Case("RootConstants", RootSignatureElementKind::RootConstants)
.Default(RootSignatureElementKind::Error);
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
return parseRootFlags(Ctx, RSD, Element);
case RootSignatureElementKind::RootConstants:
return parseRootConstants(Ctx, RSD, Element);
break;
case RootSignatureElementKind::Error:
return reportError(Ctx, "Invalid Root Signature Element: " +
ElementText->getString());
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *Node) {
bool HasError = false;
// Loop through the Root Elements of the root signature.
for (const auto &Operand : Node->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
return reportError(Ctx, "Missing Root Element Metadata Node.");
HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
}
return HasError;
}
static bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
static bool verifyVersion(uint32_t Version) {
return (Version == 1 || Version == 2);
}
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
if (!verifyVersion(RSD.Version)) {
return reportValueError(Ctx, "Version", RSD.Version);
}
if (!verifyRootFlag(RSD.Flags)) {
return reportValueError(Ctx, "RootFlags", RSD.Flags);
}
for (const mcdxbc::RootParameter &P : RSD.Parameters) {
if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility))
return reportValueError(Ctx, "ShaderVisibility",
P.Header.ShaderVisibility);
assert(dxbc::isValidParameterType(P.Header.ParameterType) &&
"Invalid value for ParameterType");
}
return false;
}
static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
analyzeModule(Module &M) {
/** Root Signature are specified as following in the metadata:
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3 } ; function, root signature
!3 = !{ !4, !5, !6, !7 } ; list of root signature elements
So for each MDNode inside dx.rootsignatures NamedMDNode
(the Root parameter of this function), the parsing process needs
to loop through each of its operands and process the function,
signature pair.
*/
LLVMContext *Ctx = &M.getContext();
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap;
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
if (RootSignatureNode == nullptr)
return RSDMap;
for (const auto &RSDefNode : RootSignatureNode->operands()) {
if (RSDefNode->getNumOperands() != 2) {
reportError(Ctx, "Invalid format for Root Signature Definition. Pairs "
"of function, root signature expected.");
continue;
}
// Function was pruned during compilation.
const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
if (FunctionPointerMdNode == nullptr) {
reportError(
Ctx, "Function associated with Root Signature definition is null.");
continue;
}
ValueAsMetadata *VAM =
llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
if (VAM == nullptr) {
reportError(Ctx, "First element of root signature is not a Value");
continue;
}
Function *F = dyn_cast<Function>(VAM->getValue());
if (F == nullptr) {
reportError(Ctx, "First element of root signature is not a Function");
continue;
}
Metadata *RootElementListOperand = RSDefNode->getOperand(1).get();
if (RootElementListOperand == nullptr) {
reportError(Ctx, "Root Element mdnode is null.");
continue;
}
MDNode *RootElementListNode = dyn_cast<MDNode>(RootElementListOperand);
if (RootElementListNode == nullptr) {
reportError(Ctx, "Root Element is not a metadata node.");
continue;
}
mcdxbc::RootSignatureDesc RSD;
// Clang emits the root signature data in dxcontainer following a specific
// sequence. First the header, then the root parameters. So the header
// offset will always equal to the header size.
RSD.RootParameterOffset = sizeof(dxbc::RootSignatureHeader);
if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
return RSDMap;
}
RSDMap.insert(std::make_pair(F, RSD));
}
return RSDMap;
}
AnalysisKey RootSignatureAnalysis::Key;
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
return analyzeModule(M);
}
//===----------------------------------------------------------------------===//
PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
ModuleAnalysisManager &AM) {
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> &RSDMap =
AM.getResult<RootSignatureAnalysis>(M);
OS << "Root Signature Definitions"
<< "\n";
uint8_t Space = 0;
for (const Function &F : M) {
auto It = RSDMap.find(&F);
if (It == RSDMap.end())
continue;
const auto &RS = It->second;
OS << "Definition for '" << F.getName() << "':\n";
// start root signature header
Space++;
OS << indent(Space) << "Flags: " << format_hex(RS.Flags, 8) << "\n";
OS << indent(Space) << "Version: " << RS.Version << "\n";
OS << indent(Space) << "RootParametersOffset: " << RS.RootParameterOffset
<< "\n";
OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n";
Space++;
for (auto const &P : RS.Parameters) {
OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType
<< "\n";
OS << indent(Space + 2)
<< "Shader Visibility: " << P.Header.ShaderVisibility << "\n";
switch (P.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
OS << indent(Space + 2)
<< "Register Space: " << P.Constants.RegisterSpace << "\n";
OS << indent(Space + 2)
<< "Shader Register: " << P.Constants.ShaderRegister << "\n";
OS << indent(Space + 2)
<< "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n";
break;
}
}
Space--;
OS << indent(Space) << "NumStaticSamplers: " << 0 << "\n";
OS << indent(Space) << "StaticSamplersOffset: " << RS.StaticSamplersOffset
<< "\n";
Space--;
// end root signature header
}
return PreservedAnalyses::all();
}
//===----------------------------------------------------------------------===//
bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
FuncToRsMap = analyzeModule(M);
return false;
}
void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
}
char RootSignatureAnalysisWrapper::ID = 0;
INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
"dxil-root-signature-analysis",
"DXIL Root Signature Analysis", true, true)
INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
"dxil-root-signature-analysis",
"DXIL Root Signature Analysis", true, true)