Update dataclasses.py from 3.14.3

This commit is contained in:
ShaharNaveh
2026-02-08 10:03:57 +02:00
committed by Jeong, YunWon
parent d2d8eeea2f
commit 32b57785c3
2 changed files with 705 additions and 102 deletions

292
Lib/dataclasses.py vendored
View File

@@ -5,6 +5,7 @@ import types
import inspect
import keyword
import itertools
import annotationlib
import abc
from reprlib import recursive_repr
@@ -243,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
property,
})
# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
# without importing `typing` module.
_ANY_MARKER = object()
class InitVar:
__slots__ = ('type', )
@@ -282,11 +287,12 @@ class Field:
'compare',
'metadata',
'kw_only',
'doc',
'_field_type', # Private: not to be used by user code.
)
def __init__(self, default, default_factory, init, repr, hash, compare,
metadata, kw_only):
metadata, kw_only, doc):
self.name = None
self.type = None
self.default = default
@@ -299,6 +305,7 @@ class Field:
if metadata is None else
types.MappingProxyType(metadata))
self.kw_only = kw_only
self.doc = doc
self._field_type = None
@recursive_repr()
@@ -314,6 +321,7 @@ class Field:
f'compare={self.compare!r},'
f'metadata={self.metadata!r},'
f'kw_only={self.kw_only!r},'
f'doc={self.doc!r},'
f'_field_type={self._field_type}'
')')
@@ -381,7 +389,7 @@ class _DataclassParams:
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
hash=None, compare=True, metadata=None, kw_only=MISSING):
hash=None, compare=True, metadata=None, kw_only=MISSING, doc=None):
"""Return an object to identify dataclass fields.
default is the default value of the field. default_factory is a
@@ -393,7 +401,7 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
comparison functions. metadata, if specified, must be a mapping
which is stored but not otherwise examined by dataclass. If kw_only
is true, the field will become a keyword-only parameter to
__init__().
__init__(). doc is an optional docstring for this field.
It is an error to specify both default and default_factory.
"""
@@ -401,7 +409,7 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
if default is not MISSING and default_factory is not MISSING:
raise ValueError('cannot specify both default and default_factory')
return Field(default, default_factory, init, repr, hash, compare,
metadata, kw_only)
metadata, kw_only, doc)
def _fields_in_init_order(fields):
@@ -433,9 +441,11 @@ class _FuncBuilder:
self.locals = {}
self.overwrite_errors = {}
self.unconditional_adds = {}
self.method_annotations = {}
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
overwrite_error=False, unconditional_add=False, decorator=None):
overwrite_error=False, unconditional_add=False, decorator=None,
annotation_fields=None):
if locals is not None:
self.locals.update(locals)
@@ -456,16 +466,14 @@ class _FuncBuilder:
self.names.append(name)
if return_type is not MISSING:
self.locals[f'__dataclass_{name}_return_type__'] = return_type
return_annotation = f'->__dataclass_{name}_return_type__'
else:
return_annotation = ''
if annotation_fields is not None:
self.method_annotations[name] = (annotation_fields, return_type)
args = ','.join(args)
body = '\n'.join(body)
# Compute the text of the entire function, add it to the text we're generating.
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
def add_fns_to_class(self, cls):
# The source to all of the functions we're generating.
@@ -501,6 +509,15 @@ class _FuncBuilder:
# Now that we've generated the functions, assign them into cls.
for name, fn in zip(self.names, fns):
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
try:
annotation_fields, return_type = self.method_annotations[name]
except KeyError:
pass
else:
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
fn.__annotate__ = annotate_fn
if self.unconditional_adds.get(name, False):
setattr(cls, name, fn)
else:
@@ -516,6 +533,49 @@ class _FuncBuilder:
raise TypeError(error_msg)
def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
# Create an __annotate__ function for a dataclass
# Try to return annotations in the same format as they would be
# from a regular __init__ function
def __annotate__(format, /):
Format = annotationlib.Format
match format:
case Format.VALUE | Format.FORWARDREF | Format.STRING:
cls_annotations = {}
for base in reversed(__class__.__mro__):
cls_annotations.update(
annotationlib.get_annotations(base, format=format)
)
new_annotations = {}
for k in annotation_fields:
# gh-142214: The annotation may be missing in unusual dynamic cases.
# If so, just skip it.
try:
new_annotations[k] = cls_annotations[k]
except KeyError:
pass
if return_type is not MISSING:
if format == Format.STRING:
new_annotations["return"] = annotationlib.type_repr(return_type)
else:
new_annotations["return"] = return_type
return new_annotations
case _:
raise NotImplementedError(format)
# This is a flag for _add_slots to know it needs to regenerate this method
# In order to remove references to the original class when it is replaced
__annotate__.__generated_by_dataclasses__ = True
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
return __annotate__
def _field_assign(frozen, name, value, self_name):
# If we're a frozen class, then assign to our fields in __init__
# via object.__setattr__. Otherwise, just use a simple
@@ -604,7 +664,7 @@ def _init_param(f):
elif f.default_factory is not MISSING:
# There's a factory function. Set a marker.
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
return f'{f.name}:__dataclass_type_{f.name}__{default}'
return f'{f.name}{default}'
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
@@ -627,11 +687,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
raise TypeError(f'non-default argument {f.name!r} '
f'follows default argument {seen_default.name!r}')
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object,
}
}
annotation_fields = [f.name for f in fields if f.init]
locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object}
body_lines = []
for f in fields:
@@ -655,14 +714,15 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
if kw_only_fields:
# Add the keyword-only args. Because the * can only be added if
# there's at least one keyword-only arg, there needs to be a test here
# (instead of just concatenting the lists together).
# (instead of just concatenating the lists together).
_init_params += ['*']
_init_params += [_init_param(f) for f in kw_only_fields]
func_builder.add_fn('__init__',
[self_name] + _init_params,
body_lines,
locals=locals,
return_type=None)
return_type=None,
annotation_fields=annotation_fields)
def _frozen_get_del_attr(cls, fields, func_builder):
@@ -689,11 +749,8 @@ def _frozen_get_del_attr(cls, fields, func_builder):
def _is_classvar(a_type, typing):
# This test uses a typing internal class, but it's the best way to
# test if this is a ClassVar.
return (a_type is typing.ClassVar
or (type(a_type) is typing._GenericAlias
and a_type.__origin__ is typing.ClassVar))
or (typing.get_origin(a_type) is typing.ClassVar))
def _is_initvar(a_type, dataclasses):
@@ -981,7 +1038,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# actual default value. Pseudo-fields ClassVars and InitVars are
# included, despite the fact that they're not real fields. That's
# dealt with later.
cls_annotations = inspect.get_annotations(cls)
cls_annotations = annotationlib.get_annotations(
cls, format=annotationlib.Format.FORWARDREF)
# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes) where
@@ -1161,7 +1219,10 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
try:
# In some cases fetching a signature is not possible.
# But, we surely should not fail in this case.
text_sig = str(inspect.signature(cls)).replace(' -> None', '')
text_sig = str(inspect.signature(
cls,
annotation_format=annotationlib.Format.FORWARDREF,
)).replace(' -> None', '')
except (TypeError, ValueError):
text_sig = ''
cls.__doc__ = (cls.__name__ + text_sig)
@@ -1175,7 +1236,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
if weakref_slot and not slots:
raise TypeError('weakref_slot is True but slots is False')
if slots:
cls = _add_slots(cls, frozen, weakref_slot)
cls = _add_slots(cls, frozen, weakref_slot, fields)
abc.update_abstractmethods(cls)
@@ -1219,14 +1280,65 @@ def _get_slots(cls):
raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
def _add_slots(cls, is_frozen, weakref_slot):
# Need to create a new class, since we can't set __slots__
# after a class has been created.
def _update_func_cell_for__class__(f, oldcls, newcls):
# Returns True if we update a cell, else False.
if f is None:
# f will be None in the case of a property where not all of
# fget, fset, and fdel are used. Nothing to do in that case.
return False
try:
idx = f.__code__.co_freevars.index("__class__")
except ValueError:
# This function doesn't reference __class__, so nothing to do.
return False
# Fix the cell to point to the new class, if it's already pointing
# at the old class. I'm not convinced that the "is oldcls" test
# is needed, but other than performance can't hurt.
closure = f.__closure__[idx]
if closure.cell_contents is oldcls:
closure.cell_contents = newcls
return True
return False
def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot):
# The slots for our class. Remove slots from our base classes. Add
# '__weakref__' if weakref_slot was given, unless it is already present.
seen_docs = False
slots = {}
for slot in itertools.filterfalse(
inherited_slots.__contains__,
itertools.chain(
# gh-93521: '__weakref__' also needs to be filtered out if
# already present in inherited_slots
field_names, ('__weakref__',) if weakref_slot else ()
)
):
doc = getattr(defined_fields.get(slot), 'doc', None)
if doc is not None:
seen_docs = True
slots[slot] = doc
# We only return dict if there's at least one doc member,
# otherwise we return tuple, which is the old default format.
if seen_docs:
return slots
return tuple(slots)
def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
# Need to create a new class, since we can't set __slots__ after a
# class has been created, and the @dataclass decorator is called
# after the class is created.
# Make sure __slots__ isn't already set.
if '__slots__' in cls.__dict__:
raise TypeError(f'{cls.__name__} already specifies __slots__')
# gh-102069: Remove existing __weakref__ descriptor.
# gh-135228: Make sure the original class can be garbage collected.
sys._clear_type_descriptors(cls)
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in fields(cls))
@@ -1234,17 +1346,9 @@ def _add_slots(cls, is_frozen, weakref_slot):
inherited_slots = set(
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
)
# The slots for our class. Remove slots from our base classes. Add
# '__weakref__' if weakref_slot was given, unless it is already present.
cls_dict["__slots__"] = tuple(
itertools.filterfalse(
inherited_slots.__contains__,
itertools.chain(
# gh-93521: '__weakref__' also needs to be filtered out if
# already present in inherited_slots
field_names, ('__weakref__',) if weakref_slot else ()
)
),
cls_dict["__slots__"] = _create_slots(
defined_fields, inherited_slots, field_names, weakref_slot,
)
for field_name in field_names:
@@ -1252,26 +1356,59 @@ def _add_slots(cls, is_frozen, weakref_slot):
# available in _MARKER.
cls_dict.pop(field_name, None)
# Remove __dict__ itself.
cls_dict.pop('__dict__', None)
# Clear existing `__weakref__` descriptor, it belongs to a previous type:
cls_dict.pop('__weakref__', None) # gh-102069
# And finally create the class.
qualname = getattr(cls, '__qualname__', None)
cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
newcls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
newcls.__qualname__ = qualname
if is_frozen:
# Need this for pickling frozen classes with slots.
if '__getstate__' not in cls_dict:
cls.__getstate__ = _dataclass_getstate
newcls.__getstate__ = _dataclass_getstate
if '__setstate__' not in cls_dict:
cls.__setstate__ = _dataclass_setstate
newcls.__setstate__ = _dataclass_setstate
return cls
# Fix up any closures which reference __class__. This is used to
# fix zero argument super so that it points to the correct class
# (the newly created one, which we're returning) and not the
# original class. We can break out of this loop as soon as we
# make an update, since all closures for a class will share a
# given cell.
for member in newcls.__dict__.values():
# If this is a wrapped function, unwrap it.
member = inspect.unwrap(member)
if isinstance(member, types.FunctionType):
if _update_func_cell_for__class__(member, cls, newcls):
break
elif isinstance(member, property):
if (_update_func_cell_for__class__(member.fget, cls, newcls)
or _update_func_cell_for__class__(member.fset, cls, newcls)
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
break
# Get new annotations to remove references to the original class
# in forward references
newcls_ann = annotationlib.get_annotations(
newcls, format=annotationlib.Format.FORWARDREF)
# Fix references in dataclass Fields
for f in getattr(newcls, _FIELDS).values():
try:
ann = newcls_ann[f.name]
except KeyError:
pass
else:
f.type = ann
# Fix the class reference in the __annotate__ method
init = newcls.__init__
if init_annotate := getattr(init, "__annotate__", None):
if getattr(init_annotate, "__generated_by_dataclasses__", False):
_update_func_cell_for__class__(init_annotate, cls, newcls)
return newcls
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
@@ -1490,7 +1627,7 @@ def _astuple_inner(obj, tuple_factory):
def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
repr=True, eq=True, order=False, unsafe_hash=False,
frozen=False, match_args=True, kw_only=False, slots=False,
weakref_slot=False, module=None):
weakref_slot=False, module=None, decorator=dataclass):
"""Return a new dynamically created dataclass.
The dataclass name will be 'cls_name'. 'fields' is an iterable
@@ -1528,7 +1665,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
for item in fields:
if isinstance(item, str):
name = item
tp = 'typing.Any'
tp = _ANY_MARKER
elif len(item) == 2:
name, tp, = item
elif len(item) == 3:
@@ -1547,15 +1684,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
seen.add(name)
annotations[name] = tp
# We initially block the VALUE format, because inside dataclass() we'll
# call get_annotations(), which will try the VALUE format first. If we don't
# block, that means we'd always end up eagerly importing typing here, which
# is what we're trying to avoid.
value_blocked = True
def annotate_method(format):
def get_any():
match format:
case annotationlib.Format.STRING:
return 'typing.Any'
case annotationlib.Format.FORWARDREF:
typing = sys.modules.get("typing")
if typing is None:
return annotationlib.ForwardRef("Any", module="typing")
else:
return typing.Any
case annotationlib.Format.VALUE:
if value_blocked:
raise NotImplementedError
from typing import Any
return Any
case _:
raise NotImplementedError
annos = {
ann: get_any() if t is _ANY_MARKER else t
for ann, t in annotations.items()
}
if format == annotationlib.Format.STRING:
return annotationlib.annotations_to_string(annos)
else:
return annos
# Update 'ns' with the user-supplied namespace plus our calculated values.
def exec_body_callback(ns):
ns.update(namespace)
ns.update(defaults)
ns['__annotations__'] = annotations
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
# of generic dataclasses.
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
# For now, set annotations including the _ANY_MARKER.
cls.__annotate__ = annotate_method
# For pickling to work, the __module__ variable needs to be set to the frame
# where the dataclass is created.
@@ -1570,11 +1741,14 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
if module is not None:
cls.__module__ = module
# Apply the normal decorator.
return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen,
match_args=match_args, kw_only=kw_only, slots=slots,
weakref_slot=weakref_slot)
# Apply the normal provided decorator.
cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
unsafe_hash=unsafe_hash, frozen=frozen,
match_args=match_args, kw_only=kw_only, slots=slots,
weakref_slot=weakref_slot)
# Now that the class is ready, allow the VALUE format.
value_blocked = False
return cls
def replace(obj, /, **changes):

View File

@@ -5,6 +5,7 @@
from dataclasses import *
import abc
import annotationlib
import io
import pickle
import inspect
@@ -12,18 +13,21 @@ import builtins
import types
import weakref
import traceback
import sys
import textwrap
import unittest
from unittest.mock import Mock
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict
from typing import get_type_hints
from collections import deque, OrderedDict, namedtuple, defaultdict
from copy import deepcopy
from functools import total_ordering
from functools import total_ordering, wraps
import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
from test import support
from test.support import import_helper
# Just any custom exception we can catch.
class CustomError(Exception): pass
@@ -61,7 +65,7 @@ class TestCase(unittest.TestCase):
x: int = field(default=1, default_factory=int)
def test_field_repr(self):
int_field = field(default=1, init=True, repr=False)
int_field = field(default=1, init=True, repr=False, doc='Docstring')
int_field.name = "id"
repr_output = repr(int_field)
expected_output = "Field(name='id',type=None," \
@@ -69,6 +73,7 @@ class TestCase(unittest.TestCase):
"init=True,repr=False,hash=None," \
"compare=True,metadata=mappingproxy({})," \
f"kw_only={MISSING!r}," \
"doc='Docstring'," \
"_field_type=None)"
self.assertEqual(repr_output, expected_output)
@@ -115,7 +120,7 @@ class TestCase(unittest.TestCase):
for param in inspect.signature(dataclass).parameters:
if param == 'cls':
continue
self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param)
self.assertHasAttr(Some.__dataclass_params__, param)
def test_named_init_params(self):
@dataclass
@@ -666,7 +671,7 @@ class TestCase(unittest.TestCase):
self.assertEqual(the_fields[0].name, 'x')
self.assertEqual(the_fields[0].type, int)
self.assertFalse(hasattr(C, 'x'))
self.assertNotHasAttr(C, 'x')
self.assertTrue (the_fields[0].init)
self.assertTrue (the_fields[0].repr)
self.assertEqual(the_fields[1].name, 'y')
@@ -676,7 +681,7 @@ class TestCase(unittest.TestCase):
self.assertTrue (the_fields[1].repr)
self.assertEqual(the_fields[2].name, 'z')
self.assertEqual(the_fields[2].type, str)
self.assertFalse(hasattr(C, 'z'))
self.assertNotHasAttr(C, 'z')
self.assertTrue (the_fields[2].init)
self.assertFalse(the_fields[2].repr)
@@ -727,8 +732,8 @@ class TestCase(unittest.TestCase):
z: object = default
t: int = field(default=100)
self.assertFalse(hasattr(C, 'x'))
self.assertFalse(hasattr(C, 'y'))
self.assertNotHasAttr(C, 'x')
self.assertNotHasAttr(C, 'y')
self.assertIs (C.z, default)
self.assertEqual(C.t, 100)
@@ -922,6 +927,20 @@ class TestCase(unittest.TestCase):
validate_class(C)
def test_incomplete_annotations(self):
# gh-142214
@dataclass
class C:
"doc" # needed because otherwise we fetch the annotations at the wrong time
x: int
C.__annotate__ = lambda _: {}
self.assertEqual(
annotationlib.get_annotations(C.__init__),
{"return": None}
)
def test_missing_default(self):
# Test that MISSING works the same as a default not being
# specified.
@@ -1776,8 +1795,7 @@ class TestCase(unittest.TestCase):
self.assertIsNot(d['f'], t)
self.assertEqual(d['f'].my_a(), 6)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_helper_asdict_defaultdict(self):
# Ensure asdict() does not throw exceptions when a
# defaultdict is a member of a dataclass
@@ -1920,8 +1938,7 @@ class TestCase(unittest.TestCase):
t = astuple(c, tuple_factory=list)
self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_helper_astuple_defaultdict(self):
# Ensure astuple() does not throw exceptions when a
# defaultdict is a member of a dataclass
@@ -2311,13 +2328,12 @@ class TestDocString(unittest.TestCase):
self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_docstring_one_field_with_default_none(self):
@dataclass
class C:
x: Union[int, type(None)] = None
self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
self.assertDocStrEqual(C.__doc__, "C(x:int|None=None)")
def test_docstring_list_field(self):
@dataclass
@@ -2347,6 +2363,31 @@ class TestDocString(unittest.TestCase):
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
def test_docstring_undefined_name(self):
@dataclass
class C:
x: undef
self.assertDocStrEqual(C.__doc__, "C(x:undef)")
def test_docstring_with_unsolvable_forward_ref_in_init(self):
# See: https://github.com/python/cpython/issues/128184
ns = {}
exec(
textwrap.dedent(
"""
from dataclasses import dataclass
@dataclass
class C:
def __init__(self, x: X, num: int) -> None: ...
""",
),
ns,
)
self.assertDocStrEqual(ns['C'].__doc__, "C(x:X,num:int)")
def test_docstring_with_no_signature(self):
# See https://github.com/python/cpython/issues/103449
class Meta(type):
@@ -2446,6 +2487,149 @@ class TestInit(unittest.TestCase):
self.assertEqual(D(5).a, 10)
class TestInitAnnotate(unittest.TestCase):
# Tests for the generated __annotate__ function for __init__
# See: https://github.com/python/cpython/issues/137530
def test_annotate_function(self):
# No forward references
@dataclass
class A:
a: int
value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)
self.assertEqual(value_annos, {'a': int, 'return': None})
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})
self.assertTrue(getattr(A.__init__.__annotate__, "__generated_by_dataclasses__"))
def test_annotate_function_forwardref(self):
# With forward references
@dataclass
class B:
b: undefined
# VALUE annotations should raise while unresolvable
with self.assertRaises(NameError):
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
undefined = int
value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
self.assertEqual(value_annos, {'b': int, 'return': None})
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
def test_annotate_function_init_false(self):
# Check `init=False` attributes don't get into the annotations of the __init__ function
@dataclass
class C:
c: str = field(init=False)
self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
def test_annotate_function_contains_forwardref(self):
# Check string annotations on objects containing a ForwardRef
@dataclass
class D:
d: list[undefined]
with self.assertRaises(NameError):
annotationlib.get_annotations(D.__init__)
self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
)
self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
{"d": "list[undefined]", "return": "None"}
)
# Now test when it is defined
undefined = str
# VALUE should now resolve
self.assertEqual(
annotationlib.get_annotations(D.__init__),
{"d": list[str], "return": None}
)
self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
{"d": list[str], "return": None}
)
self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
{"d": "list[undefined]", "return": "None"}
)
def test_annotate_function_not_replaced(self):
# Check that __annotate__ is not replaced on non-generated __init__ functions
@dataclass(slots=True)
class E:
x: str
def __init__(self, x: int) -> None:
self.x = x
self.assertEqual(
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
)
self.assertFalse(hasattr(E.__init__.__annotate__, "__generated_by_dataclasses__"))
def test_slots_true_init_false(self):
# Test that slots=True and init=False work together and
# that __annotate__ is not added to __init__.
@dataclass(slots=True, init=False)
class F:
x: int
f = F()
f.x = 10
self.assertEqual(f.x, 10)
self.assertFalse(hasattr(F.__init__, "__annotate__"))
def test_init_false_forwardref(self):
# Test forward references in fields not required for __init__ annotations.
# At the moment this raises a NameError for VALUE annotations even though the
# undefined annotation is not required for the __init__ annotations.
# Ideally this will be fixed but currently there is no good way to resolve this
@dataclass
class F:
not_in_init: list[undefined] = field(init=False, default=None)
in_init: int
annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
self.assertEqual(
annos,
{"in_init": int, "return": None},
)
with self.assertRaises(NameError):
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init
class TestRepr(unittest.TestCase):
def test_repr(self):
@dataclass
@@ -2887,10 +3071,10 @@ class TestFrozen(unittest.TestCase):
pass
c = C()
self.assertFalse(hasattr(c, 'i'))
self.assertNotHasAttr(c, 'i')
with self.assertRaises(FrozenInstanceError):
c.i = 5
self.assertFalse(hasattr(c, 'i'))
self.assertNotHasAttr(c, 'i')
with self.assertRaises(FrozenInstanceError):
del c.i
@@ -3119,7 +3303,7 @@ class TestFrozen(unittest.TestCase):
del s.y
self.assertEqual(s.y, 10)
del s.cached
self.assertFalse(hasattr(s, 'cached'))
self.assertNotHasAttr(s, 'cached')
with self.assertRaises(AttributeError) as cm:
del s.cached
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
@@ -3133,12 +3317,12 @@ class TestFrozen(unittest.TestCase):
pass
s = S()
self.assertFalse(hasattr(s, 'x'))
self.assertNotHasAttr(s, 'x')
s.x = 5
self.assertEqual(s.x, 5)
del s.x
self.assertFalse(hasattr(s, 'x'))
self.assertNotHasAttr(s, 'x')
with self.assertRaises(AttributeError) as cm:
del s.x
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
@@ -3309,7 +3493,7 @@ class TestSlots(unittest.TestCase):
j: str
h: str
self.assertEqual(Base.__slots__, ('y', ))
self.assertEqual(Base.__slots__, ('y',))
@dataclass(slots=True)
class Derived(Base):
@@ -3319,7 +3503,7 @@ class TestSlots(unittest.TestCase):
k: str
h: str
self.assertEqual(Derived.__slots__, ('z', ))
self.assertEqual(Derived.__slots__, ('z',))
@dataclass
class AnotherDerived(Base):
@@ -3327,6 +3511,24 @@ class TestSlots(unittest.TestCase):
self.assertNotIn('__slots__', AnotherDerived.__dict__)
def test_slots_with_docs(self):
class Root:
__slots__ = {'x': 'x'}
@dataclass(slots=True)
class Base(Root):
y1: int = field(doc='y1')
y2: int
self.assertEqual(Base.__slots__, {'y1': 'y1', 'y2': None})
@dataclass(slots=True)
class Child(Base):
z1: int = field(doc='z1')
z2: int
self.assertEqual(Child.__slots__, {'z1': 'z1', 'z2': None})
def test_cant_inherit_from_iterator_slots(self):
class Root:
@@ -3350,8 +3552,8 @@ class TestSlots(unittest.TestCase):
B = dataclass(A, slots=True)
self.assertIsNot(A, B)
self.assertFalse(hasattr(A, "__slots__"))
self.assertTrue(hasattr(B, "__slots__"))
self.assertNotHasAttr(A, "__slots__")
self.assertHasAttr(B, "__slots__")
# Can't be local to test_frozen_pickle.
@dataclass(frozen=True, slots=True)
@@ -3470,8 +3672,7 @@ class TestSlots(unittest.TestCase):
self.assertEqual(obj.a, 'a')
self.assertEqual(obj.b, 'b')
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_slots_no_weakref(self):
@dataclass(slots=True)
class A:
@@ -3486,8 +3687,7 @@ class TestSlots(unittest.TestCase):
with self.assertRaises(AttributeError):
a.__weakref__
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_slots_weakref(self):
@dataclass(slots=True, weakref_slot=True)
class A:
@@ -3548,8 +3748,7 @@ class TestSlots(unittest.TestCase):
"weakref_slot is True but slots is False"):
B = make_dataclass('B', [('a', int),], weakref_slot=True)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_weakref_slot_subclass_weakref_slot(self):
@dataclass(slots=True, weakref_slot=True)
class Base:
@@ -3568,8 +3767,7 @@ class TestSlots(unittest.TestCase):
a_ref = weakref.ref(a)
self.assertIs(a.__weakref__, a_ref)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_weakref_slot_subclass_no_weakref_slot(self):
@dataclass(slots=True, weakref_slot=True)
class Base:
@@ -3587,8 +3785,7 @@ class TestSlots(unittest.TestCase):
a_ref = weakref.ref(a)
self.assertIs(a.__weakref__, a_ref)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_weakref_slot_normal_base_weakref_slot(self):
class Base:
__slots__ = ('__weakref__',)
@@ -3633,8 +3830,7 @@ class TestSlots(unittest.TestCase):
self.assertTrue(B.__weakref__)
B()
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_dataclass_derived_generic_from_base(self):
T = typing.TypeVar('T')
@@ -3726,7 +3922,6 @@ class TestSlots(unittest.TestCase):
@support.cpython_only
def test_dataclass_slot_dict_ctype(self):
# https://github.com/python/cpython/issues/123935
from test.support import import_helper
# Skips test if `_testcapi` is not present:
_testcapi = import_helper.import_module('_testcapi')
@@ -3774,6 +3969,50 @@ class TestSlots(unittest.TestCase):
# that we create internally.
self.assertEqual(CorrectSuper.args, ["default", "default"])
@unittest.skip("TODO: RUSTPYTHON; Crash - static type name must be already interned but async_generator_wrapped_value is not")
def test_original_class_is_gced(self):
# gh-135228: Make sure when we replace the class with slots=True, the original class
# gets garbage collected.
def make_simple():
@dataclass(slots=True)
class SlotsTest:
pass
return SlotsTest
def make_with_annotations():
@dataclass(slots=True)
class SlotsTest:
x: int
return SlotsTest
def make_with_annotations_and_method():
@dataclass(slots=True)
class SlotsTest:
x: int
def method(self) -> int:
return self.x
return SlotsTest
def make_with_forwardref():
@dataclass(slots=True)
class SlotsTest:
x: undefined
y: list[undefined]
return SlotsTest
for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
with self.subTest(make=make):
C = make()
support.gc_collect()
candidates = [cls for cls in object.__subclasses__() if cls.__name__ == 'SlotsTest'
and cls.__firstlineno__ == make.__code__.co_firstlineno + 1]
self.assertEqual(candidates, [C])
class TestDescriptors(unittest.TestCase):
def test_set_name(self):
@@ -4218,16 +4457,56 @@ class TestMakeDataclass(unittest.TestCase):
C = make_dataclass('Point', ['x', 'y', 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
'y': 'typing.Any',
'z': 'typing.Any'})
self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': typing.Any,
'z': typing.Any})
C = make_dataclass('Point', ['x', ('y', int), 'z'])
c = C(1, 2, 3)
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
self.assertEqual(C.__annotations__, {'x': typing.Any,
'y': int,
'z': 'typing.Any'})
'z': typing.Any})
def test_no_types_get_annotations(self):
C = make_dataclass('C', ['x', ('y', int), 'z'])
self.assertEqual(
annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
{'x': typing.Any, 'y': int, 'z': typing.Any},
)
self.assertEqual(
annotationlib.get_annotations(
C, format=annotationlib.Format.FORWARDREF),
{'x': typing.Any, 'y': int, 'z': typing.Any},
)
self.assertEqual(
annotationlib.get_annotations(
C, format=annotationlib.Format.STRING),
{'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
)
def test_no_types_no_typing_import(self):
with import_helper.CleanImport('typing'):
self.assertNotIn('typing', sys.modules)
C = make_dataclass('C', ['x', ('y', int)])
self.assertNotIn('typing', sys.modules)
self.assertEqual(
C.__annotate__(annotationlib.Format.FORWARDREF),
{
'x': annotationlib.ForwardRef('Any', module='typing'),
'y': int,
},
)
self.assertNotIn('typing', sys.modules)
for field in fields(C):
if field.name == "x":
self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
else:
self.assertEqual(field.name, "y")
self.assertIs(field.type, int)
def test_module_attr(self):
self.assertEqual(ByMakeDataClass.__module__, __name__)
@@ -4314,6 +4593,23 @@ class TestMakeDataclass(unittest.TestCase):
C = make_dataclass(classname, ['a', 'b'])
self.assertEqual(C.__name__, classname)
def test_dataclass_decorator_default(self):
C = make_dataclass('C', [('x', int)], decorator=dataclass)
c = C(10)
self.assertEqual(c.x, 10)
def test_dataclass_custom_decorator(self):
def custom_dataclass(cls, *args, **kwargs):
dc = dataclass(cls, *args, **kwargs)
dc.__custom__ = True
return dc
C = make_dataclass('C', [('x', int)], decorator=custom_dataclass)
c = C(10)
self.assertEqual(c.x, 10)
self.assertEqual(c.__custom__, True)
class TestReplace(unittest.TestCase):
def test(self):
@dataclass(frozen=True)
@@ -4562,8 +4858,7 @@ class TestAbstract(unittest.TestCase):
self.assertFalse(inspect.isabstract(Date))
self.assertGreater(Date(2020,12,25), Date(2020,8,31))
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_maintain_abc(self):
class A(abc.ABC):
@abc.abstractmethod
@@ -4919,6 +5214,140 @@ class TestKeywordArgs(unittest.TestCase):
self.assertTrue(fields(B)[0].kw_only)
self.assertFalse(fields(B)[1].kw_only)
def test_deferred_annotations(self):
@dataclass
class A:
x: undefined
y: ClassVar[undefined]
fs = fields(A)
self.assertEqual(len(fs), 1)
self.assertEqual(fs[0].name, 'x')
class TestZeroArgumentSuperWithSlots(unittest.TestCase):
def test_zero_argument_super(self):
@dataclass(slots=True)
class A:
def foo(self):
super()
A().foo()
def test_dunder_class_with_old_property(self):
@dataclass(slots=True)
class A:
def _get_foo(slf):
self.assertIs(__class__, type(slf))
self.assertIs(__class__, slf.__class__)
return __class__
def _set_foo(slf, value):
self.assertIs(__class__, type(slf))
self.assertIs(__class__, slf.__class__)
def _del_foo(slf):
self.assertIs(__class__, type(slf))
self.assertIs(__class__, slf.__class__)
foo = property(_get_foo, _set_foo, _del_foo)
a = A()
self.assertIs(a.foo, A)
a.foo = 4
del a.foo
def test_dunder_class_with_new_property(self):
@dataclass(slots=True)
class A:
@property
def foo(slf):
return slf.__class__
@foo.setter
def foo(slf, value):
self.assertIs(__class__, type(slf))
@foo.deleter
def foo(slf):
self.assertIs(__class__, type(slf))
a = A()
self.assertIs(a.foo, A)
a.foo = 4
del a.foo
# Test the parts of a property individually.
def test_slots_dunder_class_property_getter(self):
@dataclass(slots=True)
class A:
@property
def foo(slf):
return __class__
a = A()
self.assertIs(a.foo, A)
def test_slots_dunder_class_property_setter(self):
@dataclass(slots=True)
class A:
foo = property()
@foo.setter
def foo(slf, val):
self.assertIs(__class__, type(slf))
a = A()
a.foo = 4
def test_slots_dunder_class_property_deleter(self):
@dataclass(slots=True)
class A:
foo = property()
@foo.deleter
def foo(slf):
self.assertIs(__class__, type(slf))
a = A()
del a.foo
def test_wrapped(self):
def mydecorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
@dataclass(slots=True)
class A:
@mydecorator
def foo(self):
super()
A().foo()
def test_remembered_class(self):
# Apply the dataclass decorator manually (not when the class
# is created), so that we can keep a reference to the
# undecorated class.
class A:
def cls(self):
return __class__
self.assertIs(A().cls(), A)
B = dataclass(slots=True)(A)
self.assertIs(B().cls(), B)
# This is undesirable behavior, but is a function of how
# modifying __class__ in the closure works. I'm not sure this
# should be tested or not: I don't really want to guarantee
# this behavior, but I don't want to lose the point that this
# is how it works.
# The underlying class is "broken" by changing its __class__
# in A.foo() to B. This normally isn't a problem, because no
# one will be keeping a reference to the underlying class A.
self.assertIs(A().cls(), B)
if __name__ == '__main__':
unittest.main()
unittest.main()