mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
* Make auto-mark output deterministic and fix blank line leak Sort set iteration in build_patches and dict iteration in _iter_patch_lines Phase 2 so expectedFailure markers are always added in alphabetical order. Include preceding blank line in _method_removal_range so removing a super-call override doesn't leave behind the blank line that was added with the method. * deps code * auto-mark handles crash better * Auto-format: ruff format --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
398 lines
12 KiB
Python
398 lines
12 KiB
Python
"""
|
|
Low-level module for converting between test files and JSON patches.
|
|
|
|
This module handles:
|
|
- Extracting patches from test files (file -> JSON)
|
|
- Applying patches to test files (JSON -> file)
|
|
"""
|
|
|
|
import ast
|
|
import collections
|
|
import enum
|
|
import textwrap
|
|
import typing
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from collections.abc import Iterator
|
|
|
|
type Patches = dict[str, dict[str, list["PatchSpec"]]]
|
|
|
|
DEFAULT_INDENT = " " * 4
|
|
COMMENT = "TODO: RUSTPYTHON"
|
|
UT = "unittest"
|
|
|
|
|
|
@enum.unique
|
|
class UtMethod(enum.StrEnum):
|
|
"""
|
|
UnitTest Method.
|
|
"""
|
|
|
|
def _generate_next_value_(name, start, count, last_values) -> str:
|
|
return name[0].lower() + name[1:]
|
|
|
|
def has_args(self) -> bool:
|
|
return self != self.ExpectedFailure
|
|
|
|
def has_cond(self) -> bool:
|
|
return self.endswith(("If", "Unless"))
|
|
|
|
ExpectedFailure = enum.auto()
|
|
ExpectedFailureIf = enum.auto()
|
|
ExpectedFailureIfWindows = enum.auto()
|
|
Skip = enum.auto()
|
|
SkipIf = enum.auto()
|
|
SkipUnless = enum.auto()
|
|
|
|
|
|
class PatchSpec(typing.NamedTuple):
|
|
"""
|
|
Attributes
|
|
----------
|
|
ut_method : UtMethod
|
|
unittest method.
|
|
cond : str, optional
|
|
`ut_method` condition. Relevant only for some of `ut_method` types.
|
|
reason : str, optional
|
|
Reason for why the test is patched in this way.
|
|
"""
|
|
|
|
ut_method: UtMethod
|
|
cond: str | None = None
|
|
reason: str = ""
|
|
|
|
@property
|
|
def _reason(self) -> str:
|
|
return f"{COMMENT}; {self.reason}".strip(" ;")
|
|
|
|
@property
|
|
def _attr_node(self) -> ast.Attribute:
|
|
return ast.Attribute(value=ast.Name(id=UT), attr=self.ut_method)
|
|
|
|
def as_ast_node(self) -> ast.Attribute | ast.Call:
|
|
if not self.ut_method.has_args():
|
|
return self._attr_node
|
|
|
|
args = []
|
|
if self.cond:
|
|
args.append(ast.parse(self.cond).body[0].value)
|
|
args.append(ast.Constant(value=self._reason))
|
|
|
|
return ast.Call(func=self._attr_node, args=args, keywords=[])
|
|
|
|
def as_decorator(self) -> str:
|
|
unparsed = ast.unparse(self.as_ast_node())
|
|
# ast.unparse uses single quotes; convert to double quotes for ruff compatibility
|
|
unparsed = _single_to_double_quotes(unparsed)
|
|
|
|
if not self.ut_method.has_args():
|
|
unparsed = f"{unparsed} # {self._reason}"
|
|
|
|
return f"@{unparsed}"
|
|
|
|
|
|
def _single_to_double_quotes(s: str) -> str:
|
|
"""Convert single-quoted strings to double-quoted strings.
|
|
|
|
Falls back to original if conversion breaks the AST equivalence.
|
|
"""
|
|
import re
|
|
|
|
def replace_string(match: re.Match) -> str:
|
|
content = match.group(1)
|
|
# Unescape single quotes and escape double quotes
|
|
content = content.replace("\\'", "'").replace('"', '\\"')
|
|
return f'"{content}"'
|
|
|
|
# Match single-quoted strings (handles escaped single quotes inside)
|
|
converted = re.sub(r"'((?:[^'\\]|\\.)*)'", replace_string, s)
|
|
|
|
# Verify: parse converted and unparse should equal original
|
|
try:
|
|
converted_ast = ast.parse(converted, mode="eval")
|
|
if ast.unparse(converted_ast) == s:
|
|
return converted
|
|
except SyntaxError:
|
|
pass
|
|
|
|
# Fall back to original if conversion failed
|
|
return s
|
|
|
|
|
|
class PatchEntry(typing.NamedTuple):
|
|
"""
|
|
Stores patch metadata.
|
|
|
|
Attributes
|
|
----------
|
|
parent_class : str
|
|
Parent class of test.
|
|
test_name : str
|
|
Test name.
|
|
spec : PatchSpec
|
|
Patch spec.
|
|
"""
|
|
|
|
parent_class: str
|
|
test_name: str
|
|
spec: PatchSpec
|
|
|
|
@classmethod
|
|
def iter_patch_entries(
|
|
cls, tree: ast.Module, lines: list[str]
|
|
) -> "Iterator[typing.Self]":
|
|
import re
|
|
import sys
|
|
|
|
for cls_node, fn_node in iter_tests(tree):
|
|
parent_class = cls_node.name
|
|
for dec_node in fn_node.decorator_list:
|
|
if not isinstance(dec_node, (ast.Attribute, ast.Call)):
|
|
continue
|
|
|
|
attr_node = (
|
|
dec_node if isinstance(dec_node, ast.Attribute) else dec_node.func
|
|
)
|
|
|
|
if (
|
|
isinstance(attr_node, ast.Name)
|
|
or getattr(attr_node.value, "id", None) != UT
|
|
):
|
|
continue
|
|
|
|
cond = None
|
|
try:
|
|
ut_method = UtMethod(attr_node.attr)
|
|
except ValueError:
|
|
continue
|
|
|
|
# If our ut_method has args then,
|
|
# we need to search for a constant that contains our `COMMENT`.
|
|
# Otherwise we need to search it in the raw source code :/
|
|
if ut_method.has_args():
|
|
reason = next(
|
|
(
|
|
node.value
|
|
for node in ast.walk(dec_node)
|
|
if isinstance(node, ast.Constant)
|
|
and isinstance(node.value, str)
|
|
and COMMENT in node.value
|
|
),
|
|
None,
|
|
)
|
|
|
|
# If we didn't find a constant containing <COMMENT>,
|
|
# then we didn't put this decorator
|
|
if not reason:
|
|
continue
|
|
|
|
if ut_method.has_cond():
|
|
cond = ast.unparse(dec_node.args[0])
|
|
else:
|
|
# Search first on decorator line, then in the line before
|
|
for line in lines[dec_node.lineno - 1 : dec_node.lineno - 3 : -1]:
|
|
if found := re.search(rf"{COMMENT}.?(.*)", line):
|
|
reason = found.group()
|
|
break
|
|
else:
|
|
# Didn't find our `COMMENT` :)
|
|
continue
|
|
|
|
reason = reason.removeprefix(COMMENT).strip(";:, ")
|
|
spec = PatchSpec(ut_method, cond, reason)
|
|
yield cls(parent_class, fn_node.name, spec)
|
|
|
|
|
|
def iter_tests(
|
|
tree: ast.Module,
|
|
) -> "Iterator[tuple[ast.ClassDef, ast.FunctionDef | ast.AsyncFunctionDef]]":
|
|
for key, nodes in ast.iter_fields(tree):
|
|
if key != "body":
|
|
continue
|
|
|
|
for cls_node in nodes:
|
|
if not isinstance(cls_node, ast.ClassDef):
|
|
continue
|
|
|
|
for fn_node in cls_node.body:
|
|
if not isinstance(fn_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
continue
|
|
|
|
yield (cls_node, fn_node)
|
|
|
|
|
|
def iter_patches(contents: str) -> "Iterator[PatchEntry]":
|
|
lines = contents.splitlines()
|
|
tree = ast.parse(contents)
|
|
yield from PatchEntry.iter_patch_entries(tree, lines)
|
|
|
|
|
|
def build_patch_dict(it: "Iterator[PatchEntry]") -> Patches:
|
|
patches = collections.defaultdict(lambda: collections.defaultdict(list))
|
|
for entry in it:
|
|
patches[entry.parent_class][entry.test_name].append(entry.spec)
|
|
|
|
return {k: dict(v) for k, v in patches.items()}
|
|
|
|
|
|
def extract_patches(contents: str) -> Patches:
|
|
"""Extract patches from file contents and return as dict."""
|
|
return build_patch_dict(iter_patches(contents))
|
|
|
|
|
|
def _iter_patch_lines(
|
|
tree: ast.Module, patches: Patches
|
|
) -> "Iterator[tuple[int, str]]":
|
|
import sys
|
|
|
|
# Build cache of all classes (for Phase 2 to find classes without methods)
|
|
cache = {}
|
|
# Build per-class set of async method names (for Phase 2 to generate correct override)
|
|
async_methods: dict[str, set[str]] = {}
|
|
# Track class bases for inherited async method lookup
|
|
class_bases: dict[str, list[str]] = {}
|
|
all_classes = {node.name for node in tree.body if isinstance(node, ast.ClassDef)}
|
|
for node in tree.body:
|
|
if isinstance(node, ast.ClassDef):
|
|
cache[node.name] = node.end_lineno
|
|
class_bases[node.name] = [
|
|
base.id
|
|
for base in node.bases
|
|
if isinstance(base, ast.Name) and base.id in all_classes
|
|
]
|
|
cls_async: set[str] = set()
|
|
for item in node.body:
|
|
if isinstance(item, ast.AsyncFunctionDef):
|
|
cls_async.add(item.name)
|
|
if cls_async:
|
|
async_methods[node.name] = cls_async
|
|
|
|
# Phase 1: Iterate and mark existing tests
|
|
for cls_node, fn_node in iter_tests(tree):
|
|
specs = patches.get(cls_node.name, {}).pop(fn_node.name, None)
|
|
if not specs:
|
|
continue
|
|
|
|
lineno = min(
|
|
(dec_node.lineno for dec_node in fn_node.decorator_list),
|
|
default=fn_node.lineno,
|
|
)
|
|
indent = " " * fn_node.col_offset
|
|
patch_lines = "\n".join(spec.as_decorator() for spec in specs)
|
|
yield (lineno - 1, textwrap.indent(patch_lines, indent))
|
|
|
|
# Phase 2: Iterate and mark inherited tests
|
|
for cls_name, tests in sorted(patches.items()):
|
|
lineno = cache.get(cls_name)
|
|
if not lineno:
|
|
print(f"WARNING: {cls_name} does not exist in remote file", file=sys.stderr)
|
|
continue
|
|
|
|
for test_name, specs in sorted(tests.items()):
|
|
decorators = "\n".join(spec.as_decorator() for spec in specs)
|
|
# Check current class and ancestors for async method
|
|
is_async = False
|
|
queue = [cls_name]
|
|
visited: set[str] = set()
|
|
while queue:
|
|
cur = queue.pop(0)
|
|
if cur in visited:
|
|
continue
|
|
visited.add(cur)
|
|
if test_name in async_methods.get(cur, set()):
|
|
is_async = True
|
|
break
|
|
queue.extend(class_bases.get(cur, []))
|
|
if is_async:
|
|
patch_lines = f"""
|
|
{decorators}
|
|
async def {test_name}(self):
|
|
{DEFAULT_INDENT}return await super().{test_name}()
|
|
""".rstrip()
|
|
else:
|
|
patch_lines = f"""
|
|
{decorators}
|
|
def {test_name}(self):
|
|
{DEFAULT_INDENT}return super().{test_name}()
|
|
""".rstrip()
|
|
yield (lineno, textwrap.indent(patch_lines, DEFAULT_INDENT))
|
|
|
|
|
|
def _has_unittest_import(tree: ast.Module) -> bool:
|
|
"""Check if 'import unittest' is already present in the file."""
|
|
for node in tree.body:
|
|
if isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
if alias.name == UT and alias.asname is None:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _find_import_insert_line(tree: ast.Module) -> int:
|
|
"""Find the line number after the last import statement."""
|
|
last_import_line = None
|
|
for node in tree.body:
|
|
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
|
last_import_line = node.end_lineno or node.lineno
|
|
if last_import_line is not None:
|
|
return last_import_line
|
|
# No imports found - insert after module docstring if present, else at top
|
|
if (
|
|
tree.body
|
|
and isinstance(tree.body[0], ast.Expr)
|
|
and isinstance(tree.body[0].value, ast.Constant)
|
|
and isinstance(tree.body[0].value.value, str)
|
|
):
|
|
return tree.body[0].end_lineno or tree.body[0].lineno
|
|
return 0
|
|
|
|
|
|
def apply_patches(contents: str, patches: Patches) -> str:
|
|
"""Apply patches to file contents and return modified contents."""
|
|
tree = ast.parse(contents)
|
|
lines = contents.splitlines()
|
|
|
|
modifications = list(_iter_patch_lines(tree, patches))
|
|
|
|
# If we have modifications and unittest is not imported, add it
|
|
if modifications and not _has_unittest_import(tree):
|
|
import_line = _find_import_insert_line(tree)
|
|
modifications.append(
|
|
(
|
|
import_line,
|
|
"\nimport unittest # XXX: RUSTPYTHON; importing to be able to skip tests",
|
|
)
|
|
)
|
|
|
|
# Going in reverse to not disrupt the line offset
|
|
for lineno, patch in sorted(modifications, reverse=True):
|
|
lines.insert(lineno, patch)
|
|
|
|
joined = "\n".join(lines)
|
|
return f"{joined}\n"
|
|
|
|
|
|
def patches_to_json(patches: Patches) -> dict:
|
|
"""Convert patches to JSON-serializable dict."""
|
|
return {
|
|
cls_name: {
|
|
test_name: [spec._asdict() for spec in specs]
|
|
for test_name, specs in tests.items()
|
|
}
|
|
for cls_name, tests in patches.items()
|
|
}
|
|
|
|
|
|
def patches_from_json(data: dict) -> Patches:
|
|
"""Convert JSON dict back to Patches."""
|
|
return {
|
|
cls_name: {
|
|
test_name: [
|
|
PatchSpec(**spec)._replace(ut_method=UtMethod(spec["ut_method"]))
|
|
for spec in specs
|
|
]
|
|
for test_name, specs in tests.items()
|
|
}
|
|
for cls_name, tests in data.items()
|
|
}
|