[Offload] Check for initialization (#144370)
All entry points (except olInit) now check that offload has been
initialized. If not, a new `OL_ERRC_UNINITIALIZED` error is returned.
diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td
index 8a2ecd6..cd8c3c6 100644
--- a/offload/liboffload/API/Common.td
+++ b/offload/liboffload/API/Common.td
@@ -106,6 +106,7 @@
Etor<"ASSEMBLE_FAILURE", "assembler failure while processing binary image">,
Etor<"LINK_FAILURE", "linker failure while processing binary image">,
Etor<"BACKEND_FAILURE", "the plugin backend is in an invalid or unsupported state">,
+ Etor<"UNINITIALIZED", "not initialized">,
// Handle related errors - only makes sense for liboffload
Etor<"INVALID_NULL_HANDLE", "a handle argument is null when it should not be">,
diff --git a/offload/liboffload/include/OffloadImpl.hpp b/offload/liboffload/include/OffloadImpl.hpp
index a12d8c4..f98164d 100644
--- a/offload/liboffload/include/OffloadImpl.hpp
+++ b/offload/liboffload/include/OffloadImpl.hpp
@@ -26,6 +26,7 @@
namespace offload {
bool isTracingEnabled();
bool isValidationEnabled();
+bool isOffloadInitialized();
} // namespace offload
} // namespace llvm
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index f02497c..8a48756 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -120,9 +120,10 @@
// If the context is uninited, then we assume tracing is disabled
bool isTracingEnabled() {
- return OffloadContextVal && OffloadContext::get().TracingEnabled;
+ return isOffloadInitialized() && OffloadContext::get().TracingEnabled;
}
bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; }
+bool isOffloadInitialized() { return OffloadContextVal != nullptr; }
template <typename HandleT> Error olDestroy(HandleT Handle) {
delete Handle;
diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp
index 13aa0d1..4e42e49 100644
--- a/offload/tools/offload-tblgen/EntryPointGen.cpp
+++ b/offload/tools/offload-tblgen/EntryPointGen.cpp
@@ -82,6 +82,10 @@
}
OS << ") {\n";
+ // Check offload is initialized
+ if (F.getName() != "olInit")
+ OS << "if (!llvm::offload::isOffloadInitialized()) return &UninitError;";
+
// Emit pre-call prints
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
@@ -143,6 +147,14 @@
void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
OS << GenericHeader;
+
+ constexpr const char *UninitMessage =
+ "liboffload has not been initialized - please call olInit before using "
+ "this API";
+ OS << formatv("static {0}_error_struct_t UninitError = "
+ "{{{1}_ERRC_UNINITIALIZED, \"{2}\"};",
+ PrefixLower, PrefixUpper, UninitMessage);
+
for (auto *R : Records.getAllDerivedDefinitions("Function")) {
EmitValidationFunc(FunctionRec{R}, OS);
EmitEntryPointFunc(FunctionRec{R}, OS);
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt
index 2844b67..05e8628 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -12,6 +12,10 @@
event/olDestroyEvent.cpp
event/olWaitEvent.cpp)
+add_offload_unittest("init"
+ init/olInit.cpp)
+target_compile_definitions("init.unittests" PRIVATE DISABLE_WRAPPER)
+
add_offload_unittest("kernel"
kernel/olGetKernel.cpp
kernel/olLaunchKernel.cpp)
diff --git a/offload/unittests/OffloadAPI/common/Environment.cpp b/offload/unittests/OffloadAPI/common/Environment.cpp
index 9433472..ef092cd 100644
--- a/offload/unittests/OffloadAPI/common/Environment.cpp
+++ b/offload/unittests/OffloadAPI/common/Environment.cpp
@@ -17,11 +17,13 @@
// Wrapper so we don't have to constantly init and shutdown Offload in every
// test, while having sensible lifetime for the platform environment
+#ifndef DISABLE_WRAPPER
struct OffloadInitWrapper {
OffloadInitWrapper() { olInit(); }
~OffloadInitWrapper() { olShutDown(); }
};
static OffloadInitWrapper Wrapper{};
+#endif
static cl::opt<std::string>
SelectedPlatform("platform", cl::desc("Only test the specified platform"),
diff --git a/offload/unittests/OffloadAPI/init/olInit.cpp b/offload/unittests/OffloadAPI/init/olInit.cpp
new file mode 100644
index 0000000..8e27e77
--- /dev/null
+++ b/offload/unittests/OffloadAPI/init/olInit.cpp
@@ -0,0 +1,22 @@
+//===------- Offload API tests - olInit -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// NOTE: For this test suite, the implicit olInit/olShutDown doesn't happen, so
+// tests have to do it themselves
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+struct olInitTest : ::testing::Test {};
+
+TEST_F(olInitTest, Uninitialized) {
+ ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
+ olIterateDevices(
+ [](ol_device_handle_t, void *) { return false; }, nullptr));
+}