[MemCpyOpt] memset->memcpy forwarding with undef tail

Currently memcpyopt optimizes cases like

    memset(a, byte, N);
    memcpy(b, a, M);

to

    memset(a, byte, N);
    memset(b, byte, M);

if M <= N. Often this allows further simplifications down the line,
which drop the first memset entirely.

This patch extends this optimization for the case where M > N, but we
know that the bytes a[N..M] are undef due to alloca/lifetime.start.

This situation arises relatively often for Rust code, because Rust does
not initialize trailing structure padding and loves to insert redundant
memcpys. This also fixes https://bugs.llvm.org/show_bug.cgi?id=39844.

For the implementation, I'm reusing a bit of code for a similar existing
optimization (direct memcpy of undef). I've also added memset support to
MemDepAnalysis GetLocation -- Instead, getPointerDependencyFrom could be
used, but it seems to make more sense to add this to GetLocation and thus
make the computation cachable.

Differential Revision: https://reviews.llvm.org/D55120

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@348645 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/MemoryDependenceAnalysis.cpp b/lib/Analysis/MemoryDependenceAnalysis.cpp
index 2fe012d..0907559 100644
--- a/lib/Analysis/MemoryDependenceAnalysis.cpp
+++ b/lib/Analysis/MemoryDependenceAnalysis.cpp
@@ -154,6 +154,12 @@
     return ModRefInfo::Mod;
   }
 
+  if (const MemSetInst *MI = dyn_cast<MemSetInst>(Inst)) {
+    Loc = MemoryLocation::getForDest(MI);
+    // Conversatively assume ModRef for volatile memset.
+    return MI->isVolatile() ? ModRefInfo::ModRef : ModRefInfo::Mod;
+  }
+
   if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
     switch (II->getIntrinsicID()) {
     case Intrinsic::lifetime_start:
diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp
index 4e82e2b..fa44cd9 100644
--- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp
+++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp
@@ -1144,6 +1144,21 @@
   return true;
 }
 
+/// Determine whether the instruction has undefined content for the given Size,
+/// either because it was freshly alloca'd or started its lifetime.
+static bool hasUndefContents(Instruction *I, ConstantInt *Size) {
+  if (isa<AllocaInst>(I))
+    return true;
+
+  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
+    if (II->getIntrinsicID() == Intrinsic::lifetime_start)
+      if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0)))
+        if (LTSize->getZExtValue() >= Size->getZExtValue())
+          return true;
+
+  return false;
+}
+
 /// Transform memcpy to memset when its source was just memset.
 /// In other words, turn:
 /// \code
@@ -1167,12 +1182,23 @@
   if (!AA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource()))
     return false;
 
-  ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength());
+  // A known memset size is required.
   ConstantInt *MemSetSize = dyn_cast<ConstantInt>(MemSet->getLength());
+  if (!MemSetSize)
+    return false;
+
   // Make sure the memcpy doesn't read any more than what the memset wrote.
   // Don't worry about sizes larger than i64.
-  if (!MemSetSize || CopySize->getZExtValue() > MemSetSize->getZExtValue())
-    return false;
+  ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength());
+  if (CopySize->getZExtValue() > MemSetSize->getZExtValue()) {
+    // If the memcpy is larger than the memset, but the memory was undef prior
+    // to the memset, we can just ignore the tail.
+    MemDepResult DepInfo = MD->getDependency(MemSet);
+    if (DepInfo.isDef() && hasUndefContents(DepInfo.getInst(), CopySize))
+      CopySize = MemSetSize;
+    else
+      return false;
+  }
 
   IRBuilder<> Builder(MemCpy);
   Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1),
@@ -1252,19 +1278,7 @@
     if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst()))
       return processMemCpyMemCpyDependence(M, MDep);
   } else if (SrcDepInfo.isDef()) {
-    Instruction *I = SrcDepInfo.getInst();
-    bool hasUndefContents = false;
-
-    if (isa<AllocaInst>(I)) {
-      hasUndefContents = true;
-    } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
-      if (II->getIntrinsicID() == Intrinsic::lifetime_start)
-        if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0)))
-          if (LTSize->getZExtValue() >= CopySize->getZExtValue())
-            hasUndefContents = true;
-    }
-
-    if (hasUndefContents) {
+    if (hasUndefContents(SrcDepInfo.getInst(), CopySize)) {
       MD->removeInstruction(M);
       M->eraseFromParent();
       ++NumMemCpyInstr;
diff --git a/test/Transforms/MemCpyOpt/memset-memcpy-oversized.ll b/test/Transforms/MemCpyOpt/memset-memcpy-oversized.ll
index 39538be..7495400 100644
--- a/test/Transforms/MemCpyOpt/memset-memcpy-oversized.ll
+++ b/test/Transforms/MemCpyOpt/memset-memcpy-oversized.ll
@@ -12,7 +12,7 @@
 ; CHECK-NEXT:    [[A:%.*]] = alloca [[T:%.*]], align 8
 ; CHECK-NEXT:    [[B:%.*]] = bitcast %T* [[A]] to i8*
 ; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 8 [[B]], i8 0, i64 12, i1 false)
-; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[RESULT:%.*]], i8* align 8 [[B]], i64 16, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* [[RESULT:%.*]], i8 0, i64 12, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %a = alloca %T, align 8
@@ -28,7 +28,7 @@
 ; CHECK-NEXT:    [[B:%.*]] = bitcast %T* [[A]] to i8*
 ; CHECK-NEXT:    call void @llvm.lifetime.start.p0i8(i64 16, i8* [[B]])
 ; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 8 [[B]], i8 0, i64 12, i1 false)
-; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[RESULT:%.*]], i8* align 8 [[B]], i64 16, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* [[RESULT:%.*]], i8 0, i64 12, i1 false)
 ; CHECK-NEXT:    call void @llvm.lifetime.end.p0i8(i64 16, i8* [[B]])
 ; CHECK-NEXT:    ret void
 ;
@@ -46,7 +46,7 @@
 ; CHECK-NEXT:    [[A:%.*]] = call i8* @malloc(i64 16)
 ; CHECK-NEXT:    call void @llvm.lifetime.start.p0i8(i64 16, i8* [[A]])
 ; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 8 [[A]], i8 0, i64 12, i1 false)
-; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[RESULT:%.*]], i8* align 8 [[A]], i64 16, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* [[RESULT:%.*]], i8 0, i64 12, i1 false)
 ; CHECK-NEXT:    call void @llvm.lifetime.end.p0i8(i64 16, i8* [[A]])
 ; CHECK-NEXT:    call void @free(i8* [[A]])
 ; CHECK-NEXT:    ret void
@@ -98,7 +98,7 @@
 ; CHECK-NEXT:    [[A:%.*]] = alloca [[T:%.*]], align 8
 ; CHECK-NEXT:    [[B:%.*]] = bitcast %T* [[A]] to i8*
 ; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 8 [[B]], i8 0, i64 12, i1 true)
-; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* [[RESULT:%.*]], i8* align 8 [[B]], i64 16, i1 false)
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* [[RESULT:%.*]], i8 0, i64 12, i1 false)
 ; CHECK-NEXT:    ret void
 ;
   %a = alloca %T, align 8