Files
RustPython/Lib/collections/_defaultdict.py
Changjoon 330b18f2fe Preserve recursively-set value in defaultdict.__missing__ (#7718)
CPython's defaultdict.__missing__ (Modules/_collectionsmodule.c::defdict_missing)
calls default_factory() first; if the factory's recursion already populated
self[key] while running, the existing value is preserved instead of being
overwritten.

RustPython ships a Python fallback at Lib/collections/_defaultdict.py
(the C _collections.defaultdict is not available). That fallback
unconditionally executed self[key] = val after the factory returned,
overwriting any value the recursive call had already stored.

Add a 'if key in self: return self[key]' guard before the assignment.
dict.__contains__ does not invoke __missing__, so there's no recursion
risk; in the common non-reentrant case the check is False and behavior
is unchanged.

Unmasks test_defaultdict.TestDefaultDict.test_factory_conflict_with_set_value.
2026-04-29 18:14:32 +09:00

63 lines
1.9 KiB
Python
Vendored

from reprlib import recursive_repr as _recursive_repr
class defaultdict(dict):
def __init__(self, *args, **kwargs):
if len(args) >= 1:
default_factory = args[0]
if default_factory is not None and not callable(default_factory):
raise TypeError("first argument must be callable or None")
args = args[1:]
else:
default_factory = None
super().__init__(*args, **kwargs)
self.default_factory = default_factory
def __missing__(self, key):
if self.default_factory is not None:
val = self.default_factory()
else:
raise KeyError(key)
# CPython parity: a recursive __missing__ via factory() may have
# already populated key; preserve that value instead of overwriting.
if key in self:
return self[key]
self[key] = val
return val
@_recursive_repr()
def __repr_factory(factory):
return repr(factory)
def __repr__(self):
return f"{type(self).__name__}({defaultdict.__repr_factory(self.default_factory)}, {dict.__repr__(self)})"
def copy(self):
return type(self)(self.default_factory, self)
__copy__ = copy
def __reduce__(self):
if self.default_factory is not None:
args = self.default_factory,
else:
args = ()
return type(self), args, None, None, iter(self.items())
def __or__(self, other):
if not isinstance(other, dict):
return NotImplemented
new = defaultdict(self.default_factory, self)
new.update(other)
return new
def __ror__(self, other):
if not isinstance(other, dict):
return NotImplemented
new = defaultdict(self.default_factory, other)
new.update(self)
return new
defaultdict.__module__ = 'collections'