[libc] refactor atof string parsing

Split the code for parsing hexadecimal floating point numbers from the
code for parsing the decimal floating point numbers so that the parsing
can be faster for both of them.

This decreases the time for the benchmark in release mode by about 15%,
which noticeably beats GLibc.

Old version: 2.299s
New version: 1.893s
GLibc: 2.133s

Tests run by running the following command 10 times for each version:
time ~/llvm-project/build/bin/libc_str_to_float_comparison_test ~/parse-number-fxx-test-data/data/*

the parse-number-fxx-test-data-repository is here:
https://github.com/nigeltao/parse-number-fxx-test-data/tree/fe94de252c691900982050c8e7c503d1efd1299a

It's important to build llvm-libc in Release mode for accurate
performance comparisons against glibc (set -DCMAKE_BUILD_TYPE=Release in
your cmake).
You also have to build the libc_str_to_float_comparison_test target.

Reviewed By: lntue

Differential Revision: https://reviews.llvm.org/D113036

GitOrigin-RevId: 8298424cae9b4d3d41dbe17857dc9cb247d90786
diff --git a/src/__support/str_to_float.h b/src/__support/str_to_float.h
index 59bd1ec..90408e7 100644
--- a/src/__support/str_to_float.h
+++ b/src/__support/str_to_float.h
@@ -40,7 +40,7 @@
   }
 }
 
-template <class T> uint32_t leadingZeroes(T inputNumber) {
+template <class T> uint32_t inline leadingZeroes(T inputNumber) {
   // TODO(michaelrj): investigate the portability of using something like
   // __builtin_clz for specific types.
   constexpr uint32_t bitsInT = sizeof(T) * 8;
@@ -71,6 +71,14 @@
   return bitsInT - curGuess;
 }
 
+template <> uint32_t inline leadingZeroes<uint32_t>(uint32_t inputNumber) {
+  return inputNumber == 0 ? 32 : __builtin_clz(inputNumber);
+}
+
+template <> uint32_t inline leadingZeroes<uint64_t>(uint64_t inputNumber) {
+  return inputNumber == 0 ? 64 : __builtin_clzll(inputNumber);
+}
+
 static inline uint64_t low64(__uint128_t num) {
   return static_cast<uint64_t>(num & 0xffffffffffffffff);
 }
@@ -442,6 +450,81 @@
   return;
 }
 
+// Takes a mantissa and base 2 exponent and converts it into its closest
+// floating point type T equivalient. Since the exponent is already in the right
+// form, this is mostly just shifting and rounding. This is used for hexadecimal
+// numbers since a base 16 exponent multiplied by 4 is the base 2 exponent.
+template <class T>
+static inline void
+binaryExpToFloat(typename fputil::FPBits<T>::UIntType mantissa, int32_t exp2,
+                 typename fputil::FPBits<T>::UIntType *outputMantissa,
+                 uint32_t *outputExp2) {
+  using BitsType = typename fputil::FPBits<T>::UIntType;
+
+  // This is the number of leading zeroes a properly normalized float of type T
+  // should have.
+  constexpr uint32_t NORMALIZED_LEADING_ZEROES =
+      (sizeof(BitsType) * 8) - fputil::FloatProperties<T>::mantissaWidth - 1;
+  constexpr BitsType OVERFLOWED_MANTISSA =
+      BitsType(1) << (fputil::FloatProperties<T>::mantissaWidth + 1);
+
+  // Normalization
+  int32_t amountToShift =
+      NORMALIZED_LEADING_ZEROES - leadingZeroes<BitsType>(mantissa);
+  if (amountToShift < 0) {
+    mantissa <<= -amountToShift;
+  } else {
+    mantissa = shiftRightAndRound(mantissa, amountToShift);
+    if (mantissa == OVERFLOWED_MANTISSA) {
+      mantissa >>= 1;
+      exp2 += 1;
+    }
+  }
+  exp2 += amountToShift;
+
+  // Account for the fact that the mantissa represented an integer
+  // previously, but now represents the fractional part of a normalized
+  // number.
+  exp2 += fputil::FloatProperties<T>::mantissaWidth;
+
+  int32_t biasedExponent = exp2 + fputil::FPBits<T>::exponentBias;
+  // handle subnormals
+  if (biasedExponent <= 0) {
+
+    // the most mantissa is currently normalized, meaning that the msb is
+    // one bit left of where the decimal point should go.
+    amountToShift = 1;
+    BitsType mantissaCopy = mantissa >> 1;
+    while (biasedExponent < 0 && mantissaCopy > 0) {
+      mantissaCopy = mantissaCopy >> 1;
+      ++amountToShift;
+      ++biasedExponent;
+    }
+    // If we cut off any bits to fit this number into a subnormal, then it's
+    // out of range for this size of float.
+    if ((mantissa & ((1 << amountToShift) - 1)) > 0) {
+      errno = ERANGE; // NOLINT
+    }
+    mantissa = shiftRightAndRound(mantissa, amountToShift);
+    if (mantissa == OVERFLOWED_MANTISSA) {
+      mantissa >>= 1;
+      exp2 += 1;
+    } else if (mantissa == 0) {
+      biasedExponent = 0;
+    }
+  }
+  // handle numbers that're too large and get squashed to inf
+  else if (biasedExponent >
+           (1 << fputil::FloatProperties<T>::exponentWidth) - 1) {
+    // This indicates an overflow, so we make the result INF and set errno.
+    biasedExponent = (1 << fputil::FloatProperties<T>::exponentWidth) - 1;
+    mantissa = 0;
+    errno = ERANGE; // NOLINT
+  }
+  *outputMantissa = mantissa;
+  *outputExp2 = biasedExponent;
+}
+
 // checks if the next 4 characters of the string pointer are the start of a
 // hexadecimal floating point number. Does not advance the string pointer.
 static inline bool is_float_hex_start(const char *__restrict src,
@@ -456,6 +539,218 @@
   }
 }
 
+// Takes the start of a string representing a decimal float, as well as the
+// local decimalPoint. It returns if it suceeded in parsing any digits, and if
+// the return value is true then the outputs are pointer to the end of the
+// number, and the mantissa and exponent for the closest float T representation.
+// If the return value is false, then it is assumed that there is no number
+// here.
+template <class T>
+static inline bool
+decimalStringToFloat(const char *__restrict src, const char DECIMAL_POINT,
+                     char **__restrict strEnd,
+                     typename fputil::FPBits<T>::UIntType *outputMantissa,
+                     uint32_t *outputExponent) {
+  using BitsType = typename fputil::FPBits<T>::UIntType;
+  constexpr uint32_t BASE = 10;
+  constexpr char EXPONENT_MARKER = 'e';
+
+  const char *__restrict numStart = src;
+  bool truncated = false;
+  bool seenDigit = false;
+  bool afterDecimal = false;
+  BitsType mantissa = 0;
+  int32_t exponent = 0;
+
+  // The goal for the first step of parsing is to convert the number in src to
+  // the format mantissa * (base ^ exponent)
+
+  // The first loop fills the mantissa with as many digits as it can hold
+  const BitsType BITSTYPE_MAX_DIV_BY_BASE =
+      __llvm_libc::cpp::NumericLimits<BitsType>::max() / BASE;
+  while ((isdigit(*src) || *src == DECIMAL_POINT) &&
+         mantissa < BITSTYPE_MAX_DIV_BY_BASE) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
+        break; // this means that *src points to a second decimal point, ending
+               // the number.
+      } else {
+        afterDecimal = true;
+        ++src;
+        continue;
+      }
+    }
+    uint32_t digit = *src - '0';
+
+    mantissa = (mantissa * BASE) + digit;
+    seenDigit = true;
+    if (afterDecimal) {
+      --exponent;
+    }
+
+    ++src;
+  }
+
+  if (!seenDigit)
+    return false;
+
+  // The second loop is to run through the remaining digits after we've filled
+  // the mantissa.
+  while (isdigit(*src) || *src == DECIMAL_POINT) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
+        break; // this means that *src points to a second decimal point, ending
+               // the number.
+      } else {
+        afterDecimal = true;
+        ++src;
+        continue;
+      }
+    }
+    uint32_t digit = *src - '0';
+
+    if (digit > 0)
+      truncated = true;
+
+    if (!afterDecimal)
+      ++exponent;
+
+    ++src;
+  }
+
+  if ((*src | 32) == EXPONENT_MARKER) {
+    if (*(src + 1) == '+' || *(src + 1) == '-' || isdigit(*(src + 1))) {
+      ++src;
+      char *tempStrEnd;
+      int32_t add_to_exponent = strtointeger<int32_t>(src, &tempStrEnd, 10);
+      if (add_to_exponent > 100000)
+        add_to_exponent = 100000;
+      else if (add_to_exponent < -100000)
+        add_to_exponent = -100000;
+
+      src = tempStrEnd;
+      exponent += add_to_exponent;
+    }
+  }
+
+  *strEnd = const_cast<char *>(src);
+  if (mantissa == 0) { // if we have a 0, then also 0 the exponent.
+    *outputMantissa = 0;
+    *outputExponent = 0;
+  } else {
+    decimalExpToFloat<T>(mantissa, exponent, numStart, truncated,
+                         outputMantissa, outputExponent);
+  }
+  return true;
+}
+
+// Takes the start of a string representing a hexadecimal float, as well as the
+// local decimal point. It returns if it suceeded in parsing any digits, and if
+// the return value is true then the outputs are pointer to the end of the
+// number, and the mantissa and exponent for the closest float T representation.
+// If the return value is false, then it is assumed that there is no number
+// here.
+template <class T>
+static inline bool
+hexadecimalStringToFloat(const char *__restrict src, const char DECIMAL_POINT,
+                         char **__restrict strEnd,
+                         typename fputil::FPBits<T>::UIntType *outputMantissa,
+                         uint32_t *outputExponent) {
+  using BitsType = typename fputil::FPBits<T>::UIntType;
+  constexpr uint32_t BASE = 16;
+  constexpr char EXPONENT_MARKER = 'p';
+
+  bool truncated = false;
+  bool seenDigit = false;
+  bool afterDecimal = false;
+  BitsType mantissa = 0;
+  int32_t exponent = 0;
+
+  // The goal for the first step of parsing is to convert the number in src to
+  // the format mantissa * (base ^ exponent)
+
+  // The first loop fills the mantissa with as many digits as it can hold
+  const BitsType BITSTYPE_MAX_DIV_BY_BASE =
+      __llvm_libc::cpp::NumericLimits<BitsType>::max() / BASE;
+  while ((isalnum(*src) || *src == DECIMAL_POINT) &&
+         mantissa < BITSTYPE_MAX_DIV_BY_BASE) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
+        break; // this means that *src points to a second decimal point, ending
+               // the number.
+      } else {
+        afterDecimal = true;
+        ++src;
+        continue;
+      }
+    }
+    uint32_t digit = b36_char_to_int(*src);
+    if (digit >= BASE)
+      break;
+
+    mantissa = (mantissa * BASE) + digit;
+    seenDigit = true;
+    if (afterDecimal)
+      --exponent;
+
+    ++src;
+  }
+
+  if (!seenDigit)
+    return false;
+
+  // The second loop is to run through the remaining digits after we've filled
+  // the mantissa.
+  while (isalnum(*src) || *src == DECIMAL_POINT) {
+    if (*src == DECIMAL_POINT) {
+      if (afterDecimal) {
+        break; // this means that *src points to a second decimal point, ending
+               // the number.
+      } else {
+        afterDecimal = true;
+        ++src;
+        continue;
+      }
+    }
+    uint32_t digit = b36_char_to_int(*src);
+    if (digit >= BASE)
+      break;
+
+    if (digit > 0)
+      truncated = true;
+
+    if (!afterDecimal)
+      ++exponent;
+
+    ++src;
+  }
+
+  // Convert the exponent from having a base of 16 to having a base of 2.
+  exponent *= 4;
+
+  if ((*src | 32) == EXPONENT_MARKER) {
+    if (*(src + 1) == '+' || *(src + 1) == '-' || isdigit(*(src + 1))) {
+      ++src;
+      char *tempStrEnd;
+      int32_t add_to_exponent = strtointeger<int32_t>(src, &tempStrEnd, 10);
+      if (add_to_exponent > 100000)
+        add_to_exponent = 100000;
+      else if (add_to_exponent < -100000)
+        add_to_exponent = -100000;
+      src = tempStrEnd;
+      exponent += add_to_exponent;
+    }
+  }
+  *strEnd = const_cast<char *>(src);
+  if (mantissa == 0) { // if we have a 0, then also 0 the exponent.
+    *outputMantissa = 0;
+    *outputExponent = 0;
+  } else {
+    binaryExpToFloat<T>(mantissa, exponent, outputMantissa, outputExponent);
+  }
+  return true;
+}
+
 // Takes a pointer to a string and a pointer to a string pointer. This function
 // is used as the backend for all of the string to float functions.
 template <class T>
@@ -478,7 +773,7 @@
   static const char *INF_STRING = "infinity";
   static const char *NAN_STRING = "nan";
 
-  bool truncated = false;
+  // bool truncated = false;
 
   if (isdigit(*src) || *src == DECIMAL_POINT) { // regular number
     int base = 10;
@@ -489,157 +784,23 @@
       exponentMarker = 'p';
       seenDigit = true;
     }
-    const char *__restrict numStart = src;
-    bool afterDecimal = false;
+    char *newStrEnd = nullptr;
 
-    BitsType mantissa = 0;
-    int32_t exponent = 0;
-
-    // The goal for the first step of parsing is to convert the number in src to
-    // the format mantissa * (base ^ exponent)
-
-    constexpr BitsType MANTISSA_MAX =
-        BitsType(1) << (fputil::FloatProperties<T>::mantissaWidth +
-                        1); // The extra bit is to give space for the implicit 1
-    const BitsType BITSTYPE_MAX_DIV_BY_BASE =
-        __llvm_libc::cpp::NumericLimits<BitsType>::max() / base;
-    while ((isalnum(*src) || *src == DECIMAL_POINT) &&
-           mantissa < BITSTYPE_MAX_DIV_BY_BASE) {
-      if (*src == DECIMAL_POINT && afterDecimal) {
-        break; // this means that *src points to a second decimal point, ending
-               // the number.
-      } else if (*src == DECIMAL_POINT) {
-        afterDecimal = true;
-        ++src;
-        continue;
-      }
-      int digit = b36_char_to_int(*src);
-      if (digit >= base) {
-        break;
-      }
-
-      mantissa = (mantissa * base) + digit;
-      seenDigit = true;
-      if (afterDecimal) {
-        --exponent;
-      }
-
-      ++src;
-    }
-
-    // The second loop is to run through the remaining digits after we've filled
-    // the mantissa.
-    while (isalnum(*src) || *src == DECIMAL_POINT) {
-      if (*src == DECIMAL_POINT && afterDecimal) {
-        break; // this means that *src points to a second decimal point, ending
-               // the number.
-      } else if (*src == DECIMAL_POINT) {
-        afterDecimal = true;
-        ++src;
-        continue;
-      }
-      int digit = b36_char_to_int(*src);
-      if (digit >= base) {
-        break;
-      }
-
-      if (digit > 0) {
-        truncated = true;
-      }
-
-      if (!afterDecimal) {
-        exponent++;
-      }
-
-      ++src;
-    }
-
-    // if our base is 16 then convert the exponent to base 2
+    BitsType outputMantissa = 0;
+    uint32_t outputExponent = 0;
     if (base == 16) {
-      exponent *= 4;
-    }
-
-    if ((*src | 32) == exponentMarker) {
-      if (*(src + 1) == '+' || *(src + 1) == '-' || isdigit(*(src + 1))) {
-        ++src;
-        char *tempStrEnd;
-        int32_t add_to_exponent = strtointeger<int32_t>(src, &tempStrEnd, 10);
-        if (add_to_exponent > 100000) {
-          add_to_exponent = 100000;
-        } else if (add_to_exponent < -100000) {
-          add_to_exponent = -100000;
-        }
-        src += tempStrEnd - src;
-        exponent += add_to_exponent;
-      }
-    }
-
-    if (mantissa == 0) { // if we have a 0, then also 0 the exponent.
-      exponent = 0;
-    } else if (base == 16) {
-
-      // These two loops should normalize the number if we assume the decimal
-      // point is after the bit at mantissaWidth.
-      // For example if type T is a 32 bit float, this should result in a
-      // mantissa with its most significant 1 being at bit 23.
-      while (mantissa < (MANTISSA_MAX >> 1)) {
-        mantissa = mantissa << 1;
-        --exponent;
-      }
-      BitsType mantissaCopy = mantissa;
-      unsigned int amountToShift = 0;
-      while (mantissaCopy > MANTISSA_MAX) {
-        mantissaCopy = mantissaCopy >> 1;
-        ++amountToShift;
-      }
-      exponent += amountToShift;
-      mantissa = shiftRightAndRound(mantissa, amountToShift);
-
-      // Account for the fact that the mantissa represented an integer
-      // previously, but now represents the fractional part of a normalized
-      // number.
-      exponent += fputil::FloatProperties<T>::mantissaWidth;
-
-      int32_t biasedExponent = exponent + fputil::FPBits<T>::exponentBias;
-      if (biasedExponent <= 0) {
-        // handle subnormals here
-
-        // the most mantissa is currently normalized, meaning that the msb is
-        // one bit left of where the decimal point should go.
-        amountToShift = 1;
-        mantissaCopy = mantissa >> 1;
-        while (biasedExponent < 0 && mantissaCopy > 0) {
-          mantissaCopy = mantissaCopy >> 1;
-          ++amountToShift;
-          ++biasedExponent;
-        }
-        // If we cut off any bits to fit this number into a subnormal, then it's
-        // out of range for this size of float.
-        if ((mantissa & ((1 << amountToShift) - 1)) > 0) {
-          errno = ERANGE; // NOLINT
-        }
-        mantissa = shiftRightAndRound(mantissa, amountToShift);
-        if (mantissa == 0) {
-          biasedExponent = 0;
-        }
-      } else if (biasedExponent > result.maxExponent) {
-        // This indicates an overflow, so we make the result INF and set errno.
-        biasedExponent = result.maxExponent;
-        mantissa = 0;
-        errno = ERANGE; // NOLINT
-      }
-
-      result.setUnbiasedExponent(biasedExponent);
-      result.setMantissa(mantissa);
+      seenDigit = hexadecimalStringToFloat<T>(src, DECIMAL_POINT, &newStrEnd,
+                                              &outputMantissa, &outputExponent);
     } else { // base is 10
-      BitsType outputMantissa = 0;
-      uint32_t outputExponent = 0;
-      decimalExpToFloat<T>(mantissa, exponent, numStart, truncated,
-                           &outputMantissa, &outputExponent);
+      seenDigit = decimalStringToFloat<T>(src, DECIMAL_POINT, &newStrEnd,
+                                          &outputMantissa, &outputExponent);
+    }
+
+    if (seenDigit) {
+      src += newStrEnd - src;
       result.setMantissa(outputMantissa);
       result.setUnbiasedExponent(outputExponent);
     }
-
   } else if ((*src | 32) == 'n') { // NaN
     if ((src[1] | 32) == NAN_STRING[1] && (src[2] | 32) == NAN_STRING[2]) {
       seenDigit = true;
diff --git a/test/src/stdlib/strtof_test.cpp b/test/src/stdlib/strtof_test.cpp
index f20cc0d..2109e7d 100644
--- a/test/src/stdlib/strtof_test.cpp
+++ b/test/src/stdlib/strtof_test.cpp
@@ -132,6 +132,10 @@
   runTest("0x123456700", 11, 0x4f91a2b4);
 }
 
+TEST_F(LlvmLibcStrToFTest, HexadecimalsWithRoundingProblems) {
+  runTest("0xFFFFFFFF", 10, 0x4f800000);
+}
+
 TEST_F(LlvmLibcStrToFTest, HexadecimalOutOfRangeTests) {
   runTest("0x123456789123456789123456789123456789", 38, 0x7f800000, ERANGE);
   runTest("-0x123456789123456789123456789123456789", 39, 0xff800000, ERANGE);