[libcxx] Add strict weak ordering checks to sorting algorithms
This is the implementation of the first proposal of strict weak ordering checks described in https://discourse.llvm.org/t/rfc-strict-weak-ordering-checks-in-the-debug-libc/70217
This targets the most vulnerable algorithms like std::sort
Reviewed By: philnik, #libc
Differential Revision: https://reviews.llvm.org/D150264
GitOrigin-RevId: 7e1ee1e10dc0b77914de714b8f420c487e5705c6
diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt
index 02eb636..1e91e18 100644
--- a/include/CMakeLists.txt
+++ b/include/CMakeLists.txt
@@ -324,6 +324,7 @@
__coroutine/trivial_awaitables.h
__debug
__debug_utils/randomize_range.h
+ __debug_utils/strict_weak_ordering_check.h
__exception/exception.h
__exception/exception_ptr.h
__exception/nested_exception.h
diff --git a/include/__algorithm/sort.h b/include/__algorithm/sort.h
index 77e0b2e..3215c52 100644
--- a/include/__algorithm/sort.h
+++ b/include/__algorithm/sort.h
@@ -23,6 +23,7 @@
#include <__config>
#include <__debug>
#include <__debug_utils/randomize_range.h>
+#include <__debug_utils/strict_weak_ordering_check.h>
#include <__functional/operations.h>
#include <__functional/ranges_operations.h>
#include <__iterator/iterator_traits.h>
@@ -921,6 +922,7 @@
} else {
std::__sort_dispatch<_AlgPolicy>(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __comp);
}
+ std::__check_strict_weak_ordering_sorted(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __comp);
}
template <class _RandomAccessIterator, class _Comp>
diff --git a/include/__algorithm/sort_heap.h b/include/__algorithm/sort_heap.h
index 0dc9acc..ed72ff9 100644
--- a/include/__algorithm/sort_heap.h
+++ b/include/__algorithm/sort_heap.h
@@ -14,6 +14,7 @@
#include <__algorithm/iterator_operations.h>
#include <__algorithm/pop_heap.h>
#include <__config>
+#include <__debug_utils/strict_weak_ordering_check.h>
#include <__iterator/iterator_traits.h>
#include <__type_traits/is_copy_assignable.h>
#include <__type_traits/is_copy_constructible.h>
@@ -28,11 +29,13 @@
template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14
void __sort_heap(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare&& __comp) {
+ _RandomAccessIterator __saved_last = __last;
__comp_ref_type<_Compare> __comp_ref = __comp;
using difference_type = typename iterator_traits<_RandomAccessIterator>::difference_type;
for (difference_type __n = __last - __first; __n > 1; --__last, (void) --__n)
std::__pop_heap<_AlgPolicy>(__first, __last, __comp_ref, __n);
+ std::__check_strict_weak_ordering_sorted(__first, __saved_last, __comp_ref);
}
template <class _RandomAccessIterator, class _Compare>
diff --git a/include/__algorithm/stable_sort.h b/include/__algorithm/stable_sort.h
index 0c9daa2..38fd2be 100644
--- a/include/__algorithm/stable_sort.h
+++ b/include/__algorithm/stable_sort.h
@@ -15,6 +15,7 @@
#include <__algorithm/iterator_operations.h>
#include <__algorithm/sort.h>
#include <__config>
+#include <__debug_utils/strict_weak_ordering_check.h>
#include <__iterator/iterator_traits.h>
#include <__memory/destruct_n.h>
#include <__memory/temporary_buffer.h>
@@ -259,6 +260,7 @@
}
std::__stable_sort<_AlgPolicy, __comp_ref_type<_Compare> >(__first, __last, __comp, __len, __buf.first, __buf.second);
+ std::__check_strict_weak_ordering_sorted(__first, __last, __comp);
}
template <class _RandomAccessIterator, class _Compare>
diff --git a/include/__debug b/include/__debug
index 19ed474..1a080fd 100644
--- a/include/__debug
+++ b/include/__debug
@@ -23,6 +23,10 @@
# define _LIBCPP_DEBUG_RANDOMIZE_UNSPECIFIED_STABILITY
#endif
+#if defined(_LIBCPP_ENABLE_DEBUG_MODE) && !defined(_LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK)
+# define _LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK
+#endif
+
#if defined(_LIBCPP_ENABLE_DEBUG_MODE) && !defined(_LIBCPP_DEBUG_ITERATOR_BOUNDS_CHECKING)
# define _LIBCPP_DEBUG_ITERATOR_BOUNDS_CHECKING
#endif
diff --git a/include/__debug_utils/strict_weak_ordering_check.h b/include/__debug_utils/strict_weak_ordering_check.h
new file mode 100644
index 0000000..cfdc434
--- /dev/null
+++ b/include/__debug_utils/strict_weak_ordering_check.h
@@ -0,0 +1,76 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___LIBCXX_DEBUG_STRICT_WEAK_ORDERING_CHECK
+#define _LIBCPP___LIBCXX_DEBUG_STRICT_WEAK_ORDERING_CHECK
+
+#include <__config>
+
+#include <__algorithm/comp_ref_type.h>
+#include <__algorithm/is_sorted.h>
+#include <__assert>
+#include <__iterator/iterator_traits.h>
+#include <__type_traits/is_constant_evaluated.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+# pragma GCC system_header
+#endif
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _RandomAccessIterator, class _Comp>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
+__check_strict_weak_ordering_sorted(_RandomAccessIterator __first, _RandomAccessIterator __last, _Comp& __comp) {
+#ifdef _LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK
+ using __diff_t = __iter_diff_t<_RandomAccessIterator>;
+ using _Comp_ref = __comp_ref_type<_Comp>;
+ if (!__libcpp_is_constant_evaluated()) {
+ // Check if the range is actually sorted.
+ _LIBCPP_ASSERT((std::is_sorted<_RandomAccessIterator, _Comp_ref>(__first, __last, _Comp_ref(__comp))),
+ "The range is not sorted after the sort, your comparator is not a valid strict-weak ordering");
+ // Limit the number of elements we need to check.
+ __diff_t __size = __last - __first > __diff_t(100) ? __diff_t(100) : __last - __first;
+ __diff_t __p = 0;
+ while (__p < __size) {
+ __diff_t __q = __p + __diff_t(1);
+ // Find first element that is greater than *(__first+__p).
+ while (__q < __size && !__comp(*(__first + __p), *(__first + __q))) {
+ ++__q;
+ }
+ // Check that the elements from __p to __q are equal between each other.
+ for (__diff_t __b = __p; __b < __q; ++__b) {
+ for (__diff_t __a = __p; __a <= __b; ++__a) {
+ _LIBCPP_ASSERT(
+ !__comp(*(__first + __a), *(__first + __b)), "Your comparator is not a valid strict-weak ordering");
+ _LIBCPP_ASSERT(
+ !__comp(*(__first + __b), *(__first + __a)), "Your comparator is not a valid strict-weak ordering");
+ }
+ }
+ // Check that elements between __p and __q are less than between __q and __size.
+ for (__diff_t __a = __p; __a < __q; ++__a) {
+ for (__diff_t __b = __q; __b < __size; ++__b) {
+ _LIBCPP_ASSERT(
+ __comp(*(__first + __a), *(__first + __b)), "Your comparator is not a valid strict-weak ordering");
+ _LIBCPP_ASSERT(
+ !__comp(*(__first + __b), *(__first + __a)), "Your comparator is not a valid strict-weak ordering");
+ }
+ }
+ // Skip these equal elements.
+ __p = __q;
+ }
+ }
+#else
+ (void)__first;
+ (void)__last;
+ (void)__comp;
+#endif
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___LIBCXX_DEBUG_STRICT_WEAK_ORDERING_CHECK
diff --git a/include/module.modulemap.in b/include/module.modulemap.in
index 46dd028..38216f4 100644
--- a/include/module.modulemap.in
+++ b/include/module.modulemap.in
@@ -1141,7 +1141,8 @@
}
module __debug_utils {
- module randomize_range { private header "__debug_utils/randomize_range.h" }
+ module randomize_range { private header "__debug_utils/randomize_range.h" }
+ module strict_weak_ordering_check { private header "__debug_utils/strict_weak_ordering_check.h" }
}
module limits {
diff --git a/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp b/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
index a175890..06f0854 100644
--- a/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
+++ b/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp
@@ -11,23 +11,26 @@
// REQUIRES: has-unix-headers
// UNSUPPORTED: c++03, c++11, c++14, c++17
// XFAIL: availability-verbose_abort-missing
-// ADDITIONAL_COMPILE_FLAGS: -D_LIBCPP_ENABLE_ASSERTIONS=1
+// ADDITIONAL_COMPILE_FLAGS: -D_LIBCPP_ENABLE_ASSERTIONS=1 -D_LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK
// This test uses a specific combination of an invalid comparator and sequence of values to
-// ensure that our sorting functions do not go out-of-bounds in that case. Instead, we should
-// fail loud with an assertion. The specific issue we're looking for here is when the comparator
-// does not satisfy the following property:
+// ensure that our sorting functions do not go out-of-bounds and satisfy strict weak ordering in that case.
+// Instead, we should fail loud with an assertion. The specific issue we're looking for here is when the comparator
+// does not satisfy the strict weak ordering:
//
-// comp(a, b) implies that !comp(b, a)
-//
-// In other words,
-//
-// a < b implies that !(b < a)
+// Irreflexivity: comp(a, a) is false
+// Antisymmetry: comp(a, b) implies that !comp(b, a)
+// Transitivity: comp(a, b), comp(b, c) imply comp(a, c)
+// Transitivity of equivalence: !comp(a, b), !comp(b, a), !comp(b, c), !comp(c, b) imply !comp(a, c), !comp(c, a)
//
// If this is not satisfied, we have seen issues in the past where the std::sort implementation
-// would proceed to do OOB reads (rdar://106897934).
+// would proceed to do OOB reads. (rdar://106897934).
+// Other algorithms like std::stable_sort, std::sort_heap do not go out of bounds but can produce
+// incorrect results, we also want to assert on that.
+// Sometimes std::sort does not go out of bounds as well, for example, right now if transitivity
+// of equivalence is not met, std::sort can only produce incorrect result but would not fail.
-// When the debug mode is enabled, this test fails because we actually catch that the comparator
+// When the debug mode is enabled, this test fails because we actually catch on the fly that the comparator
// is not a strict-weak ordering before we catch that we'd dereference out-of-bounds inside std::sort,
// which leads to different errors than the ones tested below.
// XFAIL: libcpp-has-debug-mode
@@ -35,9 +38,11 @@
#include <algorithm>
#include <cassert>
#include <cstddef>
+#include <limits>
#include <map>
#include <memory>
#include <ranges>
+#include <random>
#include <set>
#include <string>
#include <vector>
@@ -45,7 +50,7 @@
#include "bad_comparator_values.h"
#include "check_assertion.h"
-int main(int, char**) {
+void check_oob_sort_read() {
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
auto values = std::views::split(line, ' ');
@@ -90,20 +95,27 @@
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
- std::stable_sort(copy.begin(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
+ TEST_LIBCPP_ASSERT_FAILURE(std::stable_sort(copy.begin(), copy.end(), checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
- std::partial_sort(copy.begin(), copy.begin(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
+ std::make_heap(copy.begin(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
+ TEST_LIBCPP_ASSERT_FAILURE(std::sort_heap(copy.begin(), copy.end(), checked_predicate), "not a valid strict-weak ordering");
+ }
+ {
+ std::vector<std::size_t*> copy;
+ for (auto const& e : elements)
+ copy.push_back(e.get());
+ TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort(copy.begin(), copy.end(), copy.end(), checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
std::vector<std::size_t*> results(copy.size(), nullptr);
- std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate); // doesn't go OOB even with invalid comparator
+ TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
@@ -123,20 +135,27 @@
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
- std::ranges::stable_sort(copy, checked_predicate); // doesn't go OOB even with invalid comparator
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::stable_sort(copy, checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
- std::ranges::partial_sort(copy, copy.begin(), checked_predicate); // doesn't go OOB even with invalid comparator
+ std::ranges::make_heap(copy, checked_predicate); // doesn't go OOB even with invalid comparator
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort_heap(copy, checked_predicate), "not a valid strict-weak ordering");
+ }
+ {
+ std::vector<std::size_t*> copy;
+ for (auto const& e : elements)
+ copy.push_back(e.get());
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort(copy, copy.end(), checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
std::vector<std::size_t*> results(copy.size(), nullptr);
- std::ranges::partial_sort_copy(copy, results, checked_predicate); // doesn't go OOB even with invalid comparator
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
@@ -144,6 +163,60 @@
copy.push_back(e.get());
std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
}
+}
+
+struct FloatContainer {
+ float value;
+ bool operator<(const FloatContainer& other) const {
+ return value < other.value;
+ }
+};
+
+// Nans in floats do not satisfy strict weak ordering by breaking transitivity of equivalence.
+std::vector<FloatContainer> generate_float_data() {
+ std::vector<FloatContainer> floats(50);
+ for (int i = 0; i < 50; ++i) {
+ floats[i].value = static_cast<float>(i);
+ }
+ floats.push_back(FloatContainer{std::numeric_limits<float>::quiet_NaN()});
+ std::shuffle(floats.begin(), floats.end(), std::default_random_engine());
+ return floats;
+}
+
+void check_nan_floats() {
+ auto floats = generate_float_data();
+ TEST_LIBCPP_ASSERT_FAILURE(std::sort(floats.begin(), floats.end()), "not a valid strict-weak ordering");
+ floats = generate_float_data();
+ TEST_LIBCPP_ASSERT_FAILURE(std::stable_sort(floats.begin(), floats.end()), "not a valid strict-weak ordering");
+ floats = generate_float_data();
+ std::make_heap(floats.begin(), floats.end());
+ TEST_LIBCPP_ASSERT_FAILURE(std::sort_heap(floats.begin(), floats.end()), "not a valid strict-weak ordering");
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort(generate_float_data(), std::less()), "not a valid strict-weak ordering");
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::stable_sort(generate_float_data(), std::less()), "not a valid strict-weak ordering");
+ floats = generate_float_data();
+ std::ranges::make_heap(floats, std::less());
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort_heap(floats, std::less()), "not a valid strict-weak ordering");
+}
+
+void check_irreflexive() {
+ std::vector<int> v(1);
+ TEST_LIBCPP_ASSERT_FAILURE(std::sort(v.begin(), v.end(), std::greater_equal<int>()), "not a valid strict-weak ordering");
+ TEST_LIBCPP_ASSERT_FAILURE(std::stable_sort(v.begin(), v.end(), std::greater_equal<int>()), "not a valid strict-weak ordering");
+ std::make_heap(v.begin(), v.end(), std::greater_equal<int>());
+ TEST_LIBCPP_ASSERT_FAILURE(std::sort_heap(v.begin(), v.end(), std::greater_equal<int>()), "not a valid strict-weak ordering");
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort(v, std::greater_equal<int>()), "not a valid strict-weak ordering");
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::stable_sort(v, std::greater_equal<int>()), "not a valid strict-weak ordering");
+ std::ranges::make_heap(v, std::greater_equal<int>());
+ TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort_heap(v, std::greater_equal<int>()), "not a valid strict-weak ordering");
+}
+
+int main(int, char**) {
+
+ check_oob_sort_read();
+
+ check_nan_floats();
+
+ check_irreflexive();
return 0;
}
diff --git a/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/complexity.pass.cpp b/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/complexity.pass.cpp
index e3cb233..58d53e0 100644
--- a/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/complexity.pass.cpp
+++ b/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/complexity.pass.cpp
@@ -58,6 +58,9 @@
const int n = (1 << logn);
auto first = v.begin();
auto last = v.begin() + n;
+ const int debug_elements = std::min(100, n);
+ // Multiplier 2 because of comp(a,b) comp(b, a) checks.
+ const int debug_comparisons = 2 * (debug_elements + 1) * debug_elements;
std::shuffle(first, last, g);
std::make_heap(first, last);
// The exact stats of our current implementation are recorded here.
@@ -69,7 +72,7 @@
LIBCPP_ASSERT(stats.compared <= n * logn);
#endif
LIBCPP_ASSERT(std::is_sorted(first, last));
- LIBCPP_ASSERT(stats.compared <= 2 * n * logn);
+ LIBCPP_ASSERT(stats.compared <= 2 * n * logn + debug_comparisons);
}
return 0;
}
diff --git a/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/ranges_sort_heap.pass.cpp b/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/ranges_sort_heap.pass.cpp
index 128ff80..ed149e3 100644
--- a/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/ranges_sort_heap.pass.cpp
+++ b/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/ranges_sort_heap.pass.cpp
@@ -207,7 +207,7 @@
{ // `std::ranges::dangling` is returned.
[[maybe_unused]] std::same_as<std::ranges::dangling> decltype(auto) result =
- std::ranges::sort_heap(std::array{2, 1, 3});
+ std::ranges::sort_heap(std::array{3, 1, 2});
}
return true;
@@ -252,6 +252,9 @@
const int n = (1 << logn);
auto first = v.begin();
auto last = v.begin() + n;
+ const int debug_elements = std::min(100, n);
+ // Multiplier 2 because of comp(a,b) comp(b, a) checks.
+ const int debug_comparisons = 2 * (debug_elements + 1) * debug_elements;
std::shuffle(first, last, g);
std::make_heap(first, last, &MyInt::Comp);
// The exact stats of our current implementation are recorded here.
@@ -263,7 +266,7 @@
LIBCPP_ASSERT(stats.compared <= n * logn);
#endif
LIBCPP_ASSERT(std::is_sorted(first, last, &MyInt::Comp));
- LIBCPP_ASSERT(stats.compared <= 2 * n * logn);
+ LIBCPP_ASSERT(stats.compared <= 2 * n * logn + debug_comparisons);
}
}
diff --git a/test/std/algorithms/ranges_robust_against_dangling.pass.cpp b/test/std/algorithms/ranges_robust_against_dangling.pass.cpp
index c71b57e..1057c74 100644
--- a/test/std/algorithms/ranges_robust_against_dangling.pass.cpp
+++ b/test/std/algorithms/ranges_robust_against_dangling.pass.cpp
@@ -201,6 +201,7 @@
dangling_1st(std::ranges::make_heap, in);
dangling_1st(std::ranges::push_heap, in);
dangling_1st(std::ranges::pop_heap, in);
+ dangling_1st(std::ranges::make_heap, in);
dangling_1st(std::ranges::sort_heap, in);
dangling_1st<prev_permutation_result<dangling>>(std::ranges::prev_permutation, in);
dangling_1st<next_permutation_result<dangling>>(std::ranges::next_permutation, in);
diff --git a/test/std/algorithms/robust_against_proxy_iterators_lifetime_bugs.pass.cpp b/test/std/algorithms/robust_against_proxy_iterators_lifetime_bugs.pass.cpp
index 3a15d25..3a335c4 100644
--- a/test/std/algorithms/robust_against_proxy_iterators_lifetime_bugs.pass.cpp
+++ b/test/std/algorithms/robust_against_proxy_iterators_lifetime_bugs.pass.cpp
@@ -145,7 +145,7 @@
assert(lifetime_cache.contains(this) && lifetime_cache.contains(&rhs));
assert(!rhs.moved_from_);
- v_ = rhs.v_;
+ *v_ = *rhs.v_;
moved_from_ = false;
return *this;
@@ -157,7 +157,7 @@
assert(!rhs.moved_from_);
rhs.moved_from_ = true;
- v_ = rhs.v_;
+ *v_ = *rhs.v_;
moved_from_ = false;
return *this;