[Sema][Parse][HLSL] Implement front-end rootsignature validations (#156754)

This pr implements the following validations:

1. Check that descriptor tables don't mix Sample and non-Sampler
resources
2. Ensure that descriptor ranges don't append onto an unbounded range
3. Ensure that descriptor ranges don't overflow
4. Adds a missing validation to ensure that only a single `RootFlags`
parameter is provided

Resolves: https://github.com/llvm/llvm-project/issues/153868.
GitOrigin-RevId: 6a9d43a7f6a802b8b36b60de9b729c6a4ab1f634
diff --git a/include/clang/Basic/DiagnosticSemaKinds.td b/include/clang/Basic/DiagnosticSemaKinds.td
index e69123b..b0e669c 100644
--- a/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13202,6 +13202,9 @@
 
 def err_hlsl_invalid_rootsig_value : Error<"value must be in the range [%0, %1]">;
 def err_hlsl_invalid_rootsig_flag : Error< "invalid flags for version 1.%0">;
+def err_hlsl_invalid_mixed_resources: Error< "sampler and non-sampler resource mixed in descriptor table">;
+def err_hlsl_appending_onto_unbound: Error<"offset appends to unbounded descriptor range">;
+def err_hlsl_offset_overflow: Error<"descriptor range offset overflows [%0, %1]">;
 
 def subst_hlsl_format_ranges: TextSubstitution<
   "%select{t|u|b|s}0[%1;%select{%3]|unbounded)}2">;
diff --git a/lib/Parse/ParseHLSLRootSignature.cpp b/lib/Parse/ParseHLSLRootSignature.cpp
index 9197648..3b16efb 100644
--- a/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/lib/Parse/ParseHLSLRootSignature.cpp
@@ -38,8 +38,18 @@
   // Iterate as many RootSignatureElements as possible, until we hit the
   // end of the stream
   bool HadError = false;
+  bool HasRootFlags = false;
   while (!peekExpectedToken(TokenKind::end_of_stream)) {
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      if (HasRootFlags) {
+        reportDiag(diag::err_hlsl_rootsig_repeat_param)
+            << TokenKind::kw_RootFlags;
+        HadError = true;
+        skipUntilExpectedToken(RootElementKeywords);
+        continue;
+      }
+      HasRootFlags = true;
+
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Flags = parseRootFlags();
       if (!Flags.has_value()) {
diff --git a/lib/Sema/SemaHLSL.cpp b/lib/Sema/SemaHLSL.cpp
index c14ce2a..9d5c408 100644
--- a/lib/Sema/SemaHLSL.cpp
+++ b/lib/Sema/SemaHLSL.cpp
@@ -1359,12 +1359,48 @@
                    std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
       assert(UnboundClauses.size() == Table->NumClauses &&
              "Number of unbound elements must match the number of clauses");
+      bool HasAnySampler = false;
+      bool HasAnyNonSampler = false;
+      uint32_t Offset = 0;
       for (const auto &[Clause, ClauseElem] : UnboundClauses) {
-        uint32_t LowerBound(Clause->Reg.Number);
+        SourceLocation Loc = ClauseElem->getLocation();
+        if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
+          HasAnySampler = true;
+        else
+          HasAnyNonSampler = true;
+
+        if (HasAnySampler && HasAnyNonSampler)
+          Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
+
         // Relevant error will have already been reported above and needs to be
-        // fixed before we can conduct range analysis, so shortcut error return
+        // fixed before we can conduct further analysis, so shortcut error
+        // return
         if (Clause->NumDescriptors == 0)
           return true;
+
+        if (Clause->Offset !=
+            llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
+          // Manually specified the offset
+          Offset = Clause->Offset;
+        }
+
+        uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
+            Offset, Clause->NumDescriptors);
+
+        if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
+          // Trying to append onto unbound offset
+          Diag(Loc, diag::err_hlsl_appending_onto_unbound);
+        } else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound)) {
+          // Upper bound overflows maximum offset
+          Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
+        }
+
+        Offset = RangeBound == llvm::hlsl::rootsig::NumDescriptorsUnbounded
+                     ? uint32_t(RangeBound)
+                     : uint32_t(RangeBound + 1);
+
+        // Compute the register bounds and track resource binding
+        uint32_t LowerBound(Clause->Reg.Number);
         uint32_t UpperBound = Clause->NumDescriptors == ~0u
                                   ? ~0u
                                   : LowerBound + Clause->NumDescriptors - 1;
diff --git a/test/SemaHLSL/RootSignature-err.hlsl b/test/SemaHLSL/RootSignature-err.hlsl
index ccfa093..89c684c 100644
--- a/test/SemaHLSL/RootSignature-err.hlsl
+++ b/test/SemaHLSL/RootSignature-err.hlsl
@@ -179,7 +179,7 @@
 
 // expected-error@+2 {{value must be in the range [1, 4294967294]}}
 // expected-error@+1 {{value must be in the range [1, 4294967294]}}
-[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0), Sampler(s0, numDescriptors = 0))")]
+[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0)), DescriptorTable(Sampler(s0, numDescriptors = 0))")]
 void basic_validation_4() {}
 
 // expected-error@+2 {{value must be in the range [0, 16]}}
