[EarlyCSE] Support memset loads (#194268)

This PR addresses the zero-`memset` case in EarlyCSE as discussed in
#194080. If we do a `memset` of zero and then load back from the same
base pointer, we can fold that load to `null`.
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index ff3d55e..3ea18b7 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -821,6 +821,12 @@
             Info.IsVolatile = false;
             break;
           }
+        } else if (auto *MI = dyn_cast<MemSetInst>(Inst)) {
+          Info.PtrVal = MI->getDest();
+          Info.MatchingId = 0;
+          Info.ReadMem = false;
+          Info.WriteMem = true;
+          Info.IsVolatile = MI->isVolatile();
         }
       }
     }
@@ -1226,6 +1232,26 @@
                                   unsigned CurrentGeneration) {
   if (InVal.DefInst == nullptr)
     return nullptr;
+  if (auto *MSI = dyn_cast<MemSetInst>(InVal.DefInst)) {
+    if (!MemInst.isLoad() || MemInst.isVolatile() || !MemInst.isUnordered())
+      return nullptr;
+    if (MSI->isVolatile())
+      return nullptr;
+    auto *Val = dyn_cast<ConstantInt>(MSI->getValue());
+    if (!Val || !Val->isZero())
+      return nullptr;
+    auto Len = MSI->getLengthInBytes();
+    if (!Len)
+      return nullptr;
+    TypeSize LoadSize = SQ.DL.getTypeStoreSize(MemInst.getValueType());
+    if (LoadSize.isScalable() || Len->ult(LoadSize.getFixedValue()))
+      return nullptr;
+    if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) &&
+        !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst,
+                             MemInst.get()))
+      return nullptr;
+    return Constant::getNullValue(MemInst.getValueType());
+  }
   if (InVal.MatchingId != MemInst.getMatchingId())
     return nullptr;
   // We don't yet handle removing loads with ordering of any kind.
diff --git a/llvm/test/Transforms/EarlyCSE/memset-load.ll b/llvm/test/Transforms/EarlyCSE/memset-load.ll
new file mode 100644
index 0000000..f563d29
--- /dev/null
+++ b/llvm/test/Transforms/EarlyCSE/memset-load.ll
@@ -0,0 +1,92 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt -S -passes='early-cse<memssa>' < %s | FileCheck %s
+
+target datalayout = "pe1:64:64:64:32"
+
+define ptr @load_from_zero_memset(ptr %p) {
+; CHECK-LABEL: define ptr @load_from_zero_memset(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P]], i8 0, i64 8, i1 false)
+; CHECK-NEXT:    ret ptr null
+;
+entry:
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 8, i1 false)
+  %v = load ptr, ptr %p, align 8
+  ret ptr %v
+}
+
+define ptr @load_from_nonzero_memset(ptr %p) {
+; CHECK-LABEL: define ptr @load_from_nonzero_memset(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P]], i8 1, i64 8, i1 false)
+; CHECK-NEXT:    [[V:%.*]] = load ptr, ptr [[P]], align 8
+; CHECK-NEXT:    ret ptr [[V]]
+;
+entry:
+  call void @llvm.memset.p0.i64(ptr %p, i8 1, i64 8, i1 false)
+  %v = load ptr, ptr %p, align 8
+  ret ptr %v
+}
+
+define ptr @load_from_zero_memset_with_clobber(ptr %p) {
+; CHECK-LABEL: define ptr @load_from_zero_memset_with_clobber(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P]], i8 0, i64 8, i1 false)
+; CHECK-NEXT:    call void @clobber(ptr [[P]])
+; CHECK-NEXT:    [[V:%.*]] = load ptr, ptr [[P]], align 8
+; CHECK-NEXT:    ret ptr [[V]]
+;
+entry:
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 8, i1 false)
+  call void @clobber(ptr %p)
+  %v = load ptr, ptr %p, align 8
+  ret ptr %v
+}
+
+define ptr @load_from_volatile_zero_memset(ptr %p) {
+; CHECK-LABEL: define ptr @load_from_volatile_zero_memset(
+; CHECK-SAME: ptr [[P:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P]], i8 0, i64 8, i1 true)
+; CHECK-NEXT:    [[V:%.*]] = load ptr, ptr [[P]], align 8
+; CHECK-NEXT:    ret ptr [[V]]
+;
+entry:
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 8, i1 true)
+  %v = load ptr, ptr %p, align 8
+  ret ptr %v
+}
+
+define ptr @load_from_zero_memset_unknown_length(ptr %p, i64 %n) {
+; CHECK-LABEL: define ptr @load_from_zero_memset_unknown_length(
+; CHECK-SAME: ptr [[P:%.*]], i64 [[N:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.memset.p0.i64(ptr [[P]], i8 0, i64 [[N]], i1 false)
+; CHECK-NEXT:    [[V:%.*]] = load ptr, ptr [[P]], align 8
+; CHECK-NEXT:    ret ptr [[V]]
+;
+entry:
+  call void @llvm.memset.p0.i64(ptr %p, i8 0, i64 %n, i1 false)
+  %v = load ptr, ptr %p, align 8
+  ret ptr %v
+}
+
+define ptr addrspace(1) @load_from_zero_memset_external(ptr addrspace(1) %p) {
+; CHECK-LABEL: define ptr addrspace(1) @load_from_zero_memset_external(
+; CHECK-SAME: ptr addrspace(1) [[P:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    call void @llvm.memset.p1.i64(ptr addrspace(1) [[P]], i8 0, i64 8, i1 false)
+; CHECK-NEXT:    ret ptr addrspace(1) null
+;
+entry:
+  call void @llvm.memset.p1.i64(ptr addrspace(1) %p, i8 0, i64 8, i1 false)
+  %v = load ptr addrspace(1), ptr addrspace(1) %p, align 8
+  ret ptr addrspace(1) %v
+}
+
+declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)
+declare void @llvm.memset.p1.i64(ptr addrspace(1) nocapture writeonly, i8, i64, i1 immarg)
+declare void @clobber(ptr)