auto mark parent tests (#6778)

* don't skip auto-format

* auto mark parent
This commit is contained in:
Jeong, YunWon
2026-01-18 22:06:07 +09:00
committed by GitHub
parent 252fa816d6
commit 2b8fac3af3
3 changed files with 79 additions and 5 deletions

View File

@@ -14,7 +14,6 @@ concurrency:
jobs:
auto_format:
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip:ci') }}
permissions:
contents: write
pull-requests: write

View File

@@ -83,7 +83,7 @@ def parse_results(result):
test_results.stdout = result.stdout
in_test_results = False
for line in lines:
if re.match(r"Run tests? sequentially", line):
if re.search(r"Run \d+ tests? sequentially", line):
in_test_results = True
elif line.startswith("-----------"):
in_test_results = False
@@ -161,6 +161,66 @@ def is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> boo
return True
def build_inheritance_info(tree: ast.Module) -> tuple[dict, dict]:
"""
Build inheritance information from AST.
Returns:
class_bases: dict[str, list[str]] - parent classes for each class (only those defined in the file)
class_methods: dict[str, set[str]] - methods directly defined in each class
"""
all_classes = {
node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)
}
class_bases = {}
class_methods = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
# Collect only parent classes defined in this file
bases = [
base.id
for base in node.bases
if isinstance(base, ast.Name) and base.id in all_classes
]
class_bases[node.name] = bases
# Collect directly defined methods
methods = {
item.name
for item in node.body
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
}
class_methods[node.name] = methods
return class_bases, class_methods
def find_method_definition(
class_name: str, method_name: str, class_bases: dict, class_methods: dict
) -> str | None:
"""Find the class where a method is actually defined. Traverses inheritance chain (BFS)."""
# Check current class first
if method_name in class_methods.get(class_name, set()):
return class_name
# Search parent classes
visited = set()
queue = list(class_bases.get(class_name, []))
while queue:
current = queue.pop(0)
if current in visited:
continue
visited.add(current)
if method_name in class_methods.get(current, set()):
return current
queue.extend(class_bases.get(current, []))
return None
def remove_expected_failures(
contents: str, tests_to_remove: set[tuple[str, str]]
) -> str:
@@ -172,6 +232,18 @@ def remove_expected_failures(
lines = contents.splitlines()
lines_to_remove = set()
# Build inheritance information
class_bases, class_methods = build_inheritance_info(tree)
# Resolve to actual defining classes
resolved_tests = set()
for class_name, method_name in tests_to_remove:
defining_class = find_method_definition(
class_name, method_name, class_bases, class_methods
)
if defining_class:
resolved_tests.add((defining_class, method_name))
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
@@ -180,7 +252,7 @@ def remove_expected_failures(
if not isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
method_name = item.name
if (class_name, method_name) not in tests_to_remove:
if (class_name, method_name) not in resolved_tests:
continue
# Check if we should remove the entire method (super() call only)

View File

@@ -236,11 +236,14 @@ def build_patch_dict(it: "Iterator[PatchEntry]") -> Patches:
def iter_patch_lines(tree: ast.Module, patches: Patches) -> "Iterator[tuple[int, str]]":
cache = {} # Used in phase 2. Stores the end line location of a class name.
# Build cache of all classes (for Phase 2 to find classes without methods)
cache = {}
for node in tree.body:
if isinstance(node, ast.ClassDef):
cache[node.name] = node.end_lineno
# Phase 1: Iterate and mark existing tests
for cls_node, fn_node in iter_tests(tree):
cache[cls_node.name] = cls_node.end_lineno
specs = patches.get(cls_node.name, {}).pop(fn_node.name, None)
if not specs:
continue