[Attributor][FIX] Prevent alignment breakage wrt. must-tail calls

If we have a must-tail call the callee and caller need to have matching
ABIs. Part of that is alignment which we might modify when we deduce
alignment of arguments of either. Since we would need to keep them in
sync, which is not as simple, we simply avoid deducing alignment for
arguments of the must-tail caller or callee.

Reviewed By: rnk

Differential Revision: https://reviews.llvm.org/D76673
diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index e5c50bd..3b0dad0 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -602,6 +602,13 @@
     return AG.getAnalysis<AAManager>(F);
   }
 
+  /// Return true if \p Arg is involved in a must-tail call, thus the argument
+  /// of the caller or callee.
+  bool isInvolvedInMustTailCall(const Argument &Arg) const {
+    return FunctionsCalledViaMustTail.count(Arg.getParent()) ||
+           FunctionsWithMustTailCall.count(Arg.getParent());
+  }
+
   /// Return the analysis result from a pass \p AP for function \p F.
   template <typename AP>
   typename AP::Result *getAnalysisResultForFunction(const Function &F) {
@@ -635,6 +642,12 @@
   /// A map from functions to their instructions that may read or write memory.
   FuncRWInstsMapTy FuncRWInstsMap;
 
+  /// Functions called by a `musttail` call.
+  SmallPtrSet<Function *, 8> FunctionsCalledViaMustTail;
+
+  /// Functions containing a `musttail` call.
+  SmallPtrSet<Function *, 8> FunctionsWithMustTailCall;
+
   /// The datalayout used in the module.
   const DataLayout &DL;
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index b0b9af9..5a9a926 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -3076,9 +3076,8 @@
 
 static AttrBuilder getParameterABIAttributes(int I, AttributeList Attrs) {
   static const Attribute::AttrKind ABIAttrs[] = {
-      Attribute::StructRet, Attribute::ByVal, Attribute::InAlloca,
-      Attribute::InReg, Attribute::Returned, Attribute::SwiftSelf,
-      Attribute::SwiftError};
+      Attribute::StructRet, Attribute::ByVal,     Attribute::InAlloca,
+      Attribute::InReg,     Attribute::SwiftSelf, Attribute::SwiftError};
   AttrBuilder Copy;
   for (auto AK : ABIAttrs) {
     if (Attrs.hasParamAttribute(I, AK))
diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 1e5591f..60be591 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -4095,10 +4095,20 @@
 struct AAAlignArgument final
     : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AAAlign,
                                                               AAAlignImpl> {
-  AAAlignArgument(const IRPosition &IRP)
-      : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AAAlign,
-                                                                AAAlignImpl>(
-            IRP) {}
+  using Base =
+      AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext<AAAlign,
+                                                              AAAlignImpl>;
+  AAAlignArgument(const IRPosition &IRP) : Base(IRP) {}
+
+  /// See AbstractAttribute::manifest(...).
+  ChangeStatus manifest(Attributor &A) override {
+    // If the associated argument is involved in a must-tail call we give up
+    // because we would need to keep the argument alignments of caller and
+    // callee in-sync. Just does not seem worth the trouble right now.
+    if (A.getInfoCache().isInvolvedInMustTailCall(*getAssociatedArgument()))
+      return ChangeStatus::UNCHANGED;
+    return Base::manifest(A);
+  }
 
   /// See AbstractAttribute::trackStatistics()
   void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(aligned) }
@@ -4109,6 +4119,12 @@
 
   /// See AbstractAttribute::manifest(...).
   ChangeStatus manifest(Attributor &A) override {
+    // If the associated argument is involved in a must-tail call we give up
+    // because we would need to keep the argument alignments of caller and
+    // callee in-sync. Just does not seem worth the trouble right now.
+    if (Argument *Arg = getAssociatedArgument())
+      if (A.getInfoCache().isInvolvedInMustTailCall(*Arg))
+        return ChangeStatus::UNCHANGED;
     ChangeStatus Changed = AAAlignImpl::manifest(A);
     MaybeAlign InheritAlign =
         getAssociatedValue().getPointerAlignment(A.getDataLayout());
@@ -8379,17 +8395,18 @@
              "Attributor.");
       break;
     case Instruction::Call:
-      // Calls are interesting but for `llvm.assume` calls we also fill the
-      // KnowledgeMap as we find them.
+      // Calls are interesting on their own, additionally:
+      // For `llvm.assume` calls we also fill the KnowledgeMap as we find them.
+      // For `must-tail` calls we remember the caller and callee.
       if (IntrinsicInst *Assume = dyn_cast<IntrinsicInst>(&I)) {
         if (Assume->getIntrinsicID() == Intrinsic::assume)
           fillMapFromAssume(*Assume, InfoCache.KnowledgeMap);
+      } else if (cast<CallInst>(I).isMustTailCall()) {
+        InfoCache.FunctionsWithMustTailCall.insert(&F);
+        InfoCache.FunctionsCalledViaMustTail.insert(
+            cast<CallInst>(I).getCalledFunction());
       }
       LLVM_FALLTHROUGH;
-    case Instruction::Load:
-      // The alignment of a pointer is interesting for loads.
-    case Instruction::Store:
-      // The alignment of a pointer is interesting for stores.
     case Instruction::CallBr:
     case Instruction::Invoke:
     case Instruction::CleanupRet:
@@ -8399,6 +8416,10 @@
     case Instruction::Br:
     case Instruction::Resume:
     case Instruction::Ret:
+    case Instruction::Load:
+      // The alignment of a pointer is interesting for loads.
+    case Instruction::Store:
+      // The alignment of a pointer is interesting for stores.
       IsInterestingOpcode = true;
     }
     if (IsInterestingOpcode)
@@ -8433,6 +8454,15 @@
   if (F.isDeclaration())
     return;
 
+  // In non-module runs we need to look at the call sites of a function to
+  // determine if it is part of a must-tail call edge. This will influence what
+  // attributes we can derive.
+  if (!isModulePass() && !InfoCache.FunctionsCalledViaMustTail.count(&F))
+    for (const Use &U : F.uses())
+      if (ImmutableCallSite ICS = ImmutableCallSite(U.getUser()))
+        if (ICS.isCallee(&U) && ICS.isMustTailCall())
+          InfoCache.FunctionsCalledViaMustTail.insert(&F);
+
   IRPosition FPos = IRPosition::function(F);
 
   // Check for dead BasicBlocks in every function.
diff --git a/llvm/test/Transforms/Attributor/align.ll b/llvm/test/Transforms/Attributor/align.ll
index ae8a3db..882d0e8 100644
--- a/llvm/test/Transforms/Attributor/align.ll
+++ b/llvm/test/Transforms/Attributor/align.ll
@@ -552,6 +552,23 @@
   store i8 0, i8* %bc
   ret void
 }
+
+; Make sure we do not annotate the callee of a must-tail call with an alignment
+; we cannot also put on the caller.
+@cnd = external global i1
+define i32 @musttail_callee_1(i32* %p) {
+  %v = load i32, i32* %p, align 32
+  ret i32 %v
+}
+define i32 @musttail_caller_1(i32* %p) {
+  %c = load i1, i1* @cnd
+  br i1 %c, label %mt, label %exit
+mt:
+  %v = musttail call i32 @musttail_callee_1(i32* %p)
+  ret i32 %v
+exit:
+  ret i32 0
+}
 ; UTC_ARGS: --disable
 
 attributes #0 = { nounwind uwtable noinline }