[libc][NFC] Allow memcmp to be inlined

Similar to D113097 although not strictly necessary for now. It helps
keeping the same structure for all memory functions.

Differential Revision: https://reviews.llvm.org/D113103

GitOrigin-RevId: 4f3511e28fc49faa4d2af743a37ca1007f5be30c
diff --git a/src/string/CMakeLists.txt b/src/string/CMakeLists.txt
index c7cdcaa..4631818 100644
--- a/src/string/CMakeLists.txt
+++ b/src/string/CMakeLists.txt
@@ -322,7 +322,7 @@
 
 function(add_memcmp memcmp_name)
   add_implementation(memcmp ${memcmp_name}
-    SRCS ${LIBC_MEMCMP_SRC}
+    SRCS ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp
     HDRS ${LIBC_SOURCE_DIR}/src/string/memcmp.h
     DEPENDS
       .memory_utils.memory_utils
@@ -334,7 +334,6 @@
 endfunction()
 
 if(${LIBC_TARGET_ARCHITECTURE_IS_X86})
-  set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp)
   add_memcmp(memcmp_x86_64_opt_sse2   COMPILE_OPTIONS -march=k8             REQUIRE SSE2)
   add_memcmp(memcmp_x86_64_opt_sse4   COMPILE_OPTIONS -march=nehalem        REQUIRE SSE4_2)
   add_memcmp(memcmp_x86_64_opt_avx2   COMPILE_OPTIONS -march=haswell        REQUIRE AVX2)
