[MLIR] Removing dead values for branches (#117501)

Fixing RemoveDeadValues to properly remove arguments from
BranchOpInterface operations.
This is a follow-up for:
https://github.com/llvm/llvm-project/pull/117405
cc: @joker-eph @codemzs

---------

Co-authored-by: Renat Idrisov <parsifal-47@users.noreply.github.com>
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index dbce4a54..3429008 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -172,7 +172,7 @@
 /// iff it has no memory effects and none of its results are live.
 ///
 /// It is assumed that `op` is simple. Here, a simple op is one which isn't a
-/// symbol op, a symbol-user op, a region branch op, a branch op, a region
+/// function-like op, a call-like op, a region branch op, a branch op, a region
 /// branch terminator op, or return-like.
 static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
   if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
@@ -563,6 +563,51 @@
   dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
 }
 
+// 1. Iterate over each successor block of the given BranchOpInterface
+//    operation.
+// 2. For each successor block:
+//    a. Retrieve the operands passed to the successor.
+//    b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
+//       which operands are live in the successor block.
+//    c. Mark each operand as live or dead based on the analysis.
+// 3. Remove dead operands from the branch operation and arguments accordingly
+
+static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) {
+  unsigned numSuccessors = branchOp->getNumSuccessors();
+
+  // Do (1)
+  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+    Block *successorBlock = branchOp->getSuccessor(succIdx);
+
+    // Do (2)
+    SuccessorOperands successorOperands =
+        branchOp.getSuccessorOperands(succIdx);
+    SmallVector<Value> operandValues;
+    for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
+         ++operandIdx) {
+      operandValues.push_back(successorOperands[operandIdx]);
+    }
+
+    BitVector successorLiveOperands = markLives(operandValues, la);
+
+    // Do (3)
+    for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) {
+      if (!successorLiveOperands[argIdx]) {
+        if (successorBlock->getNumArguments() < successorOperands.size()) {
+          // if block was cleaned through a different code path
+          // we only need to remove operands from the invokation
+          successorOperands.erase(argIdx);
+          continue;
+        }
+
+        successorBlock->getArgument(argIdx).dropAllUses();
+        successorOperands.erase(argIdx);
+        successorBlock->eraseArgument(argIdx);
+      }
+    }
+  }
+}
+
 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
   void runOnOperation() override;
 };
@@ -572,26 +617,13 @@
   auto &la = getAnalysis<RunLivenessAnalysis>();
   Operation *module = getOperation();
 
-  // The removal of non-live values is performed iff there are no branch ops,
-  // and all symbol user ops present in the IR are call-like.
-  WalkResult acceptableIR = module->walk([&](Operation *op) {
-    if (op == module)
-      return WalkResult::advance();
-    if (isa<BranchOpInterface>(op)) {
-      op->emitError() << "cannot optimize an IR with branch ops\n";
-      return WalkResult::interrupt();
-    }
-    return WalkResult::advance();
-  });
-
-  if (acceptableIR.wasInterrupted())
-    return signalPassFailure();
-
   module->walk([&](Operation *op) {
     if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
       cleanFuncOp(funcOp, module, la);
     } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
       cleanRegionBranchOp(regionBranchOp, la);
+    } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+      cleanBranchOp(branchOp, la);
     } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
       // Nothing to do here because this is a terminator op and it should be
       // honored with respect to its parent
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 5387552..9273ac0 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -28,22 +28,51 @@
 
 // -----
 
-// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
+// The IR contains both conditional and unconditional branches with a loop
+// in which the last cf.cond_br is referncing the first cf.br
 //
-func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
+func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) {
   %non_live = arith.constant 0 : i32
-  // expected-error @+1 {{cannot optimize an IR with branch ops}}
-  cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
-^bb1(%non_live_0 : i32):
-  cf.br ^bb3
-^bb2(%non_live_1 : i32):
-  cf.br ^bb3
-^bb3:
+  // CHECK-NOT: arith.constant
+  cf.br ^bb1(%non_live : i32)
+  // CHECK: cf.br ^[[BB1:bb[0-9]+]]
+^bb1(%non_live_1 : i32):
+  // CHECK: ^[[BB1]]:
+  %non_live_5 = arith.constant 1 : i32
+  cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
+  // CHECK: cf.br ^[[BB3:bb[0-9]+]]
+  // CHECK-NOT: i32
+^bb3(%non_live_2 : i32, %non_live_6 : i32):
+  // CHECK: ^[[BB3]]:
+  cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
+  // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
+^bb4(%non_live_4 : i32):
+  // CHECK: ^[[BB4]]:
   return
 }
 
 // -----
 
+// Checking that iter_args are properly handled
+//
+func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %non_live = arith.constant 0 : index
+  // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+  %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
+    // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+    %new_live = arith.addi %live_arg, %i : index
+    // CHECK: scf.yield [[SUM:%.+]]
+    scf.yield %new_live, %non_live_arg : index, index
+  }
+  // CHECK: return [[RESULT]] : index
+  return %result : index
+}
+
+// -----
+
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
 // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {