[Coroutine] Relax CoroElide musttail check

As discussed in D94834, we don't really need to do complicated analysis. It's safe to just drop the tail call attribute.

Differential Revision: https://reviews.llvm.org/D96926

GitOrigin-RevId: 3bf8f162a0a922026d4c183231acb2be0dcdfcc7
diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp
index 07a183c..d17650d 100644
--- a/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/lib/Transforms/Coroutines/CoroElide.cpp
@@ -79,11 +79,16 @@
 // Look for any tail calls referencing the coroutine frame and remove tail
 // attribute from them, since now coroutine frame resides on the stack and tail
 // call implies that the function does not references anything on the stack.
+// However if it's a musttail call, we cannot remove the tailcall attribute.
+// It's safe to keep it there as the musttail call is for symmetric transfer,
+// and by that point the frame should have been destroyed and hence not
+// interfering with operands.
 static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {
   Function &F = *Frame->getFunction();
   for (Instruction &I : instructions(F))
     if (auto *Call = dyn_cast<CallInst>(&I))
-      if (Call->isTailCall() && operandReferences(Call, Frame, AA))
+      if (Call->isTailCall() && operandReferences(Call, Frame, AA) &&
+          !Call->isMustTailCall())
         Call->setTailCall(false);
 }
 
@@ -246,20 +251,7 @@
   // If size of the set is the same as total number of coro.begin, that means we
   // found a coro.free or coro.destroy referencing each coro.begin, so we can
   // perform heap elision.
-  if (ReferencedCoroBegins.size() != CoroBegins.size())
-    return false;
-
-  // If any call in the function is a musttail call, it usually won't work
-  // because we cannot drop the tailcall attribute, and a tail call will reuse
-  // the entire stack where we are going to put the new frame. In theory a more
-  // precise analysis can be done to check whether the new frame aliases with
-  // the call, however it's challenging to do so before the elision actually
-  // happened.
-  for (BasicBlock &BB : *F)
-    if (BB.getTerminatingMustTailCall())
-      return false;
-
-  return true;
+  return ReferencedCoroBegins.size() == CoroBegins.size();
 }
 
 void Lowerer::collectPostSplitCoroIds(Function *F) {
diff --git a/test/Transforms/Coroutines/coro-elide-musttail.ll b/test/Transforms/Coroutines/coro-elide-musttail.ll
index 2920bac..f04a953 100644
--- a/test/Transforms/Coroutines/coro-elide-musttail.ll
+++ b/test/Transforms/Coroutines/coro-elide-musttail.ll
@@ -17,17 +17,28 @@
 declare dso_local void @"bar"() align 2
 declare dso_local fastcc void @"bar.resume"(%"bar.Frame"*) align 2
 
-; There is a musttail call. CoroElide won't happen.
+; There is a musttail call.
+; With alias analysis, we can tell that the frame does not interfere with CALL34, and hence we can keep the tailcalls.
+; Without alias analysis, we have to keep the tailcalls.
 define internal fastcc void @foo.resume_musttail(%"foo.Frame"* %FramePtr) {
 ; CHECK-LABEL: @foo.resume_musttail(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @bar to i8*), i8* bitcast ([3 x void (%bar.Frame*)*]* @bar.resumers to i8*))
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i1 @llvm.coro.alloc(token [[TMP0]])
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call i8* @llvm.coro.begin(token [[TMP0]], i8* null)
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca [24 x i8], align 8
+; CHECK-NEXT:    [[VFRAME:%.*]] = bitcast [24 x i8]* [[TMP0]] to i8*
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @bar to i8*), i8* bitcast ([3 x void (%bar.Frame*)*]* @bar.resumers to i8*))
 ; CHECK-NEXT:    [[CALL34:%.*]] = call i8* undef()
 ; CHECK-NEXT:    musttail call fastcc void undef(i8* [[CALL34]])
 ; CHECK-NEXT:    ret void
 ;
