blob: 42c629eb415bb1016b7501edd3d8a23d5400bc50 [file] [log] [blame]
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#ifndef _LIBCPP_TEST_EXPERIMENTAL_TASK_SYNC_WAIT
#define _LIBCPP_TEST_EXPERIMENTAL_TASK_SYNC_WAIT
#include <experimental/__config>
#include <experimental/coroutine>
#include <type_traits>
#include <mutex>
#include <condition_variable>
#include "awaitable_traits.hpp"
_LIBCPP_BEGIN_NAMESPACE_EXPERIMENTAL_COROUTINES
// Thread-synchronisation helper that allows one thread to block in a call
// to .wait() until another thread signals the thread by calling .set().
class __oneshot_event
{
public:
__oneshot_event() : __isSet_(false) {}
void set() noexcept
{
unique_lock<mutex> __lock{ __mutex_ };
__isSet_ = true;
__cv_.notify_all();
}
void wait() noexcept
{
unique_lock<mutex> __lock{ __mutex_ };
__cv_.wait(__lock, [this] { return __isSet_; });
}
private:
mutex __mutex_;
condition_variable __cv_;
bool __isSet_;
};
template<typename _Derived>
class __sync_wait_promise_base
{
public:
using __handle_t = coroutine_handle<_Derived>;
private:
struct _FinalAwaiter
{
bool await_ready() noexcept { return false; }
void await_suspend(__handle_t __coro) noexcept
{
__sync_wait_promise_base& __promise = __coro.promise();
__promise.__event_.set();
}
void await_resume() noexcept {}
};
friend struct _FinalAwaiter;
public:
__handle_t get_return_object() { return __handle(); }
suspend_always initial_suspend() { return {}; }
_FinalAwaiter final_suspend() { return {}; }
private:
__handle_t __handle() noexcept
{
return __handle_t::from_promise(static_cast<_Derived&>(*this));
}
protected:
// Start the coroutine and then block waiting for it to finish.
void run() noexcept
{
__handle().resume();
__event_.wait();
}
private:
__oneshot_event __event_;
};
template<typename _Tp>
class __sync_wait_promise final
: public __sync_wait_promise_base<__sync_wait_promise<_Tp>>
{
public:
__sync_wait_promise() : __state_(_State::__empty) {}
~__sync_wait_promise()
{
switch (__state_)
{
case _State::__empty:
case _State::__value:
break;
#ifndef _LIBCPP_NO_EXCEPTIONS
case _State::__exception:
__exception_.~exception_ptr();
break;
#endif
}
}
void return_void() noexcept
{
// Should be unreachable since coroutine should always
// suspend at `co_yield` point where it will be destroyed
// or will fail with an exception and bypass return_void()
// and call unhandled_exception() instead.
std::abort();
}
void unhandled_exception() noexcept
{
#ifndef _LIBCPP_NO_EXCEPTIONS
::new (static_cast<void*>(&__exception_)) exception_ptr(
std::current_exception());
__state_ = _State::__exception;
#else
_VSTD::abort();
#endif
}
auto yield_value(_Tp&& __value) noexcept
{
__valuePtr_ = std::addressof(__value);
__state_ = _State::__value;
return this->final_suspend();
}
_Tp&& get()
{
this->run();
#ifndef _LIBCPP_NO_EXCEPTIONS
if (__state_ == _State::__exception)
{
std::rethrow_exception(_VSTD::move(__exception_));
}
#endif
return static_cast<_Tp&&>(*__valuePtr_);
}
private:
enum class _State {
__empty,
__value,
__exception
};
_State __state_;
union {
std::add_pointer_t<_Tp> __valuePtr_;
std::exception_ptr __exception_;
};
};
template<>
struct __sync_wait_promise<void> final
: public __sync_wait_promise_base<__sync_wait_promise<void>>
{
public:
void unhandled_exception() noexcept
{
#ifndef _LIBCPP_NO_EXCEPTIONS
__exception_ = std::current_exception();
#endif
}
void return_void() noexcept {}
void get()
{
this->run();
#ifndef _LIBCPP_NO_EXCEPTIONS
if (__exception_)
{
std::rethrow_exception(_VSTD::move(__exception_));
}
#endif
}
private:
std::exception_ptr __exception_;
};
template<typename _Tp>
class __sync_wait_task final
{
public:
using promise_type = __sync_wait_promise<_Tp>;
private:
using __handle_t = typename promise_type::__handle_t;
public:
__sync_wait_task(__handle_t __coro) noexcept : __coro_(__coro) {}
~__sync_wait_task()
{
_LIBCPP_ASSERT(__coro_, "Should always have a valid coroutine handle");
__coro_.destroy();
}
decltype(auto) get()
{
return __coro_.promise().get();
}
private:
__handle_t __coro_;
};
template<typename _Tp>
struct __remove_rvalue_reference
{
using type = _Tp;
};
template<typename _Tp>
struct __remove_rvalue_reference<_Tp&&>
{
using type = _Tp;
};
template<typename _Tp>
using __remove_rvalue_reference_t =
typename __remove_rvalue_reference<_Tp>::type;
template<
typename _Awaitable,
typename _AwaitResult = await_result_t<_Awaitable>,
std::enable_if_t<std::is_void_v<_AwaitResult>, int> = 0>
__sync_wait_task<_AwaitResult> __make_sync_wait_task(_Awaitable&& __awaitable)
{
co_await static_cast<_Awaitable&&>(__awaitable);
}
template<
typename _Awaitable,
typename _AwaitResult = await_result_t<_Awaitable>,
std::enable_if_t<!std::is_void_v<_AwaitResult>, int> = 0>
__sync_wait_task<_AwaitResult> __make_sync_wait_task(_Awaitable&& __awaitable)
{
co_yield co_await static_cast<_Awaitable&&>(__awaitable);
}
template<typename _Awaitable>
auto sync_wait(_Awaitable&& __awaitable)
-> __remove_rvalue_reference_t<await_result_t<_Awaitable>>
{
return _VSTD_CORO::__make_sync_wait_task(
static_cast<_Awaitable&&>(__awaitable)).get();
}
_LIBCPP_END_NAMESPACE_EXPERIMENTAL_COROUTINES
#endif