@@ -189,4 +189,8 @@
 
 // expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
 [RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
-void basic_validation_6() {}
\ No newline at end of file
+void basic_validation_6() {}
+
+// expected-error@+1 {{sampler and non-sampler resource mixed in descriptor table}}
+[RootSignature("DescriptorTable(Sampler(s0), CBV(b0))")]
+void mixed_resource_table() {}
diff --git a/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl b/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
index fd098b0..2d025d0 100644
--- a/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
+++ b/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
@@ -117,3 +117,28 @@
 // expected-note@+1 {{overlapping resource range here}}
 [RootSignature(DuplicatesRootSignature)]
 void valid_root_signature_15() {}
+
+#define AppendingToUnbound \
+  "DescriptorTable(CBV(b1, numDescriptors = unbounded), CBV(b0))"
+
+// expected-error@+1 {{offset appends to unbounded descriptor range}}
+[RootSignature(AppendingToUnbound)]
+void append_to_unbound_signature() {}
+
+#define DirectOffsetOverflow \
+  "DescriptorTable(CBV(b0, offset = 4294967294 , numDescriptors = 6))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967294, 4294967299]}}
+[RootSignature(DirectOffsetOverflow)]
+void direct_offset_overflow_signature() {}
+
+#define AppendOffsetOverflow \
+  "DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 7))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967293, 4294967299]}}
+[RootSignature(AppendOffsetOverflow)]
+void append_offset_overflow_signature() {}
+
+// expected-error@+1 {{descriptor range offset overflows [4294967292, 4294967296]}}
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292, numDescriptors = 5))")]
+void offset_() {}
diff --git a/test/SemaHLSL/RootSignature-resource-ranges.hlsl b/test/SemaHLSL/RootSignature-resource-ranges.hlsl
index 09a1110..10e7215 100644
--- a/test/SemaHLSL/RootSignature-resource-ranges.hlsl
+++ b/test/SemaHLSL/RootSignature-resource-ranges.hlsl
@@ -22,3 +22,6 @@
 
 [RootSignature("DescriptorTable(SRV(t5), UAV(u5, numDescriptors=2))")]
 void valid_root_signature_6() {}
+
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 3))")]
+void valid_root_signature_7() {}
diff --git a/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 44c0978..9b9f5dd 100644
--- a/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -501,8 +501,6 @@
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
   using llvm::dxbc::RootFlags;
   const llvm::StringLiteral Source = R"cc(
-    RootFlags(),
-    RootFlags(0),
     RootFlags(
       deny_domain_shader_root_access |
       deny_pixel_shader_root_access |
@@ -533,18 +531,10 @@
   ASSERT_FALSE(Parser.parse());
 
   auto Elements = Parser.getElements();
-  ASSERT_EQ(Elements.size(), 3u);
+  ASSERT_EQ(Elements.size(), 1u);
 
   RootElement Elem = Elements[0].getElement();
   ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
-  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
-  Elem = Elements[1].getElement();
-  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
-  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
-  Elem = Elements[2].getElement();
-  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
   auto ValidRootFlags = RootFlags::AllowInputAssemblerInputLayout |
                         RootFlags::DenyVertexShaderRootAccess |
                         RootFlags::DenyHullShaderRootAccess |
@@ -562,6 +552,64 @@
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyRootFlagsTest) {
+  using llvm::dxbc::RootFlags;
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(),
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  auto Elements = Parser.getElements();
+  ASSERT_EQ(Elements.size(), 1u);
+
+  RootElement Elem = Elements[0].getElement();
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseZeroRootFlagsTest) {
+  using llvm::dxbc::RootFlags;
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(0),
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+  // Test no diagnostics produced
+  Consumer->setNoDiag();
+
+  ASSERT_FALSE(Parser.parse());
+
+  auto Elements = Parser.getElements();
+  ASSERT_EQ(Elements.size(), 1u);
+
+  RootElement Elem = Elements[0].getElement();
+  ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+  ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
   using llvm::dxbc::RootDescriptorFlags;
   const llvm::StringLiteral Source = R"cc(
@@ -1658,4 +1706,27 @@
   ASSERT_TRUE(Consumer->isSatisfied());
 }
 
+TEST_F(ParseHLSLRootSignatureTest, InvalidMultipleRootFlagsTest) {
+  // This test will check that an error is produced when there are multiple
+  // root flags provided
+  const llvm::StringLiteral Source = R"cc(
+    RootFlags(DENY_VERTEX_SHADER_ROOT_ACCESS),
+    RootFlags(DENY_PIXEL_SHADER_ROOT_ACCESS)
+  )cc";
+
+  auto Ctx = createMinimalASTContext();
+  StringLiteral *Signature = wrapSource(Ctx, Source);
+
+  TrivialModuleLoader ModLoader;
+  auto PP = createPP(Source, ModLoader);
+
+  hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+  // Test correct diagnostic produced
+  Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+  ASSERT_TRUE(Parser.parse());
+
+  ASSERT_TRUE(Consumer->isSatisfied());
+}
+
 } // anonymous namespace