diff --git a/scripts/fix_test.py b/scripts/fix_test.py index 53b10d638..49a2e1870 100644 --- a/scripts/fix_test.py +++ b/scripts/fix_test.py @@ -15,12 +15,11 @@ How to use: """ import argparse -import ast -import itertools -import platform import sys from pathlib import Path +from lib_updater import apply_patches, PatchSpec, UtMethod + def parse_args(): parser = argparse.ArgumentParser(description="Fix test.") @@ -102,39 +101,16 @@ def path_to_test(path) -> list[str]: return parts[-2:] # Get class name and method name -def find_test_lineno(file: str, test: list[str]) -> tuple[int, int] | None: - """Find the line number and column offset of a test function. - Returns (lineno, col_offset) or None if not found. - """ - a = ast.parse(file) - for key, node in ast.iter_fields(a): - if key == "body": - for n in node: - match n: - case ast.ClassDef(): - if len(test) == 2 and test[0] == n.name: - for fn in n.body: - match fn: - case ast.FunctionDef() | ast.AsyncFunctionDef(): - if fn.name == test[-1]: - return (fn.lineno, fn.col_offset) - case ast.FunctionDef() | ast.AsyncFunctionDef(): - if n.name == test[0] and len(test) == 1: - return (n.lineno, n.col_offset) - return None - - -def apply_modifications(file: str, modifications: list[tuple[int, int]]) -> str: - """Apply all modifications in reverse order to avoid line number offset issues.""" - lines = file.splitlines() - fixture = "@unittest.expectedFailure" - # Sort by line number in descending order - modifications.sort(key=lambda x: x[0], reverse=True) - for lineno, col_offset in modifications: - indent = " " * col_offset - lines.insert(lineno - 1, indent + fixture) - lines.insert(lineno - 1, indent + "# TODO: RUSTPYTHON") - return "\n".join(lines) +def build_patches(test_parts_set: set[tuple[str, str]]) -> dict: + """Convert failing tests to lib_updater patch format.""" + patches = {} + for class_name, method_name in test_parts_set: + if class_name not in patches: + patches[class_name] = {} + patches[class_name][method_name] = [ + PatchSpec(UtMethod.ExpectedFailure, None, "") + ] + return patches def run_test(test_name): @@ -167,26 +143,21 @@ if __name__ == "__main__": tests = run_test(test_name) f = test_path.read_text(encoding="utf-8") - # Collect all modifications first (with deduplication for subtests) - modifications = [] + # Collect failing tests (with deduplication for subtests) seen_tests = set() # Track (class_name, method_name) to avoid duplicates for test in tests.tests: if test.result == "fail" or test.result == "error": test_parts = path_to_test(test.path) - test_key = tuple(test_parts) - if test_key in seen_tests: - continue # Skip duplicate (same test, different subtest) - seen_tests.add(test_key) - location = find_test_lineno(f, test_parts) - if location: - print(f"Modifying test: {test.name} at line {location[0]}") - modifications.append(location) - else: - print(f"Warning: Could not find test: {test.name} ({test_parts})") + if len(test_parts) == 2: + test_key = tuple(test_parts) + if test_key not in seen_tests: + seen_tests.add(test_key) + print(f"Marking test: {test_parts[0]}.{test_parts[1]}") - # Apply all modifications in reverse order - if modifications: - f = apply_modifications(f, modifications) + # Apply patches using lib_updater + if seen_tests: + patches = build_patches(seen_tests) + f = apply_patches(f, patches) test_path.write_text(f, encoding="utf-8") - print(f"Modified {len(modifications)} tests") + print(f"Modified {len(seen_tests)} tests")