blob: ba29ebc3212ee8b35d55264f4ef60aad6b0af0a2 [file] [log] [blame]
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef _LIBCPP_BARRIER
#define _LIBCPP_BARRIER
/*
barrier synopsis
namespace std
{
template<class CompletionFunction = see below>
class barrier // since C++20
{
public:
using arrival_token = see below;
static constexpr ptrdiff_t max() noexcept;
constexpr explicit barrier(ptrdiff_t phase_count,
CompletionFunction f = CompletionFunction());
~barrier();
barrier(const barrier&) = delete;
barrier& operator=(const barrier&) = delete;
[[nodiscard]] arrival_token arrive(ptrdiff_t update = 1);
void wait(arrival_token&& arrival) const;
void arrive_and_wait();
void arrive_and_drop();
private:
CompletionFunction completion; // exposition only
};
}
*/
#include <__config>
#if !defined(_LIBCPP_HAS_NO_THREADS)
# include <__assert>
# include <__atomic/atomic_base.h>
# include <__atomic/memory_order.h>
# include <__memory/unique_ptr.h>
# include <__thread/poll_with_backoff.h>
# include <__thread/timed_backoff_policy.h>
# include <__utility/move.h>
# include <cstddef>
# include <cstdint>
# include <limits>
# include <version>
# if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
# endif
_LIBCPP_PUSH_MACROS
# include <__undef_macros>
# if _LIBCPP_STD_VER >= 20
_LIBCPP_BEGIN_NAMESPACE_STD
struct __empty_completion {
inline _LIBCPP_HIDE_FROM_ABI void operator()() noexcept {}
};
# ifndef _LIBCPP_HAS_NO_TREE_BARRIER
/*
The default implementation of __barrier_base is a classic tree barrier.
It looks different from literature pseudocode for two main reasons:
1. Threads that call into std::barrier functions do not provide indices,
so a numbering step is added before the actual barrier algorithm,
appearing as an N+1 round to the N rounds of the tree barrier.
2. A great deal of attention has been paid to avoid cache line thrashing
by flattening the tree structure into cache-line sized arrays, that
are indexed in an efficient way.
*/
using __barrier_phase_t = uint8_t;
class __barrier_algorithm_base;
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_EXPORTED_FROM_ABI __barrier_algorithm_base*
__construct_barrier_algorithm_base(ptrdiff_t& __expected);
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_EXPORTED_FROM_ABI bool
__arrive_barrier_algorithm_base(__barrier_algorithm_base* __barrier, __barrier_phase_t __old_phase) noexcept;
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_EXPORTED_FROM_ABI void
__destroy_barrier_algorithm_base(__barrier_algorithm_base* __barrier) noexcept;
template <class _CompletionF>
class __barrier_base {
ptrdiff_t __expected_;
unique_ptr<__barrier_algorithm_base, void (*)(__barrier_algorithm_base*)> __base_;
__atomic_base<ptrdiff_t> __expected_adjustment_;
_CompletionF __completion_;
__atomic_base<__barrier_phase_t> __phase_;
public:
using arrival_token = __barrier_phase_t;
static _LIBCPP_HIDE_FROM_ABI constexpr ptrdiff_t max() noexcept { return numeric_limits<ptrdiff_t>::max(); }
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI
__barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
: __expected_(__expected),
__base_(std::__construct_barrier_algorithm_base(this->__expected_), &__destroy_barrier_algorithm_base),
__expected_adjustment_(0),
__completion_(std::move(__completion)),
__phase_(0) {}
_LIBCPP_NODISCARD _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t __update) {
_LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
__update <= __expected_, "update is greater than the expected count for the current barrier phase");
auto const __old_phase = __phase_.load(memory_order_relaxed);
for (; __update; --__update)
if (__arrive_barrier_algorithm_base(__base_.get(), __old_phase)) {
__completion_();
__expected_ += __expected_adjustment_.load(memory_order_relaxed);
__expected_adjustment_.store(0, memory_order_relaxed);
__phase_.store(__old_phase + 2, memory_order_release);
__phase_.notify_all();
}
return __old_phase;
}
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void wait(arrival_token&& __old_phase) const {
auto const __test_fn = [this, __old_phase]() -> bool { return __phase_.load(memory_order_acquire) != __old_phase; };
std::__libcpp_thread_poll_with_backoff(__test_fn, __libcpp_timed_backoff_policy());
}
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_drop() {
__expected_adjustment_.fetch_sub(1, memory_order_relaxed);
(void)arrive(1);
}
};
# else
/*
The alternative implementation of __barrier_base is a central barrier.
Two versions of this algorithm are provided:
1. A fairly straightforward implementation of the litterature for the
general case where the completion function is not empty.
2. An optimized implementation that exploits 2's complement arithmetic
and well-defined overflow in atomic arithmetic, to handle the phase
roll-over for free.
*/
template <class _CompletionF>
class __barrier_base {
__atomic_base<ptrdiff_t> __expected;
__atomic_base<ptrdiff_t> __arrived;
_CompletionF __completion;
__atomic_base<bool> __phase;
public:
using arrival_token = bool;
static constexpr ptrdiff_t max() noexcept { return numeric_limits<ptrdiff_t>::max(); }
_LIBCPP_HIDE_FROM_ABI __barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
: __expected(__expected), __arrived(__expected), __completion(std::move(__completion)), __phase(false) {}
[[nodiscard]] _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t update) {
auto const __old_phase = __phase.load(memory_order_relaxed);
auto const __result = __arrived.fetch_sub(update, memory_order_acq_rel) - update;
auto const new_expected = __expected.load(memory_order_relaxed);
_LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
update <= new_expected, "update is greater than the expected count for the current barrier phase");
if (0 == __result) {
__completion();
__arrived.store(new_expected, memory_order_relaxed);
__phase.store(!__old_phase, memory_order_release);
__phase.notify_all();
}
return __old_phase;
}
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void wait(arrival_token&& __old_phase) const {
__phase.wait(__old_phase, memory_order_acquire);
}
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_drop() {
__expected.fetch_sub(1, memory_order_relaxed);
(void)arrive(1);
}
};
template <>
class __barrier_base<__empty_completion> {
static constexpr uint64_t __expected_unit = 1ull;
static constexpr uint64_t __arrived_unit = 1ull << 32;
static constexpr uint64_t __expected_mask = __arrived_unit - 1;
static constexpr uint64_t __phase_bit = 1ull << 63;
static constexpr uint64_t __arrived_mask = (__phase_bit - 1) & ~__expected_mask;
__atomic_base<uint64_t> __phase_arrived_expected;
static _LIBCPP_HIDE_FROM_ABI constexpr uint64_t __init(ptrdiff_t __count) _NOEXCEPT {
return ((uint64_t(1u << 31) - __count) << 32) | (uint64_t(1u << 31) - __count);
}
public:
using arrival_token = uint64_t;
static constexpr ptrdiff_t max() noexcept { return ptrdiff_t(1u << 31) - 1; }
_LIBCPP_HIDE_FROM_ABI explicit inline __barrier_base(ptrdiff_t __count, __empty_completion = __empty_completion())
: __phase_arrived_expected(__init(__count)) {}
[[nodiscard]] inline _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t update) {
auto const __inc = __arrived_unit * update;
auto const __old = __phase_arrived_expected.fetch_add(__inc, memory_order_acq_rel);
_LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
update <= __old, "update is greater than the expected count for the current barrier phase");
if ((__old ^ (__old + __inc)) & __phase_bit) {
__phase_arrived_expected.fetch_add((__old & __expected_mask) << 32, memory_order_relaxed);
__phase_arrived_expected.notify_all();
}
return __old & __phase_bit;
}
inline _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void wait(arrival_token&& __phase) const {
auto const __test_fn = [=]() -> bool {
uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
return ((__current & __phase_bit) != __phase);
};
__libcpp_thread_poll_with_backoff(__test_fn, __libcpp_timed_backoff_policy());
}
inline _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_drop() {
__phase_arrived_expected.fetch_add(__expected_unit, memory_order_relaxed);
(void)arrive(1);
}
};
# endif // !_LIBCPP_HAS_NO_TREE_BARRIER
template <class _CompletionF = __empty_completion>
class barrier {
__barrier_base<_CompletionF> __b_;
public:
using arrival_token = typename __barrier_base<_CompletionF>::arrival_token;
static _LIBCPP_HIDE_FROM_ABI constexpr ptrdiff_t max() noexcept { return __barrier_base<_CompletionF>::max(); }
_LIBCPP_AVAILABILITY_SYNC
_LIBCPP_HIDE_FROM_ABI explicit barrier(ptrdiff_t __count, _CompletionF __completion = _CompletionF())
: __b_(__count, std::move(__completion)) {
_LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
__count >= 0,
"barrier::barrier(ptrdiff_t, CompletionFunction): barrier cannot be initialized with a negative value");
_LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(
__count <= max(),
"barrier::barrier(ptrdiff_t, CompletionFunction): barrier cannot be initialized with "
"a value greater than max()");
}
barrier(barrier const&) = delete;
barrier& operator=(barrier const&) = delete;
_LIBCPP_NODISCARD _LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t __update = 1) {
_LIBCPP_ASSERT_ARGUMENT_WITHIN_DOMAIN(__update > 0, "barrier:arrive must be called with a value greater than 0");
return __b_.arrive(__update);
}
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void wait(arrival_token&& __phase) const {
__b_.wait(std::move(__phase));
}
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_wait() { wait(arrive()); }
_LIBCPP_AVAILABILITY_SYNC _LIBCPP_HIDE_FROM_ABI void arrive_and_drop() { __b_.arrive_and_drop(); }
};
_LIBCPP_END_NAMESPACE_STD
# endif // _LIBCPP_STD_VER >= 20
_LIBCPP_POP_MACROS
#endif // !defined(_LIBCPP_HAS_NO_THREADS)
#if !defined(_LIBCPP_REMOVE_TRANSITIVE_INCLUDES) && _LIBCPP_STD_VER <= 20
# include <atomic>
# include <concepts>
# include <iterator>
# include <memory>
# include <stdexcept>
# include <variant>
#endif
#endif // _LIBCPP_BARRIER