[libc++] Introduce __is_pointer_in_range

This checks whether a pointer is within a range, even during constant evaluation. This allows running optimized code paths during constant evaluation, instead of falling back to the general-purpose implementation all the time. This is also a central place for comparing unrelated pointers, which is technically UB.

Reviewed By: ldionne, #libc

Spies: libcxx-commits

Differential Revision: https://reviews.llvm.org/D143327

GitOrigin-RevId: 7949ee0d4f10f4fa3b3021fb6269e14175c3b0ae
diff --git a/.clang-format b/.clang-format
index acf987d..4fd10e8 100644
--- a/.clang-format
+++ b/.clang-format
@@ -38,6 +38,7 @@
                   '_LIBCPP_HIDE_FROM_ABI_AFTER_V1',
                   '_LIBCPP_INLINE_VISIBILITY',
                   '_LIBCPP_NOALIAS',
+                  '_LIBCPP_NO_SANITIZE',
                   '_LIBCPP_USING_IF_EXISTS',
                   '_LIBCPP_DEPRECATED',
                   '_LIBCPP_DEPRECATED_IN_CXX11',
diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt
index f676805..4b5aef4 100644
--- a/include/CMakeLists.txt
+++ b/include/CMakeLists.txt
@@ -831,6 +831,7 @@
   __utility/forward_like.h
   __utility/in_place.h
   __utility/integer_sequence.h
+  __utility/is_pointer_in_range.h
   __utility/move.h
   __utility/pair.h
   __utility/piecewise_construct.h
diff --git a/include/__config b/include/__config
index 1fc45b5..17229e2 100644
--- a/include/__config
+++ b/include/__config
@@ -35,6 +35,8 @@
 
 #ifdef __cplusplus
 
+// The attributes supported by clang are documented at https://clang.llvm.org/docs/AttributeReference.html
+
 // _LIBCPP_VERSION represents the version of libc++, which matches the version of LLVM.
 // Given a LLVM release LLVM XX.YY.ZZ (e.g. LLVM 17.0.1 == 17.00.01), _LIBCPP_VERSION is
 // defined to XXYYZZ.
@@ -1099,6 +1101,12 @@
 #    define _LIBCPP_PREFERRED_NAME(x)
 #  endif
 
+#  if __has_attribute(__no_sanitize__)
+#    define _LIBCPP_NO_SANITIZE(...) __attribute__((__no_sanitize__(__VA_ARGS__)))
+#  else
+#    define _LIBCPP_NO_SANITIZE(...)
+#  endif
+
 // We often repeat things just for handling wide characters in the library.
 // When wide characters are disabled, it can be useful to have a quick way of
 // disabling it without having to resort to #if-#endif, which has a larger
diff --git a/include/__utility/is_pointer_in_range.h b/include/__utility/is_pointer_in_range.h
new file mode 100644
index 0000000..84e833d
--- /dev/null
+++ b/include/__utility/is_pointer_in_range.h
@@ -0,0 +1,52 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___UTILITY_IS_POINTER_IN_RANGE_H
+#define _LIBCPP___UTILITY_IS_POINTER_IN_RANGE_H
+
+#include <__algorithm/comp.h>
+#include <__assert>
+#include <__config>
+#include <__type_traits/enable_if.h>
+#include <__type_traits/integral_constant.h>
+#include <__type_traits/is_constant_evaluated.h>
+#include <__type_traits/is_function.h>
+#include <__type_traits/is_member_pointer.h>
+#include <__type_traits/void_t.h>
+#include <__utility/declval.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <class _Tp, class _Up>
+_LIBCPP_CONSTEXPR_SINCE_CXX14 _LIBCPP_HIDE_FROM_ABI _LIBCPP_NO_SANITIZE("address") bool __is_pointer_in_range(
+    const _Tp* __begin, const _Tp* __end, const _Up* __ptr) {
+  static_assert(!is_function<_Tp>::value && !is_function<_Up>::value,
+                "__is_pointer_in_range should not be called with function pointers");
+  static_assert(!is_member_pointer<_Tp>::value && !is_member_pointer<_Up>::value,
+                "__is_pointer_in_range should not be called with member pointers");
+
+  if (__libcpp_is_constant_evaluated()) {
+    _LIBCPP_ASSERT(__builtin_constant_p(__begin <= __end), "__begin and __end do not form a range");
+
+    // If this is not a constant during constant evaluation we know that __ptr is not part of the allocation where
+    // [__begin, __end) is.
+    if (!__builtin_constant_p(__begin <= __ptr && __ptr < __end))
+      return false;
+  }
+
+  // Checking this for unrelated pointers is technically UB, but no compiler optimizes based on it (currently).
+  return !__less<>()(__ptr, __begin) && __less<>()(__ptr, __end);
+}
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP___UTILITY_IS_POINTER_IN_RANGE_H
diff --git a/include/module.modulemap.in b/include/module.modulemap.in
index d9948a7..ba47450 100644
--- a/include/module.modulemap.in
+++ b/include/module.modulemap.in
@@ -1755,6 +1755,7 @@
       module forward_like           { private header "__utility/forward_like.h" }
       module in_place               { private header "__utility/in_place.h" }
       module integer_sequence       { private header "__utility/integer_sequence.h" }
+      module is_pointer_in_range    { private header "__utility/is_pointer_in_range.h" }
       module move                   { private header "__utility/move.h" }
       module pair                   { private header "__utility/pair.h" }
       module pair_fwd               { private header "__fwd/pair.h" }
