[ConstantFPRange] Add support for flushDenormals (#163074)

This patch provides a helper function to handle non-IEEE denormal
flushing behaviours. For the dynamic mode, it returns a union of all
possible results.
diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h
index 39dc7c1..10467cc 100644
--- a/llvm/include/llvm/IR/ConstantFPRange.h
+++ b/llvm/include/llvm/IR/ConstantFPRange.h
@@ -230,6 +230,10 @@
   /// Return a new range representing the possible values resulting
   /// from a subtraction of a value in this range and a value in \p Other.
   LLVM_ABI ConstantFPRange sub(const ConstantFPRange &Other) const;
+
+  /// Flush denormal values to zero according to the specified mode.
+  /// For dynamic mode, we return the union of all possible results.
+  LLVM_ABI void flushDenormals(DenormalMode::DenormalModeKind Mode);
 };
 
 inline raw_ostream &operator<<(raw_ostream &OS, const ConstantFPRange &CR) {
diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp
index 51d2e21..e9c058e 100644
--- a/llvm/lib/IR/ConstantFPRange.cpp
+++ b/llvm/lib/IR/ConstantFPRange.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/IR/ConstantFPRange.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
@@ -506,3 +507,24 @@
   // fsub X, Y = fadd X, (fneg Y)
   return add(Other.negate());
 }
+
+void ConstantFPRange::flushDenormals(DenormalMode::DenormalModeKind Mode) {
+  if (Mode == DenormalMode::IEEE)
+    return;
+  FPClassTest Class = classify();
+  if (!(Class & fcSubnormal))
+    return;
+
+  auto &Sem = getSemantics();
+  // PreserveSign: PosSubnormal -> PosZero, NegSubnormal -> NegZero
+  // PositiveZero: PosSubnormal -> PosZero, NegSubnormal -> PosZero
+  // Dynamic:      PosSubnormal -> PosZero, NegSubnormal -> NegZero/PosZero
+  bool ZeroLowerNegative =
+      Mode != DenormalMode::PositiveZero && (Class & fcNegSubnormal);
+  bool ZeroUpperNegative =
+      Mode == DenormalMode::PreserveSign && !(Class & fcPosSubnormal);
+  assert((ZeroLowerNegative || !ZeroUpperNegative) &&
+         "ZeroLower is greater than ZeroUpper.");
+  Lower = minnum(Lower, APFloat::getZero(Sem, ZeroLowerNegative));
+  Upper = maxnum(Upper, APFloat::getZero(Sem, ZeroUpperNegative));
+}
diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp
index cf9b31c..2431db9 100644
--- a/llvm/unittests/IR/ConstantFPRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/IR/ConstantFPRange.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Operator.h"
 #include "gtest/gtest.h"
@@ -1065,4 +1066,70 @@
 #endif
 }
 
+TEST_F(ConstantFPRangeTest, flushDenormals) {
+  const fltSemantics &FP8Sem = APFloat::Float8E4M3();
+  APFloat NormalVal = APFloat::getSmallestNormalized(FP8Sem);
+  APFloat Subnormal1 = NormalVal;
+  Subnormal1.next(/*nextDown=*/true);
+  APFloat Subnormal2 = APFloat::getSmallest(FP8Sem);
+  APFloat ZeroVal = APFloat::getZero(FP8Sem);
+  APFloat EdgeValues[8] = {-NormalVal, -Subnormal1, -Subnormal2, -ZeroVal,
+                           ZeroVal,    Subnormal2,  Subnormal1,  NormalVal};
+  constexpr DenormalMode::DenormalModeKind Modes[4] = {
+      DenormalMode::IEEE, DenormalMode::PreserveSign,
+      DenormalMode::PositiveZero, DenormalMode::Dynamic};
+  for (uint32_t I = 0; I != 8; ++I) {
+    for (uint32_t J = I; J != 8; ++J) {
+      ConstantFPRange OriginCR =
+          ConstantFPRange::getNonNaN(EdgeValues[I], EdgeValues[J]);
+      for (auto Mode : Modes) {
+        StringRef ModeName = denormalModeKindName(Mode);
+        ConstantFPRange FlushedCR = OriginCR;
+        FlushedCR.flushDenormals(Mode);
+
+        ConstantFPRange Expected = ConstantFPRange::getEmpty(FP8Sem);
+        auto CheckFlushedV = [&](const APFloat &V, const APFloat &FlushedV) {
+          EXPECT_TRUE(FlushedCR.contains(FlushedV))
+              << "Wrong result for flushDenormal(" << V << ", " << ModeName
+              << "). The result " << FlushedCR << " should contain "
+              << FlushedV;
+          if (!Expected.contains(FlushedV))
+            Expected = Expected.unionWith(ConstantFPRange(FlushedV));
+        };
+        EnumerateValuesInConstantFPRange(
+            OriginCR,
+            [&](const APFloat &V) {
+              if (V.isDenormal()) {
+                switch (Mode) {
+                case DenormalMode::IEEE:
+                  break;
+                case DenormalMode::PreserveSign:
+                  CheckFlushedV(V, APFloat::getZero(FP8Sem, V.isNegative()));
+                  break;
+                case DenormalMode::PositiveZero:
+                  CheckFlushedV(V, APFloat::getZero(FP8Sem));
+                  break;
+                case DenormalMode::Dynamic:
+                  // PreserveSign
+                  CheckFlushedV(V, APFloat::getZero(FP8Sem, V.isNegative()));
+                  // PositiveZero
+                  CheckFlushedV(V, APFloat::getZero(FP8Sem));
+                  break;
+                default:
+                  llvm_unreachable("unknown denormal mode");
+                }
+              }
+              // It is not mandated that flushing to zero occurs.
+              CheckFlushedV(V, V);
+            },
+            /*IgnoreNaNPayload=*/true);
+        EXPECT_EQ(FlushedCR, Expected)
+            << "Suboptimal result for flushDenormal(" << OriginCR << ", "
+            << ModeName << "). Expected " << Expected << ", but got "
+            << FlushedCR;
+      }
+    }
+  }
+}
+
 } // anonymous namespace