| #!/usr/bin/env python |
| """A script to generate FileCheck statements for mlir unit tests. |
| |
| This script is a utility to add FileCheck patterns to an mlir file. |
| |
| NOTE: The input .mlir is expected to be the output from the parser, not a |
| stripped down variant. |
| |
| Example usage: |
| $ generate-test-checks.py foo.mlir |
| $ mlir-opt foo.mlir -transformation | generate-test-checks.py |
| |
| The script will heuristically insert CHECK/CHECK-LABEL commands for each line |
| within the file. By default this script will also try to insert string |
| substitution blocks for all SSA value names. The script is designed to make |
| adding checks to a test case fast, it is *not* designed to be authoritative |
| about what constitutes a good test! |
| """ |
| |
| # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import argparse |
| import os # Used to advertise this file's name ("autogenerated_note"). |
| import re |
| import sys |
| |
| ADVERT = '// NOTE: Assertions have been autogenerated by ' |
| |
| # Regex command to match an SSA identifier. |
| SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*' |
| SSA_RE = re.compile(SSA_RE_STR) |
| |
| |
| # Class used to generate and manage string substitution blocks for SSA value |
| # names. |
| class SSAVariableNamer: |
| |
| def __init__(self): |
| self.scopes = [] |
| self.name_counter = 0 |
| |
| # Generate a substitution name for the given ssa value name. |
| def generate_name(self, ssa_name): |
| variable = 'VAL_' + str(self.name_counter) |
| self.name_counter += 1 |
| self.scopes[-1][ssa_name] = variable |
| return variable |
| |
| # Push a new variable name scope. |
| def push_name_scope(self): |
| self.scopes.append({}) |
| |
| # Pop the last variable name scope. |
| def pop_name_scope(self): |
| self.scopes.pop() |
| |
| |
| # Process a line of input that has been split at each SSA identifier '%'. |
| def process_line(line_chunks, variable_namer): |
| output_line = '' |
| |
| # Process the rest that contained an SSA value name. |
| for chunk in line_chunks: |
| m = SSA_RE.match(chunk) |
| ssa_name = m.group(0) |
| |
| # Check if an existing variable exists for this name. |
| variable = None |
| for scope in variable_namer.scopes: |
| variable = scope.get(ssa_name) |
| if variable is not None: |
| break |
| |
| # If one exists, then output the existing name. |
| if variable is not None: |
| output_line += '[[' + variable + ']]' |
| else: |
| # Otherwise, generate a new variable. |
| variable = variable_namer.generate_name(ssa_name) |
| output_line += '[[' + variable + ':%.*]]' |
| |
| # Append the non named group. |
| output_line += chunk[len(ssa_name):] |
| |
| return output_line + '\n' |
| |
| |
| # Pre-process a line of input to remove any character sequences that will be |
| # problematic with FileCheck. |
| def preprocess_line(line): |
| # Replace any double brackets, '[[' with escaped replacements. '[[' |
| # corresponds to variable names in FileCheck. |
| output_line = line.replace('[[', '{{\\[\\[}}') |
| |
| # Replace any single brackets that are followed by an SSA identifier, the |
| # identifier will be replace by a variable; Creating the same situation as |
| # above. |
| output_line = output_line.replace('[%', '{{\\[}}%') |
| |
| return output_line |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser( |
| description=__doc__, formatter_class=argparse.RawTextHelpFormatter) |
| parser.add_argument( |
| '--check-prefix', default='CHECK', help='Prefix to use from check file.') |
| parser.add_argument( |
| '-o', |
| '--output', |
| nargs='?', |
| type=argparse.FileType('w'), |
| default=sys.stdout) |
| parser.add_argument( |
| 'input', |
| nargs='?', |
| type=argparse.FileType('r'), |
| default=sys.stdin) |
| args = parser.parse_args() |
| |
| # Open the given input file. |
| input_lines = [l.rstrip() for l in args.input] |
| args.input.close() |
| |
| output_lines = [] |
| |
| # Generate a note used for the generated check file. |
| script_name = os.path.basename(__file__) |
| autogenerated_note = (ADVERT + 'utils/' + script_name) |
| output_lines.append(autogenerated_note + '\n') |
| |
| # A map containing data used for naming SSA value names. |
| variable_namer = SSAVariableNamer() |
| for input_line in input_lines: |
| if not input_line: |
| continue |
| lstripped_input_line = input_line.lstrip() |
| |
| # Lines with blocks begin with a ^. These lines have a trailing comment |
| # that needs to be stripped. |
| is_block = lstripped_input_line[0] == '^' |
| if is_block: |
| input_line = input_line.rsplit('//', 1)[0].rstrip() |
| |
| # Top-level operations are heuristically the operations at nesting level 1. |
| is_toplevel_op = (not is_block and input_line.startswith(' ') and |
| input_line[2] != ' ' and input_line[2] != '}') |
| |
| # If the line starts with a '}', pop the last name scope. |
| if lstripped_input_line[0] == '}': |
| variable_namer.pop_name_scope() |
| |
| # If the line ends with a '{', push a new name scope. |
| if input_line[-1] == '{': |
| variable_namer.push_name_scope() |
| |
| # Preprocess the input to remove any sequences that may be problematic with |
| # FileCheck. |
| input_line = preprocess_line(input_line) |
| |
| # Split the line at the each SSA value name. |
| ssa_split = input_line.split('%') |
| |
| # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. |
| if not is_toplevel_op or not ssa_split[0]: |
| output_line = '// ' + args.check_prefix + ': ' |
| # Pad to align with the 'LABEL' statements. |
| output_line += (' ' * len('-LABEL')) |
| |
| # Output the first line chunk that does not contain an SSA name. |
| output_line += ssa_split[0] |
| |
| # Process the rest of the input line. |
| output_line += process_line(ssa_split[1:], variable_namer) |
| |
| else: |
| # Append a newline to the output to separate the logical blocks. |
| output_lines.append('\n') |
| output_line = '// ' + args.check_prefix + '-LABEL: ' |
| |
| # Output the first line chunk that does not contain an SSA name for the |
| # label. |
| output_line += ssa_split[0] + '\n' |
| |
| # Process the rest of the input line on a separate check line. |
| if len(ssa_split) > 1: |
| output_line += '// ' + args.check_prefix + '-SAME: ' |
| |
| # Pad to align with the original position in the line. |
| output_line += ' ' * len(ssa_split[0]) |
| |
| # Process the rest of the line. |
| output_line += process_line(ssa_split[1:], variable_namer) |
| |
| # Append the output line. |
| output_lines.append(output_line) |
| |
| # Write the output. |
| for output_line in output_lines: |
| args.output.write(output_line) |
| args.output.write('\n') |
| args.output.close() |
| |
| |
| if __name__ == '__main__': |
| main() |