Fix set in-place operators with self argument (#6661)

Co-authored-by: Hugh <HueCodes@users.noreply.github.com>
This commit is contained in:
HueCodes
2026-01-09 04:31:21 -08:00
committed by GitHub
parent b38cdaa30e
commit 83a0deac25
2 changed files with 36 additions and 6 deletions

View File

@@ -24,6 +24,7 @@ use crate::{
vm::VirtualMachine,
};
use alloc::fmt;
use core::borrow::Borrow;
use core::ops::Deref;
use rustpython_common::{
atomic::{Ordering, PyAtomic, Radium},
@@ -719,8 +720,10 @@ impl PySet {
}
fn __iand__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?;
if !set.is(zelf.as_object()) {
zelf.inner
.intersection_update(core::iter::once(set.into_iterable(vm)?), vm)?;
}
Ok(zelf)
}
@@ -731,8 +734,12 @@ impl PySet {
}
fn __isub__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.difference_update(set.into_iterable_iter(vm)?, vm)?;
if set.is(zelf.as_object()) {
zelf.inner.clear();
} else {
zelf.inner
.difference_update(set.into_iterable_iter(vm)?, vm)?;
}
Ok(zelf)
}
@@ -748,8 +755,12 @@ impl PySet {
}
fn __ixor__(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
if set.is(zelf.as_object()) {
zelf.inner.clear();
} else {
zelf.inner
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
}
Ok(zelf)
}
@@ -1297,6 +1308,13 @@ struct AnySet {
object: PyObjectRef,
}
impl Borrow<PyObject> for AnySet {
#[inline(always)]
fn borrow(&self) -> &PyObject {
&self.object
}
}
impl AnySet {
/// Check if object is a set or frozenset (including subclasses)
/// Equivalent to CPython's PyAnySet_Check

View File

@@ -200,6 +200,18 @@ with assert_raises(TypeError):
with assert_raises(TypeError):
a &= [1, 2, 3]
a = set([1, 2, 3])
a &= a
assert a == set([1, 2, 3])
a = set([1, 2, 3])
a -= a
assert a == set()
a = set([1, 2, 3])
a ^= a
assert a == set()
a = set([1, 2, 3])
a.difference_update([3, 4, 5])
assert a == set([1, 2])