+; NOAA-LABEL: @foo.resume_musttail(
+; NOAA-NEXT:  entry:
+; NOAA-NEXT:    [[TMP0:%.*]] = alloca [24 x i8], align 8
+; NOAA-NEXT:    [[VFRAME:%.*]] = bitcast [24 x i8]* [[TMP0]] to i8*
+; NOAA-NEXT:    [[TMP1:%.*]] = call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @bar to i8*), i8* bitcast ([3 x void (%bar.Frame*)*]* @bar.resumers to i8*))
+; NOAA-NEXT:    [[CALL34:%.*]] = call i8* undef()
+; NOAA-NEXT:    musttail call fastcc void undef(i8* [[CALL34]])
+; NOAA-NEXT:    ret void
+;
 entry:
   %0 = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @"bar" to i8*), i8* bitcast ([3 x void (%"bar.Frame"*)*]* @"bar.resumers" to i8*))
   %1 = tail call i1 @llvm.coro.alloc(token %0)
@@ -38,60 +49,6 @@
   ret void
 }
 
-; The new frame (TMP0) could potentially alias CALL34, the tailcall attribute on that call must be removed
-define internal fastcc void @foo.resume_no_musttail_with_alias(%"foo.Frame"* %FramePtr) {
-; CHECK-LABEL: @foo.resume_no_musttail_with_alias(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = alloca [24 x i8], align 8
-; CHECK-NEXT:    [[VFRAME:%.*]] = bitcast [24 x i8]* [[TMP0]] to i8*
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @bar to i8*), i8* bitcast ([3 x void (%bar.Frame*)*]* @bar.resumers to i8*))
-; CHECK-NEXT:    call fastcc void undef(i8* [[VFRAME]])
-; CHECK-NEXT:    [[CALL34:%.*]] = call i8* undef()
-; CHECK-NEXT:    call fastcc void undef(i8* [[CALL34]])
-; CHECK-NEXT:    ret void
-;
-entry:
-  %0 = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @"bar" to i8*), i8* bitcast ([3 x void (%"bar.Frame"*)*]* @"bar.resumers" to i8*))
-  %1 = tail call i1 @llvm.coro.alloc(token %0)
-  %2 = tail call i8* @llvm.coro.begin(token %0, i8* null)
-  call i8* @llvm.coro.subfn.addr(i8* %2, i8 1)
-  call fastcc void undef(i8* %2)
-  %call34 = call i8* undef()
-  tail call fastcc void undef(i8* %call34)
-  ret void
-}
-
-; The new frame (TMP0) does not alias CALL34, tailcall attribute can reimain. This analysis is only available when alias analysis is enabled.
-define internal fastcc void @foo.resume_no_musttail_no_alias(%"foo.Frame"* %FramePtr) {
-; CHECK-LABEL: @foo.resume_no_musttail_no_alias(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = alloca [24 x i8], align 8
-; CHECK-NEXT:    [[VFRAME:%.*]] = bitcast [24 x i8]* [[TMP0]] to i8*
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @bar to i8*), i8* bitcast ([3 x void (%bar.Frame*)*]* @bar.resumers to i8*))
-; CHECK-NEXT:    [[CALL34:%.*]] = call i8* undef()
-; CHECK-NEXT:    tail call fastcc void undef(i8* [[CALL34]])
-; CHECK-NEXT:    ret void
-;
-; NOAA-LABEL: @foo.resume_no_musttail_no_alias(
-; NOAA-NEXT:  entry:
-; NOAA-NEXT:    [[TMP0:%.*]] = alloca [24 x i8], align 8
-; NOAA-NEXT:    [[VFRAME:%.*]] = bitcast [24 x i8]* [[TMP0]] to i8*
-; NOAA-NEXT:    [[TMP1:%.*]] = call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @bar to i8*), i8* bitcast ([3 x void (%bar.Frame*)*]* @bar.resumers to i8*))
-; NOAA-NEXT:    [[CALL34:%.*]] = call i8* undef()
-; NOAA-NEXT:    call fastcc void undef(i8* [[CALL34]])
-; NOAA-NEXT:    ret void
-;
-entry:
-  %0 = tail call token @llvm.coro.id(i32 16, i8* null, i8* bitcast (void ()* @"bar" to i8*), i8* bitcast ([3 x void (%"bar.Frame"*)*]* @"bar.resumers" to i8*))
-  %1 = tail call i1 @llvm.coro.alloc(token %0)
-  %2 = tail call i8* @llvm.coro.begin(token %0, i8* null)
-  call i8* @llvm.coro.subfn.addr(i8* %2, i8 1)
-  %call34 = call i8* undef()
-  tail call fastcc void undef(i8* %call34)
-  ret void
-}
-
-
 ; Function Attrs: argmemonly nofree nosync nounwind willreturn
 declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #0