[libc] Extend MPFRMatcher to handle 2-input-1-output and support hypot function.

Differential Revision: https://reviews.llvm.org/D87514

GitOrigin-RevId: abf1c82dcc5c54f2bbd65eb7b30cc40de2bd7147
diff --git a/utils/MPFRWrapper/MPFRUtils.cpp b/utils/MPFRWrapper/MPFRUtils.cpp
index 0520d8a..56764e9 100644
--- a/utils/MPFRWrapper/MPFRUtils.cpp
+++ b/utils/MPFRWrapper/MPFRUtils.cpp
@@ -133,6 +133,12 @@
     return result;
   }
 
+  MPFRNumber hypot(const MPFRNumber &b) {
+    MPFRNumber result;
+    mpfr_hypot(result.value, value, b.value, MPFR_RNDN);
+    return result;
+  }
+
   MPFRNumber remquo(const MPFRNumber &divisor, int &quotient) {
     MPFRNumber remainder;
     long q;
@@ -278,6 +284,18 @@
 
 template <typename InputType>
 cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
+binaryOperationOneOutput(Operation op, InputType x, InputType y) {
+  MPFRNumber inputX(x), inputY(y);
+  switch (op) {
+  case Operation::Hypot:
+    return inputX.hypot(inputY);
+  default:
+    __builtin_unreachable();
+  }
+}
+
+template <typename InputType>
+cpp::EnableIfType<cpp::IsFloatingPointType<InputType>::Value, MPFRNumber>
 binaryOperationTwoOutputs(Operation op, InputType x, InputType y, int &output) {
   MPFRNumber inputX(x), inputY(y);
   switch (op) {
@@ -402,6 +420,41 @@
     const BinaryOutput<long double> &, testutils::StreamWrapper &);
 
 template <typename T>
+void explainBinaryOperationOneOutputError(Operation op,
+                                          const BinaryInput<T> &input,
+                                          T libcResult,
+                                          testutils::StreamWrapper &OS) {
+  MPFRNumber mpfrX(input.x);
+  MPFRNumber mpfrY(input.y);
+  FPBits<T> xbits(input.x);
+  FPBits<T> ybits(input.y);
+  MPFRNumber mpfrResult = binaryOperationOneOutput(op, input.x, input.y);
+  MPFRNumber mpfrMatchValue(libcResult);
+
+  OS << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue("First input bits: ", input.x,
+                                              OS);
+  __llvm_libc::fputil::testing::describeValue("Second input bits: ", input.y,
+                                              OS);
+
+  OS << "Libc result: " << mpfrMatchValue.str() << '\n'
+     << "MPFR result: " << mpfrResult.str() << '\n';
+  __llvm_libc::fputil::testing::describeValue(
+      "Libc floating point result bits: ", libcResult, OS);
+  __llvm_libc::fputil::testing::describeValue(
+      "              MPFR rounded bits: ", mpfrResult.as<T>(), OS);
+  OS << "ULP error: " << std::to_string(mpfrResult.ulp(libcResult)) << '\n';
+}
+
+template void explainBinaryOperationOneOutputError<float>(
+    Operation, const BinaryInput<float> &, float, testutils::StreamWrapper &);
+template void explainBinaryOperationOneOutputError<double>(
+    Operation, const BinaryInput<double> &, double, testutils::StreamWrapper &);
+template void explainBinaryOperationOneOutputError<long double>(
+    Operation, const BinaryInput<long double> &, long double,
+    testutils::StreamWrapper &);
+
+template <typename T>
 bool compareUnaryOperationSingleOutput(Operation op, T input, T libcResult,
                                        double ulpError) {
   // If the ulp error is exactly 0.5 (i.e a tie), we would check that the result
@@ -480,6 +533,26 @@
     Operation, const BinaryInput<long double> &,
     const BinaryOutput<long double> &, double);
 
+template <typename T>
+bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
+                                     T libcResult, double ulpError) {
+  MPFRNumber mpfrResult = binaryOperationOneOutput(op, input.x, input.y);
+  double ulp = mpfrResult.ulp(libcResult);
+
+  bool bitsAreEven = ((FPBits<T>(libcResult).bitsAsUInt() & 1) == 0);
+  return (ulp < ulpError) ||
+         ((ulp == ulpError) && ((ulp != 0.5) || bitsAreEven));
+}
+
+template bool compareBinaryOperationOneOutput<float>(Operation,
+                                                     const BinaryInput<float> &,
+                                                     float, double);
+template bool
+compareBinaryOperationOneOutput<double>(Operation, const BinaryInput<double> &,
+                                        double, double);
+template bool compareBinaryOperationOneOutput<long double>(
+    Operation, const BinaryInput<long double> &, long double, double);
+
 } // namespace internal
 
 } // namespace mpfr
diff --git a/utils/MPFRWrapper/MPFRUtils.h b/utils/MPFRWrapper/MPFRUtils.h
index b46f09d..6fb9fe5 100644
--- a/utils/MPFRWrapper/MPFRUtils.h
+++ b/utils/MPFRWrapper/MPFRUtils.h
@@ -47,7 +47,7 @@
   // input and produce a single floating point number of the same type as
   // output.
   BeginBinaryOperationsSingleOutput,
-  // TODO: Add operations like hypot.
+  Hypot,
   EndBinaryOperationsSingleOutput,
 
   // Operations which take two floating point numbers of the same type as
@@ -110,6 +110,10 @@
                                       double t);
 
 template <typename T>
+bool compareBinaryOperationOneOutput(Operation op, const BinaryInput<T> &input,
+                                     T libcOutput, double t);
+
+template <typename T>
 void explainUnaryOperationSingleOutputError(Operation op, T input, T matchValue,
                                             testutils::StreamWrapper &OS);
 template <typename T>
@@ -122,6 +126,12 @@
                                            const BinaryOutput<T> &matchValue,
                                            testutils::StreamWrapper &OS);
 
+template <typename T>
+void explainBinaryOperationOneOutputError(Operation op,
+                                          const BinaryInput<T> &input,
+                                          T matchValue,
+                                          testutils::StreamWrapper &OS);
+
 template <Operation op, typename InputType, typename OutputType>
 class MPFRMatcher : public testing::Matcher<OutputType> {
   InputType input;
@@ -153,7 +163,7 @@
 
   template <typename T>
   static bool match(const BinaryInput<T> &in, T out, double tolerance) {
-    // TODO: Implement the comparision function and error reporter.
+    return compareBinaryOperationOneOutput(op, in, out, tolerance);
   }
 
   template <typename T>
@@ -183,6 +193,12 @@
                            testutils::StreamWrapper &OS) {
     explainBinaryOperationTwoOutputsError(op, in, out, OS);
   }
+
+  template <typename T>
+  static void explainError(const BinaryInput<T> &in, T out,
+                           testutils::StreamWrapper &OS) {
+    explainBinaryOperationOneOutputError(op, in, out, OS);
+  }
 };
 
 } // namespace internal