[NFC][Cloning] Replace IdentityMD set with a predicate in ValueMapper (#129147)


Summary:
We used the set only to check if it contains certain metadata nodes.
Replacing the set with a predicate makes the intention clearer and the
API more general.

Test Plan:
ninja check-all
diff --git a/llvm/include/llvm/Transforms/Utils/Cloning.h b/llvm/include/llvm/Transforms/Utils/Cloning.h
index d36f914..2252dda 100644
--- a/llvm/include/llvm/Transforms/Utils/Cloning.h
+++ b/llvm/include/llvm/Transforms/Utils/Cloning.h
@@ -194,7 +194,7 @@
                                ValueToValueMapTy &VMap, RemapFlags RemapFlag,
                                ValueMapTypeRemapper *TypeMapper = nullptr,
                                ValueMaterializer *Materializer = nullptr,
-                               const MetadataSetTy *IdentityMD = nullptr);
+                               const MetadataPredicate *IdentityMD = nullptr);
 
 /// Clone OldFunc's body into NewFunc.
 void CloneFunctionBodyInto(Function &NewFunc, const Function &OldFunc,
@@ -204,7 +204,7 @@
                            ClonedCodeInfo *CodeInfo = nullptr,
                            ValueMapTypeRemapper *TypeMapper = nullptr,
                            ValueMaterializer *Materializer = nullptr,
-                           const MetadataSetTy *IdentityMD = nullptr);
+                           const MetadataPredicate *IdentityMD = nullptr);
 
 void CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc,
                                const Instruction *StartingInst,
diff --git a/llvm/include/llvm/Transforms/Utils/ValueMapper.h b/llvm/include/llvm/Transforms/Utils/ValueMapper.h
index 852d709..560df1d 100644
--- a/llvm/include/llvm/Transforms/Utils/ValueMapper.h
+++ b/llvm/include/llvm/Transforms/Utils/ValueMapper.h
@@ -37,6 +37,7 @@
 using ValueToValueMapTy = ValueMap<const Value *, WeakTrackingVH>;
 using DbgRecordIterator = simple_ilist<DbgRecord>::iterator;
 using MetadataSetTy = SmallPtrSet<const Metadata *, 16>;
+using MetadataPredicate = std::function<bool(const Metadata *)>;
 
 /// This is a class that can be implemented by clients to remap types when
 /// cloning constants and instructions.
@@ -138,8 +139,8 @@
 /// alternate \a ValueToValueMapTy and \a ValueMaterializer and returns a ID to
 /// pass into the schedule*() functions.
 ///
-/// If an \a IdentityMD set is optionally provided, \a Metadata inside this set
-/// will be mapped onto itself in \a VM on first use.
+/// If an \a IdentityMD predicate is optionally provided, \a Metadata for which
+/// the predicate returns true will be mapped onto itself in \a VM on first use.
 ///
 /// TODO: lib/Linker really doesn't need the \a ValueHandle in the \a
 /// ValueToValueMapTy.  We should template \a ValueMapper (and its
@@ -158,7 +159,7 @@
   ValueMapper(ValueToValueMapTy &VM, RemapFlags Flags = RF_None,
               ValueMapTypeRemapper *TypeMapper = nullptr,
               ValueMaterializer *Materializer = nullptr,
-              const MetadataSetTy *IdentityMD = nullptr);
+              const MetadataPredicate *IdentityMD = nullptr);
   ValueMapper(ValueMapper &&) = delete;
   ValueMapper(const ValueMapper &) = delete;
   ValueMapper &operator=(ValueMapper &&) = delete;
@@ -225,7 +226,7 @@
                        RemapFlags Flags = RF_None,
                        ValueMapTypeRemapper *TypeMapper = nullptr,
                        ValueMaterializer *Materializer = nullptr,
-                       const MetadataSetTy *IdentityMD = nullptr) {
+                       const MetadataPredicate *IdentityMD = nullptr) {
   return ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .mapValue(*V);
 }
@@ -239,8 +240,8 @@
 ///     \c MD.
 ///  3. Else if \c MD is a \a ConstantAsMetadata, call \a MapValue() and
 ///     re-wrap its return (returning nullptr on nullptr).