diff --git a/include/string b/include/string
index 4f4a119..d116bac 100644
--- a/include/string
+++ b/include/string
@@ -571,7 +571,10 @@
 #include <__type_traits/is_trivial.h>
 #include <__type_traits/noexcept_move_assign_container.h>
 #include <__type_traits/remove_cvref.h>
+#include <__type_traits/void_t.h>
 #include <__utility/auto_cast.h>
+#include <__utility/declval.h>
+#include <__utility/is_pointer_in_range.h>
 #include <__utility/move.h>
 #include <__utility/swap.h>
 #include <__utility/unreachable.h>
@@ -660,6 +663,13 @@
 
 struct __uninitialized_size_tag {};
 
+template <class _Tp, class _Up, class = void>
+struct __is_less_than_comparable : false_type {};
+
+template <class _Tp, class _Up>
+struct __is_less_than_comparable<_Tp, _Up, __void_t<decltype(std::declval<_Tp>() < std::declval<_Up>())> > : true_type {
+};
+
 template<class _CharT, class _Traits, class _Allocator>
 class basic_string
 {
@@ -1942,14 +1952,23 @@
 
     _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 void __invalidate_iterators_past(size_type);
 
-    template<class _Tp>
-    _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20
-    bool __addr_in_range(_Tp&& __t) const {
-        // assume that the ranges overlap, because we can't check during constant evaluation
-        if (__libcpp_is_constant_evaluated())
-          return true;
-        const volatile void *__p = std::addressof(__t);
-        return data() <= __p && __p <= data() + size();
+    template <
+        class _Tp,
+        __enable_if_t<__is_less_than_comparable<const __remove_cvref_t<_Tp>*, const value_type*>::value, int> = 0>
+    _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __addr_in_range(const _Tp& __v) const {
+      return std::__is_pointer_in_range(data(), data() + size() + 1, std::addressof(__v));
+    }
+
+    template <
+        class _Tp,
+        __enable_if_t<!__is_less_than_comparable<const __remove_cvref_t<_Tp>*, const value_type*>::value, int> = 0>
+    _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __addr_in_range(const _Tp& __v) const {
+      if (__libcpp_is_constant_evaluated())
+                return false;
+
+      auto __t_ptr = reinterpret_cast<const char*>(std::addressof(__v));
+      return reinterpret_cast<const char*>(data()) <= __t_ptr &&
+             __t_ptr < reinterpret_cast<const char*>(data() + size() + 1);
     }
 
     _LIBCPP_NORETURN _LIBCPP_HIDE_FROM_ABI
@@ -2886,10 +2905,6 @@
     size_type __cap = capacity();
     if (__cap - __sz + __n1 >= __n2)
     {
-        if (__libcpp_is_constant_evaluated()) {
-            __grow_by_and_replace(__cap, 0, __sz, __pos, __n1, __n2, __s);
-            return *this;
-        }
         value_type* __p = std::__to_address(__get_pointer());
         if (__n1 != __n2)
         {
@@ -2902,7 +2917,7 @@
                     traits_type::move(__p + __pos + __n2, __p + __pos + __n1, __n_move);
                     return __null_terminate_at(__p, __sz + (__n2 - __n1));
                 }
-                if (__p + __pos < __s && __s < __p + __sz)
+                if (std::__is_pointer_in_range(__p + __pos + 1, __p + __sz, __s))
                 {
                     if (__p + __pos + __n1 <= __s)
                         __s += __n2 - __n1;
diff --git a/test/libcxx/utilities/is_pointer_in_range.pass.cpp b/test/libcxx/utilities/is_pointer_in_range.pass.cpp
new file mode 100644
index 0000000..a1b54ff
--- /dev/null
+++ b/test/libcxx/utilities/is_pointer_in_range.pass.cpp
@@ -0,0 +1,58 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// ADDITIONAL_COMPILE_FLAGS: -Wno-private-header
+
+#include <__utility/is_pointer_in_range.h>
+#include <cassert>
+
+#include "test_macros.h"
+
+template <class T, class U>
+TEST_CONSTEXPR_CXX14 void test_cv_quals() {
+  T i = 0;
+  U j = 0;
+  assert(!std::__is_pointer_in_range(&i, &i, &i));
+  assert(std::__is_pointer_in_range(&i, &i + 1, &i));
+  assert(!std::__is_pointer_in_range(&i, &i + 1, &j));
+
+#if TEST_STD_VER >= 20
+  {
+    T* arr1 = new int[4]{1, 2, 3, 4};
+    U* arr2 = new int[4]{5, 6, 7, 8};
+
+    assert(!std::__is_pointer_in_range(arr1, arr1 + 4, arr2));
+    assert(std::__is_pointer_in_range(arr1, arr1 + 4, arr1 + 3));
+    assert(!std::__is_pointer_in_range(arr1, arr1, arr1 + 3));
+
+    delete[] arr1;
+    delete[] arr2;
+  }
+#endif
+}
+
+TEST_CONSTEXPR_CXX14 bool test() {
+  test_cv_quals<int, int>();
+  test_cv_quals<const int, int>();
+  test_cv_quals<int, const int>();
+  test_cv_quals<const int, const int>();
+  test_cv_quals<volatile int, int>();
+  test_cv_quals<int, volatile int>();
+  test_cv_quals<volatile int, volatile int>();
+
+  return true;
+}
+
+int main(int, char**) {
+  test();
+#if TEST_STD_VER >= 14
+  static_assert(test(), "");
+#endif
+
+  return 0;
+}