[mlir][utils] Update generate-test-checks.py (use SSA names) (#136819)
This patch updates generate-test-checks.py to preserve original SSA
names (capitalized) when generating LIT variable names for function
arguments (i.e. for `CHECK-SAME` lines). This improves readability and
helps maintain consistency between the input MLIR and the expected
FileCheck/LIT output.
For example, given the following function:
```mlir
func.func @example(
%input: memref<4x6x3xf32>,
%filter: memref<1x3x8xf32>,
%output: memref<4x2x8xf32>) {
linalg.conv_1d_nwc_wcf
{dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
ins(%input, %filter : memref<4x6x3xf32>, memref<1x3x8xf32>)
outs(%output : memref<4x2x8xf32>)
return
}
```
The generated output becomes:
```mlir
// CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref(
// CHECK-SAME: %[[INPUT:.*]]: memref<4x6x3xf32>,
// CHECK-SAME: %[[FILTER:.*]]: memref<1x3x8xf32>,
// CHECK-SAME: %[[OUTPUT:.*]]: memref<4x2x8xf32>) {
// CHECK: linalg.conv_1d_nwc_wcf
// CHECK: {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK: ins(%[[INPUT]], %[[FILTER]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK: outs(%[[OUTPUT]] : memref<4x2x8xf32>)
// CHECK: return
// CHECK: }
```
By contrast, the current version of the script would generate:
```mlir
// CHECK-LABEL: func.func @conv1d_nwc_4x2x8_memref(
// CHECK-SAME: %[[VAL_0:.*]]: memref<4x6x3xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: memref<1x3x8xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<4x2x8xf32>) {
// CHECK: linalg.conv_1d_nwc_wcf
// CHECK: {dilations = dense<1> : tensor<1xi64>, strides = dense<3> : tensor<1xi64>}
// CHECK: ins(%[[VAL_0]], %[[VAL_1]] : memref<4x6x3xf32>, memref<1x3x8xf32>)
// CHECK: outs(%[[VAL_2]] : memref<4x2x8xf32>)
// CHECK: return
// CHECK: }
```
diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py
index d157af9..11fb4e4 100755
--- a/mlir/utils/generate-test-checks.py
+++ b/mlir/utils/generate-test-checks.py
@@ -77,13 +77,20 @@
self.generate_in_parent_scope_left = n
# Generate a substitution name for the given ssa value name.
- def generate_name(self, source_variable_name):
+ def generate_name(self, source_variable_name, use_ssa_name):
# Compute variable name
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
if variable_name == '':
- variable_name = "VAL_" + str(self.name_counter)
- self.name_counter += 1
+ # If `use_ssa_name` is set, use the MLIR SSA value name to generate
+ # a FileCHeck substation string. As FileCheck requires these
+ # strings to start with a character, skip MLIR variables starting
+ # with a digit (e.g. `%0`).
+ if use_ssa_name and source_variable_name[0].isalpha():
+ variable_name = source_variable_name.upper()
+ else:
+ variable_name = "VAL_" + str(self.name_counter)
+ self.name_counter += 1
# Scope where variable name is saved
scope = len(self.scopes) - 1
@@ -158,7 +165,7 @@
# Process a line of input that has been split at each SSA identifier '%'.
-def process_line(line_chunks, variable_namer, strict_name_re=False):
+def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
output_line = ""
# Process the rest that contained an SSA value name.
@@ -178,7 +185,7 @@
output_line += "%[[" + variable + "]]"
else:
# Otherwise, generate a new variable.
- variable = variable_namer.generate_name(ssa_name)
+ variable = variable_namer.generate_name(ssa_name, use_ssa_name)
if strict_name_re:
# Use stricter regexp for the variable name, if requested.
# Greedy matching may cause issues with the generic '.*'
@@ -415,9 +422,11 @@
pad_depth = label_length if label_length < 21 else 4
output_line += " " * pad_depth
- # Process the rest of the line.
+ # Process the rest of the line. Use the original SSA name to generate the LIT
+ # variable names.
+ use_ssa_names = True
output_line += process_line(
- [argument], variable_namer, args.strict_name_re
+ [argument], variable_namer, use_ssa_names, args.strict_name_re
)
# Append the output line.