mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Update annotationlib.py to 3.14.5 (#7867)
This commit is contained in:
36
Lib/annotationlib.py
vendored
36
Lib/annotationlib.py
vendored
@@ -47,6 +47,7 @@ _SLOTS = (
|
||||
"__cell__",
|
||||
"__owner__",
|
||||
"__stringifier_dict__",
|
||||
"__resolved_str_cache__",
|
||||
)
|
||||
|
||||
|
||||
@@ -94,6 +95,7 @@ class ForwardRef:
|
||||
# value later.
|
||||
self.__code__ = None
|
||||
self.__ast_node__ = None
|
||||
self.__resolved_str_cache__ = None
|
||||
|
||||
def __init_subclass__(cls, /, *args, **kwds):
|
||||
raise TypeError("Cannot subclass ForwardRef")
|
||||
@@ -113,7 +115,7 @@ class ForwardRef:
|
||||
"""
|
||||
match format:
|
||||
case Format.STRING:
|
||||
return self.__forward_arg__
|
||||
return self.__resolved_str__
|
||||
case Format.VALUE:
|
||||
is_forwardref_format = False
|
||||
case Format.FORWARDREF:
|
||||
@@ -258,6 +260,24 @@ class ForwardRef:
|
||||
"Attempted to access '__forward_arg__' on an uninitialized ForwardRef"
|
||||
)
|
||||
|
||||
@property
|
||||
def __resolved_str__(self):
|
||||
# __forward_arg__ with any names from __extra_names__ replaced
|
||||
# with the type_repr of the value they represent
|
||||
if self.__resolved_str_cache__ is None:
|
||||
resolved_str = self.__forward_arg__
|
||||
names = self.__extra_names__
|
||||
|
||||
if names:
|
||||
visitor = _ExtraNameFixer(names)
|
||||
ast_expr = ast.parse(resolved_str, mode="eval").body
|
||||
node = visitor.visit(ast_expr)
|
||||
resolved_str = ast.unparse(node)
|
||||
|
||||
self.__resolved_str_cache__ = resolved_str
|
||||
|
||||
return self.__resolved_str_cache__
|
||||
|
||||
@property
|
||||
def __forward_code__(self):
|
||||
if self.__code__ is not None:
|
||||
@@ -321,7 +341,7 @@ class ForwardRef:
|
||||
extra.append(", is_class=True")
|
||||
if self.__owner__ is not None:
|
||||
extra.append(f", owner={self.__owner__!r}")
|
||||
return f"ForwardRef({self.__forward_arg__!r}{''.join(extra)})"
|
||||
return f"ForwardRef({self.__resolved_str__!r}{''.join(extra)})"
|
||||
|
||||
|
||||
_Template = type(t"")
|
||||
@@ -357,6 +377,7 @@ class _Stringifier:
|
||||
self.__cell__ = cell
|
||||
self.__owner__ = owner
|
||||
self.__stringifier_dict__ = stringifier_dict
|
||||
self.__resolved_str_cache__ = None # Needed for ForwardRef
|
||||
|
||||
def __convert_to_ast(self, other):
|
||||
if isinstance(other, _Stringifier):
|
||||
@@ -1163,3 +1184,14 @@ def _get_dunder_annotations(obj):
|
||||
if not isinstance(ann, dict):
|
||||
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
|
||||
return ann
|
||||
|
||||
|
||||
class _ExtraNameFixer(ast.NodeTransformer):
|
||||
"""Fixer for __extra_names__ items in ForwardRef __repr__ and string evaluation"""
|
||||
def __init__(self, extra_names):
|
||||
self.extra_names = extra_names
|
||||
|
||||
def visit_Name(self, node: ast.Name):
|
||||
if (new_name := self.extra_names.get(node.id, _sentinel)) is not _sentinel:
|
||||
node = ast.Name(id=type_repr(new_name))
|
||||
return node
|
||||
|
||||
98
Lib/test/test_annotationlib.py
vendored
98
Lib/test/test_annotationlib.py
vendored
@@ -1619,6 +1619,84 @@ class TestCallAnnotateFunction(unittest.TestCase):
|
||||
# Some non-Format value
|
||||
annotationlib.call_annotate_function(annotate, 7)
|
||||
|
||||
def test_basic_non_function_annotate(self):
|
||||
class Annotate:
|
||||
def __call__(self, format, /, __Format=Format,
|
||||
__NotImplementedError=NotImplementedError):
|
||||
if format == __Format.VALUE:
|
||||
return {'x': str}
|
||||
elif format == __Format.VALUE_WITH_FAKE_GLOBALS:
|
||||
return {'x': int}
|
||||
elif format == __Format.STRING:
|
||||
return {'x': "float"}
|
||||
else:
|
||||
raise __NotImplementedError(format)
|
||||
|
||||
annotations = annotationlib.call_annotate_function(Annotate(), Format.VALUE)
|
||||
self.assertEqual(annotations, {"x": str})
|
||||
|
||||
annotations = annotationlib.call_annotate_function(Annotate(), Format.STRING)
|
||||
self.assertEqual(annotations, {"x": "float"})
|
||||
|
||||
with self.assertRaises(AttributeError) as cm:
|
||||
annotations = annotationlib.call_annotate_function(
|
||||
Annotate(), Format.FORWARDREF
|
||||
)
|
||||
|
||||
self.assertEqual(cm.exception.name, "__builtins__")
|
||||
self.assertIsInstance(cm.exception.obj, Annotate)
|
||||
|
||||
def test_full_non_function_annotate(self):
|
||||
def outer():
|
||||
local = str
|
||||
|
||||
class Annotate:
|
||||
called_formats = []
|
||||
|
||||
def __call__(self, format=None, *, _self=None):
|
||||
nonlocal local
|
||||
if _self is not None:
|
||||
self, format = _self, self
|
||||
|
||||
self.called_formats.append(format)
|
||||
if format == 1: # VALUE
|
||||
return {"x": MyClass, "y": int, "z": local}
|
||||
if format == 2: # VALUE_WITH_FAKE_GLOBALS
|
||||
return {"w": unknown, "x": MyClass, "y": int, "z": local}
|
||||
raise NotImplementedError
|
||||
|
||||
__globals__ = {"MyClass": MyClass}
|
||||
__builtins__ = {"int": int}
|
||||
__closure__ = (types.CellType(str),)
|
||||
__defaults__ = (None,)
|
||||
|
||||
__kwdefaults__ = property(lambda self: dict(_self=self))
|
||||
__code__ = property(lambda self: self.__call__.__code__)
|
||||
|
||||
return Annotate()
|
||||
|
||||
annotate = outer()
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.call_annotate_function(annotate, Format.VALUE),
|
||||
{"x": MyClass, "y": int, "z": str}
|
||||
)
|
||||
self.assertEqual(annotate.called_formats[-1], Format.VALUE)
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.call_annotate_function(annotate, Format.STRING),
|
||||
{"w": "unknown", "x": "MyClass", "y": "int", "z": "local"}
|
||||
)
|
||||
self.assertIn(Format.STRING, annotate.called_formats)
|
||||
self.assertEqual(annotate.called_formats[-1], Format.VALUE_WITH_FAKE_GLOBALS)
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.call_annotate_function(annotate, Format.FORWARDREF),
|
||||
{"w": support.EqualToForwardRef("unknown"), "x": MyClass, "y": int, "z": str}
|
||||
)
|
||||
self.assertIn(Format.FORWARDREF, annotate.called_formats)
|
||||
self.assertEqual(annotate.called_formats[-1], Format.VALUE_WITH_FAKE_GLOBALS)
|
||||
|
||||
def test_error_from_value_raised(self):
|
||||
# Test that the error from format.VALUE is raised
|
||||
# if all formats fail
|
||||
@@ -1961,6 +2039,15 @@ class TestForwardRefClass(unittest.TestCase):
|
||||
"typing.List[ForwardRef('int', owner='class')]",
|
||||
)
|
||||
|
||||
def test_forward_repr_extra_names(self):
|
||||
def f(a: undefined | str): ...
|
||||
|
||||
annos = get_annotations(f, format=Format.FORWARDREF)
|
||||
|
||||
self.assertRegex(
|
||||
repr(annos['a']), r"ForwardRef\('undefined \| str'.*\)"
|
||||
)
|
||||
|
||||
def test_forward_recursion_actually(self):
|
||||
def namespace1():
|
||||
a = ForwardRef("A")
|
||||
@@ -2037,6 +2124,17 @@ class TestForwardRefClass(unittest.TestCase):
|
||||
fr = ForwardRef("set[Any]")
|
||||
self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")
|
||||
|
||||
def test_evaluate_string_format_extra_names(self):
|
||||
# Test that internal extra_names are replaced when evaluating as strings
|
||||
def f(a: unknown | str | int | list[str] | tuple[int, ...]): ...
|
||||
|
||||
fr = get_annotations(f, format=Format.FORWARDREF)['a']
|
||||
# Test the cache is not populated before access
|
||||
self.assertIsNone(fr.__resolved_str_cache__)
|
||||
|
||||
self.assertEqual(fr.evaluate(format=Format.STRING), "unknown | str | int | list[str] | tuple[int, ...]")
|
||||
self.assertEqual(fr.__resolved_str_cache__, "unknown | str | int | list[str] | tuple[int, ...]")
|
||||
|
||||
def test_evaluate_forwardref_format(self):
|
||||
fr = ForwardRef("undef")
|
||||
evaluated = fr.evaluate(format=Format.FORWARDREF)
|
||||
|
||||
Reference in New Issue
Block a user