mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Update dataclasses.py from 3.14.3
This commit is contained in:
committed by
Jeong, YunWon
parent
d2d8eeea2f
commit
32b57785c3
292
Lib/dataclasses.py
vendored
292
Lib/dataclasses.py
vendored
@@ -5,6 +5,7 @@ import types
|
||||
import inspect
|
||||
import keyword
|
||||
import itertools
|
||||
import annotationlib
|
||||
import abc
|
||||
from reprlib import recursive_repr
|
||||
|
||||
@@ -243,6 +244,10 @@ _ATOMIC_TYPES = frozenset({
|
||||
property,
|
||||
})
|
||||
|
||||
# Any marker is used in `make_dataclass` to mark unannotated fields as `Any`
|
||||
# without importing `typing` module.
|
||||
_ANY_MARKER = object()
|
||||
|
||||
|
||||
class InitVar:
|
||||
__slots__ = ('type', )
|
||||
@@ -282,11 +287,12 @@ class Field:
|
||||
'compare',
|
||||
'metadata',
|
||||
'kw_only',
|
||||
'doc',
|
||||
'_field_type', # Private: not to be used by user code.
|
||||
)
|
||||
|
||||
def __init__(self, default, default_factory, init, repr, hash, compare,
|
||||
metadata, kw_only):
|
||||
metadata, kw_only, doc):
|
||||
self.name = None
|
||||
self.type = None
|
||||
self.default = default
|
||||
@@ -299,6 +305,7 @@ class Field:
|
||||
if metadata is None else
|
||||
types.MappingProxyType(metadata))
|
||||
self.kw_only = kw_only
|
||||
self.doc = doc
|
||||
self._field_type = None
|
||||
|
||||
@recursive_repr()
|
||||
@@ -314,6 +321,7 @@ class Field:
|
||||
f'compare={self.compare!r},'
|
||||
f'metadata={self.metadata!r},'
|
||||
f'kw_only={self.kw_only!r},'
|
||||
f'doc={self.doc!r},'
|
||||
f'_field_type={self._field_type}'
|
||||
')')
|
||||
|
||||
@@ -381,7 +389,7 @@ class _DataclassParams:
|
||||
# so that a type checker can be told (via overloads) that this is a
|
||||
# function whose type depends on its parameters.
|
||||
def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
|
||||
hash=None, compare=True, metadata=None, kw_only=MISSING):
|
||||
hash=None, compare=True, metadata=None, kw_only=MISSING, doc=None):
|
||||
"""Return an object to identify dataclass fields.
|
||||
|
||||
default is the default value of the field. default_factory is a
|
||||
@@ -393,7 +401,7 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
|
||||
comparison functions. metadata, if specified, must be a mapping
|
||||
which is stored but not otherwise examined by dataclass. If kw_only
|
||||
is true, the field will become a keyword-only parameter to
|
||||
__init__().
|
||||
__init__(). doc is an optional docstring for this field.
|
||||
|
||||
It is an error to specify both default and default_factory.
|
||||
"""
|
||||
@@ -401,7 +409,7 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
|
||||
if default is not MISSING and default_factory is not MISSING:
|
||||
raise ValueError('cannot specify both default and default_factory')
|
||||
return Field(default, default_factory, init, repr, hash, compare,
|
||||
metadata, kw_only)
|
||||
metadata, kw_only, doc)
|
||||
|
||||
|
||||
def _fields_in_init_order(fields):
|
||||
@@ -433,9 +441,11 @@ class _FuncBuilder:
|
||||
self.locals = {}
|
||||
self.overwrite_errors = {}
|
||||
self.unconditional_adds = {}
|
||||
self.method_annotations = {}
|
||||
|
||||
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
|
||||
overwrite_error=False, unconditional_add=False, decorator=None):
|
||||
overwrite_error=False, unconditional_add=False, decorator=None,
|
||||
annotation_fields=None):
|
||||
if locals is not None:
|
||||
self.locals.update(locals)
|
||||
|
||||
@@ -456,16 +466,14 @@ class _FuncBuilder:
|
||||
|
||||
self.names.append(name)
|
||||
|
||||
if return_type is not MISSING:
|
||||
self.locals[f'__dataclass_{name}_return_type__'] = return_type
|
||||
return_annotation = f'->__dataclass_{name}_return_type__'
|
||||
else:
|
||||
return_annotation = ''
|
||||
if annotation_fields is not None:
|
||||
self.method_annotations[name] = (annotation_fields, return_type)
|
||||
|
||||
args = ','.join(args)
|
||||
body = '\n'.join(body)
|
||||
|
||||
# Compute the text of the entire function, add it to the text we're generating.
|
||||
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
|
||||
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
|
||||
|
||||
def add_fns_to_class(self, cls):
|
||||
# The source to all of the functions we're generating.
|
||||
@@ -501,6 +509,15 @@ class _FuncBuilder:
|
||||
# Now that we've generated the functions, assign them into cls.
|
||||
for name, fn in zip(self.names, fns):
|
||||
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
|
||||
|
||||
try:
|
||||
annotation_fields, return_type = self.method_annotations[name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
|
||||
fn.__annotate__ = annotate_fn
|
||||
|
||||
if self.unconditional_adds.get(name, False):
|
||||
setattr(cls, name, fn)
|
||||
else:
|
||||
@@ -516,6 +533,49 @@ class _FuncBuilder:
|
||||
raise TypeError(error_msg)
|
||||
|
||||
|
||||
def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
|
||||
# Create an __annotate__ function for a dataclass
|
||||
# Try to return annotations in the same format as they would be
|
||||
# from a regular __init__ function
|
||||
|
||||
def __annotate__(format, /):
|
||||
Format = annotationlib.Format
|
||||
match format:
|
||||
case Format.VALUE | Format.FORWARDREF | Format.STRING:
|
||||
cls_annotations = {}
|
||||
for base in reversed(__class__.__mro__):
|
||||
cls_annotations.update(
|
||||
annotationlib.get_annotations(base, format=format)
|
||||
)
|
||||
|
||||
new_annotations = {}
|
||||
for k in annotation_fields:
|
||||
# gh-142214: The annotation may be missing in unusual dynamic cases.
|
||||
# If so, just skip it.
|
||||
try:
|
||||
new_annotations[k] = cls_annotations[k]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if return_type is not MISSING:
|
||||
if format == Format.STRING:
|
||||
new_annotations["return"] = annotationlib.type_repr(return_type)
|
||||
else:
|
||||
new_annotations["return"] = return_type
|
||||
|
||||
return new_annotations
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(format)
|
||||
|
||||
# This is a flag for _add_slots to know it needs to regenerate this method
|
||||
# In order to remove references to the original class when it is replaced
|
||||
__annotate__.__generated_by_dataclasses__ = True
|
||||
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
|
||||
|
||||
return __annotate__
|
||||
|
||||
|
||||
def _field_assign(frozen, name, value, self_name):
|
||||
# If we're a frozen class, then assign to our fields in __init__
|
||||
# via object.__setattr__. Otherwise, just use a simple
|
||||
@@ -604,7 +664,7 @@ def _init_param(f):
|
||||
elif f.default_factory is not MISSING:
|
||||
# There's a factory function. Set a marker.
|
||||
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
|
||||
return f'{f.name}:__dataclass_type_{f.name}__{default}'
|
||||
return f'{f.name}{default}'
|
||||
|
||||
|
||||
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
||||
@@ -627,11 +687,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
||||
raise TypeError(f'non-default argument {f.name!r} '
|
||||
f'follows default argument {seen_default.name!r}')
|
||||
|
||||
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
|
||||
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||
'__dataclass_builtins_object__': object,
|
||||
}
|
||||
}
|
||||
annotation_fields = [f.name for f in fields if f.init]
|
||||
|
||||
locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
|
||||
'__dataclass_builtins_object__': object}
|
||||
|
||||
body_lines = []
|
||||
for f in fields:
|
||||
@@ -655,14 +714,15 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
|
||||
if kw_only_fields:
|
||||
# Add the keyword-only args. Because the * can only be added if
|
||||
# there's at least one keyword-only arg, there needs to be a test here
|
||||
# (instead of just concatenting the lists together).
|
||||
# (instead of just concatenating the lists together).
|
||||
_init_params += ['*']
|
||||
_init_params += [_init_param(f) for f in kw_only_fields]
|
||||
func_builder.add_fn('__init__',
|
||||
[self_name] + _init_params,
|
||||
body_lines,
|
||||
locals=locals,
|
||||
return_type=None)
|
||||
return_type=None,
|
||||
annotation_fields=annotation_fields)
|
||||
|
||||
|
||||
def _frozen_get_del_attr(cls, fields, func_builder):
|
||||
@@ -689,11 +749,8 @@ def _frozen_get_del_attr(cls, fields, func_builder):
|
||||
|
||||
|
||||
def _is_classvar(a_type, typing):
|
||||
# This test uses a typing internal class, but it's the best way to
|
||||
# test if this is a ClassVar.
|
||||
return (a_type is typing.ClassVar
|
||||
or (type(a_type) is typing._GenericAlias
|
||||
and a_type.__origin__ is typing.ClassVar))
|
||||
or (typing.get_origin(a_type) is typing.ClassVar))
|
||||
|
||||
|
||||
def _is_initvar(a_type, dataclasses):
|
||||
@@ -981,7 +1038,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
||||
# actual default value. Pseudo-fields ClassVars and InitVars are
|
||||
# included, despite the fact that they're not real fields. That's
|
||||
# dealt with later.
|
||||
cls_annotations = inspect.get_annotations(cls)
|
||||
cls_annotations = annotationlib.get_annotations(
|
||||
cls, format=annotationlib.Format.FORWARDREF)
|
||||
|
||||
# Now find fields in our class. While doing so, validate some
|
||||
# things, and set the default values (as class attributes) where
|
||||
@@ -1161,7 +1219,10 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
||||
try:
|
||||
# In some cases fetching a signature is not possible.
|
||||
# But, we surely should not fail in this case.
|
||||
text_sig = str(inspect.signature(cls)).replace(' -> None', '')
|
||||
text_sig = str(inspect.signature(
|
||||
cls,
|
||||
annotation_format=annotationlib.Format.FORWARDREF,
|
||||
)).replace(' -> None', '')
|
||||
except (TypeError, ValueError):
|
||||
text_sig = ''
|
||||
cls.__doc__ = (cls.__name__ + text_sig)
|
||||
@@ -1175,7 +1236,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
|
||||
if weakref_slot and not slots:
|
||||
raise TypeError('weakref_slot is True but slots is False')
|
||||
if slots:
|
||||
cls = _add_slots(cls, frozen, weakref_slot)
|
||||
cls = _add_slots(cls, frozen, weakref_slot, fields)
|
||||
|
||||
abc.update_abstractmethods(cls)
|
||||
|
||||
@@ -1219,14 +1280,65 @@ def _get_slots(cls):
|
||||
raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
|
||||
|
||||
|
||||
def _add_slots(cls, is_frozen, weakref_slot):
|
||||
# Need to create a new class, since we can't set __slots__
|
||||
# after a class has been created.
|
||||
def _update_func_cell_for__class__(f, oldcls, newcls):
|
||||
# Returns True if we update a cell, else False.
|
||||
if f is None:
|
||||
# f will be None in the case of a property where not all of
|
||||
# fget, fset, and fdel are used. Nothing to do in that case.
|
||||
return False
|
||||
try:
|
||||
idx = f.__code__.co_freevars.index("__class__")
|
||||
except ValueError:
|
||||
# This function doesn't reference __class__, so nothing to do.
|
||||
return False
|
||||
# Fix the cell to point to the new class, if it's already pointing
|
||||
# at the old class. I'm not convinced that the "is oldcls" test
|
||||
# is needed, but other than performance can't hurt.
|
||||
closure = f.__closure__[idx]
|
||||
if closure.cell_contents is oldcls:
|
||||
closure.cell_contents = newcls
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot):
|
||||
# The slots for our class. Remove slots from our base classes. Add
|
||||
# '__weakref__' if weakref_slot was given, unless it is already present.
|
||||
seen_docs = False
|
||||
slots = {}
|
||||
for slot in itertools.filterfalse(
|
||||
inherited_slots.__contains__,
|
||||
itertools.chain(
|
||||
# gh-93521: '__weakref__' also needs to be filtered out if
|
||||
# already present in inherited_slots
|
||||
field_names, ('__weakref__',) if weakref_slot else ()
|
||||
)
|
||||
):
|
||||
doc = getattr(defined_fields.get(slot), 'doc', None)
|
||||
if doc is not None:
|
||||
seen_docs = True
|
||||
slots[slot] = doc
|
||||
|
||||
# We only return dict if there's at least one doc member,
|
||||
# otherwise we return tuple, which is the old default format.
|
||||
if seen_docs:
|
||||
return slots
|
||||
return tuple(slots)
|
||||
|
||||
|
||||
def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
|
||||
# Need to create a new class, since we can't set __slots__ after a
|
||||
# class has been created, and the @dataclass decorator is called
|
||||
# after the class is created.
|
||||
|
||||
# Make sure __slots__ isn't already set.
|
||||
if '__slots__' in cls.__dict__:
|
||||
raise TypeError(f'{cls.__name__} already specifies __slots__')
|
||||
|
||||
# gh-102069: Remove existing __weakref__ descriptor.
|
||||
# gh-135228: Make sure the original class can be garbage collected.
|
||||
sys._clear_type_descriptors(cls)
|
||||
|
||||
# Create a new dict for our new class.
|
||||
cls_dict = dict(cls.__dict__)
|
||||
field_names = tuple(f.name for f in fields(cls))
|
||||
@@ -1234,17 +1346,9 @@ def _add_slots(cls, is_frozen, weakref_slot):
|
||||
inherited_slots = set(
|
||||
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
|
||||
)
|
||||
# The slots for our class. Remove slots from our base classes. Add
|
||||
# '__weakref__' if weakref_slot was given, unless it is already present.
|
||||
cls_dict["__slots__"] = tuple(
|
||||
itertools.filterfalse(
|
||||
inherited_slots.__contains__,
|
||||
itertools.chain(
|
||||
# gh-93521: '__weakref__' also needs to be filtered out if
|
||||
# already present in inherited_slots
|
||||
field_names, ('__weakref__',) if weakref_slot else ()
|
||||
)
|
||||
),
|
||||
|
||||
cls_dict["__slots__"] = _create_slots(
|
||||
defined_fields, inherited_slots, field_names, weakref_slot,
|
||||
)
|
||||
|
||||
for field_name in field_names:
|
||||
@@ -1252,26 +1356,59 @@ def _add_slots(cls, is_frozen, weakref_slot):
|
||||
# available in _MARKER.
|
||||
cls_dict.pop(field_name, None)
|
||||
|
||||
# Remove __dict__ itself.
|
||||
cls_dict.pop('__dict__', None)
|
||||
|
||||
# Clear existing `__weakref__` descriptor, it belongs to a previous type:
|
||||
cls_dict.pop('__weakref__', None) # gh-102069
|
||||
|
||||
# And finally create the class.
|
||||
qualname = getattr(cls, '__qualname__', None)
|
||||
cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
|
||||
newcls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
|
||||
if qualname is not None:
|
||||
cls.__qualname__ = qualname
|
||||
newcls.__qualname__ = qualname
|
||||
|
||||
if is_frozen:
|
||||
# Need this for pickling frozen classes with slots.
|
||||
if '__getstate__' not in cls_dict:
|
||||
cls.__getstate__ = _dataclass_getstate
|
||||
newcls.__getstate__ = _dataclass_getstate
|
||||
if '__setstate__' not in cls_dict:
|
||||
cls.__setstate__ = _dataclass_setstate
|
||||
newcls.__setstate__ = _dataclass_setstate
|
||||
|
||||
return cls
|
||||
# Fix up any closures which reference __class__. This is used to
|
||||
# fix zero argument super so that it points to the correct class
|
||||
# (the newly created one, which we're returning) and not the
|
||||
# original class. We can break out of this loop as soon as we
|
||||
# make an update, since all closures for a class will share a
|
||||
# given cell.
|
||||
for member in newcls.__dict__.values():
|
||||
# If this is a wrapped function, unwrap it.
|
||||
member = inspect.unwrap(member)
|
||||
|
||||
if isinstance(member, types.FunctionType):
|
||||
if _update_func_cell_for__class__(member, cls, newcls):
|
||||
break
|
||||
elif isinstance(member, property):
|
||||
if (_update_func_cell_for__class__(member.fget, cls, newcls)
|
||||
or _update_func_cell_for__class__(member.fset, cls, newcls)
|
||||
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
|
||||
break
|
||||
|
||||
# Get new annotations to remove references to the original class
|
||||
# in forward references
|
||||
newcls_ann = annotationlib.get_annotations(
|
||||
newcls, format=annotationlib.Format.FORWARDREF)
|
||||
|
||||
# Fix references in dataclass Fields
|
||||
for f in getattr(newcls, _FIELDS).values():
|
||||
try:
|
||||
ann = newcls_ann[f.name]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
f.type = ann
|
||||
|
||||
# Fix the class reference in the __annotate__ method
|
||||
init = newcls.__init__
|
||||
if init_annotate := getattr(init, "__annotate__", None):
|
||||
if getattr(init_annotate, "__generated_by_dataclasses__", False):
|
||||
_update_func_cell_for__class__(init_annotate, cls, newcls)
|
||||
|
||||
return newcls
|
||||
|
||||
|
||||
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
|
||||
@@ -1490,7 +1627,7 @@ def _astuple_inner(obj, tuple_factory):
|
||||
def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||
repr=True, eq=True, order=False, unsafe_hash=False,
|
||||
frozen=False, match_args=True, kw_only=False, slots=False,
|
||||
weakref_slot=False, module=None):
|
||||
weakref_slot=False, module=None, decorator=dataclass):
|
||||
"""Return a new dynamically created dataclass.
|
||||
|
||||
The dataclass name will be 'cls_name'. 'fields' is an iterable
|
||||
@@ -1528,7 +1665,7 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||
for item in fields:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
tp = 'typing.Any'
|
||||
tp = _ANY_MARKER
|
||||
elif len(item) == 2:
|
||||
name, tp, = item
|
||||
elif len(item) == 3:
|
||||
@@ -1547,15 +1684,49 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||
seen.add(name)
|
||||
annotations[name] = tp
|
||||
|
||||
# We initially block the VALUE format, because inside dataclass() we'll
|
||||
# call get_annotations(), which will try the VALUE format first. If we don't
|
||||
# block, that means we'd always end up eagerly importing typing here, which
|
||||
# is what we're trying to avoid.
|
||||
value_blocked = True
|
||||
|
||||
def annotate_method(format):
|
||||
def get_any():
|
||||
match format:
|
||||
case annotationlib.Format.STRING:
|
||||
return 'typing.Any'
|
||||
case annotationlib.Format.FORWARDREF:
|
||||
typing = sys.modules.get("typing")
|
||||
if typing is None:
|
||||
return annotationlib.ForwardRef("Any", module="typing")
|
||||
else:
|
||||
return typing.Any
|
||||
case annotationlib.Format.VALUE:
|
||||
if value_blocked:
|
||||
raise NotImplementedError
|
||||
from typing import Any
|
||||
return Any
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
annos = {
|
||||
ann: get_any() if t is _ANY_MARKER else t
|
||||
for ann, t in annotations.items()
|
||||
}
|
||||
if format == annotationlib.Format.STRING:
|
||||
return annotationlib.annotations_to_string(annos)
|
||||
else:
|
||||
return annos
|
||||
|
||||
# Update 'ns' with the user-supplied namespace plus our calculated values.
|
||||
def exec_body_callback(ns):
|
||||
ns.update(namespace)
|
||||
ns.update(defaults)
|
||||
ns['__annotations__'] = annotations
|
||||
|
||||
# We use `types.new_class()` instead of simply `type()` to allow dynamic creation
|
||||
# of generic dataclasses.
|
||||
cls = types.new_class(cls_name, bases, {}, exec_body_callback)
|
||||
# For now, set annotations including the _ANY_MARKER.
|
||||
cls.__annotate__ = annotate_method
|
||||
|
||||
# For pickling to work, the __module__ variable needs to be set to the frame
|
||||
# where the dataclass is created.
|
||||
@@ -1570,11 +1741,14 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
|
||||
if module is not None:
|
||||
cls.__module__ = module
|
||||
|
||||
# Apply the normal decorator.
|
||||
return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
|
||||
unsafe_hash=unsafe_hash, frozen=frozen,
|
||||
match_args=match_args, kw_only=kw_only, slots=slots,
|
||||
weakref_slot=weakref_slot)
|
||||
# Apply the normal provided decorator.
|
||||
cls = decorator(cls, init=init, repr=repr, eq=eq, order=order,
|
||||
unsafe_hash=unsafe_hash, frozen=frozen,
|
||||
match_args=match_args, kw_only=kw_only, slots=slots,
|
||||
weakref_slot=weakref_slot)
|
||||
# Now that the class is ready, allow the VALUE format.
|
||||
value_blocked = False
|
||||
return cls
|
||||
|
||||
|
||||
def replace(obj, /, **changes):
|
||||
|
||||
515
Lib/test/test_dataclasses/__init__.py
vendored
515
Lib/test/test_dataclasses/__init__.py
vendored
@@ -5,6 +5,7 @@
|
||||
from dataclasses import *
|
||||
|
||||
import abc
|
||||
import annotationlib
|
||||
import io
|
||||
import pickle
|
||||
import inspect
|
||||
@@ -12,18 +13,21 @@ import builtins
|
||||
import types
|
||||
import weakref
|
||||
import traceback
|
||||
import sys
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict
|
||||
from typing import get_type_hints
|
||||
from collections import deque, OrderedDict, namedtuple, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import total_ordering
|
||||
from functools import total_ordering, wraps
|
||||
|
||||
import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
|
||||
import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
|
||||
|
||||
from test import support
|
||||
from test.support import import_helper
|
||||
|
||||
# Just any custom exception we can catch.
|
||||
class CustomError(Exception): pass
|
||||
@@ -61,7 +65,7 @@ class TestCase(unittest.TestCase):
|
||||
x: int = field(default=1, default_factory=int)
|
||||
|
||||
def test_field_repr(self):
|
||||
int_field = field(default=1, init=True, repr=False)
|
||||
int_field = field(default=1, init=True, repr=False, doc='Docstring')
|
||||
int_field.name = "id"
|
||||
repr_output = repr(int_field)
|
||||
expected_output = "Field(name='id',type=None," \
|
||||
@@ -69,6 +73,7 @@ class TestCase(unittest.TestCase):
|
||||
"init=True,repr=False,hash=None," \
|
||||
"compare=True,metadata=mappingproxy({})," \
|
||||
f"kw_only={MISSING!r}," \
|
||||
"doc='Docstring'," \
|
||||
"_field_type=None)"
|
||||
|
||||
self.assertEqual(repr_output, expected_output)
|
||||
@@ -115,7 +120,7 @@ class TestCase(unittest.TestCase):
|
||||
for param in inspect.signature(dataclass).parameters:
|
||||
if param == 'cls':
|
||||
continue
|
||||
self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param)
|
||||
self.assertHasAttr(Some.__dataclass_params__, param)
|
||||
|
||||
def test_named_init_params(self):
|
||||
@dataclass
|
||||
@@ -666,7 +671,7 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(the_fields[0].name, 'x')
|
||||
self.assertEqual(the_fields[0].type, int)
|
||||
self.assertFalse(hasattr(C, 'x'))
|
||||
self.assertNotHasAttr(C, 'x')
|
||||
self.assertTrue (the_fields[0].init)
|
||||
self.assertTrue (the_fields[0].repr)
|
||||
self.assertEqual(the_fields[1].name, 'y')
|
||||
@@ -676,7 +681,7 @@ class TestCase(unittest.TestCase):
|
||||
self.assertTrue (the_fields[1].repr)
|
||||
self.assertEqual(the_fields[2].name, 'z')
|
||||
self.assertEqual(the_fields[2].type, str)
|
||||
self.assertFalse(hasattr(C, 'z'))
|
||||
self.assertNotHasAttr(C, 'z')
|
||||
self.assertTrue (the_fields[2].init)
|
||||
self.assertFalse(the_fields[2].repr)
|
||||
|
||||
@@ -727,8 +732,8 @@ class TestCase(unittest.TestCase):
|
||||
z: object = default
|
||||
t: int = field(default=100)
|
||||
|
||||
self.assertFalse(hasattr(C, 'x'))
|
||||
self.assertFalse(hasattr(C, 'y'))
|
||||
self.assertNotHasAttr(C, 'x')
|
||||
self.assertNotHasAttr(C, 'y')
|
||||
self.assertIs (C.z, default)
|
||||
self.assertEqual(C.t, 100)
|
||||
|
||||
@@ -922,6 +927,20 @@ class TestCase(unittest.TestCase):
|
||||
|
||||
validate_class(C)
|
||||
|
||||
def test_incomplete_annotations(self):
|
||||
# gh-142214
|
||||
@dataclass
|
||||
class C:
|
||||
"doc" # needed because otherwise we fetch the annotations at the wrong time
|
||||
x: int
|
||||
|
||||
C.__annotate__ = lambda _: {}
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(C.__init__),
|
||||
{"return": None}
|
||||
)
|
||||
|
||||
def test_missing_default(self):
|
||||
# Test that MISSING works the same as a default not being
|
||||
# specified.
|
||||
@@ -1776,8 +1795,7 @@ class TestCase(unittest.TestCase):
|
||||
self.assertIsNot(d['f'], t)
|
||||
self.assertEqual(d['f'].my_a(), 6)
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_helper_asdict_defaultdict(self):
|
||||
# Ensure asdict() does not throw exceptions when a
|
||||
# defaultdict is a member of a dataclass
|
||||
@@ -1920,8 +1938,7 @@ class TestCase(unittest.TestCase):
|
||||
t = astuple(c, tuple_factory=list)
|
||||
self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_helper_astuple_defaultdict(self):
|
||||
# Ensure astuple() does not throw exceptions when a
|
||||
# defaultdict is a member of a dataclass
|
||||
@@ -2311,13 +2328,12 @@ class TestDocString(unittest.TestCase):
|
||||
|
||||
self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_docstring_one_field_with_default_none(self):
|
||||
@dataclass
|
||||
class C:
|
||||
x: Union[int, type(None)] = None
|
||||
|
||||
self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
|
||||
self.assertDocStrEqual(C.__doc__, "C(x:int|None=None)")
|
||||
|
||||
def test_docstring_list_field(self):
|
||||
@dataclass
|
||||
@@ -2347,6 +2363,31 @@ class TestDocString(unittest.TestCase):
|
||||
|
||||
self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
|
||||
|
||||
def test_docstring_undefined_name(self):
|
||||
@dataclass
|
||||
class C:
|
||||
x: undef
|
||||
|
||||
self.assertDocStrEqual(C.__doc__, "C(x:undef)")
|
||||
|
||||
def test_docstring_with_unsolvable_forward_ref_in_init(self):
|
||||
# See: https://github.com/python/cpython/issues/128184
|
||||
ns = {}
|
||||
exec(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class C:
|
||||
def __init__(self, x: X, num: int) -> None: ...
|
||||
""",
|
||||
),
|
||||
ns,
|
||||
)
|
||||
|
||||
self.assertDocStrEqual(ns['C'].__doc__, "C(x:X,num:int)")
|
||||
|
||||
def test_docstring_with_no_signature(self):
|
||||
# See https://github.com/python/cpython/issues/103449
|
||||
class Meta(type):
|
||||
@@ -2446,6 +2487,149 @@ class TestInit(unittest.TestCase):
|
||||
self.assertEqual(D(5).a, 10)
|
||||
|
||||
|
||||
class TestInitAnnotate(unittest.TestCase):
|
||||
# Tests for the generated __annotate__ function for __init__
|
||||
# See: https://github.com/python/cpython/issues/137530
|
||||
|
||||
def test_annotate_function(self):
|
||||
# No forward references
|
||||
@dataclass
|
||||
class A:
|
||||
a: int
|
||||
|
||||
value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
|
||||
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
|
||||
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)
|
||||
|
||||
self.assertEqual(value_annos, {'a': int, 'return': None})
|
||||
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
|
||||
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})
|
||||
|
||||
self.assertTrue(getattr(A.__init__.__annotate__, "__generated_by_dataclasses__"))
|
||||
|
||||
def test_annotate_function_forwardref(self):
|
||||
# With forward references
|
||||
@dataclass
|
||||
class B:
|
||||
b: undefined
|
||||
|
||||
# VALUE annotations should raise while unresolvable
|
||||
with self.assertRaises(NameError):
|
||||
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
|
||||
|
||||
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
|
||||
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
|
||||
|
||||
self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
|
||||
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
|
||||
|
||||
# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
|
||||
undefined = int
|
||||
|
||||
value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
|
||||
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
|
||||
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
|
||||
|
||||
self.assertEqual(value_annos, {'b': int, 'return': None})
|
||||
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
|
||||
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
|
||||
|
||||
def test_annotate_function_init_false(self):
|
||||
# Check `init=False` attributes don't get into the annotations of the __init__ function
|
||||
@dataclass
|
||||
class C:
|
||||
c: str = field(init=False)
|
||||
|
||||
self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
|
||||
|
||||
def test_annotate_function_contains_forwardref(self):
|
||||
# Check string annotations on objects containing a ForwardRef
|
||||
@dataclass
|
||||
class D:
|
||||
d: list[undefined]
|
||||
|
||||
with self.assertRaises(NameError):
|
||||
annotationlib.get_annotations(D.__init__)
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
|
||||
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
|
||||
{"d": "list[undefined]", "return": "None"}
|
||||
)
|
||||
|
||||
# Now test when it is defined
|
||||
undefined = str
|
||||
|
||||
# VALUE should now resolve
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(D.__init__),
|
||||
{"d": list[str], "return": None}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
|
||||
{"d": list[str], "return": None}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
|
||||
{"d": "list[undefined]", "return": "None"}
|
||||
)
|
||||
|
||||
def test_annotate_function_not_replaced(self):
|
||||
# Check that __annotate__ is not replaced on non-generated __init__ functions
|
||||
@dataclass(slots=True)
|
||||
class E:
|
||||
x: str
|
||||
def __init__(self, x: int) -> None:
|
||||
self.x = x
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
|
||||
)
|
||||
|
||||
self.assertFalse(hasattr(E.__init__.__annotate__, "__generated_by_dataclasses__"))
|
||||
|
||||
def test_slots_true_init_false(self):
|
||||
# Test that slots=True and init=False work together and
|
||||
# that __annotate__ is not added to __init__.
|
||||
|
||||
@dataclass(slots=True, init=False)
|
||||
class F:
|
||||
x: int
|
||||
|
||||
f = F()
|
||||
f.x = 10
|
||||
self.assertEqual(f.x, 10)
|
||||
|
||||
self.assertFalse(hasattr(F.__init__, "__annotate__"))
|
||||
|
||||
def test_init_false_forwardref(self):
|
||||
# Test forward references in fields not required for __init__ annotations.
|
||||
|
||||
# At the moment this raises a NameError for VALUE annotations even though the
|
||||
# undefined annotation is not required for the __init__ annotations.
|
||||
# Ideally this will be fixed but currently there is no good way to resolve this
|
||||
|
||||
@dataclass
|
||||
class F:
|
||||
not_in_init: list[undefined] = field(init=False, default=None)
|
||||
in_init: int
|
||||
|
||||
annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
|
||||
self.assertEqual(
|
||||
annos,
|
||||
{"in_init": int, "return": None},
|
||||
)
|
||||
|
||||
with self.assertRaises(NameError):
|
||||
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init
|
||||
|
||||
|
||||
class TestRepr(unittest.TestCase):
|
||||
def test_repr(self):
|
||||
@dataclass
|
||||
@@ -2887,10 +3071,10 @@ class TestFrozen(unittest.TestCase):
|
||||
pass
|
||||
|
||||
c = C()
|
||||
self.assertFalse(hasattr(c, 'i'))
|
||||
self.assertNotHasAttr(c, 'i')
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
c.i = 5
|
||||
self.assertFalse(hasattr(c, 'i'))
|
||||
self.assertNotHasAttr(c, 'i')
|
||||
with self.assertRaises(FrozenInstanceError):
|
||||
del c.i
|
||||
|
||||
@@ -3119,7 +3303,7 @@ class TestFrozen(unittest.TestCase):
|
||||
del s.y
|
||||
self.assertEqual(s.y, 10)
|
||||
del s.cached
|
||||
self.assertFalse(hasattr(s, 'cached'))
|
||||
self.assertNotHasAttr(s, 'cached')
|
||||
with self.assertRaises(AttributeError) as cm:
|
||||
del s.cached
|
||||
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
|
||||
@@ -3133,12 +3317,12 @@ class TestFrozen(unittest.TestCase):
|
||||
pass
|
||||
|
||||
s = S()
|
||||
self.assertFalse(hasattr(s, 'x'))
|
||||
self.assertNotHasAttr(s, 'x')
|
||||
s.x = 5
|
||||
self.assertEqual(s.x, 5)
|
||||
|
||||
del s.x
|
||||
self.assertFalse(hasattr(s, 'x'))
|
||||
self.assertNotHasAttr(s, 'x')
|
||||
with self.assertRaises(AttributeError) as cm:
|
||||
del s.x
|
||||
self.assertNotIsInstance(cm.exception, FrozenInstanceError)
|
||||
@@ -3309,7 +3493,7 @@ class TestSlots(unittest.TestCase):
|
||||
j: str
|
||||
h: str
|
||||
|
||||
self.assertEqual(Base.__slots__, ('y', ))
|
||||
self.assertEqual(Base.__slots__, ('y',))
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Derived(Base):
|
||||
@@ -3319,7 +3503,7 @@ class TestSlots(unittest.TestCase):
|
||||
k: str
|
||||
h: str
|
||||
|
||||
self.assertEqual(Derived.__slots__, ('z', ))
|
||||
self.assertEqual(Derived.__slots__, ('z',))
|
||||
|
||||
@dataclass
|
||||
class AnotherDerived(Base):
|
||||
@@ -3327,6 +3511,24 @@ class TestSlots(unittest.TestCase):
|
||||
|
||||
self.assertNotIn('__slots__', AnotherDerived.__dict__)
|
||||
|
||||
def test_slots_with_docs(self):
|
||||
class Root:
|
||||
__slots__ = {'x': 'x'}
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Base(Root):
|
||||
y1: int = field(doc='y1')
|
||||
y2: int
|
||||
|
||||
self.assertEqual(Base.__slots__, {'y1': 'y1', 'y2': None})
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Child(Base):
|
||||
z1: int = field(doc='z1')
|
||||
z2: int
|
||||
|
||||
self.assertEqual(Child.__slots__, {'z1': 'z1', 'z2': None})
|
||||
|
||||
def test_cant_inherit_from_iterator_slots(self):
|
||||
|
||||
class Root:
|
||||
@@ -3350,8 +3552,8 @@ class TestSlots(unittest.TestCase):
|
||||
B = dataclass(A, slots=True)
|
||||
self.assertIsNot(A, B)
|
||||
|
||||
self.assertFalse(hasattr(A, "__slots__"))
|
||||
self.assertTrue(hasattr(B, "__slots__"))
|
||||
self.assertNotHasAttr(A, "__slots__")
|
||||
self.assertHasAttr(B, "__slots__")
|
||||
|
||||
# Can't be local to test_frozen_pickle.
|
||||
@dataclass(frozen=True, slots=True)
|
||||
@@ -3470,8 +3672,7 @@ class TestSlots(unittest.TestCase):
|
||||
self.assertEqual(obj.a, 'a')
|
||||
self.assertEqual(obj.b, 'b')
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_slots_no_weakref(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
@@ -3486,8 +3687,7 @@ class TestSlots(unittest.TestCase):
|
||||
with self.assertRaises(AttributeError):
|
||||
a.__weakref__
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_slots_weakref(self):
|
||||
@dataclass(slots=True, weakref_slot=True)
|
||||
class A:
|
||||
@@ -3548,8 +3748,7 @@ class TestSlots(unittest.TestCase):
|
||||
"weakref_slot is True but slots is False"):
|
||||
B = make_dataclass('B', [('a', int),], weakref_slot=True)
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_weakref_slot_subclass_weakref_slot(self):
|
||||
@dataclass(slots=True, weakref_slot=True)
|
||||
class Base:
|
||||
@@ -3568,8 +3767,7 @@ class TestSlots(unittest.TestCase):
|
||||
a_ref = weakref.ref(a)
|
||||
self.assertIs(a.__weakref__, a_ref)
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_weakref_slot_subclass_no_weakref_slot(self):
|
||||
@dataclass(slots=True, weakref_slot=True)
|
||||
class Base:
|
||||
@@ -3587,8 +3785,7 @@ class TestSlots(unittest.TestCase):
|
||||
a_ref = weakref.ref(a)
|
||||
self.assertIs(a.__weakref__, a_ref)
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_weakref_slot_normal_base_weakref_slot(self):
|
||||
class Base:
|
||||
__slots__ = ('__weakref__',)
|
||||
@@ -3633,8 +3830,7 @@ class TestSlots(unittest.TestCase):
|
||||
self.assertTrue(B.__weakref__)
|
||||
B()
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_dataclass_derived_generic_from_base(self):
|
||||
T = typing.TypeVar('T')
|
||||
|
||||
@@ -3726,7 +3922,6 @@ class TestSlots(unittest.TestCase):
|
||||
@support.cpython_only
|
||||
def test_dataclass_slot_dict_ctype(self):
|
||||
# https://github.com/python/cpython/issues/123935
|
||||
from test.support import import_helper
|
||||
# Skips test if `_testcapi` is not present:
|
||||
_testcapi = import_helper.import_module('_testcapi')
|
||||
|
||||
@@ -3774,6 +3969,50 @@ class TestSlots(unittest.TestCase):
|
||||
# that we create internally.
|
||||
self.assertEqual(CorrectSuper.args, ["default", "default"])
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON; Crash - static type name must be already interned but async_generator_wrapped_value is not")
|
||||
def test_original_class_is_gced(self):
|
||||
# gh-135228: Make sure when we replace the class with slots=True, the original class
|
||||
# gets garbage collected.
|
||||
def make_simple():
|
||||
@dataclass(slots=True)
|
||||
class SlotsTest:
|
||||
pass
|
||||
|
||||
return SlotsTest
|
||||
|
||||
def make_with_annotations():
|
||||
@dataclass(slots=True)
|
||||
class SlotsTest:
|
||||
x: int
|
||||
|
||||
return SlotsTest
|
||||
|
||||
def make_with_annotations_and_method():
|
||||
@dataclass(slots=True)
|
||||
class SlotsTest:
|
||||
x: int
|
||||
|
||||
def method(self) -> int:
|
||||
return self.x
|
||||
|
||||
return SlotsTest
|
||||
|
||||
def make_with_forwardref():
|
||||
@dataclass(slots=True)
|
||||
class SlotsTest:
|
||||
x: undefined
|
||||
y: list[undefined]
|
||||
|
||||
return SlotsTest
|
||||
|
||||
for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
|
||||
with self.subTest(make=make):
|
||||
C = make()
|
||||
support.gc_collect()
|
||||
candidates = [cls for cls in object.__subclasses__() if cls.__name__ == 'SlotsTest'
|
||||
and cls.__firstlineno__ == make.__code__.co_firstlineno + 1]
|
||||
self.assertEqual(candidates, [C])
|
||||
|
||||
|
||||
class TestDescriptors(unittest.TestCase):
|
||||
def test_set_name(self):
|
||||
@@ -4218,16 +4457,56 @@ class TestMakeDataclass(unittest.TestCase):
|
||||
C = make_dataclass('Point', ['x', 'y', 'z'])
|
||||
c = C(1, 2, 3)
|
||||
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
|
||||
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
|
||||
'y': 'typing.Any',
|
||||
'z': 'typing.Any'})
|
||||
self.assertEqual(C.__annotations__, {'x': typing.Any,
|
||||
'y': typing.Any,
|
||||
'z': typing.Any})
|
||||
|
||||
C = make_dataclass('Point', ['x', ('y', int), 'z'])
|
||||
c = C(1, 2, 3)
|
||||
self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
|
||||
self.assertEqual(C.__annotations__, {'x': 'typing.Any',
|
||||
self.assertEqual(C.__annotations__, {'x': typing.Any,
|
||||
'y': int,
|
||||
'z': 'typing.Any'})
|
||||
'z': typing.Any})
|
||||
|
||||
def test_no_types_get_annotations(self):
|
||||
C = make_dataclass('C', ['x', ('y', int), 'z'])
|
||||
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(C, format=annotationlib.Format.VALUE),
|
||||
{'x': typing.Any, 'y': int, 'z': typing.Any},
|
||||
)
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(
|
||||
C, format=annotationlib.Format.FORWARDREF),
|
||||
{'x': typing.Any, 'y': int, 'z': typing.Any},
|
||||
)
|
||||
self.assertEqual(
|
||||
annotationlib.get_annotations(
|
||||
C, format=annotationlib.Format.STRING),
|
||||
{'x': 'typing.Any', 'y': 'int', 'z': 'typing.Any'},
|
||||
)
|
||||
|
||||
def test_no_types_no_typing_import(self):
|
||||
with import_helper.CleanImport('typing'):
|
||||
self.assertNotIn('typing', sys.modules)
|
||||
C = make_dataclass('C', ['x', ('y', int)])
|
||||
|
||||
self.assertNotIn('typing', sys.modules)
|
||||
self.assertEqual(
|
||||
C.__annotate__(annotationlib.Format.FORWARDREF),
|
||||
{
|
||||
'x': annotationlib.ForwardRef('Any', module='typing'),
|
||||
'y': int,
|
||||
},
|
||||
)
|
||||
self.assertNotIn('typing', sys.modules)
|
||||
|
||||
for field in fields(C):
|
||||
if field.name == "x":
|
||||
self.assertEqual(field.type, annotationlib.ForwardRef('Any', module='typing'))
|
||||
else:
|
||||
self.assertEqual(field.name, "y")
|
||||
self.assertIs(field.type, int)
|
||||
|
||||
def test_module_attr(self):
|
||||
self.assertEqual(ByMakeDataClass.__module__, __name__)
|
||||
@@ -4314,6 +4593,23 @@ class TestMakeDataclass(unittest.TestCase):
|
||||
C = make_dataclass(classname, ['a', 'b'])
|
||||
self.assertEqual(C.__name__, classname)
|
||||
|
||||
def test_dataclass_decorator_default(self):
|
||||
C = make_dataclass('C', [('x', int)], decorator=dataclass)
|
||||
c = C(10)
|
||||
self.assertEqual(c.x, 10)
|
||||
|
||||
def test_dataclass_custom_decorator(self):
|
||||
def custom_dataclass(cls, *args, **kwargs):
|
||||
dc = dataclass(cls, *args, **kwargs)
|
||||
dc.__custom__ = True
|
||||
return dc
|
||||
|
||||
C = make_dataclass('C', [('x', int)], decorator=custom_dataclass)
|
||||
c = C(10)
|
||||
self.assertEqual(c.x, 10)
|
||||
self.assertEqual(c.__custom__, True)
|
||||
|
||||
|
||||
class TestReplace(unittest.TestCase):
|
||||
def test(self):
|
||||
@dataclass(frozen=True)
|
||||
@@ -4562,8 +4858,7 @@ class TestAbstract(unittest.TestCase):
|
||||
self.assertFalse(inspect.isabstract(Date))
|
||||
self.assertGreater(Date(2020,12,25), Date(2020,8,31))
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_maintain_abc(self):
|
||||
class A(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
@@ -4919,6 +5214,140 @@ class TestKeywordArgs(unittest.TestCase):
|
||||
self.assertTrue(fields(B)[0].kw_only)
|
||||
self.assertFalse(fields(B)[1].kw_only)
|
||||
|
||||
def test_deferred_annotations(self):
|
||||
@dataclass
|
||||
class A:
|
||||
x: undefined
|
||||
y: ClassVar[undefined]
|
||||
|
||||
fs = fields(A)
|
||||
self.assertEqual(len(fs), 1)
|
||||
self.assertEqual(fs[0].name, 'x')
|
||||
|
||||
|
||||
class TestZeroArgumentSuperWithSlots(unittest.TestCase):
|
||||
def test_zero_argument_super(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
def foo(self):
|
||||
super()
|
||||
|
||||
A().foo()
|
||||
|
||||
def test_dunder_class_with_old_property(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
def _get_foo(slf):
|
||||
self.assertIs(__class__, type(slf))
|
||||
self.assertIs(__class__, slf.__class__)
|
||||
return __class__
|
||||
|
||||
def _set_foo(slf, value):
|
||||
self.assertIs(__class__, type(slf))
|
||||
self.assertIs(__class__, slf.__class__)
|
||||
|
||||
def _del_foo(slf):
|
||||
self.assertIs(__class__, type(slf))
|
||||
self.assertIs(__class__, slf.__class__)
|
||||
|
||||
foo = property(_get_foo, _set_foo, _del_foo)
|
||||
|
||||
a = A()
|
||||
self.assertIs(a.foo, A)
|
||||
a.foo = 4
|
||||
del a.foo
|
||||
|
||||
def test_dunder_class_with_new_property(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
@property
|
||||
def foo(slf):
|
||||
return slf.__class__
|
||||
|
||||
@foo.setter
|
||||
def foo(slf, value):
|
||||
self.assertIs(__class__, type(slf))
|
||||
|
||||
@foo.deleter
|
||||
def foo(slf):
|
||||
self.assertIs(__class__, type(slf))
|
||||
|
||||
a = A()
|
||||
self.assertIs(a.foo, A)
|
||||
a.foo = 4
|
||||
del a.foo
|
||||
|
||||
# Test the parts of a property individually.
|
||||
def test_slots_dunder_class_property_getter(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
@property
|
||||
def foo(slf):
|
||||
return __class__
|
||||
|
||||
a = A()
|
||||
self.assertIs(a.foo, A)
|
||||
|
||||
def test_slots_dunder_class_property_setter(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
foo = property()
|
||||
@foo.setter
|
||||
def foo(slf, val):
|
||||
self.assertIs(__class__, type(slf))
|
||||
|
||||
a = A()
|
||||
a.foo = 4
|
||||
|
||||
def test_slots_dunder_class_property_deleter(self):
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
foo = property()
|
||||
@foo.deleter
|
||||
def foo(slf):
|
||||
self.assertIs(__class__, type(slf))
|
||||
|
||||
a = A()
|
||||
del a.foo
|
||||
|
||||
def test_wrapped(self):
|
||||
def mydecorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@dataclass(slots=True)
|
||||
class A:
|
||||
@mydecorator
|
||||
def foo(self):
|
||||
super()
|
||||
|
||||
A().foo()
|
||||
|
||||
def test_remembered_class(self):
|
||||
# Apply the dataclass decorator manually (not when the class
|
||||
# is created), so that we can keep a reference to the
|
||||
# undecorated class.
|
||||
class A:
|
||||
def cls(self):
|
||||
return __class__
|
||||
|
||||
self.assertIs(A().cls(), A)
|
||||
|
||||
B = dataclass(slots=True)(A)
|
||||
self.assertIs(B().cls(), B)
|
||||
|
||||
# This is undesirable behavior, but is a function of how
|
||||
# modifying __class__ in the closure works. I'm not sure this
|
||||
# should be tested or not: I don't really want to guarantee
|
||||
# this behavior, but I don't want to lose the point that this
|
||||
# is how it works.
|
||||
|
||||
# The underlying class is "broken" by changing its __class__
|
||||
# in A.foo() to B. This normally isn't a problem, because no
|
||||
# one will be keeping a reference to the underlying class A.
|
||||
self.assertIs(A().cls(), B)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user