integrate to lib_updater

This commit is contained in:
Jeong YunWon
2026-01-11 00:18:18 +09:00
parent 3909b18eac
commit 2a1faf4265

View File

@@ -15,12 +15,11 @@ How to use:
"""
import argparse
import ast
import itertools
import platform
import sys
from pathlib import Path
from lib_updater import apply_patches, PatchSpec, UtMethod
def parse_args():
parser = argparse.ArgumentParser(description="Fix test.")
@@ -102,39 +101,16 @@ def path_to_test(path) -> list[str]:
return parts[-2:] # Get class name and method name
def find_test_lineno(file: str, test: list[str]) -> tuple[int, int] | None:
"""Find the line number and column offset of a test function.
Returns (lineno, col_offset) or None if not found.
"""
a = ast.parse(file)
for key, node in ast.iter_fields(a):
if key == "body":
for n in node:
match n:
case ast.ClassDef():
if len(test) == 2 and test[0] == n.name:
for fn in n.body:
match fn:
case ast.FunctionDef() | ast.AsyncFunctionDef():
if fn.name == test[-1]:
return (fn.lineno, fn.col_offset)
case ast.FunctionDef() | ast.AsyncFunctionDef():
if n.name == test[0] and len(test) == 1:
return (n.lineno, n.col_offset)
return None
def apply_modifications(file: str, modifications: list[tuple[int, int]]) -> str:
"""Apply all modifications in reverse order to avoid line number offset issues."""
lines = file.splitlines()
fixture = "@unittest.expectedFailure"
# Sort by line number in descending order
modifications.sort(key=lambda x: x[0], reverse=True)
for lineno, col_offset in modifications:
indent = " " * col_offset
lines.insert(lineno - 1, indent + fixture)
lines.insert(lineno - 1, indent + "# TODO: RUSTPYTHON")
return "\n".join(lines)
def build_patches(test_parts_set: set[tuple[str, str]]) -> dict:
"""Convert failing tests to lib_updater patch format."""
patches = {}
for class_name, method_name in test_parts_set:
if class_name not in patches:
patches[class_name] = {}
patches[class_name][method_name] = [
PatchSpec(UtMethod.ExpectedFailure, None, "")
]
return patches
def run_test(test_name):
@@ -167,26 +143,21 @@ if __name__ == "__main__":
tests = run_test(test_name)
f = test_path.read_text(encoding="utf-8")
# Collect all modifications first (with deduplication for subtests)
modifications = []
# Collect failing tests (with deduplication for subtests)
seen_tests = set() # Track (class_name, method_name) to avoid duplicates
for test in tests.tests:
if test.result == "fail" or test.result == "error":
test_parts = path_to_test(test.path)
test_key = tuple(test_parts)
if test_key in seen_tests:
continue # Skip duplicate (same test, different subtest)
seen_tests.add(test_key)
location = find_test_lineno(f, test_parts)
if location:
print(f"Modifying test: {test.name} at line {location[0]}")
modifications.append(location)
else:
print(f"Warning: Could not find test: {test.name} ({test_parts})")
if len(test_parts) == 2:
test_key = tuple(test_parts)
if test_key not in seen_tests:
seen_tests.add(test_key)
print(f"Marking test: {test_parts[0]}.{test_parts[1]}")
# Apply all modifications in reverse order
if modifications:
f = apply_modifications(f, modifications)
# Apply patches using lib_updater
if seen_tests:
patches = build_patches(seen_tests)
f = apply_patches(f, patches)
test_path.write_text(f, encoding="utf-8")
print(f"Modified {len(modifications)} tests")
print(f"Modified {len(seen_tests)} tests")