blob: 132a280c659a1c93e9d4fdc01d68d8a589570368 [file] [log] [blame] [edit]
//===- RootSignatureMetadata.h - HLSL Root Signature helpers --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file implements a library for working with HLSL Root Signatures
/// and their metadata representation.
///
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ScopedPrinter.h"
using namespace llvm;
namespace llvm {
namespace hlsl {
namespace rootsig {
char RootSignatureValidationError::ID;
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI =
mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
return CI->getZExtValue();
return std::nullopt;
}
static std::optional<float> extractMdFloatValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
return CI->getValueAPF().convertToFloat();
return std::nullopt;
}
static std::optional<StringRef> extractMdStringValue(MDNode *Node,
unsigned int OpId) {
MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
if (NodeText == nullptr)
return std::nullopt;
return NodeText->getString();
}
namespace {
// We use the OverloadVisit with std::visit to ensure the compiler catches if a
// new RootElement variant type is added but it's metadata generation isn't
// handled.
template <class... Ts> struct OverloadedVisit : Ts... {
using Ts::operator()...;
};
template <class... Ts> OverloadedVisit(Ts...) -> OverloadedVisit<Ts...>;
struct FmtRange {
dxil::ResourceClass Type;
uint32_t Register;
uint32_t Space;
FmtRange(const mcdxbc::DescriptorRange &Range)
: Type(Range.RangeType), Register(Range.BaseShaderRegister),
Space(Range.RegisterSpace) {}
};
raw_ostream &operator<<(llvm::raw_ostream &OS, const FmtRange &Range) {
OS << getResourceClassName(Range.Type) << "(register=" << Range.Register
<< ", space=" << Range.Space << ")";
return OS;
}
struct FmtMDNode {
const MDNode *Node;
FmtMDNode(const MDNode *Node) : Node(Node) {}
};
raw_ostream &operator<<(llvm::raw_ostream &OS, FmtMDNode Fmt) {
Fmt.Node->printTree(OS);
return OS;
}
static Error makeRSError(const Twine &Msg) {
return make_error<RootSignatureValidationError>(Msg);
}
} // namespace
template <typename T, typename = std::enable_if_t<
std::is_enum_v<T> &&
std::is_same_v<std::underlying_type_t<T>, uint32_t>>>
static Expected<T>
extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
llvm::function_ref<bool(uint32_t)> VerifyFn) {
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
if (!VerifyFn(*Val))
return makeRSError(formatv("Invalid value for {0}: {1}", ErrText, Val));
return static_cast<T>(*Val);
}
return makeRSError(formatv("Invalid value for {0}:", ErrText));
}
MDNode *MetadataBuilder::BuildRootSignature() {
const auto Visitor = OverloadedVisit{
[this](const dxbc::RootFlags &Flags) -> MDNode * {
return BuildRootFlags(Flags);
},
[this](const RootConstants &Constants) -> MDNode * {
return BuildRootConstants(Constants);
},
[this](const RootDescriptor &Descriptor) -> MDNode * {
return BuildRootDescriptor(Descriptor);
},
[this](const DescriptorTableClause &Clause) -> MDNode * {
return BuildDescriptorTableClause(Clause);
},
[this](const DescriptorTable &Table) -> MDNode * {
return BuildDescriptorTable(Table);
},
[this](const StaticSampler &Sampler) -> MDNode * {
return BuildStaticSampler(Sampler);
},
};
for (const RootElement &Element : Elements) {
MDNode *ElementMD = std::visit(Visitor, Element);
assert(ElementMD != nullptr &&
"Root Element must be initialized and validated");
GeneratedMetadata.push_back(ElementMD);
}
return MDNode::get(Ctx, GeneratedMetadata);
}
MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "RootFlags"),
ConstantAsMetadata::get(Builder.getInt32(to_underlying(Flags))),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "RootConstants"),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Constants.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Constants.Space)),
ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
StringRef ResName = dxil::getResourceClassName(Descriptor.Type);
assert(!ResName.empty() && "Provided an invalid Resource Class");
SmallString<7> Name({"Root", ResName});
Metadata *Operands[] = {
MDString::get(Ctx, Name),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Descriptor.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Descriptor.Flags))),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
IRBuilder<> Builder(Ctx);
SmallVector<Metadata *> TableOperands;
// Set the mandatory arguments
TableOperands.push_back(MDString::get(Ctx, "DescriptorTable"));
TableOperands.push_back(ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Table.Visibility))));
// Remaining operands are references to the table's clauses. The in-memory
// representation of the Root Elements created from parsing will ensure that
// the previous N elements are the clauses for this table.
assert(Table.NumClauses <= GeneratedMetadata.size() &&
"Table expected all owned clauses to be generated already");
// So, add a refence to each clause to our operands
TableOperands.append(GeneratedMetadata.end() - Table.NumClauses,
GeneratedMetadata.end());
// Then, remove those clauses from the general list of Root Elements
GeneratedMetadata.pop_back_n(Table.NumClauses);
return MDNode::get(Ctx, TableOperands);
}
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
StringRef ResName = dxil::getResourceClassName(Clause.Type);
assert(!ResName.empty() && "Provided an invalid Resource Class");
Metadata *Operands[] = {
MDString::get(Ctx, ResName),
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
ConstantAsMetadata::get(Builder.getInt32(to_underlying(Clause.Flags))),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "StaticSampler"),
ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Filter))),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.AddressU))),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.AddressV))),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.AddressW))),
ConstantAsMetadata::get(
ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MipLODBias)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.CompFunc))),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.BorderColor))),
ConstantAsMetadata::get(
ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MinLOD)),
ConstantAsMetadata::get(
ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MaxLOD)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)),
ConstantAsMetadata::get(
Builder.getInt32(to_underlying(Sampler.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Flags))),
};
return MDNode::get(Ctx, Operands);
}
Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
return makeRSError("Invalid format for RootFlags Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
return makeRSError("Invalid value for RootFlag");
return Error::success();
}
Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
MDNode *RootConstantNode) {
if (RootConstantNode->getNumOperands() != 5)
return makeRSError("Invalid format for RootConstants Element");
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1,
"ShaderVisibility",
dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
mcdxbc::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
return makeRSError("Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
return makeRSError("Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
return makeRSError("Invalid value for Num32BitValues");
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
*Visibility, Constants);
return Error::success();
}
Error MetadataParser::parseRootDescriptors(
mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode,
RootSignatureElementKind ElementKind) {
assert((ElementKind == RootSignatureElementKind::SRV ||
ElementKind == RootSignatureElementKind::UAV ||
ElementKind == RootSignatureElementKind::CBV) &&
"parseRootDescriptors should only be called with RootDescriptor "
"element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
return makeRSError("Invalid format for Root Descriptor Element");
dxbc::RootParameterType Type;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
Type = dxbc::RootParameterType::SRV;
break;
case RootSignatureElementKind::UAV:
Type = dxbc::RootParameterType::UAV;
break;
case RootSignatureElementKind::CBV:
Type = dxbc::RootParameterType::CBV;
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(RootDescriptorNode, 1,
"ShaderVisibility",
dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
mcdxbc::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
return makeRSError("Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
return makeRSError("Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
return makeRSError("Invalid value for Root Descriptor Flags");
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
return Error::success();
}
Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
MDNode *RangeDescriptorNode) {
if (RangeDescriptorNode->getNumOperands() != 6)
return makeRSError("Invalid format for Descriptor Range");
mcdxbc::DescriptorRange Range;
std::optional<StringRef> ElementText =
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
return makeRSError("Invalid format for Descriptor Range");
if (*ElementText == "CBV")
Range.RangeType = dxil::ResourceClass::CBuffer;
else if (*ElementText == "SRV")
Range.RangeType = dxil::ResourceClass::SRV;
else if (*ElementText == "UAV")
Range.RangeType = dxil::ResourceClass::UAV;
else if (*ElementText == "Sampler")
Range.RangeType = dxil::ResourceClass::Sampler;
else
return makeRSError(formatv("Invalid Descriptor Range type.\n{0}",
FmtMDNode{RangeDescriptorNode}));
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
return makeRSError(formatv("Invalid number of Descriptor in Range.\n{0}",
FmtMDNode{RangeDescriptorNode}));
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
return makeRSError("Invalid value for BaseShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
return makeRSError("Invalid value for RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
return makeRSError("Invalid value for OffsetInDescriptorsFromTableStart");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
return makeRSError("Invalid value for Descriptor Range Flags");
Table.Ranges.push_back(Range);
return Error::success();
}
Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
MDNode *DescriptorTableNode) {
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
return makeRSError("Invalid format for Descriptor Table");
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1,
"ShaderVisibility",
dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
mcdxbc::DescriptorTable Table;
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
return makeRSError(formatv("Missing Root Element Metadata Node.\n{0}",
FmtMDNode{DescriptorTableNode}));
if (auto Err = parseDescriptorRange(Table, Element))
return Err;
}
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
*Visibility, Table);
return Error::success();
}
Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 15)
return makeRSError("Invalid format for Static Sampler");
mcdxbc::StaticSampler Sampler;
Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>(
StaticSamplerNode, 1, "Filter", dxbc::isValidSamplerFilter);
if (auto E = Filter.takeError())
return Error(std::move(E));
Sampler.Filter = *Filter;
Expected<dxbc::TextureAddressMode> AddressU =
extractEnumValue<dxbc::TextureAddressMode>(
StaticSamplerNode, 2, "AddressU", dxbc::isValidAddress);
if (auto E = AddressU.takeError())
return Error(std::move(E));
Sampler.AddressU = *AddressU;
Expected<dxbc::TextureAddressMode> AddressV =
extractEnumValue<dxbc::TextureAddressMode>(
StaticSamplerNode, 3, "AddressV", dxbc::isValidAddress);
if (auto E = AddressV.takeError())
return Error(std::move(E));
Sampler.AddressV = *AddressV;
Expected<dxbc::TextureAddressMode> AddressW =
extractEnumValue<dxbc::TextureAddressMode>(
StaticSamplerNode, 4, "AddressW", dxbc::isValidAddress);
if (auto E = AddressW.takeError())
return Error(std::move(E));
Sampler.AddressW = *AddressW;
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
else
return makeRSError("Invalid value for MipLODBias");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
return makeRSError("Invalid value for MaxAnisotropy");
Expected<dxbc::ComparisonFunc> ComparisonFunc =
extractEnumValue<dxbc::ComparisonFunc>(
StaticSamplerNode, 7, "ComparisonFunc", dxbc::isValidComparisonFunc);
if (auto E = ComparisonFunc.takeError())
return Error(std::move(E));
Sampler.ComparisonFunc = *ComparisonFunc;
Expected<dxbc::StaticBorderColor> BorderColor =
extractEnumValue<dxbc::StaticBorderColor>(
StaticSamplerNode, 8, "BorderColor", dxbc::isValidBorderColor);
if (auto E = BorderColor.takeError())
return Error(std::move(E));
Sampler.BorderColor = *BorderColor;
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
else
return makeRSError("Invalid value for MinLOD");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = *Val;
else
return makeRSError("Invalid value for MaxLOD");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
return makeRSError("Invalid value for ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
return makeRSError("Invalid value for RegisterSpace");
Expected<dxbc::ShaderVisibility> Visibility =
extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13,
"ShaderVisibility",
dxbc::isValidShaderVisibility);
if (auto E = Visibility.takeError())
return Error(std::move(E));
Sampler.ShaderVisibility = *Visibility;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 14))
Sampler.Flags = *Val;
else
return makeRSError("Invalid value for Static Sampler Flags");
RSD.StaticSamplers.push_back(Sampler);
return Error::success();
}
Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
MDNode *Element) {
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
return makeRSError("Invalid format for Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
.Case("RootFlags", RootSignatureElementKind::RootFlags)
.Case("RootConstants", RootSignatureElementKind::RootConstants)
.Case("RootCBV", RootSignatureElementKind::CBV)
.Case("RootSRV", RootSignatureElementKind::SRV)
.Case("RootUAV", RootSignatureElementKind::UAV)
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
.Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
.Default(RootSignatureElementKind::Error);
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
return parseRootFlags(RSD, Element);
case RootSignatureElementKind::RootConstants:
return parseRootConstants(RSD, Element);
case RootSignatureElementKind::CBV:
case RootSignatureElementKind::SRV:
case RootSignatureElementKind::UAV:
return parseRootDescriptors(RSD, Element, ElementKind);
case RootSignatureElementKind::DescriptorTable:
return parseDescriptorTable(RSD, Element);
case RootSignatureElementKind::StaticSamplers:
return parseStaticSampler(RSD, Element);
case RootSignatureElementKind::Error:
return makeRSError(
formatv("Invalid Root Signature Element\n{0}", FmtMDNode{Element}));
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
static Error
validateDescriptorTableSamplerMixin(const mcdxbc::DescriptorTable &Table,
uint32_t Location) {
dxil::ResourceClass CurrRC = dxil::ResourceClass::Sampler;
for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
if (Range.RangeType == dxil::ResourceClass::Sampler &&
CurrRC != dxil::ResourceClass::Sampler)
return makeRSError(
formatv("Samplers cannot be mixed with other resource types in a "
"descriptor table, {0}(location={1})",
getResourceClassName(CurrRC), Location));
CurrRC = Range.RangeType;
}
return Error::success();
}
static Error
validateDescriptorTableRegisterOverflow(const mcdxbc::DescriptorTable &Table,
uint32_t Location) {
uint64_t Offset = 0;
bool IsPrevUnbound = false;
for (const mcdxbc::DescriptorRange &Range : Table.Ranges) {
// Validation of NumDescriptors should have happened by this point.
if (Range.NumDescriptors == 0)
continue;
const uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
Range.BaseShaderRegister, Range.NumDescriptors);
if (!verifyNoOverflowedOffset(RangeBound))
return makeRSError(
formatv("Overflow for shader register range: {0}", FmtRange{Range}));
bool IsAppending =
Range.OffsetInDescriptorsFromTableStart == DescriptorTableOffsetAppend;
if (!IsAppending)
Offset = Range.OffsetInDescriptorsFromTableStart;
if (IsPrevUnbound && IsAppending)
return makeRSError(
formatv("Range {0} cannot be appended after an unbounded range",
FmtRange{Range}));
const uint64_t OffsetBound =
llvm::hlsl::rootsig::computeRangeBound(Offset, Range.NumDescriptors);
if (!verifyNoOverflowedOffset(OffsetBound))
return makeRSError(formatv("Offset overflow for descriptor range: {0}.",
FmtRange{Range}));
Offset = OffsetBound + 1;
IsPrevUnbound =
Range.NumDescriptors == llvm::hlsl::rootsig::NumDescriptorsUnbounded;
}
return Error::success();
}
Error MetadataParser::validateRootSignature(
const mcdxbc::RootSignatureDesc &RSD) {
Error DeferredErrs = Error::success();
if (!hlsl::rootsig::verifyVersion(RSD.Version)) {
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for Version: {0}", RSD.Version)));
}
if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for RootFlags: {0}", RSD.Flags)));
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
switch (Info.Type) {
case dxbc::RootParameterType::Constants32Bit:
break;
case dxbc::RootParameterType::CBV:
case dxbc::RootParameterType::UAV:
case dxbc::RootParameterType::SRV: {
const mcdxbc::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for ShaderRegister: {0}",
Descriptor.ShaderRegister)));
if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for RegisterSpace: {0}",
Descriptor.RegisterSpace)));
bool IsValidFlag =
dxbc::isValidRootDesciptorFlags(Descriptor.Flags) &&
hlsl::rootsig::verifyRootDescriptorFlag(
RSD.Version, dxbc::RootDescriptorFlags(Descriptor.Flags));
if (!IsValidFlag)
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for RootDescriptorFlag: {0}",
Descriptor.Flags)));
break;
}
case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const mcdxbc::DescriptorRange &Range : Table) {
if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for RegisterSpace: {0}",
Range.RegisterSpace)));
if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for NumDescriptors: {0}",
Range.NumDescriptors)));
bool IsValidFlag = dxbc::isValidDescriptorRangeFlags(Range.Flags) &&
hlsl::rootsig::verifyDescriptorRangeFlag(
RSD.Version, Range.RangeType,
dxbc::DescriptorRangeFlags(Range.Flags));
if (!IsValidFlag)
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for DescriptorFlag: {0}",
Range.Flags)));
if (Error Err =
validateDescriptorTableSamplerMixin(Table, Info.Location))
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
if (Error Err =
validateDescriptorTableRegisterOverflow(Table, Info.Location))
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
break;
}
}
}
for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers) {
if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
makeRSError(formatv("Invalid value for MipLODBias: {0:e}",
Sampler.MipLODBias)));
if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
makeRSError(formatv("Invalid value for MaxAnisotropy: {0}",
Sampler.MaxAnisotropy)));
if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
makeRSError(formatv("Invalid value for MinLOD: {0}",
Sampler.MinLOD)));
if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
makeRSError(formatv("Invalid value for MaxLOD: {0}",
Sampler.MaxLOD)));
if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for ShaderRegister: {0}",
Sampler.ShaderRegister)));
if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
DeferredErrs =
joinErrors(std::move(DeferredErrs),
makeRSError(formatv("Invalid value for RegisterSpace: {0}",
Sampler.RegisterSpace)));
bool IsValidFlag =
dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
hlsl::rootsig::verifyStaticSamplerFlags(
RSD.Version, dxbc::StaticSamplerFlags(Sampler.Flags));
if (!IsValidFlag)
DeferredErrs = joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Invalid value for Static Sampler Flag: {0}",
Sampler.Flags)));
}
return DeferredErrs;
}
Expected<mcdxbc::RootSignatureDesc>
MetadataParser::ParseRootSignature(uint32_t Version) {
Error DeferredErrs = Error::success();
mcdxbc::RootSignatureDesc RSD;
RSD.Version = Version;
for (const auto &Operand : Root->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
return joinErrors(
std::move(DeferredErrs),
makeRSError(formatv("Missing Root Element Metadata Node.")));
if (auto Err = parseRootSignatureElement(RSD, Element))
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
if (auto Err = validateRootSignature(RSD))
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
if (DeferredErrs)
return std::move(DeferredErrs);
return std::move(RSD);
}
} // namespace rootsig
} // namespace hlsl
} // namespace llvm