[ConstantFPRange] Add support for flushDenormals (#163074)
This patch provides a helper function to handle non-IEEE denormal
flushing behaviours. For the dynamic mode, it returns a union of all
possible results.
diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h
index 39dc7c1..10467cc 100644
--- a/llvm/include/llvm/IR/ConstantFPRange.h
+++ b/llvm/include/llvm/IR/ConstantFPRange.h
@@ -230,6 +230,10 @@
/// Return a new range representing the possible values resulting
/// from a subtraction of a value in this range and a value in \p Other.
LLVM_ABI ConstantFPRange sub(const ConstantFPRange &Other) const;
+
+ /// Flush denormal values to zero according to the specified mode.
+ /// For dynamic mode, we return the union of all possible results.
+ LLVM_ABI void flushDenormals(DenormalMode::DenormalModeKind Mode);
};
inline raw_ostream &operator<<(raw_ostream &OS, const ConstantFPRange &CR) {
diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp
index 51d2e21..e9c058e 100644
--- a/llvm/lib/IR/ConstantFPRange.cpp
+++ b/llvm/lib/IR/ConstantFPRange.cpp
@@ -8,6 +8,7 @@
#include "llvm/IR/ConstantFPRange.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
@@ -506,3 +507,24 @@
// fsub X, Y = fadd X, (fneg Y)
return add(Other.negate());
}
+
+void ConstantFPRange::flushDenormals(DenormalMode::DenormalModeKind Mode) {
+ if (Mode == DenormalMode::IEEE)
+ return;
+ FPClassTest Class = classify();
+ if (!(Class & fcSubnormal))
+ return;
+
+ auto &Sem = getSemantics();
+ // PreserveSign: PosSubnormal -> PosZero, NegSubnormal -> NegZero
+ // PositiveZero: PosSubnormal -> PosZero, NegSubnormal -> PosZero
+ // Dynamic: PosSubnormal -> PosZero, NegSubnormal -> NegZero/PosZero
+ bool ZeroLowerNegative =
+ Mode != DenormalMode::PositiveZero && (Class & fcNegSubnormal);
+ bool ZeroUpperNegative =
+ Mode == DenormalMode::PreserveSign && !(Class & fcPosSubnormal);
+ assert((ZeroLowerNegative || !ZeroUpperNegative) &&
+ "ZeroLower is greater than ZeroUpper.");
+ Lower = minnum(Lower, APFloat::getZero(Sem, ZeroLowerNegative));
+ Upper = maxnum(Upper, APFloat::getZero(Sem, ZeroUpperNegative));
+}
diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp
index cf9b31c..2431db9 100644
--- a/llvm/unittests/IR/ConstantFPRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp
@@ -8,6 +8,7 @@
#include "llvm/IR/ConstantFPRange.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "gtest/gtest.h"
@@ -1065,4 +1066,70 @@
#endif
}
+TEST_F(ConstantFPRangeTest, flushDenormals) {
+ const fltSemantics &FP8Sem = APFloat::Float8E4M3();
+ APFloat NormalVal = APFloat::getSmallestNormalized(FP8Sem);
+ APFloat Subnormal1 = NormalVal;
+ Subnormal1.next(/*nextDown=*/true);
+ APFloat Subnormal2 = APFloat::getSmallest(FP8Sem);
+ APFloat ZeroVal = APFloat::getZero(FP8Sem);
+ APFloat EdgeValues[8] = {-NormalVal, -Subnormal1, -Subnormal2, -ZeroVal,
+ ZeroVal, Subnormal2, Subnormal1, NormalVal};
+ constexpr DenormalMode::DenormalModeKind Modes[4] = {
+ DenormalMode::IEEE, DenormalMode::PreserveSign,
+ DenormalMode::PositiveZero, DenormalMode::Dynamic};
+ for (uint32_t I = 0; I != 8; ++I) {
+ for (uint32_t J = I; J != 8; ++J) {
+ ConstantFPRange OriginCR =
+ ConstantFPRange::getNonNaN(EdgeValues[I], EdgeValues[J]);
+ for (auto Mode : Modes) {
+ StringRef ModeName = denormalModeKindName(Mode);
+ ConstantFPRange FlushedCR = OriginCR;
+ FlushedCR.flushDenormals(Mode);
+
+ ConstantFPRange Expected = ConstantFPRange::getEmpty(FP8Sem);
+ auto CheckFlushedV = [&](const APFloat &V, const APFloat &FlushedV) {
+ EXPECT_TRUE(FlushedCR.contains(FlushedV))
+ << "Wrong result for flushDenormal(" << V << ", " << ModeName
+ << "). The result " << FlushedCR << " should contain "
+ << FlushedV;
+ if (!Expected.contains(FlushedV))
+ Expected = Expected.unionWith(ConstantFPRange(FlushedV));
+ };
+ EnumerateValuesInConstantFPRange(
+ OriginCR,
+ [&](const APFloat &V) {
+ if (V.isDenormal()) {
+ switch (Mode) {
+ case DenormalMode::IEEE:
+ break;
+ case DenormalMode::PreserveSign:
+ CheckFlushedV(V, APFloat::getZero(FP8Sem, V.isNegative()));
+ break;
+ case DenormalMode::PositiveZero:
+ CheckFlushedV(V, APFloat::getZero(FP8Sem));
+ break;
+ case DenormalMode::Dynamic:
+ // PreserveSign
+ CheckFlushedV(V, APFloat::getZero(FP8Sem, V.isNegative()));
+ // PositiveZero
+ CheckFlushedV(V, APFloat::getZero(FP8Sem));
+ break;
+ default:
+ llvm_unreachable("unknown denormal mode");
+ }
+ }
+ // It is not mandated that flushing to zero occurs.
+ CheckFlushedV(V, V);
+ },
+ /*IgnoreNaNPayload=*/true);
+ EXPECT_EQ(FlushedCR, Expected)
+ << "Suboptimal result for flushDenormal(" << OriginCR << ", "
+ << ModeName << "). Expected " << Expected << ", but got "
+ << FlushedCR;
+ }
+ }
+ }
+}
+
} // anonymous namespace