From ec247e0d4f458452308e53aef3b830c355b4b933 Mon Sep 17 00:00:00 2001 From: Chris Moradi <37349208+chrismoradi@users.noreply.github.com> Date: Mon, 25 Oct 2021 21:19:44 -0700 Subject: [PATCH 1/2] Update collections from CPython, fix tests for UserDict/List/String Clean implementation of changes in PR #3371 based on feedback. Copies from [CPython tag `v3.9.7` and adds back custom RustPython changes where needed for: - `Lib/collections/__init__.py` - `Lib/test/test_collections.py` Closes: #3371 --- Lib/collections/__init__.py | 557 ++++++++++++++++++++++++++--------- Lib/test/test_collections.py | 209 +++++++++++-- Lib/test/test_userstring.py | 4 - 3 files changed, 602 insertions(+), 168 deletions(-) diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 6260b9601..25d1db7f9 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -14,22 +14,30 @@ list, set, and tuple. ''' -__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList', - 'UserString', 'Counter', 'OrderedDict', 'ChainMap'] +__all__ = [ + 'ChainMap', + 'Counter', + 'OrderedDict', + 'UserDict', + 'UserList', + 'UserString', + 'defaultdict', + 'deque', + 'namedtuple', +] -# For backwards compatibility, continue to make the collections ABCs -# available through the collections module. -from _collections_abc import * import _collections_abc -__all__ += _collections_abc.__all__ - -from operator import itemgetter as _itemgetter, eq as _eq -from keyword import iskeyword as _iskeyword -import sys as _sys import heapq as _heapq -from _weakref import proxy as _proxy -from itertools import repeat as _repeat, chain as _chain, starmap as _starmap +import sys as _sys + +from itertools import chain as _chain +from itertools import repeat as _repeat +from itertools import starmap as _starmap +from keyword import iskeyword as _iskeyword +from operator import eq as _eq +from operator import itemgetter as _itemgetter from reprlib import recursive_repr as _recursive_repr +from _weakref import proxy as _proxy try: from _collections import deque @@ -41,7 +49,11 @@ else: try: from _collections import defaultdict except ImportError: - pass + # FIXME: try to implement defaultdict in collections.rs rather than in Python + # I (coolreader18) couldn't figure out some class stuff with __new__ and + # __init__ and __missing__ and subclassing built-in types from Rust, so I went + # with this instead. + from ._defaultdict import defaultdict def __getattr__(name): @@ -52,13 +64,14 @@ def __getattr__(name): obj = getattr(_collections_abc, name) import warnings warnings.warn("Using or importing the ABCs from 'collections' instead " - "of from 'collections.abc' is deprecated, " - "and in 3.8 it will stop working", + "of from 'collections.abc' is deprecated since Python 3.3, " + "and in 3.10 it will stop working", DeprecationWarning, stacklevel=2) globals()[name] = obj return obj raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + ################################################################################ ### OrderedDict ################################################################################ @@ -108,6 +121,7 @@ class OrderedDict(dict): self, *args = args if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: self.__root except AttributeError: @@ -304,6 +318,24 @@ class OrderedDict(dict): return dict.__eq__(self, other) and all(map(_eq, self, other)) return dict.__eq__(self, other) + def __ior__(self, other): + self.update(other) + return self + + def __or__(self, other): + if not isinstance(other, dict): + return NotImplemented + new = self.__class__(self) + new.update(other) + return new + + def __ror__(self, other): + if not isinstance(other, dict): + return NotImplemented + new = self.__class__(other) + new.update(self) + return new + try: from _collections import OrderedDict @@ -316,7 +348,10 @@ except ImportError: ### namedtuple ################################################################################ -_nt_itemgetters = {} +try: + from _collections import _tuplegetter +except ImportError: + _tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): """Returns a new subclass of tuple with named fields. @@ -389,18 +424,23 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non # Variables used in the methods and docstrings field_names = tuple(map(_sys.intern, field_names)) num_fields = len(field_names) - arg_list = repr(field_names).replace("'", "")[1:-1] + arg_list = ', '.join(field_names) + if num_fields == 1: + arg_list += ',' repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' tuple_new = tuple.__new__ - _len = len + _dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip # Create all the named tuple methods to be added to the class namespace - s = f'def __new__(_cls, {arg_list}): return _tuple_new(_cls, ({arg_list}))' - namespace = {'_tuple_new': tuple_new, '__name__': f'namedtuple_{typename}'} - # Note: exec() has the side-effect of interning the field names - exec(s, namespace) - __new__ = namespace['__new__'] + namespace = { + '_tuple_new': tuple_new, + '__builtins__': {}, + '__name__': f'namedtuple_{typename}', + } + code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' + __new__ = eval(code, namespace) + __new__.__name__ = '__new__' __new__.__doc__ = f'Create new instance of {typename}({arg_list})' if defaults is not None: __new__.__defaults__ = defaults @@ -415,8 +455,8 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' 'or iterable') - def _replace(_self, **kwds): - result = _self._make(map(kwds.pop, field_names, _self)) + def _replace(_self, /, **kwds): + result = _self._make(_map(kwds.pop, field_names, _self)) if kwds: raise ValueError(f'Got unexpected field names: {list(kwds)!r}') return result @@ -429,17 +469,22 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non return self.__class__.__name__ + repr_fmt % self def _asdict(self): - 'Return a new OrderedDict which maps field names to their values.' - return OrderedDict(zip(self._fields, self)) + 'Return a new dict which maps field names to their values.' + return _dict(_zip(self._fields, self)) def __getnewargs__(self): 'Return self as a plain tuple. Used by copy and pickle.' - return tuple(self) + return _tuple(self) # Modify function metadata to help with introspection and debugging - - for method in (__new__, _make.__func__, _replace, - __repr__, _asdict, __getnewargs__): + for method in ( + __new__, + _make.__func__, + _replace, + __repr__, + _asdict, + __getnewargs__, + ): method.__qualname__ = f'{typename}.{method.__name__}' # Build-up the class namespace dictionary @@ -448,7 +493,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non '__doc__': f'{typename}({arg_list})', '__slots__': (), '_fields': field_names, - '_fields_defaults': field_defaults, + '_field_defaults': field_defaults, '__new__': __new__, '_make': _make, '_replace': _replace, @@ -456,15 +501,9 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non '_asdict': _asdict, '__getnewargs__': __getnewargs__, } - cache = _nt_itemgetters for index, name in enumerate(field_names): - try: - itemgetter_object, doc = cache[index] - except KeyError: - itemgetter_object = _itemgetter(index) - doc = f'Alias for field number {index}' - cache[index] = itemgetter_object, doc - class_namespace[name] = property(itemgetter_object, doc=doc) + doc = _sys.intern(f'Alias for field number {index}') + class_namespace[name] = _tuplegetter(index, doc) result = type(typename, (tuple,), class_namespace) @@ -579,8 +618,8 @@ class Counter(dict): '''List the n most common elements and their counts from the most common to the least. If n is None, then list all element counts. - >>> Counter('abcdeabcdabcaba').most_common(3) - [('a', 5), ('b', 4), ('c', 3)] + >>> Counter('abracadabra').most_common(3) + [('a', 5), ('b', 2), ('r', 2)] ''' # Emulate Bag.sortedByCount from Smalltalk @@ -614,8 +653,13 @@ class Counter(dict): @classmethod def fromkeys(cls, iterable, v=None): - # There is no equivalent method for counters because setting v=1 - # means that no element can have a count greater than one. + # There is no equivalent method for counters because the semantics + # would be ambiguous in cases such as Counter.fromkeys('aaabbc', v=2). + # Initializing counters to zero values isn't necessary because zero + # is already the default value for counter lookups. Initializing + # to one is easily accomplished with Counter(set(iterable)). For + # more exotic cases, create a dictionary first using a dictionary + # comprehension or dict.fromkeys(). raise NotImplementedError( 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.') @@ -646,6 +690,7 @@ class Counter(dict): if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) iterable = args[0] if args else None + if iterable is not None: if isinstance(iterable, _collections_abc.Mapping): if self: @@ -653,7 +698,8 @@ class Counter(dict): for elem, count in iterable.items(): self[elem] = count + self_get(elem, 0) else: - super(Counter, self).update(iterable) # fast path when counter is empty + # fast path when counter is empty + super(Counter, self).update(iterable) else: _count_elements(self, iterable) if kwds: @@ -682,6 +728,7 @@ class Counter(dict): if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) iterable = args[0] if args else None + if iterable is not None: self_get = self.get if isinstance(iterable, _collections_abc.Mapping): @@ -707,13 +754,14 @@ class Counter(dict): def __repr__(self): if not self: - return '%s()' % self.__class__.__name__ + return f'{self.__class__.__name__}()' try: - items = ', '.join(map('%r: %r'.__mod__, self.most_common())) - return '%s({%s})' % (self.__class__.__name__, items) + # dict() preserves the ordering returned by most_common() + d = dict(self.most_common()) except TypeError: # handle case where values are not orderable - return '{0}({1!r})'.format(self.__class__.__name__, dict(self)) + d = dict(self) + return f'{self.__class__.__name__}({d!r})' # Multiset-style mathematical operations discussed in: # Knuth TAOCP Volume II section 4.6.3 exercise 19 @@ -723,6 +771,13 @@ class Counter(dict): # # To strip negative and zero counts, add-in an empty counter: # c += Counter() + # + # Rich comparison operators for multiset subset and superset tests + # are deliberately omitted due to semantic conflicts with the + # existing inherited dict equality method. Subset and superset + # semantics ignore zero counts and require that p≤q ∧ p≥q → p=q; + # however, that would not be the case for p=Counter(a=1, b=0) + # and q=Counter(a=1) where the dictionaries are not equal. def __add__(self, other): '''Add counts from two counters. @@ -927,7 +982,7 @@ class ChainMap(_collections_abc.MutableMapping): def __iter__(self): d = {} for mapping in reversed(self.maps): - d.update(mapping) # reuses stored hash values if possible + d.update(dict.fromkeys(mapping)) # reuses stored hash values if possible return iter(d) def __contains__(self, key): @@ -938,8 +993,7 @@ class ChainMap(_collections_abc.MutableMapping): @_recursive_repr() def __repr__(self): - return '{0.__class__.__name__}({1})'.format( - self, ', '.join(map(repr, self.maps))) + return f'{self.__class__.__name__}({", ".join(map(repr, self.maps))})' @classmethod def fromkeys(cls, iterable, *args): @@ -972,7 +1026,7 @@ class ChainMap(_collections_abc.MutableMapping): try: del self.maps[0][key] except KeyError: - raise KeyError('Key not found in the first mapping: {!r}'.format(key)) + raise KeyError(f'Key not found in the first mapping: {key!r}') def popitem(self): 'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.' @@ -986,12 +1040,31 @@ class ChainMap(_collections_abc.MutableMapping): try: return self.maps[0].pop(key, *args) except KeyError: - raise KeyError('Key not found in the first mapping: {!r}'.format(key)) + raise KeyError(f'Key not found in the first mapping: {key!r}') def clear(self): 'Clear maps[0], leaving maps[1:] intact.' self.maps[0].clear() + def __ior__(self, other): + self.maps[0].update(other) + return self + + def __or__(self, other): + if not isinstance(other, _collections_abc.Mapping): + return NotImplemented + m = self.copy() + m.maps[0].update(other) + return m + + def __ror__(self, other): + if not isinstance(other, _collections_abc.Mapping): + return NotImplemented + m = dict(other) + for child in reversed(self.maps): + m.update(child) + return self.__class__(m) + ################################################################################ ### UserDict @@ -1000,36 +1073,32 @@ class ChainMap(_collections_abc.MutableMapping): class UserDict(_collections_abc.MutableMapping): # Start by filling-out the abstract methods - def __init__(*args, **kwargs): - if not args: - raise TypeError("descriptor '__init__' of 'UserDict' object " - "needs an argument") - self, *args = args - if len(args) > 1: - raise TypeError('expected at most 1 arguments, got %d' % len(args)) - if args: - dict = args[0] - elif 'dict' in kwargs: - dict = kwargs.pop('dict') - import warnings - warnings.warn("Passing 'dict' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - dict = None + def __init__(self, dict=None, /, **kwargs): self.data = {} if dict is not None: self.update(dict) - if len(kwargs): + if kwargs: self.update(kwargs) - def __len__(self): return len(self.data) + + def __len__(self): + return len(self.data) + + def __bool__(self): + return bool(self.data) + def __getitem__(self, key): if key in self.data: return self.data[key] if hasattr(self.__class__, "__missing__"): return self.__class__.__missing__(self, key) raise KeyError(key) - def __setitem__(self, key, item): self.data[key] = item - def __delitem__(self, key): del self.data[key] + + def __setitem__(self, key, item): + self.data[key] = item + + def __delitem__(self, key): + del self.data[key] + def __iter__(self): return iter(self.data) @@ -1038,7 +1107,40 @@ class UserDict(_collections_abc.MutableMapping): return key in self.data # Now, add the methods in dicts but not in MutableMapping - def __repr__(self): return repr(self.data) + def __repr__(self): + return repr(self.data) + + def __or__(self, other): + if isinstance(other, UserDict): + return self.__class__(self.data | other.data) + if isinstance(other, dict): + return self.__class__(self.data | other) + return NotImplemented + + def __ror__(self, other): + if isinstance(other, UserDict): + return self.__class__(other.data | self.data) + if isinstance(other, dict): + return self.__class__(other | self.data) + return NotImplemented + + def __ior__(self, other): + if isinstance(other, UserDict): + self.data |= other.data + else: + self.data |= other + return self + + def __copy__(self): + inst = self.__class__.__new__(self.__class__) + inst.__dict__.update(self.__dict__) + # Create a copy and avoid triggering descriptors + inst.__dict__["data"] = self.__dict__["data"].copy() + return inst + + def __sizeof__(self): + return _sys.getsizeof(self.data) + def copy(self): if self.__class__ is UserDict: return UserDict(self.data.copy()) @@ -1051,6 +1153,7 @@ class UserDict(_collections_abc.MutableMapping): self.data = data c.update(self) return c + @classmethod def fromkeys(cls, iterable, value=None): d = cls() @@ -1059,13 +1162,13 @@ class UserDict(_collections_abc.MutableMapping): return d - ################################################################################ ### UserList ################################################################################ class UserList(_collections_abc.MutableSequence): """A more or less complete user-defined wrapper around list objects.""" + def __init__(self, initlist=None): self.data = [] if initlist is not None: @@ -1076,31 +1179,63 @@ class UserList(_collections_abc.MutableSequence): self.data[:] = initlist.data[:] else: self.data = list(initlist) - def __repr__(self): return repr(self.data) - def __lt__(self, other): return self.data < self.__cast(other) - def __le__(self, other): return self.data <= self.__cast(other) - def __eq__(self, other): return self.data == self.__cast(other) - def __gt__(self, other): return self.data > self.__cast(other) - def __ge__(self, other): return self.data >= self.__cast(other) + + def __repr__(self): + return repr(self.data) + + def __lt__(self, other): + return self.data < self.__cast(other) + + def __le__(self, other): + return self.data <= self.__cast(other) + + def __eq__(self, other): + return self.data == self.__cast(other) + + def __gt__(self, other): + return self.data > self.__cast(other) + + def __ge__(self, other): + return self.data >= self.__cast(other) + def __cast(self, other): return other.data if isinstance(other, UserList) else other - def __contains__(self, item): return item in self.data - def __len__(self): return len(self.data) - def __getitem__(self, i): return self.data[i] - def __setitem__(self, i, item): self.data[i] = item - def __delitem__(self, i): del self.data[i] + + def __contains__(self, item): + return item in self.data + + def __len__(self): + return len(self.data) + + def __bool__(self): + return bool(self.data) + + def __getitem__(self, i): + if isinstance(i, slice): + return self.__class__(self.data[i]) + else: + return self.data[i] + + def __setitem__(self, i, item): + self.data[i] = item + + def __delitem__(self, i): + del self.data[i] + def __add__(self, other): if isinstance(other, UserList): return self.__class__(self.data + other.data) elif isinstance(other, type(self.data)): return self.__class__(self.data + other) return self.__class__(self.data + list(other)) + def __radd__(self, other): if isinstance(other, UserList): return self.__class__(other.data + self.data) elif isinstance(other, type(self.data)): return self.__class__(other + self.data) return self.__class__(list(other) + self.data) + def __iadd__(self, other): if isinstance(other, UserList): self.data += other.data @@ -1109,22 +1244,56 @@ class UserList(_collections_abc.MutableSequence): else: self.data += list(other) return self + def __mul__(self, n): - return self.__class__(self.data*n) + return self.__class__(self.data * n) + __rmul__ = __mul__ + def __imul__(self, n): self.data *= n return self - def append(self, item): self.data.append(item) - def insert(self, i, item): self.data.insert(i, item) - def pop(self, i=-1): return self.data.pop(i) - def remove(self, item): self.data.remove(item) - def clear(self): self.data.clear() - def copy(self): return self.__class__(self) - def count(self, item): return self.data.count(item) - def index(self, item, *args): return self.data.index(item, *args) - def reverse(self): self.data.reverse() - def sort(self, *args, **kwds): self.data.sort(*args, **kwds) + + def __copy__(self): + inst = self.__class__.__new__(self.__class__) + inst.__dict__.update(self.__dict__) + # Create a copy and avoid triggering descriptors + inst.__dict__["data"] = self.__dict__["data"][:] + return inst + + def __sizeof__(self): + return _sys.getsizeof(self.data) + + def append(self, item): + self.data.append(item) + + def insert(self, i, item): + self.data.insert(i, item) + + def pop(self, i=-1): + return self.data.pop(i) + + def remove(self, item): + self.data.remove(item) + + def clear(self): + self.data.clear() + + def copy(self): + return self.__class__(self) + + def count(self, item): + return self.data.count(item) + + def index(self, item, *args): + return self.data.index(item, *args) + + def reverse(self): + self.data.reverse() + + def sort(self, /, *args, **kwds): + self.data.sort(*args, **kwds) + def extend(self, other): if isinstance(other, UserList): self.data.extend(other.data) @@ -1132,12 +1301,12 @@ class UserList(_collections_abc.MutableSequence): self.data.extend(other) - ################################################################################ ### UserString ################################################################################ class UserString(_collections_abc.Sequence): + def __init__(self, seq): if isinstance(seq, str): self.data = seq @@ -1145,12 +1314,25 @@ class UserString(_collections_abc.Sequence): self.data = seq.data[:] else: self.data = str(seq) - def __str__(self): return str(self.data) - def __repr__(self): return repr(self.data) - def __int__(self): return int(self.data) - def __float__(self): return float(self.data) - def __complex__(self): return complex(self.data) - def __hash__(self): return hash(self.data) + + def __str__(self): + return str(self.data) + + def __repr__(self): + return repr(self.data) + + def __int__(self): + return int(self.data) + + def __float__(self): + return float(self.data) + + def __complex__(self): + return complex(self.data) + + def __hash__(self): + return hash(self.data) + def __getnewargs__(self): return (self.data[:],) @@ -1158,18 +1340,22 @@ class UserString(_collections_abc.Sequence): if isinstance(string, UserString): return self.data == string.data return self.data == string + def __lt__(self, string): if isinstance(string, UserString): return self.data < string.data return self.data < string + def __le__(self, string): if isinstance(string, UserString): return self.data <= string.data return self.data <= string + def __gt__(self, string): if isinstance(string, UserString): return self.data > string.data return self.data > string + def __ge__(self, string): if isinstance(string, UserString): return self.data >= string.data @@ -1180,111 +1366,194 @@ class UserString(_collections_abc.Sequence): char = char.data return char in self.data - def __len__(self): return len(self.data) - def __getitem__(self, index): return self.__class__(self.data[index]) + def __bool__(self): + return bool(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.__class__(self.data[index]) + def __add__(self, other): if isinstance(other, UserString): return self.__class__(self.data + other.data) elif isinstance(other, str): return self.__class__(self.data + other) return self.__class__(self.data + str(other)) + def __radd__(self, other): if isinstance(other, str): return self.__class__(other + self.data) return self.__class__(str(other) + self.data) + def __mul__(self, n): - return self.__class__(self.data*n) + return self.__class__(self.data * n) + __rmul__ = __mul__ + def __mod__(self, args): return self.__class__(self.data % args) - def __rmod__(self, format): - return self.__class__(format % args) + + def __rmod__(self, template): + return self.__class__(str(template) % self) + + def __sizeof__(self): + return _sys.getsizeof(self.data) # the following methods are defined in alphabetical order: - def capitalize(self): return self.__class__(self.data.capitalize()) + def capitalize(self): + return self.__class__(self.data.capitalize()) + def casefold(self): return self.__class__(self.data.casefold()) + def center(self, width, *args): return self.__class__(self.data.center(width, *args)) + def count(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.count(sub, start, end) - def encode(self, encoding=None, errors=None): # XXX improve this? - if encoding: - if errors: - return self.__class__(self.data.encode(encoding, errors)) - return self.__class__(self.data.encode(encoding)) - return self.__class__(self.data.encode()) + + def removeprefix(self, prefix, /): + if isinstance(prefix, UserString): + prefix = prefix.data + return self.__class__(self.data.removeprefix(prefix)) + + def removesuffix(self, suffix, /): + if isinstance(suffix, UserString): + suffix = suffix.data + return self.__class__(self.data.removesuffix(suffix)) + + def encode(self, encoding='utf-8', errors='strict'): + encoding = 'utf-8' if encoding is None else encoding + errors = 'strict' if errors is None else errors + return self.data.encode(encoding, errors) + def endswith(self, suffix, start=0, end=_sys.maxsize): return self.data.endswith(suffix, start, end) + def expandtabs(self, tabsize=8): return self.__class__(self.data.expandtabs(tabsize)) + def find(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.find(sub, start, end) - def format(self, *args, **kwds): + + def format(self, /, *args, **kwds): return self.data.format(*args, **kwds) + def format_map(self, mapping): return self.data.format_map(mapping) + def index(self, sub, start=0, end=_sys.maxsize): return self.data.index(sub, start, end) - def isalpha(self): return self.data.isalpha() - def isalnum(self): return self.data.isalnum() - def isascii(self): return self.data.isascii() - def isdecimal(self): return self.data.isdecimal() - def isdigit(self): return self.data.isdigit() - def isidentifier(self): return self.data.isidentifier() - def islower(self): return self.data.islower() - def isnumeric(self): return self.data.isnumeric() - def isprintable(self): return self.data.isprintable() - def isspace(self): return self.data.isspace() - def istitle(self): return self.data.istitle() - def isupper(self): return self.data.isupper() - def join(self, seq): return self.data.join(seq) + + def isalpha(self): + return self.data.isalpha() + + def isalnum(self): + return self.data.isalnum() + + def isascii(self): + return self.data.isascii() + + def isdecimal(self): + return self.data.isdecimal() + + def isdigit(self): + return self.data.isdigit() + + def isidentifier(self): + return self.data.isidentifier() + + def islower(self): + return self.data.islower() + + def isnumeric(self): + return self.data.isnumeric() + + def isprintable(self): + return self.data.isprintable() + + def isspace(self): + return self.data.isspace() + + def istitle(self): + return self.data.istitle() + + def isupper(self): + return self.data.isupper() + + def join(self, seq): + return self.data.join(seq) + def ljust(self, width, *args): return self.__class__(self.data.ljust(width, *args)) - def lower(self): return self.__class__(self.data.lower()) - def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars)) + + def lower(self): + return self.__class__(self.data.lower()) + + def lstrip(self, chars=None): + return self.__class__(self.data.lstrip(chars)) + maketrans = str.maketrans + def partition(self, sep): return self.data.partition(sep) + def replace(self, old, new, maxsplit=-1): if isinstance(old, UserString): old = old.data if isinstance(new, UserString): new = new.data return self.__class__(self.data.replace(old, new, maxsplit)) + def rfind(self, sub, start=0, end=_sys.maxsize): if isinstance(sub, UserString): sub = sub.data return self.data.rfind(sub, start, end) + def rindex(self, sub, start=0, end=_sys.maxsize): return self.data.rindex(sub, start, end) + def rjust(self, width, *args): return self.__class__(self.data.rjust(width, *args)) + def rpartition(self, sep): return self.data.rpartition(sep) + def rstrip(self, chars=None): return self.__class__(self.data.rstrip(chars)) + def split(self, sep=None, maxsplit=-1): return self.data.split(sep, maxsplit) + def rsplit(self, sep=None, maxsplit=-1): return self.data.rsplit(sep, maxsplit) - def splitlines(self, keepends=False): return self.data.splitlines(keepends) + + def splitlines(self, keepends=False): + return self.data.splitlines(keepends) + def startswith(self, prefix, start=0, end=_sys.maxsize): return self.data.startswith(prefix, start, end) - def strip(self, chars=None): return self.__class__(self.data.strip(chars)) - def swapcase(self): return self.__class__(self.data.swapcase()) - def title(self): return self.__class__(self.data.title()) + + def strip(self, chars=None): + return self.__class__(self.data.strip(chars)) + + def swapcase(self): + return self.__class__(self.data.swapcase()) + + def title(self): + return self.__class__(self.data.title()) + def translate(self, *args): return self.__class__(self.data.translate(*args)) - def upper(self): return self.__class__(self.data.upper()) - def zfill(self, width): return self.__class__(self.data.zfill(width)) -# FIXME: try to implement defaultdict in collections.rs rather than in Python -# I (coolreader18) couldn't figure out some class stuff with __new__ and -# __init__ and __missing__ and subclassing built-in types from Rust, so I went -# with this instead. -from ._defaultdict import defaultdict + def upper(self): + return self.__class__(self.data.upper()) + + def zfill(self, width): + return self.__class__(self.data.zfill(width)) \ No newline at end of file diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index bbabeba61..c2550c471 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -51,30 +51,20 @@ class TestUserObjects(unittest.TestCase): self.assertEqual(obj.data, obj_copy.data) self.assertIs(obj.test, obj_copy.test) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_str_protocol(self): self._superset_test(UserString, str) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_list_protocol(self): self._superset_test(UserList, list) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dict_protocol(self): self._superset_test(UserDict, dict) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_list_copy(self): obj = UserList() obj.append(123) self._copy_test(obj) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dict_copy(self): obj = UserDict() obj[123] = "abc" @@ -205,6 +195,22 @@ class TestChainMap(unittest.TestCase): ('e', 55), ('f', 666), ('g', 777), ('h', 88888), ('i', 9999), ('j', 0)]) + def test_iter_not_calling_getitem_on_maps(self): + class DictWithGetItem(UserDict): + def __init__(self, *args, **kwds): + self.called = False + UserDict.__init__(self, *args, **kwds) + def __getitem__(self, item): + self.called = True + UserDict.__getitem__(self, item) + + d = DictWithGetItem(a=1) + c = ChainMap(d) + d.called = False + + set(c) # iterate over chain map + self.assertFalse(d.called, '__getitem__ was called') + def test_dict_coercion(self): d = ChainMap(dict(a=1, b=2), dict(b=20, c=30)) self.assertEqual(dict(d), dict(a=1, b=2, c=30)) @@ -242,6 +248,54 @@ class TestChainMap(unittest.TestCase): for k, v in dict(a=1, B=20, C=30, z=100).items(): # check get self.assertEqual(d.get(k, 100), v) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_union_operators(self): + cm1 = ChainMap(dict(a=1, b=2), dict(c=3, d=4)) + cm2 = ChainMap(dict(a=10, e=5), dict(b=20, d=4)) + cm3 = cm1.copy() + d = dict(a=10, c=30) + pairs = [('c', 3), ('p',0)] + + tmp = cm1 | cm2 # testing between chainmaps + self.assertEqual(tmp.maps, [cm1.maps[0] | dict(cm2), *cm1.maps[1:]]) + cm1 |= cm2 + self.assertEqual(tmp, cm1) + + tmp = cm2 | d # testing between chainmap and mapping + self.assertEqual(tmp.maps, [cm2.maps[0] | d, *cm2.maps[1:]]) + self.assertEqual((d | cm2).maps, [d | dict(cm2)]) + cm2 |= d + self.assertEqual(tmp, cm2) + + # testing behavior between chainmap and iterable key-value pairs + with self.assertRaises(TypeError): + cm3 | pairs + tmp = cm3.copy() + cm3 |= pairs + self.assertEqual(cm3.maps, [tmp.maps[0] | dict(pairs), *tmp.maps[1:]]) + + # testing proper return types for ChainMap and it's subclasses + class Subclass(ChainMap): + pass + + class SubclassRor(ChainMap): + def __ror__(self, other): + return super().__ror__(other) + + tmp = ChainMap() | ChainMap() + self.assertIs(type(tmp), ChainMap) + self.assertIs(type(tmp.maps[0]), dict) + tmp = ChainMap() | Subclass() + self.assertIs(type(tmp), ChainMap) + self.assertIs(type(tmp.maps[0]), dict) + tmp = Subclass() | ChainMap() + self.assertIs(type(tmp), Subclass) + self.assertIs(type(tmp.maps[0]), dict) + tmp = ChainMap() | SubclassRor() + self.assertIs(type(tmp), SubclassRor) + self.assertIs(type(tmp.maps[0]), dict) + ################################################################################ ### Named Tuples @@ -280,8 +334,6 @@ class TestNamedTuple(unittest.TestCase): self.assertRaises(TypeError, Point._make, [11]) # catch too few args self.assertRaises(TypeError, Point._make, [11, 22, 33]) # catch too many args - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_defaults(self): Point = namedtuple('Point', 'x y', defaults=(10, 20)) # 2 defaults self.assertEqual(Point._field_defaults, {'x': 10, 'y': 20}) @@ -377,6 +429,18 @@ class TestNamedTuple(unittest.TestCase): self.assertIs(P.m.__doc__, Q.o.__doc__) self.assertIs(P.n.__doc__, Q.p.__doc__) + @support.cpython_only + def test_field_repr(self): + Point = namedtuple('Point', 'x y') + self.assertEqual(repr(Point.x), "_tuplegetter(0, 'Alias for field number 0')") + self.assertEqual(repr(Point.y), "_tuplegetter(1, 'Alias for field number 1')") + + Point.x.__doc__ = 'The x-coordinate' + Point.y.__doc__ = 'The y-coordinate' + + self.assertEqual(repr(Point.x), "_tuplegetter(0, 'The x-coordinate')") + self.assertEqual(repr(Point.y), "_tuplegetter(1, 'The y-coordinate')") + def test_name_fixer(self): for spec, renamed in [ [('efg', 'g%hi'), ('efg', '_1')], # field with non-alpha char @@ -436,8 +500,8 @@ class TestNamedTuple(unittest.TestCase): self.assertIsInstance(p, tuple) self.assertEqual(p, (11, 22)) # matches a real tuple - self.assertEqual(tuple(p), (11, 22)) # coercable to a real tuple - self.assertEqual(list(p), [11, 22]) # coercable to a list + self.assertEqual(tuple(p), (11, 22)) # coercible to a real tuple + self.assertEqual(list(p), [11, 22]) # coercible to a list self.assertEqual(max(p), 22) # iterable self.assertEqual(max(*p), 22) # star-able x, y = p @@ -620,6 +684,11 @@ class TestNamedTuple(unittest.TestCase): self.assertEqual(np.x, 1) self.assertEqual(np.y, 2) + def test_new_builtins_issue_43102(self): + self.assertEqual( + namedtuple('C', ()).__new__.__globals__['__builtins__'], + {}) + ################################################################################ ### Abstract Base Classes @@ -1445,8 +1514,6 @@ class TestCollectionABCs(ABCTestCase): s &= WithSet('cdef') # This used to fail self.assertEqual(set(s), set('cd')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_issue_4920(self): # MutableSet.pop() method did not work class MySet(MutableSet): @@ -1471,8 +1538,12 @@ class TestCollectionABCs(ABCTestCase): return result def __repr__(self): return "MySet(%s)" % repr(list(self)) - s = MySet([5,43,2,1]) - self.assertEqual(s.pop(), 1) + items = [5,43,2,1] + s = MySet(items) + r = s.pop() + self.assertEquals(len(s), len(items) - 1) + self.assertNotIn(r, s) + self.assertIn(r, items) def test_issue8750(self): empty = WithSet() @@ -1523,6 +1594,7 @@ class TestCollectionABCs(ABCTestCase): class CustomEqualObject: def __eq__(self, other): return False + class CustomSequence(Sequence): def __init__(self, seq): self._seq = seq @@ -1532,7 +1604,7 @@ class TestCollectionABCs(ABCTestCase): return len(self._seq) nan = float('nan') - obj = CustomEqualObject() + obj = CustomEqualObject seq = CustomSequence([nan, obj, nan]) containers = [ seq, @@ -1551,6 +1623,64 @@ class TestCollectionABCs(ABCTestCase): # coerce both to a real set then check equality self.assertSetEqual(set(s1), set(s2)) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_Set_from_iterable(self): + """Verify _from_iterable overriden to an instance method works.""" + class SetUsingInstanceFromIterable(MutableSet): + def __init__(self, values, created_by): + if not created_by: + raise ValueError(f'created_by must be specified') + self.created_by = created_by + self._values = set(values) + + def _from_iterable(self, values): + return type(self)(values, 'from_iterable') + + def __contains__(self, value): + return value in self._values + + def __iter__(self): + yield from self._values + + def __len__(self): + return len(self._values) + + def add(self, value): + self._values.add(value) + + def discard(self, value): + self._values.discard(value) + + impl = SetUsingInstanceFromIterable([1, 2, 3], 'test') + + actual = impl - {1} + self.assertIsInstance(actual, SetUsingInstanceFromIterable) + self.assertEqual('from_iterable', actual.created_by) + self.assertEqual({2, 3}, actual) + + actual = impl | {4} + self.assertIsInstance(actual, SetUsingInstanceFromIterable) + self.assertEqual('from_iterable', actual.created_by) + self.assertEqual({1, 2, 3, 4}, actual) + + actual = impl & {2} + self.assertIsInstance(actual, SetUsingInstanceFromIterable) + self.assertEqual('from_iterable', actual.created_by) + self.assertEqual({2}, actual) + + actual = impl ^ {3, 4} + self.assertIsInstance(actual, SetUsingInstanceFromIterable) + self.assertEqual('from_iterable', actual.created_by) + self.assertEqual({1, 2, 4}, actual) + + # NOTE: ixor'ing with a list is important here: internally, __ixor__ + # only calls _from_iterable if the other value isn't already a Set. + impl ^= [3, 4] + self.assertIsInstance(impl, SetUsingInstanceFromIterable) + self.assertEqual('test', impl.created_by) + self.assertEqual({1, 2, 4}, impl) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_Set_interoperability_with_real_sets(self): @@ -1705,6 +1835,20 @@ class TestCollectionABCs(ABCTestCase): self.assertTrue(f1 != l1) self.assertTrue(f1 != l2) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_Set_hash_matches_frozenset(self): + sets = [ + {}, {1}, {None}, {-1}, {0.0}, {"abc"}, {1, 2, 3}, + {10**100, 10**101}, {"a", "b", "ab", ""}, {False, True}, + {object(), object(), object()}, {float("nan")}, {frozenset()}, + {*range(1000)}, {*range(1000)} - {100, 200, 300}, + {*range(sys.maxsize - 10, sys.maxsize + 10)}, + ] + for s in sets: + fs = frozenset(s) + self.assertEqual(hash(fs), Set._hash(fs), msg=s) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_Mapping(self): @@ -2125,6 +2269,31 @@ class TestCounter(unittest.TestCase): set_result = setop(set(p.elements()), set(q.elements())) self.assertEqual(counter_result, dict.fromkeys(set_result, 1)) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_subset_superset_not_implemented(self): + # Verify that multiset comparison operations are not implemented. + + # These operations were intentionally omitted because multiset + # comparison semantics conflict with existing dict equality semantics. + + # For multisets, we would expect that if p<=q and p>=q are both true, + # then p==q. However, dict equality semantics require that p!=q when + # one of sets contains an element with a zero count and the other + # doesn't. + + p = Counter(a=1, b=0) + q = Counter(a=1, c=0) + self.assertNotEqual(p, q) + with self.assertRaises(TypeError): + p < q + with self.assertRaises(TypeError): + p <= q + with self.assertRaises(TypeError): + p > q + with self.assertRaises(TypeError): + p >= q + def test_inplace_operations(self): elements = 'abcd' for i in range(1000): @@ -2218,4 +2387,4 @@ def test_main(verbose=None): if __name__ == "__main__": - test_main(verbose=True) + test_main(verbose=True) \ No newline at end of file diff --git a/Lib/test/test_userstring.py b/Lib/test/test_userstring.py index b6c11cc4e..188c59e26 100644 --- a/Lib/test/test_userstring.py +++ b/Lib/test/test_userstring.py @@ -14,13 +14,9 @@ class UserStringTest( type2test = UserString - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_removeprefix(self): super().test_removeprefix() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_removesuffix(self): super().test_removesuffix() From 6abfc3a00c78511097ff6ea1e06a30ea53e0169b Mon Sep 17 00:00:00 2001 From: Chris Moradi <37349208+chrismoradi@users.noreply.github.com> Date: Mon, 25 Oct 2021 22:41:59 -0700 Subject: [PATCH 2/2] Revert UserDict.__init__ changes, new passing test of UserList --- Lib/collections/__init__.py | 19 +++++++++++++++++-- Lib/test/test_userlist.py | 2 -- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 25d1db7f9..ace8db113 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -1073,11 +1073,26 @@ class ChainMap(_collections_abc.MutableMapping): class UserDict(_collections_abc.MutableMapping): # Start by filling-out the abstract methods - def __init__(self, dict=None, /, **kwargs): + def __init__(*args, **kwargs): + if not args: + raise TypeError("descriptor '__init__' of 'UserDict' object " + "needs an argument") + self, *args = args + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + if args: + dict = args[0] + elif 'dict' in kwargs: + dict = kwargs.pop('dict') + import warnings + warnings.warn("Passing 'dict' as keyword argument is deprecated", + DeprecationWarning, stacklevel=2) + else: + dict = None self.data = {} if dict is not None: self.update(dict) - if kwargs: + if len(kwargs): self.update(kwargs) def __len__(self): diff --git a/Lib/test/test_userlist.py b/Lib/test/test_userlist.py index 26752353a..3b30cdc8e 100644 --- a/Lib/test/test_userlist.py +++ b/Lib/test/test_userlist.py @@ -27,8 +27,6 @@ class UserListTest(list_tests.CommonTest): for j in range(-3, 6): self.assertEqual(u[i:j], l[i:j]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_slice_type(self): l = [0, 1, 2, 3, 4] u = UserList(l)