Merge pull request #3378 from chrismoradi/update-collections-and-tests-from-cpython

Update collections from CPython, fix tests for UserDict/List/String
This commit is contained in:
Jeong YunWon
2021-10-27 02:35:06 +09:00
committed by GitHub
4 changed files with 600 additions and 153 deletions

View File

@@ -14,22 +14,30 @@ list, set, and tuple.
'''
__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList',
'UserString', 'Counter', 'OrderedDict', 'ChainMap']
__all__ = [
'ChainMap',
'Counter',
'OrderedDict',
'UserDict',
'UserList',
'UserString',
'defaultdict',
'deque',
'namedtuple',
]
# For backwards compatibility, continue to make the collections ABCs
# available through the collections module.
from _collections_abc import *
import _collections_abc
__all__ += _collections_abc.__all__
from operator import itemgetter as _itemgetter, eq as _eq
from keyword import iskeyword as _iskeyword
import sys as _sys
import heapq as _heapq
from _weakref import proxy as _proxy
from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
import sys as _sys
from itertools import chain as _chain
from itertools import repeat as _repeat
from itertools import starmap as _starmap
from keyword import iskeyword as _iskeyword
from operator import eq as _eq
from operator import itemgetter as _itemgetter
from reprlib import recursive_repr as _recursive_repr
from _weakref import proxy as _proxy
try:
from _collections import deque
@@ -41,7 +49,11 @@ else:
try:
from _collections import defaultdict
except ImportError:
pass
# FIXME: try to implement defaultdict in collections.rs rather than in Python
# I (coolreader18) couldn't figure out some class stuff with __new__ and
# __init__ and __missing__ and subclassing built-in types from Rust, so I went
# with this instead.
from ._defaultdict import defaultdict
def __getattr__(name):
@@ -52,13 +64,14 @@ def __getattr__(name):
obj = getattr(_collections_abc, name)
import warnings
warnings.warn("Using or importing the ABCs from 'collections' instead "
"of from 'collections.abc' is deprecated, "
"and in 3.8 it will stop working",
"of from 'collections.abc' is deprecated since Python 3.3, "
"and in 3.10 it will stop working",
DeprecationWarning, stacklevel=2)
globals()[name] = obj
return obj
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
################################################################################
### OrderedDict
################################################################################
@@ -108,6 +121,7 @@ class OrderedDict(dict):
self, *args = args
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
try:
self.__root
except AttributeError:
@@ -304,6 +318,24 @@ class OrderedDict(dict):
return dict.__eq__(self, other) and all(map(_eq, self, other))
return dict.__eq__(self, other)
def __ior__(self, other):
self.update(other)
return self
def __or__(self, other):
if not isinstance(other, dict):
return NotImplemented
new = self.__class__(self)
new.update(other)
return new
def __ror__(self, other):
if not isinstance(other, dict):
return NotImplemented
new = self.__class__(other)
new.update(self)
return new
try:
from _collections import OrderedDict
@@ -316,7 +348,10 @@ except ImportError:
### namedtuple
################################################################################
_nt_itemgetters = {}
try:
from _collections import _tuplegetter
except ImportError:
_tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc)
def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None):
"""Returns a new subclass of tuple with named fields.
@@ -389,18 +424,23 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
# Variables used in the methods and docstrings
field_names = tuple(map(_sys.intern, field_names))
num_fields = len(field_names)
arg_list = repr(field_names).replace("'", "")[1:-1]
arg_list = ', '.join(field_names)
if num_fields == 1:
arg_list += ','
repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')'
tuple_new = tuple.__new__
_len = len
_dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip
# Create all the named tuple methods to be added to the class namespace
s = f'def __new__(_cls, {arg_list}): return _tuple_new(_cls, ({arg_list}))'
namespace = {'_tuple_new': tuple_new, '__name__': f'namedtuple_{typename}'}
# Note: exec() has the side-effect of interning the field names
exec(s, namespace)
__new__ = namespace['__new__']
namespace = {
'_tuple_new': tuple_new,
'__builtins__': {},
'__name__': f'namedtuple_{typename}',
}
code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))'
__new__ = eval(code, namespace)
__new__.__name__ = '__new__'
__new__.__doc__ = f'Create new instance of {typename}({arg_list})'
if defaults is not None:
__new__.__defaults__ = defaults
@@ -415,8 +455,8 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
_make.__func__.__doc__ = (f'Make a new {typename} object from a sequence '
'or iterable')
def _replace(_self, **kwds):
result = _self._make(map(kwds.pop, field_names, _self))
def _replace(_self, /, **kwds):
result = _self._make(_map(kwds.pop, field_names, _self))
if kwds:
raise ValueError(f'Got unexpected field names: {list(kwds)!r}')
return result
@@ -429,17 +469,22 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
return self.__class__.__name__ + repr_fmt % self
def _asdict(self):
'Return a new OrderedDict which maps field names to their values.'
return OrderedDict(zip(self._fields, self))
'Return a new dict which maps field names to their values.'
return _dict(_zip(self._fields, self))
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return tuple(self)
return _tuple(self)
# Modify function metadata to help with introspection and debugging
for method in (__new__, _make.__func__, _replace,
__repr__, _asdict, __getnewargs__):
for method in (
__new__,
_make.__func__,
_replace,
__repr__,
_asdict,
__getnewargs__,
):
method.__qualname__ = f'{typename}.{method.__name__}'
# Build-up the class namespace dictionary
@@ -448,7 +493,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
'__doc__': f'{typename}({arg_list})',
'__slots__': (),
'_fields': field_names,
'_fields_defaults': field_defaults,
'_field_defaults': field_defaults,
'__new__': __new__,
'_make': _make,
'_replace': _replace,
@@ -456,15 +501,9 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
'_asdict': _asdict,
'__getnewargs__': __getnewargs__,
}
cache = _nt_itemgetters
for index, name in enumerate(field_names):
try:
itemgetter_object, doc = cache[index]
except KeyError:
itemgetter_object = _itemgetter(index)
doc = f'Alias for field number {index}'
cache[index] = itemgetter_object, doc
class_namespace[name] = property(itemgetter_object, doc=doc)
doc = _sys.intern(f'Alias for field number {index}')
class_namespace[name] = _tuplegetter(index, doc)
result = type(typename, (tuple,), class_namespace)
@@ -579,8 +618,8 @@ class Counter(dict):
'''List the n most common elements and their counts from the most
common to the least. If n is None, then list all element counts.
>>> Counter('abcdeabcdabcaba').most_common(3)
[('a', 5), ('b', 4), ('c', 3)]
>>> Counter('abracadabra').most_common(3)
[('a', 5), ('b', 2), ('r', 2)]
'''
# Emulate Bag.sortedByCount from Smalltalk
@@ -614,8 +653,13 @@ class Counter(dict):
@classmethod
def fromkeys(cls, iterable, v=None):
# There is no equivalent method for counters because setting v=1
# means that no element can have a count greater than one.
# There is no equivalent method for counters because the semantics
# would be ambiguous in cases such as Counter.fromkeys('aaabbc', v=2).
# Initializing counters to zero values isn't necessary because zero
# is already the default value for counter lookups. Initializing
# to one is easily accomplished with Counter(set(iterable)). For
# more exotic cases, create a dictionary first using a dictionary
# comprehension or dict.fromkeys().
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
@@ -646,6 +690,7 @@ class Counter(dict):
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
iterable = args[0] if args else None
if iterable is not None:
if isinstance(iterable, _collections_abc.Mapping):
if self:
@@ -653,7 +698,8 @@ class Counter(dict):
for elem, count in iterable.items():
self[elem] = count + self_get(elem, 0)
else:
super(Counter, self).update(iterable) # fast path when counter is empty
# fast path when counter is empty
super(Counter, self).update(iterable)
else:
_count_elements(self, iterable)
if kwds:
@@ -682,6 +728,7 @@ class Counter(dict):
if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % len(args))
iterable = args[0] if args else None
if iterable is not None:
self_get = self.get
if isinstance(iterable, _collections_abc.Mapping):
@@ -707,13 +754,14 @@ class Counter(dict):
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
return f'{self.__class__.__name__}()'
try:
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
# dict() preserves the ordering returned by most_common()
d = dict(self.most_common())
except TypeError:
# handle case where values are not orderable
return '{0}({1!r})'.format(self.__class__.__name__, dict(self))
d = dict(self)
return f'{self.__class__.__name__}({d!r})'
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
@@ -723,6 +771,13 @@ class Counter(dict):
#
# To strip negative and zero counts, add-in an empty counter:
# c += Counter()
#
# Rich comparison operators for multiset subset and superset tests
# are deliberately omitted due to semantic conflicts with the
# existing inherited dict equality method. Subset and superset
# semantics ignore zero counts and require that p≤q ∧ p≥q → p=q;
# however, that would not be the case for p=Counter(a=1, b=0)
# and q=Counter(a=1) where the dictionaries are not equal.
def __add__(self, other):
'''Add counts from two counters.
@@ -927,7 +982,7 @@ class ChainMap(_collections_abc.MutableMapping):
def __iter__(self):
d = {}
for mapping in reversed(self.maps):
d.update(mapping) # reuses stored hash values if possible
d.update(dict.fromkeys(mapping)) # reuses stored hash values if possible
return iter(d)
def __contains__(self, key):
@@ -938,8 +993,7 @@ class ChainMap(_collections_abc.MutableMapping):
@_recursive_repr()
def __repr__(self):
return '{0.__class__.__name__}({1})'.format(
self, ', '.join(map(repr, self.maps)))
return f'{self.__class__.__name__}({", ".join(map(repr, self.maps))})'
@classmethod
def fromkeys(cls, iterable, *args):
@@ -972,7 +1026,7 @@ class ChainMap(_collections_abc.MutableMapping):
try:
del self.maps[0][key]
except KeyError:
raise KeyError('Key not found in the first mapping: {!r}'.format(key))
raise KeyError(f'Key not found in the first mapping: {key!r}')
def popitem(self):
'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.'
@@ -986,12 +1040,31 @@ class ChainMap(_collections_abc.MutableMapping):
try:
return self.maps[0].pop(key, *args)
except KeyError:
raise KeyError('Key not found in the first mapping: {!r}'.format(key))
raise KeyError(f'Key not found in the first mapping: {key!r}')
def clear(self):
'Clear maps[0], leaving maps[1:] intact.'
self.maps[0].clear()
def __ior__(self, other):
self.maps[0].update(other)
return self
def __or__(self, other):
if not isinstance(other, _collections_abc.Mapping):
return NotImplemented
m = self.copy()
m.maps[0].update(other)
return m
def __ror__(self, other):
if not isinstance(other, _collections_abc.Mapping):
return NotImplemented
m = dict(other)
for child in reversed(self.maps):
m.update(child)
return self.__class__(m)
################################################################################
### UserDict
@@ -1021,15 +1094,26 @@ class UserDict(_collections_abc.MutableMapping):
self.update(dict)
if len(kwargs):
self.update(kwargs)
def __len__(self): return len(self.data)
def __len__(self):
return len(self.data)
def __bool__(self):
return bool(self.data)
def __getitem__(self, key):
if key in self.data:
return self.data[key]
if hasattr(self.__class__, "__missing__"):
return self.__class__.__missing__(self, key)
raise KeyError(key)
def __setitem__(self, key, item): self.data[key] = item
def __delitem__(self, key): del self.data[key]
def __setitem__(self, key, item):
self.data[key] = item
def __delitem__(self, key):
del self.data[key]
def __iter__(self):
return iter(self.data)
@@ -1038,7 +1122,40 @@ class UserDict(_collections_abc.MutableMapping):
return key in self.data
# Now, add the methods in dicts but not in MutableMapping
def __repr__(self): return repr(self.data)
def __repr__(self):
return repr(self.data)
def __or__(self, other):
if isinstance(other, UserDict):
return self.__class__(self.data | other.data)
if isinstance(other, dict):
return self.__class__(self.data | other)
return NotImplemented
def __ror__(self, other):
if isinstance(other, UserDict):
return self.__class__(other.data | self.data)
if isinstance(other, dict):
return self.__class__(other | self.data)
return NotImplemented
def __ior__(self, other):
if isinstance(other, UserDict):
self.data |= other.data
else:
self.data |= other
return self
def __copy__(self):
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
# Create a copy and avoid triggering descriptors
inst.__dict__["data"] = self.__dict__["data"].copy()
return inst
def __sizeof__(self):
return _sys.getsizeof(self.data)
def copy(self):
if self.__class__ is UserDict:
return UserDict(self.data.copy())
@@ -1051,6 +1168,7 @@ class UserDict(_collections_abc.MutableMapping):
self.data = data
c.update(self)
return c
@classmethod
def fromkeys(cls, iterable, value=None):
d = cls()
@@ -1059,13 +1177,13 @@ class UserDict(_collections_abc.MutableMapping):
return d
################################################################################
### UserList
################################################################################
class UserList(_collections_abc.MutableSequence):
"""A more or less complete user-defined wrapper around list objects."""
def __init__(self, initlist=None):
self.data = []
if initlist is not None:
@@ -1076,31 +1194,63 @@ class UserList(_collections_abc.MutableSequence):
self.data[:] = initlist.data[:]
else:
self.data = list(initlist)
def __repr__(self): return repr(self.data)
def __lt__(self, other): return self.data < self.__cast(other)
def __le__(self, other): return self.data <= self.__cast(other)
def __eq__(self, other): return self.data == self.__cast(other)
def __gt__(self, other): return self.data > self.__cast(other)
def __ge__(self, other): return self.data >= self.__cast(other)
def __repr__(self):
return repr(self.data)
def __lt__(self, other):
return self.data < self.__cast(other)
def __le__(self, other):
return self.data <= self.__cast(other)
def __eq__(self, other):
return self.data == self.__cast(other)
def __gt__(self, other):
return self.data > self.__cast(other)
def __ge__(self, other):
return self.data >= self.__cast(other)
def __cast(self, other):
return other.data if isinstance(other, UserList) else other
def __contains__(self, item): return item in self.data
def __len__(self): return len(self.data)
def __getitem__(self, i): return self.data[i]
def __setitem__(self, i, item): self.data[i] = item
def __delitem__(self, i): del self.data[i]
def __contains__(self, item):
return item in self.data
def __len__(self):
return len(self.data)
def __bool__(self):
return bool(self.data)
def __getitem__(self, i):
if isinstance(i, slice):
return self.__class__(self.data[i])
else:
return self.data[i]
def __setitem__(self, i, item):
self.data[i] = item
def __delitem__(self, i):
del self.data[i]
def __add__(self, other):
if isinstance(other, UserList):
return self.__class__(self.data + other.data)
elif isinstance(other, type(self.data)):
return self.__class__(self.data + other)
return self.__class__(self.data + list(other))
def __radd__(self, other):
if isinstance(other, UserList):
return self.__class__(other.data + self.data)
elif isinstance(other, type(self.data)):
return self.__class__(other + self.data)
return self.__class__(list(other) + self.data)
def __iadd__(self, other):
if isinstance(other, UserList):
self.data += other.data
@@ -1109,22 +1259,56 @@ class UserList(_collections_abc.MutableSequence):
else:
self.data += list(other)
return self
def __mul__(self, n):
return self.__class__(self.data*n)
return self.__class__(self.data * n)
__rmul__ = __mul__
def __imul__(self, n):
self.data *= n
return self
def append(self, item): self.data.append(item)
def insert(self, i, item): self.data.insert(i, item)
def pop(self, i=-1): return self.data.pop(i)
def remove(self, item): self.data.remove(item)
def clear(self): self.data.clear()
def copy(self): return self.__class__(self)
def count(self, item): return self.data.count(item)
def index(self, item, *args): return self.data.index(item, *args)
def reverse(self): self.data.reverse()
def sort(self, *args, **kwds): self.data.sort(*args, **kwds)
def __copy__(self):
inst = self.__class__.__new__(self.__class__)
inst.__dict__.update(self.__dict__)
# Create a copy and avoid triggering descriptors
inst.__dict__["data"] = self.__dict__["data"][:]
return inst
def __sizeof__(self):
return _sys.getsizeof(self.data)
def append(self, item):
self.data.append(item)
def insert(self, i, item):
self.data.insert(i, item)
def pop(self, i=-1):
return self.data.pop(i)
def remove(self, item):
self.data.remove(item)
def clear(self):
self.data.clear()
def copy(self):
return self.__class__(self)
def count(self, item):
return self.data.count(item)
def index(self, item, *args):
return self.data.index(item, *args)
def reverse(self):
self.data.reverse()
def sort(self, /, *args, **kwds):
self.data.sort(*args, **kwds)
def extend(self, other):
if isinstance(other, UserList):
self.data.extend(other.data)
@@ -1132,12 +1316,12 @@ class UserList(_collections_abc.MutableSequence):
self.data.extend(other)
################################################################################
### UserString
################################################################################
class UserString(_collections_abc.Sequence):
def __init__(self, seq):
if isinstance(seq, str):
self.data = seq
@@ -1145,12 +1329,25 @@ class UserString(_collections_abc.Sequence):
self.data = seq.data[:]
else:
self.data = str(seq)
def __str__(self): return str(self.data)
def __repr__(self): return repr(self.data)
def __int__(self): return int(self.data)
def __float__(self): return float(self.data)
def __complex__(self): return complex(self.data)
def __hash__(self): return hash(self.data)
def __str__(self):
return str(self.data)
def __repr__(self):
return repr(self.data)
def __int__(self):
return int(self.data)
def __float__(self):
return float(self.data)
def __complex__(self):
return complex(self.data)
def __hash__(self):
return hash(self.data)
def __getnewargs__(self):
return (self.data[:],)
@@ -1158,18 +1355,22 @@ class UserString(_collections_abc.Sequence):
if isinstance(string, UserString):
return self.data == string.data
return self.data == string
def __lt__(self, string):
if isinstance(string, UserString):
return self.data < string.data
return self.data < string
def __le__(self, string):
if isinstance(string, UserString):
return self.data <= string.data
return self.data <= string
def __gt__(self, string):
if isinstance(string, UserString):
return self.data > string.data
return self.data > string
def __ge__(self, string):
if isinstance(string, UserString):
return self.data >= string.data
@@ -1180,111 +1381,194 @@ class UserString(_collections_abc.Sequence):
char = char.data
return char in self.data
def __len__(self): return len(self.data)
def __getitem__(self, index): return self.__class__(self.data[index])
def __bool__(self):
return bool(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.__class__(self.data[index])
def __add__(self, other):
if isinstance(other, UserString):
return self.__class__(self.data + other.data)
elif isinstance(other, str):
return self.__class__(self.data + other)
return self.__class__(self.data + str(other))
def __radd__(self, other):
if isinstance(other, str):
return self.__class__(other + self.data)
return self.__class__(str(other) + self.data)
def __mul__(self, n):
return self.__class__(self.data*n)
return self.__class__(self.data * n)
__rmul__ = __mul__
def __mod__(self, args):
return self.__class__(self.data % args)
def __rmod__(self, format):
return self.__class__(format % args)
def __rmod__(self, template):
return self.__class__(str(template) % self)
def __sizeof__(self):
return _sys.getsizeof(self.data)
# the following methods are defined in alphabetical order:
def capitalize(self): return self.__class__(self.data.capitalize())
def capitalize(self):
return self.__class__(self.data.capitalize())
def casefold(self):
return self.__class__(self.data.casefold())
def center(self, width, *args):
return self.__class__(self.data.center(width, *args))
def count(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.count(sub, start, end)
def encode(self, encoding=None, errors=None): # XXX improve this?
if encoding:
if errors:
return self.__class__(self.data.encode(encoding, errors))
return self.__class__(self.data.encode(encoding))
return self.__class__(self.data.encode())
def removeprefix(self, prefix, /):
if isinstance(prefix, UserString):
prefix = prefix.data
return self.__class__(self.data.removeprefix(prefix))
def removesuffix(self, suffix, /):
if isinstance(suffix, UserString):
suffix = suffix.data
return self.__class__(self.data.removesuffix(suffix))
def encode(self, encoding='utf-8', errors='strict'):
encoding = 'utf-8' if encoding is None else encoding
errors = 'strict' if errors is None else errors
return self.data.encode(encoding, errors)
def endswith(self, suffix, start=0, end=_sys.maxsize):
return self.data.endswith(suffix, start, end)
def expandtabs(self, tabsize=8):
return self.__class__(self.data.expandtabs(tabsize))
def find(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.find(sub, start, end)
def format(self, *args, **kwds):
def format(self, /, *args, **kwds):
return self.data.format(*args, **kwds)
def format_map(self, mapping):
return self.data.format_map(mapping)
def index(self, sub, start=0, end=_sys.maxsize):
return self.data.index(sub, start, end)
def isalpha(self): return self.data.isalpha()
def isalnum(self): return self.data.isalnum()
def isascii(self): return self.data.isascii()
def isdecimal(self): return self.data.isdecimal()
def isdigit(self): return self.data.isdigit()
def isidentifier(self): return self.data.isidentifier()
def islower(self): return self.data.islower()
def isnumeric(self): return self.data.isnumeric()
def isprintable(self): return self.data.isprintable()
def isspace(self): return self.data.isspace()
def istitle(self): return self.data.istitle()
def isupper(self): return self.data.isupper()
def join(self, seq): return self.data.join(seq)
def isalpha(self):
return self.data.isalpha()
def isalnum(self):
return self.data.isalnum()
def isascii(self):
return self.data.isascii()
def isdecimal(self):
return self.data.isdecimal()
def isdigit(self):
return self.data.isdigit()
def isidentifier(self):
return self.data.isidentifier()
def islower(self):
return self.data.islower()
def isnumeric(self):
return self.data.isnumeric()
def isprintable(self):
return self.data.isprintable()
def isspace(self):
return self.data.isspace()
def istitle(self):
return self.data.istitle()
def isupper(self):
return self.data.isupper()
def join(self, seq):
return self.data.join(seq)
def ljust(self, width, *args):
return self.__class__(self.data.ljust(width, *args))
def lower(self): return self.__class__(self.data.lower())
def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars))
def lower(self):
return self.__class__(self.data.lower())
def lstrip(self, chars=None):
return self.__class__(self.data.lstrip(chars))
maketrans = str.maketrans
def partition(self, sep):
return self.data.partition(sep)
def replace(self, old, new, maxsplit=-1):
if isinstance(old, UserString):
old = old.data
if isinstance(new, UserString):
new = new.data
return self.__class__(self.data.replace(old, new, maxsplit))
def rfind(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.rfind(sub, start, end)
def rindex(self, sub, start=0, end=_sys.maxsize):
return self.data.rindex(sub, start, end)
def rjust(self, width, *args):
return self.__class__(self.data.rjust(width, *args))
def rpartition(self, sep):
return self.data.rpartition(sep)
def rstrip(self, chars=None):
return self.__class__(self.data.rstrip(chars))
def split(self, sep=None, maxsplit=-1):
return self.data.split(sep, maxsplit)
def rsplit(self, sep=None, maxsplit=-1):
return self.data.rsplit(sep, maxsplit)
def splitlines(self, keepends=False): return self.data.splitlines(keepends)
def splitlines(self, keepends=False):
return self.data.splitlines(keepends)
def startswith(self, prefix, start=0, end=_sys.maxsize):
return self.data.startswith(prefix, start, end)
def strip(self, chars=None): return self.__class__(self.data.strip(chars))
def swapcase(self): return self.__class__(self.data.swapcase())
def title(self): return self.__class__(self.data.title())
def strip(self, chars=None):
return self.__class__(self.data.strip(chars))
def swapcase(self):
return self.__class__(self.data.swapcase())
def title(self):
return self.__class__(self.data.title())
def translate(self, *args):
return self.__class__(self.data.translate(*args))
def upper(self): return self.__class__(self.data.upper())
def zfill(self, width): return self.__class__(self.data.zfill(width))
# FIXME: try to implement defaultdict in collections.rs rather than in Python
# I (coolreader18) couldn't figure out some class stuff with __new__ and
# __init__ and __missing__ and subclassing built-in types from Rust, so I went
# with this instead.
from ._defaultdict import defaultdict
def upper(self):
return self.__class__(self.data.upper())
def zfill(self, width):
return self.__class__(self.data.zfill(width))

View File

@@ -51,30 +51,20 @@ class TestUserObjects(unittest.TestCase):
self.assertEqual(obj.data, obj_copy.data)
self.assertIs(obj.test, obj_copy.test)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_str_protocol(self):
self._superset_test(UserString, str)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_list_protocol(self):
self._superset_test(UserList, list)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_dict_protocol(self):
self._superset_test(UserDict, dict)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_list_copy(self):
obj = UserList()
obj.append(123)
self._copy_test(obj)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_dict_copy(self):
obj = UserDict()
obj[123] = "abc"
@@ -205,6 +195,22 @@ class TestChainMap(unittest.TestCase):
('e', 55), ('f', 666), ('g', 777), ('h', 88888),
('i', 9999), ('j', 0)])
def test_iter_not_calling_getitem_on_maps(self):
class DictWithGetItem(UserDict):
def __init__(self, *args, **kwds):
self.called = False
UserDict.__init__(self, *args, **kwds)
def __getitem__(self, item):
self.called = True
UserDict.__getitem__(self, item)
d = DictWithGetItem(a=1)
c = ChainMap(d)
d.called = False
set(c) # iterate over chain map
self.assertFalse(d.called, '__getitem__ was called')
def test_dict_coercion(self):
d = ChainMap(dict(a=1, b=2), dict(b=20, c=30))
self.assertEqual(dict(d), dict(a=1, b=2, c=30))
@@ -242,6 +248,54 @@ class TestChainMap(unittest.TestCase):
for k, v in dict(a=1, B=20, C=30, z=100).items(): # check get
self.assertEqual(d.get(k, 100), v)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_union_operators(self):
cm1 = ChainMap(dict(a=1, b=2), dict(c=3, d=4))
cm2 = ChainMap(dict(a=10, e=5), dict(b=20, d=4))
cm3 = cm1.copy()
d = dict(a=10, c=30)
pairs = [('c', 3), ('p',0)]
tmp = cm1 | cm2 # testing between chainmaps
self.assertEqual(tmp.maps, [cm1.maps[0] | dict(cm2), *cm1.maps[1:]])
cm1 |= cm2
self.assertEqual(tmp, cm1)
tmp = cm2 | d # testing between chainmap and mapping
self.assertEqual(tmp.maps, [cm2.maps[0] | d, *cm2.maps[1:]])
self.assertEqual((d | cm2).maps, [d | dict(cm2)])
cm2 |= d
self.assertEqual(tmp, cm2)
# testing behavior between chainmap and iterable key-value pairs
with self.assertRaises(TypeError):
cm3 | pairs
tmp = cm3.copy()
cm3 |= pairs
self.assertEqual(cm3.maps, [tmp.maps[0] | dict(pairs), *tmp.maps[1:]])
# testing proper return types for ChainMap and it's subclasses
class Subclass(ChainMap):
pass
class SubclassRor(ChainMap):
def __ror__(self, other):
return super().__ror__(other)
tmp = ChainMap() | ChainMap()
self.assertIs(type(tmp), ChainMap)
self.assertIs(type(tmp.maps[0]), dict)
tmp = ChainMap() | Subclass()
self.assertIs(type(tmp), ChainMap)
self.assertIs(type(tmp.maps[0]), dict)
tmp = Subclass() | ChainMap()
self.assertIs(type(tmp), Subclass)
self.assertIs(type(tmp.maps[0]), dict)
tmp = ChainMap() | SubclassRor()
self.assertIs(type(tmp), SubclassRor)
self.assertIs(type(tmp.maps[0]), dict)
################################################################################
### Named Tuples
@@ -280,8 +334,6 @@ class TestNamedTuple(unittest.TestCase):
self.assertRaises(TypeError, Point._make, [11]) # catch too few args
self.assertRaises(TypeError, Point._make, [11, 22, 33]) # catch too many args
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_defaults(self):
Point = namedtuple('Point', 'x y', defaults=(10, 20)) # 2 defaults
self.assertEqual(Point._field_defaults, {'x': 10, 'y': 20})
@@ -377,6 +429,18 @@ class TestNamedTuple(unittest.TestCase):
self.assertIs(P.m.__doc__, Q.o.__doc__)
self.assertIs(P.n.__doc__, Q.p.__doc__)
@support.cpython_only
def test_field_repr(self):
Point = namedtuple('Point', 'x y')
self.assertEqual(repr(Point.x), "_tuplegetter(0, 'Alias for field number 0')")
self.assertEqual(repr(Point.y), "_tuplegetter(1, 'Alias for field number 1')")
Point.x.__doc__ = 'The x-coordinate'
Point.y.__doc__ = 'The y-coordinate'
self.assertEqual(repr(Point.x), "_tuplegetter(0, 'The x-coordinate')")
self.assertEqual(repr(Point.y), "_tuplegetter(1, 'The y-coordinate')")
def test_name_fixer(self):
for spec, renamed in [
[('efg', 'g%hi'), ('efg', '_1')], # field with non-alpha char
@@ -436,8 +500,8 @@ class TestNamedTuple(unittest.TestCase):
self.assertIsInstance(p, tuple)
self.assertEqual(p, (11, 22)) # matches a real tuple
self.assertEqual(tuple(p), (11, 22)) # coercable to a real tuple
self.assertEqual(list(p), [11, 22]) # coercable to a list
self.assertEqual(tuple(p), (11, 22)) # coercible to a real tuple
self.assertEqual(list(p), [11, 22]) # coercible to a list
self.assertEqual(max(p), 22) # iterable
self.assertEqual(max(*p), 22) # star-able
x, y = p
@@ -620,6 +684,11 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(np.x, 1)
self.assertEqual(np.y, 2)
def test_new_builtins_issue_43102(self):
self.assertEqual(
namedtuple('C', ()).__new__.__globals__['__builtins__'],
{})
################################################################################
### Abstract Base Classes
@@ -1445,8 +1514,6 @@ class TestCollectionABCs(ABCTestCase):
s &= WithSet('cdef') # This used to fail
self.assertEqual(set(s), set('cd'))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_issue_4920(self):
# MutableSet.pop() method did not work
class MySet(MutableSet):
@@ -1471,8 +1538,12 @@ class TestCollectionABCs(ABCTestCase):
return result
def __repr__(self):
return "MySet(%s)" % repr(list(self))
s = MySet([5,43,2,1])
self.assertEqual(s.pop(), 1)
items = [5,43,2,1]
s = MySet(items)
r = s.pop()
self.assertEquals(len(s), len(items) - 1)
self.assertNotIn(r, s)
self.assertIn(r, items)
def test_issue8750(self):
empty = WithSet()
@@ -1523,6 +1594,7 @@ class TestCollectionABCs(ABCTestCase):
class CustomEqualObject:
def __eq__(self, other):
return False
class CustomSequence(Sequence):
def __init__(self, seq):
self._seq = seq
@@ -1532,7 +1604,7 @@ class TestCollectionABCs(ABCTestCase):
return len(self._seq)
nan = float('nan')
obj = CustomEqualObject()
obj = CustomEqualObject
seq = CustomSequence([nan, obj, nan])
containers = [
seq,
@@ -1551,6 +1623,64 @@ class TestCollectionABCs(ABCTestCase):
# coerce both to a real set then check equality
self.assertSetEqual(set(s1), set(s2))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_Set_from_iterable(self):
"""Verify _from_iterable overriden to an instance method works."""
class SetUsingInstanceFromIterable(MutableSet):
def __init__(self, values, created_by):
if not created_by:
raise ValueError(f'created_by must be specified')
self.created_by = created_by
self._values = set(values)
def _from_iterable(self, values):
return type(self)(values, 'from_iterable')
def __contains__(self, value):
return value in self._values
def __iter__(self):
yield from self._values
def __len__(self):
return len(self._values)
def add(self, value):
self._values.add(value)
def discard(self, value):
self._values.discard(value)
impl = SetUsingInstanceFromIterable([1, 2, 3], 'test')
actual = impl - {1}
self.assertIsInstance(actual, SetUsingInstanceFromIterable)
self.assertEqual('from_iterable', actual.created_by)
self.assertEqual({2, 3}, actual)
actual = impl | {4}
self.assertIsInstance(actual, SetUsingInstanceFromIterable)
self.assertEqual('from_iterable', actual.created_by)
self.assertEqual({1, 2, 3, 4}, actual)
actual = impl & {2}
self.assertIsInstance(actual, SetUsingInstanceFromIterable)
self.assertEqual('from_iterable', actual.created_by)
self.assertEqual({2}, actual)
actual = impl ^ {3, 4}
self.assertIsInstance(actual, SetUsingInstanceFromIterable)
self.assertEqual('from_iterable', actual.created_by)
self.assertEqual({1, 2, 4}, actual)
# NOTE: ixor'ing with a list is important here: internally, __ixor__
# only calls _from_iterable if the other value isn't already a Set.
impl ^= [3, 4]
self.assertIsInstance(impl, SetUsingInstanceFromIterable)
self.assertEqual('test', impl.created_by)
self.assertEqual({1, 2, 4}, impl)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_Set_interoperability_with_real_sets(self):
@@ -1705,6 +1835,20 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(f1 != l1)
self.assertTrue(f1 != l2)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_Set_hash_matches_frozenset(self):
sets = [
{}, {1}, {None}, {-1}, {0.0}, {"abc"}, {1, 2, 3},
{10**100, 10**101}, {"a", "b", "ab", ""}, {False, True},
{object(), object(), object()}, {float("nan")}, {frozenset()},
{*range(1000)}, {*range(1000)} - {100, 200, 300},
{*range(sys.maxsize - 10, sys.maxsize + 10)},
]
for s in sets:
fs = frozenset(s)
self.assertEqual(hash(fs), Set._hash(fs), msg=s)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_Mapping(self):
@@ -2125,6 +2269,31 @@ class TestCounter(unittest.TestCase):
set_result = setop(set(p.elements()), set(q.elements()))
self.assertEqual(counter_result, dict.fromkeys(set_result, 1))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_subset_superset_not_implemented(self):
# Verify that multiset comparison operations are not implemented.
# These operations were intentionally omitted because multiset
# comparison semantics conflict with existing dict equality semantics.
# For multisets, we would expect that if p<=q and p>=q are both true,
# then p==q. However, dict equality semantics require that p!=q when
# one of sets contains an element with a zero count and the other
# doesn't.
p = Counter(a=1, b=0)
q = Counter(a=1, c=0)
self.assertNotEqual(p, q)
with self.assertRaises(TypeError):
p < q
with self.assertRaises(TypeError):
p <= q
with self.assertRaises(TypeError):
p > q
with self.assertRaises(TypeError):
p >= q
def test_inplace_operations(self):
elements = 'abcd'
for i in range(1000):
@@ -2218,4 +2387,4 @@ def test_main(verbose=None):
if __name__ == "__main__":
test_main(verbose=True)
test_main(verbose=True)

View File

@@ -27,8 +27,6 @@ class UserListTest(list_tests.CommonTest):
for j in range(-3, 6):
self.assertEqual(u[i:j], l[i:j])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_slice_type(self):
l = [0, 1, 2, 3, 4]
u = UserList(l)

View File

@@ -14,13 +14,9 @@ class UserStringTest(
type2test = UserString
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_removeprefix(self):
super().test_removeprefix()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_removesuffix(self):
super().test_removesuffix()