[pstl] Replace direct use of assert() with _PSTL_ASSERT

Standard libraries may (libstdc++ in particular) forbid direct use of
assert()/<cassert> in library code.

Differential Revision: https://reviews.llvm.org/D60249

GitOrigin-RevId: 2e15f4ac572bcf429ec12e8f3efbb8ad254042c7
diff --git a/include/pstl/internal/algorithm_impl.h b/include/pstl/internal/algorithm_impl.h
index b4e1c14..46f4224 100644
--- a/include/pstl/internal/algorithm_impl.h
+++ b/include/pstl/internal/algorithm_impl.h
@@ -2757,8 +2757,8 @@
         return !__internal::__parallel_or(
             std::forward<_ExecutionPolicy>(__exec), __first2, __last2,
             [__first1, __last1, __first2, __last2, &__comp](_ForwardIterator2 __i, _ForwardIterator2 __j) {
-                assert(__j > __i);
-                //assert(__j - __i > 1);
+                _PSTL_ASSERT(__j > __i);
+                //_PSTL_ASSERT(__j - __i > 1);
 
                 //1. moving boundaries to "consume" subsequence of equal elements
                 auto __is_equal = [&__comp](_ForwardIterator2 __a, _ForwardIterator2 __b) -> bool {
@@ -2782,8 +2782,8 @@
                 //2. testing is __a subsequence of the second range included into the first range
                 auto __b = std::lower_bound(__first1, __last1, *__i, __comp);
 
-                assert(!__comp(*(__last1 - 1), *__b));
-                assert(!__comp(*(__j - 1), *__i));
+                _PSTL_ASSERT(!__comp(*(__last1 - 1), *__b));
+                _PSTL_ASSERT(!__comp(*(__j - 1), *__i));
                 return !std::includes(__b, __last1, __i, __j, __comp);
             });
     });
@@ -2971,7 +2971,7 @@
     }
 
     const auto __m2 = __left_bound_seq_2 - __first2;
