blob: c6020f471e88d83614b857dea51f56dfc707af43 [file] [log] [blame]
//===-- Utils which wrap MPFR ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "MPFRUtils.h"
#include "utils/FPUtil/FPBits.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include <mpfr.h>
#include <stdint.h>
#include <string>
template <typename T> using FPBits = __llvm_libc::fputil::FPBits<T>;
namespace __llvm_libc {
namespace testing {
namespace mpfr {
class MPFRNumber {
// A precision value which allows sufficiently large additional
// precision even compared to quad-precision floating point values.
static constexpr unsigned int mpfrPrecision = 128;
mpfr_t value;
public:
MPFRNumber() { mpfr_init2(value, mpfrPrecision); }
// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
template <typename XType,
cpp::EnableIfType<cpp::IsSame<float, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_flt(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_d(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsSame<long double, XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_ld(value, x, MPFR_RNDN);
}
template <typename XType,
cpp::EnableIfType<cpp::IsIntegral<XType>::Value, int> = 0>
explicit MPFRNumber(XType x) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_sj(value, x, MPFR_RNDN);
}
template <typename XType> MPFRNumber(XType x, const Tolerance &t) {
mpfr_init2(value, mpfrPrecision);
mpfr_set_zero(value, 1); // Set to positive zero.
MPFRNumber xExponent(fputil::FPBits<XType>(x).getExponent());
// E = 2^E
mpfr_exp2(xExponent.value, xExponent.value, MPFR_RNDN);
uint32_t bitMask = 1 << (t.width - 1);
for (int n = -t.basePrecision; bitMask > 0; bitMask >>= 1) {
--n;
if (t.bits & bitMask) {
// delta = -n
MPFRNumber delta(n);
// delta = 2^(-n)
mpfr_exp2(delta.value, delta.value, MPFR_RNDN);
// delta = E * 2^(-n)
mpfr_mul(delta.value, delta.value, xExponent.value, MPFR_RNDN);
// tolerance += delta
mpfr_add(value, value, delta.value, MPFR_RNDN);
}
}
}
template <typename XType,
cpp::EnableIfType<cpp::IsFloatingPointType<XType>::Value, int> = 0>
MPFRNumber(Operation op, XType rawValue) {
mpfr_init2(value, mpfrPrecision);
MPFRNumber mpfrInput(rawValue);
switch (op) {
case Operation::Abs:
mpfr_abs(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Ceil:
mpfr_ceil(value, mpfrInput.value);
break;
case Operation::Cos:
mpfr_cos(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Exp:
mpfr_exp(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Exp2:
mpfr_exp2(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Floor:
mpfr_floor(value, mpfrInput.value);
break;
case Operation::Round:
mpfr_round(value, mpfrInput.value);
break;
case Operation::Sin:
mpfr_sin(value, mpfrInput.value, MPFR_RNDN);
break;
case Operation::Trunc:
mpfr_trunc(value, mpfrInput.value);
break;
}
}
MPFRNumber(const MPFRNumber &other) {
mpfr_set(value, other.value, MPFR_RNDN);
}
~MPFRNumber() { mpfr_clear(value); }
// Returns true if |other| is within the |tolerance| value of this
// number.
bool isEqual(const MPFRNumber &other, const MPFRNumber &tolerance) const {
MPFRNumber difference;
if (mpfr_cmp(value, other.value) >= 0)
mpfr_sub(difference.value, value, other.value, MPFR_RNDN);
else
mpfr_sub(difference.value, other.value, value, MPFR_RNDN);
return mpfr_lessequal_p(difference.value, tolerance.value);
}
std::string str() const {
// 200 bytes should be more than sufficient to hold a 100-digit number
// plus additional bytes for the decimal point, '-' sign etc.
constexpr size_t printBufSize = 200;
char buffer[printBufSize];
mpfr_snprintf(buffer, printBufSize, "%100.50Rf", value);
llvm::StringRef ref(buffer);
ref = ref.trim();
return ref.str();
}
// These functions are useful for debugging.
float asFloat() const { return mpfr_get_flt(value, MPFR_RNDN); }
double asDouble() const { return mpfr_get_d(value, MPFR_RNDN); }
void dump(const char *msg) const { mpfr_printf("%s%.128Rf\n", msg, value); }
};
namespace internal {
template <typename T>
void MPFRMatcher<T>::explainError(testutils::StreamWrapper &OS) {
MPFRNumber mpfrResult(operation, input);
MPFRNumber mpfrInput(input);
MPFRNumber mpfrMatchValue(matchValue);
MPFRNumber mpfrToleranceValue(matchValue, tolerance);
FPBits<T> inputBits(input);
FPBits<T> matchBits(matchValue);
// TODO: Call to llvm::utohexstr implicitly converts __uint128_t values to
// uint64_t values. This can be fixed using a custom wrapper for
// llvm::utohexstr to handle __uint128_t values correctly.
OS << "Match value not within tolerance value of MPFR result:\n"
<< " Input decimal: " << mpfrInput.str() << '\n'
<< " Input bits: 0x" << llvm::utohexstr(inputBits.bitsAsUInt()) << '\n'
<< " Match decimal: " << mpfrMatchValue.str() << '\n'
<< " Match bits: 0x" << llvm::utohexstr(matchBits.bitsAsUInt()) << '\n'
<< " MPFR result: " << mpfrResult.str() << '\n'
<< "Tolerance value: " << mpfrToleranceValue.str() << '\n';
}
template void MPFRMatcher<float>::explainError(testutils::StreamWrapper &);
template void MPFRMatcher<double>::explainError(testutils::StreamWrapper &);
template void
MPFRMatcher<long double>::explainError(testutils::StreamWrapper &);
template <typename T>
bool compare(Operation op, T input, T libcResult, const Tolerance &t) {
MPFRNumber mpfrResult(op, input);
MPFRNumber mpfrLibcResult(libcResult);
MPFRNumber mpfrToleranceValue(libcResult, t);
return mpfrResult.isEqual(mpfrLibcResult, mpfrToleranceValue);
};
template bool compare<float>(Operation, float, float, const Tolerance &);
template bool compare<double>(Operation, double, double, const Tolerance &);
template bool compare<long double>(Operation, long double, long double,
const Tolerance &);
} // namespace internal
} // namespace mpfr
} // namespace testing
} // namespace __llvm_libc