diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index a45760eaf..698ae7346 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -51,6 +51,7 @@ numer orelse pathconfig patma +phcount platstdlib posonlyarg posonlyargs diff --git a/Lib/functools.py b/Lib/functools.py index 4c1175b81..df4660eef 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -6,23 +6,22 @@ # Written by Nick Coghlan , # Raymond Hettinger , # and Łukasz Langa . -# Copyright (C) 2006-2013 Python Software Foundation. +# Copyright (C) 2006 Python Software Foundation. # See C source code for _functools credits/copyright __all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', 'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod', - 'cached_property'] + 'cached_property', 'Placeholder'] from abc import get_cache_token from collections import namedtuple -# import types, weakref # Deferred to single_dispatch() +# import weakref # Deferred to single_dispatch() +from operator import itemgetter from reprlib import recursive_repr +from types import GenericAlias, MethodType, MappingProxyType, UnionType from _thread import RLock -# Avoid importing types, so we can speedup import time -GenericAlias = type(list[int]) - ################################################################################ ### update_wrapper() and wraps() decorator ################################################################################ @@ -31,7 +30,7 @@ GenericAlias = type(list[int]) # wrapper functions that can handle naive introspection WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', - '__annotations__', '__type_params__') + '__annotate__', '__type_params__') WRAPPER_UPDATES = ('__dict__',) def update_wrapper(wrapper, wrapped, @@ -237,7 +236,7 @@ _initial_missing = object() def reduce(function, sequence, initial=_initial_missing): """ - reduce(function, iterable[, initial], /) -> value + reduce(function, iterable, /[, initial]) -> value Apply a function of two arguments cumulatively to the items of an iterable, from left to right. @@ -265,63 +264,138 @@ def reduce(function, sequence, initial=_initial_missing): return value -try: - from _functools import reduce -except ImportError: - pass - ################################################################################ ### partial() argument application ################################################################################ + +class _PlaceholderType: + """The type of the Placeholder singleton. + + Used as a placeholder for partial arguments. + """ + __instance = None + __slots__ = () + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError(f"type '{cls.__name__}' is not an acceptable base type") + + def __new__(cls): + if cls.__instance is None: + cls.__instance = object.__new__(cls) + return cls.__instance + + def __repr__(self): + return 'Placeholder' + + def __reduce__(self): + return 'Placeholder' + +Placeholder = _PlaceholderType() + +def _partial_prepare_merger(args): + if not args: + return 0, None + nargs = len(args) + order = [] + j = nargs + for i, a in enumerate(args): + if a is Placeholder: + order.append(j) + j += 1 + else: + order.append(i) + phcount = j - nargs + merger = itemgetter(*order) if phcount else None + return phcount, merger + +def _partial_new(cls, func, /, *args, **keywords): + if issubclass(cls, partial): + base_cls = partial + if not callable(func): + raise TypeError("the first argument must be callable") + else: + base_cls = partialmethod + # func could be a descriptor like classmethod which isn't callable + if not callable(func) and not hasattr(func, "__get__"): + raise TypeError(f"the first argument {func!r} must be a callable " + "or a descriptor") + if args and args[-1] is Placeholder: + raise TypeError("trailing Placeholders are not allowed") + for value in keywords.values(): + if value is Placeholder: + raise TypeError("Placeholder cannot be passed as a keyword argument") + if isinstance(func, base_cls): + pto_phcount = func._phcount + tot_args = func.args + if args: + tot_args += args + if pto_phcount: + # merge args with args of `func` which is `partial` + nargs = len(args) + if nargs < pto_phcount: + tot_args += (Placeholder,) * (pto_phcount - nargs) + tot_args = func._merger(tot_args) + if nargs > pto_phcount: + tot_args += args[pto_phcount:] + phcount, merger = _partial_prepare_merger(tot_args) + else: # works for both pto_phcount == 0 and != 0 + phcount, merger = pto_phcount, func._merger + keywords = {**func.keywords, **keywords} + func = func.func + else: + tot_args = args + phcount, merger = _partial_prepare_merger(tot_args) + + self = object.__new__(cls) + self.func = func + self.args = tot_args + self.keywords = keywords + self._phcount = phcount + self._merger = merger + return self + +def _partial_repr(self): + cls = type(self) + module = cls.__module__ + qualname = cls.__qualname__ + args = [repr(self.func)] + args.extend(map(repr, self.args)) + args.extend(f"{k}={v!r}" for k, v in self.keywords.items()) + return f"{module}.{qualname}({', '.join(args)})" + # Purely functional, no descriptor behaviour class partial: """New function with partial application of the given arguments and keywords. """ - __slots__ = "func", "args", "keywords", "__dict__", "__weakref__" + __slots__ = ("func", "args", "keywords", "_phcount", "_merger", + "__dict__", "__weakref__") - def __new__(cls, func, /, *args, **keywords): - if not callable(func): - raise TypeError("the first argument must be callable") - - if isinstance(func, partial): - args = func.args + args - keywords = {**func.keywords, **keywords} - func = func.func - - self = super(partial, cls).__new__(cls) - - self.func = func - self.args = args - self.keywords = keywords - return self + __new__ = _partial_new + __repr__ = recursive_repr()(_partial_repr) def __call__(self, /, *args, **keywords): + phcount = self._phcount + if phcount: + try: + pto_args = self._merger(self.args + args) + args = args[phcount:] + except IndexError: + raise TypeError("missing positional arguments " + "in 'partial' call; expected " + f"at least {phcount}, got {len(args)}") + else: + pto_args = self.args keywords = {**self.keywords, **keywords} - return self.func(*self.args, *args, **keywords) - - @recursive_repr() - def __repr__(self): - cls = type(self) - qualname = cls.__qualname__ - module = cls.__module__ - args = [repr(self.func)] - args.extend(repr(x) for x in self.args) - args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items()) - return f"{module}.{qualname}({', '.join(args)})" + return self.func(*pto_args, *args, **keywords) def __get__(self, obj, objtype=None): if obj is None: return self - import warnings - warnings.warn('functools.partial will be a method descriptor in ' - 'future Python versions; wrap it in staticmethod() ' - 'if you want to preserve the old behavior', - FutureWarning, 2) - return self + return MethodType(self, obj) def __reduce__(self): return type(self), (self.func,), (self.func, self.args, @@ -338,6 +412,10 @@ class partial: (namespace is not None and not isinstance(namespace, dict))): raise TypeError("invalid partial state") + if args and args[-1] is Placeholder: + raise TypeError("trailing Placeholders are not allowed") + phcount, merger = _partial_prepare_merger(args) + args = tuple(args) # just in case it's a subclass if kwds is None: kwds = {} @@ -350,56 +428,43 @@ class partial: self.func = func self.args = args self.keywords = kwds + self._phcount = phcount + self._merger = merger __class_getitem__ = classmethod(GenericAlias) try: - from _functools import partial + from _functools import partial, Placeholder, _PlaceholderType except ImportError: pass # Descriptor version -class partialmethod(object): +class partialmethod: """Method descriptor with partial application of the given arguments and keywords. Supports wrapping existing descriptors and handles non-descriptor callables as instance methods. """ - - def __init__(self, func, /, *args, **keywords): - if not callable(func) and not hasattr(func, "__get__"): - raise TypeError("{!r} is not callable or a descriptor" - .format(func)) - - # func could be a descriptor like classmethod which isn't callable, - # so we can't inherit from partial (it verifies func is callable) - if isinstance(func, partialmethod): - # flattening is mandatory in order to place cls/self before all - # other arguments - # it's also more efficient since only one function will be called - self.func = func.func - self.args = func.args + args - self.keywords = {**func.keywords, **keywords} - else: - self.func = func - self.args = args - self.keywords = keywords - - def __repr__(self): - cls = type(self) - module = cls.__module__ - qualname = cls.__qualname__ - args = [repr(self.func)] - args.extend(map(repr, self.args)) - args.extend(f"{k}={v!r}" for k, v in self.keywords.items()) - return f"{module}.{qualname}({', '.join(args)})" + __new__ = _partial_new + __repr__ = _partial_repr def _make_unbound_method(self): def _method(cls_or_self, /, *args, **keywords): + phcount = self._phcount + if phcount: + try: + pto_args = self._merger(self.args + args) + args = args[phcount:] + except IndexError: + raise TypeError("missing positional arguments " + "in 'partialmethod' call; expected " + f"at least {phcount}, got {len(args)}") + else: + pto_args = self.args keywords = {**self.keywords, **keywords} - return self.func(cls_or_self, *self.args, *args, **keywords) + return self.func(cls_or_self, *pto_args, *args, **keywords) _method.__isabstractmethod__ = self.__isabstractmethod__ _method.__partialmethod__ = self return _method @@ -407,7 +472,7 @@ class partialmethod(object): def __get__(self, obj, cls=None): get = getattr(self.func, "__get__", None) result = None - if get is not None and not isinstance(self.func, partial): + if get is not None: new_func = get(obj, cls) if new_func is not self.func: # Assume __get__ returning something new indicates the @@ -454,22 +519,6 @@ def _unwrap_partialmethod(func): _CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) -class _HashedSeq(list): - """ This class guarantees that hash() will be called no more than once - per element. This is important because the lru_cache() will hash - the key multiple times on a cache miss. - - """ - - __slots__ = 'hashvalue' - - def __init__(self, tup, hash=hash): - self[:] = tup - self.hashvalue = hash(tup) - - def __hash__(self): - return self.hashvalue - def _make_key(args, kwds, typed, kwd_mark = (object(),), fasttypes = {int, str}, @@ -499,7 +548,7 @@ def _make_key(args, kwds, typed, key += tuple(type(v) for v in kwds.values()) elif len(key) == 1 and type(key[0]) in fasttypes: return key[0] - return _HashedSeq(key) + return key def lru_cache(maxsize=128, typed=False): """Least-recently-used cache decorator. @@ -835,7 +884,7 @@ def singledispatch(func): # There are many programs that use functools without singledispatch, so we # trade-off making singledispatch marginally slower for the benefit of # making start-up of such applications slightly faster. - import types, weakref + import weakref registry = {} dispatch_cache = weakref.WeakKeyDictionary() @@ -864,16 +913,11 @@ def singledispatch(func): dispatch_cache[cls] = impl return impl - def _is_union_type(cls): - from typing import get_origin, Union - return get_origin(cls) in {Union, types.UnionType} - def _is_valid_dispatch_type(cls): if isinstance(cls, type): return True - from typing import get_args - return (_is_union_type(cls) and - all(isinstance(arg, type) for arg in get_args(cls))) + return (isinstance(cls, UnionType) and + all(isinstance(arg, type) for arg in cls.__args__)) def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -891,8 +935,8 @@ def singledispatch(func): f"Invalid first argument to `register()`. " f"{cls!r} is not a class or union type." ) - ann = getattr(cls, '__annotations__', {}) - if not ann: + ann = getattr(cls, '__annotate__', None) + if ann is None: raise TypeError( f"Invalid first argument to `register()`: {cls!r}. " f"Use either `@register(some_class)` or plain `@register` " @@ -902,23 +946,27 @@ def singledispatch(func): # only import typing if annotation parsing is necessary from typing import get_type_hints - argname, cls = next(iter(get_type_hints(func).items())) + from annotationlib import Format, ForwardRef + argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items())) if not _is_valid_dispatch_type(cls): - if _is_union_type(cls): + if isinstance(cls, UnionType): raise TypeError( f"Invalid annotation for {argname!r}. " f"{cls!r} not all arguments are classes." ) + elif isinstance(cls, ForwardRef): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is an unresolved forward reference." + ) else: raise TypeError( f"Invalid annotation for {argname!r}. " f"{cls!r} is not a class." ) - if _is_union_type(cls): - from typing import get_args - - for arg in get_args(cls): + if isinstance(cls, UnionType): + for arg in cls.__args__: registry[arg] = func else: registry[cls] = func @@ -937,7 +985,7 @@ def singledispatch(func): registry[object] = func wrapper.register = register wrapper.dispatch = dispatch - wrapper.registry = types.MappingProxyType(registry) + wrapper.registry = MappingProxyType(registry) wrapper._clear_cache = dispatch_cache.clear update_wrapper(wrapper, func) return wrapper @@ -947,8 +995,7 @@ def singledispatch(func): class singledispatchmethod: """Single-dispatch generic method descriptor. - Supports wrapping existing descriptors and handles non-descriptor - callables as instance methods. + Supports wrapping existing descriptors. """ def __init__(self, func): @@ -966,24 +1013,77 @@ class singledispatchmethod: return self.dispatcher.register(cls, func=method) def __get__(self, obj, cls=None): - dispatch = self.dispatcher.dispatch - funcname = getattr(self.func, '__name__', 'singledispatchmethod method') - def _method(*args, **kwargs): - if not args: - raise TypeError(f'{funcname} requires at least ' - '1 positional argument') - return dispatch(args[0].__class__).__get__(obj, cls)(*args, **kwargs) - - _method.__isabstractmethod__ = self.__isabstractmethod__ - _method.register = self.register - update_wrapper(_method, self.func) - - return _method + return _singledispatchmethod_get(self, obj, cls) @property def __isabstractmethod__(self): return getattr(self.func, '__isabstractmethod__', False) + def __repr__(self): + try: + name = self.func.__qualname__ + except AttributeError: + try: + name = self.func.__name__ + except AttributeError: + name = '?' + return f'' + +class _singledispatchmethod_get: + def __init__(self, unbound, obj, cls): + self._unbound = unbound + self._dispatch = unbound.dispatcher.dispatch + self._obj = obj + self._cls = cls + # Set instance attributes which cannot be handled in __getattr__() + # because they conflict with type descriptors. + func = unbound.func + try: + self.__module__ = func.__module__ + except AttributeError: + pass + try: + self.__doc__ = func.__doc__ + except AttributeError: + pass + + def __repr__(self): + try: + name = self.__qualname__ + except AttributeError: + try: + name = self.__name__ + except AttributeError: + name = '?' + if self._obj is not None: + return f'' + else: + return f'' + + def __call__(self, /, *args, **kwargs): + if not args: + funcname = getattr(self._unbound.func, '__name__', + 'singledispatchmethod method') + raise TypeError(f'{funcname} requires at least ' + '1 positional argument') + return self._dispatch(args[0].__class__).__get__(self._obj, self._cls)(*args, **kwargs) + + def __getattr__(self, name): + # Resolve these attributes lazily to speed up creation of + # the _singledispatchmethod_get instance. + if name not in {'__name__', '__qualname__', '__isabstractmethod__', + '__annotations__', '__type_params__'}: + raise AttributeError + return getattr(self._unbound.func, name) + + @property + def __wrapped__(self): + return self._unbound.func + + @property + def register(self): + return self._unbound.register + ################################################################################ ### cached_property() - property result cached as instance attribute @@ -1035,3 +1135,31 @@ class cached_property: return val __class_getitem__ = classmethod(GenericAlias) + +def _warn_python_reduce_kwargs(py_reduce): + @wraps(py_reduce) + def wrapper(*args, **kwargs): + if 'function' in kwargs or 'sequence' in kwargs: + import os + import warnings + warnings.warn( + 'Calling functools.reduce with keyword arguments ' + '"function" or "sequence" ' + 'is deprecated in Python 3.14 and will be ' + 'forbidden in Python 3.16.', + DeprecationWarning, + skip_file_prefixes=(os.path.dirname(__file__),)) + return py_reduce(*args, **kwargs) + return wrapper + +reduce = _warn_python_reduce_kwargs(reduce) +del _warn_python_reduce_kwargs + +# The import of the C accelerated version of reduce() has been moved +# here due to gh-121676. In Python 3.16, _warn_python_reduce_kwargs() +# should be removed and the import block should be moved back right +# after the definition of reduce(). +try: + from _functools import reduce +except ImportError: + pass diff --git a/Lib/inspect.py b/Lib/inspect.py index d2a045469..ed384cf8b 100644 --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -563,7 +563,8 @@ def isroutine(object): or isfunction(object) or ismethod(object) or ismethoddescriptor(object) - or ismethodwrapper(object)) + or ismethodwrapper(object) + or isinstance(object, functools._singledispatchmethod_get)) def isabstract(object): """Return true if the object is an abstract base class (ABC).""" diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index bf3e24481..9ab0c8917 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1,4 +1,5 @@ import abc +from annotationlib import Format, get_annotations import builtins import collections import collections.abc @@ -6,6 +7,7 @@ import copy from itertools import permutations import pickle from random import choice +import re import sys from test import support import threading @@ -19,8 +21,11 @@ from weakref import proxy import contextlib from inspect import Signature +from test.support import ALWAYS_EQ from test.support import import_helper from test.support import threading_helper +from test.support import cpython_only +from test.support import EqualToForwardRef import functools @@ -60,6 +65,14 @@ class BadTuple(tuple): class MyDict(dict): pass +class TestImportTime(unittest.TestCase): + + @cpython_only + def test_lazy_import(self): + import_helper.ensure_lazy_imports( + "functools", {"os", "weakref", "typing", "annotationlib", "warnings"} + ) + class TestPartial: @@ -210,6 +223,69 @@ class TestPartial: p2.new_attr = 'spam' self.assertEqual(p2.new_attr, 'spam') + def test_placeholders_trailing_raise(self): + PH = self.module.Placeholder + for args in [(PH,), (0, PH), (0, PH, 1, PH, PH, PH)]: + with self.assertRaises(TypeError): + self.partial(capture, *args) + + def test_placeholders(self): + PH = self.module.Placeholder + # 1 Placeholder + args = (PH, 0) + p = self.partial(capture, *args) + actual_args, actual_kwds = p('x') + self.assertEqual(actual_args, ('x', 0)) + self.assertEqual(actual_kwds, {}) + # 2 Placeholders + args = (PH, 0, PH, 1) + p = self.partial(capture, *args) + with self.assertRaises(TypeError): + p('x') + actual_args, actual_kwds = p('x', 'y') + self.assertEqual(actual_args, ('x', 0, 'y', 1)) + self.assertEqual(actual_kwds, {}) + # Checks via `is` and not `eq` + # thus ALWAYS_EQ isn't treated as Placeholder + p = self.partial(capture, ALWAYS_EQ) + actual_args, actual_kwds = p() + self.assertEqual(len(actual_args), 1) + self.assertIs(actual_args[0], ALWAYS_EQ) + self.assertEqual(actual_kwds, {}) + + def test_placeholders_optimization(self): + PH = self.module.Placeholder + p = self.partial(capture, PH, 0) + p2 = self.partial(p, PH, 1, 2, 3) + self.assertEqual(p2.args, (PH, 0, 1, 2, 3)) + p3 = self.partial(p2, -1, 4) + actual_args, actual_kwds = p3(5) + self.assertEqual(actual_args, (-1, 0, 1, 2, 3, 4, 5)) + self.assertEqual(actual_kwds, {}) + # inner partial has placeholders and outer partial has no args case + p = self.partial(capture, PH, 0) + p2 = self.partial(p) + self.assertEqual(p2.args, (PH, 0)) + self.assertEqual(p2(1), ((1, 0), {})) + + def test_placeholders_kw_restriction(self): + PH = self.module.Placeholder + with self.assertRaisesRegex(TypeError, "Placeholder"): + self.partial(capture, a=PH) + # Passes, as checks via `is` and not `eq` + p = self.partial(capture, a=ALWAYS_EQ) + actual_args, actual_kwds = p() + self.assertEqual(actual_args, ()) + self.assertEqual(len(actual_kwds), 1) + self.assertIs(actual_kwds['a'], ALWAYS_EQ) + + def test_construct_placeholder_singleton(self): + PH = self.module.Placeholder + tp = type(PH) + self.assertIs(tp(), PH) + self.assertRaises(TypeError, tp, 1, 2) + self.assertRaises(TypeError, tp, a=1, b=2) + def test_repr(self): args = (object(), object()) args_repr = ', '.join(repr(a) for a in args) @@ -311,8 +387,26 @@ class TestPartial: self.assertEqual(f(2), ((2,), {})) self.assertEqual(f(), ((), {})) + # Set State with placeholders + PH = self.module.Placeholder + f = self.partial(signature) + f.__setstate__((capture, (PH, 1), dict(a=10), dict(attr=[]))) + self.assertEqual(signature(f), (capture, (PH, 1), dict(a=10), dict(attr=[]))) + msg_regex = re.escape("missing positional arguments in 'partial' call; " + "expected at least 1, got 0") + with self.assertRaisesRegex(TypeError, f'^{msg_regex}$') as cm: + f() + self.assertEqual(f(2), ((2, 1), dict(a=10))) + + # Trailing Placeholder error + f = self.partial(signature) + msg_regex = re.escape("trailing Placeholders are not allowed") + with self.assertRaisesRegex(TypeError, f'^{msg_regex}$') as cm: + f.__setstate__((capture, (1, PH), dict(a=10), dict(attr=[]))) + def test_setstate_errors(self): f = self.partial(signature) + self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) @@ -320,6 +414,8 @@ class TestPartial: self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) + self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, ())) + self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, 'test')) def test_setstate_subclasses(self): f = self.partial(signature) @@ -341,6 +437,8 @@ class TestPartial: self.assertEqual(r, ((1, 2), {})) self.assertIs(type(r[0]), tuple) + @support.skip_if_sanitizer("thread sanitizer crashes in __tsan::FuncEntry", thread=True) + @support.skip_emscripten_stack_overflow() def test_recursive_pickle(self): with replaced_module('functools', self.module): f = self.partial(capture) @@ -395,7 +493,6 @@ class TestPartial: f = self.partial(object) self.assertRaises(TypeError, f.__setstate__, BadSequence()) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_partial_as_method(self): class A: meth = self.partial(capture, 1, a=2) @@ -406,9 +503,7 @@ class TestPartial: self.assertEqual(A.meth(3, b=4), ((1, 3), {'a': 2, 'b': 4})) self.assertEqual(A.cmeth(3, b=4), ((1, A, 3), {'a': 2, 'b': 4})) self.assertEqual(A.smeth(3, b=4), ((1, 3), {'a': 2, 'b': 4})) - with self.assertWarns(FutureWarning) as w: - self.assertEqual(a.meth(3, b=4), ((1, 3), {'a': 2, 'b': 4})) - self.assertEqual(w.filename, __file__) + self.assertEqual(a.meth(3, b=4), ((1, a, 3), {'a': 2, 'b': 4})) self.assertEqual(a.cmeth(3, b=4), ((1, A, 3), {'a': 2, 'b': 4})) self.assertEqual(a.smeth(3, b=4), ((1, 3), {'a': 2, 'b': 4})) @@ -465,11 +560,18 @@ class TestPartialC(TestPartial, unittest.TestCase): self.assertIn('astr', r) self.assertIn("['sth']", r) - def test_repr(self): - return super().test_repr() - - def test_recursive_repr(self): - return super().test_recursive_repr() + def test_placeholders_refcount_smoke(self): + PH = self.module.Placeholder + # sum supports vector call + lst1, start = [], [] + sum_lists = self.partial(sum, PH, start) + for i in range(10): + sum_lists([lst1, lst1]) + # collections.ChainMap initializer does not support vectorcall + map1, map2 = {}, {} + partial_cm = self.partial(collections.ChainMap, PH, map1) + for i in range(10): + partial_cm(map2, map2) class TestPartialPy(TestPartial, unittest.TestCase): @@ -495,6 +597,19 @@ class TestPartialCSubclass(TestPartialC): class TestPartialPySubclass(TestPartialPy): partial = PyPartialSubclass + def test_subclass_optimization(self): + # `partial` input to `partial` subclass + p = py_functools.partial(min, 2) + p2 = self.partial(p, 1) + self.assertIs(p2.func, min) + self.assertEqual(p2(0), 0) + # `partial` subclass input to `partial` subclass + p = self.partial(min, 2) + p2 = self.partial(p, 1) + self.assertIs(p2.func, min) + self.assertEqual(p2(0), 0) + + class TestPartialMethod(unittest.TestCase): class A(object): @@ -564,11 +679,11 @@ class TestPartialMethod(unittest.TestCase): def test_unbound_method_retrieval(self): obj = self.A - self.assertFalse(hasattr(obj.both, "__self__")) - self.assertFalse(hasattr(obj.nested, "__self__")) - self.assertFalse(hasattr(obj.over_partial, "__self__")) - self.assertFalse(hasattr(obj.static, "__self__")) - self.assertFalse(hasattr(self.a.static, "__self__")) + self.assertNotHasAttr(obj.both, "__self__") + self.assertNotHasAttr(obj.nested, "__self__") + self.assertNotHasAttr(obj.over_partial, "__self__") + self.assertNotHasAttr(obj.static, "__self__") + self.assertNotHasAttr(self.a.static, "__self__") def test_descriptors(self): for obj in [self.A, self.a]: @@ -632,6 +747,20 @@ class TestPartialMethod(unittest.TestCase): p = functools.partial(f, 1) self.assertEqual(p(2), f(1, 2)) + def test_subclass_optimization(self): + class PartialMethodSubclass(functools.partialmethod): + pass + # `partialmethod` input to `partialmethod` subclass + p = functools.partialmethod(min, 2) + p2 = PartialMethodSubclass(p, 1) + self.assertIs(p2.func, min) + self.assertEqual(p2.__get__(0)(), 0) + # `partialmethod` subclass input to `partialmethod` subclass + p = PartialMethodSubclass(min, 2) + p2 = PartialMethodSubclass(p, 1) + self.assertIs(p2.func, min) + self.assertEqual(p2.__get__(0)(), 0) + class TestUpdateWrapper(unittest.TestCase): @@ -696,7 +825,7 @@ class TestUpdateWrapper(unittest.TestCase): self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.__annotations__, {}) - self.assertFalse(hasattr(wrapper, 'attr')) + self.assertNotHasAttr(wrapper, 'attr') def test_selective_update(self): def f(): @@ -745,7 +874,7 @@ class TestUpdateWrapper(unittest.TestCase): pass functools.update_wrapper(wrapper, max) self.assertEqual(wrapper.__name__, 'max') - self.assertTrue(wrapper.__doc__.startswith('max(')) + self.assertStartsWith(wrapper.__doc__, 'max(') self.assertEqual(wrapper.__annotations__, {}) def test_update_type_wrapper(self): @@ -756,6 +885,26 @@ class TestUpdateWrapper(unittest.TestCase): self.assertEqual(wrapper.__annotations__, {}) self.assertEqual(wrapper.__type_params__, ()) + def test_update_wrapper_annotations(self): + def inner(x: int): pass + def wrapper(*args): pass + + functools.update_wrapper(wrapper, inner) + self.assertEqual(wrapper.__annotations__, {'x': int}) + self.assertIs(wrapper.__annotate__, inner.__annotate__) + + def with_forward_ref(x: undefined): pass + def wrapper(*args): pass + + functools.update_wrapper(wrapper, with_forward_ref) + + self.assertIs(wrapper.__annotate__, with_forward_ref.__annotate__) + with self.assertRaises(NameError): + wrapper.__annotations__ + + undefined = str + self.assertEqual(wrapper.__annotations__, {'x': undefined}) + class TestWraps(TestUpdateWrapper): @@ -795,7 +944,7 @@ class TestWraps(TestUpdateWrapper): self.assertEqual(wrapper.__name__, 'wrapper') self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) - self.assertFalse(hasattr(wrapper, 'attr')) + self.assertNotHasAttr(wrapper, 'attr') def test_selective_update(self): def f(): @@ -897,6 +1046,29 @@ class TestReduce: d = {"one": 1, "two": 2, "three": 3} self.assertEqual(self.reduce(add, d), "".join(d.keys())) + # test correctness of keyword usage of `initial` in `reduce` + def test_initial_keyword(self): + def add(x, y): + return x + y + self.assertEqual( + self.reduce(add, ['a', 'b', 'c'], ''), + self.reduce(add, ['a', 'b', 'c'], initial=''), + ) + self.assertEqual( + self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), + self.reduce(add, [['a', 'c'], [], ['d', 'w']], initial=[]), + ) + self.assertEqual( + self.reduce(lambda x, y: x*y, range(2,8), 1), + self.reduce(lambda x, y: x*y, range(2,8), initial=1), + ) + self.assertEqual( + self.reduce(lambda x, y: x*y, range(2,21), 1), + self.reduce(lambda x, y: x*y, range(2,21), initial=1), + ) + self.assertRaises(TypeError, self.reduce, add, [0, 1], initial="") + self.assertEqual(self.reduce(42, "", initial="1"), "1") # func is never called with one item + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestReduceC(TestReduce, unittest.TestCase): @@ -907,6 +1079,12 @@ class TestReduceC(TestReduce, unittest.TestCase): class TestReducePy(TestReduce, unittest.TestCase): reduce = staticmethod(py_functools.reduce) + def test_reduce_with_kwargs(self): + with self.assertWarns(DeprecationWarning): + self.reduce(function=lambda x, y: x + y, sequence=[1, 2, 3, 4, 5], initial=1) + with self.assertWarns(DeprecationWarning): + self.reduce(lambda x, y: x + y, sequence=[1, 2, 3, 4, 5], initial=1) + class TestCmpToKey: @@ -1014,35 +1192,35 @@ class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): self, type(c_functools.cmp_to_key(None)) ) - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_bad_cmp(self): return super().test_bad_cmp() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cmp_to_key(self): return super().test_cmp_to_key() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cmp_to_key_arguments(self): return super().test_cmp_to_key_arguments() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_cmp_to_signature(self): return super().test_cmp_to_signature() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_hash(self): return super().test_hash() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_obj_field(self): return super().test_obj_field() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_sort_int(self): return super().test_sort_int() - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_sort_int_str(self): return super().test_sort_int_str() @@ -1532,6 +1710,7 @@ class TestLRU: f(0, **{}) self.assertEqual(f.cache_info().hits, 1) + @unittest.expectedFailure # TODO: RUSTPYTHON; Python lru_cache impl doesn't cache hash like C impl def test_lru_hash_only_once(self): # To protect against weird reentrancy bugs and to improve # efficiency when faced with slow __hash__ methods, the @@ -1925,7 +2104,7 @@ class TestLRU: return 1 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) - @support.suppress_immortalization() + @unittest.expectedFailure # TODO: RUSTPYTHON; GC behavior differs from CPython's refcounting def test_lru_cache_weakrefable(self): @self.module.lru_cache def test_function(x): @@ -1963,8 +2142,38 @@ class TestLRU: self.assertEqual(str(Signature.from_callable(lru.cache_info)), '()') self.assertEqual(str(Signature.from_callable(lru.cache_clear)), '()') + def test_get_annotations(self): + def orig(a: int) -> str: ... + lru = self.module.lru_cache(1)(orig) + + self.assertEqual( + get_annotations(orig), {"a": int, "return": str}, + ) + self.assertEqual( + get_annotations(lru), {"a": int, "return": str}, + ) + + def test_get_annotations_with_forwardref(self): + def orig(a: int) -> nonexistent: ... + lru = self.module.lru_cache(1)(orig) + + self.assertEqual( + get_annotations(orig, format=Format.FORWARDREF), + {"a": int, "return": EqualToForwardRef('nonexistent', owner=orig)}, + ) + self.assertEqual( + get_annotations(lru, format=Format.FORWARDREF), + {"a": int, "return": EqualToForwardRef('nonexistent', owner=lru)}, + ) + with self.assertRaises(NameError): + get_annotations(orig, format=Format.VALUE) + with self.assertRaises(NameError): + get_annotations(lru, format=Format.VALUE) + @support.skip_on_s390x @unittest.skipIf(support.is_wasi, "WASI has limited C stack") + @support.skip_if_sanitizer("requires deep stack", ub=True, thread=True) + @support.skip_emscripten_stack_overflow() def test_lru_recursion(self): @self.module.lru_cache @@ -1973,15 +2182,12 @@ class TestLRU: return n return fib(n-1) + fib(n-2) - if not support.Py_DEBUG: - depth = support.get_c_recursion_limit()*2//7 - with support.infinite_recursion(): - fib(depth) + fib(100) if self.module == c_functools: fib.cache_clear() with support.infinite_recursion(): with self.assertRaises(RecursionError): - fib(10000) + fib(support.exceeds_recursion_limit()) @py_functools.lru_cache() @@ -2563,15 +2769,15 @@ class TestSingleDispatch(unittest.TestCase): a.t(0) self.assertEqual(a.arg, "int") aa = A() - self.assertFalse(hasattr(aa, 'arg')) + self.assertNotHasAttr(aa, 'arg') a.t('') self.assertEqual(a.arg, "str") aa = A() - self.assertFalse(hasattr(aa, 'arg')) + self.assertNotHasAttr(aa, 'arg') a.t(0.0) self.assertEqual(a.arg, "base") aa = A() - self.assertFalse(hasattr(aa, 'arg')) + self.assertNotHasAttr(aa, 'arg') def test_staticmethod_register(self): class A: @@ -2806,6 +3012,8 @@ class TestSingleDispatch(unittest.TestCase): A().static_func ): with self.subTest(meth=meth): + self.assertEqual(meth.__module__, __name__) + self.assertEqual(type(meth).__module__, 'functools') self.assertEqual(meth.__qualname__, prefix + meth.__name__) self.assertEqual(meth.__doc__, ('My function docstring' @@ -2820,6 +3028,67 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(A.static_func.__name__, 'static_func') self.assertEqual(A().static_func.__name__, 'static_func') + def test_method_repr(self): + class Callable: + def __call__(self, *args): + pass + + class CallableWithName: + __name__ = 'NOQUALNAME' + def __call__(self, *args): + pass + + class A: + @functools.singledispatchmethod + def func(self, arg): + pass + @functools.singledispatchmethod + @classmethod + def cls_func(cls, arg): + pass + @functools.singledispatchmethod + @staticmethod + def static_func(arg): + pass + # No __qualname__, only __name__ + no_qualname = functools.singledispatchmethod(CallableWithName()) + # No __qualname__, no __name__ + no_name = functools.singledispatchmethod(Callable()) + + self.assertEqual(repr(A.__dict__['func']), + f'') + self.assertEqual(repr(A.__dict__['cls_func']), + f'') + self.assertEqual(repr(A.__dict__['static_func']), + f'') + self.assertEqual(repr(A.__dict__['no_qualname']), + f'') + self.assertEqual(repr(A.__dict__['no_name']), + f'') + + self.assertEqual(repr(A.func), + f'') + self.assertEqual(repr(A.cls_func), + f'') + self.assertEqual(repr(A.static_func), + f'') + self.assertEqual(repr(A.no_qualname), + f'') + self.assertEqual(repr(A.no_name), + f'') + + a = A() + self.assertEqual(repr(a.func), + f'') + self.assertEqual(repr(a.cls_func), + f'') + self.assertEqual(repr(a.static_func), + f'') + self.assertEqual(repr(a.no_qualname), + f'') + self.assertEqual(repr(a.no_name), + f'') + def test_double_wrapped_methods(self): def classmethod_friendly_decorator(func): wrapped = func.__func__ @@ -2836,7 +3105,8 @@ class TestSingleDispatch(unittest.TestCase): try: yield str(arg) finally: - return 'Done' + pass + return 'Done' @classmethod_friendly_decorator @classmethod @@ -2852,7 +3122,8 @@ class TestSingleDispatch(unittest.TestCase): try: yield str(arg) finally: - return 'Done' + pass + return 'Done' @functools.singledispatchmethod @classmethod_friendly_decorator @@ -2922,7 +3193,6 @@ class TestSingleDispatch(unittest.TestCase): 'decorated_classmethod' ) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_invalid_registrations(self): msg_prefix = "Invalid first argument to `register()`: " msg_suffix = ( @@ -2936,16 +3206,16 @@ class TestSingleDispatch(unittest.TestCase): @i.register(42) def _(arg): return "I annotated with a non-type" - self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) - self.assertTrue(str(exc.exception).endswith(msg_suffix)) + self.assertStartsWith(str(exc.exception), msg_prefix + "42") + self.assertEndsWith(str(exc.exception), msg_suffix) with self.assertRaises(TypeError) as exc: @i.register def _(arg): return "I forgot to annotate" - self.assertTrue(str(exc.exception).startswith(msg_prefix + + self.assertStartsWith(str(exc.exception), msg_prefix + "._" - )) - self.assertTrue(str(exc.exception).endswith(msg_suffix)) + ) + self.assertEndsWith(str(exc.exception), msg_suffix) with self.assertRaises(TypeError) as exc: @i.register @@ -2955,23 +3225,23 @@ class TestSingleDispatch(unittest.TestCase): # types from `typing`. Instead, annotate with regular types # or ABCs. return "I annotated with a generic collection" - self.assertTrue(str(exc.exception).startswith( + self.assertStartsWith(str(exc.exception), "Invalid annotation for 'arg'." - )) - self.assertTrue(str(exc.exception).endswith( + ) + self.assertEndsWith(str(exc.exception), 'typing.Iterable[str] is not a class.' - )) + ) with self.assertRaises(TypeError) as exc: @i.register def _(arg: typing.Union[int, typing.Iterable[str]]): return "Invalid Union" - self.assertTrue(str(exc.exception).startswith( + self.assertStartsWith(str(exc.exception), "Invalid annotation for 'arg'." - )) - self.assertTrue(str(exc.exception).endswith( - 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' - )) + ) + self.assertEndsWith(str(exc.exception), + 'int | typing.Iterable[str] not all arguments are classes.' + ) def test_invalid_positional_argument(self): @functools.singledispatch @@ -3118,6 +3388,28 @@ class TestSingleDispatch(unittest.TestCase): self.assertEqual(f(""), "default") self.assertEqual(f(b""), "default") + def test_forward_reference(self): + @functools.singledispatch + def f(arg, arg2=None): + return "default" + + @f.register + def _(arg: str, arg2: undefined = None): + return "forward reference" + + self.assertEqual(f(1), "default") + self.assertEqual(f(""), "forward reference") + + def test_unresolved_forward_reference(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "is an unresolved forward reference"): + @f.register + def _(arg: undefined): + return "forward reference" + def test_method_equal_instances(self): # gh-127750: Reference to self was cached class A: @@ -3300,7 +3592,7 @@ class TestCachedProperty(unittest.TestCase): ): MyClass.prop - @unittest.expectedFailure # TODO: RUSTPYTHON + @unittest.expectedFailure # TODO: RUSTPYTHON def test_reuse_different_names(self): """Disallow this case because decorated function a would not be cached.""" with self.assertRaises(TypeError) as ctx: diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index c11898741..f421d745a 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -2059,6 +2059,7 @@ class MappingTestCase(TestBase): # copying should not result in a crash. self.check_threaded_weak_dict_copy(weakref.WeakKeyDictionary, False) + @unittest.skip('TODO: RUSTPYTHON; occasionally crash (malloc corruption)') @threading_helper.requires_working_threading() @support.requires_resource('cpu') def test_threaded_weak_key_dict_deepcopy(self): @@ -2066,13 +2067,14 @@ class MappingTestCase(TestBase): # copying should not result in a crash. self.check_threaded_weak_dict_copy(weakref.WeakKeyDictionary, True) - @unittest.skip('TODO: RUSTPYTHON; occasionally crash (Exit code -6)') + @unittest.skip('TODO: RUSTPYTHON; occasionally crash (malloc corruption)') @threading_helper.requires_working_threading() def test_threaded_weak_value_dict_copy(self): # Issue #35615: Weakref keys or values getting GC'ed during dict # copying should not result in a crash. self.check_threaded_weak_dict_copy(weakref.WeakValueDictionary, False) + @unittest.skip('TODO: RUSTPYTHON; occasionally crash (malloc corruption)') @threading_helper.requires_working_threading() @support.requires_resource('cpu') def test_threaded_weak_value_dict_deepcopy(self): diff --git a/crates/vm/src/builtins/classmethod.rs b/crates/vm/src/builtins/classmethod.rs index 5b7f92186..d2f1377be 100644 --- a/crates/vm/src/builtins/classmethod.rs +++ b/crates/vm/src/builtins/classmethod.rs @@ -3,7 +3,7 @@ use crate::{ AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, common::lock::PyMutex, - function::FuncArgs, + function::{FuncArgs, PySetterValue}, types::{Constructor, GetDescriptor, Initializer, Representable}, }; @@ -158,6 +158,27 @@ impl PyClassMethod { self.callable.lock().get_attr("__annotations__", vm) } + #[pygetset(setter)] + fn set___annotations__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> { + match value { + PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotations__", v, vm), + PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython + } + } + + #[pygetset] + fn __annotate__(&self, vm: &VirtualMachine) -> PyResult { + self.callable.lock().get_attr("__annotate__", vm) + } + + #[pygetset(setter)] + fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> { + match value { + PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotate__", v, vm), + PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython + } + } + #[pygetset] fn __isabstractmethod__(&self, vm: &VirtualMachine) -> PyObjectRef { match vm.get_attribute_opt(self.callable.lock().clone(), "__isabstractmethod__") { diff --git a/crates/vm/src/builtins/staticmethod.rs b/crates/vm/src/builtins/staticmethod.rs index 5d2474a56..ac363415a 100644 --- a/crates/vm/src/builtins/staticmethod.rs +++ b/crates/vm/src/builtins/staticmethod.rs @@ -3,7 +3,7 @@ use crate::{ Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, class::PyClassImpl, common::lock::PyMutex, - function::FuncArgs, + function::{FuncArgs, PySetterValue}, types::{Callable, Constructor, GetDescriptor, Initializer, Representable}, }; @@ -121,6 +121,27 @@ impl PyStaticMethod { self.callable.lock().get_attr("__annotations__", vm) } + #[pygetset(setter)] + fn set___annotations__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> { + match value { + PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotations__", v, vm), + PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython + } + } + + #[pygetset] + fn __annotate__(&self, vm: &VirtualMachine) -> PyResult { + self.callable.lock().get_attr("__annotate__", vm) + } + + #[pygetset(setter)] + fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> { + match value { + PySetterValue::Assign(v) => self.callable.lock().set_attr("__annotate__", v, vm), + PySetterValue::Delete => Ok(()), // Silently ignore delete like CPython + } + } + #[pygetset] fn __isabstractmethod__(&self, vm: &VirtualMachine) -> PyObjectRef { match vm.get_attribute_opt(self.callable.lock().clone(), "__isabstractmethod__") { diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 510c3fb84..3fe396932 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -876,11 +876,13 @@ impl PyType { } #[pygetset(setter)] - fn set___annotate__(&self, value: Option, vm: &VirtualMachine) -> PyResult<()> { - if value.is_none() { - return Err(vm.new_type_error("cannot delete __annotate__ attribute".to_owned())); - } - let value = value.unwrap(); + fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> { + let value = match value { + PySetterValue::Delete => { + return Err(vm.new_type_error("cannot delete __annotate__ attribute".to_owned())); + } + PySetterValue::Assign(v) => v, + }; if self.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) { return Err(vm.new_type_error(format!( diff --git a/crates/vm/src/stdlib/functools.rs b/crates/vm/src/stdlib/functools.rs index a59fd48f6..307f7760f 100644 --- a/crates/vm/src/stdlib/functools.rs +++ b/crates/vm/src/stdlib/functools.rs @@ -4,28 +4,42 @@ pub(crate) use _functools::make_module; mod _functools { use crate::{ Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, - builtins::{PyDict, PyGenericAlias, PyTuple, PyTypeRef}, + builtins::{PyBoundMethod, PyDict, PyGenericAlias, PyTuple, PyType, PyTypeRef}, common::lock::PyRwLock, - function::{FuncArgs, KwArgs, OptionalArg}, + function::{FuncArgs, KwArgs, OptionalOption}, object::AsObject, protocol::PyIter, pyclass, recursion::ReprGuard, - types::{Callable, Constructor, Representable}, + types::{Callable, Constructor, GetDescriptor, Representable}, }; use indexmap::IndexMap; - #[pyfunction] - fn reduce( + #[derive(FromArgs)] + struct ReduceArgs { function: PyObjectRef, iterator: PyIter, - start_value: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { + #[pyarg(any, optional, name = "initial")] + initial: OptionalOption, + } + + #[pyfunction] + fn reduce(args: ReduceArgs, vm: &VirtualMachine) -> PyResult { + let ReduceArgs { + function, + iterator, + initial, + } = args; let mut iter = iterator.iter_without_hint(vm)?; - let start_value = if let OptionalArg::Present(val) = start_value { - val + // OptionalOption distinguishes between: + // - Missing: no argument provided → use first element from iterator + // - Present(None): explicitly passed None → use None as initial value + // - Present(Some(v)): passed a value → use that value + let start_value = if let Some(val) = initial.into_option() { + // initial was provided (could be None or Some value) + val.unwrap_or_else(|| vm.ctx.none()) } else { + // initial was not provided at all iter.next().transpose()?.ok_or_else(|| { let exc_type = vm.ctx.exceptions.type_error.to_owned(); vm.new_exception_msg( @@ -42,6 +56,72 @@ mod _functools { Ok(accumulator) } + // Placeholder singleton for partial arguments + // The singleton is stored as _instance on the type class + #[pyattr] + #[allow(non_snake_case)] + fn Placeholder(vm: &VirtualMachine) -> PyObjectRef { + let placeholder = PyPlaceholderType.into_pyobject(vm); + // Store the singleton on the type class for slot_new to find + let typ = placeholder.class(); + typ.set_attr(vm.ctx.intern_str("_instance"), placeholder.clone()); + placeholder + } + + #[pyattr] + #[pyclass(name = "_PlaceholderType", module = "functools")] + #[derive(Debug, PyPayload)] + pub struct PyPlaceholderType; + + impl Constructor for PyPlaceholderType { + type Args = FuncArgs; + + fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !args.args.is_empty() || !args.kwargs.is_empty() { + return Err(vm.new_type_error("_PlaceholderType takes no arguments".to_owned())); + } + // Return the singleton stored on the type class + if let Some(instance) = cls.get_attr(vm.ctx.intern_str("_instance")) { + return Ok(instance); + } + // Fallback: create a new instance (shouldn't happen for base type after module init) + Ok(PyPlaceholderType.into_pyobject(vm)) + } + + fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { + // This is never called because we override slot_new + Ok(PyPlaceholderType) + } + } + + #[pyclass(with(Constructor, Representable))] + impl PyPlaceholderType { + #[pymethod] + fn __reduce__(&self) -> &'static str { + "Placeholder" + } + + #[pymethod] + fn __init_subclass__(_cls: PyTypeRef, vm: &VirtualMachine) -> PyResult<()> { + Err(vm.new_type_error("cannot subclass '_PlaceholderType'".to_owned())) + } + } + + impl Representable for PyPlaceholderType { + #[inline] + fn repr_str(_zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok("Placeholder".to_owned()) + } + } + + fn is_placeholder(obj: &PyObjectRef) -> bool { + &*obj.class().name() == "_PlaceholderType" + } + + fn count_placeholders(args: &[PyObjectRef]) -> usize { + args.iter().filter(|a| is_placeholder(a)).count() + } + #[pyattr] #[pyclass(name = "partial", module = "functools")] #[derive(Debug, PyPayload)] @@ -54,9 +134,13 @@ mod _functools { func: PyObjectRef, args: PyRef, keywords: PyRef, + phcount: usize, } - #[pyclass(with(Constructor, Callable, Representable), flags(BASETYPE, HAS_DICT))] + #[pyclass( + with(Constructor, Callable, GetDescriptor, Representable), + flags(BASETYPE, HAS_DICT) + )] impl PyPartial { #[pygetset] fn func(&self) -> PyObjectRef { @@ -157,6 +241,13 @@ mod _functools { } }; + // Validate no trailing placeholders + let args_slice = args_tuple.as_slice(); + if !args_slice.is_empty() && is_placeholder(args_slice.last().unwrap()) { + return Err(vm.new_type_error("trailing Placeholders are not allowed".to_owned())); + } + let phcount = count_placeholders(args_slice); + // Actually update the state let mut inner = zelf.inner.write(); inner.func = func.clone(); @@ -165,6 +256,7 @@ mod _functools { // Handle keywords - keep the original type inner.keywords = keywords_dict; + inner.phcount = phcount; // Update __dict__ if provided let Some(instance_dict) = zelf.as_object().dict() else { @@ -218,17 +310,54 @@ mod _functools { return Err(vm.new_type_error("the first argument must be callable")); } + // Check for placeholders in kwargs + for (key, value) in &args.kwargs { + if is_placeholder(value) { + return Err(vm.new_type_error(format!( + "Placeholder cannot be passed as a keyword argument to partial(). \ + Did you mean partial(..., {}=Placeholder, ...)(value)?", + key + ))); + } + } + // Handle nested partial objects let (final_func, final_args, final_keywords) = if let Some(partial) = func.downcast_ref::() { let inner = partial.inner.read(); - let mut combined_args = inner.args.as_slice().to_vec(); - combined_args.extend_from_slice(args_slice); - (inner.func.clone(), combined_args, inner.keywords.clone()) + let stored_args = inner.args.as_slice(); + + // Merge placeholders: replace placeholders in stored_args with new args + let mut merged_args = Vec::with_capacity(stored_args.len() + args_slice.len()); + let mut new_args_iter = args_slice.iter(); + + for stored_arg in stored_args { + if is_placeholder(stored_arg) { + // Replace placeholder with next new arg, or keep placeholder + if let Some(new_arg) = new_args_iter.next() { + merged_args.push(new_arg.clone()); + } else { + merged_args.push(stored_arg.clone()); + } + } else { + merged_args.push(stored_arg.clone()); + } + } + // Append remaining new args + merged_args.extend(new_args_iter.cloned()); + + (inner.func.clone(), merged_args, inner.keywords.clone()) } else { (func.clone(), args_slice.to_vec(), vm.ctx.new_dict()) }; + // Trailing placeholders are not allowed + if !final_args.is_empty() && is_placeholder(final_args.last().unwrap()) { + return Err(vm.new_type_error("trailing Placeholders are not allowed".to_owned())); + } + + let phcount = count_placeholders(&final_args); + // Add new keywords for (key, value) in args.kwargs { final_keywords.set_item(vm.ctx.intern_str(key.as_str()), value, vm)?; @@ -239,6 +368,7 @@ mod _functools { func: final_func, args: vm.ctx.new_tuple(final_args), keywords: final_keywords, + phcount, }), }) } @@ -249,17 +379,44 @@ mod _functools { fn call(zelf: &Py, args: FuncArgs, vm: &VirtualMachine) -> PyResult { // Clone and release lock before calling Python code to prevent deadlock - let (func, stored_args, keywords) = { + let (func, stored_args, keywords, phcount) = { let inner = zelf.inner.read(); ( inner.func.clone(), inner.args.clone(), inner.keywords.clone(), + inner.phcount, ) }; - let mut combined_args = stored_args.as_slice().to_vec(); - combined_args.extend_from_slice(&args.args); + // Check if we have enough args to fill placeholders + if phcount > 0 && args.args.len() < phcount { + return Err(vm.new_type_error(format!( + "missing positional arguments in 'partial' call; expected at least {}, got {}", + phcount, + args.args.len() + ))); + } + + // Build combined args, replacing placeholders + let mut combined_args = Vec::with_capacity(stored_args.len() + args.args.len()); + let mut new_args_iter = args.args.iter(); + + for stored_arg in stored_args.as_slice() { + if is_placeholder(stored_arg) { + // Replace placeholder with next new arg + if let Some(new_arg) = new_args_iter.next() { + combined_args.push(new_arg.clone()); + } else { + // This shouldn't happen if phcount check passed + combined_args.push(stored_arg.clone()); + } + } else { + combined_args.push(stored_arg.clone()); + } + } + // Append remaining new args + combined_args.extend(new_args_iter.cloned()); // Merge keywords from self.keywords and args.kwargs let mut final_kwargs = IndexMap::new(); @@ -281,6 +438,21 @@ mod _functools { } } + impl GetDescriptor for PyPartial { + fn descr_get( + zelf: PyObjectRef, + obj: Option, + _cls: Option, + vm: &VirtualMachine, + ) -> PyResult { + let obj = match obj { + Some(obj) if !vm.is_none(&obj) => obj, + _ => return Ok(zelf), + }; + Ok(PyBoundMethod::new(obj, zelf).into_ref(&vm.ctx).into()) + } + } + impl Representable for PyPartial { #[inline] fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult {