blob: eea46e714b75601736242ea07bc6973bd9858b26 [file] [log] [blame]
//===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains helper objects and APIs for working with DXIL
/// Root Signatures.
///
//===----------------------------------------------------------------------===//
#include "DXILRootSignature.h"
#include "DirectX.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/BinaryFormat/DXContainer.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <cmath>
#include <cstdint>
#include <optional>
#include <utility>
using namespace llvm;
using namespace llvm::dxil;
static bool reportError(LLVMContext *Ctx, Twine Message,
DiagnosticSeverity Severity = DS_Error) {
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
return true;
}
static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
uint32_t Value) {
Ctx->diagnose(DiagnosticInfoGeneric(
"Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
return true;
}
// Template function to get formatted type string based on C++ type
template <typename T> std::string getTypeFormatted() {
if constexpr (std::is_same_v<T, MDString>) {
return "string";
} else if constexpr (std::is_same_v<T, MDNode *> ||
std::is_same_v<T, const MDNode *>) {
return "metadata";
} else if constexpr (std::is_same_v<T, ConstantAsMetadata *> ||
std::is_same_v<T, const ConstantAsMetadata *>) {
return "constant";
} else if constexpr (std::is_same_v<T, ConstantAsMetadata>) {
return "constant";
} else if constexpr (std::is_same_v<T, ConstantInt *> ||
std::is_same_v<T, const ConstantInt *>) {
return "constant int";
} else if constexpr (std::is_same_v<T, ConstantInt>) {
return "constant int";
}
return "unknown";
}
// Helper function to get the actual type of a metadata operand
std::string getActualMDType(const MDNode *Node, unsigned Index) {
if (!Node || Index >= Node->getNumOperands())
return "null";
Metadata *Op = Node->getOperand(Index);
if (!Op)
return "null";
if (isa<MDString>(Op))
return getTypeFormatted<MDString>();
if (isa<ConstantAsMetadata>(Op)) {
if (auto *CAM = dyn_cast<ConstantAsMetadata>(Op)) {
Type *T = CAM->getValue()->getType();
if (T->isIntegerTy())
return (Twine("i") + Twine(T->getIntegerBitWidth())).str();
if (T->isFloatingPointTy())
return T->isFloatTy() ? getTypeFormatted<float>()
: T->isDoubleTy() ? getTypeFormatted<double>()
: "fp";
return getTypeFormatted<ConstantAsMetadata>();
}
}
if (isa<MDNode>(Op))
return getTypeFormatted<MDNode *>();
return "unknown";
}
// Helper function to simplify error reporting for invalid metadata values
template <typename ET>
auto reportInvalidTypeError(LLVMContext *Ctx, Twine ParamName,
const MDNode *Node, unsigned Index) {
std::string ExpectedType = getTypeFormatted<ET>();
std::string ActualType = getActualMDType(Node, Index);
return reportError(Ctx, "Root Signature Node: " + ParamName +
" expected metadata node of type " +
ExpectedType + " at index " + Twine(Index) +
" but got " + ActualType);
}
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI =
mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
return CI->getZExtValue();
return std::nullopt;
}
static std::optional<APFloat> extractMdFloatValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
return CI->getValue();
return std::nullopt;
}
static std::optional<StringRef> extractMdStringValue(MDNode *Node,
unsigned int OpId) {
MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
if (NodeText == nullptr)
return std::nullopt;
return NodeText->getString();
}
static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
return reportError(Ctx, "Invalid format for RootFlag Element");
bool HasError = false;
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootFlagNode",
RootFlagNode, 1) ||
HasError;
return HasError;
}
static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *RootConstantNode) {
if (RootConstantNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for RootConstants Element");
bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
// The parameter offset doesn't matter here - we recalculate it during
// serialization Header.ParameterOffset = 0;
Header.ParameterType =
llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
Header.ShaderVisibility = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
RootConstantNode, 1) ||
HasError;
dxbc::RTS0::v1::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
RootConstantNode, 2) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
RootConstantNode, 3) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootConstantNode",
RootConstantNode, 4) ||
HasError;
if (!HasError)
RSD.ParametersContainer.addParameter(Header, Constants);
return HasError;
}
static bool parseRootDescriptors(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD,
MDNode *RootDescriptorNode,
RootSignatureElementKind ElementKind) {
assert(ElementKind == RootSignatureElementKind::SRV ||
ElementKind == RootSignatureElementKind::UAV ||
ElementKind == RootSignatureElementKind::CBV &&
"parseRootDescriptors should only be called with RootDescriptor "
"element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
return reportError(Ctx, "Invalid format for Root Descriptor Element");
bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
break;
case RootSignatureElementKind::UAV:
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
break;
case RootSignatureElementKind::CBV:
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
Header.ShaderVisibility = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
RootDescriptorNode, 1) ||
HasError;
dxbc::RTS0::v2::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
RootDescriptorNode, 2) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
RootDescriptorNode, 3) ||
HasError;
if (RSD.Version == 1) {
if (!HasError)
RSD.ParametersContainer.addParameter(Header, Descriptor);
return HasError;
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RootDescriptorNode",
RootDescriptorNode, 4) ||
HasError;
if (!HasError)
RSD.ParametersContainer.addParameter(Header, Descriptor);
return HasError;
}
static bool parseDescriptorRange(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD,
mcdxbc::DescriptorTable &Table,
MDNode *RangeDescriptorNode) {
if (RangeDescriptorNode->getNumOperands() != 6)
return reportError(Ctx, "Invalid format for Descriptor Range");
bool HasError = false;
dxbc::RTS0::v2::DescriptorRange Range;
std::optional<StringRef> ElementText =
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
HasError = reportInvalidTypeError<MDString>(Ctx, "RangeDescriptorNode",
RangeDescriptorNode, 0) ||
HasError;
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
.Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
.Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
.Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
.Case("Sampler",
llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
.Default(-1u);
if (Range.RangeType == -1u)
HasError =
reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
RangeDescriptorNode, 1) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
RangeDescriptorNode, 2) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
RangeDescriptorNode, 3) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
RangeDescriptorNode, 4) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "RangeDescriptorNode",
RangeDescriptorNode, 5) ||
HasError;
if (!HasError)
Table.Ranges.push_back(Range);
return HasError;
}
static bool parseDescriptorTable(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD,
MDNode *DescriptorTableNode) {
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
return reportError(Ctx, "Invalid format for Descriptor Table");
bool HasError = false;
dxbc::RTS0::v1::RootParameterHeader Header;
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
Header.ShaderVisibility = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "DescriptorTableNode",
DescriptorTableNode, 1) ||
HasError;
mcdxbc::DescriptorTable Table;
Header.ParameterType =
llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
HasError = reportInvalidTypeError<MDNode>(Ctx, "DescriptorTableNode",
DescriptorTableNode, I) ||
HasError;
if (parseDescriptorRange(Ctx, RSD, Table, Element))
HasError = true || HasError;
}
if (!HasError)
RSD.ParametersContainer.addParameter(Header, Table);
return HasError;
}
static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 14)
return reportError(Ctx, "Invalid format for Static Sampler");
bool HasError = false;
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 1) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 2) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 3) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 4) ||
HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = Val->convertToFloat();
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 5) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 6) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 7) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 8) ||
HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = Val->convertToFloat();
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 9) ||
HasError;
if (std::optional<APFloat> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = Val->convertToFloat();
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 10) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 11) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 12) ||
HasError;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
HasError = reportInvalidTypeError<ConstantInt>(Ctx, "StaticSamplerNode",
StaticSamplerNode, 13) ||
HasError;
if (!HasError)
RSD.StaticSamplers.push_back(Sampler);
return HasError;
}
static bool parseRootSignatureElement(LLVMContext *Ctx,
mcdxbc::RootSignatureDesc &RSD,
MDNode *Element) {
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
return reportError(Ctx, "Invalid format for Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
.Case("RootFlags", RootSignatureElementKind::RootFlags)
.Case("RootConstants", RootSignatureElementKind::RootConstants)
.Case("RootCBV", RootSignatureElementKind::CBV)
.Case("RootSRV", RootSignatureElementKind::SRV)
.Case("RootUAV", RootSignatureElementKind::UAV)
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
.Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
.Default(RootSignatureElementKind::Error);
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
return parseRootFlags(Ctx, RSD, Element);
case RootSignatureElementKind::RootConstants:
return parseRootConstants(Ctx, RSD, Element);
case RootSignatureElementKind::CBV:
case RootSignatureElementKind::SRV:
case RootSignatureElementKind::UAV:
return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
case RootSignatureElementKind::DescriptorTable:
return parseDescriptorTable(Ctx, RSD, Element);
case RootSignatureElementKind::StaticSamplers:
return parseStaticSampler(Ctx, RSD, Element);
case RootSignatureElementKind::Error:
return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
MDNode *Node) {
bool HasError = false;
// Loop through the Root Elements of the root signature.
for (const auto &Operand : Node->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
return reportError(Ctx, "Missing Root Element Metadata Node.");
HasError = parseRootSignatureElement(Ctx, RSD, Element) || HasError;
}
return HasError;
}
static bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
static bool verifyVersion(uint32_t Version) {
return (Version == 1 || Version == 2);
}
static bool verifyRegisterValue(uint32_t RegisterValue) {
return RegisterValue != ~0U;
}
// This Range is reserverved, therefore invalid, according to the spec
// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
static bool verifyRegisterSpace(uint32_t RegisterSpace) {
return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
}
static bool verifyDescriptorFlag(uint32_t Flags) { return (Flags & ~0xE) == 0; }
static bool verifyRangeType(uint32_t Type) {
switch (Type) {
case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
return true;
};
return false;
}
template <typename... FlagTypes>
static bool isFlagSet(uint32_t Flags, FlagTypes... FlagsToCheck) {
return ((Flags & llvm::to_underlying(FlagsToCheck)) | ...) == Flags;
}
static bool verifyDescriptorRangeFlag(uint32_t Version, uint32_t Type,
uint32_t FlagsVal) {
using FlagT = dxbc::DescriptorRangeFlag;
FlagT Flags = FlagT(FlagsVal);
const bool IsSampler =
(Type == llvm::to_underlying(dxbc::DescriptorRangeType::Sampler));
if (Version == 1) {
if (IsSampler)
return Flags == FlagT::NONE;
return Flags == FlagT::DESCRIPTORS_VOLATILE;
}
// The data-specific flags are mutually exclusive.
FlagT DataFlags = FlagT::DATA_VOLATILE | FlagT::DATA_STATIC |
FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
if (popcount(llvm::to_underlying(Flags & DataFlags)) > 1)
return false;
// For volatile descriptors, DATA_STATIC is never valid.
if ((Flags & FlagT::DESCRIPTORS_VOLATILE) == FlagT::DESCRIPTORS_VOLATILE) {
FlagT Mask = FlagT::DESCRIPTORS_VOLATILE;
if (!IsSampler) {
Mask |= FlagT::DATA_VOLATILE;
Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
}
return (Flags & ~Mask) == FlagT::NONE;
}
// For "STATIC_KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
// the other data-specific flags may all be set.
if ((Flags & FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) ==
FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) {
FlagT Mask = FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS;
if (!IsSampler) {
Mask |= FlagT::DATA_VOLATILE;
Mask |= FlagT::DATA_STATIC;
Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
}
return (Flags & ~Mask) == FlagT::NONE;
}
// When no descriptor flag is set, any data flag is allowed.
return (Flags & ~DataFlags) == FlagT::NONE;
}
static bool verifySamplerFilter(uint32_t Filter) {
switch (Filter) {
case llvm::to_underlying(dxbc::StaticSamplerFilter::MIN_MAG_MIP_POINT):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MIN_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MIN_POINT_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MIN_POINT_MAG_MIP_LINEAR):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MIN_LINEAR_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MIN_LINEAR_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MIN_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MIN_MAG_MIP_LINEAR):
case llvm::to_underlying(dxbc::StaticSamplerFilter::ANISOTROPIC):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_POINT_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_POINT_MAG_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_LINEAR_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_LINEAR_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::COMPARISON_MIN_MAG_MIP_LINEAR):
case llvm::to_underlying(dxbc::StaticSamplerFilter::COMPARISON_ANISOTROPIC):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_POINT_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_POINT_MAG_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_LINEAR_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_LINEAR_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MINIMUM_MIN_MAG_MIP_LINEAR):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MINIMUM_ANISOTROPIC):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_POINT_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_POINT_MAG_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_LINEAR_MAG_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_LINEAR_MAG_POINT_MIP_LINEAR):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_MAG_LINEAR_MIP_POINT):
case llvm::to_underlying(
dxbc::StaticSamplerFilter::MAXIMUM_MIN_MAG_MIP_LINEAR):
case llvm::to_underlying(dxbc::StaticSamplerFilter::MAXIMUM_ANISOTROPIC):
return true;
}
return false;
}
// Values allowed here:
// https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_texture_address_mode#syntax
static bool verifyAddress(uint32_t Address) {
switch (Address) {
case llvm::to_underlying(dxbc::TextureAddressMode::Border):
case llvm::to_underlying(dxbc::TextureAddressMode::Clamp):
case llvm::to_underlying(dxbc::TextureAddressMode::Mirror):
case llvm::to_underlying(dxbc::TextureAddressMode::MirrorOnce):
case llvm::to_underlying(dxbc::TextureAddressMode::Wrap):
return true;
}
return false;
}
static bool verifyMipLODBias(float MipLODBias) {
return MipLODBias >= -16.f && MipLODBias <= 16.f;
}
static bool verifyMaxAnisotropy(uint32_t MaxAnisotropy) {
return MaxAnisotropy <= 16u;
}
static bool verifyComparisonFunc(uint32_t ComparisonFunc) {
switch (ComparisonFunc) {
case llvm::to_underlying(dxbc::SamplersComparisonFunction::Never):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::Less):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::Equal):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::LessEqual):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::Greater):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::NotEqual):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::GreaterEqual):
case llvm::to_underlying(dxbc::SamplersComparisonFunction::Always):
return true;
}
return false;
}
static bool verifyBorderColor(uint32_t BorderColor) {
switch (BorderColor) {
case llvm::to_underlying(dxbc::SamplersBorderColor::TransparentBlack):
case llvm::to_underlying(dxbc::SamplersBorderColor::OpaqueBlack):
case llvm::to_underlying(dxbc::SamplersBorderColor::OpaqueWhite):
case llvm::to_underlying(dxbc::SamplersBorderColor::OpaqueBlackUint):
case llvm::to_underlying(dxbc::SamplersBorderColor::OpaqueWhiteUint):
return true;
}
return false;
}
static bool verifyLOD(float LOD) { return !std::isnan(LOD); }
static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
bool HasError = false;
if (!verifyVersion(RSD.Version)) {
HasError = reportValueError(Ctx, "Version", RSD.Version) || HasError;
}
if (!verifyRootFlag(RSD.Flags)) {
HasError = reportValueError(Ctx, "RootFlags", RSD.Flags) || HasError;
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
HasError = reportValueError(Ctx, "ShaderVisibility",
Info.Header.ShaderVisibility) ||
HasError;
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
"Invalid value for ParameterType");
switch (Info.Header.ParameterType) {
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!verifyRegisterValue(Descriptor.ShaderRegister))
HasError = reportValueError(Ctx, "ShaderRegister",
Descriptor.ShaderRegister) ||
HasError;
if (!verifyRegisterSpace(Descriptor.RegisterSpace))
HasError =
reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace) ||
HasError;
if (RSD.Version > 1) {
if (!verifyDescriptorFlag(Descriptor.Flags))
HasError =
reportValueError(Ctx, "DescriptorFlag", Descriptor.Flags) ||
HasError;
}
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
if (!verifyRangeType(Range.RangeType))
HasError =
reportValueError(Ctx, "RangeType", Range.RangeType) || HasError;
if (!verifyRegisterSpace(Range.RegisterSpace))
HasError =
reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace) ||
HasError;
if (!verifyDescriptorRangeFlag(RSD.Version, Range.RangeType,
Range.Flags))
HasError =
reportValueError(Ctx, "DescriptorFlag", Range.Flags) || HasError;
}
break;
}
}
}
for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
if (!verifySamplerFilter(Sampler.Filter))
HasError = reportValueError(Ctx, "Filter", Sampler.Filter) || HasError;
if (!verifyAddress(Sampler.AddressU))
HasError =
reportValueError(Ctx, "AddressU", Sampler.AddressU) || HasError;
if (!verifyAddress(Sampler.AddressV))
HasError =
reportValueError(Ctx, "AddressV", Sampler.AddressV) || HasError;
if (!verifyAddress(Sampler.AddressW))
HasError =
reportValueError(Ctx, "AddressW", Sampler.AddressW) || HasError;
if (!verifyMipLODBias(Sampler.MipLODBias))
HasError =
reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias) || HasError;
if (!verifyMaxAnisotropy(Sampler.MaxAnisotropy))
HasError =
reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy) ||
HasError;
if (!verifyComparisonFunc(Sampler.ComparisonFunc))
HasError =
reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc) ||
HasError;
if (!verifyBorderColor(Sampler.BorderColor))
HasError =
reportValueError(Ctx, "BorderColor", Sampler.BorderColor) || HasError;
if (!verifyLOD(Sampler.MinLOD))
HasError = reportValueError(Ctx, "MinLOD", Sampler.MinLOD) || HasError;
if (!verifyLOD(Sampler.MaxLOD))
HasError = reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD) || HasError;
if (!verifyRegisterValue(Sampler.ShaderRegister))
HasError =
reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister) ||
HasError;
if (!verifyRegisterSpace(Sampler.RegisterSpace))
HasError =
reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace) ||
HasError;
if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
HasError =
reportValueError(Ctx, "ShaderVisibility", Sampler.ShaderVisibility) ||
HasError;
}
return HasError;
}
static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
analyzeModule(Module &M) {
/** Root Signature are specified as following in the metadata:
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3 } ; function, root signature
!3 = !{ !4, !5, !6, !7 } ; list of root signature elements
So for each MDNode inside dx.rootsignatures NamedMDNode
(the Root parameter of this function), the parsing process needs
to loop through each of its operands and process the function,
signature pair.
*/
LLVMContext *Ctx = &M.getContext();
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap;
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
if (RootSignatureNode == nullptr)
return RSDMap;
for (const auto &RSDefNode : RootSignatureNode->operands()) {
if (RSDefNode->getNumOperands() != 2) {
reportError(Ctx, "Invalid format for Root Signature Definition. Pairs "
"of function, root signature expected.");
continue;
}
// Function was pruned during compilation.
const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
if (FunctionPointerMdNode == nullptr) {
reportError(
Ctx, "Function associated with Root Signature definition is null.");
continue;
}
ValueAsMetadata *VAM =
llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
if (VAM == nullptr) {
reportError(Ctx, "First element of root signature is not a Value");
continue;
}
Function *F = dyn_cast<Function>(VAM->getValue());
if (F == nullptr) {
reportError(Ctx, "First element of root signature is not a Function");
continue;
}
Metadata *RootElementListOperand = RSDefNode->getOperand(1).get();
if (RootElementListOperand == nullptr) {
reportError(Ctx, "Root Element mdnode is null.");
continue;
}
MDNode *RootElementListNode = dyn_cast<MDNode>(RootElementListOperand);
if (RootElementListNode == nullptr) {
reportError(Ctx, "Root Element is not a metadata node.");
continue;
}
mcdxbc::RootSignatureDesc RSD;
// Clang emits the root signature data in dxcontainer following a specific
// sequence. First the header, then the root parameters. So the header
// offset will always equal to the header size.
RSD.RootParameterOffset = sizeof(dxbc::RTS0::v1::RootSignatureHeader);
// static sampler offset is calculated when writting dxcontainer.
RSD.StaticSamplersOffset = 0u;
if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
return RSDMap;
}
RSDMap.insert(std::make_pair(F, RSD));
}
return RSDMap;
}
AnalysisKey RootSignatureAnalysis::Key;
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
return analyzeModule(M);
}
//===----------------------------------------------------------------------===//
PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
ModuleAnalysisManager &AM) {
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> &RSDMap =
AM.getResult<RootSignatureAnalysis>(M);
OS << "Root Signature Definitions"
<< "\n";
for (const Function &F : M) {
auto It = RSDMap.find(&F);
if (It == RSDMap.end())
continue;
const auto &RS = It->second;
OS << "Definition for '" << F.getName() << "':\n";
// start root signature header
OS << "Flags: " << format_hex(RS.Flags, 8) << "\n";
OS << "Version: " << RS.Version << "\n";
OS << "RootParametersOffset: " << RS.RootParameterOffset << "\n";
OS << "NumParameters: " << RS.ParametersContainer.size() << "\n";
for (size_t I = 0; I < RS.ParametersContainer.size(); I++) {
const auto &[Type, Loc] =
RS.ParametersContainer.getTypeAndLocForParameter(I);
const dxbc::RTS0::v1::RootParameterHeader Header =
RS.ParametersContainer.getHeader(I);
OS << "- Parameter Type: " << Type << "\n";
OS << indent(2) << "Shader Visibility: " << Header.ShaderVisibility
<< "\n";
switch (Type) {
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
const dxbc::RTS0::v1::RootConstants &Constants =
RS.ParametersContainer.getConstant(Loc);
OS << indent(2) << "Register Space: " << Constants.RegisterSpace
<< "\n";
OS << indent(2) << "Shader Register: " << Constants.ShaderRegister
<< "\n";
OS << indent(2) << "Num 32 Bit Values: " << Constants.Num32BitValues
<< "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::CBV):
case llvm::to_underlying(dxbc::RootParameterType::UAV):
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
RS.ParametersContainer.getRootDescriptor(Loc);
OS << indent(2) << "Register Space: " << Descriptor.RegisterSpace
<< "\n";
OS << indent(2) << "Shader Register: " << Descriptor.ShaderRegister
<< "\n";
if (RS.Version > 1)
OS << indent(2) << "Flags: " << Descriptor.Flags << "\n";
break;
}
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
const mcdxbc::DescriptorTable &Table =
RS.ParametersContainer.getDescriptorTable(Loc);
OS << indent(2) << "NumRanges: " << Table.Ranges.size() << "\n";
for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
OS << indent(2) << "- Range Type: " << Range.RangeType << "\n";
OS << indent(4) << "Register Space: " << Range.RegisterSpace << "\n";
OS << indent(4)
<< "Base Shader Register: " << Range.BaseShaderRegister << "\n";
OS << indent(4) << "Num Descriptors: " << Range.NumDescriptors
<< "\n";
OS << indent(4) << "Offset In Descriptors From Table Start: "
<< Range.OffsetInDescriptorsFromTableStart << "\n";
if (RS.Version > 1)
OS << indent(4) << "Flags: " << Range.Flags << "\n";
}
break;
}
}
}
OS << "NumStaticSamplers: " << 0 << "\n";
OS << "StaticSamplersOffset: " << RS.StaticSamplersOffset << "\n";
}
return PreservedAnalyses::all();
}
//===----------------------------------------------------------------------===//
bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
FuncToRsMap = analyzeModule(M);
return false;
}
void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
}
char RootSignatureAnalysisWrapper::ID = 0;
INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
"dxil-root-signature-analysis",
"DXIL Root Signature Analysis", true, true)
INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
"dxil-root-signature-analysis",
"DXIL Root Signature Analysis", true, true)