@@ -342,11 +341,9 @@
   add_memcmp(memcmp_opt_host          COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
   add_memcmp(memcmp)
 elseif(${LIBC_TARGET_ARCHITECTURE_IS_AARCH64})
-  set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/aarch64/memcmp.cpp)
-  add_memcmp(memcmp)
   add_memcmp(memcmp_opt_host          COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
+  add_memcmp(memcmp)
 else()
-  set(LIBC_MEMCMP_SRC ${LIBC_SOURCE_DIR}/src/string/memcmp.cpp)
   add_memcmp(memcmp_opt_host          COMPILE_OPTIONS ${LIBC_COMPILE_OPTIONS_NATIVE})
   add_memcmp(memcmp)
 endif()
diff --git a/src/string/aarch64/memcmp.cpp b/src/string/aarch64/memcmp.cpp
deleted file mode 100644
index 7ef3004..0000000
--- a/src/string/aarch64/memcmp.cpp
+++ /dev/null
@@ -1,52 +0,0 @@
-//===-- Implementation of memcmp ------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "src/string/memcmp.h"
-#include "src/__support/common.h"
-#include "src/string/memory_utils/elements.h"
-#include <stddef.h> // size_t
-
-namespace __llvm_libc {
-
-static int memcmp_aarch64(const char *lhs, const char *rhs, size_t count) {
-  // Use aarch64 strategies (_1, _2, _3 ...)
-  using namespace __llvm_libc::aarch64;
-
-  if (count == 0) // [0, 0]
-    return 0;
-  if (count == 1) // [1, 1]
-    return ThreeWayCompare<_1>(lhs, rhs);
-  if (count == 2) // [2, 2]
-    return ThreeWayCompare<_2>(lhs, rhs);
-  if (count == 3) // [3, 3]
-    return ThreeWayCompare<_3>(lhs, rhs);
-  if (count < 8) // [4, 7]
-    return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
-  if (count < 16) // [8, 15]
-    return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
-  if (unlikely(count >= 128)) // [128, ∞]
-    return ThreeWayCompare<Align<_16>::Then<Loop<_32>>>(lhs, rhs, count);
-  if (!Equals<_16>(lhs, rhs)) // [16, 16]
-    return ThreeWayCompare<_16>(lhs, rhs);
-  if (count < 32) // [17, 31]
-    return ThreeWayCompare<Tail<_16>>(lhs, rhs, count);
-  if (!Equals<Skip<16>::Then<_16>>(lhs, rhs)) // [32, 32]
-    return ThreeWayCompare<Skip<16>::Then<_16>>(lhs, rhs);
-  if (count < 64) // [33, 63]
-    return ThreeWayCompare<Tail<_32>>(lhs, rhs, count);
-  // [64, 127]
-  return ThreeWayCompare<Skip<32>::Then<Loop<_16>>>(lhs, rhs, count);
-}
-
-LLVM_LIBC_FUNCTION(int, memcmp,
-                   (const void *lhs, const void *rhs, size_t count)) {
-  return memcmp_aarch64(reinterpret_cast<const char *>(lhs),
-                        reinterpret_cast<const char *>(rhs), count);
-}
-
-} // namespace __llvm_libc
diff --git a/src/string/memcmp.cpp b/src/string/memcmp.cpp
index 0f2dae2..292525e 100644
--- a/src/string/memcmp.cpp
+++ b/src/string/memcmp.cpp
@@ -7,46 +7,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "src/string/memcmp.h"
-#include "src/__support/architectures.h"
-#include "src/__support/common.h"
-#include "src/string/memory_utils/elements.h"
+#include "src/string/memory_utils/memcmp_implementations.h"
 
 #include <stddef.h> // size_t
 
 namespace __llvm_libc {
 
-static int memcmp_impl(const char *lhs, const char *rhs, size_t count) {
-#if defined(LLVM_LIBC_ARCH_X86)
-  using namespace ::__llvm_libc::x86;
-#else
-  using namespace ::__llvm_libc::scalar;
-#endif
-
-  if (count == 0)
-    return 0;
-  if (count == 1)
-    return ThreeWayCompare<_1>(lhs, rhs);
-  if (count == 2)
-    return ThreeWayCompare<_2>(lhs, rhs);
-  if (count == 3)
-    return ThreeWayCompare<_3>(lhs, rhs);
-  if (count <= 8)
-    return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
-  if (count <= 16)
-    return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
-  if (count <= 32)
-    return ThreeWayCompare<HeadTail<_16>>(lhs, rhs, count);
-  if (count <= 64)
-    return ThreeWayCompare<HeadTail<_32>>(lhs, rhs, count);
-  if (count <= 128)
-    return ThreeWayCompare<HeadTail<_64>>(lhs, rhs, count);
-  return ThreeWayCompare<Align<_32>::Then<Loop<_32>>>(lhs, rhs, count);
-}
-
 LLVM_LIBC_FUNCTION(int, memcmp,
                    (const void *lhs, const void *rhs, size_t count)) {
-  return memcmp_impl(static_cast<const char *>(lhs),
-                     static_cast<const char *>(rhs), count);
+  return inline_memcmp(static_cast<const char *>(lhs),
+                       static_cast<const char *>(rhs), count);
 }
 
 } // namespace __llvm_libc
diff --git a/src/string/memory_utils/memcmp_implementations.h b/src/string/memory_utils/memcmp_implementations.h
new file mode 100644
index 0000000..a2934ce
--- /dev/null
+++ b/src/string/memory_utils/memcmp_implementations.h
@@ -0,0 +1,105 @@
+//===-- Implementation of memcmp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_STRING_MEMORY_UTILS_MEMCMP_IMPLEMENTATIONS_H
+#define LLVM_LIBC_SRC_STRING_MEMORY_UTILS_MEMCMP_IMPLEMENTATIONS_H
+
+#include "src/__support/architectures.h"
+#include "src/__support/common.h"
+#include "src/string/memory_utils/elements.h"
+
+#include <stddef.h> // size_t
+
+namespace __llvm_libc {
+
+static inline int inline_memcmp(const char *lhs, const char *rhs,
+                                size_t count) {
+#if defined(LLVM_LIBC_ARCH_X86)
+  /////////////////////////////////////////////////////////////////////////////
+  // LLVM_LIBC_ARCH_X86
+  /////////////////////////////////////////////////////////////////////////////
+  using namespace __llvm_libc::x86;
+  if (count == 0)
+    return 0;
+  if (count == 1)
+    return ThreeWayCompare<_1>(lhs, rhs);
+  if (count == 2)
+    return ThreeWayCompare<_2>(lhs, rhs);
+  if (count == 3)
+    return ThreeWayCompare<_3>(lhs, rhs);
+  if (count <= 8)
+    return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
+  if (count <= 16)
+    return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
+  if (count <= 32)
+    return ThreeWayCompare<HeadTail<_16>>(lhs, rhs, count);
+  if (count <= 64)
+    return ThreeWayCompare<HeadTail<_32>>(lhs, rhs, count);
+  if (count <= 128)
+    return ThreeWayCompare<HeadTail<_64>>(lhs, rhs, count);
+  return ThreeWayCompare<Align<_32>::Then<Loop<_32>>>(lhs, rhs, count);
+#elif defined(LLVM_LIBC_ARCH_AARCH64)
+  /////////////////////////////////////////////////////////////////////////////
+  // LLVM_LIBC_ARCH_AARCH64
+  /////////////////////////////////////////////////////////////////////////////
+  using namespace ::__llvm_libc::aarch64;
+  if (count == 0) // [0, 0]
+    return 0;
+  if (count == 1) // [1, 1]
+    return ThreeWayCompare<_1>(lhs, rhs);
+  if (count == 2) // [2, 2]
+    return ThreeWayCompare<_2>(lhs, rhs);
+  if (count == 3) // [3, 3]
+    return ThreeWayCompare<_3>(lhs, rhs);
+  if (count < 8) // [4, 7]
+    return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
+  if (count < 16) // [8, 15]
+    return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
+  if (unlikely(count >= 128)) // [128, ∞]
+    return ThreeWayCompare<Align<_16>::Then<Loop<_32>>>(lhs, rhs, count);
+  if (!Equals<_16>(lhs, rhs)) // [16, 16]
+    return ThreeWayCompare<_16>(lhs, rhs);
+  if (count < 32) // [17, 31]
+    return ThreeWayCompare<Tail<_16>>(lhs, rhs, count);
+  if (!Equals<Skip<16>::Then<_16>>(lhs, rhs)) // [32, 32]
+    return ThreeWayCompare<Skip<16>::Then<_16>>(lhs, rhs);
+  if (count < 64) // [33, 63]
+    return ThreeWayCompare<Tail<_32>>(lhs, rhs, count);
+  // [64, 127]
+  return ThreeWayCompare<Skip<32>::Then<Loop<_16>>>(lhs, rhs, count);
+#else
+  /////////////////////////////////////////////////////////////////////////////
+  // Default
+  /////////////////////////////////////////////////////////////////////////////
+  using namespace ::__llvm_libc::scalar;
+
+  if (count == 0)
+    return 0;
+  if (count == 1)
+    return ThreeWayCompare<_1>(lhs, rhs);
+  if (count == 2)
+    return ThreeWayCompare<_2>(lhs, rhs);
+  if (count == 3)
+    return ThreeWayCompare<_3>(lhs, rhs);
+  if (count <= 8)
+    return ThreeWayCompare<HeadTail<_4>>(lhs, rhs, count);
+  if (count <= 16)
+    return ThreeWayCompare<HeadTail<_8>>(lhs, rhs, count);
+  if (count <= 32)
+    return ThreeWayCompare<HeadTail<_16>>(lhs, rhs, count);
+  if (count <= 64)
+    return ThreeWayCompare<HeadTail<_32>>(lhs, rhs, count);
+  if (count <= 128)
+    return ThreeWayCompare<HeadTail<_64>>(lhs, rhs, count);
+  return ThreeWayCompare<Align<_32>::Then<Loop<_32>>>(lhs, rhs, count);
+#endif
+}
+
+} // namespace __llvm_libc
+
+#endif // LLVM_LIBC_SRC_STRING_MEMORY_UTILS_MEMCMP_IMPLEMENTATIONS_H