| //===-- Utils which wrap MPFR ---------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "MPFRUtils.h" |
| |
| #include "utils/FPUtil/FPBits.h" |
| |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringRef.h" |
| |
| #include <mpfr.h> |
| #include <stdint.h> |
| #include <string> |
| |
| template <typename T> using FPBits = __llvm_libc::fputil::FPBits<T>; |
| |
| namespace __llvm_libc { |
| namespace testing { |
| namespace mpfr { |
| |
| class MPFRNumber { |
| // A precision value which allows sufficiently large additional |
| // precision even compared to quad-precision floating point values. |
| static constexpr unsigned int mpfrPrecision = 128; |
| |
| mpfr_t value; |
| |
| public: |
| MPFRNumber() { mpfr_init2(value, mpfrPrecision); } |
| |
| // We use explicit EnableIf specializations to disallow implicit |
| // conversions. Implicit conversions can potentially lead to loss of |
| // precision. |
| template <typename XType, |
| cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0> |
| explicit MPFRNumber(XType x) { |
| mpfr_init2(value, mpfrPrecision); |
| mpfr_set_flt(value, x, MPFR_RNDN); |
| } |
| |
| template <typename XType, |
| cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0> |
| explicit MPFRNumber(XType x) { |
| mpfr_init2(value, mpfrPrecision); |
| mpfr_set_d(value, x, MPFR_RNDN); |
| } |
| |
| template <typename XType, |
| cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0> |
| explicit MPFRNumber(XType x) { |
| mpfr_init2(value, mpfrPrecision); |
| mpfr_set_ld(value, x, MPFR_RNDN); |
| } |
| |
| template <typename XType, |
| cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0> |
| explicit MPFRNumber(XType x) { |
| mpfr_init2(value, mpfrPrecision); |
| mpfr_set_sj(value, x, MPFR_RNDN); |
| } |
| |
| template <typename XType> MPFRNumber(XType x, const Tolerance &t) { |
| mpfr_init2(value, mpfrPrecision); |
| mpfr_set_zero(value, 1); // Set to positive zero. |
| MPFRNumber xExponent(fputil::FPBits<XType>(x).getExponent()); |
| // E = 2^E |
| mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN); |
| uint32_t bitMask = 1 << (t.width - 1); |
| for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) { |
| --n; |
| if (t.bits & bitMask) { |
| // delta = -n |
| MPFRNumber delta(n); |
| |
| // delta = 2^(-n) |
| mpfr_exp2(delta.value, delta.value, MPFR_RNDN); |
| |
| // delta = E * 2^(-n) |
| mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN); |
| |
| // tolerance += delta |
| mpfr_add(value, value, delta.value, MPFR_RNDN); |
| } |
| } |
| } |
| |
| template <typename XType, |
| cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0> |
| MPFRNumber(Operation op, XType rawValue) { |
| mpfr_init2(value, mpfrPrecision); |
| MPFRNumber mpfrInput(rawValue); |
| switch (op) { |
| case Operation::Abs: |
| mpfr_abs(value, mpfrInput.value, MPFR_RNDN); |
| break; |
| case Operation::Ceil: |
| mpfr_ceil(value, mpfrInput.value); |
| break; |
| case Operation::Cos: |
| mpfr_cos(value, mpfrInput.value, MPFR_RNDN); |
| break; |
| case Operation::Exp: |
| mpfr_exp(value, mpfrInput.value, MPFR_RNDN); |
| break; |
| case Operation::Exp2: |
| mpfr_exp2(value, mpfrInput.value, MPFR_RNDN); |
| break; |
| case Operation::Floor: |
| mpfr_floor(value, mpfrInput.value); |
| break; |
| case Operation::Round: |
| mpfr_round(value, mpfrInput.value); |
| break; |
| case Operation::Sin: |
| mpfr_sin(value, mpfrInput.value, MPFR_RNDN); |
| break; |
| case Operation::Trunc: |
| mpfr_trunc(value, mpfrInput.value); |
| break; |
| } |
| } |
| |
| MPFRNumber(const MPFRNumber &other) { |
| mpfr_set(value, other.value, MPFR_RNDN); |
| } |
| |
| ~MPFRNumber() { mpfr_clear(value); } |
| |
| // Returns true if |other| is within the |tolerance| value of this |
| // number. |
| bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const { |
| MPFRNumber difference; |
| if (mpfr_cmp(value, other.value) >= 0) |
| mpfr_sub(difference.value, value, other.value, MPFR_RNDN); |
| else |
| mpfr_sub(difference.value, other.value, value, MPFR_RNDN); |
| |
| return mpfr_lessequal_p(difference.value, tolerance.value); |
| } |
| |
| std::string str() const { |
| // 200 bytes should be more than sufficient to hold a 100-digit number |
| // plus additional bytes for the decimal point, '-' sign etc. |
| constexpr size_t printBufSize = 200; |
| char buffer[printBufSize]; |
| mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value); |
| llvm::StringRef ref(buffer); |
| ref = ref.trim(); |
| return ref.str(); |
| } |
| |
| // These functions are useful for debugging. |
| float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); } |
| double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); } |
| void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); } |
| }; |
| |
| namespace internal { |
| |
| template <typename T> |
| void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) { |
| MPFRNumber mpfrResult(operation, input); |
| MPFRNumber mpfrInput(input); |
| MPFRNumber mpfrMatchValue(matchValue); |
| MPFRNumber mpfrToleranceValue(matchValue, tolerance); |
| FPBits<T> inputBits(input); |
| FPBits<T> matchBits(matchValue); |
| // TODO: Call to llvm::utohexstr implicitly converts __uint128_t values to |
| // uint64_t values. This can be fixed using a custom wrapper for |
| // llvm::utohexstr to handle __uint128_t values correctly. |
| OS << "Match value not within tolerance value of MPFR result:\n" |
| << " Input decimal: " << mpfrInput.str() << '\n' |
| << " Input bits: 0x" << llvm::utohexstr(inputBits.bitsAsUInt()) << '\n' |
| << " Match decimal: " << mpfrMatchValue.str() << '\n' |
| << " Match bits: 0x" << llvm::utohexstr(matchBits.bitsAsUInt()) << '\n' |
| << " MPFR result: " << mpfrResult.str() << '\n' |
| << "Tolerance value: " << mpfrToleranceValue.str() << '\n'; |
| } |
| |
| template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &); |
| template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &); |
| template void |
| MPFRMatcher<long double>::explainError(testutils::StreamWrapper &); |
| |
| template <typename T> |
| bool compare(Operation op, T input, T libcResult, const Tolerance &t) { |
| MPFRNumber mpfrResult(op, input); |
| MPFRNumber mpfrLibcResult(libcResult); |
| MPFRNumber mpfrToleranceValue(libcResult, t); |
| |
| return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue); |
| }; |
| |
| template bool compare<float>(Operation, float, float, const Tolerance &); |
| template bool compare<double>(Operation, double, double, const Tolerance &); |
| template bool compare<long double>(Operation, long double, long double, |
| const Tolerance &); |
| |
| } // namespace internal |
| |
| } // namespace mpfr |
| } // namespace testing |
| } // namespace __llvm_libc |