[libc++][PSTL] Implement std::count{,_if}

Reviewed By: ldionne, #libc

Spies: libcxx-commits

Differential Revision: https://reviews.llvm.org/D150128

GitOrigin-RevId: 7a3b528e1b540b4c98f4b557f917447481872749
diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt
index 910e6b1..62d6ea5 100644
--- a/include/CMakeLists.txt
+++ b/include/CMakeLists.txt
@@ -85,6 +85,7 @@
   __algorithm/pstl_backends/cpu_backends/transform.h
   __algorithm/pstl_backends/cpu_backends/transform_reduce.h
   __algorithm/pstl_copy.h
+  __algorithm/pstl_count.h
   __algorithm/pstl_fill.h
   __algorithm/pstl_find.h
   __algorithm/pstl_for_each.h
diff --git a/include/__algorithm/pstl_backend.h b/include/__algorithm/pstl_backend.h
index c25a8b1..73e2c48 100644
--- a/include/__algorithm/pstl_backend.h
+++ b/include/__algorithm/pstl_backend.h
@@ -113,6 +113,12 @@
   temlate <class _ExecutionPolicy, class _Iterator>
   __iter_value_type<_Iterator> __pstl_reduce(_Backend, _Iterator __first, _Iterator __last);
 
+  template <class _ExecuitonPolicy, class _Iterator, class _Tp>
+  __iter_diff_t<_Iterator> __pstl_count(_Backend, _Iterator __first, _Iterator __last, const _Tp& __value);
+
+  template <class _ExecutionPolicy, class _Iterator, class _Predicate>
+  __iter_diff_t<_Iterator> __pstl_count_if(_Backend, _Iterator __first, _Iterator __last, _Predicate __pred);
+
 // TODO: Complete this list
 
 */
diff --git a/include/__algorithm/pstl_count.h b/include/__algorithm/pstl_count.h
new file mode 100644
index 0000000..7f591c9
--- /dev/null
+++ b/include/__algorithm/pstl_count.h
@@ -0,0 +1,86 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ALGORITHM_PSTL_COUNT_H
+#define _LIBCPP___ALGORITHM_PSTL_COUNT_H
+
+#include <__algorithm/count.h>
+#include <__algorithm/for_each.h>
+#include <__algorithm/pstl_backend.h>
+#include <__algorithm/pstl_for_each.h>
+#include <__algorithm/pstl_frontend_dispatch.h>
+#include <__atomic/atomic.h>
+#include <__config>
+#include <__iterator/iterator_traits.h>
+#include <__numeric/pstl_transform_reduce.h>
+#include <__type_traits/is_execution_policy.h>
+#include <__type_traits/remove_cvref.h>
+#include <__utility/terminate_on_exception.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+#if !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class>
+void __pstl_count_if(); // declaration needed for the frontend dispatch below
+
+template <class _ExecutionPolicy,
+          class _ForwardIterator,
+          class _Predicate,
+          class _RawPolicy                                    = __remove_cvref_t<_ExecutionPolicy>,
+          enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
+_LIBCPP_HIDE_FROM_ABI __iter_diff_t<_ForwardIterator>
+count_if(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, _Predicate __pred) {
+  using __diff_t = __iter_diff_t<_ForwardIterator>;
+  return std::__pstl_frontend_dispatch(
+      _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count_if),
+      [&](_ForwardIterator __g_first, _ForwardIterator __g_last, _Predicate __g_pred) {
+        return std::transform_reduce(
+            __policy,
+            std::move(__g_first),
+            std::move(__g_last),
+            __diff_t(),
+            std::plus{},
+            [&](__iter_reference<_ForwardIterator> __element) -> bool { return __g_pred(__element); });
+      },
+      std::move(__first),
+      std::move(__last),
+      std::move(__pred));
+}
+
+template <class>
+void __pstl_count(); // declaration needed for the frontend dispatch below
+
+template <class _ExecutionPolicy,
+          class _ForwardIterator,
+          class _Tp,
+          class _RawPolicy                                    = __remove_cvref_t<_ExecutionPolicy>,
+          enable_if_t<is_execution_policy_v<_RawPolicy>, int> = 0>
+_LIBCPP_HIDE_FROM_ABI __iter_diff_t<_ForwardIterator>
+count(_ExecutionPolicy&& __policy, _ForwardIterator __first, _ForwardIterator __last, const _Tp& __value) {
+  return std::__pstl_frontend_dispatch(
+      _LIBCPP_PSTL_CUSTOMIZATION_POINT(__pstl_count),
+      [&](_ForwardIterator __g_first, _ForwardIterator __g_last, const _Tp& __g_value) {
+        return std::count_if(__policy, __g_first, __g_last, [&](__iter_reference<_ForwardIterator> __v) {
+          return __v == __g_value;
+        });
+      },
+      std::move(__first),
+      std::move(__last),
+      __value);
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // !defined(_LIBCPP_HAS_NO_INCOMPLETE_PSTL) && _LIBCPP_STD_VER >= 17
+
+#endif // _LIBCPP___ALGORITHM_PSTL_COUNT_H
diff --git a/include/algorithm b/include/algorithm
index 24d29fd..04a21c0 100644
--- a/include/algorithm
+++ b/include/algorithm
@@ -1802,6 +1802,7 @@
 #include <__algorithm/prev_permutation.h>
 #include <__algorithm/pstl_any_all_none_of.h>
 #include <__algorithm/pstl_copy.h>
+#include <__algorithm/pstl_count.h>
 #include <__algorithm/pstl_fill.h>
 #include <__algorithm/pstl_find.h>
 #include <__algorithm/pstl_for_each.h>
diff --git a/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp b/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
index 76188e7..2a634c3 100644
--- a/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
+++ b/test/libcxx/algorithms/pstl.robust_against_customization_points_not_working.pass.cpp
@@ -42,6 +42,26 @@
   return true;
 }
 
