blob: 0b65515120b7f56cb38815707de9bacd6c1818ad [file] [log] [blame]
//===-- SPSWrapperFunctionTest.cpp ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test SPSWrapperFunction and associated utilities.
//
//===----------------------------------------------------------------------===//
#include "orc-rt/SPSWrapperFunction.h"
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
#include "gtest/gtest.h"
using namespace orc_rt;
/// Make calls and call result handlers directly on the current thread.
class DirectCaller {
private:
class DirectResultSender {
public:
virtual ~DirectResultSender() {}
virtual void send(orc_rt_SessionRef Session,
WrapperFunctionBuffer ResultBytes) = 0;
static void send(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionBuffer ResultBytes) {
std::unique_ptr<DirectResultSender>(
reinterpret_cast<DirectResultSender *>(CallCtx))
->send(Session, ResultBytes);
}
};
template <typename ImplFn>
class DirectResultSenderImpl : public DirectResultSender {
public:
DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
void send(orc_rt_SessionRef Session,
WrapperFunctionBuffer ResultBytes) override {
Fn(Session, std::move(ResultBytes));
}
private:
std::decay_t<ImplFn> Fn;
};
template <typename ImplFn>
static std::unique_ptr<DirectResultSender>
makeDirectResultSender(ImplFn &&Fn) {
return std::make_unique<DirectResultSenderImpl<ImplFn>>(
std::forward<ImplFn>(Fn));
}
public:
DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
: Session(Session), Fn(Fn) {}
template <typename HandleResultFn>
void operator()(HandleResultFn &&HandleResult,
WrapperFunctionBuffer ArgBytes) {
auto DR =
makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
Fn(Session, reinterpret_cast<void *>(DR.release()),
DirectResultSender::send, ArgBytes.release());
}
private:
orc_rt_SessionRef Session;
orc_rt_WrapperFunction Fn;
};
static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<void()>::handle(
Session, CallCtx, Return, ArgBytes,
[](move_only_function<void()> Return) { Return(); });
}
TEST(SPSWrapperFunctionUtilsTest, TestVoidNoop) {
bool Ran = false;
SPSWrapperFunction<void()>::call(DirectCaller(nullptr, void_noop_sps_wrapper),
[&](Error Err) {
cantFail(std::move(Err));
Ran = true;
});
EXPECT_TRUE(Ran);
}
static void add_via_lambda_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
Session, CallCtx, Return, ArgBytes,
[](move_only_function<void(int32_t)> Return, int32_t X, int32_t Y) {
Return(X + Y);
});
}
TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaLambda) {
int32_t Result = 0;
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
DirectCaller(nullptr, add_via_lambda_sps_wrapper),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EXPECT_EQ(Result, 42);
}
static void add_via_function(move_only_function<void(int32_t)> Return,
int32_t X, int32_t Y) {
Return(X + Y);
}
static void
add_via_function_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
Session, CallCtx, Return, ArgBytes, add_via_function);
}
TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunction) {
int32_t Result = 0;
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
DirectCaller(nullptr, add_via_function_sps_wrapper),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EXPECT_EQ(Result, 42);
}
static void
add_via_function_pointer_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
Session, CallCtx, Return, ArgBytes, &add_via_function);
}
TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) {
int32_t Result = 0;
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
DirectCaller(nullptr, add_via_function_pointer_sps_wrapper),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EXPECT_EQ(Result, 42);
}