mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
auto mark parent tests (#6778)
* don't skip auto-format * auto mark parent
This commit is contained in:
1
.github/workflows/pr-auto-commit.yaml
vendored
1
.github/workflows/pr-auto-commit.yaml
vendored
@@ -14,7 +14,6 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
auto_format:
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip:ci') }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
@@ -83,7 +83,7 @@ def parse_results(result):
|
||||
test_results.stdout = result.stdout
|
||||
in_test_results = False
|
||||
for line in lines:
|
||||
if re.match(r"Run tests? sequentially", line):
|
||||
if re.search(r"Run \d+ tests? sequentially", line):
|
||||
in_test_results = True
|
||||
elif line.startswith("-----------"):
|
||||
in_test_results = False
|
||||
@@ -161,6 +161,66 @@ def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> boo
|
||||
return True
|
||||
|
||||
|
||||
def build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
|
||||
"""
|
||||
Build inheritance information from AST.
|
||||
|
||||
Returns:
|
||||
class_bases: dict[str, list[str]] - parent classes for each class (only those defined in the file)
|
||||
class_methods: dict[str, set[str]] - methods directly defined in each class
|
||||
"""
|
||||
all_classes = {
|
||||
node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
|
||||
}
|
||||
class_bases = {}
|
||||
class_methods = {}
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
# Collect only parent classes defined in this file
|
||||
bases = [
|
||||
base.id
|
||||
for base in node.bases
|
||||
if isinstance(base, ast.Name) and base.id in all_classes
|
||||
]
|
||||
class_bases[node.name] = bases
|
||||
|
||||
# Collect directly defined methods
|
||||
methods = {
|
||||
item.name
|
||||
for item in node.body
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
}
|
||||
class_methods[node.name] = methods
|
||||
|
||||
return class_bases, class_methods
|
||||
|
||||
|
||||
def find_method_definition(
|
||||
class_name: str, method_name: str, class_bases: dict, class_methods: dict
|
||||
) -> str | None:
|
||||
"""Find the class where a method is actually defined. Traverses inheritance chain (BFS)."""
|
||||
# Check current class first
|
||||
if method_name in class_methods.get(class_name, set()):
|
||||
return class_name
|
||||
|
||||
# Search parent classes
|
||||
visited = set()
|
||||
queue = list(class_bases.get(class_name, []))
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
|
||||
if method_name in class_methods.get(current, set()):
|
||||
return current
|
||||
queue.extend(class_bases.get(current, []))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def remove_expected_failures(
|
||||
contents: str, tests_to_remove: set[tuple[str, str]]
|
||||
) -> str:
|
||||
@@ -172,6 +232,18 @@ def remove_expected_failures(
|
||||
lines = contents.splitlines()
|
||||
lines_to_remove = set()
|
||||
|
||||
# Build inheritance information
|
||||
class_bases, class_methods = build_inheritance_info(tree)
|
||||
|
||||
# Resolve to actual defining classes
|
||||
resolved_tests = set()
|
||||
for class_name, method_name in tests_to_remove:
|
||||
defining_class = find_method_definition(
|
||||
class_name, method_name, class_bases, class_methods
|
||||
)
|
||||
if defining_class:
|
||||
resolved_tests.add((defining_class, method_name))
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
@@ -180,7 +252,7 @@ def remove_expected_failures(
|
||||
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
continue
|
||||
method_name = item.name
|
||||
if (class_name, method_name) not in tests_to_remove:
|
||||
if (class_name, method_name) not in resolved_tests:
|
||||
continue
|
||||
|
||||
# Check if we should remove the entire method (super() call only)
|
||||
|
||||
@@ -236,11 +236,14 @@ def build_patch_dict(it: "Iterator[PatchEntry]") -> Patches:
|
||||
|
||||
|
||||
def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int, str]]":
|
||||
cache = {} # Used in phase 2. Stores the end line location of a class name.
|
||||
# Build cache of all classes (for Phase 2 to find classes without methods)
|
||||
cache = {}
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
cache[node.name] = node.end_lineno
|
||||
|
||||
# Phase 1: Iterate and mark existing tests
|
||||
for cls_node, fn_node in iter_tests(tree):
|
||||
cache[cls_node.name] = cls_node.end_lineno
|
||||
specs = patches.get(cls_node.name, {}).pop(fn_node.name, None)
|
||||
if not specs:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user