diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 8ef2871c6..c68ba49d9 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -1,16 +1,17 @@ +import collections.abc +import copy +import gc +import itertools +import operator +import pickle +import re import unittest +import warnings +import weakref +from random import randrange, shuffle from test import support from test.support import warnings_helper -import gc -import weakref -import operator -import copy -import pickle -from random import randrange, shuffle -import warnings -import collections -import collections.abc -import itertools + class PassThru(Exception): pass @@ -19,6 +20,14 @@ def check_pass_thru(): raise PassThru yield 1 +class CustomHash: + def __init__(self, hash): + self.hash = hash + def __hash__(self): + return self.hash + def __repr__(self): + return f'' + class BadCmp: def __hash__(self): return 1 @@ -227,14 +236,17 @@ class TestJointOps: def test_pickling(self): for i in range(pickle.HIGHEST_PROTOCOL + 1): + if type(self.s) not in (set, frozenset): + self.s.x = ['x'] + self.s.z = ['z'] p = pickle.dumps(self.s, i) dup = pickle.loads(p) self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup)) if type(self.s) not in (set, frozenset): - self.s.x = 10 - p = pickle.dumps(self.s, i) - dup = pickle.loads(p) self.assertEqual(self.s.x, dup.x) + self.assertEqual(self.s.z, dup.z) + self.assertNotHasAttr(self.s, 'y') + del self.s.x, self.s.z def test_iterator_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -318,8 +330,7 @@ class TestJointOps: name = repr(s).partition('(')[0] # strip class name self.assertEqual(repr(s), '%s({%s(...)})' % (name, name)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_do_not_rehash_dict_keys(self): n = 10 d = dict.fromkeys(map(HashCountingInt, range(n))) @@ -339,8 +350,7 @@ class TestJointOps: self.assertEqual(sum(elem.hash_count for elem in d), n) self.assertEqual(d3, dict.fromkeys(d, 123)) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_container_iterator(self): # Bug #3680: tp_traverse was not implemented for set iterator object class C(object): @@ -353,8 +363,7 @@ class TestJointOps: gc.collect() self.assertTrue(ref() is None, "Cycle was not collected") - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.expectedFailure # TODO: RUSTPYTHON def test_free_after_iterating(self): support.check_free_after_iterating(self, iter, self.thetype) @@ -430,7 +439,7 @@ class TestSet(TestJointOps, unittest.TestCase): self.assertRaises(KeyError, self.s.remove, self.thetype(self.word)) def test_remove_keyerror_unpacking(self): - # bug: www.python.org/sf/1576657 + # https://bugs.python.org/issue1576657 for v1 in ['Q', (1,)]: try: self.s.remove(v1) @@ -638,10 +647,68 @@ class TestSet(TestJointOps, unittest.TestCase): myset >= myobj self.assertTrue(myobj.le_called) - @unittest.skipUnless(hasattr(set, "test_c_api"), - 'C API test only available in a debug build') - def test_c_api(self): - self.assertEqual(set().test_c_api(), True) + def test_set_membership(self): + myfrozenset = frozenset(range(3)) + myset = {myfrozenset, "abc", 1} + self.assertIn(set(range(3)), myset) + self.assertNotIn(set(range(1)), myset) + myset.discard(set(range(3))) + self.assertEqual(myset, {"abc", 1}) + self.assertRaises(KeyError, myset.remove, set(range(1))) + self.assertRaises(KeyError, myset.remove, set(range(3))) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_unhashable_element(self): + myset = {'a'} + elem = [1, 2, 3] + + def check_unhashable_element(): + msg = "cannot use 'list' as a set element (unhashable type: 'list')" + return self.assertRaisesRegex(TypeError, re.escape(msg)) + + with check_unhashable_element(): + elem in myset + with check_unhashable_element(): + myset.add(elem) + with check_unhashable_element(): + myset.discard(elem) + + # Only TypeError exception is overriden, + # other exceptions are left unchanged. + class HashError: + def __hash__(self): + raise KeyError('error') + + elem2 = HashError() + with self.assertRaises(KeyError): + elem2 in myset + with self.assertRaises(KeyError): + myset.add(elem2) + with self.assertRaises(KeyError): + myset.discard(elem2) + + def test_hash_collision_remove_add(self): + self.maxDiff = None + # There should be enough space, so all elements with unique hash + # will be placed in corresponding cells without collision. + n = 64 + elems = [CustomHash(h) for h in range(n)] + # Elements with hash collision. + a = CustomHash(n) + b = CustomHash(n) + elems += [a, b] + s = self.thetype(elems) + self.assertEqual(len(s), len(elems), s) + s.remove(a) + # "a" has been replaced with a dummy. + del elems[n] + self.assertEqual(len(s), len(elems), s) + self.assertEqual(s, set(elems)) + s.add(b) + # "b" should not replace the dummy. + self.assertEqual(len(s), len(elems), s) + self.assertEqual(s, set(elems)) + class SetSubclass(set): pass @@ -650,15 +717,37 @@ class TestSetSubclass(TestSet): thetype = SetSubclass basetype = set -class SetSubclassWithKeywordArgs(set): - def __init__(self, iterable=[], newarg=None): - set.__init__(self, iterable) - -class TestSetSubclassWithKeywordArgs(TestSet): - def test_keywords_in_subclass(self): - 'SF bug #1486663 -- this used to erroneously raise a TypeError' - SetSubclassWithKeywordArgs(newarg=1) + class subclass(set): + pass + u = subclass([1, 2]) + self.assertIs(type(u), subclass) + self.assertEqual(set(u), {1, 2}) + with self.assertRaises(TypeError): + subclass(sequence=()) + + class subclass_with_init(set): + def __init__(self, arg, newarg=None): + super().__init__(arg) + self.newarg = newarg + u = subclass_with_init([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_init) + self.assertEqual(set(u), {1, 2}) + self.assertEqual(u.newarg, 3) + + class subclass_with_new(set): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self + u = subclass_with_new([1, 2]) + self.assertIs(type(u), subclass_with_new) + self.assertEqual(set(u), {1, 2}) + self.assertIsNone(u.newarg) + # disallow kwargs in __new__ only (https://bugs.python.org/issue43413#msg402000) + with self.assertRaises(TypeError): + subclass_with_new([1, 2], newarg=3) + class TestFrozenSet(TestJointOps, unittest.TestCase): thetype = frozenset @@ -740,6 +829,34 @@ class TestFrozenSetSubclass(TestFrozenSet): thetype = FrozenSetSubclass basetype = frozenset + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_keywords_in_subclass(self): + class subclass(frozenset): + pass + u = subclass([1, 2]) + self.assertIs(type(u), subclass) + self.assertEqual(set(u), {1, 2}) + with self.assertRaises(TypeError): + subclass(sequence=()) + + class subclass_with_init(frozenset): + def __init__(self, arg, newarg=None): + self.newarg = newarg + u = subclass_with_init([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_init) + self.assertEqual(set(u), {1, 2}) + self.assertEqual(u.newarg, 3) + + class subclass_with_new(frozenset): + def __new__(cls, arg, newarg=None): + self = super().__new__(cls, arg) + self.newarg = newarg + return self + u = subclass_with_new([1, 2], newarg=3) + self.assertIs(type(u), subclass_with_new) + self.assertEqual(set(u), {1, 2}) + self.assertEqual(u.newarg, 3) + def test_constructor_identity(self): s = self.thetype(range(3)) t = self.thetype(s) @@ -765,6 +882,25 @@ class TestFrozenSetSubclass(TestFrozenSet): # All empty frozenset subclass instances should have different ids self.assertEqual(len(set(map(id, efs))), len(efs)) + +class SetSubclassWithSlots(set): + __slots__ = ('x', 'y', '__dict__') + +class TestSetSubclassWithSlots(unittest.TestCase): + thetype = SetSubclassWithSlots + setUp = TestJointOps.setUp + test_pickling = TestJointOps.test_pickling + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_pickling(self): + return super().test_pickling() + +class FrozenSetSubclassWithSlots(frozenset): + __slots__ = ('x', 'y', '__dict__') + +class TestFrozenSetSubclassWithSlots(TestSetSubclassWithSlots): + thetype = FrozenSetSubclassWithSlots + # Tests taken from test_sets.py ============================================= empty_set = set() @@ -779,8 +915,8 @@ class TestBasicOps: def check_repr_against_values(self): text = repr(self.set) - self.assertTrue(text.startswith('{')) - self.assertTrue(text.endswith('}')) + self.assertStartsWith(text, '{') + self.assertEndsWith(text, '}') result = text[1:-1].split(', ') result.sort() @@ -961,8 +1097,7 @@ class TestBasicOpsBytes(TestBasicOps, unittest.TestCase): class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): def setUp(self): - self._warning_filters = warnings_helper.check_warnings() - self._warning_filters.__enter__() + self.enterContext(warnings_helper.check_warnings()) warnings.simplefilter('ignore', BytesWarning) self.case = "string and bytes set" self.values = ["a", "b", b"a", b"b"] @@ -970,9 +1105,6 @@ class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase): self.dup = set(self.values) self.length = 4 - def tearDown(self): - self._warning_filters.__exit__(None, None, None) - def test_repr(self): self.check_repr_against_values() @@ -1760,6 +1892,7 @@ class TestWeirdBugs(unittest.TestCase): list(si) def test_merge_and_mutate(self): + # gh-141805 class X: def __hash__(self): return hash(0) @@ -1772,6 +1905,33 @@ class TestWeirdBugs(unittest.TestCase): s = {0} s.update(other) + def test_hash_collision_concurrent_add(self): + class X: + def __hash__(self): + return 0 + class Y: + flag = False + def __hash__(self): + return 0 + def __eq__(self, other): + if not self.flag: + self.flag = True + s.add(X()) + return self is other + + a = X() + s = set() + s.add(a) + s.add(X()) + s.remove(a) + # Now the set contains a dummy entry followed by an entry + # for an object with hash 0. + s.add(Y()) + # The following operations should not crash. + repr(s) + list(s) + set() | s + class TestOperationsMutating: """Regression test for bpo-46615""" diff --git a/crates/vm/src/builtins/set.rs b/crates/vm/src/builtins/set.rs index 7e2c10ac8..11f952df7 100644 --- a/crates/vm/src/builtins/set.rs +++ b/crates/vm/src/builtins/set.rs @@ -671,8 +671,7 @@ impl PySet { #[pymethod] pub fn add(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.inner.add(item, vm)?; - Ok(()) + self.inner.add(item, vm) } #[pymethod] @@ -682,8 +681,7 @@ impl PySet { #[pymethod] fn discard(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.inner.discard(&item, vm)?; - Ok(()) + self.inner.discard(&item, vm).map(|_| ()) } #[pymethod] @@ -729,8 +727,7 @@ impl PySet { #[pymethod] fn difference_update(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult<()> { - self.inner.difference_update(others.into_iter(), vm)?; - Ok(()) + self.inner.difference_update(others.into_iter(), vm) } fn __isub__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { @@ -750,8 +747,7 @@ impl PySet { vm: &VirtualMachine, ) -> PyResult<()> { self.inner - .symmetric_difference_update(others.into_iter(), vm)?; - Ok(()) + .symmetric_difference_update(others.into_iter(), vm) } fn __ixor__(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { @@ -950,11 +946,7 @@ impl Representable for PySet { return Ok(Wtf8Buf::from(format!("{class_name}()"))); } if let Some(_guard) = ReprGuard::enter(vm, zelf.as_object()) { - let name = if class_name != "set" { - Some(class_name) - } else { - None - }; + let name = (class_name != "set").then_some(class_name); zelf.inner.repr(name, vm) } else { Ok(Wtf8Buf::from(format!("{class_name}(...)")))