Support Args in set/frozenset methods

This commit is contained in:
Aviv Palivoda
2020-02-16 19:09:30 +02:00
parent 176fa581b6
commit 34d8c44952
2 changed files with 97 additions and 63 deletions

View File

@@ -289,6 +289,10 @@ impl<T> IntoIterator for KwArgs<T> {
pub struct Args<T = PyObjectRef>(Vec<T>);
impl<T> Args<T> {
pub fn new(args: Vec<T>) -> Self {
Args(args)
}
pub fn into_vec(self) -> Vec<T> {
self.0
}

View File

@@ -8,7 +8,7 @@ use std::fmt;
use super::objlist::PyListIterator;
use super::objtype::{self, PyClassRef};
use crate::dictdatatype;
use crate::function::OptionalArg;
use crate::function::{Args, OptionalArg};
use crate::pyobject::{
PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
TypeProtocol,
@@ -280,39 +280,51 @@ impl PySetInner {
}
}
fn update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
for item in iterable.iter(vm)? {
self.add(&item?, vm)?;
}
Ok(())
}
fn intersection_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
let temp_inner = self.copy();
self.clear();
for item in iterable.iter(vm)? {
let obj = item?;
if temp_inner.contains(&obj, vm)? {
self.add(&obj, vm)?;
fn update(&mut self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<()> {
for iterable in others {
for item in iterable.iter(vm)? {
self.add(&item?, vm)?;
}
}
Ok(())
}
fn difference_update(&mut self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
for item in iterable.iter(vm)? {
self.content.delete_if_exists(vm, &item?)?;
fn intersection_update(
&mut self,
others: Args<PyIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
let temp_inner = self.copy();
self.clear();
for iterable in others {
for item in iterable.iter(vm)? {
let obj = item?;
if temp_inner.contains(&obj, vm)? {
self.add(&obj, vm)?;
}
}
}
Ok(())
}
fn difference_update(&mut self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<()> {
for iterable in others {
for item in iterable.iter(vm)? {
self.content.delete_if_exists(vm, &item?)?;
}
}
Ok(())
}
fn symmetric_difference_update(
&mut self,
iterable: PyIterable,
others: Args<PyIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for item in iterable.iter(vm)? {
self.content.delete_or_insert(vm, &item?, ())?;
for iterable in others {
for item in iterable.iter(vm)? {
self.content.delete_or_insert(vm, &item?, ())?;
}
}
Ok(())
}
@@ -328,6 +340,18 @@ macro_rules! try_set_cmp {
};
}
macro_rules! multi_args_set {
($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{
let mut res = $zelf.inner.borrow().copy();
for other in $others {
res = res.$op(other, $vm)?
}
Ok(Self {
inner: RefCell::new(res),
})
}};
}
#[pyimpl(flags(BASETYPE))]
impl PySet {
#[pyslot]
@@ -395,31 +419,27 @@ impl PySet {
}
#[pymethod]
fn union(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: RefCell::new(self.inner.borrow().union(other, vm)?),
})
fn union(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_set!(vm, others, self, union)
}
#[pymethod]
fn intersection(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: RefCell::new(self.inner.borrow().intersection(other, vm)?),
})
fn intersection(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_set!(vm, others, self, intersection)
}
#[pymethod]
fn difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: RefCell::new(self.inner.borrow().difference(other, vm)?),
})
fn difference(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_set!(vm, others, self, difference)
}
#[pymethod]
fn symmetric_difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: RefCell::new(self.inner.borrow().symmetric_difference(other, vm)?),
})
fn symmetric_difference(
&self,
others: Args<PyIterable>,
vm: &VirtualMachine,
) -> PyResult<Self> {
multi_args_set!(vm, others, self, symmetric_difference)
}
#[pymethod]
@@ -529,14 +549,14 @@ impl PySet {
}
#[pymethod]
fn update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
self.inner.borrow_mut().update(iterable, vm)?;
fn update(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult {
self.inner.borrow_mut().update(others, vm)?;
Ok(vm.get_none())
}
#[pymethod]
fn intersection_update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
self.inner.borrow_mut().intersection_update(iterable, vm)?;
fn intersection_update(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult {
self.inner.borrow_mut().intersection_update(others, vm)?;
Ok(vm.get_none())
}
@@ -549,8 +569,8 @@ impl PySet {
}
#[pymethod]
fn difference_update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
self.inner.borrow_mut().difference_update(iterable, vm)?;
fn difference_update(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult {
self.inner.borrow_mut().difference_update(others, vm)?;
Ok(vm.get_none())
}
@@ -563,10 +583,14 @@ impl PySet {
}
#[pymethod]
fn symmetric_difference_update(&self, iterable: PyIterable, vm: &VirtualMachine) -> PyResult {
fn symmetric_difference_update(
&self,
others: Args<PyIterable>,
vm: &VirtualMachine,
) -> PyResult {
self.inner
.borrow_mut()
.symmetric_difference_update(iterable, vm)?;
.symmetric_difference_update(others, vm)?;
Ok(vm.get_none())
}
@@ -584,6 +608,16 @@ impl PySet {
}
}
macro_rules! multi_args_frozenset {
($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{
let mut res = $zelf.inner.copy();
for other in $others {
res = res.$op(other, $vm)?
}
Ok(Self { inner: res })
}};
}
#[pyimpl(flags(BASETYPE))]
impl PyFrozenSet {
#[pyslot]
@@ -651,31 +685,27 @@ impl PyFrozenSet {
}
#[pymethod]
fn union(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: self.inner.union(other, vm)?,
})
fn union(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, union)
}
#[pymethod]
fn intersection(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: self.inner.intersection(other, vm)?,
})
fn intersection(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, intersection)
}
#[pymethod]
fn difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: self.inner.difference(other, vm)?,
})
fn difference(&self, others: Args<PyIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, difference)
}
#[pymethod]
fn symmetric_difference(&self, other: PyIterable, vm: &VirtualMachine) -> PyResult<Self> {
Ok(Self {
inner: self.inner.symmetric_difference(other, vm)?,
})
fn symmetric_difference(
&self,
others: Args<PyIterable>,
vm: &VirtualMachine,
) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, symmetric_difference)
}
#[pymethod]
@@ -753,7 +783,7 @@ impl PyFrozenSet {
}
struct SetIterable {
iterable: PyIterable,
iterable: Args<PyIterable>,
}
impl TryFromObject for SetIterable {
@@ -762,7 +792,7 @@ impl TryFromObject for SetIterable {
|| objtype::issubclass(&obj.class(), &vm.ctx.frozenset_type())
{
Ok(SetIterable {
iterable: PyIterable::try_from_object(vm, obj)?,
iterable: Args::new(vec![PyIterable::try_from_object(vm, obj)?]),
})
} else {
Err(vm.new_type_error(format!(