Verify the LLVMContext that an Attribute belongs to.

Attributes don't know their parent Context, adding this would make Attribute larger. Instead, we add hasParentContext that answers whether this Attribute belongs to a particular LLVMContext by checking for itself inside the context's FoldingSet. Same with AttributeSet and AttributeList. The Verifier checks them with the Module context.

Differential Revision: https://reviews.llvm.org/D99362

GitOrigin-RevId: 244d9d6e41db71e76eeb55e56d84f658b3f56681
diff --git a/include/llvm/IR/Attributes.h b/include/llvm/IR/Attributes.h
index 55f342c..8a87a56 100644
--- a/include/llvm/IR/Attributes.h
+++ b/include/llvm/IR/Attributes.h
@@ -208,6 +208,9 @@
   /// is, presumably, for writing out the mnemonics for the assembly writer.
   std::string getAsString(bool InAttrGrp = false) const;
 
+  /// Return true if this attribute belongs to the LLVMContext.
+  bool hasParentContext(LLVMContext &C) const;
+
   /// Equality and non-equality operators.
   bool operator==(Attribute A) const { return pImpl == A.pImpl; }
   bool operator!=(Attribute A) const { return pImpl != A.pImpl; }
@@ -331,6 +334,9 @@
   std::pair<unsigned, unsigned> getVScaleRangeArgs() const;
   std::string getAsString(bool InAttrGrp = false) const;
 
+  /// Return true if this attribute set belongs to the LLVMContext.
+  bool hasParentContext(LLVMContext &C) const;
+
   using iterator = const Attribute *;
 
   iterator begin() const;
@@ -724,6 +730,9 @@
   /// Return the attributes at the index as a string.
   std::string getAsString(unsigned Index, bool InAttrGrp = false) const;
 
+  /// Return true if this attribute list belongs to the LLVMContext.
+  bool hasParentContext(LLVMContext &C) const;
+
   //===--------------------------------------------------------------------===//
   // AttributeList Introspection
   //===--------------------------------------------------------------------===//
@@ -751,6 +760,8 @@
   /// Return true if there are no attributes.
   bool isEmpty() const { return pImpl == nullptr; }
 
+  void print(raw_ostream &O) const;
+
   void dump() const;
 };
 
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
index a3ad7c7..60ad3b8 100644
--- a/lib/IR/Attributes.cpp
+++ b/lib/IR/Attributes.cpp
@@ -607,6 +607,14 @@
   llvm_unreachable("Unknown attribute");
 }
 
+bool Attribute::hasParentContext(LLVMContext &C) const {
+  assert(isValid() && "invalid Attribute doesn't refer to any context");
+  FoldingSetNodeID ID;
+  pImpl->Profile(ID);
+  void *Unused;
+  return C.pImpl->AttrsSet.FindNodeOrInsertPos(ID, Unused) == pImpl;
+}
+
 bool Attribute::operator<(Attribute A) const {
   if (!pImpl && !A.pImpl) return false;
   if (!pImpl) return true;
@@ -835,6 +843,14 @@
   return SetNode ? SetNode->getAsString(InAttrGrp) : "";
 }
 
+bool AttributeSet::hasParentContext(LLVMContext &C) const {
+  assert(hasAttributes() && "empty AttributeSet doesn't refer to any context");
+  FoldingSetNodeID ID;
+  SetNode->Profile(ID);
+  void *Unused;
+  return C.pImpl->AttrsSetNodes.FindNodeOrInsertPos(ID, Unused) == SetNode;
+}
+
 AttributeSet::iterator AttributeSet::begin() const {
   return SetNode ? SetNode->begin() : nullptr;
 }
@@ -1640,6 +1656,14 @@
   return pImpl->begin()[Index];
 }
 
+bool AttributeList::hasParentContext(LLVMContext &C) const {
+  assert(!isEmpty() && "an empty attribute list has no parent context");
+  FoldingSetNodeID ID;
+  pImpl->Profile(ID);
+  void *Unused;
+  return C.pImpl->AttrsLists.FindNodeOrInsertPos(ID, Unused) == pImpl;
+}
+
 AttributeList::iterator AttributeList::begin() const {
   return pImpl ? pImpl->begin() : nullptr;
 }
