[ARM] Revert WhileLoopStartLR to DoLoopStart

If a WhileLoopStartLR is reverted due to calls in the preheader, we may
still be able to instead create a DoLoopStart, preserving the low
overhead loop. This adds code for that, only reverting the
WhileLoopStartR to a Br/Cmp, leaving the rest of the low overhead loop
in place.

Differential Revision: https://reviews.llvm.org/D98413

GitOrigin-RevId: d97189600e26553fa4fcdc73bd66b22c0ea420dd
diff --git a/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp b/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp
index f21ea27..3fd404e 100644
--- a/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp
+++ b/lib/Target/ARM/MVETPAndVPTOptimisationsPass.cpp
@@ -273,11 +273,28 @@
 
   // Check if there is an illegal instruction (a call) in the low overhead loop
   // and if so revert it now before we get any further. While loops also need to
-  // check the preheaders.
-  SmallPtrSet<MachineBasicBlock *, 4> MBBs(ML->block_begin(), ML->block_end());
-  if (LoopStart->getOpcode() == ARM::t2WhileLoopStartLR)
-    MBBs.insert(ML->getHeader()->pred_begin(), ML->getHeader()->pred_end());
-  for (MachineBasicBlock *MBB : MBBs) {
+  // check the preheaders, but can be reverted to a DLS loop if needed.
+  auto *PreHeader = ML->getLoopPreheader();
+  if (LoopStart->getOpcode() == ARM::t2WhileLoopStartLR && PreHeader &&
+      LoopStart->getParent() != PreHeader) {
+    for (MachineInstr &MI : *PreHeader) {
+      if (MI.isCall()) {
+        // Create a t2DoLoopStart at the end of the preheader.
+        MachineInstrBuilder MIB =
+            BuildMI(*PreHeader, PreHeader->getFirstTerminator(),
+                    LoopStart->getDebugLoc(), TII->get(ARM::t2DoLoopStart));
+        MIB.add(LoopStart->getOperand(0));
+        MIB.add(LoopStart->getOperand(1));
+
+        // Revert the t2WhileLoopStartLR to a CMP and Br.
+        RevertWhileLoopStartLR(LoopStart, TII, ARM::t2Bcc, true);
+        LoopStart = MIB;
+        break;
+      }
+    }
+  }
+
+  for (MachineBasicBlock *MBB : ML->blocks()) {
     for (MachineInstr &MI : *MBB) {
       if (MI.isCall()) {
         LLVM_DEBUG(dbgs() << "Found call in loop, reverting: " << MI);
diff --git a/lib/Target/ARM/MVETailPredUtils.h b/lib/Target/ARM/MVETailPredUtils.h
index 4798c66..b0c0031 100644
--- a/lib/Target/ARM/MVETailPredUtils.h
+++ b/lib/Target/ARM/MVETailPredUtils.h
@@ -77,24 +77,38 @@
 
 // WhileLoopStart holds the exit block, so produce a subs Op0, Op1, 0 and then a
 // beq that branches to the exit branch.
+// If UseCmp is true, this will create a t2CMP instead of a t2SUBri, meaning the
+// value of LR into the loop will not be setup. This is used if the LR setup is
+// done via another means (via a t2DoLoopStart, for example).
 inline void RevertWhileLoopStartLR(MachineInstr *MI, const TargetInstrInfo *TII,
-                                   unsigned BrOpc = ARM::t2Bcc) {
+                                   unsigned BrOpc = ARM::t2Bcc,
+                                   bool UseCmp = false) {
   MachineBasicBlock *MBB = MI->getParent();
   assert(MI->getOpcode() == ARM::t2WhileLoopStartLR &&
          "Only expected a t2WhileLoopStartLR in RevertWhileLoopStartLR!");
 
-  // Subs
-  MachineInstrBuilder MIB =
-      BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2SUBri));
-  MIB.add(MI->getOperand(0));
-  MIB.add(MI->getOperand(1));
-  MIB.addImm(0);
-  MIB.addImm(ARMCC::AL);
-  MIB.addReg(ARM::NoRegister);
-  MIB.addReg(ARM::CPSR, RegState::Define);
+  // Subs/Cmp
+  if (UseCmp) {
+    MachineInstrBuilder MIB =
+        BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2CMPri));
+    MIB.add(MI->getOperand(1));
+    MIB.addImm(0);
+    MIB.addImm(ARMCC::AL);
+    MIB.addReg(ARM::NoRegister);
+  } else {
+    MachineInstrBuilder MIB =
+        BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(ARM::t2SUBri));
+    MIB.add(MI->getOperand(0));
+    MIB.add(MI->getOperand(1));
+    MIB.addImm(0);
+    MIB.addImm(ARMCC::AL);
+    MIB.addReg(ARM::NoRegister);
+    MIB.addReg(ARM::CPSR, RegState::Define);
+  }
 
   // Branch
-  MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
+  MachineInstrBuilder MIB =
+      BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc));
   MIB.add(MI->getOperand(2)); // branch target
   MIB.addImm(ARMCC::EQ);      // condition code
   MIB.addReg(ARM::CPSR);
diff --git a/test/CodeGen/Thumb2/LowOverheadLoops/while-loops.ll b/test/CodeGen/Thumb2/LowOverheadLoops/while-loops.ll
index 52690fc..cfa5a34 100644
--- a/test/CodeGen/Thumb2/LowOverheadLoops/while-loops.ll
+++ b/test/CodeGen/Thumb2/LowOverheadLoops/while-loops.ll
@@ -322,21 +322,20 @@
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    .save {r4, r5, r6, lr}
 ; CHECK-NEXT:    push {r4, r5, r6, lr}
-; CHECK-NEXT:    subs r6, r2, #0
 ; CHECK-NEXT:    mov r5, r0
 ; CHECK-NEXT:    mov r4, r1
-; CHECK-NEXT:    mov.w r0, #0
-; CHECK-NEXT:    beq .LBB3_3
+; CHECK-NEXT:    movs r0, #0
+; CHECK-NEXT:    cbz r2, .LBB3_3
 ; CHECK-NEXT:  @ %bb.1: @ %for.body.ph
+; CHECK-NEXT:    mov r6, r2
 ; CHECK-NEXT:    bl callee
+; CHECK-NEXT:    dls lr, r6
 ; CHECK-NEXT:    movs r0, #0
 ; CHECK-NEXT:  .LBB3_2: @ %for.body
 ; CHECK-NEXT:    @ =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ldr r1, [r5], #4
-; CHECK-NEXT:    subs r6, #1
 ; CHECK-NEXT:    add r0, r1
-; CHECK-NEXT:    cbz r6, .LBB3_3
-; CHECK-NEXT:    le .LBB3_2
+; CHECK-NEXT:    le lr, .LBB3_2
 ; CHECK-NEXT:  .LBB3_3: @ %for.cond.cleanup
 ; CHECK-NEXT:    str r0, [r4]
 ; CHECK-NEXT:    pop {r4, r5, r6, pc}