SetIterable -> AnySet

This commit is contained in:
Jeong Yunwon
2022-07-16 01:55:49 +09:00
parent 46073a2a65
commit 4230efae2e

View File

@@ -35,9 +35,39 @@ pub struct PySet {
}
impl PySet {
pub fn new_ref(ctx: &Context) -> PyRef<Self> {
// Initialized empty, as calling __hash__ is required for adding each object to the set
// which requires a VM context - this is done in the set code itself.
PyRef::new_ref(Self::default(), ctx.types.set_type.to_owned(), None)
}
pub fn elements(&self) -> Vec<PyObjectRef> {
self.inner.elements()
}
fn fold_op(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self.inner.fold_op(others, op, vm)?,
})
}
fn op(
&self,
other: AnySet,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self
.inner
.fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?,
})
}
}
/// frozenset() -> empty frozenset object
@@ -51,9 +81,46 @@ pub struct PyFrozenSet {
}
impl PyFrozenSet {
// Also used by ssl.rs windows.
pub fn from_iter(
vm: &VirtualMachine,
it: impl IntoIterator<Item = PyObjectRef>,
) -> PyResult<Self> {
let inner = PySetInner::default();
for elem in it {
inner.add(elem, vm)?;
}
// FIXME: empty set check
Ok(Self { inner })
}
pub fn elements(&self) -> Vec<PyObjectRef> {
self.inner.elements()
}
fn fold_op(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self.inner.fold_op(others, op, vm)?,
})
}
fn op(
&self,
other: AnySet,
op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult<PySetInner>,
vm: &VirtualMachine,
) -> PyResult<Self> {
Ok(Self {
inner: self
.inner
.fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?,
})
}
}
impl fmt::Debug for PySet {
@@ -99,6 +166,19 @@ impl PySetInner {
Ok(set)
}
fn fold_op<O>(
&self,
others: impl std::iter::Iterator<Item = O>,
op: fn(&Self, O, &VirtualMachine) -> PyResult<Self>,
vm: &VirtualMachine,
) -> PyResult<Self> {
let mut res = self.copy();
for other in others {
res = op(&res, other, vm)?;
}
Ok(res)
}
fn len(&self) -> usize {
self.content.len()
}
@@ -259,7 +339,11 @@ impl PySetInner {
}
}
fn update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
fn update(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
for item in iterable.iter(vm)? {
self.add(item?, vm)?;
@@ -270,7 +354,7 @@ impl PySetInner {
fn intersection_update(
&self,
others: PosArgs<ArgIterable>,
others: impl std::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
let mut temp_inner = self.copy();
@@ -287,7 +371,11 @@ impl PySetInner {
Ok(())
}
fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
fn difference_update(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
for item in iterable.iter(vm)? {
self.content.delete_if_exists(vm, &*item?)?;
@@ -298,7 +386,7 @@ impl PySetInner {
fn symmetric_difference_update(
&self,
others: PosArgs<ArgIterable>,
others: impl std::iter::Iterator<Item = ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
@@ -373,24 +461,6 @@ fn reduce_set(
))
}
macro_rules! multi_args_set {
($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{
let mut res = $zelf.inner.copy();
for other in $others {
res = res.$op(other, $vm)?
}
Ok(Self { inner: res })
}};
}
impl PySet {
pub fn new_ref(ctx: &Context) -> PyRef<Self> {
// Initialized empty, as calling __hash__ is required for adding each object to the set
// which requires a VM context - this is done in the set code itself.
PyRef::new_ref(Self::default(), ctx.types.set_type.to_owned(), None)
}
}
#[pyimpl(
with(Constructor, Initializer, AsSequence, Hashable, Comparable, Iterable),
flags(BASETYPE)
@@ -420,17 +490,17 @@ impl PySet {
#[pymethod]
fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_set!(vm, others, self, union)
self.fold_op(others.into_iter(), PySetInner::union, vm)
}
#[pymethod]
fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_set!(vm, others, self, intersection)
self.fold_op(others.into_iter(), PySetInner::intersection, vm)
}
#[pymethod]
fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_set!(vm, others, self, difference)
self.fold_op(others.into_iter(), PySetInner::difference, vm)
}
#[pymethod]
@@ -439,7 +509,7 @@ impl PySet {
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<Self> {
multi_args_set!(vm, others, self, symmetric_difference)
self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
}
#[pymethod]
@@ -460,10 +530,12 @@ impl PySet {
#[pymethod(name = "__ror__")]
#[pymethod(magic)]
fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.union(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::union,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -472,10 +544,12 @@ impl PySet {
#[pymethod(name = "__rand__")]
#[pymethod(magic)]
fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.intersection(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::intersection,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -483,10 +557,12 @@ impl PySet {
#[pymethod(magic)]
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.difference(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -500,10 +576,12 @@ impl PySet {
#[pymethod(name = "__rxor__")]
#[pymethod(magic)]
fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.symmetric_difference(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::symmetric_difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -557,14 +635,14 @@ impl PySet {
}
#[pymethod(magic)]
fn ior(zelf: PyRef<Self>, iterable: SetIterable, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner.update(iterable.iterable, vm)?;
fn ior(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner.update(set.into_iterable_iter(vm)?, vm)?;
Ok(zelf)
}
#[pymethod]
fn update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
self.inner.update(others, vm)?;
self.inner.update(others.into_iter(), vm)?;
Ok(())
}
@@ -574,33 +652,27 @@ impl PySet {
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
self.inner.intersection_update(others, vm)?;
self.inner.intersection_update(others.into_iter(), vm)?;
Ok(())
}
#[pymethod(magic)]
fn iand(
zelf: PyRef<Self>,
iterable: SetIterable,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
zelf.inner.intersection_update(iterable.iterable, vm)?;
fn iand(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?;
Ok(zelf)
}
#[pymethod]
fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
self.inner.difference_update(others, vm)?;
self.inner.difference_update(others.into_iter(), vm)?;
Ok(())
}
#[pymethod(magic)]
fn isub(
zelf: PyRef<Self>,
iterable: SetIterable,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
zelf.inner.difference_update(iterable.iterable, vm)?;
fn isub(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.difference_update(set.into_iterable_iter(vm)?, vm)?;
Ok(zelf)
}
@@ -610,18 +682,15 @@ impl PySet {
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<()> {
self.inner.symmetric_difference_update(others, vm)?;
self.inner
.symmetric_difference_update(others.into_iter(), vm)?;
Ok(())
}
#[pymethod(magic)]
fn ixor(
zelf: PyRef<Self>,
iterable: SetIterable,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
fn ixor(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.symmetric_difference_update(iterable.iterable, vm)?;
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
Ok(zelf)
}
@@ -690,16 +759,6 @@ impl Iterable for PySet {
}
}
macro_rules! multi_args_frozenset {
($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{
let mut res = $zelf.inner.copy();
for other in $others {
res = res.$op(other, $vm)?
}
Ok(Self { inner: res })
}};
}
impl Constructor for PyFrozenSet {
type Args = OptionalArg<PyObjectRef>;
@@ -733,19 +792,6 @@ impl Constructor for PyFrozenSet {
with(Constructor, AsSequence, Hashable, Comparable, Iterable)
)]
impl PyFrozenSet {
// Also used by ssl.rs windows.
pub fn from_iter(
vm: &VirtualMachine,
it: impl IntoIterator<Item = PyObjectRef>,
) -> PyResult<Self> {
let inner = PySetInner::default();
for elem in it {
inner.add(elem, vm)?;
}
// FIXME: empty set check
Ok(Self { inner })
}
#[pymethod(magic)]
fn len(&self) -> usize {
self.inner.len()
@@ -775,17 +821,17 @@ impl PyFrozenSet {
#[pymethod]
fn union(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, union)
self.fold_op(others.into_iter(), PySetInner::union, vm)
}
#[pymethod]
fn intersection(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, intersection)
self.fold_op(others.into_iter(), PySetInner::intersection, vm)
}
#[pymethod]
fn difference(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, difference)
self.fold_op(others.into_iter(), PySetInner::difference, vm)
}
#[pymethod]
@@ -794,7 +840,7 @@ impl PyFrozenSet {
others: PosArgs<ArgIterable>,
vm: &VirtualMachine,
) -> PyResult<Self> {
multi_args_frozenset!(vm, others, self, symmetric_difference)
self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm)
}
#[pymethod]
@@ -815,10 +861,12 @@ impl PyFrozenSet {
#[pymethod(name = "__ror__")]
#[pymethod(magic)]
fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.union(set_iter.iterable, vm)?,
))
if let Ok(set) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
set,
PySetInner::union,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -827,10 +875,12 @@ impl PyFrozenSet {
#[pymethod(name = "__rand__")]
#[pymethod(magic)]
fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.intersection(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::intersection,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -838,10 +888,12 @@ impl PyFrozenSet {
#[pymethod(magic)]
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.difference(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -855,10 +907,12 @@ impl PyFrozenSet {
#[pymethod(name = "__rxor__")]
#[pymethod(magic)]
fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyArithmeticValue<Self>> {
if let Ok(set_iter) = SetIterable::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(
self.symmetric_difference(set_iter.iterable, vm)?,
))
if let Ok(other) = AnySet::try_from_object(vm, other) {
Ok(PyArithmeticValue::Implemented(self.op(
other,
PySetInner::symmetric_difference,
vm,
)?))
} else {
Ok(PyArithmeticValue::NotImplemented)
}
@@ -927,11 +981,24 @@ impl Iterable for PyFrozenSet {
}
}
struct SetIterable {
iterable: PosArgs<ArgIterable>,
struct AnySet {
object: PyObjectRef,
}
impl TryFromObject for SetIterable {
impl AnySet {
fn into_iterable(self, vm: &VirtualMachine) -> PyResult<ArgIterable> {
self.object.try_into_value(vm)
}
fn into_iterable_iter(
self,
vm: &VirtualMachine,
) -> PyResult<impl std::iter::Iterator<Item = ArgIterable>> {
Ok(std::iter::once(self.into_iterable(vm)?))
}
}
impl TryFromObject for AnySet {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
let class = obj.class();
if class.fast_issubclass(vm.ctx.types.set_type)
@@ -939,9 +1006,7 @@ impl TryFromObject for SetIterable {
{
// the class lease needs to be drop to be able to return the object
drop(class);
Ok(SetIterable {
iterable: PosArgs::new(vec![ArgIterable::try_from_object(vm, obj)?]),
})
Ok(AnySet { object: obj })
} else {
Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", class)))
}