mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Merge pull request #6811 from youknowone/functools
Implement more functools features and Update functools from v3.14.2
This commit is contained in:
@@ -51,6 +51,7 @@ numer
|
||||
orelse
|
||||
pathconfig
|
||||
patma
|
||||
phcount
|
||||
platstdlib
|
||||
posonlyarg
|
||||
posonlyargs
|
||||
|
||||
382
Lib/functools.py
vendored
382
Lib/functools.py
vendored
@@ -6,23 +6,22 @@
|
||||
# Written by Nick Coghlan <ncoghlan at gmail.com>,
|
||||
# Raymond Hettinger <python at rcn.com>,
|
||||
# and Łukasz Langa <lukasz at langa.pl>.
|
||||
# Copyright (C) 2006-2013 Python Software Foundation.
|
||||
# Copyright (C) 2006 Python Software Foundation.
|
||||
# See C source code for _functools credits/copyright
|
||||
|
||||
__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
|
||||
'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce',
|
||||
'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod',
|
||||
'cached_property']
|
||||
'cached_property', 'Placeholder']
|
||||
|
||||
from abc import get_cache_token
|
||||
from collections import namedtuple
|
||||
# import types, weakref # Deferred to single_dispatch()
|
||||
# import weakref # Deferred to single_dispatch()
|
||||
from operator import itemgetter
|
||||
from reprlib import recursive_repr
|
||||
from types import GenericAlias, MethodType, MappingProxyType, UnionType
|
||||
from _thread import RLock
|
||||
|
||||
# Avoid importing types, so we can speedup import time
|
||||
GenericAlias = type(list[int])
|
||||
|
||||
################################################################################
|
||||
### update_wrapper() and wraps() decorator
|
||||
################################################################################
|
||||
@@ -31,7 +30,7 @@ GenericAlias = type(list[int])
|
||||
# wrapper functions that can handle naive introspection
|
||||
|
||||
WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
|
||||
'__annotations__', '__type_params__')
|
||||
'__annotate__', '__type_params__')
|
||||
WRAPPER_UPDATES = ('__dict__',)
|
||||
def update_wrapper(wrapper,
|
||||
wrapped,
|
||||
@@ -237,7 +236,7 @@ _initial_missing = object()
|
||||
|
||||
def reduce(function, sequence, initial=_initial_missing):
|
||||
"""
|
||||
reduce(function, iterable[, initial], /) -> value
|
||||
reduce(function, iterable, /[, initial]) -> value
|
||||
|
||||
Apply a function of two arguments cumulatively to the items of an iterable, from left to right.
|
||||
|
||||
@@ -265,63 +264,138 @@ def reduce(function, sequence, initial=_initial_missing):
|
||||
|
||||
return value
|
||||
|
||||
try:
|
||||
from _functools import reduce
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
################################################################################
|
||||
### partial() argument application
|
||||
################################################################################
|
||||
|
||||
|
||||
class _PlaceholderType:
|
||||
"""The type of the Placeholder singleton.
|
||||
|
||||
Used as a placeholder for partial arguments.
|
||||
"""
|
||||
__instance = None
|
||||
__slots__ = ()
|
||||
|
||||
def __init_subclass__(cls, *args, **kwargs):
|
||||
raise TypeError(f"type '{cls.__name__}' is not an acceptable base type")
|
||||
|
||||
def __new__(cls):
|
||||
if cls.__instance is None:
|
||||
cls.__instance = object.__new__(cls)
|
||||
return cls.__instance
|
||||
|
||||
def __repr__(self):
|
||||
return 'Placeholder'
|
||||
|
||||
def __reduce__(self):
|
||||
return 'Placeholder'
|
||||
|
||||
Placeholder = _PlaceholderType()
|
||||
|
||||
def _partial_prepare_merger(args):
|
||||
if not args:
|
||||
return 0, None
|
||||
nargs = len(args)
|
||||
order = []
|
||||
j = nargs
|
||||
for i, a in enumerate(args):
|
||||
if a is Placeholder:
|
||||
order.append(j)
|
||||
j += 1
|
||||
else:
|
||||
order.append(i)
|
||||
phcount = j - nargs
|
||||
merger = itemgetter(*order) if phcount else None
|
||||
return phcount, merger
|
||||
|
||||
def _partial_new(cls, func, /, *args, **keywords):
|
||||
if issubclass(cls, partial):
|
||||
base_cls = partial
|
||||
if not callable(func):
|
||||
raise TypeError("the first argument must be callable")
|
||||
else:
|
||||
base_cls = partialmethod
|
||||
# func could be a descriptor like classmethod which isn't callable
|
||||
if not callable(func) and not hasattr(func, "__get__"):
|
||||
raise TypeError(f"the first argument {func!r} must be a callable "
|
||||
"or a descriptor")
|
||||
if args and args[-1] is Placeholder:
|
||||
raise TypeError("trailing Placeholders are not allowed")
|
||||
for value in keywords.values():
|
||||
if value is Placeholder:
|
||||
raise TypeError("Placeholder cannot be passed as a keyword argument")
|
||||
if isinstance(func, base_cls):
|
||||
pto_phcount = func._phcount
|
||||
tot_args = func.args
|
||||
if args:
|
||||
tot_args += args
|
||||
if pto_phcount:
|
||||
# merge args with args of `func` which is `partial`
|
||||
nargs = len(args)
|
||||
if nargs < pto_phcount:
|
||||
tot_args += (Placeholder,) * (pto_phcount - nargs)
|
||||
tot_args = func._merger(tot_args)
|
||||
if nargs > pto_phcount:
|
||||
tot_args += args[pto_phcount:]
|
||||
phcount, merger = _partial_prepare_merger(tot_args)
|
||||
else: # works for both pto_phcount == 0 and != 0
|
||||
phcount, merger = pto_phcount, func._merger
|
||||
keywords = {**func.keywords, **keywords}
|
||||
func = func.func
|
||||
else:
|
||||
tot_args = args
|
||||
phcount, merger = _partial_prepare_merger(tot_args)
|
||||
|
||||
self = object.__new__(cls)
|
||||
self.func = func
|
||||
self.args = tot_args
|
||||
self.keywords = keywords
|
||||
self._phcount = phcount
|
||||
self._merger = merger
|
||||
return self
|
||||
|
||||
def _partial_repr(self):
|
||||
cls = type(self)
|
||||
module = cls.__module__
|
||||
qualname = cls.__qualname__
|
||||
args = [repr(self.func)]
|
||||
args.extend(map(repr, self.args))
|
||||
args.extend(f"{k}={v!r}" for k, v in self.keywords.items())
|
||||
return f"{module}.{qualname}({', '.join(args)})"
|
||||
|
||||
# Purely functional, no descriptor behaviour
|
||||
class partial:
|
||||
"""New function with partial application of the given arguments
|
||||
and keywords.
|
||||
"""
|
||||
|
||||
__slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
|
||||
__slots__ = ("func", "args", "keywords", "_phcount", "_merger",
|
||||
"__dict__", "__weakref__")
|
||||
|
||||
def __new__(cls, func, /, *args, **keywords):
|
||||
if not callable(func):
|
||||
raise TypeError("the first argument must be callable")
|
||||
|
||||
if isinstance(func, partial):
|
||||
args = func.args + args
|
||||
keywords = {**func.keywords, **keywords}
|
||||
func = func.func
|
||||
|
||||
self = super(partial, cls).__new__(cls)
|
||||
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.keywords = keywords
|
||||
return self
|
||||
__new__ = _partial_new
|
||||
__repr__ = recursive_repr()(_partial_repr)
|
||||
|
||||
def __call__(self, /, *args, **keywords):
|
||||
phcount = self._phcount
|
||||
if phcount:
|
||||
try:
|
||||
pto_args = self._merger(self.args + args)
|
||||
args = args[phcount:]
|
||||
except IndexError:
|
||||
raise TypeError("missing positional arguments "
|
||||
"in 'partial' call; expected "
|
||||
f"at least {phcount}, got {len(args)}")
|
||||
else:
|
||||
pto_args = self.args
|
||||
keywords = {**self.keywords, **keywords}
|
||||
return self.func(*self.args, *args, **keywords)
|
||||
|
||||
@recursive_repr()
|
||||
def __repr__(self):
|
||||
cls = type(self)
|
||||
qualname = cls.__qualname__
|
||||
module = cls.__module__
|
||||
args = [repr(self.func)]
|
||||
args.extend(repr(x) for x in self.args)
|
||||
args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
|
||||
return f"{module}.{qualname}({', '.join(args)})"
|
||||
return self.func(*pto_args, *args, **keywords)
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
if obj is None:
|
||||
return self
|
||||
import warnings
|
||||
warnings.warn('functools.partial will be a method descriptor in '
|
||||
'future Python versions; wrap it in staticmethod() '
|
||||
'if you want to preserve the old behavior',
|
||||
FutureWarning, 2)
|
||||
return self
|
||||
return MethodType(self, obj)
|
||||
|
||||
def __reduce__(self):
|
||||
return type(self), (self.func,), (self.func, self.args,
|
||||
@@ -338,6 +412,10 @@ class partial:
|
||||
(namespace is not None and not isinstance(namespace, dict))):
|
||||
raise TypeError("invalid partial state")
|
||||
|
||||
if args and args[-1] is Placeholder:
|
||||
raise TypeError("trailing Placeholders are not allowed")
|
||||
phcount, merger = _partial_prepare_merger(args)
|
||||
|
||||
args = tuple(args) # just in case it's a subclass
|
||||
if kwds is None:
|
||||
kwds = {}
|
||||
@@ -350,56 +428,43 @@ class partial:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.keywords = kwds
|
||||
self._phcount = phcount
|
||||
self._merger = merger
|
||||
|
||||
__class_getitem__ = classmethod(GenericAlias)
|
||||
|
||||
|
||||
try:
|
||||
from _functools import partial
|
||||
from _functools import partial, Placeholder, _PlaceholderType
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Descriptor version
|
||||
class partialmethod(object):
|
||||
class partialmethod:
|
||||
"""Method descriptor with partial application of the given arguments
|
||||
and keywords.
|
||||
|
||||
Supports wrapping existing descriptors and handles non-descriptor
|
||||
callables as instance methods.
|
||||
"""
|
||||
|
||||
def __init__(self, func, /, *args, **keywords):
|
||||
if not callable(func) and not hasattr(func, "__get__"):
|
||||
raise TypeError("{!r} is not callable or a descriptor"
|
||||
.format(func))
|
||||
|
||||
# func could be a descriptor like classmethod which isn't callable,
|
||||
# so we can't inherit from partial (it verifies func is callable)
|
||||
if isinstance(func, partialmethod):
|
||||
# flattening is mandatory in order to place cls/self before all
|
||||
# other arguments
|
||||
# it's also more efficient since only one function will be called
|
||||
self.func = func.func
|
||||
self.args = func.args + args
|
||||
self.keywords = {**func.keywords, **keywords}
|
||||
else:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.keywords = keywords
|
||||
|
||||
def __repr__(self):
|
||||
cls = type(self)
|
||||
module = cls.__module__
|
||||
qualname = cls.__qualname__
|
||||
args = [repr(self.func)]
|
||||
args.extend(map(repr, self.args))
|
||||
args.extend(f"{k}={v!r}" for k, v in self.keywords.items())
|
||||
return f"{module}.{qualname}({', '.join(args)})"
|
||||
__new__ = _partial_new
|
||||
__repr__ = _partial_repr
|
||||
|
||||
def _make_unbound_method(self):
|
||||
def _method(cls_or_self, /, *args, **keywords):
|
||||
phcount = self._phcount
|
||||
if phcount:
|
||||
try:
|
||||
pto_args = self._merger(self.args + args)
|
||||
args = args[phcount:]
|
||||
except IndexError:
|
||||
raise TypeError("missing positional arguments "
|
||||
"in 'partialmethod' call; expected "
|
||||
f"at least {phcount}, got {len(args)}")
|
||||
else:
|
||||
pto_args = self.args
|
||||
keywords = {**self.keywords, **keywords}
|
||||
return self.func(cls_or_self, *self.args, *args, **keywords)
|
||||
return self.func(cls_or_self, *pto_args, *args, **keywords)
|
||||
_method.__isabstractmethod__ = self.__isabstractmethod__
|
||||
_method.__partialmethod__ = self
|
||||
return _method
|
||||
@@ -407,7 +472,7 @@ class partialmethod(object):
|
||||
def __get__(self, obj, cls=None):
|
||||
get = getattr(self.func, "__get__", None)
|
||||
result = None
|
||||
if get is not None and not isinstance(self.func, partial):
|
||||
if get is not None:
|
||||
new_func = get(obj, cls)
|
||||
if new_func is not self.func:
|
||||
# Assume __get__ returning something new indicates the
|
||||
@@ -454,22 +519,6 @@ def _unwrap_partialmethod(func):
|
||||
|
||||
_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
|
||||
|
||||
class _HashedSeq(list):
|
||||
""" This class guarantees that hash() will be called no more than once
|
||||
per element. This is important because the lru_cache() will hash
|
||||
the key multiple times on a cache miss.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = 'hashvalue'
|
||||
|
||||
def __init__(self, tup, hash=hash):
|
||||
self[:] = tup
|
||||
self.hashvalue = hash(tup)
|
||||
|
||||
def __hash__(self):
|
||||
return self.hashvalue
|
||||
|
||||
def _make_key(args, kwds, typed,
|
||||
kwd_mark = (object(),),
|
||||
fasttypes = {int, str},
|
||||
@@ -499,7 +548,7 @@ def _make_key(args, kwds, typed,
|
||||
key += tuple(type(v) for v in kwds.values())
|
||||
elif len(key) == 1 and type(key[0]) in fasttypes:
|
||||
return key[0]
|
||||
return _HashedSeq(key)
|
||||
return key
|
||||
|
||||
def lru_cache(maxsize=128, typed=False):
|
||||
"""Least-recently-used cache decorator.
|
||||
@@ -835,7 +884,7 @@ def singledispatch(func):
|
||||
# There are many programs that use functools without singledispatch, so we
|
||||
# trade-off making singledispatch marginally slower for the benefit of
|
||||
# making start-up of such applications slightly faster.
|
||||
import types, weakref
|
||||
import weakref
|
||||
|
||||
registry = {}
|
||||
dispatch_cache = weakref.WeakKeyDictionary()
|
||||
@@ -864,16 +913,11 @@ def singledispatch(func):
|
||||
dispatch_cache[cls] = impl
|
||||
return impl
|
||||
|
||||
def _is_union_type(cls):
|
||||
from typing import get_origin, Union
|
||||
return get_origin(cls) in {Union, types.UnionType}
|
||||
|
||||
def _is_valid_dispatch_type(cls):
|
||||
if isinstance(cls, type):
|
||||
return True
|
||||
from typing import get_args
|
||||
return (_is_union_type(cls) and
|
||||
all(isinstance(arg, type) for arg in get_args(cls)))
|
||||
return (isinstance(cls, UnionType) and
|
||||
all(isinstance(arg, type) for arg in cls.__args__))
|
||||
|
||||
def register(cls, func=None):
|
||||
"""generic_func.register(cls, func) -> func
|
||||
@@ -891,8 +935,8 @@ def singledispatch(func):
|
||||
f"Invalid first argument to `register()`. "
|
||||
f"{cls!r} is not a class or union type."
|
||||
)
|
||||
ann = getattr(cls, '__annotations__', {})
|
||||
if not ann:
|
||||
ann = getattr(cls, '__annotate__', None)
|
||||
if ann is None:
|
||||
raise TypeError(
|
||||
f"Invalid first argument to `register()`: {cls!r}. "
|
||||
f"Use either `@register(some_class)` or plain `@register` "
|
||||
@@ -902,23 +946,27 @@ def singledispatch(func):
|
||||
|
||||
# only import typing if annotation parsing is necessary
|
||||
from typing import get_type_hints
|
||||
argname, cls = next(iter(get_type_hints(func).items()))
|
||||
from annotationlib import Format, ForwardRef
|
||||
argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items()))
|
||||
if not _is_valid_dispatch_type(cls):
|
||||
if _is_union_type(cls):
|
||||
if isinstance(cls, UnionType):
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
f"{cls!r} not all arguments are classes."
|
||||
)
|
||||
elif isinstance(cls, ForwardRef):
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
f"{cls!r} is an unresolved forward reference."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
f"{cls!r} is not a class."
|
||||
)
|
||||
|
||||
if _is_union_type(cls):
|
||||
from typing import get_args
|
||||
|
||||
for arg in get_args(cls):
|
||||
if isinstance(cls, UnionType):
|
||||
for arg in cls.__args__:
|
||||
registry[arg] = func
|
||||
else:
|
||||
registry[cls] = func
|
||||
@@ -937,7 +985,7 @@ def singledispatch(func):
|
||||
registry[object] = func
|
||||
wrapper.register = register
|
||||
wrapper.dispatch = dispatch
|
||||
wrapper.registry = types.MappingProxyType(registry)
|
||||
wrapper.registry = MappingProxyType(registry)
|
||||
wrapper._clear_cache = dispatch_cache.clear
|
||||
update_wrapper(wrapper, func)
|
||||
return wrapper
|
||||
@@ -947,8 +995,7 @@ def singledispatch(func):
|
||||
class singledispatchmethod:
|
||||
"""Single-dispatch generic method descriptor.
|
||||
|
||||
Supports wrapping existing descriptors and handles non-descriptor
|
||||
callables as instance methods.
|
||||
Supports wrapping existing descriptors.
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
@@ -966,24 +1013,77 @@ class singledispatchmethod:
|
||||
return self.dispatcher.register(cls, func=method)
|
||||
|
||||
def __get__(self, obj, cls=None):
|
||||
dispatch = self.dispatcher.dispatch
|
||||
funcname = getattr(self.func, '__name__', 'singledispatchmethod method')
|
||||
def _method(*args, **kwargs):
|
||||
if not args:
|
||||
raise TypeError(f'{funcname} requires at least '
|
||||
'1 positional argument')
|
||||
return dispatch(args[0].__class__).__get__(obj, cls)(*args, **kwargs)
|
||||
|
||||
_method.__isabstractmethod__ = self.__isabstractmethod__
|
||||
_method.register = self.register
|
||||
update_wrapper(_method, self.func)
|
||||
|
||||
return _method
|
||||
return _singledispatchmethod_get(self, obj, cls)
|
||||
|
||||
@property
|
||||
def __isabstractmethod__(self):
|
||||
return getattr(self.func, '__isabstractmethod__', False)
|
||||
|
||||
def __repr__(self):
|
||||
try:
|
||||
name = self.func.__qualname__
|
||||
except AttributeError:
|
||||
try:
|
||||
name = self.func.__name__
|
||||
except AttributeError:
|
||||
name = '?'
|
||||
return f'<single dispatch method descriptor {name}>'
|
||||
|
||||
class _singledispatchmethod_get:
|
||||
def __init__(self, unbound, obj, cls):
|
||||
self._unbound = unbound
|
||||
self._dispatch = unbound.dispatcher.dispatch
|
||||
self._obj = obj
|
||||
self._cls = cls
|
||||
# Set instance attributes which cannot be handled in __getattr__()
|
||||
# because they conflict with type descriptors.
|
||||
func = unbound.func
|
||||
try:
|
||||
self.__module__ = func.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
self.__doc__ = func.__doc__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
try:
|
||||
name = self.__qualname__
|
||||
except AttributeError:
|
||||
try:
|
||||
name = self.__name__
|
||||
except AttributeError:
|
||||
name = '?'
|
||||
if self._obj is not None:
|
||||
return f'<bound single dispatch method {name} of {self._obj!r}>'
|
||||
else:
|
||||
return f'<single dispatch method {name}>'
|
||||
|
||||
def __call__(self, /, *args, **kwargs):
|
||||
if not args:
|
||||
funcname = getattr(self._unbound.func, '__name__',
|
||||
'singledispatchmethod method')
|
||||
raise TypeError(f'{funcname} requires at least '
|
||||
'1 positional argument')
|
||||
return self._dispatch(args[0].__class__).__get__(self._obj, self._cls)(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Resolve these attributes lazily to speed up creation of
|
||||
# the _singledispatchmethod_get instance.
|
||||
if name not in {'__name__', '__qualname__', '__isabstractmethod__',
|
||||
'__annotations__', '__type_params__'}:
|
||||
raise AttributeError
|
||||
return getattr(self._unbound.func, name)
|
||||
|
||||
@property
|
||||
def __wrapped__(self):
|
||||
return self._unbound.func
|
||||
|
||||
@property
|
||||
def register(self):
|
||||
return self._unbound.register
|
||||
|
||||
|
||||
################################################################################
|
||||
### cached_property() - property result cached as instance attribute
|
||||
@@ -1035,3 +1135,31 @@ class cached_property:
|
||||
return val
|
||||
|
||||
__class_getitem__ = classmethod(GenericAlias)
|
||||
|
||||
def _warn_python_reduce_kwargs(py_reduce):
|
||||
@wraps(py_reduce)
|
||||
def wrapper(*args, **kwargs):
|
||||
if 'function' in kwargs or 'sequence' in kwargs:
|
||||
import os
|
||||
import warnings
|
||||
warnings.warn(
|
||||
'Calling functools.reduce with keyword arguments '
|
||||
'"function" or "sequence" '
|
||||
'is deprecated in Python 3.14 and will be '
|
||||
'forbidden in Python 3.16.',
|
||||
DeprecationWarning,
|
||||
skip_file_prefixes=(os.path.dirname(__file__),))
|
||||
return py_reduce(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
reduce = _warn_python_reduce_kwargs(reduce)
|
||||
del _warn_python_reduce_kwargs
|
||||
|
||||
# The import of the C accelerated version of reduce() has been moved
|
||||
# here due to gh-121676. In Python 3.16, _warn_python_reduce_kwargs()
|
||||
# should be removed and the import block should be moved back right
|
||||
# after the definition of reduce().
|
||||
try:
|
||||
from _functools import reduce
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
3
Lib/inspect.py
vendored
3
Lib/inspect.py
vendored
@@ -563,7 +563,8 @@ def isroutine(object):
|
||||
or isfunction(object)
|
||||
or ismethod(object)
|
||||
or ismethoddescriptor(object)
|
||||
or ismethodwrapper(object))
|
||||
or ismethodwrapper(object)
|
||||
or isinstance(object, functools._singledispatchmethod_get))
|
||||
|
||||
def isabstract(object):
|
||||
"""Return true if the object is an abstract base class (ABC)."""
|
||||
|
||||
396
Lib/test/test_functools.py
vendored
396
Lib/test/test_functools.py
vendored
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from annotationlib import Format, get_annotations
|
||||
import builtins
|
||||
import collections
|
||||
import collections.abc
|
||||
@@ -6,6 +7,7 @@ import copy
|
||||
from itertools import permutations
|
||||
import pickle
|
||||
from random import choice
|
||||
import re
|
||||
import sys
|
||||
from test import support
|
||||
import threading
|
||||
@@ -19,8 +21,11 @@ from weakref import proxy
|
||||
import contextlib
|
||||
from inspect import Signature
|
||||
|
||||
from test.support import ALWAYS_EQ
|
||||
from test.support import import_helper
|
||||
from test.support import threading_helper
|
||||
from test.support import cpython_only
|
||||
from test.support import EqualToForwardRef
|
||||
|
||||
import functools
|
||||
|
||||
@@ -60,6 +65,14 @@ class BadTuple(tuple):
|
||||
class MyDict(dict):
|
||||
pass
|
||||
|
||||
class TestImportTime(unittest.TestCase):
|
||||
|
||||
@cpython_only
|
||||
def test_lazy_import(self):
|
||||
import_helper.ensure_lazy_imports(
|
||||
"functools", {"os", "weakref", "typing", "annotationlib", "warnings"}
|
||||
)
|
||||
|
||||
|
||||
class TestPartial:
|
||||
|
||||
@@ -210,6 +223,69 @@ class TestPartial:
|
||||
p2.new_attr = 'spam'
|
||||
self.assertEqual(p2.new_attr, 'spam')
|
||||
|
||||
def test_placeholders_trailing_raise(self):
|
||||
PH = self.module.Placeholder
|
||||
for args in [(PH,), (0, PH), (0, PH, 1, PH, PH, PH)]:
|
||||
with self.assertRaises(TypeError):
|
||||
self.partial(capture, *args)
|
||||
|
||||
def test_placeholders(self):
|
||||
PH = self.module.Placeholder
|
||||
# 1 Placeholder
|
||||
args = (PH, 0)
|
||||
p = self.partial(capture, *args)
|
||||
actual_args, actual_kwds = p('x')
|
||||
self.assertEqual(actual_args, ('x', 0))
|
||||
self.assertEqual(actual_kwds, {})
|
||||
# 2 Placeholders
|
||||
args = (PH, 0, PH, 1)
|
||||
p = self.partial(capture, *args)
|
||||
with self.assertRaises(TypeError):
|
||||
p('x')
|
||||
actual_args, actual_kwds = p('x', 'y')
|
||||
self.assertEqual(actual_args, ('x', 0, 'y', 1))
|
||||
self.assertEqual(actual_kwds, {})
|
||||
# Checks via `is` and not `eq`
|
||||
# thus ALWAYS_EQ isn't treated as Placeholder
|
||||
p = self.partial(capture, ALWAYS_EQ)
|
||||
actual_args, actual_kwds = p()
|
||||
self.assertEqual(len(actual_args), 1)
|
||||
self.assertIs(actual_args[0], ALWAYS_EQ)
|
||||
self.assertEqual(actual_kwds, {})
|
||||
|
||||
def test_placeholders_optimization(self):
|
||||
PH = self.module.Placeholder
|
||||
p = self.partial(capture, PH, 0)
|
||||
p2 = self.partial(p, PH, 1, 2, 3)
|
||||
self.assertEqual(p2.args, (PH, 0, 1, 2, 3))
|
||||
p3 = self.partial(p2, -1, 4)
|
||||
actual_args, actual_kwds = p3(5)
|
||||
self.assertEqual(actual_args, (-1, 0, 1, 2, 3, 4, 5))
|
||||
self.assertEqual(actual_kwds, {})
|
||||
# inner partial has placeholders and outer partial has no args case
|
||||
p = self.partial(capture, PH, 0)
|
||||
p2 = self.partial(p)
|
||||
self.assertEqual(p2.args, (PH, 0))
|
||||
self.assertEqual(p2(1), ((1, 0), {}))
|
||||
|
||||
def test_placeholders_kw_restriction(self):
|
||||
PH = self.module.Placeholder
|
||||
with self.assertRaisesRegex(TypeError, "Placeholder"):
|
||||
self.partial(capture, a=PH)
|
||||
# Passes, as checks via `is` and not `eq`
|
||||
p = self.partial(capture, a=ALWAYS_EQ)
|
||||
actual_args, actual_kwds = p()
|
||||
self.assertEqual(actual_args, ())
|
||||
self.assertEqual(len(actual_kwds), 1)
|
||||
self.assertIs(actual_kwds['a'], ALWAYS_EQ)
|
||||
|
||||
def test_construct_placeholder_singleton(self):
|
||||
PH = self.module.Placeholder
|
||||
tp = type(PH)
|
||||
self.assertIs(tp(), PH)
|
||||
self.assertRaises(TypeError, tp, 1, 2)
|
||||
self.assertRaises(TypeError, tp, a=1, b=2)
|
||||
|
||||
def test_repr(self):
|
||||
args = (object(), object())
|
||||
args_repr = ', '.join(repr(a) for a in args)
|
||||
@@ -311,8 +387,26 @@ class TestPartial:
|
||||
self.assertEqual(f(2), ((2,), {}))
|
||||
self.assertEqual(f(), ((), {}))
|
||||
|
||||
# Set State with placeholders
|
||||
PH = self.module.Placeholder
|
||||
f = self.partial(signature)
|
||||
f.__setstate__((capture, (PH, 1), dict(a=10), dict(attr=[])))
|
||||
self.assertEqual(signature(f), (capture, (PH, 1), dict(a=10), dict(attr=[])))
|
||||
msg_regex = re.escape("missing positional arguments in 'partial' call; "
|
||||
"expected at least 1, got 0")
|
||||
with self.assertRaisesRegex(TypeError, f'^{msg_regex}$') as cm:
|
||||
f()
|
||||
self.assertEqual(f(2), ((2, 1), dict(a=10)))
|
||||
|
||||
# Trailing Placeholder error
|
||||
f = self.partial(signature)
|
||||
msg_regex = re.escape("trailing Placeholders are not allowed")
|
||||
with self.assertRaisesRegex(TypeError, f'^{msg_regex}$') as cm:
|
||||
f.__setstate__((capture, (1, PH), dict(a=10), dict(attr=[])))
|
||||
|
||||
def test_setstate_errors(self):
|
||||
f = self.partial(signature)
|
||||
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
|
||||
self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
|
||||
@@ -320,6 +414,8 @@ class TestPartial:
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, ()))
|
||||
self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, 'test'))
|
||||
|
||||
def test_setstate_subclasses(self):
|
||||
f = self.partial(signature)
|
||||
@@ -341,6 +437,8 @@ class TestPartial:
|
||||
self.assertEqual(r, ((1, 2), {}))
|
||||
self.assertIs(type(r[0]), tuple)
|
||||
|
||||
@support.skip_if_sanitizer("thread sanitizer crashes in __tsan::FuncEntry", thread=True)
|
||||
@support.skip_emscripten_stack_overflow()
|
||||
def test_recursive_pickle(self):
|
||||
with replaced_module('functools', self.module):
|
||||
f = self.partial(capture)
|
||||
@@ -395,7 +493,6 @@ class TestPartial:
|
||||
f = self.partial(object)
|
||||
self.assertRaises(TypeError, f.__setstate__, BadSequence())
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_partial_as_method(self):
|
||||
class A:
|
||||
meth = self.partial(capture, 1, a=2)
|
||||
@@ -406,9 +503,7 @@ class TestPartial:
|
||||
self.assertEqual(A.meth(3, b=4), ((1, 3), {'a': 2, 'b': 4}))
|
||||
self.assertEqual(A.cmeth(3, b=4), ((1, A, 3), {'a': 2, 'b': 4}))
|
||||
self.assertEqual(A.smeth(3, b=4), ((1, 3), {'a': 2, 'b': 4}))
|
||||
with self.assertWarns(FutureWarning) as w:
|
||||
self.assertEqual(a.meth(3, b=4), ((1, 3), {'a': 2, 'b': 4}))
|
||||
self.assertEqual(w.filename, __file__)
|
||||
self.assertEqual(a.meth(3, b=4), ((1, a, 3), {'a': 2, 'b': 4}))
|
||||
self.assertEqual(a.cmeth(3, b=4), ((1, A, 3), {'a': 2, 'b': 4}))
|
||||
self.assertEqual(a.smeth(3, b=4), ((1, 3), {'a': 2, 'b': 4}))
|
||||
|
||||
@@ -465,11 +560,18 @@ class TestPartialC(TestPartial, unittest.TestCase):
|
||||
self.assertIn('astr', r)
|
||||
self.assertIn("['sth']", r)
|
||||
|
||||
def test_repr(self):
|
||||
return super().test_repr()
|
||||
|
||||
def test_recursive_repr(self):
|
||||
return super().test_recursive_repr()
|
||||
def test_placeholders_refcount_smoke(self):
|
||||
PH = self.module.Placeholder
|
||||
# sum supports vector call
|
||||
lst1, start = [], []
|
||||
sum_lists = self.partial(sum, PH, start)
|
||||
for i in range(10):
|
||||
sum_lists([lst1, lst1])
|
||||
# collections.ChainMap initializer does not support vectorcall
|
||||
map1, map2 = {}, {}
|
||||
partial_cm = self.partial(collections.ChainMap, PH, map1)
|
||||
for i in range(10):
|
||||
partial_cm(map2, map2)
|
||||
|
||||
|
||||
class TestPartialPy(TestPartial, unittest.TestCase):
|
||||
@@ -495,6 +597,19 @@ class TestPartialCSubclass(TestPartialC):
|
||||
class TestPartialPySubclass(TestPartialPy):
|
||||
partial = PyPartialSubclass
|
||||
|
||||
def test_subclass_optimization(self):
|
||||
# `partial` input to `partial` subclass
|
||||
p = py_functools.partial(min, 2)
|
||||
p2 = self.partial(p, 1)
|
||||
self.assertIs(p2.func, min)
|
||||
self.assertEqual(p2(0), 0)
|
||||
# `partial` subclass input to `partial` subclass
|
||||
p = self.partial(min, 2)
|
||||
p2 = self.partial(p, 1)
|
||||
self.assertIs(p2.func, min)
|
||||
self.assertEqual(p2(0), 0)
|
||||
|
||||
|
||||
class TestPartialMethod(unittest.TestCase):
|
||||
|
||||
class A(object):
|
||||
@@ -564,11 +679,11 @@ class TestPartialMethod(unittest.TestCase):
|
||||
|
||||
def test_unbound_method_retrieval(self):
|
||||
obj = self.A
|
||||
self.assertFalse(hasattr(obj.both, "__self__"))
|
||||
self.assertFalse(hasattr(obj.nested, "__self__"))
|
||||
self.assertFalse(hasattr(obj.over_partial, "__self__"))
|
||||
self.assertFalse(hasattr(obj.static, "__self__"))
|
||||
self.assertFalse(hasattr(self.a.static, "__self__"))
|
||||
self.assertNotHasAttr(obj.both, "__self__")
|
||||
self.assertNotHasAttr(obj.nested, "__self__")
|
||||
self.assertNotHasAttr(obj.over_partial, "__self__")
|
||||
self.assertNotHasAttr(obj.static, "__self__")
|
||||
self.assertNotHasAttr(self.a.static, "__self__")
|
||||
|
||||
def test_descriptors(self):
|
||||
for obj in [self.A, self.a]:
|
||||
@@ -632,6 +747,20 @@ class TestPartialMethod(unittest.TestCase):
|
||||
p = functools.partial(f, 1)
|
||||
self.assertEqual(p(2), f(1, 2))
|
||||
|
||||
def test_subclass_optimization(self):
|
||||
class PartialMethodSubclass(functools.partialmethod):
|
||||
pass
|
||||
# `partialmethod` input to `partialmethod` subclass
|
||||
p = functools.partialmethod(min, 2)
|
||||
p2 = PartialMethodSubclass(p, 1)
|
||||
self.assertIs(p2.func, min)
|
||||
self.assertEqual(p2.__get__(0)(), 0)
|
||||
# `partialmethod` subclass input to `partialmethod` subclass
|
||||
p = PartialMethodSubclass(min, 2)
|
||||
p2 = PartialMethodSubclass(p, 1)
|
||||
self.assertIs(p2.func, min)
|
||||
self.assertEqual(p2.__get__(0)(), 0)
|
||||
|
||||
|
||||
class TestUpdateWrapper(unittest.TestCase):
|
||||
|
||||
@@ -696,7 +825,7 @@ class TestUpdateWrapper(unittest.TestCase):
|
||||
self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
|
||||
self.assertEqual(wrapper.__doc__, None)
|
||||
self.assertEqual(wrapper.__annotations__, {})
|
||||
self.assertFalse(hasattr(wrapper, 'attr'))
|
||||
self.assertNotHasAttr(wrapper, 'attr')
|
||||
|
||||
def test_selective_update(self):
|
||||
def f():
|
||||
@@ -745,7 +874,7 @@ class TestUpdateWrapper(unittest.TestCase):
|
||||
pass
|
||||
functools.update_wrapper(wrapper, max)
|
||||
self.assertEqual(wrapper.__name__, 'max')
|
||||
self.assertTrue(wrapper.__doc__.startswith('max('))
|
||||
self.assertStartsWith(wrapper.__doc__, 'max(')
|
||||
self.assertEqual(wrapper.__annotations__, {})
|
||||
|
||||
def test_update_type_wrapper(self):
|
||||
@@ -756,6 +885,26 @@ class TestUpdateWrapper(unittest.TestCase):
|
||||
self.assertEqual(wrapper.__annotations__, {})
|
||||
self.assertEqual(wrapper.__type_params__, ())
|
||||
|
||||
def test_update_wrapper_annotations(self):
|
||||
def inner(x: int): pass
|
||||
def wrapper(*args): pass
|
||||
|
||||
functools.update_wrapper(wrapper, inner)
|
||||
self.assertEqual(wrapper.__annotations__, {'x': int})
|
||||
self.assertIs(wrapper.__annotate__, inner.__annotate__)
|
||||
|
||||
def with_forward_ref(x: undefined): pass
|
||||
def wrapper(*args): pass
|
||||
|
||||
functools.update_wrapper(wrapper, with_forward_ref)
|
||||
|
||||
self.assertIs(wrapper.__annotate__, with_forward_ref.__annotate__)
|
||||
with self.assertRaises(NameError):
|
||||
wrapper.__annotations__
|
||||
|
||||
undefined = str
|
||||
self.assertEqual(wrapper.__annotations__, {'x': undefined})
|
||||
|
||||
|
||||
class TestWraps(TestUpdateWrapper):
|
||||
|
||||
@@ -795,7 +944,7 @@ class TestWraps(TestUpdateWrapper):
|
||||
self.assertEqual(wrapper.__name__, 'wrapper')
|
||||
self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
|
||||
self.assertEqual(wrapper.__doc__, None)
|
||||
self.assertFalse(hasattr(wrapper, 'attr'))
|
||||
self.assertNotHasAttr(wrapper, 'attr')
|
||||
|
||||
def test_selective_update(self):
|
||||
def f():
|
||||
@@ -897,6 +1046,29 @@ class TestReduce:
|
||||
d = {"one": 1, "two": 2, "three": 3}
|
||||
self.assertEqual(self.reduce(add, d), "".join(d.keys()))
|
||||
|
||||
# test correctness of keyword usage of `initial` in `reduce`
|
||||
def test_initial_keyword(self):
|
||||
def add(x, y):
|
||||
return x + y
|
||||
self.assertEqual(
|
||||
self.reduce(add, ['a', 'b', 'c'], ''),
|
||||
self.reduce(add, ['a', 'b', 'c'], initial=''),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
|
||||
self.reduce(add, [['a', 'c'], [], ['d', 'w']], initial=[]),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.reduce(lambda x, y: x*y, range(2,8), 1),
|
||||
self.reduce(lambda x, y: x*y, range(2,8), initial=1),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.reduce(lambda x, y: x*y, range(2,21), 1),
|
||||
self.reduce(lambda x, y: x*y, range(2,21), initial=1),
|
||||
)
|
||||
self.assertRaises(TypeError, self.reduce, add, [0, 1], initial="")
|
||||
self.assertEqual(self.reduce(42, "", initial="1"), "1") # func is never called with one item
|
||||
|
||||
|
||||
@unittest.skipUnless(c_functools, 'requires the C _functools module')
|
||||
class TestReduceC(TestReduce, unittest.TestCase):
|
||||
@@ -907,6 +1079,12 @@ class TestReduceC(TestReduce, unittest.TestCase):
|
||||
class TestReducePy(TestReduce, unittest.TestCase):
|
||||
reduce = staticmethod(py_functools.reduce)
|
||||
|
||||
def test_reduce_with_kwargs(self):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.reduce(function=lambda x, y: x + y, sequence=[1, 2, 3, 4, 5], initial=1)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
self.reduce(lambda x, y: x + y, sequence=[1, 2, 3, 4, 5], initial=1)
|
||||
|
||||
|
||||
class TestCmpToKey:
|
||||
|
||||
@@ -1014,35 +1192,35 @@ class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
|
||||
self, type(c_functools.cmp_to_key(None))
|
||||
)
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_bad_cmp(self):
|
||||
return super().test_bad_cmp()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_cmp_to_key(self):
|
||||
return super().test_cmp_to_key()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_cmp_to_key_arguments(self):
|
||||
return super().test_cmp_to_key_arguments()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_cmp_to_signature(self):
|
||||
return super().test_cmp_to_signature()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_hash(self):
|
||||
return super().test_hash()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_obj_field(self):
|
||||
return super().test_obj_field()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_sort_int(self):
|
||||
return super().test_sort_int()
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_sort_int_str(self):
|
||||
return super().test_sort_int_str()
|
||||
|
||||
@@ -1532,6 +1710,7 @@ class TestLRU:
|
||||
f(0, **{})
|
||||
self.assertEqual(f.cache_info().hits, 1)
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON; Python lru_cache impl doesn't cache hash like C impl
|
||||
def test_lru_hash_only_once(self):
|
||||
# To protect against weird reentrancy bugs and to improve
|
||||
# efficiency when faced with slow __hash__ methods, the
|
||||
@@ -1925,7 +2104,7 @@ class TestLRU:
|
||||
return 1
|
||||
self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
|
||||
|
||||
@support.suppress_immortalization()
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON; GC behavior differs from CPython's refcounting
|
||||
def test_lru_cache_weakrefable(self):
|
||||
@self.module.lru_cache
|
||||
def test_function(x):
|
||||
@@ -1963,8 +2142,38 @@ class TestLRU:
|
||||
self.assertEqual(str(Signature.from_callable(lru.cache_info)), '()')
|
||||
self.assertEqual(str(Signature.from_callable(lru.cache_clear)), '()')
|
||||
|
||||
def test_get_annotations(self):
|
||||
def orig(a: int) -> str: ...
|
||||
lru = self.module.lru_cache(1)(orig)
|
||||
|
||||
self.assertEqual(
|
||||
get_annotations(orig), {"a": int, "return": str},
|
||||
)
|
||||
self.assertEqual(
|
||||
get_annotations(lru), {"a": int, "return": str},
|
||||
)
|
||||
|
||||
def test_get_annotations_with_forwardref(self):
|
||||
def orig(a: int) -> nonexistent: ...
|
||||
lru = self.module.lru_cache(1)(orig)
|
||||
|
||||
self.assertEqual(
|
||||
get_annotations(orig, format=Format.FORWARDREF),
|
||||
{"a": int, "return": EqualToForwardRef('nonexistent', owner=orig)},
|
||||
)
|
||||
self.assertEqual(
|
||||
get_annotations(lru, format=Format.FORWARDREF),
|
||||
{"a": int, "return": EqualToForwardRef('nonexistent', owner=lru)},
|
||||
)
|
||||
with self.assertRaises(NameError):
|
||||
get_annotations(orig, format=Format.VALUE)
|
||||
with self.assertRaises(NameError):
|
||||
get_annotations(lru, format=Format.VALUE)
|
||||
|
||||
@support.skip_on_s390x
|
||||
@unittest.skipIf(support.is_wasi, "WASI has limited C stack")
|
||||
@support.skip_if_sanitizer("requires deep stack", ub=True, thread=True)
|
||||
@support.skip_emscripten_stack_overflow()
|
||||
def test_lru_recursion(self):
|
||||
|
||||
@self.module.lru_cache
|
||||
@@ -1973,15 +2182,12 @@ class TestLRU:
|
||||
return n
|
||||
return fib(n-1) + fib(n-2)
|
||||
|
||||
if not support.Py_DEBUG:
|
||||
depth = support.get_c_recursion_limit()*2//7
|
||||
with support.infinite_recursion():
|
||||
fib(depth)
|
||||
fib(100)
|
||||
if self.module == c_functools:
|
||||
fib.cache_clear()
|
||||
with support.infinite_recursion():
|
||||
with self.assertRaises(RecursionError):
|
||||
fib(10000)
|
||||
fib(support.exceeds_recursion_limit())
|
||||
|
||||
|
||||
@py_functools.lru_cache()
|
||||
@@ -2563,15 +2769,15 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
a.t(0)
|
||||
self.assertEqual(a.arg, "int")
|
||||
aa = A()
|
||||
self.assertFalse(hasattr(aa, 'arg'))
|
||||
self.assertNotHasAttr(aa, 'arg')
|
||||
a.t('')
|
||||
self.assertEqual(a.arg, "str")
|
||||
aa = A()
|
||||
self.assertFalse(hasattr(aa, 'arg'))
|
||||
self.assertNotHasAttr(aa, 'arg')
|
||||
a.t(0.0)
|
||||
self.assertEqual(a.arg, "base")
|
||||
aa = A()
|
||||
self.assertFalse(hasattr(aa, 'arg'))
|
||||
self.assertNotHasAttr(aa, 'arg')
|
||||
|
||||
def test_staticmethod_register(self):
|
||||
class A:
|
||||
@@ -2806,6 +3012,8 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
A().static_func
|
||||
):
|
||||
with self.subTest(meth=meth):
|
||||
self.assertEqual(meth.__module__, __name__)
|
||||
self.assertEqual(type(meth).__module__, 'functools')
|
||||
self.assertEqual(meth.__qualname__, prefix + meth.__name__)
|
||||
self.assertEqual(meth.__doc__,
|
||||
('My function docstring'
|
||||
@@ -2820,6 +3028,67 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
self.assertEqual(A.static_func.__name__, 'static_func')
|
||||
self.assertEqual(A().static_func.__name__, 'static_func')
|
||||
|
||||
def test_method_repr(self):
|
||||
class Callable:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
class CallableWithName:
|
||||
__name__ = 'NOQUALNAME'
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
class A:
|
||||
@functools.singledispatchmethod
|
||||
def func(self, arg):
|
||||
pass
|
||||
@functools.singledispatchmethod
|
||||
@classmethod
|
||||
def cls_func(cls, arg):
|
||||
pass
|
||||
@functools.singledispatchmethod
|
||||
@staticmethod
|
||||
def static_func(arg):
|
||||
pass
|
||||
# No __qualname__, only __name__
|
||||
no_qualname = functools.singledispatchmethod(CallableWithName())
|
||||
# No __qualname__, no __name__
|
||||
no_name = functools.singledispatchmethod(Callable())
|
||||
|
||||
self.assertEqual(repr(A.__dict__['func']),
|
||||
f'<single dispatch method descriptor {A.__qualname__}.func>')
|
||||
self.assertEqual(repr(A.__dict__['cls_func']),
|
||||
f'<single dispatch method descriptor {A.__qualname__}.cls_func>')
|
||||
self.assertEqual(repr(A.__dict__['static_func']),
|
||||
f'<single dispatch method descriptor {A.__qualname__}.static_func>')
|
||||
self.assertEqual(repr(A.__dict__['no_qualname']),
|
||||
f'<single dispatch method descriptor NOQUALNAME>')
|
||||
self.assertEqual(repr(A.__dict__['no_name']),
|
||||
f'<single dispatch method descriptor ?>')
|
||||
|
||||
self.assertEqual(repr(A.func),
|
||||
f'<single dispatch method {A.__qualname__}.func>')
|
||||
self.assertEqual(repr(A.cls_func),
|
||||
f'<single dispatch method {A.__qualname__}.cls_func>')
|
||||
self.assertEqual(repr(A.static_func),
|
||||
f'<single dispatch method {A.__qualname__}.static_func>')
|
||||
self.assertEqual(repr(A.no_qualname),
|
||||
f'<single dispatch method NOQUALNAME>')
|
||||
self.assertEqual(repr(A.no_name),
|
||||
f'<single dispatch method ?>')
|
||||
|
||||
a = A()
|
||||
self.assertEqual(repr(a.func),
|
||||
f'<bound single dispatch method {A.__qualname__}.func of {a!r}>')
|
||||
self.assertEqual(repr(a.cls_func),
|
||||
f'<bound single dispatch method {A.__qualname__}.cls_func of {a!r}>')
|
||||
self.assertEqual(repr(a.static_func),
|
||||
f'<bound single dispatch method {A.__qualname__}.static_func of {a!r}>')
|
||||
self.assertEqual(repr(a.no_qualname),
|
||||
f'<bound single dispatch method NOQUALNAME of {a!r}>')
|
||||
self.assertEqual(repr(a.no_name),
|
||||
f'<bound single dispatch method ? of {a!r}>')
|
||||
|
||||
def test_double_wrapped_methods(self):
|
||||
def classmethod_friendly_decorator(func):
|
||||
wrapped = func.__func__
|
||||
@@ -2836,7 +3105,8 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
try:
|
||||
yield str(arg)
|
||||
finally:
|
||||
return 'Done'
|
||||
pass
|
||||
return 'Done'
|
||||
|
||||
@classmethod_friendly_decorator
|
||||
@classmethod
|
||||
@@ -2852,7 +3122,8 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
try:
|
||||
yield str(arg)
|
||||
finally:
|
||||
return 'Done'
|
||||
pass
|
||||
return 'Done'
|
||||
|
||||
@functools.singledispatchmethod
|
||||
@classmethod_friendly_decorator
|
||||
@@ -2922,7 +3193,6 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
'decorated_classmethod'
|
||||
)
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_invalid_registrations(self):
|
||||
msg_prefix = "Invalid first argument to `register()`: "
|
||||
msg_suffix = (
|
||||
@@ -2936,16 +3206,16 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
@i.register(42)
|
||||
def _(arg):
|
||||
return "I annotated with a non-type"
|
||||
self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
|
||||
self.assertTrue(str(exc.exception).endswith(msg_suffix))
|
||||
self.assertStartsWith(str(exc.exception), msg_prefix + "42")
|
||||
self.assertEndsWith(str(exc.exception), msg_suffix)
|
||||
with self.assertRaises(TypeError) as exc:
|
||||
@i.register
|
||||
def _(arg):
|
||||
return "I forgot to annotate"
|
||||
self.assertTrue(str(exc.exception).startswith(msg_prefix +
|
||||
self.assertStartsWith(str(exc.exception), msg_prefix +
|
||||
"<function TestSingleDispatch.test_invalid_registrations.<locals>._"
|
||||
))
|
||||
self.assertTrue(str(exc.exception).endswith(msg_suffix))
|
||||
)
|
||||
self.assertEndsWith(str(exc.exception), msg_suffix)
|
||||
|
||||
with self.assertRaises(TypeError) as exc:
|
||||
@i.register
|
||||
@@ -2955,23 +3225,23 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
# types from `typing`. Instead, annotate with regular types
|
||||
# or ABCs.
|
||||
return "I annotated with a generic collection"
|
||||
self.assertTrue(str(exc.exception).startswith(
|
||||
self.assertStartsWith(str(exc.exception),
|
||||
"Invalid annotation for 'arg'."
|
||||
))
|
||||
self.assertTrue(str(exc.exception).endswith(
|
||||
)
|
||||
self.assertEndsWith(str(exc.exception),
|
||||
'typing.Iterable[str] is not a class.'
|
||||
))
|
||||
)
|
||||
|
||||
with self.assertRaises(TypeError) as exc:
|
||||
@i.register
|
||||
def _(arg: typing.Union[int, typing.Iterable[str]]):
|
||||
return "Invalid Union"
|
||||
self.assertTrue(str(exc.exception).startswith(
|
||||
self.assertStartsWith(str(exc.exception),
|
||||
"Invalid annotation for 'arg'."
|
||||
))
|
||||
self.assertTrue(str(exc.exception).endswith(
|
||||
'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
|
||||
))
|
||||
)
|
||||
self.assertEndsWith(str(exc.exception),
|
||||
'int | typing.Iterable[str] not all arguments are classes.'
|
||||
)
|
||||
|
||||
def test_invalid_positional_argument(self):
|
||||
@functools.singledispatch
|
||||
@@ -3118,6 +3388,28 @@ class TestSingleDispatch(unittest.TestCase):
|
||||
self.assertEqual(f(""), "default")
|
||||
self.assertEqual(f(b""), "default")
|
||||
|
||||
def test_forward_reference(self):
|
||||
@functools.singledispatch
|
||||
def f(arg, arg2=None):
|
||||
return "default"
|
||||
|
||||
@f.register
|
||||
def _(arg: str, arg2: undefined = None):
|
||||
return "forward reference"
|
||||
|
||||
self.assertEqual(f(1), "default")
|
||||
self.assertEqual(f(""), "forward reference")
|
||||
|
||||
def test_unresolved_forward_reference(self):
|
||||
@functools.singledispatch
|
||||
def f(arg):
|
||||
return "default"
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "is an unresolved forward reference"):
|
||||
@f.register
|
||||
def _(arg: undefined):
|
||||
return "forward reference"
|
||||
|
||||
def test_method_equal_instances(self):
|
||||
# gh-127750: Reference to self was cached
|
||||
class A:
|
||||
@@ -3300,7 +3592,7 @@ class TestCachedProperty(unittest.TestCase):
|
||||
):
|
||||
MyClass.prop
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_reuse_different_names(self):
|
||||
"""Disallow this case because decorated function a would not be cached."""
|
||||
with self.assertRaises(TypeError) as ctx:
|
||||
|
||||
4
Lib/test/test_weakref.py
vendored
4
Lib/test/test_weakref.py
vendored
@@ -2059,6 +2059,7 @@ class MappingTestCase(TestBase):
|
||||
# copying should not result in a crash.
|
||||
self.check_threaded_weak_dict_copy(weakref.WeakKeyDictionary, False)
|
||||
|
||||
@unittest.skip('TODO: RUSTPYTHON; occasionally crash (malloc corruption)')
|
||||
@threading_helper.requires_working_threading()
|
||||
@support.requires_resource('cpu')
|
||||
def test_threaded_weak_key_dict_deepcopy(self):
|
||||
@@ -2066,13 +2067,14 @@ class MappingTestCase(TestBase):
|
||||
# copying should not result in a crash.
|
||||
self.check_threaded_weak_dict_copy(weakref.WeakKeyDictionary, True)
|
||||
|
||||
@unittest.skip('TODO: RUSTPYTHON; occasionally crash (Exit code -6)')
|
||||
@unittest.skip('TODO: RUSTPYTHON; occasionally crash (malloc corruption)')
|
||||
@threading_helper.requires_working_threading()
|
||||
def test_threaded_weak_value_dict_copy(self):
|
||||
# Issue #35615: Weakref keys or values getting GC'ed during dict
|
||||
# copying should not result in a crash.
|
||||
self.check_threaded_weak_dict_copy(weakref.WeakValueDictionary, False)
|
||||
|
||||
@unittest.skip('TODO: RUSTPYTHON; occasionally crash (malloc corruption)')
|
||||
@threading_helper.requires_working_threading()
|
||||
@support.requires_resource('cpu')
|
||||
def test_threaded_weak_value_dict_deepcopy(self):
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
class::PyClassImpl,
|
||||
common::lock::PyMutex,
|
||||
function::FuncArgs,
|
||||
function::{FuncArgs, PySetterValue},
|
||||
types::{Constructor, GetDescriptor, Initializer, Representable},
|
||||
};
|
||||
|
||||
@@ -158,6 +158,27 @@ impl PyClassMethod {
|
||||
self.callable.lock().get_attr("__annotations__", vm)
|
||||
}
|
||||
|
||||
#[pygetset(setter)]
|
||||
fn set___annotations__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match value {
|
||||
PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotations__", v, vm),
|
||||
PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn __annotate__(&self, vm: &VirtualMachine) -> PyResult {
|
||||
self.callable.lock().get_attr("__annotate__", vm)
|
||||
}
|
||||
|
||||
#[pygetset(setter)]
|
||||
fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match value {
|
||||
PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotate__", v, vm),
|
||||
PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn __isabstractmethod__(&self, vm: &VirtualMachine) -> PyObjectRef {
|
||||
match vm.get_attribute_opt(self.callable.lock().clone(), "__isabstractmethod__") {
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
class::PyClassImpl,
|
||||
common::lock::PyMutex,
|
||||
function::FuncArgs,
|
||||
function::{FuncArgs, PySetterValue},
|
||||
types::{Callable, Constructor, GetDescriptor, Initializer, Representable},
|
||||
};
|
||||
|
||||
@@ -121,6 +121,27 @@ impl PyStaticMethod {
|
||||
self.callable.lock().get_attr("__annotations__", vm)
|
||||
}
|
||||
|
||||
#[pygetset(setter)]
|
||||
fn set___annotations__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match value {
|
||||
PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotations__", v, vm),
|
||||
PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn __annotate__(&self, vm: &VirtualMachine) -> PyResult {
|
||||
self.callable.lock().get_attr("__annotate__", vm)
|
||||
}
|
||||
|
||||
#[pygetset(setter)]
|
||||
fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match value {
|
||||
PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotate__", v, vm),
|
||||
PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn __isabstractmethod__(&self, vm: &VirtualMachine) -> PyObjectRef {
|
||||
match vm.get_attribute_opt(self.callable.lock().clone(), "__isabstractmethod__") {
|
||||
|
||||
@@ -876,11 +876,13 @@ impl PyType {
|
||||
}
|
||||
|
||||
#[pygetset(setter)]
|
||||
fn set___annotate__(&self, value: Option<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
|
||||
if value.is_none() {
|
||||
return Err(vm.new_type_error("cannot delete __annotate__ attribute".to_owned()));
|
||||
}
|
||||
let value = value.unwrap();
|
||||
fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let value = match value {
|
||||
PySetterValue::Delete => {
|
||||
return Err(vm.new_type_error("cannot delete __annotate__ attribute".to_owned()));
|
||||
}
|
||||
PySetterValue::Assign(v) => v,
|
||||
};
|
||||
|
||||
if self.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) {
|
||||
return Err(vm.new_type_error(format!(
|
||||
|
||||
@@ -4,28 +4,42 @@ pub(crate) use _functools::make_module;
|
||||
mod _functools {
|
||||
use crate::{
|
||||
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
builtins::{PyDict, PyGenericAlias, PyTuple, PyTypeRef},
|
||||
builtins::{PyBoundMethod, PyDict, PyGenericAlias, PyTuple, PyType, PyTypeRef},
|
||||
common::lock::PyRwLock,
|
||||
function::{FuncArgs, KwArgs, OptionalArg},
|
||||
function::{FuncArgs, KwArgs, OptionalOption},
|
||||
object::AsObject,
|
||||
protocol::PyIter,
|
||||
pyclass,
|
||||
recursion::ReprGuard,
|
||||
types::{Callable, Constructor, Representable},
|
||||
types::{Callable, Constructor, GetDescriptor, Representable},
|
||||
};
|
||||
use indexmap::IndexMap;
|
||||
|
||||
#[pyfunction]
|
||||
fn reduce(
|
||||
#[derive(FromArgs)]
|
||||
struct ReduceArgs {
|
||||
function: PyObjectRef,
|
||||
iterator: PyIter,
|
||||
start_value: OptionalArg<PyObjectRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
#[pyarg(any, optional, name = "initial")]
|
||||
initial: OptionalOption<PyObjectRef>,
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn reduce(args: ReduceArgs, vm: &VirtualMachine) -> PyResult {
|
||||
let ReduceArgs {
|
||||
function,
|
||||
iterator,
|
||||
initial,
|
||||
} = args;
|
||||
let mut iter = iterator.iter_without_hint(vm)?;
|
||||
let start_value = if let OptionalArg::Present(val) = start_value {
|
||||
val
|
||||
// OptionalOption distinguishes between:
|
||||
// - Missing: no argument provided → use first element from iterator
|
||||
// - Present(None): explicitly passed None → use None as initial value
|
||||
// - Present(Some(v)): passed a value → use that value
|
||||
let start_value = if let Some(val) = initial.into_option() {
|
||||
// initial was provided (could be None or Some value)
|
||||
val.unwrap_or_else(|| vm.ctx.none())
|
||||
} else {
|
||||
// initial was not provided at all
|
||||
iter.next().transpose()?.ok_or_else(|| {
|
||||
let exc_type = vm.ctx.exceptions.type_error.to_owned();
|
||||
vm.new_exception_msg(
|
||||
@@ -42,6 +56,72 @@ mod _functools {
|
||||
Ok(accumulator)
|
||||
}
|
||||
|
||||
// Placeholder singleton for partial arguments
|
||||
// The singleton is stored as _instance on the type class
|
||||
#[pyattr]
|
||||
#[allow(non_snake_case)]
|
||||
fn Placeholder(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let placeholder = PyPlaceholderType.into_pyobject(vm);
|
||||
// Store the singleton on the type class for slot_new to find
|
||||
let typ = placeholder.class();
|
||||
typ.set_attr(vm.ctx.intern_str("_instance"), placeholder.clone());
|
||||
placeholder
|
||||
}
|
||||
|
||||
#[pyattr]
|
||||
#[pyclass(name = "_PlaceholderType", module = "functools")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
pub struct PyPlaceholderType;
|
||||
|
||||
impl Constructor for PyPlaceholderType {
|
||||
type Args = FuncArgs;
|
||||
|
||||
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
|
||||
if !args.args.is_empty() || !args.kwargs.is_empty() {
|
||||
return Err(vm.new_type_error("_PlaceholderType takes no arguments".to_owned()));
|
||||
}
|
||||
// Return the singleton stored on the type class
|
||||
if let Some(instance) = cls.get_attr(vm.ctx.intern_str("_instance")) {
|
||||
return Ok(instance);
|
||||
}
|
||||
// Fallback: create a new instance (shouldn't happen for base type after module init)
|
||||
Ok(PyPlaceholderType.into_pyobject(vm))
|
||||
}
|
||||
|
||||
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
|
||||
// This is never called because we override slot_new
|
||||
Ok(PyPlaceholderType)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(with(Constructor, Representable))]
|
||||
impl PyPlaceholderType {
|
||||
#[pymethod]
|
||||
fn __reduce__(&self) -> &'static str {
|
||||
"Placeholder"
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn __init_subclass__(_cls: PyTypeRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
Err(vm.new_type_error("cannot subclass '_PlaceholderType'".to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Representable for PyPlaceholderType {
|
||||
#[inline]
|
||||
fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
|
||||
Ok("Placeholder".to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
fn is_placeholder(obj: &PyObjectRef) -> bool {
|
||||
&*obj.class().name() == "_PlaceholderType"
|
||||
}
|
||||
|
||||
fn count_placeholders(args: &[PyObjectRef]) -> usize {
|
||||
args.iter().filter(|a| is_placeholder(a)).count()
|
||||
}
|
||||
|
||||
#[pyattr]
|
||||
#[pyclass(name = "partial", module = "functools")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
@@ -54,9 +134,13 @@ mod _functools {
|
||||
func: PyObjectRef,
|
||||
args: PyRef<PyTuple>,
|
||||
keywords: PyRef<PyDict>,
|
||||
phcount: usize,
|
||||
}
|
||||
|
||||
#[pyclass(with(Constructor, Callable, Representable), flags(BASETYPE, HAS_DICT))]
|
||||
#[pyclass(
|
||||
with(Constructor, Callable, GetDescriptor, Representable),
|
||||
flags(BASETYPE, HAS_DICT)
|
||||
)]
|
||||
impl PyPartial {
|
||||
#[pygetset]
|
||||
fn func(&self) -> PyObjectRef {
|
||||
@@ -157,6 +241,13 @@ mod _functools {
|
||||
}
|
||||
};
|
||||
|
||||
// Validate no trailing placeholders
|
||||
let args_slice = args_tuple.as_slice();
|
||||
if !args_slice.is_empty() && is_placeholder(args_slice.last().unwrap()) {
|
||||
return Err(vm.new_type_error("trailing Placeholders are not allowed".to_owned()));
|
||||
}
|
||||
let phcount = count_placeholders(args_slice);
|
||||
|
||||
// Actually update the state
|
||||
let mut inner = zelf.inner.write();
|
||||
inner.func = func.clone();
|
||||
@@ -165,6 +256,7 @@ mod _functools {
|
||||
|
||||
// Handle keywords - keep the original type
|
||||
inner.keywords = keywords_dict;
|
||||
inner.phcount = phcount;
|
||||
|
||||
// Update __dict__ if provided
|
||||
let Some(instance_dict) = zelf.as_object().dict() else {
|
||||
@@ -218,17 +310,54 @@ mod _functools {
|
||||
return Err(vm.new_type_error("the first argument must be callable"));
|
||||
}
|
||||
|
||||
// Check for placeholders in kwargs
|
||||
for (key, value) in &args.kwargs {
|
||||
if is_placeholder(value) {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"Placeholder cannot be passed as a keyword argument to partial(). \
|
||||
Did you mean partial(..., {}=Placeholder, ...)(value)?",
|
||||
key
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle nested partial objects
|
||||
let (final_func, final_args, final_keywords) =
|
||||
if let Some(partial) = func.downcast_ref::<Self>() {
|
||||
let inner = partial.inner.read();
|
||||
let mut combined_args = inner.args.as_slice().to_vec();
|
||||
combined_args.extend_from_slice(args_slice);
|
||||
(inner.func.clone(), combined_args, inner.keywords.clone())
|
||||
let stored_args = inner.args.as_slice();
|
||||
|
||||
// Merge placeholders: replace placeholders in stored_args with new args
|
||||
let mut merged_args = Vec::with_capacity(stored_args.len() + args_slice.len());
|
||||
let mut new_args_iter = args_slice.iter();
|
||||
|
||||
for stored_arg in stored_args {
|
||||
if is_placeholder(stored_arg) {
|
||||
// Replace placeholder with next new arg, or keep placeholder
|
||||
if let Some(new_arg) = new_args_iter.next() {
|
||||
merged_args.push(new_arg.clone());
|
||||
} else {
|
||||
merged_args.push(stored_arg.clone());
|
||||
}
|
||||
} else {
|
||||
merged_args.push(stored_arg.clone());
|
||||
}
|
||||
}
|
||||
// Append remaining new args
|
||||
merged_args.extend(new_args_iter.cloned());
|
||||
|
||||
(inner.func.clone(), merged_args, inner.keywords.clone())
|
||||
} else {
|
||||
(func.clone(), args_slice.to_vec(), vm.ctx.new_dict())
|
||||
};
|
||||
|
||||
// Trailing placeholders are not allowed
|
||||
if !final_args.is_empty() && is_placeholder(final_args.last().unwrap()) {
|
||||
return Err(vm.new_type_error("trailing Placeholders are not allowed".to_owned()));
|
||||
}
|
||||
|
||||
let phcount = count_placeholders(&final_args);
|
||||
|
||||
// Add new keywords
|
||||
for (key, value) in args.kwargs {
|
||||
final_keywords.set_item(vm.ctx.intern_str(key.as_str()), value, vm)?;
|
||||
@@ -239,6 +368,7 @@ mod _functools {
|
||||
func: final_func,
|
||||
args: vm.ctx.new_tuple(final_args),
|
||||
keywords: final_keywords,
|
||||
phcount,
|
||||
}),
|
||||
})
|
||||
}
|
||||
@@ -249,17 +379,44 @@ mod _functools {
|
||||
|
||||
fn call(zelf: &Py<Self>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
|
||||
// Clone and release lock before calling Python code to prevent deadlock
|
||||
let (func, stored_args, keywords) = {
|
||||
let (func, stored_args, keywords, phcount) = {
|
||||
let inner = zelf.inner.read();
|
||||
(
|
||||
inner.func.clone(),
|
||||
inner.args.clone(),
|
||||
inner.keywords.clone(),
|
||||
inner.phcount,
|
||||
)
|
||||
};
|
||||
|
||||
let mut combined_args = stored_args.as_slice().to_vec();
|
||||
combined_args.extend_from_slice(&args.args);
|
||||
// Check if we have enough args to fill placeholders
|
||||
if phcount > 0 && args.args.len() < phcount {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"missing positional arguments in 'partial' call; expected at least {}, got {}",
|
||||
phcount,
|
||||
args.args.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Build combined args, replacing placeholders
|
||||
let mut combined_args = Vec::with_capacity(stored_args.len() + args.args.len());
|
||||
let mut new_args_iter = args.args.iter();
|
||||
|
||||
for stored_arg in stored_args.as_slice() {
|
||||
if is_placeholder(stored_arg) {
|
||||
// Replace placeholder with next new arg
|
||||
if let Some(new_arg) = new_args_iter.next() {
|
||||
combined_args.push(new_arg.clone());
|
||||
} else {
|
||||
// This shouldn't happen if phcount check passed
|
||||
combined_args.push(stored_arg.clone());
|
||||
}
|
||||
} else {
|
||||
combined_args.push(stored_arg.clone());
|
||||
}
|
||||
}
|
||||
// Append remaining new args
|
||||
combined_args.extend(new_args_iter.cloned());
|
||||
|
||||
// Merge keywords from self.keywords and args.kwargs
|
||||
let mut final_kwargs = IndexMap::new();
|
||||
@@ -281,6 +438,21 @@ mod _functools {
|
||||
}
|
||||
}
|
||||
|
||||
impl GetDescriptor for PyPartial {
|
||||
fn descr_get(
|
||||
zelf: PyObjectRef,
|
||||
obj: Option<PyObjectRef>,
|
||||
_cls: Option<PyObjectRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
let obj = match obj {
|
||||
Some(obj) if !vm.is_none(&obj) => obj,
|
||||
_ => return Ok(zelf),
|
||||
};
|
||||
Ok(PyBoundMethod::new(obj, zelf).into_ref(&vm.ctx).into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Representable for PyPartial {
|
||||
#[inline]
|
||||
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
|
||||
|
||||
Reference in New Issue
Block a user