-///  4. Else if \c MD is in \c IdentityMD then add an identity mapping for it
-///     and return it.
+///  4. Else if \c IdentityMD predicate returns true for \c MD then add an
+///     identity mapping for it and return it.
 ///  5. Else, \c MD is an \a MDNode.  These are remapped, along with their
 ///     transitive operands.  Distinct nodes are duplicated or moved depending
 ///     on \a RF_MoveDistinctNodes.  Uniqued nodes are remapped like constants.
@@ -251,7 +252,7 @@
                              RemapFlags Flags = RF_None,
                              ValueMapTypeRemapper *TypeMapper = nullptr,
                              ValueMaterializer *Materializer = nullptr,
-                             const MetadataSetTy *IdentityMD = nullptr) {
+                             const MetadataPredicate *IdentityMD = nullptr) {
   return ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .mapMetadata(*MD);
 }
@@ -261,7 +262,7 @@
                            RemapFlags Flags = RF_None,
                            ValueMapTypeRemapper *TypeMapper = nullptr,
                            ValueMaterializer *Materializer = nullptr,
-                           const MetadataSetTy *IdentityMD = nullptr) {
+                           const MetadataPredicate *IdentityMD = nullptr) {
   return ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .mapMDNode(*MD);
 }
@@ -278,7 +279,7 @@
                              RemapFlags Flags = RF_None,
                              ValueMapTypeRemapper *TypeMapper = nullptr,
                              ValueMaterializer *Materializer = nullptr,
-                             const MetadataSetTy *IdentityMD = nullptr) {
+                             const MetadataPredicate *IdentityMD = nullptr) {
   ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .remapInstruction(*I);
 }
@@ -289,7 +290,7 @@
                            RemapFlags Flags = RF_None,
                            ValueMapTypeRemapper *TypeMapper = nullptr,
                            ValueMaterializer *Materializer = nullptr,
-                           const MetadataSetTy *IdentityMD = nullptr) {
+                           const MetadataPredicate *IdentityMD = nullptr) {
   ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .remapDbgRecord(M, *DR);
 }
@@ -302,7 +303,7 @@
                                 RemapFlags Flags = RF_None,
                                 ValueMapTypeRemapper *TypeMapper = nullptr,
                                 ValueMaterializer *Materializer = nullptr,
-                                const MetadataSetTy *IdentityMD = nullptr) {
+                                const MetadataPredicate *IdentityMD = nullptr) {
   ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .remapDbgRecordRange(M, Range);
 }
@@ -317,7 +318,7 @@
                           RemapFlags Flags = RF_None,
                           ValueMapTypeRemapper *TypeMapper = nullptr,
                           ValueMaterializer *Materializer = nullptr,
-                          const MetadataSetTy *IdentityMD = nullptr) {
+                          const MetadataPredicate *IdentityMD = nullptr) {
   ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD).remapFunction(F);
 }
 
@@ -326,7 +327,7 @@
                           RemapFlags Flags = RF_None,
                           ValueMapTypeRemapper *TypeMapper = nullptr,
                           ValueMaterializer *Materializer = nullptr,
-                          const MetadataSetTy *IdentityMD = nullptr) {
+                          const MetadataPredicate *IdentityMD = nullptr) {
   return ValueMapper(VM, Flags, TypeMapper, Materializer, IdentityMD)
       .mapConstant(*V);
 }
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
index be35ef9..b2c4e64 100644
--- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -921,11 +921,14 @@
   auto savedLinkage = NewF->getLinkage();
   NewF->setLinkage(llvm::GlobalValue::ExternalLinkage);
 
+  MetadataPredicate IdentityMD = [&](const Metadata *MD) {
+    return CommonDebugInfo.contains(MD);
+  };
   CloneFunctionAttributesInto(NewF, &OrigF, VMap, false);
   CloneFunctionMetadataInto(*NewF, OrigF, VMap, RF_None, nullptr, nullptr,
-                            &CommonDebugInfo);
+                            &IdentityMD);
   CloneFunctionBodyInto(*NewF, OrigF, VMap, RF_None, Returns, "", nullptr,
-                        nullptr, nullptr, &CommonDebugInfo);
+                        nullptr, nullptr, &IdentityMD);
 
   auto &Context = NewF->getContext();
 
diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp
index dd1b4fe..502c489 100644
--- a/llvm/lib/Transforms/Utils/CloneFunction.cpp
+++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp
@@ -204,7 +204,7 @@
                                      RemapFlags RemapFlag,
                                      ValueMapTypeRemapper *TypeMapper,
                                      ValueMaterializer *Materializer,
-                                     const MetadataSetTy *IdentityMD) {
+                                     const MetadataPredicate *IdentityMD) {
   SmallVector<std::pair<unsigned, MDNode *>, 1> MDs;
   OldFunc.getAllMetadata(MDs);
   for (auto MD : MDs) {
@@ -221,7 +221,7 @@
                                  ClonedCodeInfo *CodeInfo,
                                  ValueMapTypeRemapper *TypeMapper,
                                  ValueMaterializer *Materializer,
-                                 const MetadataSetTy *IdentityMD) {
+                                 const MetadataPredicate *IdentityMD) {
   if (OldFunc.isDeclaration())
     return;
 
@@ -328,8 +328,10 @@
   DISubprogram *SPClonedWithinModule =
       CollectDebugInfoForCloning(*OldFunc, Changes, DIFinder);
 
-  MetadataSetTy IdentityMD =
-      FindDebugInfoToIdentityMap(Changes, DIFinder, SPClonedWithinModule);
+  MetadataPredicate IdentityMD =
+      [MDSet =
+           FindDebugInfoToIdentityMap(Changes, DIFinder, SPClonedWithinModule)](
+          const Metadata *MD) { return MDSet.contains(MD); };
 
   // Cloning is always a Module level operation, since Metadata needs to be
   // cloned.
diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp
index 9882f81..5e50536 100644
--- a/llvm/lib/Transforms/Utils/ValueMapper.cpp
+++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp
@@ -120,12 +120,12 @@
   SmallVector<WorklistEntry, 4> Worklist;
   SmallVector<DelayedBasicBlock, 1> DelayedBBs;
   SmallVector<Constant *, 16> AppendingInits;
-  const MetadataSetTy *IdentityMD;
+  const MetadataPredicate *IdentityMD;
 
 public:
   Mapper(ValueToValueMapTy &VM, RemapFlags Flags,
          ValueMapTypeRemapper *TypeMapper, ValueMaterializer *Materializer,
-         const MetadataSetTy *IdentityMD)
+         const MetadataPredicate *IdentityMD)
       : Flags(Flags), TypeMapper(TypeMapper),
         MCs(1, MappingContext(VM, Materializer)), IdentityMD(IdentityMD) {}
 
@@ -904,11 +904,10 @@
     return wrapConstantAsMetadata(*CMD, mapValue(CMD->getValue()));
   }
 
-  // Map metadata from IdentityMD on first use. We need to add these nodes to
-  // the mapping as otherwise metadata nodes numbering gets messed up. This is
-  // still economical because the amount of data in IdentityMD may be a lot
-  // larger than what will actually get used.
-  if (IdentityMD && IdentityMD->contains(MD))
+  // Map metadata matching IdentityMD predicate on first use. We need to add
+  // these nodes to the mapping as otherwise metadata nodes numbering gets
+  // messed up.
+  if (IdentityMD && (*IdentityMD)(MD))
     return getVM().MD()[MD] = TrackingMDRef(const_cast<Metadata *>(MD));
 
   assert(isa<MDNode>(MD) && "Expected a metadata node");
@@ -1211,7 +1210,7 @@
 ValueMapper::ValueMapper(ValueToValueMapTy &VM, RemapFlags Flags,
                          ValueMapTypeRemapper *TypeMapper,
                          ValueMaterializer *Materializer,
-                         const MetadataSetTy *IdentityMD)
+                         const MetadataPredicate *IdentityMD)
     : pImpl(new Mapper(VM, Flags, TypeMapper, Materializer, IdentityMD)) {}
 
 ValueMapper::~ValueMapper() { delete getAsMapper(pImpl); }