blob: ab7529682950df93570e71d28f459b5d58db617e [file]
//===-- Implementation header for rsqrtf16 ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H
#define LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H
#include "include/llvm-libc-macros/float16-macros.h"
#ifdef LIBC_TYPES_HAS_FLOAT16
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/ManipulationFunctions.h"
#include "src/__support/FPUtil/cast.h"
#include "src/__support/FPUtil/multiply_add.h"
#include "src/__support/FPUtil/sqrt.h"
#include "src/__support/macros/optimization.h"
namespace LIBC_NAMESPACE_DECL {
namespace math {
LIBC_INLINE constexpr float16 rsqrtf16(float16 x) {
using FPBits = fputil::FPBits<float16>;
FPBits xbits(x);
uint16_t x_u = xbits.uintval();
uint16_t x_abs = x_u & 0x7fff;
constexpr uint16_t INF_BIT = FPBits::inf().uintval();
// x is 0, inf/nan, or negative.
if (LIBC_UNLIKELY(x_u == 0 || x_u >= INF_BIT)) {
// x is NaN
if (x_abs > INF_BIT) {
if (xbits.is_signaling_nan()) {
fputil::raise_except_if_required(FE_INVALID);
return FPBits::quiet_nan().get_val();
}
return x;
}
// |x| = 0
if (x_abs == 0) {
fputil::raise_except_if_required(FE_DIVBYZERO);
fputil::set_errno_if_required(ERANGE);
return FPBits::inf(xbits.sign()).get_val();
}
// -inf <= x < 0
if (x_u > 0x7fff) {
fputil::raise_except_if_required(FE_INVALID);
fputil::set_errno_if_required(EDOM);
return FPBits::quiet_nan().get_val();
}
// x = +inf => rsqrt(x) = 0
return FPBits::zero().get_val();
}
// TODO: add integer based implementation when LIBC_TARGET_CPU_HAS_FPU_FLOAT
// is not defined
float result = 1.0f / fputil::sqrt<float>(fputil::cast<float>(x));
// Targeted post-corrections to ensure correct rounding in half for specific
// mantissa patterns
const uint16_t half_mantissa = x_abs & 0x3ff;
if (LIBC_UNLIKELY(half_mantissa == 0x011F)) {
result = fputil::multiply_add(result, 0x1.0p-21f, result);
} else if (LIBC_UNLIKELY(half_mantissa == 0x0313)) {
result = fputil::multiply_add(result, -0x1.0p-21f, result);
}
return fputil::cast<float16>(result);
}
} // namespace math
} // namespace LIBC_NAMESPACE_DECL
#endif // LIBC_TYPES_HAS_FLOAT16
#endif // LLVM_LIBC_SRC___SUPPORT_MATH_RSQRTF16_H