[AArch64][SME] Support agnostic ZA functions in the MachineSMEABIPass (#149064)

This extends the MachineSMEABIPass to handle agnostic ZA functions. This
case is currently handled like shared ZA functions, but we don't require
ZA state to be reloaded before agnostic ZA calls.

Note: This patch does not yet fully handle agnostic ZA functions that
can catch exceptions. E.g.:

```
__arm_agnostic("sme_za_state") void try_catch_agnostic_za_callee()
{
  try {
    agnostic_za_call();
  } catch(...) {
    noexcept_agnostic_za_call();
  }
}
```

As in this case, we won't commit a ZA save before the
`agnostic_za_call()`, which would be needed to restore ZA in the catch
block. This will be handled in a later patch.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9569508..5e11145 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -8489,13 +8489,22 @@
   if (Subtarget->hasCustomCallingConv())
     Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
 
-  if (getTM().useNewSMEABILowering() && !Attrs.hasAgnosticZAInterface()) {
+  if (getTM().useNewSMEABILowering()) {
     if (Subtarget->isTargetWindows() || hasInlineStackProbe(MF)) {
       SDValue Size;
       if (Attrs.hasZAState()) {
         SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
                                   DAG.getConstant(1, DL, MVT::i32));
         Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
+      } else if (Attrs.hasAgnosticZAInterface()) {
+        RTLIB::Libcall LC = RTLIB::SMEABI_SME_STATE_SIZE;
+        SDValue Callee = DAG.getExternalSymbol(
+            getLibcallName(LC), getPointerTy(DAG.getDataLayout()));
+        auto *RetTy = EVT(MVT::i64).getTypeForEVT(*DAG.getContext());
+        TargetLowering::CallLoweringInfo CLI(DAG);
+        CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+            getLibcallCallingConv(LC), RetTy, Callee, {});
+        std::tie(Size, Chain) = LowerCallTo(CLI);
       }
       if (Size) {
         SDValue Buffer = DAG.getNode(
@@ -8561,7 +8570,7 @@
       Register BufferPtr =
           MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
       FuncInfo->setSMESaveBufferAddr(BufferPtr);
-      Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
+      Chain = DAG.getCopyToReg(Buffer.getValue(1), DL, BufferPtr, Buffer);
     }
   }
 
@@ -9300,17 +9309,17 @@
 
   // Determine whether we need any streaming mode changes.
   SMECallAttrs CallAttrs = getSMECallAttrs(MF.getFunction(), *this, CLI);
+
+  std::optional<unsigned> ZAMarkerNode;
   bool UseNewSMEABILowering = getTM().useNewSMEABILowering();
-  bool IsAgnosticZAFunction = CallAttrs.caller().hasAgnosticZAInterface();
-  auto ZAMarkerNode = [&]() -> std::optional<unsigned> {
-    // TODO: Handle agnostic ZA functions.
-    if (!UseNewSMEABILowering || IsAgnosticZAFunction)
-      return std::nullopt;
-    if (!CallAttrs.caller().hasZAState() && !CallAttrs.caller().hasZT0State())
-      return std::nullopt;
-    return CallAttrs.requiresLazySave() ? AArch64ISD::REQUIRES_ZA_SAVE
-                                        : AArch64ISD::INOUT_ZA_USE;
-  }();
+  if (UseNewSMEABILowering) {
+    if (CallAttrs.requiresLazySave() ||
+        CallAttrs.requiresPreservingAllZAState())
+      ZAMarkerNode = AArch64ISD::REQUIRES_ZA_SAVE;
+    else if (CallAttrs.caller().hasZAState() ||
+             CallAttrs.caller().hasZT0State())
+      ZAMarkerNode = AArch64ISD::INOUT_ZA_USE;
+  }
 
   if (IsTailCall) {
     // Check if it's really possible to do a tail call.
@@ -9385,7 +9394,8 @@
   };
 
   bool RequiresLazySave = !UseNewSMEABILowering && CallAttrs.requiresLazySave();
-  bool RequiresSaveAllZA = CallAttrs.requiresPreservingAllZAState();
+  bool RequiresSaveAllZA =
+      !UseNewSMEABILowering && CallAttrs.requiresPreservingAllZAState();
   if (RequiresLazySave) {
     TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index 5dca186..993cff1 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -261,7 +261,9 @@
     EarlyAllocSMESaveBuffer = Ptr;
   }
 
-  Register getEarlyAllocSMESaveBuffer() { return EarlyAllocSMESaveBuffer; }
+  Register getEarlyAllocSMESaveBuffer() const {
+    return EarlyAllocSMESaveBuffer;
+  }
 
   // Old SME ABI lowering state getters/setters:
   Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
diff --git a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
index d95d170..c39a5cc 100644
--- a/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/MachineSMEABIPass.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This pass implements the SME ABI requirements for ZA state. This includes
-// implementing the lazy ZA state save schemes around calls.
+// implementing the lazy (and agnostic) ZA state save schemes around calls.
 //
 //===----------------------------------------------------------------------===//
 //
@@ -215,9 +215,44 @@
   void emitZAOff(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                  bool ClearTPIDR2);
 
+  // Emission routines for agnostic ZA functions.
+  void emitSetupFullZASave(MachineBasicBlock &MBB,
+                           MachineBasicBlock::iterator MBBI,
+                           LiveRegs PhysLiveRegs);
+  // Emit a "full" ZA save or restore. It is "full" in the sense that this
+  // function will emit a call to __arm_sme_save or __arm_sme_restore, which
+  // handles saving and restoring both ZA and ZT0.
+  void emitFullZASaveRestore(MachineBasicBlock &MBB,
+                             MachineBasicBlock::iterator MBBI,
+                             LiveRegs PhysLiveRegs, bool IsSave);
+  void emitAllocateFullZASaveBuffer(MachineBasicBlock &MBB,
+                                    MachineBasicBlock::iterator MBBI,
+                                    LiveRegs PhysLiveRegs);
+
   void emitStateChange(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
                        ZAState From, ZAState To, LiveRegs PhysLiveRegs);
 
+  // Helpers for switching between lazy/full ZA save/restore routines.
+  void emitZASave(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+                  LiveRegs PhysLiveRegs) {
+    if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
+      return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/true);
+    return emitSetupLazySave(MBB, MBBI);
+  }
+  void emitZARestore(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+                     LiveRegs PhysLiveRegs) {
+    if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
+      return emitFullZASaveRestore(MBB, MBBI, PhysLiveRegs, /*IsSave=*/false);
+    return emitRestoreLazySave(MBB, MBBI, PhysLiveRegs);
+  }
+  void emitAllocateZASaveBuffer(MachineBasicBlock &MBB,
+                                MachineBasicBlock::iterator MBBI,
+                                LiveRegs PhysLiveRegs) {
+    if (AFI->getSMEFnAttrs().hasAgnosticZAInterface())
+      return emitAllocateFullZASaveBuffer(MBB, MBBI, PhysLiveRegs);
+    return emitAllocateLazySaveBuffer(MBB, MBBI);
+  }
+
   /// Save live physical registers to virtual registers.
   PhysRegSave createPhysRegSave(LiveRegs PhysLiveRegs, MachineBasicBlock &MBB,
                                 MachineBasicBlock::iterator MBBI, DebugLoc DL);
@@ -228,6 +263,8 @@
   /// Get or create a TPIDR2 block in this function.
   TPIDR2State getTPIDR2Block();
 
+  Register getAgnosticZABufferPtr();
+
 private:
   /// Contains the needed ZA state (and live registers) at an instruction.
   struct InstInfo {
@@ -241,6 +278,7 @@
   struct BlockInfo {
     ZAState FixedEntryState{ZAState::ANY};
     SmallVector<InstInfo> Insts;
+    LiveRegs PhysLiveRegsAtEntry = LiveRegs::None;
     LiveRegs PhysLiveRegsAtExit = LiveRegs::None;
   };
 
@@ -250,18 +288,22 @@
     SmallVector<ZAState> BundleStates;
     std::optional<TPIDR2State> TPIDR2Block;
     std::optional<MachineBasicBlock::iterator> AfterSMEProloguePt;
+    Register AgnosticZABufferPtr = AArch64::NoRegister;
+    LiveRegs PhysLiveRegsAfterSMEPrologue = LiveRegs::None;
   } State;
 
   MachineFunction *MF = nullptr;
   EdgeBundles *Bundles = nullptr;
   const AArch64Subtarget *Subtarget = nullptr;
   const AArch64RegisterInfo *TRI = nullptr;
+  const AArch64FunctionInfo *AFI = nullptr;
   const TargetInstrInfo *TII = nullptr;
   MachineRegisterInfo *MRI = nullptr;
 };
 
 void MachineSMEABI::collectNeededZAStates(SMEAttrs SMEFnAttrs) {
-  assert((SMEFnAttrs.hasZT0State() || SMEFnAttrs.hasZAState()) &&
+  assert((SMEFnAttrs.hasAgnosticZAInterface() || SMEFnAttrs.hasZT0State() ||
+          SMEFnAttrs.hasZAState()) &&
          "Expected function to have ZA/ZT0 state!");
 
   State.Blocks.resize(MF->getNumBlockIDs());
@@ -295,6 +337,7 @@
 
     Block.PhysLiveRegsAtExit = GetPhysLiveRegs();
     auto FirstTerminatorInsertPt = MBB.getFirstTerminator();
+    auto FirstNonPhiInsertPt = MBB.getFirstNonPHI();
     for (MachineInstr &MI : reverse(MBB)) {
       MachineBasicBlock::iterator MBBI(MI);
       LiveUnits.stepBackward(MI);
@@ -303,8 +346,11 @@
       // buffer was allocated in SelectionDAG. It marks the end of the
       // allocation -- which is a safe point for this pass to insert any TPIDR2
       // block setup.
-      if (MI.getOpcode() == AArch64::SMEStateAllocPseudo)
+      if (MI.getOpcode() == AArch64::SMEStateAllocPseudo) {
         State.AfterSMEProloguePt = MBBI;
+        State.PhysLiveRegsAfterSMEPrologue = PhysLiveRegs;
+      }
+      // Note: We treat Agnostic ZA as inout_za with an alternate save/restore.
       auto [NeededState, InsertPt] = getZAStateBeforeInst(
           *TRI, MI, /*ZAOffAtReturn=*/SMEFnAttrs.hasPrivateZAInterface());
       assert((InsertPt == MBBI ||
@@ -313,6 +359,8 @@
       // TODO: Do something to avoid state changes where NZCV is live.
       if (MBBI == FirstTerminatorInsertPt)
         Block.PhysLiveRegsAtExit = PhysLiveRegs;
+      if (MBBI == FirstNonPhiInsertPt)
+        Block.PhysLiveRegsAtEntry = PhysLiveRegs;
       if (NeededState != ZAState::ANY)
         Block.Insts.push_back({NeededState, InsertPt, PhysLiveRegs});
     }
@@ -536,8 +584,6 @@
 void MachineSMEABI::emitAllocateLazySaveBuffer(
     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
   MachineFrameInfo &MFI = MF->getFrameInfo();
-  auto *AFI = MF->getInfo<AArch64FunctionInfo>();
-
   DebugLoc DL = getDebugLoc(MBB, MBBI);
   Register SP = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
   Register SVL = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
@@ -601,8 +647,7 @@
       .addImm(AArch64SysReg::TPIDR2_EL0);
   // If TPIDR2_EL0 is non-zero, commit the lazy save.
   // NOTE: Functions that only use ZT0 don't need to zero ZA.
-  bool ZeroZA =
-      MF->getInfo<AArch64FunctionInfo>()->getSMEFnAttrs().hasZAState();
+  bool ZeroZA = AFI->getSMEFnAttrs().hasZAState();
   auto CommitZASave =
       BuildMI(MBB, MBBI, DL, TII->get(AArch64::CommitZASavePseudo))
           .addReg(TPIDR2EL0)
@@ -617,6 +662,86 @@
       .addImm(1);
 }
 
+Register MachineSMEABI::getAgnosticZABufferPtr() {
+  if (State.AgnosticZABufferPtr != AArch64::NoRegister)
+    return State.AgnosticZABufferPtr;
+  Register BufferPtr = AFI->getEarlyAllocSMESaveBuffer();
+  State.AgnosticZABufferPtr =
+      BufferPtr != AArch64::NoRegister
+          ? BufferPtr
+          : MF->getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+  return State.AgnosticZABufferPtr;
+}
+
+void MachineSMEABI::emitFullZASaveRestore(MachineBasicBlock &MBB,
+                                          MachineBasicBlock::iterator MBBI,
+                                          LiveRegs PhysLiveRegs, bool IsSave) {
+  auto *TLI = Subtarget->getTargetLowering();
+  DebugLoc DL = getDebugLoc(MBB, MBBI);
+  Register BufferPtr = AArch64::X0;
+
+  PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
+
+  // Copy the buffer pointer into X0.
+  BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
+      .addReg(getAgnosticZABufferPtr());
+
+  // Call __arm_sme_save/__arm_sme_restore.
+  BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
+      .addReg(BufferPtr, RegState::Implicit)
+      .addExternalSymbol(TLI->getLibcallName(
+          IsSave ? RTLIB::SMEABI_SME_SAVE : RTLIB::SMEABI_SME_RESTORE))
+      .addRegMask(TRI->getCallPreservedMask(
+          *MF,
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+
+  restorePhyRegSave(RegSave, MBB, MBBI, DL);
+}
+
+void MachineSMEABI::emitAllocateFullZASaveBuffer(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
+    LiveRegs PhysLiveRegs) {
+  // Buffer already allocated in SelectionDAG.
+  if (AFI->getEarlyAllocSMESaveBuffer())
+    return;
+
+  DebugLoc DL = getDebugLoc(MBB, MBBI);
+  Register BufferPtr = getAgnosticZABufferPtr();
+  Register BufferSize = MRI->createVirtualRegister(&AArch64::GPR64RegClass);
+
+  PhysRegSave RegSave = createPhysRegSave(PhysLiveRegs, MBB, MBBI, DL);
+
+  // Calculate the SME state size.
+  {
+    auto *TLI = Subtarget->getTargetLowering();
+    const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL))
+        .addExternalSymbol(TLI->getLibcallName(RTLIB::SMEABI_SME_STATE_SIZE))
+        .addReg(AArch64::X0, RegState::ImplicitDefine)
+        .addRegMask(TRI->getCallPreservedMask(
+            *MF, CallingConv::
+                     AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+    BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferSize)
+        .addReg(AArch64::X0);
+  }
+
+  // Allocate a buffer object of the size given __arm_sme_state_size.
+  {
+    MachineFrameInfo &MFI = MF->getFrameInfo();
+    BuildMI(MBB, MBBI, DL, TII->get(AArch64::SUBXrx64), AArch64::SP)
+        .addReg(AArch64::SP)
+        .addReg(BufferSize)
+        .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
+    BuildMI(MBB, MBBI, DL, TII->get(TargetOpcode::COPY), BufferPtr)
+        .addReg(AArch64::SP);
+
+    // We have just allocated a variable sized object, tell this to PEI.
+    MFI.CreateVariableSizedObject(Align(16), nullptr);
+  }
+
+  restorePhyRegSave(RegSave, MBB, MBBI, DL);
+}
+
 void MachineSMEABI::emitStateChange(MachineBasicBlock &MBB,
                                     MachineBasicBlock::iterator InsertPt,
                                     ZAState From, ZAState To,
@@ -634,10 +759,7 @@
   // TODO: Avoid setting up the save buffer if there's no transition to
   // LOCAL_SAVED.
   if (From == ZAState::CALLER_DORMANT) {
-    assert(MBB.getParent()
-               ->getInfo<AArch64FunctionInfo>()
-               ->getSMEFnAttrs()
-               .hasPrivateZAInterface() &&
+    assert(AFI->getSMEFnAttrs().hasPrivateZAInterface() &&
            "CALLER_DORMANT state requires private ZA interface");
     assert(&MBB == &MBB.getParent()->front() &&
            "CALLER_DORMANT state only valid in entry block");
@@ -652,12 +774,14 @@
   }
 
   if (From == ZAState::ACTIVE && To == ZAState::LOCAL_SAVED)
-    emitSetupLazySave(MBB, InsertPt);
+    emitZASave(MBB, InsertPt, PhysLiveRegs);
   else if (From == ZAState::LOCAL_SAVED && To == ZAState::ACTIVE)
-    emitRestoreLazySave(MBB, InsertPt, PhysLiveRegs);
+    emitZARestore(MBB, InsertPt, PhysLiveRegs);
   else if (To == ZAState::OFF) {
     assert(From != ZAState::CALLER_DORMANT &&
            "CALLER_DORMANT to OFF should have already been handled");
+    assert(!AFI->getSMEFnAttrs().hasAgnosticZAInterface() &&
+           "Should not turn ZA off in agnostic ZA function");
     emitZAOff(MBB, InsertPt, /*ClearTPIDR2=*/From == ZAState::LOCAL_SAVED);
   } else {
     dbgs() << "Error: Transition from " << getZAStateString(From) << " to "
@@ -675,9 +799,10 @@
   if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
     return false;
 
-  auto *AFI = MF.getInfo<AArch64FunctionInfo>();
+  AFI = MF.getInfo<AArch64FunctionInfo>();
   SMEAttrs SMEFnAttrs = AFI->getSMEFnAttrs();
-  if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State())
+  if (!SMEFnAttrs.hasZAState() && !SMEFnAttrs.hasZT0State() &&
+      !SMEFnAttrs.hasAgnosticZAInterface())
     return false;
 
   assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
@@ -696,15 +821,18 @@
   insertStateChanges();
 
   // Allocate save buffer (if needed).
-  if (State.TPIDR2Block) {
+  if (State.AgnosticZABufferPtr != AArch64::NoRegister || State.TPIDR2Block) {
     if (State.AfterSMEProloguePt) {
       // Note: With inline stack probes the AfterSMEProloguePt may not be in the
       // entry block (due to the probing loop).
-      emitAllocateLazySaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
-                                 *State.AfterSMEProloguePt);
+      emitAllocateZASaveBuffer(*(*State.AfterSMEProloguePt)->getParent(),
+                               *State.AfterSMEProloguePt,
+                               State.PhysLiveRegsAfterSMEPrologue);
     } else {
       MachineBasicBlock &EntryBlock = MF.front();
-      emitAllocateLazySaveBuffer(EntryBlock, EntryBlock.getFirstNonPHI());
+      emitAllocateZASaveBuffer(
+          EntryBlock, EntryBlock.getFirstNonPHI(),
+          State.Blocks[EntryBlock.getNumber()].PhysLiveRegsAtEntry);
     }
   }
 
diff --git a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
index b31ae68e..a0a14f2 100644
--- a/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
+++ b/llvm/test/CodeGen/AArch64/sme-agnostic-za.ll
@@ -1,18 +1,19 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mattr=+sme2 < %s | FileCheck %s
-; RUN: llc -mattr=+sme2 < %s -aarch64-new-sme-abi | FileCheck %s
+; RUN: llc -mattr=+sme2 < %s | FileCheck %s --check-prefixes=CHECK-COMMON,CHECK
+; RUN: llc -mattr=+sme2 < %s -aarch64-new-sme-abi | FileCheck %s --check-prefixes=CHECK-COMMON,CHECK-NEWLOWERING
 
 target triple = "aarch64"
 
 declare i64 @private_za_decl(i64)
+declare void @private_za()
 declare i64 @agnostic_decl(i64) "aarch64_za_state_agnostic"
 
 ; No calls. Test that no buffer is allocated.
 define i64 @agnostic_caller_no_callees(ptr %ptr) nounwind "aarch64_za_state_agnostic" {
-; CHECK-LABEL: agnostic_caller_no_callees:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr x0, [x0]
-; CHECK-NEXT:    ret
+; CHECK-COMMON-LABEL: agnostic_caller_no_callees:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    ldr x0, [x0]
+; CHECK-COMMON-NEXT:    ret
   %v = load i64, ptr %ptr
   ret i64 %v
 }
@@ -51,6 +52,29 @@
 ; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: agnostic_caller_private_za_callee:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    mov x29, sp
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state_size
+; CHECK-NEWLOWERING-NEXT:    sub sp, sp, x0
+; CHECK-NEWLOWERING-NEXT:    mov x19, sp
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    bl private_za_decl
+; CHECK-NEWLOWERING-NEXT:    bl private_za_decl
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    mov sp, x29
+; CHECK-NEWLOWERING-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
   %res = call i64 @private_za_decl(i64 %v)
   %res2 = call i64 @private_za_decl(i64 %res)
   ret i64 %res2
@@ -60,12 +84,12 @@
 ;
 ; Should not result in save/restore code.
 define i64 @agnostic_caller_agnostic_callee(i64 %v) nounwind "aarch64_za_state_agnostic" {
-; CHECK-LABEL: agnostic_caller_agnostic_callee:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    bl agnostic_decl
-; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
-; CHECK-NEXT:    ret
+; CHECK-COMMON-LABEL: agnostic_caller_agnostic_callee:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-COMMON-NEXT:    bl agnostic_decl
+; CHECK-COMMON-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-COMMON-NEXT:    ret
   %res = call i64 @agnostic_decl(i64 %v)
   ret i64 %res
 }
@@ -74,12 +98,12 @@
 ;
 ; Should not result in lazy-save or save of ZT0
 define i64 @shared_caller_agnostic_callee(i64 %v) nounwind "aarch64_inout_za" "aarch64_inout_zt0" {
-; CHECK-LABEL: shared_caller_agnostic_callee:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
-; CHECK-NEXT:    bl agnostic_decl
-; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
-; CHECK-NEXT:    ret
+; CHECK-COMMON-LABEL: shared_caller_agnostic_callee:
+; CHECK-COMMON:       // %bb.0:
+; CHECK-COMMON-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-COMMON-NEXT:    bl agnostic_decl
+; CHECK-COMMON-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-COMMON-NEXT:    ret
   %res = call i64 @agnostic_decl(i64 %v)
   ret i64 %res
 }
@@ -126,6 +150,41 @@
 ; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: streaming_agnostic_caller_nonstreaming_private_za_callee:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    add x29, sp, #64
+; CHECK-NEWLOWERING-NEXT:    stp x20, x19, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state_size
+; CHECK-NEWLOWERING-NEXT:    sub sp, sp, x0
+; CHECK-NEWLOWERING-NEXT:    mov x20, sp
+; CHECK-NEWLOWERING-NEXT:    mov x0, x20
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
+; CHECK-NEWLOWERING-NEXT:    smstop sm
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    bl private_za_decl
+; CHECK-NEWLOWERING-NEXT:    smstart sm
+; CHECK-NEWLOWERING-NEXT:    smstop sm
+; CHECK-NEWLOWERING-NEXT:    bl private_za_decl
+; CHECK-NEWLOWERING-NEXT:    smstart sm
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    mov x0, x20
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    sub sp, x29, #64
+; CHECK-NEWLOWERING-NEXT:    ldp x20, x19, [sp, #80] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
   %res = call i64 @private_za_decl(i64 %v)
   %res2 = call i64 @private_za_decl(i64 %res)
   ret i64 %res2
@@ -186,6 +245,54 @@
 ; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
 ; CHECK-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: streaming_compatible_agnostic_caller_nonstreaming_private_za_callee:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    add x29, sp, #64
+; CHECK-NEWLOWERING-NEXT:    stp x20, x19, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state_size
+; CHECK-NEWLOWERING-NEXT:    sub sp, sp, x0
+; CHECK-NEWLOWERING-NEXT:    mov x19, sp
+; CHECK-NEWLOWERING-NEXT:    mrs x20, SVCR
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
+; CHECK-NEWLOWERING-NEXT:    tbz w20, #0, .LBB5_2
+; CHECK-NEWLOWERING-NEXT:  // %bb.1:
+; CHECK-NEWLOWERING-NEXT:    smstop sm
+; CHECK-NEWLOWERING-NEXT:  .LBB5_2:
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    bl private_za_decl
+; CHECK-NEWLOWERING-NEXT:    tbz w20, #0, .LBB5_4
+; CHECK-NEWLOWERING-NEXT:  // %bb.3:
+; CHECK-NEWLOWERING-NEXT:    smstart sm
+; CHECK-NEWLOWERING-NEXT:  .LBB5_4:
+; CHECK-NEWLOWERING-NEXT:    tbz w20, #0, .LBB5_6
+; CHECK-NEWLOWERING-NEXT:  // %bb.5:
+; CHECK-NEWLOWERING-NEXT:    smstop sm
+; CHECK-NEWLOWERING-NEXT:  .LBB5_6:
+; CHECK-NEWLOWERING-NEXT:    bl private_za_decl
+; CHECK-NEWLOWERING-NEXT:    tbz w20, #0, .LBB5_8
+; CHECK-NEWLOWERING-NEXT:  // %bb.7:
+; CHECK-NEWLOWERING-NEXT:    smstart sm
+; CHECK-NEWLOWERING-NEXT:  .LBB5_8:
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    sub sp, x29, #64
+; CHECK-NEWLOWERING-NEXT:    ldp x20, x19, [sp, #80] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
   %res = call i64 @private_za_decl(i64 %v)
   %res2 = call i64 @private_za_decl(i64 %res)
   ret i64 %res2
@@ -222,9 +329,99 @@
 ; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
 ; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
 ; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: test_many_callee_arguments:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    mov x29, sp
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state_size
+; CHECK-NEWLOWERING-NEXT:    sub sp, sp, x0
+; CHECK-NEWLOWERING-NEXT:    mov x19, sp
+; CHECK-NEWLOWERING-NEXT:    ldp x9, x10, [x29, #32]
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
+; CHECK-NEWLOWERING-NEXT:    stp x9, x10, [sp, #-16]!
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    bl many_args_private_za_callee
+; CHECK-NEWLOWERING-NEXT:    add sp, sp, #16
+; CHECK-NEWLOWERING-NEXT:    mov x8, x0
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
+; CHECK-NEWLOWERING-NEXT:    mov x0, x8
+; CHECK-NEWLOWERING-NEXT:    mov sp, x29
+; CHECK-NEWLOWERING-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
   i64 %0, i64 %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9
 ) nounwind "aarch64_za_state_agnostic" {
   %ret = call i64 @many_args_private_za_callee(
     i64 %0, i64 %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9)
   ret i64 %ret
 }
+
+; FIXME: The new lowering should avoid saves/restores in the probing loop.
+define void @agnostic_za_buffer_alloc_with_stack_probes() nounwind "aarch64_za_state_agnostic" "probe-stack"="inline-asm" "stack-probe-size"="65536"{
+; CHECK-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEXT:    mov x29, sp
+; CHECK-NEXT:    bl __arm_sme_state_size
+; CHECK-NEXT:    mov x8, sp
+; CHECK-NEXT:    sub x19, x8, x0
+; CHECK-NEXT:  .LBB7_1: // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    sub sp, sp, #16, lsl #12 // =65536
+; CHECK-NEXT:    cmp sp, x19
+; CHECK-NEXT:    b.le .LBB7_3
+; CHECK-NEXT:  // %bb.2: // in Loop: Header=BB7_1 Depth=1
+; CHECK-NEXT:    str xzr, [sp]
+; CHECK-NEXT:    b .LBB7_1
+; CHECK-NEXT:  .LBB7_3:
+; CHECK-NEXT:    mov sp, x19
+; CHECK-NEXT:    ldr xzr, [sp]
+; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    bl __arm_sme_save
+; CHECK-NEXT:    bl private_za
+; CHECK-NEXT:    mov x0, x19
+; CHECK-NEXT:    bl __arm_sme_restore
+; CHECK-NEXT:    mov sp, x29
+; CHECK-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-NEWLOWERING-LABEL: agnostic_za_buffer_alloc_with_stack_probes:
+; CHECK-NEWLOWERING:       // %bb.0:
+; CHECK-NEWLOWERING-NEXT:    stp x29, x30, [sp, #-32]! // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    str x19, [sp, #16] // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT:    mov x29, sp
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_state_size
+; CHECK-NEWLOWERING-NEXT:    mov x8, sp
+; CHECK-NEWLOWERING-NEXT:    sub x19, x8, x0
+; CHECK-NEWLOWERING-NEXT:  .LBB7_1: // =>This Inner Loop Header: Depth=1
+; CHECK-NEWLOWERING-NEXT:    sub sp, sp, #16, lsl #12 // =65536
+; CHECK-NEWLOWERING-NEXT:    cmp sp, x19
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    mrs x8, NZCV
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_save
+; CHECK-NEWLOWERING-NEXT:    msr NZCV, x8
+; CHECK-NEWLOWERING-NEXT:    b.le .LBB7_3
+; CHECK-NEWLOWERING-NEXT:  // %bb.2: // in Loop: Header=BB7_1 Depth=1
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    str xzr, [sp]
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
+; CHECK-NEWLOWERING-NEXT:    b .LBB7_1
+; CHECK-NEWLOWERING-NEXT:  .LBB7_3:
+; CHECK-NEWLOWERING-NEXT:    mov sp, x19
+; CHECK-NEWLOWERING-NEXT:    ldr xzr, [sp]
+; CHECK-NEWLOWERING-NEXT:    bl private_za
+; CHECK-NEWLOWERING-NEXT:    mov x0, x19
+; CHECK-NEWLOWERING-NEXT:    bl __arm_sme_restore
+; CHECK-NEWLOWERING-NEXT:    mov sp, x29
+; CHECK-NEWLOWERING-NEXT:    ldr x19, [sp, #16] // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ldp x29, x30, [sp], #32 // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    ret
+  call void @private_za()
+  ret void
+}