forked from Rust-related/RustPython
Support Args in set/frozenset methods
This commit is contained in:
@@ -289,6 +289,10 @@ impl<T> IntoIterator for KwArgs<T> {
|
||||
pub struct Args<T = PyObjectRef>(Vec<T>);
|
||||
|
||||
impl<T> Args<T> {
|
||||
pub fn new(args: Vec<T>) -> Self {
|
||||
Args(args)
|
||||
}
|
||||
|
||||
pub fn into_vec(self) -> Vec<T> {
|
||||
self.0
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::fmt;
|
||||
use super::objlist::PyListIterator;
|
||||
use super::objtype::{self, PyClassRef};
|
||||
use crate::dictdatatype;
|
||||
use crate::function::OptionalArg;
|
||||
use crate::function::{Args, OptionalArg};
|
||||
use crate::pyobject::{
|
||||
PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
|
||||
TypeProtocol,
|
||||
@@ -280,39 +280,51 @@ impl PySetInner {
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
|
||||
for item in iterable.iter(vm)? {
|
||||
self.add(&item?, vm)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn intersection_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let temp_inner = self.copy();
|
||||
self.clear();
|
||||
for item in iterable.iter(vm)? {
|
||||
let obj = item?;
|
||||
if temp_inner.contains(&obj, vm)? {
|
||||
self.add(&obj, vm)?;
|
||||
fn update(&mut self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<()> {
|
||||
for iterable in others {
|
||||
for item in iterable.iter(vm)? {
|
||||
self.add(&item?, vm)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn difference_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
|
||||
for item in iterable.iter(vm)? {
|
||||
self.content.delete_if_exists(vm, &item?)?;
|
||||
fn intersection_update(
|
||||
&mut self,
|
||||
others: Args<PyIterable>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<()> {
|
||||
let temp_inner = self.copy();
|
||||
self.clear();
|
||||
for iterable in others {
|
||||
for item in iterable.iter(vm)? {
|
||||
let obj = item?;
|
||||
if temp_inner.contains(&obj, vm)? {
|
||||
self.add(&obj, vm)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn difference_update(&mut self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<()> {
|
||||
for iterable in others {
|
||||
for item in iterable.iter(vm)? {
|
||||
self.content.delete_if_exists(vm, &item?)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn symmetric_difference_update(
|
||||
&mut self,
|
||||
iterable: PyIterable,
|
||||
others: Args<PyIterable>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<()> {
|
||||
for item in iterable.iter(vm)? {
|
||||
self.content.delete_or_insert(vm, &item?, ())?;
|
||||
for iterable in others {
|
||||
for item in iterable.iter(vm)? {
|
||||
self.content.delete_or_insert(vm, &item?, ())?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -328,6 +340,18 @@ macro_rules! try_set_cmp {
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! multi_args_set {
|
||||
($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{
|
||||
let mut res = $zelf.inner.borrow().copy();
|
||||
for other in $others {
|
||||
res = res.$op(other, $vm)?
|
||||
}
|
||||
Ok(Self {
|
||||
inner: RefCell::new(res),
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
#[pyimpl(flags(BASETYPE))]
|
||||
impl PySet {
|
||||
#[pyslot]
|
||||
@@ -395,31 +419,27 @@ impl PySet {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn union(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: RefCell::new(self.inner.borrow().union(other, vm)?),
|
||||
})
|
||||
fn union(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
multi_args_set!(vm, others, self, union)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn intersection(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: RefCell::new(self.inner.borrow().intersection(other, vm)?),
|
||||
})
|
||||
fn intersection(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
multi_args_set!(vm, others, self, intersection)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: RefCell::new(self.inner.borrow().difference(other, vm)?),
|
||||
})
|
||||
fn difference(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
multi_args_set!(vm, others, self, difference)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn symmetric_difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: RefCell::new(self.inner.borrow().symmetric_difference(other, vm)?),
|
||||
})
|
||||
fn symmetric_difference(
|
||||
&self,
|
||||
others: Args<PyIterable>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Self> {
|
||||
multi_args_set!(vm, others, self, symmetric_difference)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
@@ -529,14 +549,14 @@ impl PySet {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
|
||||
self.inner.borrow_mut().update(iterable, vm)?;
|
||||
fn update(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult {
|
||||
self.inner.borrow_mut().update(others, vm)?;
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn intersection_update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
|
||||
self.inner.borrow_mut().intersection_update(iterable, vm)?;
|
||||
fn intersection_update(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult {
|
||||
self.inner.borrow_mut().intersection_update(others, vm)?;
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
@@ -549,8 +569,8 @@ impl PySet {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn difference_update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
|
||||
self.inner.borrow_mut().difference_update(iterable, vm)?;
|
||||
fn difference_update(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult {
|
||||
self.inner.borrow_mut().difference_update(others, vm)?;
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
@@ -563,10 +583,14 @@ impl PySet {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn symmetric_difference_update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
|
||||
fn symmetric_difference_update(
|
||||
&self,
|
||||
others: Args<PyIterable>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
self.inner
|
||||
.borrow_mut()
|
||||
.symmetric_difference_update(iterable, vm)?;
|
||||
.symmetric_difference_update(others, vm)?;
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
@@ -584,6 +608,16 @@ impl PySet {
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! multi_args_frozenset {
|
||||
($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{
|
||||
let mut res = $zelf.inner.copy();
|
||||
for other in $others {
|
||||
res = res.$op(other, $vm)?
|
||||
}
|
||||
Ok(Self { inner: res })
|
||||
}};
|
||||
}
|
||||
|
||||
#[pyimpl(flags(BASETYPE))]
|
||||
impl PyFrozenSet {
|
||||
#[pyslot]
|
||||
@@ -651,31 +685,27 @@ impl PyFrozenSet {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn union(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: self.inner.union(other, vm)?,
|
||||
})
|
||||
fn union(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
multi_args_frozenset!(vm, others, self, union)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn intersection(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: self.inner.intersection(other, vm)?,
|
||||
})
|
||||
fn intersection(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
multi_args_frozenset!(vm, others, self, intersection)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: self.inner.difference(other, vm)?,
|
||||
})
|
||||
fn difference(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
multi_args_frozenset!(vm, others, self, difference)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn symmetric_difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: self.inner.symmetric_difference(other, vm)?,
|
||||
})
|
||||
fn symmetric_difference(
|
||||
&self,
|
||||
others: Args<PyIterable>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Self> {
|
||||
multi_args_frozenset!(vm, others, self, symmetric_difference)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
@@ -753,7 +783,7 @@ impl PyFrozenSet {
|
||||
}
|
||||
|
||||
struct SetIterable {
|
||||
iterable: PyIterable,
|
||||
iterable: Args<PyIterable>,
|
||||
}
|
||||
|
||||
impl TryFromObject for SetIterable {
|
||||
@@ -762,7 +792,7 @@ impl TryFromObject for SetIterable {
|
||||
|| objtype::issubclass(&obj.class(), &vm.ctx.frozenset_type())
|
||||
{
|
||||
Ok(SetIterable {
|
||||
iterable: PyIterable::try_from_object(vm, obj)?,
|
||||
iterable: Args::new(vec![PyIterable::try_from_object(vm, obj)?]),
|
||||
})
|
||||
} else {
|
||||
Err(vm.new_type_error(format!(
|
||||
|
||||
Reference in New Issue
Block a user