[mlir] fix MemRefToLLVM lowering of atomic operations (#139045)
We have been confusingly, and arguably incorrectly, lowering `m**imumf`
atomic RMW operations in the MemRef dialect to `fm**` atomic RMW
operations in the LLVM dialect, which have different NaN-propagation
semantics: `m**imumf` propagates NaNs from either operand whereas
`fm**`, which lowers to the `fm**num` intrinsic returns the non-NaN
operand. This also contradicts the lowering of `arith.m**imumf` and
`arith.m**numf` operations.
Change the lowering to match the terminology in arith.
Add tests for these lowerings.
Keep a debug message in case of surprising behavior downstream (the code
may be producing more NaNs now).
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ade4e4d..8ccf1bf 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -28,6 +28,9 @@
#include "llvm/Support/MathExtras.h"
#include <optional>
+#define DEBUG_TYPE "memref-to-llvm"
+#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
+
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -1782,12 +1785,22 @@
case arith::AtomicRMWKind::assign:
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs");
+ return LLVM::AtomicBinOp::fmaximum;
+ case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
case arith::AtomicRMWKind::maxs:
return LLVM::AtomicBinOp::max;
case arith::AtomicRMWKind::maxu:
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
+ // TODO: remove this by end of 2025.
+ LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs");
+ return LLVM::AtomicBinOp::fminimum;
+ case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
case arith::AtomicRMWKind::mins:
return LLVM::AtomicBinOp::min;
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index acfc188..51d5638 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -452,11 +452,19 @@
// CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmaximum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maxnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmax %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minimumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fminimum %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minnumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
// CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- // CHECK-INTERFACE-COUNT-9: llvm.atomicrmw
+ // CHECK-INTERFACE-COUNT-13: llvm.atomicrmw
return
}