Files
RustPython/scripts/update_lib/cmd_auto_mark.py

1045 lines
38 KiB
Python

#!/usr/bin/env python
"""
Auto-mark test failures in Python test suite.
This module provides functions to:
- Run tests with RustPython and parse results
- Extract test names from test file paths
- Mark failing tests with @unittest.expectedFailure
- Remove expectedFailure from tests that now pass
"""
import ast
import pathlib
import re
import subprocess
import sys
from dataclasses import dataclass, field
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
from update_lib import COMMENT, PatchSpec, UtMethod, apply_patches
from update_lib.file_utils import get_test_module_name
class TestRunError(Exception):
"""Raised when test run fails entirely (e.g., import error, crash)."""
pass
@dataclass
class Test:
name: str = ""
path: str = ""
result: str = ""
error_message: str = ""
@dataclass
class TestResult:
tests_result: str = ""
tests: list[Test] = field(default_factory=list)
unexpected_successes: list[Test] = field(default_factory=list)
stdout: str = ""
def run_test(test_name: str, skip_build: bool = False) -> TestResult:
"""
Run a test with RustPython and return parsed results.
Args:
test_name: Test module name (e.g., "test_foo" or "test_ctypes.test_bar")
skip_build: If True, use pre-built binary instead of cargo run
Returns:
TestResult with parsed test results
"""
if skip_build:
cmd = ["./target/release/rustpython"]
if sys.platform == "win32":
cmd = ["./target/release/rustpython.exe"]
else:
cmd = ["cargo", "run", "--release", "--"]
result = subprocess.run(
cmd + ["-m", "test", "-v", "-u", "all", "--slowest", test_name],
stdout=subprocess.PIPE, # Capture stdout for parsing
stderr=None, # Let stderr pass through to terminal
text=True,
)
return parse_results(result)
def _try_parse_test_info(test_info: str) -> tuple[str, str] | None:
"""Try to extract (name, path) from 'test_name (path)' or 'test_name (path) [subtest]'."""
first_space = test_info.find(" ")
if first_space > 0:
name = test_info[:first_space]
rest = test_info[first_space:].strip()
if rest.startswith("("):
end_paren = rest.find(")")
if end_paren > 0:
return name, rest[1:end_paren]
return None
def parse_results(result: subprocess.CompletedProcess) -> TestResult:
"""Parse subprocess result into TestResult."""
lines = result.stdout.splitlines()
test_results = TestResult()
test_results.stdout = result.stdout
in_test_results = False
# For multiline format: "test_name (path)\ndocstring ... RESULT"
pending_test_info = None
for line in lines:
if re.search(r"Run \d+ tests? sequentially", line):
in_test_results = True
elif "== Tests result: " in line:
in_test_results = False
if in_test_results and " ... " in line:
stripped = line.strip()
# Skip lines that don't look like test results
if stripped.startswith("tests") or stripped.startswith("["):
pending_test_info = None
continue
# Parse: "test_name (path) [subtest] ... RESULT"
parts = stripped.split(" ... ")
if len(parts) >= 2:
test_info = parts[0]
result_str = parts[-1].lower()
# Only process FAIL or ERROR
if result_str not in ("fail", "error"):
pending_test_info = None
continue
# Try parsing from this line (single-line format)
parsed = _try_parse_test_info(test_info)
if not parsed and pending_test_info:
# Multiline format: previous line had test_name (path)
parsed = _try_parse_test_info(pending_test_info)
if parsed:
test = Test()
test.name, test.path = parsed
test.result = result_str
test_results.tests.append(test)
pending_test_info = None
elif in_test_results:
# Track test info for multiline format:
# test_name (path)
# docstring ... RESULT
stripped = line.strip()
if (
stripped
and "(" in stripped
and stripped.endswith(")")
and ":" not in stripped.split("(")[0]
):
pending_test_info = stripped
else:
pending_test_info = None
# Also check for Tests result on non-" ... " lines
if "== Tests result: " in line:
res = line.split("== Tests result: ")[1]
res = res.split(" ")[0]
test_results.tests_result = res
elif "== Tests result: " in line:
res = line.split("== Tests result: ")[1]
res = res.split(" ")[0]
test_results.tests_result = res
# Parse: "UNEXPECTED SUCCESS: test_name (path)"
if line.startswith("UNEXPECTED SUCCESS: "):
rest = line[len("UNEXPECTED SUCCESS: ") :]
# Format: "test_name (path)"
first_space = rest.find(" ")
if first_space > 0:
test = Test()
test.name = rest[:first_space]
path_part = rest[first_space:].strip()
if path_part.startswith("(") and path_part.endswith(")"):
test.path = path_part[1:-1]
test.result = "unexpected_success"
test_results.unexpected_successes.append(test)
# Parse error details to extract error messages
_parse_error_details(test_results, lines)
return test_results
def _parse_error_details(test_results: TestResult, lines: list[str]) -> None:
"""Parse error details section to extract error messages for each test."""
# Build a lookup dict for tests by (name, path)
test_lookup: dict[tuple[str, str], Test] = {}
for test in test_results.tests:
test_lookup[(test.name, test.path)] = test
# Parse error detail blocks
# Format:
# ======================================================================
# FAIL: test_name (path)
# ----------------------------------------------------------------------
# Traceback (most recent call last):
# ...
# AssertionError: message
#
# ======================================================================
i = 0
while i < len(lines):
line = lines[i]
# Look for FAIL: or ERROR: header
if line.startswith(("FAIL: ", "ERROR: ")):
# Parse: "FAIL: test_name (path)" or "ERROR: test_name (path)"
header = line.split(": ", 1)[1] if ": " in line else ""
first_space = header.find(" ")
if first_space > 0:
test_name = header[:first_space]
path_part = header[first_space:].strip()
if path_part.startswith("(") and path_part.endswith(")"):
test_path = path_part[1:-1]
# Find the last non-empty line before the next separator or end
error_lines = []
i += 1
# Skip the separator line
if i < len(lines) and lines[i].startswith("-----"):
i += 1
# Collect lines until the next separator or end
while i < len(lines):
current = lines[i]
if current.startswith("=====") or current.startswith("-----"):
break
error_lines.append(current)
i += 1
# Find the last non-empty line (the error message)
error_message = ""
for err_line in reversed(error_lines):
stripped = err_line.strip()
if stripped:
error_message = stripped
break
# Update the test with the error message
if (test_name, test_path) in test_lookup:
test_lookup[
(test_name, test_path)
].error_message = error_message
continue
i += 1
def path_to_test_parts(path: str) -> list[str]:
"""
Extract [ClassName, method_name] from test path.
Args:
path: Test path like "test.module_name.ClassName.test_method"
Returns:
[ClassName, method_name] - last 2 elements
"""
parts = path.split(".")
return parts[-2:]
def _expand_stripped_to_children(
contents: str,
stripped_tests: set[tuple[str, str]],
all_failing_tests: set[tuple[str, str]],
) -> set[tuple[str, str]]:
"""Find child-class failures that correspond to stripped parent-class markers.
When ``strip_reasonless_expected_failures`` removes a marker from a parent
(mixin) class, test failures are reported against the concrete subclasses,
not the parent itself. This function maps those child failures back so
they get re-marked (and later consolidated to the parent by
``_consolidate_to_parent``).
Returns the set of ``(class, method)`` pairs from *all_failing_tests* that
should be re-marked.
"""
# Direct matches (stripped test itself is a concrete TestCase)
result = stripped_tests & all_failing_tests
unmatched = stripped_tests - all_failing_tests
if not unmatched:
return result
tree = ast.parse(contents)
class_bases, class_methods = _build_inheritance_info(tree)
for parent_cls, method_name in unmatched:
if method_name not in class_methods.get(parent_cls, set()):
continue
for cls in _find_all_inheritors(
parent_cls, method_name, class_bases, class_methods
):
if (cls, method_name) in all_failing_tests:
result.add((cls, method_name))
return result
def _consolidate_to_parent(
contents: str,
failing_tests: set[tuple[str, str]],
error_messages: dict[tuple[str, str], str] | None = None,
) -> tuple[set[tuple[str, str]], dict[tuple[str, str], str] | None]:
"""Move failures to the parent class when ALL inheritors fail.
If every concrete subclass that inherits a method from a parent class
appears in *failing_tests*, replace those per-subclass entries with a
single entry on the parent. This avoids creating redundant super-call
overrides in every child.
Returns:
(consolidated_failing_tests, consolidated_error_messages)
"""
tree = ast.parse(contents)
class_bases, class_methods = _build_inheritance_info(tree)
# Group by (defining_parent, method) → set of failing children
from collections import defaultdict
groups: dict[tuple[str, str], set[str]] = defaultdict(set)
for class_name, method_name in failing_tests:
defining = _find_method_definition(
class_name, method_name, class_bases, class_methods
)
if defining and defining != class_name:
groups[(defining, method_name)].add(class_name)
if not groups:
return failing_tests, error_messages
result = set(failing_tests)
new_error_messages = dict(error_messages) if error_messages else {}
for (parent, method_name), failing_children in groups.items():
all_inheritors = _find_all_inheritors(
parent, method_name, class_bases, class_methods
)
if all_inheritors and failing_children >= all_inheritors:
# All inheritors fail → mark on parent instead
children_keys = {(child, method_name) for child in failing_children}
result -= children_keys
result.add((parent, method_name))
# Pick any child's error message for the parent
if new_error_messages:
for child in failing_children:
msg = new_error_messages.pop((child, method_name), "")
if msg:
new_error_messages[(parent, method_name)] = msg
return result, new_error_messages or error_messages
def build_patches(
test_parts_set: set[tuple[str, str]],
error_messages: dict[tuple[str, str], str] | None = None,
) -> dict:
"""Convert failing tests to patch format."""
patches = {}
error_messages = error_messages or {}
for class_name, method_name in sorted(test_parts_set):
if class_name not in patches:
patches[class_name] = {}
reason = error_messages.get((class_name, method_name), "")
patches[class_name][method_name] = [
PatchSpec(UtMethod.ExpectedFailure, None, reason)
]
return patches
def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
"""Check if the method body is just 'return super().method_name()' or 'return await super().method_name()'."""
if len(func_node.body) != 1:
return False
stmt = func_node.body[0]
if not isinstance(stmt, ast.Return) or stmt.value is None:
return False
call = stmt.value
# Unwrap await for async methods
if isinstance(call, ast.Await):
call = call.value
if not isinstance(call, ast.Call):
return False
if not isinstance(call.func, ast.Attribute):
return False
# Verify the method name matches
if call.func.attr != func_node.name:
return False
super_call = call.func.value
if not isinstance(super_call, ast.Call):
return False
if not isinstance(super_call.func, ast.Name) or super_call.func.id != "super":
return False
return True
def _method_removal_range(
func_node: ast.FunctionDef | ast.AsyncFunctionDef, lines: list[str]
) -> range:
"""Line range covering an entire method including decorators and a preceding COMMENT line."""
first = (
func_node.decorator_list[0].lineno - 1
if func_node.decorator_list
else func_node.lineno - 1
)
if (
first > 0
and lines[first - 1].strip().startswith("#")
and COMMENT in lines[first - 1]
):
first -= 1
# Also remove a preceding blank line to avoid double-blanks after removal
if first > 0 and not lines[first - 1].strip():
first -= 1
return range(first, func_node.end_lineno)
def _build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
"""
Build inheritance information from AST.
Returns:
class_bases: dict[str, list[str]] - parent classes for each class
class_methods: dict[str, set[str]] - methods directly defined in each class
"""
all_classes = {
node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
class_bases = {}
class_methods = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
bases = [
base.id
for base in node.bases
if isinstance(base, ast.Name) and base.id in all_classes
]
class_bases[node.name] = bases
methods = {
item.name
for item in node.body
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
}
class_methods[node.name] = methods
return class_bases, class_methods
def _find_method_definition(
class_name: str, method_name: str, class_bases: dict, class_methods: dict
) -> str | None:
"""Find the class where a method is actually defined (BFS)."""
if method_name in class_methods.get(class_name, set()):
return class_name
visited = set()
queue = list(class_bases.get(class_name, []))
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
if method_name in class_methods.get(current, set()):
return current
queue.extend(class_bases.get(current, []))
return None
def _find_all_inheritors(
parent: str, method_name: str, class_bases: dict, class_methods: dict
) -> set[str]:
"""Find all classes that inherit *method_name* from *parent* (not overriding it)."""
return {
cls
for cls in class_bases
if cls != parent
and method_name not in class_methods.get(cls, set())
and _find_method_definition(cls, method_name, class_bases, class_methods)
== parent
}
def remove_expected_failures(
contents: str, tests_to_remove: set[tuple[str, str]]
) -> str:
"""Remove @unittest.expectedFailure decorators from tests that now pass."""
if not tests_to_remove:
return contents
tree = ast.parse(contents)
lines = contents.splitlines()
lines_to_remove = set()
class_bases, class_methods = _build_inheritance_info(tree)
resolved_tests = set()
for class_name, method_name in tests_to_remove:
defining_class = _find_method_definition(
class_name, method_name, class_bases, class_methods
)
if defining_class:
resolved_tests.add((defining_class, method_name))
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
class_name = node.name
for item in node.body:
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
method_name = item.name
if (class_name, method_name) not in resolved_tests:
continue
remove_entire_method = _is_super_call_only(item)
if remove_entire_method:
lines_to_remove.update(_method_removal_range(item, lines))
else:
for dec in item.decorator_list:
dec_line = dec.lineno - 1
line_content = lines[dec_line]
if "expectedFailure" not in line_content:
continue
has_comment_on_line = COMMENT in line_content
has_comment_before = (
dec_line > 0
and lines[dec_line - 1].strip().startswith("#")
and COMMENT in lines[dec_line - 1]
)
has_comment_after = (
dec_line + 1 < len(lines)
and lines[dec_line + 1].strip().startswith("#")
and COMMENT not in lines[dec_line + 1]
)
if has_comment_on_line or has_comment_before:
lines_to_remove.add(dec_line)
if has_comment_before:
lines_to_remove.add(dec_line - 1)
if has_comment_after and has_comment_on_line:
lines_to_remove.add(dec_line + 1)
for line_idx in sorted(lines_to_remove, reverse=True):
del lines[line_idx]
return "\n".join(lines) + "\n" if lines else ""
def collect_test_changes(
results: TestResult,
module_prefix: str | None = None,
) -> tuple[set[tuple[str, str]], set[tuple[str, str]], dict[tuple[str, str], str]]:
"""
Collect failing tests and unexpected successes from test results.
Args:
results: TestResult from run_test()
module_prefix: If set, only collect tests whose path starts with this prefix
Returns:
(failing_tests, unexpected_successes, error_messages)
- failing_tests: set of (class_name, method_name) tuples
- unexpected_successes: set of (class_name, method_name) tuples
- error_messages: dict mapping (class_name, method_name) to error message
"""
failing_tests = set()
error_messages: dict[tuple[str, str], str] = {}
for test in results.tests:
if test.result in ("fail", "error"):
if module_prefix and not test.path.startswith(module_prefix):
continue
test_parts = path_to_test_parts(test.path)
if len(test_parts) == 2:
key = tuple(test_parts)
failing_tests.add(key)
if test.error_message:
error_messages[key] = test.error_message
unexpected_successes = set()
for test in results.unexpected_successes:
if module_prefix and not test.path.startswith(module_prefix):
continue
test_parts = path_to_test_parts(test.path)
if len(test_parts) == 2:
unexpected_successes.add(tuple(test_parts))
return failing_tests, unexpected_successes, error_messages
def apply_test_changes(
contents: str,
failing_tests: set[tuple[str, str]],
unexpected_successes: set[tuple[str, str]],
error_messages: dict[tuple[str, str], str] | None = None,
) -> str:
"""
Apply test changes to content.
Args:
contents: File content
failing_tests: Set of (class_name, method_name) to mark as expectedFailure
unexpected_successes: Set of (class_name, method_name) to remove expectedFailure
error_messages: Dict mapping (class_name, method_name) to error message
Returns:
Modified content
"""
if unexpected_successes:
contents = remove_expected_failures(contents, unexpected_successes)
if failing_tests:
failing_tests, error_messages = _consolidate_to_parent(
contents, failing_tests, error_messages
)
patches = build_patches(failing_tests, error_messages)
contents = apply_patches(contents, patches)
return contents
def strip_reasonless_expected_failures(
contents: str,
) -> tuple[str, set[tuple[str, str]]]:
"""Strip @expectedFailure decorators that have no failure reason.
Markers like ``@unittest.expectedFailure # TODO: RUSTPYTHON`` (without a
reason after the semicolon) are removed so the tests fail normally during
the next test run and error messages can be captured.
Returns:
(modified_contents, stripped_tests) where stripped_tests is a set of
(class_name, method_name) tuples whose markers were removed.
"""
tree = ast.parse(contents)
lines = contents.splitlines()
stripped_tests: set[tuple[str, str]] = set()
lines_to_remove: set[int] = set()
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
for item in node.body:
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
for dec in item.decorator_list:
dec_line = dec.lineno - 1
line_content = lines[dec_line]
if "expectedFailure" not in line_content:
continue
has_comment_on_line = COMMENT in line_content
has_comment_before = (
dec_line > 0
and lines[dec_line - 1].strip().startswith("#")
and COMMENT in lines[dec_line - 1]
)
if not has_comment_on_line and not has_comment_before:
continue # not our marker
# Check if there's a reason (on either the decorator or before)
for check_line in (
line_content,
lines[dec_line - 1] if has_comment_before else "",
):
match = re.search(rf"{COMMENT}(.*)", check_line)
if match and match.group(1).strip(";:, "):
break # has a reason, keep it
else:
# No reason found — strip this decorator
stripped_tests.add((node.name, item.name))
if _is_super_call_only(item):
# Remove entire super-call override (the method
# exists only to apply the decorator; without it
# the override is pointless and blocks parent
# consolidation)
lines_to_remove.update(_method_removal_range(item, lines))
else:
lines_to_remove.add(dec_line)
if has_comment_before:
lines_to_remove.add(dec_line - 1)
# Also remove a reason-comment on the line after (old format)
if (
has_comment_on_line
and dec_line + 1 < len(lines)
and lines[dec_line + 1].strip().startswith("#")
and COMMENT not in lines[dec_line + 1]
):
lines_to_remove.add(dec_line + 1)
if not lines_to_remove:
return contents, stripped_tests
for idx in sorted(lines_to_remove, reverse=True):
del lines[idx]
return "\n".join(lines) + "\n" if lines else "", stripped_tests
def extract_test_methods(contents: str) -> set[tuple[str, str]]:
"""
Extract all test method names from file contents.
Returns:
Set of (class_name, method_name) tuples
"""
from update_lib.file_utils import safe_parse_ast
from update_lib.patch_spec import iter_tests
tree = safe_parse_ast(contents)
if tree is None:
return set()
return {(cls_node.name, fn_node.name) for cls_node, fn_node in iter_tests(tree)}
def auto_mark_file(
test_path: pathlib.Path,
mark_failure: bool = False,
verbose: bool = True,
original_methods: set[tuple[str, str]] | None = None,
skip_build: bool = False,
) -> tuple[int, int, int]:
"""
Run tests and auto-mark failures in a test file.
Args:
test_path: Path to the test file
mark_failure: If True, add @expectedFailure to ALL failing tests
verbose: Print progress messages
original_methods: If provided, only auto-mark failures for NEW methods
(methods not in original_methods) even without mark_failure.
Failures in existing methods are treated as regressions.
Returns:
(num_failures_added, num_successes_removed, num_regressions)
"""
test_path = pathlib.Path(test_path).resolve()
if not test_path.exists():
raise FileNotFoundError(f"File not found: {test_path}")
# Strip reason-less markers so those tests fail normally and we capture
# their error messages during the test run.
contents = test_path.read_text(encoding="utf-8")
original_contents = contents
contents, stripped_tests = strip_reasonless_expected_failures(contents)
if stripped_tests:
test_path.write_text(contents, encoding="utf-8")
test_name = get_test_module_name(test_path)
if verbose:
print(f"Running test: {test_name}")
results = run_test(test_name, skip_build=skip_build)
# Check if test run failed entirely (e.g., import error, crash)
if (
not results.tests_result
and not results.tests
and not results.unexpected_successes
):
# Restore original contents before raising
if stripped_tests:
test_path.write_text(original_contents, encoding="utf-8")
raise TestRunError(
f"Test run failed for {test_name}. "
f"Output: {results.stdout[-500:] if results.stdout else '(no output)'}"
)
# If the run crashed (incomplete), restore original file so that markers
# for tests that never ran are preserved. Only observed results will be
# re-applied below.
if not results.tests_result and stripped_tests:
test_path.write_text(original_contents, encoding="utf-8")
stripped_tests = set()
contents = test_path.read_text(encoding="utf-8")
all_failing_tests, unexpected_successes, error_messages = collect_test_changes(
results
)
# Determine which failures to mark
if mark_failure:
failing_tests = all_failing_tests
elif original_methods is not None:
# Smart mode: only mark NEW test failures (not regressions)
current_methods = extract_test_methods(contents)
new_methods = current_methods - original_methods
failing_tests = {t for t in all_failing_tests if t in new_methods}
else:
failing_tests = set()
# Re-mark stripped tests that still fail (to restore markers with reasons).
# Uses inheritance expansion: if a parent marker was stripped, child
# failures are included so _consolidate_to_parent can re-mark the parent.
failing_tests |= _expand_stripped_to_children(
contents, stripped_tests, all_failing_tests
)
regressions = all_failing_tests - failing_tests
if verbose:
for class_name, method_name in failing_tests:
label = "(new test)" if original_methods is not None else ""
err_msg = error_messages.get((class_name, method_name), "")
err_hint = f" - {err_msg}" if err_msg else ""
print(
f"Marking as failing {label}: {class_name}.{method_name}{err_hint}".replace(
" ", " "
)
)
for class_name, method_name in unexpected_successes:
print(f"Removing expectedFailure: {class_name}.{method_name}")
contents = apply_test_changes(
contents, failing_tests, unexpected_successes, error_messages
)
if failing_tests or unexpected_successes:
test_path.write_text(contents, encoding="utf-8")
# Show hints about unmarked failures
if verbose:
unmarked_failures = all_failing_tests - failing_tests
if unmarked_failures:
print(
f"Hint: {len(unmarked_failures)} failing tests can be marked with --mark-failure; "
"but review first and do not blindly mark them all"
)
for class_name, method_name in sorted(unmarked_failures):
err_msg = error_messages.get((class_name, method_name), "")
err_hint = f" - {err_msg}" if err_msg else ""
print(f" {class_name}.{method_name}{err_hint}")
return len(failing_tests), len(unexpected_successes), len(regressions)
def auto_mark_directory(
test_dir: pathlib.Path,
mark_failure: bool = False,
verbose: bool = True,
original_methods_per_file: dict[pathlib.Path, set[tuple[str, str]]] | None = None,
skip_build: bool = False,
) -> tuple[int, int, int]:
"""
Run tests and auto-mark failures in a test directory.
Runs the test once for the whole directory, then applies results to each file.
Args:
test_dir: Path to the test directory
mark_failure: If True, add @expectedFailure to ALL failing tests
verbose: Print progress messages
original_methods_per_file: If provided, only auto-mark failures for NEW methods
even without mark_failure. Dict maps file path to
set of (class_name, method_name) tuples.
Returns:
(num_failures_added, num_successes_removed, num_regressions)
"""
test_dir = pathlib.Path(test_dir).resolve()
if not test_dir.exists():
raise FileNotFoundError(f"Directory not found: {test_dir}")
if not test_dir.is_dir():
raise ValueError(f"Not a directory: {test_dir}")
# Get all .py files in directory
test_files = sorted(test_dir.glob("**/*.py"))
# Strip reason-less markers from ALL files before running tests so those
# tests fail normally and we capture their error messages.
stripped_per_file: dict[pathlib.Path, set[tuple[str, str]]] = {}
original_per_file: dict[pathlib.Path, str] = {}
for test_file in test_files:
contents = test_file.read_text(encoding="utf-8")
stripped_contents, stripped = strip_reasonless_expected_failures(contents)
if stripped:
original_per_file[test_file] = contents
test_file.write_text(stripped_contents, encoding="utf-8")
stripped_per_file[test_file] = stripped
test_name = get_test_module_name(test_dir)
if verbose:
print(f"Running test: {test_name}")
results = run_test(test_name, skip_build=skip_build)
# Check if test run failed entirely (e.g., import error, crash)
if (
not results.tests_result
and not results.tests
and not results.unexpected_successes
):
# Restore original contents before raising
for fpath, original in original_per_file.items():
fpath.write_text(original, encoding="utf-8")
raise TestRunError(
f"Test run failed for {test_name}. "
f"Output: {results.stdout[-500:] if results.stdout else '(no output)'}"
)
# If the run crashed (incomplete), restore original files so that markers
# for tests that never ran are preserved.
if not results.tests_result and original_per_file:
for fpath, original in original_per_file.items():
fpath.write_text(original, encoding="utf-8")
stripped_per_file.clear()
total_added = 0
total_removed = 0
total_regressions = 0
all_regressions: list[tuple[str, str, str, str]] = []
for test_file in test_files:
# Get module prefix for this file (e.g., "test_inspect.test_inspect")
module_prefix = get_test_module_name(test_file)
# For __init__.py, the test path doesn't include "__init__"
if module_prefix.endswith(".__init__"):
module_prefix = module_prefix[:-9] # Remove ".__init__"
all_failing_tests, unexpected_successes, error_messages = collect_test_changes(
results, module_prefix="test." + module_prefix + "."
)
# Determine which failures to mark
if mark_failure:
failing_tests = all_failing_tests
elif original_methods_per_file is not None:
# Smart mode: only mark NEW test failures
contents = test_file.read_text(encoding="utf-8")
current_methods = extract_test_methods(contents)
original_methods = original_methods_per_file.get(test_file, set())
new_methods = current_methods - original_methods
failing_tests = {t for t in all_failing_tests if t in new_methods}
else:
failing_tests = set()
# Re-mark stripped tests that still fail (restore markers with reasons).
# Uses inheritance expansion for parent→child mapping.
stripped = stripped_per_file.get(test_file, set())
if stripped:
file_contents = test_file.read_text(encoding="utf-8")
failing_tests |= _expand_stripped_to_children(
file_contents, stripped, all_failing_tests
)
regressions = all_failing_tests - failing_tests
if failing_tests or unexpected_successes:
if verbose:
for class_name, method_name in failing_tests:
label = (
"(new test)" if original_methods_per_file is not None else ""
)
err_msg = error_messages.get((class_name, method_name), "")
err_hint = f" - {err_msg}" if err_msg else ""
print(
f" {test_file.name}: Marking as failing {label}: {class_name}.{method_name}{err_hint}".replace(
" :", ":"
)
)
for class_name, method_name in unexpected_successes:
print(
f" {test_file.name}: Removing expectedFailure: {class_name}.{method_name}"
)
contents = test_file.read_text(encoding="utf-8")
contents = apply_test_changes(
contents, failing_tests, unexpected_successes, error_messages
)
test_file.write_text(contents, encoding="utf-8")
# Collect regressions with error messages for later reporting
for class_name, method_name in regressions:
err_msg = error_messages.get((class_name, method_name), "")
all_regressions.append((test_file.name, class_name, method_name, err_msg))
total_added += len(failing_tests)
total_removed += len(unexpected_successes)
total_regressions += len(regressions)
# Show hints about unmarked failures
if verbose and total_regressions > 0:
print(
f"Hint: {total_regressions} failing tests can be marked with --mark-failure; "
"but review first and do not blindly mark them all"
)
for file_name, class_name, method_name, err_msg in sorted(all_regressions):
err_hint = f" - {err_msg}" if err_msg else ""
print(f" {file_name}: {class_name}.{method_name}{err_hint}")
return total_added, total_removed, total_regressions
def main(argv: list[str] | None = None) -> int:
import argparse
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"path",
type=pathlib.Path,
help="Path to test file or directory (e.g., Lib/test/test_foo.py or Lib/test/test_foo/)",
)
parser.add_argument(
"--mark-failure",
action="store_true",
help="Also add @expectedFailure to failing tests (default: only remove unexpected successes)",
)
parser.add_argument(
"--build",
action=argparse.BooleanOptionalAction,
default=True,
help="Build with cargo (default: enabled)",
)
args = parser.parse_args(argv)
try:
if args.path.is_dir():
num_added, num_removed, _ = auto_mark_directory(
args.path, mark_failure=args.mark_failure, skip_build=not args.build
)
else:
num_added, num_removed, _ = auto_mark_file(
args.path, mark_failure=args.mark_failure, skip_build=not args.build
)
if args.mark_failure:
print(f"Added expectedFailure to {num_added} tests")
print(f"Removed expectedFailure from {num_removed} tests")
return 0
except (FileNotFoundError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())