Fix Bug in RemoveDeadValues Pass (#148437)
This patch fixes a bug in the RemoveDeadValues pass where unused
function arguments were not removed from the function signature in an
edge case where the function returns void.
A corresponding test was added to the MLIR LIT test suite to cover this
case.
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 608bdcb..ddd5f2b 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -345,8 +345,6 @@
// since it forwards only to non-live value(s) (%1#1).
Operation *lastReturnOp = funcOp.back().getTerminator();
size_t numReturns = lastReturnOp->getNumOperands();
- if (numReturns == 0)
- return;
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -368,6 +366,8 @@
cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
// Do (5) and (6).
+ if (numReturns == 0)
+ return;
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 3af95db..9ded6a3 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -548,3 +548,26 @@
func.return
}
+// -----
+
+// CHECK-LABEL: module @return_void_with_unused_argument
+module @return_void_with_unused_argument {
+ // CHECK-LABEL: func.func private @fn_return_void_with_unused_argument
+ // CHECK-SAME: (%[[ARG0_FN:.*]]: i32)
+ func.func private @fn_return_void_with_unused_argument(%arg0: i32, %arg1: memref<4xi32>) -> () {
+ %sum = arith.addi %arg0, %arg0 : i32
+ %c0 = arith.constant 0 : index
+ %buf = memref.alloc() : memref<1xi32>
+ memref.store %sum, %buf[%c0] : memref<1xi32>
+ return
+ }
+ // CHECK-LABEL: func.func @main
+ // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+ // CHECK: call @fn_return_void_with_unused_argument(%[[ARG0_MAIN]]) : (i32) -> ()
+ func.func @main(%arg0: i32) -> memref<4xi32> {
+ %unused = memref.alloc() : memref<4xi32>
+ call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
+ return %unused : memref<4xi32>
+ }
+}
+