[Offload] Erase entries from JIT cache when program is destroyed (#148847)
When `unloadBinary` is called, any entries in the JITEngine's cache
for that binary will be cleared. This fixes a nasty issue with
liboffload program handles. If two handles happen to have had the same
address (after one was free'd, for example), the cache would be hit and
return the wrong program.
diff --git a/offload/plugins-nextgen/common/include/JIT.h b/offload/plugins-nextgen/common/include/JIT.h
index 8c53043..d62516d 100644
--- a/offload/plugins-nextgen/common/include/JIT.h
+++ b/offload/plugins-nextgen/common/include/JIT.h
@@ -55,6 +55,10 @@
process(const __tgt_device_image &Image,
target::plugin::GenericDeviceTy &Device);
+ /// Remove \p Image from the jit engine's cache
+ void erase(const __tgt_device_image &Image,
+ target::plugin::GenericDeviceTy &Device);
+
private:
/// Compile the bitcode image \p Image and generate the binary image that can
/// be loaded to the target device of the triple \p Triple architecture \p
@@ -89,11 +93,13 @@
/// LLVM Context in which the modules will be constructed.
LLVMContext Context;
- /// Output images generated from LLVM backend.
- SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
+ /// A map of embedded IR images to the buffer used to store JITed code
+ DenseMap<const __tgt_device_image *, std::unique_ptr<MemoryBuffer>>
+ JITImages;
/// A map of embedded IR images to JITed images.
- DenseMap<const __tgt_device_image *, __tgt_device_image *> TgtImageMap;
+ DenseMap<const __tgt_device_image *, std::unique_ptr<__tgt_device_image>>
+ TgtImageMap;
};
/// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute
diff --git a/offload/plugins-nextgen/common/src/JIT.cpp b/offload/plugins-nextgen/common/src/JIT.cpp
index c82a06e..00720fa 100644
--- a/offload/plugins-nextgen/common/src/JIT.cpp
+++ b/offload/plugins-nextgen/common/src/JIT.cpp
@@ -285,8 +285,8 @@
// Check if we JITed this image for the given compute unit kind before.
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
- if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
- return JITedImage;
+ if (CUI.TgtImageMap.contains(&Image))
+ return CUI.TgtImageMap[&Image].get();
auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind);
if (!ObjMBOrErr)
@@ -296,17 +296,15 @@
if (!ImageMBOrErr)
return ImageMBOrErr.takeError();
- CUI.JITImages.push_back(std::move(*ImageMBOrErr));
- __tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
- JITedImage = new __tgt_device_image();
+ CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)});
+ auto &ImageMB = CUI.JITImages[&Image];
+ CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()});
+ auto &JITedImage = CUI.TgtImageMap[&Image];
*JITedImage = Image;
-
- auto &ImageMB = CUI.JITImages.back();
-
JITedImage->ImageStart = const_cast<char *>(ImageMB->getBufferStart());
JITedImage->ImageEnd = const_cast<char *>(ImageMB->getBufferEnd());
- return JITedImage;
+ return JITedImage.get();
}
Expected<const __tgt_device_image *>
@@ -324,3 +322,13 @@
return &Image;
}
+
+void JITEngine::erase(const __tgt_device_image &Image,
+ target::plugin::GenericDeviceTy &Device) {
+ std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
+ const std::string &ComputeUnitKind = Device.getComputeUnitKind();
+ ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
+
+ CUI.TgtImageMap.erase(&Image);
+ CUI.JITImages.erase(&Image);
+}
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 81b9d42..94a050b 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -854,6 +854,9 @@
return Err;
}
+ if (Image->getTgtImageBitcode())
+ Plugin.getJIT().erase(*Image->getTgtImageBitcode(), Image->getDevice());
+
return unloadBinaryImpl(Image);
}