mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
CPython's defaultdict.__missing__ (Modules/_collectionsmodule.c::defdict_missing) calls default_factory() first; if the factory's recursion already populated self[key] while running, the existing value is preserved instead of being overwritten. RustPython ships a Python fallback at Lib/collections/_defaultdict.py (the C _collections.defaultdict is not available). That fallback unconditionally executed self[key] = val after the factory returned, overwriting any value the recursive call had already stored. Add a 'if key in self: return self[key]' guard before the assignment. dict.__contains__ does not invoke __missing__, so there's no recursion risk; in the common non-reentrant case the check is False and behavior is unchanged. Unmasks test_defaultdict.TestDefaultDict.test_factory_conflict_with_set_value.
63 lines
1.9 KiB
Python
Vendored
63 lines
1.9 KiB
Python
Vendored
from reprlib import recursive_repr as _recursive_repr
|
|
|
|
class defaultdict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
if len(args) >= 1:
|
|
default_factory = args[0]
|
|
if default_factory is not None and not callable(default_factory):
|
|
raise TypeError("first argument must be callable or None")
|
|
args = args[1:]
|
|
else:
|
|
default_factory = None
|
|
super().__init__(*args, **kwargs)
|
|
self.default_factory = default_factory
|
|
|
|
def __missing__(self, key):
|
|
if self.default_factory is not None:
|
|
val = self.default_factory()
|
|
else:
|
|
raise KeyError(key)
|
|
# CPython parity: a recursive __missing__ via factory() may have
|
|
# already populated key; preserve that value instead of overwriting.
|
|
if key in self:
|
|
return self[key]
|
|
self[key] = val
|
|
return val
|
|
|
|
@_recursive_repr()
|
|
def __repr_factory(factory):
|
|
return repr(factory)
|
|
|
|
def __repr__(self):
|
|
return f"{type(self).__name__}({defaultdict.__repr_factory(self.default_factory)}, {dict.__repr__(self)})"
|
|
|
|
def copy(self):
|
|
return type(self)(self.default_factory, self)
|
|
|
|
__copy__ = copy
|
|
|
|
def __reduce__(self):
|
|
if self.default_factory is not None:
|
|
args = self.default_factory,
|
|
else:
|
|
args = ()
|
|
return type(self), args, None, None, iter(self.items())
|
|
|
|
def __or__(self, other):
|
|
if not isinstance(other, dict):
|
|
return NotImplemented
|
|
|
|
new = defaultdict(self.default_factory, self)
|
|
new.update(other)
|
|
return new
|
|
|
|
def __ror__(self, other):
|
|
if not isinstance(other, dict):
|
|
return NotImplemented
|
|
|
|
new = defaultdict(self.default_factory, other)
|
|
new.update(self)
|
|
return new
|
|
|
|
defaultdict.__module__ = 'collections'
|