-    assert(__m1 == 0 || __m2 == 0);
+    _PSTL_ASSERT(__m1 == 0 || __m2 == 0);
     if (__m2 > __set_algo_cut_off)
     {
         auto __res_or = __result;
diff --git a/include/pstl/internal/numeric_impl.h b/include/pstl/internal/numeric_impl.h
index c288111..3e653bb 100644
--- a/include/pstl/internal/numeric_impl.h
+++ b/include/pstl/internal/numeric_impl.h
@@ -304,7 +304,7 @@
 __brick_adjacent_difference(_ForwardIterator1 __first, _ForwardIterator1 __last, _ForwardIterator2 __d_first,
                             BinaryOperation __op, /*is_vector=*/std::true_type) noexcept
 {
-    assert(__first != __last);
+    _PSTL_ASSERT(__first != __last);
 
     typedef typename std::iterator_traits<_ForwardIterator1>::reference _ReferenceType1;
     typedef typename std::iterator_traits<_ForwardIterator2>::reference _ReferenceType2;
@@ -333,7 +333,7 @@
                               _ForwardIterator2 __d_first, _BinaryOperation __op, _IsVector __is_vector,
                               /*is_parallel=*/std::true_type)
 {
-    assert(__first != __last);
+    _PSTL_ASSERT(__first != __last);
     typedef typename std::iterator_traits<_ForwardIterator1>::reference _ReferenceType1;
     typedef typename std::iterator_traits<_ForwardIterator2>::reference _ReferenceType2;
 
diff --git a/include/pstl/internal/parallel_backend_tbb.h b/include/pstl/internal/parallel_backend_tbb.h
index f1836aa..8f1e395 100644
--- a/include/pstl/internal/parallel_backend_tbb.h
+++ b/include/pstl/internal/parallel_backend_tbb.h
@@ -10,7 +10,6 @@
 #ifndef _PSTL_PARALLEL_BACKEND_TBB_H
 #define _PSTL_PARALLEL_BACKEND_TBB_H
 
-#include <cassert>
 #include <algorithm>
 #include <type_traits>
 
@@ -529,7 +528,7 @@
     __task*
     allocate_func_task(_Fn&& __f)
     {
-        assert(_M_execute_data != nullptr);
+        _PSTL_ASSERT(_M_execute_data != nullptr);
         tbb::detail::d1::small_object_allocator __alloc{};
         auto __t =
             __alloc.new_object<__func_task<typename std::decay<_Fn>::type>>(*_M_execute_data, std::forward<_Fn>(__f));
@@ -574,7 +573,7 @@
     make_additional_child_of(__task* __parent, _Fn&& __f)
     {
         auto __t = make_child_of(__parent, std::forward<_Fn>(__f));
-        assert(__parent->_M_refcount.load(std::memory_order_relaxed) > 0);
+        _PSTL_ASSERT(__parent->_M_refcount.load(std::memory_order_relaxed) > 0);
         ++__parent->_M_refcount;
         return __t;
     }
@@ -595,7 +594,7 @@
     inline void
     spawn(__task* __t)
     {
-        assert(_M_execute_data != nullptr);
+        _PSTL_ASSERT(_M_execute_data != nullptr);
         tbb::detail::d1::spawn(*__t, *_M_execute_data->context);
     }
 
@@ -648,11 +647,11 @@
 
         this->~__func_task();
 
-        assert(__parent != nullptr);
-        assert(__parent->_M_refcount.load(std::memory_order_relaxed) > 0);
+        _PSTL_ASSERT(__parent != nullptr);
+        _PSTL_ASSERT(__parent->_M_refcount.load(std::memory_order_relaxed) > 0);
         if (--__parent->_M_refcount == 0)
         {
-            assert(__next == nullptr);
+            _PSTL_ASSERT(__next == nullptr);
             __alloc.deallocate(this, *__ed);
             return __parent;
         }
@@ -864,20 +863,20 @@
     {
         const auto __nx = (_M_xe - _M_xs);
         const auto __ny = (_M_ye - _M_ys);
-        assert(__nx > 0 && __ny > 0);
+        _PSTL_ASSERT(__nx > 0 && __ny > 0);
 
-        assert(_x_orig == _y_orig);
-        assert(!is_partial());
+        _PSTL_ASSERT(_x_orig == _y_orig);
+        _PSTL_ASSERT(!is_partial());
 
         if (_x_orig)
         {
-            assert(std::is_sorted(_M_x_beg + _M_xs, _M_x_beg + _M_xe, _M_comp));
-            assert(std::is_sorted(_M_x_beg + _M_ys, _M_x_beg + _M_ye, _M_comp));
+            _PSTL_ASSERT(std::is_sorted(_M_x_beg + _M_xs, _M_x_beg + _M_xe, _M_comp));
+            _PSTL_ASSERT(std::is_sorted(_M_x_beg + _M_ys, _M_x_beg + _M_ye, _M_comp));
             return !_M_comp(*(_M_x_beg + _M_ys), *(_M_x_beg + _M_xe - 1));
         }
 
-        assert(std::is_sorted(_M_z_beg + _M_xs, _M_z_beg + _M_xe, _M_comp));
-        assert(std::is_sorted(_M_z_beg + _M_ys, _M_z_beg + _M_ye, _M_comp));
+        _PSTL_ASSERT(std::is_sorted(_M_z_beg + _M_xs, _M_z_beg + _M_xe, _M_comp));
+        _PSTL_ASSERT(std::is_sorted(_M_z_beg + _M_ys, _M_z_beg + _M_ye, _M_comp));
         return !_M_comp(*(_M_z_beg + _M_zs + __nx), *(_M_z_beg + _M_zs + __nx - 1));
     }
     void
@@ -885,7 +884,7 @@
     {
         const auto __nx = (_M_xe - _M_xs);
         const auto __ny = (_M_ye - _M_ys);
-        assert(__nx > 0 && __ny > 0);
+        _PSTL_ASSERT(__nx > 0 && __ny > 0);
 
         if (_x_orig)
             __move_range_construct()(_M_x_beg + _M_xs, _M_x_beg + _M_xe, _M_z_beg + _M_zs);
@@ -916,7 +915,7 @@
     __task*
     merge_ranges(__task* __self)
     {
-        assert(_x_orig == _y_orig); //two merged subrange must be lie into the same buffer
+        _PSTL_ASSERT(_x_orig == _y_orig); //two merged subrange must be lie into the same buffer
 
         const auto __nx = (_M_xe - _M_xs);
         const auto __ny = (_M_ye - _M_ys);
@@ -932,15 +931,15 @@
             _M_leaf_merge(_M_x_beg + _M_xs, _M_x_beg + _M_xe, _M_x_beg + _M_ys, _M_x_beg + _M_ye, _M_z_beg + _M_zs,
                           _M_comp, __move_value_construct(), __move_value_construct(), __move_range_construct(),
                           __move_range_construct());
-            assert(parent_merge(__self)); //not root merging task
+            _PSTL_ASSERT(parent_merge(__self)); //not root merging task
         }
         //merge to "origin"
         else
         {
-            assert(_x_orig == _y_orig);
+            _PSTL_ASSERT(_x_orig == _y_orig);
 
-            assert(is_partial() || std::is_sorted(_M_z_beg + _M_xs, _M_z_beg + _M_xe, _M_comp));
-            assert(is_partial() || std::is_sorted(_M_z_beg + _M_ys, _M_z_beg + _M_ye, _M_comp));
+            _PSTL_ASSERT(is_partial() || std::is_sorted(_M_z_beg + _M_xs, _M_z_beg + _M_xe, _M_comp));
+            _PSTL_ASSERT(is_partial() || std::is_sorted(_M_z_beg + _M_ys, _M_z_beg + _M_ye, _M_comp));
 
             const auto __nx = (_M_xe - _M_xs);
             const auto __ny = (_M_ye - _M_ys);
@@ -957,8 +956,8 @@
     __task*
     process_ranges(__task* __self)
     {
-        assert(_x_orig == _y_orig);
-        assert(!_split);
+        _PSTL_ASSERT(_x_orig == _y_orig);
+        _PSTL_ASSERT(!_split);
 
         auto p = parent_merge(__self);
 
@@ -1004,7 +1003,7 @@
     __task*
     split_merging(__task* __self)
     {
-        assert(_x_orig == _y_orig);
+        _PSTL_ASSERT(_x_orig == _y_orig);
         const auto __nx = (_M_xe - _M_xs);
         const auto __ny = (_M_ye - _M_ys);
 
@@ -1076,8 +1075,8 @@
     {
         const _SizeType __nx = (_M_xe - _M_xs);
         const _SizeType __ny = (_M_ye - _M_ys);
-        assert(__nx > 0);
-        assert(__nx > 0);
+        _PSTL_ASSERT(__nx > 0);
+        _PSTL_ASSERT(__nx > 0);
 
         if (__nx < __ny)
             move_x_range();
@@ -1133,7 +1132,7 @@
     if (__n <= __sort_cut_off)
     {
         _M_leaf_sort(_M_xs, _M_xe, _M_comp);
-        assert(!_M_root);
+        _PSTL_ASSERT(!_M_root);
         return nullptr;
     }
 
diff --git a/include/pstl/internal/parallel_backend_utils.h b/include/pstl/internal/parallel_backend_utils.h
index 448f924..e176d7e 100644
--- a/include/pstl/internal/parallel_backend_utils.h
+++ b/include/pstl/internal/parallel_backend_utils.h
@@ -12,7 +12,6 @@
 
 #include <iterator>
 #include <utility>
-#include <cassert>
 #include "utils.h"
 
 #include "pstl_config.h"
@@ -58,7 +57,7 @@
         constexpr bool __same_move_seq = std::is_same<_MoveSequenceX, _MoveSequenceY>::value;
 
         auto __n = _M_nmerge;
-        assert(__n > 0);
+        _PSTL_ASSERT(__n > 0);
 
         auto __nx = __xe - __xs;
         //auto __ny = __ye - __ys;
diff --git a/include/pstl/internal/pstl_config.h b/include/pstl/internal/pstl_config.h
index 7137a37..fc04b6d 100644
--- a/include/pstl/internal/pstl_config.h
+++ b/include/pstl/internal/pstl_config.h
@@ -31,6 +31,11 @@
 #    define _PSTL_USAGE_WARNINGS 0
 #endif
 
+#if !defined(_PSTL_ASSERT)
+#    include <cassert>
+#    define _PSTL_ASSERT(pred) (assert((pred)))
+#endif
+
 // Portability "#pragma" definition
 #ifdef _MSC_VER
 #    define _PSTL_PRAGMA(x) __pragma(x)
diff --git a/test/std/algorithms/alg.sorting/alg.min.max/minmax_element.pass.cpp b/test/std/algorithms/alg.sorting/alg.min.max/minmax_element.pass.cpp
index e1f5051..715b250 100644
--- a/test/std/algorithms/alg.sorting/alg.min.max/minmax_element.pass.cpp
+++ b/test/std/algorithms/alg.sorting/alg.min.max/minmax_element.pass.cpp
@@ -14,7 +14,6 @@
 #include <execution>
 #include <algorithm>
 #include <set>
-#include <cassert>
 #include <cmath>
 
 #include "support/utils.h"