| //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H |
| #define LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H |
| |
| #include "FPBits.h" |
| #include "PlatformDefs.h" |
| |
| #include "src/__support/CPP/TypeTraits.h" |
| |
| namespace __llvm_libc { |
| namespace fputil { |
| |
| namespace internal { |
| |
| template <typename T> |
| static inline void normalize(int &exponent, |
| typename FPBits<T>::UIntType &mantissa); |
| |
| template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) { |
| // Use binary search to shift the leading 1 bit. |
| // With MantissaWidth<float> = 23, it will take |
| // ceil(log2(23)) = 5 steps checking the mantissa bits as followed: |
| // Step 1: 0000 0000 0000 XXXX XXXX XXXX |
| // Step 2: 0000 00XX XXXX XXXX XXXX XXXX |
| // Step 3: 000X XXXX XXXX XXXX XXXX XXXX |
| // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX |
| // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX |
| constexpr int NSTEPS = 5; // = ceil(log2(MantissaWidth)) |
| constexpr uint32_t BOUNDS[NSTEPS] = {1 << 12, 1 << 18, 1 << 21, 1 << 22, |
| 1 << 23}; |
| constexpr int SHIFTS[NSTEPS] = {12, 6, 3, 2, 1}; |
| |
| for (int i = 0; i < NSTEPS; ++i) { |
| if (mantissa < BOUNDS[i]) { |
| exponent -= SHIFTS[i]; |
| mantissa <<= SHIFTS[i]; |
| } |
| } |
| } |
| |
| template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) { |
| // Use binary search to shift the leading 1 bit similar to float. |
| // With MantissaWidth<double> = 52, it will take |
| // ceil(log2(52)) = 6 steps checking the mantissa bits. |
| constexpr int NSTEPS = 6; // = ceil(log2(MantissaWidth)) |
| constexpr uint64_t BOUNDS[NSTEPS] = {1ULL << 26, 1ULL << 39, 1ULL << 46, |
| 1ULL << 49, 1ULL << 51, 1ULL << 52}; |
| constexpr int SHIFTS[NSTEPS] = {27, 14, 7, 4, 2, 1}; |
| |
| for (int i = 0; i < NSTEPS; ++i) { |
| if (mantissa < BOUNDS[i]) { |
| exponent -= SHIFTS[i]; |
| mantissa <<= SHIFTS[i]; |
| } |
| } |
| } |
| |
| #ifdef LONG_DOUBLE_IS_DOUBLE |
| template <> |
| inline void normalize<long double>(int &exponent, uint64_t &mantissa) { |
| normalize<double>(exponent, mantissa); |
| } |
| #elif !defined(SPECIAL_X86_LONG_DOUBLE) |
| template <> |
| inline void normalize<long double>(int &exponent, __uint128_t &mantissa) { |
| // Use binary search to shift the leading 1 bit similar to float. |
| // With MantissaWidth<long double> = 112, it will take |
| // ceil(log2(112)) = 7 steps checking the mantissa bits. |
| constexpr int NSTEPS = 7; // = ceil(log2(MantissaWidth)) |
| constexpr __uint128_t BOUNDS[NSTEPS] = { |
| __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98, |
| __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111, |
| __uint128_t(1) << 112}; |
| constexpr int SHIFTS[NSTEPS] = {57, 29, 15, 8, 4, 2, 1}; |
| |
| for (int i = 0; i < NSTEPS; ++i) { |
| if (mantissa < BOUNDS[i]) { |
| exponent -= SHIFTS[i]; |
| mantissa <<= SHIFTS[i]; |
| } |
| } |
| } |
| #endif |
| |
| } // namespace internal |
| |
| // Correctly rounded IEEE 754 SQRT with round to nearest, ties to even. |
| // Shift-and-add algorithm. |
| template <typename T, |
| cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0> |
| static inline T sqrt(T x) { |
| using UIntType = typename FPBits<T>::UIntType; |
| constexpr UIntType ONE = UIntType(1) << MantissaWidth<T>::VALUE; |
| |
| FPBits<T> bits(x); |
| |
| if (bits.is_inf_or_nan()) { |
| if (bits.get_sign() && (bits.get_mantissa() == 0)) { |
| // sqrt(-Inf) = NaN |
| return FPBits<T>::build_nan(ONE >> 1); |
| } else { |
| // sqrt(NaN) = NaN |
| // sqrt(+Inf) = +Inf |
| return x; |
| } |
| } else if (bits.is_zero()) { |
| // sqrt(+0) = +0 |
| // sqrt(-0) = -0 |
| return x; |
| } else if (bits.get_sign()) { |
| // sqrt( negative numbers ) = NaN |
| return FPBits<T>::build_nan(ONE >> 1); |
| } else { |
| int x_exp = bits.get_exponent(); |
| UIntType x_mant = bits.get_mantissa(); |
| |
| // Step 1a: Normalize denormal input and append hidden bit to the mantissa |
| if (bits.get_unbiased_exponent() == 0) { |
| ++x_exp; // let x_exp be the correct exponent of ONE bit. |
| internal::normalize<T>(x_exp, x_mant); |
| } else { |
| x_mant |= ONE; |
| } |
| |
| // Step 1b: Make sure the exponent is even. |
| if (x_exp & 1) { |
| --x_exp; |
| x_mant <<= 1; |
| } |
| |
| // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and |
| // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. |
| // Notice that the output of sqrt is always in the normal range. |
| // To perform shift-and-add algorithm to find y, let denote: |
| // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: |
| // r(n) = 2^n ( x_mant - y(n)^2 ). |
| // That leads to the following recurrence formula: |
| // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] |
| // with the initial conditions: y(0) = 1, and r(0) = x - 1. |
| // So the nth digit y_n of the mantissa of sqrt(x) can be found by: |
| // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) |
| // 0 otherwise. |
| UIntType y = ONE; |
| UIntType r = x_mant - ONE; |
| |
| for (UIntType current_bit = ONE >> 1; current_bit; current_bit >>= 1) { |
| r <<= 1; |
| UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1) |
| if (r >= tmp) { |
| r -= tmp; |
| y += current_bit; |
| } |
| } |
| |
| // We compute one more iteration in order to round correctly. |
| bool lsb = y & 1; // Least significant bit |
| bool rb = false; // Round bit |
| r <<= 2; |
| UIntType tmp = (y << 2) + 1; |
| if (r >= tmp) { |
| r -= tmp; |
| rb = true; |
| } |
| |
| // Remove hidden bit and append the exponent field. |
| x_exp = ((x_exp >> 1) + FPBits<T>::EXPONENT_BIAS); |
| |
| y = (y - ONE) | (static_cast<UIntType>(x_exp) << MantissaWidth<T>::VALUE); |
| // Round to nearest, ties to even |
| if (rb && (lsb || (r != 0))) { |
| ++y; |
| } |
| |
| return *reinterpret_cast<T *>(&y); |
| } |
| } |
| |
| } // namespace fputil |
| } // namespace __llvm_libc |
| |
| #ifdef SPECIAL_X86_LONG_DOUBLE |
| #include "x86_64/SqrtLongDouble.h" |
| #endif // SPECIAL_X86_LONG_DOUBLE |
| |
| #endif // LLVM_LIBC_SRC_SUPPORT_FPUTIL_SQRT_H |