+bool pstl_count_called = false;
+
+template <class, class ForwardIterator, class T>
+typename std::iterator_traits<ForwardIterator>::difference_type
+__pstl_count(TestBackend, ForwardIterator, ForwardIterator, const T&) {
+  assert(!pstl_count_called);
+  pstl_count_called = true;
+  return 0;
+}
+
+bool pstl_count_if_called = false;
+
+template <class, class ForwardIterator, class Pred>
+typename std::iterator_traits<ForwardIterator>::difference_type
+__pstl_count_if(TestBackend, ForwardIterator, ForwardIterator, Pred) {
+  assert(!pstl_count_if_called);
+  pstl_count_if_called = true;
+  return 0;
+}
+
 bool pstl_none_of_called = false;
 
 template <class, class ForwardIterator, class Pred>
@@ -197,6 +217,10 @@
   assert(std::pstl_all_of_called);
   (void)std::none_of(TestPolicy{}, std::begin(a), std::end(a), pred);
   assert(std::pstl_none_of_called);
+  (void)std::count(TestPolicy{}, std::begin(a), std::end(a), 0);
+  assert(std::pstl_count_called);
+  (void)std::count_if(TestPolicy{}, std::begin(a), std::end(a), pred);
+  assert(std::pstl_count_if_called);
   (void)std::fill(TestPolicy{}, std::begin(a), std::end(a), 0);
   assert(std::pstl_fill_called);
   (void)std::fill_n(TestPolicy{}, std::begin(a), std::size(a), 0);
diff --git a/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp b/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp
new file mode 100644
index 0000000..f00861f
--- /dev/null
+++ b/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count.pass.cpp
@@ -0,0 +1,86 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// UNSUPPORTED: libcpp-has-no-incomplete-pstl
+
+// <algorithm>
+
+// template<class ExecutionPolicy, class ForwardIterator, class T>
+//   typename iterator_traits<ForwardIterator>::difference_type
+//     count(ExecutionPolicy&& exec,
+//           ForwardIterator first, ForwardIterator last, const T& value);
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <vector>
+
+#include "test_macros.h"
+#include "test_execution_policies.h"
+#include "test_iterators.h"
+
+EXECUTION_POLICY_SFINAE_TEST(count);
+
+static_assert(sfinae_test_count<int, int*, int*, bool (*)(int)>);
+static_assert(!sfinae_test_count<std::execution::parallel_policy, int*, int*, int>);
+
+template <class Iter>
+struct Test {
+  template <class Policy>
+  void operator()(Policy&& policy) {
+    { // simple test
+      int a[]            = {1, 2, 3, 4, 5};
+      decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 1);
+    }
+
+    { // test that an empty range works
+      std::array<int, 0> a;
+      decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 0);
+    }
+
+    { // test that a single-element range works
+      int a[] = {1};
+      decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 1);
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 1);
+    }
+
+    { // test that a two-element range works
+      int a[] = {1, 3};
+      decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 1);
+    }
+
+    { // test that a three-element range works
+      int a[] = {3, 1, 3};
+      decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 3);
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 2);
+    }
+
+    { // test that a large range works
+      std::vector<int> a(100, 2);
+      decltype(auto) ret = std::count(policy, std::begin(a), std::end(a), 2);
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 100);
+    }
+  }
+};
+
+int main(int, char**) {
+  types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
+
+  return 0;
+}
diff --git a/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp b/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp
new file mode 100644
index 0000000..489c7a7
--- /dev/null
+++ b/test/std/algorithms/alg.nonmodifying/alg.count/pstl.count_if.pass.cpp
@@ -0,0 +1,86 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// UNSUPPORTED: c++03, c++11, c++14
+
+// UNSUPPORTED: libcpp-has-no-incomplete-pstl
+
+// <algorithm>
+
+// template<class ExecutionPolicy, class ForwardIterator, class Predicate>
+//   typename iterator_traits<ForwardIterator>::difference_type
+//     count_if(ExecutionPolicy&& exec,
+//              ForwardIterator first, ForwardIterator last, Predicate pred);
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <vector>
+
+#include "test_macros.h"
+#include "test_execution_policies.h"
+#include "test_iterators.h"
+
+EXECUTION_POLICY_SFINAE_TEST(count_if);
+
+static_assert(sfinae_test_count_if<int, int*, int*, bool (*)(int)>);
+static_assert(!sfinae_test_count_if<std::execution::parallel_policy, int*, int*, int>);
+
+template <class Iter>
+struct Test {
+  template <class Policy>
+  void operator()(Policy&& policy) {
+    { // simple test
+      int a[]            = {1, 2, 3, 4, 5};
+      decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 2);
+    }
+
+    { // test that an empty range works
+      std::array<int, 0> a;
+      decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 0);
+    }
+
+    { // test that a single-element range works
+      int a[] = {1};
+      decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 1);
+    }
+
+    { // test that a two-element range works
+      int a[] = {1, 3};
+      decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 1);
+    }
+
+    { // test that a three-element range works
+      int a[] = {2, 3, 2};
+      decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 2);
+    }
+
+    { // test that a large range works
+      std::vector<int> a(100, 2);
+      decltype(auto) ret = std::count_if(policy, std::begin(a), std::end(a), [](int i) { return i < 3; });
+      static_assert(std::is_same_v<decltype(ret), typename std::iterator_traits<Iter>::difference_type>);
+      assert(ret == 100);
+    }
+  }
+};
+
+int main(int, char**) {
+  types::for_each(types::forward_iterator_list<int*>{}, TestIteratorWithPolicies<Test>{});
+
+  return 0;
+}