@@ -1656,17 +1680,19 @@
   return pImpl ? pImpl->NumAttrSets : 0;
 }
 
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-LLVM_DUMP_METHOD void AttributeList::dump() const {
-  dbgs() << "PAL[\n";
+void AttributeList::print(raw_ostream &O) const {
+  O << "PAL[\n";
 
   for (unsigned i = index_begin(), e = index_end(); i != e; ++i) {
     if (getAttributes(i).hasAttributes())
-      dbgs() << "  { " << i << " => " << getAsString(i) << " }\n";
+      O << "  { " << i << " => " << getAsString(i) << " }\n";
   }
 
-  dbgs() << "]\n";
+  O << "]\n";
 }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void AttributeList::dump() const { print(dbgs()); }
 #endif
 
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index cc7ab86..d0bfa7e 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -199,6 +199,27 @@
 
   void Write(const unsigned i) { *OS << i << '\n'; }
 
+  // NOLINTNEXTLINE(readability-identifier-naming)
+  void Write(const Attribute *A) {
+    if (!A)
+      return;
+    *OS << A->getAsString() << '\n';
+  }
+
+  // NOLINTNEXTLINE(readability-identifier-naming)
+  void Write(const AttributeSet *AS) {
+    if (!AS)
+      return;
+    *OS << AS->getAsString() << '\n';
+  }
+
+  // NOLINTNEXTLINE(readability-identifier-naming)
+  void Write(const AttributeList *AL) {
+    if (!AL)
+      return;
+    AL->print(*OS);
+  }
+
   template <typename T> void Write(ArrayRef<T> Vs) {
     for (const T &V : Vs)
       Write(V);
@@ -1856,6 +1877,17 @@
   if (Attrs.isEmpty())
     return;
 
+  Assert(Attrs.hasParentContext(Context),
+         "Attribute list does not match Module context!", &Attrs);
+  for (const auto &AttrSet : Attrs) {
+    Assert(!AttrSet.hasAttributes() || AttrSet.hasParentContext(Context),
+           "Attribute set does not match Module context!", &AttrSet);
+    for (const auto &A : AttrSet) {
+      Assert(A.hasParentContext(Context),
+             "Attribute does not match Module context!", &A);
+    }
+  }
+
   bool SawNest = false;
   bool SawReturned = false;
   bool SawSRet = false;
diff --git a/unittests/IR/AttributesTest.cpp b/unittests/IR/AttributesTest.cpp
index 11b1598..03b4cef 100644
--- a/unittests/IR/AttributesTest.cpp
+++ b/unittests/IR/AttributesTest.cpp
@@ -184,4 +184,37 @@
   EXPECT_EQ(A.getAsString(), "byval(i32)");
 }
 
+TEST(Attributes, HasParentContext) {
+  LLVMContext C1, C2;
+
+  {
+    Attribute Attr1 = Attribute::get(C1, Attribute::AlwaysInline);
+    Attribute Attr2 = Attribute::get(C2, Attribute::AlwaysInline);
+    EXPECT_TRUE(Attr1.hasParentContext(C1));
+    EXPECT_FALSE(Attr1.hasParentContext(C2));
+    EXPECT_FALSE(Attr2.hasParentContext(C1));
+    EXPECT_TRUE(Attr2.hasParentContext(C2));
+  }
+
+  {
+    AttributeSet AS1 = AttributeSet::get(
+        C1, makeArrayRef(Attribute::get(C1, Attribute::NoReturn)));
+    AttributeSet AS2 = AttributeSet::get(
+        C2, makeArrayRef(Attribute::get(C2, Attribute::NoReturn)));
+    EXPECT_TRUE(AS1.hasParentContext(C1));
+    EXPECT_FALSE(AS1.hasParentContext(C2));
+    EXPECT_FALSE(AS2.hasParentContext(C1));
+    EXPECT_TRUE(AS2.hasParentContext(C2));
+  }
+
+  {
+    AttributeList AL1 = AttributeList::get(C1, 1, Attribute::ZExt);
+    AttributeList AL2 = AttributeList::get(C2, 1, Attribute::ZExt);
+    EXPECT_TRUE(AL1.hasParentContext(C1));
+    EXPECT_FALSE(AL1.hasParentContext(C2));
+    EXPECT_FALSE(AL2.hasParentContext(C1));
+    EXPECT_TRUE(AL2.hasParentContext(C2));
+  }
+}
+
 } // end anonymous namespace
diff --git a/unittests/IR/VerifierTest.cpp b/unittests/IR/VerifierTest.cpp
index 6b1217f..af2dd99 100644
--- a/unittests/IR/VerifierTest.cpp
+++ b/unittests/IR/VerifierTest.cpp
@@ -253,5 +253,22 @@
                   .startswith("MDNode context does not match Module context!"));
 }
 
+TEST(VerifierTest, AttributesWrongContext) {
+  LLVMContext C1, C2;
+  Module M1("M", C1);
+  FunctionType *FTy1 =
+      FunctionType::get(Type::getVoidTy(C1), /*isVarArg=*/false);
+  Function *F1 = Function::Create(FTy1, Function::ExternalLinkage, "foo", M1);
+  F1->setDoesNotReturn();
+
+  Module M2("M", C2);
+  FunctionType *FTy2 =
+      FunctionType::get(Type::getVoidTy(C2), /*isVarArg=*/false);
+  Function *F2 = Function::Create(FTy2, Function::ExternalLinkage, "foo", M2);
+  F2->copyAttributesFrom(F1);
+
+  EXPECT_TRUE(verifyFunction(*F2));
+}
+
 } // end anonymous namespace
 } // end namespace llvm