[Sema][Parse][HLSL] Implement front-end rootsignature validations (#156754)
This pr implements the following validations:
1. Check that descriptor tables don't mix Sample and non-Sampler
resources
2. Ensure that descriptor ranges don't append onto an unbounded range
3. Ensure that descriptor ranges don't overflow
4. Adds a missing validation to ensure that only a single `RootFlags`
parameter is provided
Resolves: https://github.com/llvm/llvm-project/issues/153868.
GitOrigin-RevId: 6a9d43a7f6a802b8b36b60de9b729c6a4ab1f634
diff --git a/include/clang/Basic/DiagnosticSemaKinds.td b/include/clang/Basic/DiagnosticSemaKinds.td
index e69123b..b0e669c 100644
--- a/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13202,6 +13202,9 @@
def err_hlsl_invalid_rootsig_value : Error<"value must be in the range [%0, %1]">;
def err_hlsl_invalid_rootsig_flag : Error< "invalid flags for version 1.%0">;
+def err_hlsl_invalid_mixed_resources: Error< "sampler and non-sampler resource mixed in descriptor table">;
+def err_hlsl_appending_onto_unbound: Error<"offset appends to unbounded descriptor range">;
+def err_hlsl_offset_overflow: Error<"descriptor range offset overflows [%0, %1]">;
def subst_hlsl_format_ranges: TextSubstitution<
"%select{t|u|b|s}0[%1;%select{%3]|unbounded)}2">;
diff --git a/lib/Parse/ParseHLSLRootSignature.cpp b/lib/Parse/ParseHLSLRootSignature.cpp
index 9197648..3b16efb 100644
--- a/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/lib/Parse/ParseHLSLRootSignature.cpp
@@ -38,8 +38,18 @@
// Iterate as many RootSignatureElements as possible, until we hit the
// end of the stream
bool HadError = false;
+ bool HasRootFlags = false;
while (!peekExpectedToken(TokenKind::end_of_stream)) {
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+ if (HasRootFlags) {
+ reportDiag(diag::err_hlsl_rootsig_repeat_param)
+ << TokenKind::kw_RootFlags;
+ HadError = true;
+ skipUntilExpectedToken(RootElementKeywords);
+ continue;
+ }
+ HasRootFlags = true;
+
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value()) {
diff --git a/lib/Sema/SemaHLSL.cpp b/lib/Sema/SemaHLSL.cpp
index c14ce2a..9d5c408 100644
--- a/lib/Sema/SemaHLSL.cpp
+++ b/lib/Sema/SemaHLSL.cpp
@@ -1359,12 +1359,48 @@
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
assert(UnboundClauses.size() == Table->NumClauses &&
"Number of unbound elements must match the number of clauses");
+ bool HasAnySampler = false;
+ bool HasAnyNonSampler = false;
+ uint32_t Offset = 0;
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
- uint32_t LowerBound(Clause->Reg.Number);
+ SourceLocation Loc = ClauseElem->getLocation();
+ if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
+ HasAnySampler = true;
+ else
+ HasAnyNonSampler = true;
+
+ if (HasAnySampler && HasAnyNonSampler)
+ Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
+
// Relevant error will have already been reported above and needs to be
- // fixed before we can conduct range analysis, so shortcut error return
+ // fixed before we can conduct further analysis, so shortcut error
+ // return
if (Clause->NumDescriptors == 0)
return true;
+
+ if (Clause->Offset !=
+ llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
+ // Manually specified the offset
+ Offset = Clause->Offset;
+ }
+
+ uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
+ Offset, Clause->NumDescriptors);
+
+ if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
+ // Trying to append onto unbound offset
+ Diag(Loc, diag::err_hlsl_appending_onto_unbound);
+ } else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound)) {
+ // Upper bound overflows maximum offset
+ Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
+ }
+
+ Offset = RangeBound == llvm::hlsl::rootsig::NumDescriptorsUnbounded
+ ? uint32_t(RangeBound)
+ : uint32_t(RangeBound + 1);
+
+ // Compute the register bounds and track resource binding
+ uint32_t LowerBound(Clause->Reg.Number);
uint32_t UpperBound = Clause->NumDescriptors == ~0u
? ~0u
: LowerBound + Clause->NumDescriptors - 1;
diff --git a/test/SemaHLSL/RootSignature-err.hlsl b/test/SemaHLSL/RootSignature-err.hlsl
index ccfa093..89c684c 100644
--- a/test/SemaHLSL/RootSignature-err.hlsl
+++ b/test/SemaHLSL/RootSignature-err.hlsl
@@ -179,7 +179,7 @@
// expected-error@+2 {{value must be in the range [1, 4294967294]}}
// expected-error@+1 {{value must be in the range [1, 4294967294]}}
-[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0), Sampler(s0, numDescriptors = 0))")]
+[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0)), DescriptorTable(Sampler(s0, numDescriptors = 0))")]
void basic_validation_4() {}
// expected-error@+2 {{value must be in the range [0, 16]}}
@@ -189,4 +189,8 @@
// expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
[RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
-void basic_validation_6() {}
\ No newline at end of file
+void basic_validation_6() {}
+
+// expected-error@+1 {{sampler and non-sampler resource mixed in descriptor table}}
+[RootSignature("DescriptorTable(Sampler(s0), CBV(b0))")]
+void mixed_resource_table() {}
diff --git a/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl b/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
index fd098b0..2d025d0 100644
--- a/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
+++ b/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
@@ -117,3 +117,28 @@
// expected-note@+1 {{overlapping resource range here}}
[RootSignature(DuplicatesRootSignature)]
void valid_root_signature_15() {}
+
+#define AppendingToUnbound \
+ "DescriptorTable(CBV(b1, numDescriptors = unbounded), CBV(b0))"
+
+// expected-error@+1 {{offset appends to unbounded descriptor range}}
+[RootSignature(AppendingToUnbound)]
+void append_to_unbound_signature() {}
+
+#define DirectOffsetOverflow \
+ "DescriptorTable(CBV(b0, offset = 4294967294 , numDescriptors = 6))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967294, 4294967299]}}
+[RootSignature(DirectOffsetOverflow)]
+void direct_offset_overflow_signature() {}
+
+#define AppendOffsetOverflow \
+ "DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 7))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967293, 4294967299]}}
+[RootSignature(AppendOffsetOverflow)]
+void append_offset_overflow_signature() {}
+
+// expected-error@+1 {{descriptor range offset overflows [4294967292, 4294967296]}}
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292, numDescriptors = 5))")]
+void offset_() {}
diff --git a/test/SemaHLSL/RootSignature-resource-ranges.hlsl b/test/SemaHLSL/RootSignature-resource-ranges.hlsl
index 09a1110..10e7215 100644
--- a/test/SemaHLSL/RootSignature-resource-ranges.hlsl
+++ b/test/SemaHLSL/RootSignature-resource-ranges.hlsl
@@ -22,3 +22,6 @@
[RootSignature("DescriptorTable(SRV(t5), UAV(u5, numDescriptors=2))")]
void valid_root_signature_6() {}
+
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 3))")]
+void valid_root_signature_7() {}
diff --git a/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 44c0978..9b9f5dd 100644
--- a/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -501,8 +501,6 @@
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
using llvm::dxbc::RootFlags;
const llvm::StringLiteral Source = R"cc(
- RootFlags(),
- RootFlags(0),
RootFlags(
deny_domain_shader_root_access |
deny_pixel_shader_root_access |
@@ -533,18 +531,10 @@
ASSERT_FALSE(Parser.parse());
auto Elements = Parser.getElements();
- ASSERT_EQ(Elements.size(), 3u);
+ ASSERT_EQ(Elements.size(), 1u);
RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
- ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
- Elem = Elements[1].getElement();
- ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
- ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
- Elem = Elements[2].getElement();
- ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
auto ValidRootFlags = RootFlags::AllowInputAssemblerInputLayout |
RootFlags::DenyVertexShaderRootAccess |
RootFlags::DenyHullShaderRootAccess |
@@ -562,6 +552,64 @@
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyRootFlagsTest) {
+ using llvm::dxbc::RootFlags;
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(),
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ auto Elements = Parser.getElements();
+ ASSERT_EQ(Elements.size(), 1u);
+
+ RootElement Elem = Elements[0].getElement();
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseZeroRootFlagsTest) {
+ using llvm::dxbc::RootFlags;
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(0),
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ auto Elements = Parser.getElements();
+ ASSERT_EQ(Elements.size(), 1u);
+
+ RootElement Elem = Elements[0].getElement();
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
using llvm::dxbc::RootDescriptorFlags;
const llvm::StringLiteral Source = R"cc(
@@ -1658,4 +1706,27 @@
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, InvalidMultipleRootFlagsTest) {
+ // This test will check that an error is produced when there are multiple
+ // root flags provided
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(DENY_VERTEX_SHADER_ROOT_ACCESS),
+ RootFlags(DENY_PIXEL_SHADER_ROOT_ACCESS)
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
} // anonymous namespace