From 4e6c451b2a8b28f10c5218764d69f1d4248268e3 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 18 Sep 2021 11:40:21 +0200 Subject: [PATCH 01/19] Impl pickling for bytes and bytearray --- Lib/test/test_bytes.py | 2 -- vm/src/builtins/bytearray.rs | 23 ++++++++++++++++++++++- vm/src/builtins/bytes.rs | 25 +++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index e37e5084a0..c0127cabe9 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -795,8 +795,6 @@ class BaseBytesTest: q = pickle.loads(ps) self.assertEqual(b, q) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iterator_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): for b in b"", b"a", b"abc", b"\xffab\x80", b"\0\0\377\0\0": diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 2ab6091dcd..694cdadf9e 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -742,7 +742,28 @@ impl PyValue for PyByteArrayIterator { } #[pyimpl(with(SlotIterator))] -impl PyByteArrayIterator {} +impl PyByteArrayIterator { + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyResult { + let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + Ok(vm.ctx.new_tuple(vec![ + iter, + vm.ctx.new_tuple(vec![self.bytearray.clone().into_object()]), + vm.ctx.new_int(self.position.load()), + ])) + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let Some(i) = state.payload::() { + self.position + .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } +} impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 2df6000179..dfc7e76dcb 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -1,4 +1,4 @@ -use super::{PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef}; +use super::{PyDictRef, PyInt, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef, int}; use crate::{ anystr::{self, AnyStr}, bytesinner::{ @@ -595,7 +595,28 @@ impl PyValue for PyBytesIterator { } #[pyimpl(with(SlotIterator))] -impl PyBytesIterator {} +impl PyBytesIterator { + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyResult { + let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + Ok(vm.ctx.new_tuple(vec![ + iter, + vm.ctx.new_tuple(vec![self.bytes.clone().into_object()]), + vm.ctx.new_int(self.position.load()), + ])) + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let Some(i) = state.payload::() { + self.position + .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } +} impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { From bf04b505b1a4e43961581f7a6b9be626122151b4 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 19 Sep 2021 20:49:13 +0200 Subject: [PATCH 02/19] Refactor positional iterator with general logic --- vm/src/builtins/bytearray.rs | 55 +++++++------ vm/src/builtins/bytes.rs | 50 ++++++------ vm/src/builtins/iter.rs | 154 +++++++++++++++++++++++------------ vm/src/builtins/tuple.rs | 69 +++++----------- 4 files changed, 178 insertions(+), 150 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 694cdadf9e..77c6649466 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -717,8 +717,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - position: AtomicCell::new(0), - bytearray: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -731,8 +730,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - position: AtomicCell, - bytearray: PyByteArrayRef, + internal: PositionIterInternal, } impl PyValue for PyByteArrayIterator { @@ -743,36 +741,45 @@ impl PyValue for PyByteArrayIterator { #[pyimpl(with(SlotIterator))] impl PyByteArrayIterator { + #[pymethod(magic)] + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || { + Ok(self + .internal + .obj + .read() + .payload::() + .unwrap() + .len()) + }, + vm, + ) + } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.bytearray.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ])) + Ok(self.internal.reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } } impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let pos = zelf.position.fetch_add(1); - let r = if let Some(&ret) = zelf.bytearray.borrow_buf().get(pos) { - PyIterReturn::Return(ret.into_pyobject(vm)) - } else { - PyIterReturn::StopIteration(None) - }; - Ok(r) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let bytearray = zelf.internal.obj.read(); + let bytearray = bytearray.payload::().unwrap(); + let buf = bytearray.borrow_buf(); + buf.get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|&x| vm.ctx.new_int(x)) + }, + vm, + ) } } diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index dfc7e76dcb..a94da0a17d 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -17,7 +17,6 @@ use crate::{ PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use crossbeam_utils::atomic::AtomicCell; use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut}; use std::mem::size_of; use std::ops::Deref; @@ -574,8 +573,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - position: AtomicCell::new(0), - bytes: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -584,8 +582,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - position: AtomicCell, - bytes: PyBytesRef, + internal: PositionIterInternal, } impl PyValue for PyBytesIterator { @@ -596,37 +593,40 @@ impl PyValue for PyBytesIterator { #[pyimpl(with(SlotIterator))] impl PyBytesIterator { + #[pymethod(magic)] + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || Ok(self.internal.obj.read().payload::().unwrap().len()), + vm, + ) + } + #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.bytes.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ])) + Ok(self.internal.reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } } impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let pos = zelf.position.fetch_add(1); - let r = if let Some(&ret) = zelf.bytes.as_bytes().get(pos) { - PyIterReturn::Return(vm.ctx.new_int(ret)) - } else { - PyIterReturn::StopIteration(None) - }; - Ok(r) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let bytes = zelf.internal.obj.read(); + let bytes = bytes.payload::().unwrap(); + bytes + .as_bytes() + .get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|&x| vm.ctx.new_int(x)) + }, + vm, + ) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index efb0756b24..3a578106f7 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -12,6 +12,88 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; +#[derive(Debug)] +pub struct PositionIterInternal { + pub position: AtomicCell, + /// object or PyNone if exhausted + pub obj: PyRwLock, +} + +impl PositionIterInternal { + pub fn new(obj: PyObjectRef) -> Self { + Self { + position: AtomicCell::new(0), + obj: PyRwLock::new(obj), + } + } + + pub fn is_active(&self, vm: &VirtualMachine) -> bool { + !vm.is_none(&self.obj.read()) + } + + pub fn set_state(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if self.is_active(vm) { + if let Some(i) = state.payload::() { + let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); + self.position.store(i); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } else { + Ok(()) + } + } + + pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + if self.is_active(vm) { + vm.ctx.new_tuple(vec![ + func, + vm.ctx.new_tuple(vec![self.obj.read().clone()]), + vm.ctx.new_int(self.position.load()), + ]) + } else { + vm.ctx + .new_tuple(vec![func, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]) + } + } + + pub fn next(&self, f: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce(usize) -> PyResult, + { + if self.is_active(vm) { + let pos = self.position.fetch_add(1); + match f(pos) { + Err(ref e) + if e.isinstance(&vm.ctx.exceptions.index_error) + || e.isinstance(&vm.ctx.exceptions.stop_iteration) => + { + *self.obj.write() = vm.ctx.none(); + Err(vm.new_stop_iteration()) + } + ret => ret, + } + } else { + Err(vm.new_stop_iteration()) + } + } + + pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce() -> PyResult, + { + let len = if self.is_active(vm) { + let pos = self.position.load(); + let obj_len = f()?; + obj_len.saturating_sub(pos) + } else { + 0 + }; + Ok(PyInt::from(len).into_object(vm)) + } +} + /// Marks status of iterator. #[derive(Debug, Clone, Copy)] pub enum IterStatus { @@ -24,9 +106,10 @@ pub enum IterStatus { #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - pub position: AtomicCell, - pub obj: PyObjectRef, - pub status: AtomicCell, + internal: PositionIterInternal, + // pub position: AtomicCell, + // pub obj: PyObjectRef, + // pub status: AtomicCell } impl PyValue for PySequenceIterator { @@ -39,75 +122,38 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - position: AtomicCell::new(0), - obj, - status: AtomicCell::new(IterStatus::Active), + internal: PositionIterInternal::new(obj), } } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - match self.status.load() { - IterStatus::Active => { - let pos = self.position.load(); - // return NotImplemented if no length is around. - vm.obj_len(&self.obj) - .map_or(vm.ctx.not_implemented(), |len| { - PyInt::from(len.saturating_sub(pos)).into_object(vm) - }) - } - IterStatus::Exhausted => PyInt::from(0).into_object(vm), - } + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || { + vm.obj_len(&self.internal.obj.read()) + .map_err(|_| vm.new_not_implemented_error("".to_owned())) + }, + vm, + ) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - IterStatus::Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - IterStatus::Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.obj.clone()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + Ok(self.internal.reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let IterStatus::Exhausted = self.status.load() { - return Ok(()); - } - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } } impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let IterStatus::Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_add(1); - match zelf.obj.get_item(pos, vm) { - Err(ref e) - if e.isinstance(&vm.ctx.exceptions.index_error) - || e.isinstance(&vm.ctx.exceptions.stop_iteration) => - { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } - ret => ret.map(PyIterReturn::Return), - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal + .next(|pos| zelf.internal.obj.read().get_item(pos, vm), vm) } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index f733f222ad..160f01d479 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -310,9 +310,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(Active), - tuple: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -321,9 +319,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - position: AtomicCell, - status: AtomicCell, - tuple: PyTupleRef, + internal: PositionIterInternal, } impl PyValue for PyTupleIterator { @@ -335,61 +331,40 @@ impl PyValue for PyTupleIterator { #[pyimpl(with(SlotIterator))] impl PyTupleIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.tuple.len().saturating_sub(self.position.load()), - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + self.internal.length_hint( + || Ok(self.internal.obj.read().payload::().unwrap().len()), + vm, + ) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - // Else, set to min of (pos, tuple_size). - if let Some(i) = state.payload::() { - let position = std::cmp::min( - int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0), - self.tuple.len(), - ); - self.position.store(position); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.tuple.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + Ok(self.internal.reduce(iter, vm)) } } impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { - fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_add(1); - if let Some(obj) = zelf.tuple.as_slice().get(pos) { - Ok(PyIterReturn::Return(obj.clone())) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let tuple = zelf.internal.obj.read(); + let tuple = tuple.payload::().unwrap(); + tuple + .as_slice() + .get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|x| x.clone()) + }, + vm, + ) } } From 329afeaf154b81133348ffd482c2796cd19d2ffd Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 20 Sep 2021 10:18:47 +0200 Subject: [PATCH 03/19] Fix __length_hint__ to return not_implemented --- vm/src/builtins/bytearray.rs | 8 +++----- vm/src/builtins/bytes.rs | 10 ++++++++-- vm/src/builtins/iter.rs | 18 ++++++++++-------- vm/src/builtins/tuple.rs | 4 ++-- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 77c6649466..2447c3f433 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -742,16 +742,14 @@ impl PyValue for PyByteArrayIterator { #[pyimpl(with(SlotIterator))] impl PyByteArrayIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal.length_hint( || { - Ok(self - .internal + self.internal .obj .read() .payload::() - .unwrap() - .len()) + .map(|x| x.len()) }, vm, ) diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index a94da0a17d..616fb081c3 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -594,9 +594,15 @@ impl PyValue for PyBytesIterator { #[pyimpl(with(SlotIterator))] impl PyBytesIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal.length_hint( - || Ok(self.internal.obj.read().payload::().unwrap().len()), + || { + self.internal + .obj + .read() + .payload::() + .map(|x| x.len()) + }, vm, ) } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 3a578106f7..54065926fe 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -79,18 +79,21 @@ impl PositionIterInternal { } } - pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyResult + pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where - F: FnOnce() -> PyResult, + F: FnOnce() -> Option, { let len = if self.is_active(vm) { let pos = self.position.load(); - let obj_len = f()?; - obj_len.saturating_sub(pos) + if let Some(obj_len) = f() { + obj_len.saturating_sub(pos) + } else { + return vm.ctx.not_implemented(); + } } else { 0 }; - Ok(PyInt::from(len).into_object(vm)) + PyInt::from(len).into_object(vm) } } @@ -127,11 +130,10 @@ impl PySequenceIterator { } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal.length_hint( || { - vm.obj_len(&self.internal.obj.read()) - .map_err(|_| vm.new_not_implemented_error("".to_owned())) + vm.obj_len(&self.internal.obj.read()).ok() }, vm, ) diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 160f01d479..3ee0edce41 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -331,9 +331,9 @@ impl PyValue for PyTupleIterator { #[pyimpl(with(SlotIterator))] impl PyTupleIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal.length_hint( - || Ok(self.internal.obj.read().payload::().unwrap().len()), + || self.internal.obj.read().payload::().map(|x| x.len()), vm, ) } From 563c04dea9476d27221366e33479f147421d587b Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 20 Sep 2021 10:32:29 +0200 Subject: [PATCH 04/19] Refactor PyListIterator with PositionIterInternal --- vm/src/builtins/iter.rs | 11 ++----- vm/src/builtins/list.rs | 68 +++++++++++++++++----------------------- vm/src/builtins/tuple.rs | 8 ++++- 3 files changed, 37 insertions(+), 50 deletions(-) diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 54065926fe..9f985da419 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -110,9 +110,6 @@ pub enum IterStatus { #[derive(Debug)] pub struct PySequenceIterator { internal: PositionIterInternal, - // pub position: AtomicCell, - // pub obj: PyObjectRef, - // pub status: AtomicCell } impl PyValue for PySequenceIterator { @@ -131,12 +128,8 @@ impl PySequenceIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint( - || { - vm.obj_len(&self.internal.obj.read()).ok() - }, - vm, - ) + self.internal + .length_hint(|| vm.obj_len(&self.internal.obj.read()).ok(), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 1f0fdfad58..cebd4c8d71 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -419,9 +419,7 @@ impl PyList { impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(Active), - list: zelf, + internal: PositionIterInternal::new(zelf.into_object()), } .into_object(vm)) } @@ -476,9 +474,7 @@ fn do_sort( #[pyclass(module = false, name = "list_iterator")] #[derive(Debug)] pub struct PyListIterator { - pub position: AtomicCell, - status: AtomicCell, - pub list: PyListRef, + internal: PositionIterInternal, } impl PyValue for PyListIterator { @@ -490,53 +486,45 @@ impl PyValue for PyListIterator { #[pyimpl(with(SlotIterator))] impl PyListIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => { - let list = self.list.borrow_vec(); - let pos = self.position.load(); - list.len().saturating_sub(pos) - } - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal.length_hint( + || { + self.internal + .obj + .read() + .payload::() + .map(|x| x.len()) + }, + vm, + ) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - let position = list_state(self.list.len(), state, vm)?; - self.position.store(position); - Ok(()) + self.internal.set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let pos = if let Exhausted = self.status.load() { - None - } else { - Some(self.position.load()) - }; - list_reduce(self.list.clone(), pos, false, vm) + let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + Ok(self.internal.reduce(iter, vm)) } } impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { - fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let list = zelf.list.borrow_vec(); - let pos = zelf.position.fetch_add(1); - if let Some(obj) = list.get(pos) { - Ok(PyIterReturn::Return(obj.clone())) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.next( + |pos| { + let list = zelf.internal.obj.read(); + let list = list.payload::().unwrap(); + let vec = list.borrow_vec(); + vec.get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|x| x.clone()) + }, + vm, + ) } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 3ee0edce41..463f3a39ec 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -333,7 +333,13 @@ impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal.length_hint( - || self.internal.obj.read().payload::().map(|x| x.len()), + || { + self.internal + .obj + .read() + .payload::() + .map(|x| x.len()) + }, vm, ) } From 547e19b4934fa3111b812633a2f6200f1d451e13 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 22 Sep 2021 15:42:22 +0200 Subject: [PATCH 05/19] Refactor IterStatus to hold the PyObjectRef --- vm/src/builtins/bytearray.rs | 28 ++--- vm/src/builtins/bytes.rs | 29 ++--- vm/src/builtins/dict.rs | 168 ++++++++++++++++--------- vm/src/builtins/enumerate.rs | 125 ++++++++++--------- vm/src/builtins/iter.rs | 154 ++++++++++++++--------- vm/src/builtins/list.rs | 161 +++++++++++++----------- vm/src/builtins/pystr.rs | 141 +++++++++++---------- vm/src/builtins/tuple.rs | 28 ++--- vm/src/dictdatatype.rs | 11 ++ vm/src/stdlib/collections.rs | 236 ++++++++++++++++++++++------------- 10 files changed, 621 insertions(+), 460 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 2447c3f433..f049e48035 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -717,7 +717,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - internal: PositionIterInternal::new(zelf.into_object()), + internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), } .into_object(vm)) } @@ -730,7 +730,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - internal: PositionIterInternal, + internal: PyRwLock, } impl PyValue for PyByteArrayIterator { @@ -743,35 +743,27 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint( - || { - self.internal - .obj - .read() - .payload::() - .map(|x| x.len()) - }, - vm, - ) + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.reduce(iter, vm)) + Ok(self.internal.read().reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.set_state(state, vm) + self.internal.write().set_state(state, vm) } } impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.next( - |pos| { - let bytearray = zelf.internal.obj.read(); - let bytearray = bytearray.payload::().unwrap(); + zelf.internal.write().next( + |obj, pos| { + let bytearray = obj.payload::().unwrap(); let buf = bytearray.borrow_buf(); buf.get(pos) .ok_or_else(|| vm.new_stop_iteration()) diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 616fb081c3..5d710e6d50 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -18,6 +18,7 @@ use crate::{ }; use bstr::ByteSlice; use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut}; +use rustpython_common::lock::PyRwLock; use std::mem::size_of; use std::ops::Deref; @@ -573,7 +574,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - internal: PositionIterInternal::new(zelf.into_object()), + internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), } .into_object(vm)) } @@ -582,7 +583,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - internal: PositionIterInternal, + internal: PyRwLock, } impl PyValue for PyBytesIterator { @@ -595,36 +596,28 @@ impl PyValue for PyBytesIterator { impl PyBytesIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint( - || { - self.internal - .obj - .read() - .payload::() - .map(|x| x.len()) - }, - vm, - ) + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.reduce(iter, vm)) + Ok(self.internal.read().reduce(iter, vm)) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.set_state(state, vm) + self.internal.write().set_state(state, vm) } } impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.next( - |pos| { - let bytes = zelf.internal.obj.read(); - let bytes = bytes.payload::().unwrap(); + zelf.internal.write().next( + |obj, pos| { + let bytes = obj.payload::().unwrap(); bytes .as_bytes() .get(pos) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 7a57d688df..50f8498c5d 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -15,6 +15,7 @@ use crate::{ PyResult, PyValue, TryFromObject, TypeProtocol, }; use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::PyRwLock; use std::fmt; use std::mem::size_of; @@ -703,10 +704,11 @@ macro_rules! dict_iterator { #[pyclass(module = false, name = $iter_class_name)] #[derive(Debug)] pub(crate) struct $iter_name { - pub dict: PyDictRef, + // pub dict: PyDictRef, pub size: dictdatatype::DictSize, - pub position: AtomicCell, - pub status: AtomicCell, + // pub position: AtomicCell, + // pub status: AtomicCell, + pub internal: PyRwLock, } impl PyValue for $iter_name { @@ -719,57 +721,78 @@ macro_rules! dict_iterator { impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { - position: AtomicCell::new(0), + // position: AtomicCell::new(0), size: dict.size(), - dict, - status: AtomicCell::new(IterStatus::Active), + // dict, + internal: PyRwLock::new(PositionIterInternal::new(dict.into_object(), 0)), } } #[pymethod(magic)] - fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.entries.len_from_entry_index(self.position.load()) - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.entries.len()), vm) + // if let IterStatus::Exhausted = self.status.load() { + // 0 + // } else { + // self.dict.entries.len_from_entry_index(self.position.load()) + // } } } impl IteratorIterable for $iter_name {} impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.entries.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err(vm.new_runtime_error( - "dictionary changed size during iteration".to_owned(), - )); - } - match zelf.dict.entries.next_entry_atomic(&zelf.position) { - Some((key, value)) => { - Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) - } - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let internal = zelf.internal.write(); + if let IterStatus::Active(obj) = &internal.status { + let dict = obj.payload::().unwrap(); + if dict.entries.has_changed_size(&zelf.size) { + internal.status = IterStatus::Exhausted; + return Err(vm.new_runtime_error( + "dictionary changed size during iteration".to_owned(), + )); + } + match dict.entries.next_entry(&mut internal.position) { + Some((key, value)) => Ok(($result_fn)(vm, key, value)), + None => { + internal.status = IterStatus::Exhausted; + Err(vm.new_stop_iteration()) } } + } else { + Err(vm.new_stop_iteration()) } + // match zelf.status.load() { + // IterStatus::Exhausted => Err(vm.new_stop_iteration()), + // IterStatus::Active => { + // if zelf.dict.entries.has_changed_size(&zelf.size) { + // zelf.status.store(IterStatus::Exhausted); + // return Err(vm.new_runtime_error( + // "dictionary changed size during iteration".to_owned(), + // )); + // } + // match zelf.dict.entries.next_entry_atomic(&zelf.position) { + // Some((key, value)) => Ok(($result_fn)(vm, key, value)), + // None => { + // zelf.status.store(IterStatus::Exhausted); + // Err(vm.new_stop_iteration()) + // } + // } + // } + // } } } #[pyclass(module = false, name = $reverse_iter_class_name)] #[derive(Debug)] pub(crate) struct $reverse_iter_name { - pub dict: PyDictRef, + // pub dict: PyDictRef, pub size: dictdatatype::DictSize, - pub position: AtomicCell, - pub status: AtomicCell, + // pub position: AtomicCell, + // pub status: AtomicCell, + internal: PyRwLock, } impl PyValue for $reverse_iter_name { @@ -781,48 +804,73 @@ macro_rules! dict_iterator { #[pyimpl(with(SlotIterator))] impl $reverse_iter_name { fn new(dict: PyDictRef) -> Self { + let position = dict.entries.len().saturating_sub(1); $reverse_iter_name { - position: AtomicCell::new(0), + // position: AtomicCell::new(0), size: dict.size(), - dict, - status: AtomicCell::new(IterStatus::Active), + // dict, + // status: AtomicCell::new(IterStatus::Active), + internal: PyRwLock::new(PositionIterInternal::new( + dict.into_object(), + position, + )), } } #[pymethod(magic)] - fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.entries.len_from_entry_index(self.position.load()) - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .rev_length_hint(|obj| obj.payload::().map(|x| x.entries.len()), vm) + // if let IterStatus::Exhausted = self.status.load() { + // 0 + // } else { + // self.dict.entries.len_from_entry_index(self.position.load()) + // } } } impl IteratorIterable for $reverse_iter_name {} impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.entries.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err(vm.new_runtime_error( - "dictionary changed size during iteration".to_owned(), - )); - } - match zelf.dict.entries.next_entry_atomic_reversed(&zelf.position) { - Some((key, value)) => { - Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) - } - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let internal = zelf.internal.write(); + if let IterStatus::Active(obj) = &internal.status { + let dict = obj.payload::().unwrap(); + if dict.entries.has_changed_size(&zelf.size) { + internal.status = IterStatus::Exhausted; + return Err(vm.new_runtime_error( + "dictionary changed size during iteration".to_owned(), + )); + } + match dict.entries.prev_entry(&mut internal.position) { + Some((key, value)) => Ok(($result_fn)(vm, key, value)), + None => { + internal.status = IterStatus::Exhausted; + Err(vm.new_stop_iteration()) } } + } else { + Err(vm.new_stop_iteration()) } + // match zelf.status.load() { + // IterStatus::Exhausted => Err(vm.new_stop_iteration()), + // IterStatus::Active => { + // if zelf.dict.entries.has_changed_size(&zelf.size) { + // zelf.status.store(IterStatus::Exhausted); + // return Err(vm.new_runtime_error( + // "dictionary changed size during iteration".to_owned(), + // )); + // } + // match zelf.dict.entries.next_entry_atomic_reversed(&zelf.position) { + // Some((key, value)) => Ok(($result_fn)(vm, key, value)), + // None => { + // zelf.status.store(IterStatus::Exhausted); + // Err(vm.new_stop_iteration()) + // } + // } + // } + // } } } }; diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index e2cfc1bf41..3dcc589c3d 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -72,9 +72,10 @@ impl SlotIterator for PyEnumerate { #[pyclass(module = false, name = "reversed")] #[derive(Debug)] pub struct PyReverseSequenceIterator { - pub position: AtomicCell, - pub status: AtomicCell, - pub obj: PyObjectRef, + internal: PyRwLock, + // pub position: AtomicCell, + // pub status: AtomicCell, + // pub obj: PyObjectRef, } impl PyValue for PyReverseSequenceIterator { @@ -86,80 +87,88 @@ impl PyValue for PyReverseSequenceIterator { #[pyimpl(with(SlotIterator))] impl PyReverseSequenceIterator { pub fn new(obj: PyObjectRef, len: usize) -> Self { + let position = len.saturating_sub(1); Self { - position: AtomicCell::new(len.saturating_sub(1)), - status: AtomicCell::new(if len == 0 { Exhausted } else { Active }), - obj, + internal: PyRwLock::new(PositionIterInternal::new(obj, position)) + // position: AtomicCell::new(len.saturating_sub(1)), + // status: AtomicCell::new(if len == 0 { Exhausted } else { Active }), + // obj, } } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyResult { - Ok(match self.status.load() { - Active => { - let position = self.position.load(); - if position > vm.obj_len(&self.obj)? { - 0 - } else { - position + 1 - } - } - Exhausted => 0, - }) + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .rev_length_hint(|obj| vm.obj_len(obj).ok(), vm) + // Ok(match self.status.load() { + // Active => { + // let position = self.position.load(); + // if position > vm.obj_len(&self.obj)? { + // 0 + // } else { + // position + 1 + // } + // } + // Exhausted => 0, + // }) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - let len = vm.obj_len(&self.obj)?; - let pos = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let pos = std::cmp::min( - int::try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - len.saturating_sub(1), - ); - self.position.store(pos); - Ok(()) + self.internal.read().set_state(state, vm) + // // When we're exhausted, just return. + // if let Exhausted = self.status.load() { + // return Ok(()); + // } + // let len = vm.obj_len(&self.obj)?; + // let pos = state + // .payload::() + // .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; + // let pos = std::cmp::min( + // try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), + // len.saturating_sub(1), + // ); + // self.position.store(pos); + // Ok(()) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; - Ok(vm.ctx.new_tuple(match self.status.load() { - Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![])])], - Active => vec![ - iter, - vm.ctx.new_tuple(vec![self.obj.clone()]), - vm.ctx.new_int(self.position.load()), - ], - })) + Ok(self.internal.read().reduce(iter, vm)) + // Ok(vm.ctx.new_tuple(match self.status.load() { + // Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![])])], + // Active => vec![ + // iter, + // vm.ctx.new_tuple(vec![self.obj.clone()]), + // vm.ctx.new_int(self.position.load()), + // ], + // })) } } impl IteratorIterable for PyReverseSequenceIterator {} impl SlotIterator for PyReverseSequenceIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_sub(1); - if pos == 0 { - zelf.status.store(Exhausted); - } - match zelf.obj.get_item(pos, vm) { - Err(ref e) - if e.isinstance(&vm.ctx.exceptions.index_error) - || e.isinstance(&vm.ctx.exceptions.stop_iteration) => - { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } - other => other.map(PyIterReturn::Return), - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal + .read() + .rev_next(|obj, pos| obj.get_item(pos, vm), vm) + // if let Exhausted = zelf.status.load() { + // return Err(vm.new_stop_iteration()); + // } + // let pos = zelf.position.fetch_sub(1); + // if pos == 0 { + // zelf.status.store(Exhausted); + // } + // match zelf.obj.get_item(pos, vm) { + // Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + // zelf.status.store(Exhausted); + // Err(vm.new_stop_iteration()) + // } + // // also catches stop_iteration => stop_iteration + // ret => ret, + // } } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 9f985da419..0acae123cd 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -12,30 +12,34 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; +/// Marks status of iterator. +#[derive(Debug, Clone)] +pub enum IterStatus { + /// Iterator hasn't raised StopIteration. + Active(PyObjectRef), + /// Iterator has raised StopIteration. + Exhausted, +} + #[derive(Debug)] pub struct PositionIterInternal { - pub position: AtomicCell, - /// object or PyNone if exhausted - pub obj: PyRwLock, + pub status: IterStatus, + pub position: usize, } impl PositionIterInternal { - pub fn new(obj: PyObjectRef) -> Self { + pub fn new(obj: PyObjectRef, position: usize) -> Self { Self { - position: AtomicCell::new(0), - obj: PyRwLock::new(obj), + status: IterStatus::Active(obj), + position, } } - pub fn is_active(&self, vm: &VirtualMachine) -> bool { - !vm.is_none(&self.obj.read()) - } - - pub fn set_state(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if self.is_active(vm) { + pub fn set_state(&mut self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let IterStatus::Active(_) = &self.status { if let Some(i) = state.payload::() { let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); - self.position.store(i); + self.position = i; Ok(()) } else { Err(vm.new_type_error("an integer is required.".to_owned())) @@ -46,11 +50,11 @@ impl PositionIterInternal { } pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - if self.is_active(vm) { + if let IterStatus::Active(obj) = &self.status { vm.ctx.new_tuple(vec![ func, - vm.ctx.new_tuple(vec![self.obj.read().clone()]), - vm.ctx.new_int(self.position.load()), + vm.ctx.new_tuple(vec![obj.clone()]), + vm.ctx.new_int(self.position), ]) } else { vm.ctx @@ -58,21 +62,54 @@ impl PositionIterInternal { } } - pub fn next(&self, f: F, vm: &VirtualMachine) -> PyResult + pub fn next(&mut self, f: F, vm: &VirtualMachine) -> PyResult where - F: FnOnce(usize) -> PyResult, + F: FnOnce(&PyObjectRef, usize) -> PyResult, { - if self.is_active(vm) { - let pos = self.position.fetch_add(1); - match f(pos) { - Err(ref e) - if e.isinstance(&vm.ctx.exceptions.index_error) - || e.isinstance(&vm.ctx.exceptions.stop_iteration) => - { - *self.obj.write() = vm.ctx.none(); + if let IterStatus::Active(obj) = &self.status { + match f(obj, self.position) { + Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { + self.status = IterStatus::Exhausted; + Err(e) + } + Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + self.status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } - ret => ret, + Err(e) => Err(e), + Ok(ret) => { + self.position += 1; + Ok(ret) + } + } + } else { + Err(vm.new_stop_iteration()) + } + } + + pub fn rev_next(&mut self, f: F, vm: &VirtualMachine) -> PyResult + where + F: FnOnce(&PyObjectRef, usize) -> PyResult, + { + if let IterStatus::Active(obj) = &self.status { + match f(obj, self.position) { + Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { + self.status = IterStatus::Exhausted; + Err(e) + } + Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + self.status = IterStatus::Exhausted; + Err(vm.new_stop_iteration()) + } + Err(e) => Err(e), + Ok(ret) => { + if self.position == 0 { + self.status = IterStatus::Exhausted; + } else { + self.position -= 1; + } + Ok(ret) + } } } else { Err(vm.new_stop_iteration()) @@ -81,12 +118,11 @@ impl PositionIterInternal { pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where - F: FnOnce() -> Option, + F: FnOnce(&PyObjectRef) -> Option, { - let len = if self.is_active(vm) { - let pos = self.position.load(); - if let Some(obj_len) = f() { - obj_len.saturating_sub(pos) + let len = if let IterStatus::Active(obj) = &self.status { + if let Some(obj_len) = f(obj) { + obj_len.saturating_sub(self.position) } else { return vm.ctx.not_implemented(); } @@ -95,15 +131,20 @@ impl PositionIterInternal { }; PyInt::from(len).into_object(vm) } -} -/// Marks status of iterator. -#[derive(Debug, Clone, Copy)] -pub enum IterStatus { - /// Iterator hasn't raised StopIteration. - Active, - /// Iterator has raised StopIteration. - Exhausted, + pub fn rev_length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&PyObjectRef) -> Option, + { + if let IterStatus::Active(obj) = &self.status { + if let Some(obj_len) = f(obj) { + if self.position <= obj_len { + return PyInt::from(self.position + 1).into_object(vm); + } + } + } + PyInt::from(0).into_object(vm) + } } #[pyclass(module = false, name = "iterator")] @@ -122,14 +163,13 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - internal: PositionIterInternal::new(obj), + internal: PositionIterInternal::new(obj, 0), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .length_hint(|| vm.obj_len(&self.internal.obj.read()).ok(), vm) + self.internal.length_hint(|obj| vm.obj_len(obj).ok(), vm) } #[pymethod(magic)] @@ -147,17 +187,15 @@ impl PySequenceIterator { impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal - .next(|pos| zelf.internal.obj.read().get_item(pos, vm), vm) + zelf.internal.next(|obj, pos| obj.get_item(pos, vm), vm) } } #[pyclass(module = false, name = "callable_iterator")] #[derive(Debug)] pub struct PyCallableIterator { - callable: ArgCallable, sentinel: PyObjectRef, - status: AtomicCell, + status: PyRwLock, } impl PyValue for PyCallableIterator { @@ -170,25 +208,25 @@ impl PyValue for PyCallableIterator { impl PyCallableIterator { pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self { Self { - callable, sentinel, - status: AtomicCell::new(IterStatus::Active), + status: PyRwLock::new(IterStatus::Active(callable.into_object())), } } } impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let IterStatus::Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let ret = zelf.callable.invoke((), vm)?; - if vm.bool_eq(&ret, &zelf.sentinel)? { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + if let IterStatus::Active(callable) = &*zelf.status.read() { + let ret = vm.invoke(callable, ())?; + if vm.bool_eq(&ret, &zelf.sentinel)? { + *zelf.status.write() = IterStatus::Exhausted; + Err(vm.new_stop_iteration()) + } else { + Ok(ret) + } } else { - Ok(PyIterReturn::Return(ret)) + Err(vm.new_stop_iteration()) } } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index cebd4c8d71..4c2e0b2411 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -161,16 +161,17 @@ impl PyList { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyListReverseIterator { - let final_position = zelf.borrow_vec().len(); + let position = zelf.len().saturating_sub(1); // Mark iterator as exhausted immediately if its empty. PyListReverseIterator { - position: AtomicCell::new(final_position.saturating_sub(1)), - status: AtomicCell::new(if final_position == 0 { - Exhausted - } else { - Active - }), - list: zelf, + internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), position)) + // position: AtomicCell::new(final_position.saturating_sub(1)), + // status: AtomicCell::new(if final_position == 0 { + // Exhausted + // } else { + // Active + // }), + // list: zelf, } } @@ -419,7 +420,7 @@ impl PyList { impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { - internal: PositionIterInternal::new(zelf.into_object()), + internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), } .into_object(vm)) } @@ -474,7 +475,7 @@ fn do_sort( #[pyclass(module = false, name = "list_iterator")] #[derive(Debug)] pub struct PyListIterator { - internal: PositionIterInternal, + internal: PyRwLock, } impl PyValue for PyListIterator { @@ -487,37 +488,29 @@ impl PyValue for PyListIterator { impl PyListIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint( - || { - self.internal - .obj - .read() - .payload::() - .map(|x| x.len()) - }, - vm, - ) + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.set_state(state, vm) + self.internal.write().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.reduce(iter, vm)) + Ok(self.internal.read().reduce(iter, vm)) } } impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.next( - |pos| { - let list = zelf.internal.obj.read(); - let list = list.payload::().unwrap(); + zelf.internal.write().next( + |obj, pos| { + let list = obj.payload::().unwrap(); let vec = list.borrow_vec(); vec.get(pos) .ok_or_else(|| vm.new_stop_iteration()) @@ -531,9 +524,9 @@ impl SlotIterator for PyListIterator { #[pyclass(module = false, name = "list_reverseiterator")] #[derive(Debug)] pub struct PyListReverseIterator { - pub position: AtomicCell, - pub status: AtomicCell, - pub list: PyListRef, + internal: PyRwLock, // pub position: AtomicCell, + // pub status: AtomicCell, + // pub list: PyListRef } impl PyValue for PyListReverseIterator { @@ -545,68 +538,84 @@ impl PyValue for PyListReverseIterator { #[pyimpl(with(SlotIterator))] impl PyListReverseIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => { - let position = self.position.load(); - if position > self.list.len() { - // List was mutated. Report zero, next call to `__next__` will - // fail and set iterator to Exhausted. - 0 - } else { - position + 1 - } - } - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .rev_length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + // match self.status.load() { + // Active => { + // let position = self.position.load(); + // if position > self.list.len() { + // // List was mutated. Report zero, next call to `__next__` will + // // fail and set iterator to Exhausted. + // 0 + // } else { + // position + 1 + // } + // } + // Exhausted => 0, + // } } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } + // if let Exhausted = self.status.load() { + // return Ok(()); + // } - // Max for position is list.len() - 1. - let position = list_state(self.list.len().saturating_sub(1), state, vm)?; - self.position.store(position); - Ok(()) + // // Max for position is list.len() - 1. + // let position = list_state(self.list.len().saturating_sub(1), state, vm)?; + // self.position.store(position); + // Ok(()) + self.internal.write().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let pos = if let Exhausted = self.status.load() { - None - } else { - Some(self.position.load()) - }; - list_reduce(self.list.clone(), pos, true, vm) + let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; + Ok(self.internal.read().reduce(iter, vm)) + // let pos = if let Exhausted = self.status.load() { + // None + // } else { + // Some(self.position.load()) + // }; + // list_reduce(self.list.clone(), pos, true, vm) } } impl IteratorIterable for PyListReverseIterator {} impl SlotIterator for PyListReverseIterator { - fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let list = zelf.list.borrow_vec(); - let pos = zelf.position.fetch_sub(1); - if pos > 0 { - if let Some(obj) = list.get(pos) { - return Ok(PyIterReturn::Return(obj.clone())); - } - } - // We either are == 0 or list.get returned None. Either way, set status - // to exhausted and return last item if pos == 0. - zelf.status.store(Exhausted); - if pos == 0 { - if let Some(obj) = list.get(pos) { - return Ok(PyIterReturn::Return(obj.clone())); - } - } - Ok(PyIterReturn::StopIteration(None)) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.write().rev_next( + |obj, pos| { + let list = obj.payload::().unwrap(); + let vec = list.borrow_vec(); + vec.get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|x| x.clone()) + }, + vm, + ) + // if let Exhausted = zelf.status.load() { + // return Err(vm.new_stop_iteration()); + // } + // let list = zelf.list.borrow_vec(); + // let pos = zelf.position.fetch_sub(1); + // if pos > 0 { + // if let Some(obj) = list.get(pos) { + // return Ok(obj.clone()); + // } + // } + // // We either are == 0 or list.get returned None. Either way, set status + // // to exhausted and return last item if pos == 0. + // zelf.status.store(Exhausted); + // if pos == 0 { + // if let Some(obj) = list.get(pos) { + // return Ok(obj.clone()); + // } + // } + // Err(vm.new_stop_iteration()) } } diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index a3ff1d0bd8..70e5f33835 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -170,9 +170,10 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - string: PyStrRef, - position: PyAtomic, - status: AtomicCell, + // string: PyStrRef, + // position: PyAtomic, + // status: AtomicCell, + internal: PyRwLock, } impl PyValue for PyStrIterator { @@ -184,82 +185,91 @@ impl PyValue for PyStrIterator { #[pyimpl(with(SlotIterator))] impl PyStrIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => { - let pos = self.position.load(atomic::Ordering::SeqCst); - self.string.len().saturating_sub(pos) - } - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + // match self.status.load() { + // Active => { + // let pos = self.position.load(atomic::Ordering::SeqCst); + // self.string.len().saturating_sub(pos) + // } + // Exhausted => 0, + // } } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - let pos = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let pos = std::cmp::min( - try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - self.string.len(), - ); - self.position.store(pos, atomic::Ordering::SeqCst); - Ok(()) + self.internal.write().set_state(state, vm) + // // When we're exhausted, just return. + // if let Exhausted = self.status.load() { + // return Ok(()); + // } + // let pos = state + // .payload::() + // .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; + // let pos = std::cmp::min( + // try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), + // self.string.len(), + // ); + // self.position.store(pos, atomic::Ordering::SeqCst); + // Ok(()) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(vm.ctx.new_tuple(match self.status.load() { - Exhausted => vec![ - iter, - vm.ctx.new_tuple(vec![vm.ctx.new_ascii_literal(ascii!(""))]), - ], - Active => vec![ - iter, - vm.ctx.new_tuple(vec![self.string.clone().into_object()]), - vm.ctx - .new_int(self.position.load(atomic::Ordering::Relaxed)), - ], - })) + Ok(self.internal.read().reduce(iter, vm)) } } impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let value = &*zelf.string.as_str(); - let mut start = zelf.position.load(atomic::Ordering::SeqCst); - loop { - if start == value.len() { - zelf.status.store(Exhausted); - return Ok(PyIterReturn::StopIteration(None)); + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let internal = zelf.internal.write(); + if let IterStatus::Active(obj) = &internal.status { + let s = obj.payload::().unwrap(); + let value = s.as_str(); + if internal.position >= value.len() { + internal.status = Exhausted; + return Err(vm.new_stop_iteration()); } - let ch = match value[start..].chars().next() { - Some(ch) => ch, - None => { - zelf.status.store(Exhausted); - return Ok(PyIterReturn::StopIteration(None)); - } - }; - match zelf.position.compare_exchange_weak( - start, - start + ch.len_utf8(), - atomic::Ordering::Release, - atomic::Ordering::Relaxed, - ) { - Ok(_) => break Ok(PyIterReturn::Return(ch.into_pyobject(vm))), - Err(cur) => start = cur, - } + let ch = value[internal.position..].chars().next().ok_or_else(|| { + internal.status = Exhausted; + vm.new_stop_iteration() + })?; + + internal.position += ch.len_utf8(); + Ok(ch.into_pyobject(vm)) + } else { + Err(vm.new_stop_iteration()) } + // if let Exhausted = zelf.status.load() { + // return Err(vm.new_stop_iteration()); + // } + // let value = &*zelf.string.as_str(); + // let mut start = zelf.position.load(atomic::Ordering::SeqCst); + // loop { + // if start == value.len() { + // zelf.status.store(Exhausted); + // return Err(vm.new_stop_iteration()); + // } + // let ch = value[start..].chars().next().ok_or_else(|| { + // zelf.status.store(Exhausted); + // vm.new_stop_iteration() + // })?; + + // match zelf.position.compare_exchange_weak( + // start, + // start + ch.len_utf8(), + // atomic::Ordering::Release, + // atomic::Ordering::Relaxed, + // ) { + // Ok(_) => break Ok(ch.into_pyobject(vm)), + // Err(cur) => start = cur, + // } + // } } } @@ -1294,9 +1304,10 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - position: Radium::new(0), - string: zelf, - status: AtomicCell::new(Active), + // position: Radium::new(0), + // string: zelf, + // status: AtomicCell::new(Active), + internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), } .into_object(vm)) } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 463f3a39ec..833a7552d7 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -310,7 +310,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - internal: PositionIterInternal::new(zelf.into_object()), + internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), } .into_object(vm)) } @@ -319,7 +319,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - internal: PositionIterInternal, + internal: PyRwLock, } impl PyValue for PyTupleIterator { @@ -332,37 +332,29 @@ impl PyValue for PyTupleIterator { impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint( - || { - self.internal - .obj - .read() - .payload::() - .map(|x| x.len()) - }, - vm, - ) + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.set_state(state, vm) + self.internal.write().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.reduce(iter, vm)) + Ok(self.internal.read().reduce(iter, vm)) } } impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.next( - |pos| { - let tuple = zelf.internal.obj.read(); - let tuple = tuple.payload::().unwrap(); + zelf.internal.write().next( + |obj, pos| { + let tuple = obj.payload::().unwrap(); tuple .as_slice() .get(pos) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 04ed619e00..996b757789 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -490,6 +490,17 @@ impl Dict { } } + pub fn prev_entry(&self, position: &mut EntryIndex) -> Option<(PyObjectRef, T)> { + let inner = self.read(); + loop { + let entry = inner.entries.get(*position)?; + *position = position.checked_sub(1)?; + if let Some(entry) = entry { + break Some((entry.key.clone(), entry.value.clone())); + } + } + } + pub fn next_entry_atomic(&self, position: &AtomicCell) -> Option<(PyObjectRef, T)> { let inner = self.read(); loop { diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 9158ea67a7..8c4aa28c9f 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -2,6 +2,7 @@ pub(crate) use _collections::make_module; #[pymodule] mod _collections { + use crate::builtins::PositionIterInternal; use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::{ builtins::{ @@ -357,12 +358,13 @@ mod _collections { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyResult { - let length = zelf.len(); + // let length = zelf.len(); Ok(PyReverseDequeIterator { - position: AtomicCell::new(length), - status: AtomicCell::new(if length > 0 { Active } else { Exhausted }), - length, - deque: zelf, + state: zelf.state.load(), + internal: PyRwLock::new(PositionIterInternal::new( + zelf.into_object(), + zelf.len() - 1, + )), }) } @@ -573,10 +575,12 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug, PyValue)] struct PyDequeIterator { - position: AtomicCell, - status: AtomicCell, - length: usize, // To track length immutability. - deque: PyDequeRef, + // position: AtomicCell, + // status: AtomicCell, + // length: usize, // To track length immutability. + // deque: PyDequeRef, + state: usize, + internal: PyRwLock, } #[derive(FromArgs)] @@ -596,15 +600,16 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - let len = deque.len(); + // let len = deque.len(); let iter = PyDequeIterator::new(deque); if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; - iter.position.store(min(index, len)); + iter.internal.write().position = index; + // iter.position.store(min(index, len)); - if len.le(&index) { - iter.status.store(Exhausted); - } + // if len.le(&index) { + // iter.status.store(Exhausted); + // } } iter.into_pyresult_with_type(vm, cls) } @@ -614,56 +619,82 @@ mod _collections { impl PyDequeIterator { pub(crate) fn new(deque: PyDequeRef) -> Self { PyDequeIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(IterStatus::Active), - length: deque.len(), - deque, + // position: AtomicCell::new(0), + // status: AtomicCell::new(IterStatus::Active), + // length: deque.len(), + // deque, + state: deque.state.load(), + internal: PyRwLock::new(PositionIterInternal::new(deque.into_object(), 0)), } } #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.deque.len().saturating_sub(self.position.load()), - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + // match self.status.load() { + // Active => self.deque.len().saturating_sub(self.position.load()), + // Exhausted => 0, + // } } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { - Ok(( + ) -> (PyTypeRef, (PyObjectRef, PyObjectRef)) { + let internal = zelf.internal.read(); + let deque = match &internal.status { + Active(obj) => obj.clone(), + Exhausted => PyDeque::default().into_object(vm), + }; + ( zelf.clone_class(), - (zelf.deque.clone(), vm.ctx.new_int(zelf.position.load())), - )) + (deque, vm.ctx.new_int(internal.position)), + ) } } impl IteratorIterable for PyDequeIterator {} impl SlotIterator for PyDequeIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - Exhausted => Ok(PyIterReturn::StopIteration(None)), - Active => { - if zelf.length != zelf.deque.len() { - // Deque was changed while we iterated. - zelf.status.store(Exhausted); - Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - } else { - let pos = zelf.position.fetch_add(1); - let deque = zelf.deque.borrow_deque(); - if pos < deque.len() { - let ret = deque[pos].clone(); - Ok(PyIterReturn::Return(ret)) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.write().next( + |obj, pos| { + let deque = obj.payload::().unwrap(); + if zelf.state != deque.state.load() { + return Err( + vm.new_runtime_error("Deque mutated during iteration".to_owned()) + ); } - } - } + let deque = deque.borrow_deque(); + deque + .get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|x| x.clone()) + }, + vm, + ) + // match zelf.status.load() { + // Exhausted => Err(vm.new_stop_iteration()), + // Active => { + // if zelf.length != zelf.deque.len() { + // // Deque was changed while we iterated. + // zelf.status.store(Exhausted); + // Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) + // } else { + // let pos = zelf.position.fetch_add(1); + // let deque = zelf.deque.borrow_deque(); + // if pos < deque.len() { + // let ret = deque[pos].clone(); + // Ok(ret) + // } else { + // zelf.status.store(Exhausted); + // Err(vm.new_stop_iteration()) + // } + // } + // } + // } } } @@ -671,10 +702,12 @@ mod _collections { #[pyclass(name = "_deque_reverse_iterator")] #[derive(Debug, PyValue)] struct PyReverseDequeIterator { - position: AtomicCell, - status: AtomicCell, - length: usize, // To track length immutability. - deque: PyDequeRef, + // position: AtomicCell, + // status: AtomicCell, + // length: usize, // To track length immutability. + // deque: PyDequeRef, + state: usize, + internal: PyRwLock, } impl SlotConstructor for PyReverseDequeIterator { @@ -686,14 +719,15 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - let len = deque.len(); + // let len = deque.len(); let iter = PyDeque::reversed(deque)?; if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; - if len.le(&index) { - iter.status.store(Exhausted); - } - iter.position.store(len.saturating_sub(index)); + iter.internal.write().position = index; + // if len.le(&index) { + // iter.status.store(Exhausted); + // } + // iter.position.store(len.saturating_sub(index)); } iter.into_pyresult_with_type(vm, cls) } @@ -702,54 +736,78 @@ mod _collections { #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyReverseDequeIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.position.load(), - Exhausted => 0, - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .rev_length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + // match self.status.load() { + // Active => self.position.load(), + // Exhausted => 0, + // } } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { + ) -> PyResult<(PyTypeRef, (PyObjectRef, PyObjectRef))> { + let internal = zelf.internal.read(); + let deque = match &internal.status { + Active(obj) => obj.clone(), + Exhausted => PyDeque::default().into_object(vm), + }; Ok(( zelf.clone_class(), - (zelf.deque.clone(), vm.ctx.new_int(zelf.position.load())), + (deque, vm.ctx.new_int(internal.position)), )) } } impl IteratorIterable for PyReverseDequeIterator {} impl SlotIterator for PyReverseDequeIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - Exhausted => Ok(PyIterReturn::StopIteration(None)), - Active => { - // If length changes while we iterate, set to Exhausted and bail. - if zelf.length != zelf.deque.len() { - zelf.status.store(Exhausted); - Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - } else { - let pos = zelf.position.fetch_sub(1) - 1; - let deque = zelf.deque.borrow_deque(); - if pos > 0 { - if let Some(obj) = deque.get(pos) { - return Ok(PyIterReturn::Return(obj.clone())); - } - } - // We either are == 0 or deque.get returned None. Either way, set status - // to exhausted and return last item if pos == 0. - zelf.status.store(Exhausted); - if pos == 0 { - // Can safely index directly. - return Ok(PyIterReturn::Return(deque[pos].clone())); - } - Ok(PyIterReturn::StopIteration(None)) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.write().rev_next( + |obj, pos| { + let deque = obj.payload::().unwrap(); + if deque.state.load() != zelf.state { + return Err( + vm.new_runtime_error("Deque mutated during iteration".to_owned()) + ); } - } - } + let deque = deque.borrow_deque(); + deque + .get(pos) + .ok_or_else(|| vm.new_stop_iteration()) + .map(|x| x.clone()) + }, + vm, + ) + // match zelf.status.load() { + // Exhausted => Err(vm.new_stop_iteration()), + // Active => { + // // If length changes while we iterate, set to Exhausted and bail. + // if zelf.length != zelf.deque.len() { + // zelf.status.store(Exhausted); + // Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) + // } else { + // let pos = zelf.position.fetch_sub(1) - 1; + // let deque = zelf.deque.borrow_deque(); + // if pos > 0 { + // if let Some(obj) = deque.get(pos) { + // return Ok(obj.clone()); + // } + // } + // // We either are == 0 or deque.get returned None. Either way, set status + // // to exhausted and return last item if pos == 0. + // zelf.status.store(Exhausted); + // if pos == 0 { + // // Can safely index directly. + // return Ok(deque[pos].clone()); + // } + // Err(vm.new_stop_iteration()) + // } + // } + // } } } } From b6fa8670e8a2e9e42151e669fdf4480a5514a454 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 22 Sep 2021 19:41:54 +0200 Subject: [PATCH 06/19] Refactor IterStatus with generic payload --- vm/src/builtins/bytearray.rs | 18 +++-- vm/src/builtins/bytes.rs | 18 +++-- vm/src/builtins/dict.rs | 103 +++++++---------------------- vm/src/builtins/enumerate.rs | 69 +++----------------- vm/src/builtins/iter.rs | 69 +++++++++++++------- vm/src/builtins/list.rs | 123 +++++------------------------------ vm/src/builtins/pystr.rs | 72 +++----------------- vm/src/builtins/set.rs | 91 ++++++++++++++++---------- vm/src/builtins/tuple.rs | 18 +++-- vm/src/dictdatatype.rs | 4 +- vm/src/stdlib/collections.rs | 112 ++++--------------------------- 11 files changed, 197 insertions(+), 500 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index f049e48035..8b1a8011cf 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -717,7 +717,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -730,7 +730,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyByteArrayIterator { @@ -743,14 +743,13 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] @@ -762,8 +761,7 @@ impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let bytearray = obj.payload::().unwrap(); + |bytearray, pos| { let buf = bytearray.borrow_buf(); buf.get(pos) .ok_or_else(|| vm.new_stop_iteration()) diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 5d710e6d50..200937a94c 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -574,7 +574,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -583,7 +583,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyBytesIterator { @@ -596,15 +596,14 @@ impl PyValue for PyBytesIterator { impl PyBytesIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] @@ -616,8 +615,7 @@ impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let bytes = obj.payload::().unwrap(); + |bytes, pos| { bytes .as_bytes() .get(pos) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 50f8498c5d..39a6120660 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -704,11 +704,8 @@ macro_rules! dict_iterator { #[pyclass(module = false, name = $iter_class_name)] #[derive(Debug)] pub(crate) struct $iter_name { - // pub dict: PyDictRef, pub size: dictdatatype::DictSize, - // pub position: AtomicCell, - // pub status: AtomicCell, - pub internal: PyRwLock, + pub internal: PyRwLock>, } impl PyValue for $iter_name { @@ -721,10 +718,8 @@ macro_rules! dict_iterator { impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { - // position: AtomicCell::new(0), size: dict.size(), - // dict, - internal: PyRwLock::new(PositionIterInternal::new(dict.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(dict, 0)), } } @@ -732,12 +727,7 @@ macro_rules! dict_iterator { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .length_hint(|obj| obj.payload::().map(|x| x.entries.len()), vm) - // if let IterStatus::Exhausted = self.status.load() { - // 0 - // } else { - // self.dict.entries.len_from_entry_index(self.position.load()) - // } + .length_hint(|obj| Some(obj.entries.len()), vm) } } @@ -745,54 +735,34 @@ macro_rules! dict_iterator { impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let internal = zelf.internal.write(); - if let IterStatus::Active(obj) = &internal.status { - let dict = obj.payload::().unwrap(); + let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); + let mut position = + PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); + if let IterStatus::Active(dict) = &*status { if dict.entries.has_changed_size(&zelf.size) { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; return Err(vm.new_runtime_error( "dictionary changed size during iteration".to_owned(), )); } - match dict.entries.next_entry(&mut internal.position) { + match dict.entries.next_entry(&mut *position) { Some((key, value)) => Ok(($result_fn)(vm, key, value)), None => { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } } else { Err(vm.new_stop_iteration()) } - // match zelf.status.load() { - // IterStatus::Exhausted => Err(vm.new_stop_iteration()), - // IterStatus::Active => { - // if zelf.dict.entries.has_changed_size(&zelf.size) { - // zelf.status.store(IterStatus::Exhausted); - // return Err(vm.new_runtime_error( - // "dictionary changed size during iteration".to_owned(), - // )); - // } - // match zelf.dict.entries.next_entry_atomic(&zelf.position) { - // Some((key, value)) => Ok(($result_fn)(vm, key, value)), - // None => { - // zelf.status.store(IterStatus::Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } #[pyclass(module = false, name = $reverse_iter_class_name)] #[derive(Debug)] pub(crate) struct $reverse_iter_name { - // pub dict: PyDictRef, pub size: dictdatatype::DictSize, - // pub position: AtomicCell, - // pub status: AtomicCell, - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for $reverse_iter_name { @@ -804,16 +774,11 @@ macro_rules! dict_iterator { #[pyimpl(with(SlotIterator))] impl $reverse_iter_name { fn new(dict: PyDictRef) -> Self { - let position = dict.entries.len().saturating_sub(1); + let size = dict.size(); + let position = size.entries_size.saturating_sub(1); $reverse_iter_name { - // position: AtomicCell::new(0), - size: dict.size(), - // dict, - // status: AtomicCell::new(IterStatus::Active), - internal: PyRwLock::new(PositionIterInternal::new( - dict.into_object(), - position, - )), + size, + internal: PyRwLock::new(PositionIterInternal::new(dict, position)), } } @@ -821,12 +786,7 @@ macro_rules! dict_iterator { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .rev_length_hint(|obj| obj.payload::().map(|x| x.entries.len()), vm) - // if let IterStatus::Exhausted = self.status.load() { - // 0 - // } else { - // self.dict.entries.len_from_entry_index(self.position.load()) - // } + .rev_length_hint(|_| Some(self.size.entries_size), vm) } } @@ -834,43 +794,26 @@ macro_rules! dict_iterator { impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let internal = zelf.internal.write(); - if let IterStatus::Active(obj) = &internal.status { - let dict = obj.payload::().unwrap(); + let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); + let mut position = + PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); + if let IterStatus::Active(dict) = &*status { if dict.entries.has_changed_size(&zelf.size) { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; return Err(vm.new_runtime_error( "dictionary changed size during iteration".to_owned(), )); } - match dict.entries.prev_entry(&mut internal.position) { + match dict.entries.prev_entry(&mut *position) { Some((key, value)) => Ok(($result_fn)(vm, key, value)), None => { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } } else { Err(vm.new_stop_iteration()) } - // match zelf.status.load() { - // IterStatus::Exhausted => Err(vm.new_stop_iteration()), - // IterStatus::Active => { - // if zelf.dict.entries.has_changed_size(&zelf.size) { - // zelf.status.store(IterStatus::Exhausted); - // return Err(vm.new_runtime_error( - // "dictionary changed size during iteration".to_owned(), - // )); - // } - // match zelf.dict.entries.next_entry_atomic_reversed(&zelf.position) { - // Some((key, value)) => Ok(($result_fn)(vm, key, value)), - // None => { - // zelf.status.store(IterStatus::Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } }; diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 3dcc589c3d..7ad93f55f8 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -72,10 +72,7 @@ impl SlotIterator for PyEnumerate { #[pyclass(module = false, name = "reversed")] #[derive(Debug)] pub struct PyReverseSequenceIterator { - internal: PyRwLock, - // pub position: AtomicCell, - // pub status: AtomicCell, - // pub obj: PyObjectRef, + internal: PyRwLock>, } impl PyValue for PyReverseSequenceIterator { @@ -89,10 +86,7 @@ impl PyReverseSequenceIterator { pub fn new(obj: PyObjectRef, len: usize) -> Self { let position = len.saturating_sub(1); Self { - internal: PyRwLock::new(PositionIterInternal::new(obj, position)) - // position: AtomicCell::new(len.saturating_sub(1)), - // status: AtomicCell::new(if len == 0 { Exhausted } else { Active }), - // obj, + internal: PyRwLock::new(PositionIterInternal::new(obj, position)), } } @@ -101,50 +95,18 @@ impl PyReverseSequenceIterator { self.internal .read() .rev_length_hint(|obj| vm.obj_len(obj).ok(), vm) - // Ok(match self.status.load() { - // Active => { - // let position = self.position.load(); - // if position > vm.obj_len(&self.obj)? { - // 0 - // } else { - // position + 1 - // } - // } - // Exhausted => 0, - // }) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.read().set_state(state, vm) - // // When we're exhausted, just return. - // if let Exhausted = self.status.load() { - // return Ok(()); - // } - // let len = vm.obj_len(&self.obj)?; - // let pos = state - // .payload::() - // .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - // let pos = std::cmp::min( - // try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - // len.saturating_sub(1), - // ); - // self.position.store(pos); - // Ok(()) + self.internal.write().set_state(state, vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; - Ok(self.internal.read().reduce(iter, vm)) - // Ok(vm.ctx.new_tuple(match self.status.load() { - // Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![])])], - // Active => vec![ - // iter, - // vm.ctx.new_tuple(vec![self.obj.clone()]), - // vm.ctx.new_int(self.position.load()), - // ], - // })) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_reversed_reduce(|x| x.clone(), vm) } } @@ -152,23 +114,8 @@ impl IteratorIterable for PyReverseSequenceIterator {} impl SlotIterator for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal - .read() + .write() .rev_next(|obj, pos| obj.get_item(pos, vm), vm) - // if let Exhausted = zelf.status.load() { - // return Err(vm.new_stop_iteration()); - // } - // let pos = zelf.position.fetch_sub(1); - // if pos == 0 { - // zelf.status.store(Exhausted); - // } - // match zelf.obj.get_item(pos, vm) { - // Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - // zelf.status.store(Exhausted); - // Err(vm.new_stop_iteration()) - // } - // // also catches stop_iteration => stop_iteration - // ret => ret, - // } } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 0acae123cd..45e1fa1a81 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -14,21 +14,21 @@ use crossbeam_utils::atomic::AtomicCell; /// Marks status of iterator. #[derive(Debug, Clone)] -pub enum IterStatus { +pub enum IterStatus { /// Iterator hasn't raised StopIteration. - Active(PyObjectRef), + Active(T), /// Iterator has raised StopIteration. Exhausted, } #[derive(Debug)] -pub struct PositionIterInternal { - pub status: IterStatus, +pub struct PositionIterInternal { + pub status: IterStatus, pub position: usize, } -impl PositionIterInternal { - pub fn new(obj: PyObjectRef, position: usize) -> Self { +impl PositionIterInternal { + pub fn new(obj: T, position: usize) -> Self { Self { status: IterStatus::Active(obj), position, @@ -49,11 +49,14 @@ impl PositionIterInternal { } } - pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + fn _reduce(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { if let IterStatus::Active(obj) = &self.status { vm.ctx.new_tuple(vec![ func, - vm.ctx.new_tuple(vec![obj.clone()]), + vm.ctx.new_tuple(vec![f(obj)]), vm.ctx.new_int(self.position), ]) } else { @@ -62,9 +65,25 @@ impl PositionIterInternal { } } + pub fn builtin_iter_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + let iter = vm.get_attribute(vm.builtins.clone(), "iter").unwrap(); + self._reduce(iter, f, vm) + } + + pub fn builtin_reversed_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + let reversed = vm.get_attribute(vm.builtins.clone(), "reversed").unwrap(); + self._reduce(reversed, f, vm) + } + pub fn next(&mut self, f: F, vm: &VirtualMachine) -> PyResult where - F: FnOnce(&PyObjectRef, usize) -> PyResult, + F: FnOnce(&T, usize) -> PyResult, { if let IterStatus::Active(obj) = &self.status { match f(obj, self.position) { @@ -89,7 +108,7 @@ impl PositionIterInternal { pub fn rev_next(&mut self, f: F, vm: &VirtualMachine) -> PyResult where - F: FnOnce(&PyObjectRef, usize) -> PyResult, + F: FnOnce(&T, usize) -> PyResult, { if let IterStatus::Active(obj) = &self.status { match f(obj, self.position) { @@ -118,7 +137,7 @@ impl PositionIterInternal { pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where - F: FnOnce(&PyObjectRef) -> Option, + F: FnOnce(&T) -> Option, { let len = if let IterStatus::Active(obj) = &self.status { if let Some(obj_len) = f(obj) { @@ -134,7 +153,7 @@ impl PositionIterInternal { pub fn rev_length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where - F: FnOnce(&PyObjectRef) -> Option, + F: FnOnce(&T) -> Option, { if let IterStatus::Active(obj) = &self.status { if let Some(obj_len) = f(obj) { @@ -142,6 +161,7 @@ impl PositionIterInternal { return PyInt::from(self.position + 1).into_object(vm); } } + // FIXME: return NotImplemented? } PyInt::from(0).into_object(vm) } @@ -150,7 +170,7 @@ impl PositionIterInternal { #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - internal: PositionIterInternal, + internal: PyRwLock>, } impl PyValue for PySequenceIterator { @@ -163,31 +183,34 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - internal: PositionIterInternal::new(obj, 0), + internal: PyRwLock::new(PositionIterInternal::new(obj, 0)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint(|obj| vm.obj_len(obj).ok(), vm) + self.internal + .read() + .length_hint(|obj| vm.obj_len(obj).ok(), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal.read().builtin_iter_reduce(|x| x.clone(), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.set_state(state, vm) + self.internal.write().set_state(state, vm) } } impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.next(|obj, pos| obj.get_item(pos, vm), vm) + zelf.internal + .write() + .next(|obj, pos| obj.get_item(pos, vm), vm) } } @@ -195,7 +218,7 @@ impl SlotIterator for PySequenceIterator { #[derive(Debug)] pub struct PyCallableIterator { sentinel: PyObjectRef, - status: PyRwLock, + status: PyRwLock>, } impl PyValue for PyCallableIterator { @@ -209,7 +232,7 @@ impl PyCallableIterator { pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self { Self { sentinel, - status: PyRwLock::new(IterStatus::Active(callable.into_object())), + status: PyRwLock::new(IterStatus::Active(callable)), } } } @@ -218,7 +241,7 @@ impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let IterStatus::Active(callable) = &*zelf.status.read() { - let ret = vm.invoke(callable, ())?; + let ret = callable.invoke((), vm)?; if vm.bool_eq(&ret, &zelf.sentinel)? { *zelf.status.write() = IterStatus::Exhausted; Err(vm.new_stop_iteration()) diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 4c2e0b2411..665658e052 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -162,16 +162,8 @@ impl PyList { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyListReverseIterator { let position = zelf.len().saturating_sub(1); - // Mark iterator as exhausted immediately if its empty. PyListReverseIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), position)) - // position: AtomicCell::new(final_position.saturating_sub(1)), - // status: AtomicCell::new(if final_position == 0 { - // Exhausted - // } else { - // Active - // }), - // list: zelf, + internal: PyRwLock::new(PositionIterInternal::new(zelf, position)), } } @@ -420,7 +412,7 @@ impl PyList { impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -475,7 +467,7 @@ fn do_sort( #[pyclass(module = false, name = "list_iterator")] #[derive(Debug)] pub struct PyListIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyListIterator { @@ -488,9 +480,7 @@ impl PyValue for PyListIterator { impl PyListIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] @@ -499,9 +489,10 @@ impl PyListIterator { } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -509,8 +500,7 @@ impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let list = obj.payload::().unwrap(); + |list, pos| { let vec = list.borrow_vec(); vec.get(pos) .ok_or_else(|| vm.new_stop_iteration()) @@ -524,9 +514,7 @@ impl SlotIterator for PyListIterator { #[pyclass(module = false, name = "list_reverseiterator")] #[derive(Debug)] pub struct PyListReverseIterator { - internal: PyRwLock, // pub position: AtomicCell, - // pub status: AtomicCell, - // pub list: PyListRef + internal: PyRwLock>, } impl PyValue for PyListReverseIterator { @@ -541,46 +529,19 @@ impl PyListReverseIterator { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .rev_length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => { - // let position = self.position.load(); - // if position > self.list.len() { - // // List was mutated. Report zero, next call to `__next__` will - // // fail and set iterator to Exhausted. - // 0 - // } else { - // position + 1 - // } - // } - // Exhausted => 0, - // } + .rev_length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - // if let Exhausted = self.status.load() { - // return Ok(()); - // } - - // // Max for position is list.len() - 1. - // let position = list_state(self.list.len().saturating_sub(1), state, vm)?; - // self.position.store(position); - // Ok(()) self.internal.write().set_state(state, vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; - Ok(self.internal.read().reduce(iter, vm)) - // let pos = if let Exhausted = self.status.load() { - // None - // } else { - // Some(self.position.load()) - // }; - // list_reduce(self.list.clone(), pos, true, vm) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_reversed_reduce(|x| x.clone().into_object(), vm) } } @@ -588,8 +549,7 @@ impl IteratorIterable for PyListReverseIterator {} impl SlotIterator for PyListReverseIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().rev_next( - |obj, pos| { - let list = obj.payload::().unwrap(); + |list, pos| { let vec = list.borrow_vec(); vec.get(pos) .ok_or_else(|| vm.new_stop_iteration()) @@ -597,60 +557,9 @@ impl SlotIterator for PyListReverseIterator { }, vm, ) - // if let Exhausted = zelf.status.load() { - // return Err(vm.new_stop_iteration()); - // } - // let list = zelf.list.borrow_vec(); - // let pos = zelf.position.fetch_sub(1); - // if pos > 0 { - // if let Some(obj) = list.get(pos) { - // return Ok(obj.clone()); - // } - // } - // // We either are == 0 or list.get returned None. Either way, set status - // // to exhausted and return last item if pos == 0. - // zelf.status.store(Exhausted); - // if pos == 0 { - // if let Some(obj) = list.get(pos) { - // return Ok(obj.clone()); - // } - // } - // Err(vm.new_stop_iteration()) } } -// Common reducer for forward and reverse list iterators. -fn list_reduce( - list: PyRef, - position: Option, - reverse: bool, - vm: &VirtualMachine, -) -> PyResult { - let attr = if reverse { "reversed" } else { "iter" }; - let iter = vm.get_attribute(vm.builtins.clone(), attr)?; - let elems = match position { - None => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])], - Some(position) => vec![ - iter, - vm.ctx.new_tuple(vec![list.into_object()]), - vm.ctx.new_int(position), - ], - }; - Ok(vm.ctx.new_tuple(elems)) -} - -// Common function to extract state. Clamps it in range [0, length]. -fn list_state(length: usize, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let position = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let position = std::cmp::min( - int::try_to_primitive(position.as_bigint(), vm).unwrap_or(0), - length, - ); - Ok(position) -} - pub fn init(context: &PyContext) { let list_type = &context.types.list_type; PyList::extend_class(context, list_type); diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 70e5f33835..71eb569b24 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -19,7 +19,6 @@ use crate::{ PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_common::{ @@ -170,10 +169,7 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - // string: PyStrRef, - // position: PyAtomic, - // status: AtomicCell, - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyStrIterator { @@ -186,49 +182,27 @@ impl PyValue for PyStrIterator { impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => { - // let pos = self.position.load(atomic::Ordering::SeqCst); - // self.string.len().saturating_sub(pos) - // } - // Exhausted => 0, - // } + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal.write().set_state(state, vm) - // // When we're exhausted, just return. - // if let Exhausted = self.status.load() { - // return Ok(()); - // } - // let pos = state - // .payload::() - // .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - // let pos = std::cmp::min( - // try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - // self.string.len(), - // ); - // self.position.store(pos, atomic::Ordering::SeqCst); - // Ok(()) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let internal = zelf.internal.write(); - if let IterStatus::Active(obj) = &internal.status { - let s = obj.payload::().unwrap(); + let mut internal = zelf.internal.write(); + if let IterStatus::Active(s) = &internal.status { let value = s.as_str(); if internal.position >= value.len() { internal.status = Exhausted; @@ -245,31 +219,6 @@ impl SlotIterator for PyStrIterator { } else { Err(vm.new_stop_iteration()) } - // if let Exhausted = zelf.status.load() { - // return Err(vm.new_stop_iteration()); - // } - // let value = &*zelf.string.as_str(); - // let mut start = zelf.position.load(atomic::Ordering::SeqCst); - // loop { - // if start == value.len() { - // zelf.status.store(Exhausted); - // return Err(vm.new_stop_iteration()); - // } - // let ch = value[start..].chars().next().ok_or_else(|| { - // zelf.status.store(Exhausted); - // vm.new_stop_iteration() - // })?; - - // match zelf.position.compare_exchange_weak( - // start, - // start + ch.len_utf8(), - // atomic::Ordering::Release, - // atomic::Ordering::Relaxed, - // ) { - // Ok(_) => break Ok(ch.into_pyobject(vm)), - // Err(cur) => start = cur, - // } - // } } } @@ -1304,10 +1253,7 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - // position: Radium::new(0), - // string: zelf, - // status: AtomicCell::new(Active), - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index f162f58910..7bcaa7988f 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -1,3 +1,4 @@ +use super::PositionIterInternal; /* * Builtin set type with a sequence of unique items. */ @@ -15,7 +16,7 @@ use crate::{ IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::{PyRwLock, PyRwLockWriteGuard}; use std::fmt; pub type SetContentType = dictdatatype::Dict<()>; @@ -194,10 +195,12 @@ impl PySetInner { fn iter(&self) -> PySetIterator { PySetIterator { - dict: PyRc::clone(&self.content), size: self.content.size(), - position: AtomicCell::new(0), - status: AtomicCell::new(IterStatus::Active), + internal: PyRwLock::new(PositionIterInternal::new(self.content.clone(), 0)) + // dict: PyRc::clone(&self.content), + // size: self.content.size(), + // position: AtomicCell::new(0), + // status: AtomicCell::new(IterStatus::Active), } } @@ -815,10 +818,8 @@ impl TryFromObject for SetIterable { #[pyclass(module = false, name = "set_iterator")] pub(crate) struct PySetIterator { - dict: PyRc, size: DictSize, - position: AtomicCell, - status: AtomicCell, + internal: PyRwLock>>, } impl fmt::Debug for PySetIterator { @@ -837,26 +838,27 @@ impl PyValue for PySetIterator { #[pyimpl(with(SlotIterator))] impl PySetIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.len_from_entry_index(self.position.load()) - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .length_hint(|_| Some(self.size.entries_size), vm) + // if let IterStatus::Exhausted = self.status.load() { + // 0 + // } else { + // self.dict.len_from_entry_index(self.position.load()) + // } } #[pymethod(magic)] fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> { + let internal = zelf.internal.read(); Ok(( vm.get_attribute(vm.builtins.clone(), "iter")?, - (vm.ctx.new_list(match zelf.status.load() { + (vm.ctx.new_list(match &internal.status { IterStatus::Exhausted => vec![], - IterStatus::Active => zelf - .dict - .keys() - .into_iter() - .skip(zelf.position.load()) - .collect(), + IterStatus::Active(dict) => { + dict.keys().into_iter().skip(internal.position).collect() + } }),), )) } @@ -864,25 +866,42 @@ impl PySetIterator { impl IteratorIterable for PySetIterator {} impl SlotIterator for PySetIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err( - vm.new_runtime_error("set changed size during iteration".to_owned()) - ); - } - match zelf.dict.next_entry_atomic(&zelf.position) { - Some((key, _)) => Ok(PyIterReturn::Return(key)), - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); + let mut position = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); + if let IterStatus::Active(dict) = &*status { + if dict.has_changed_size(&zelf.size) { + *status = IterStatus::Exhausted; + return Err(vm.new_runtime_error("set changed size during iteration".to_owned())); + } + match dict.next_entry(&mut *position) { + Some((key, _)) => Ok(key), + None => { + *status = IterStatus::Exhausted; + Err(vm.new_stop_iteration()) } } + } else { + Err(vm.new_stop_iteration()) } + // match zelf.status.load() { + // IterStatus::Exhausted => Err(vm.new_stop_iteration()), + // IterStatus::Active => { + // if zelf.dict.has_changed_size(&zelf.size) { + // zelf.status.store(IterStatus::Exhausted); + // return Err( + // vm.new_runtime_error("set changed size during iteration".to_owned()) + // ); + // } + // match zelf.dict.next_entry_atomic(&zelf.position) { + // Some((key, _)) => Ok(key), + // None => { + // zelf.status.store(IterStatus::Exhausted); + // Err(vm.new_stop_iteration()) + // } + // } + // } + // } } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 833a7552d7..c8aa6b95ea 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -310,7 +310,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -319,7 +319,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyTupleIterator { @@ -332,9 +332,7 @@ impl PyValue for PyTupleIterator { impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] @@ -343,9 +341,10 @@ impl PyTupleIterator { } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -353,8 +352,7 @@ impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let tuple = obj.payload::().unwrap(); + |tuple, pos| { tuple .as_slice() .get(pos) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 996b757789..6bb2cbd8ab 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -110,8 +110,8 @@ struct DictEntry { #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, - entries_size: usize, - used: usize, + pub entries_size: usize, + pub used: usize, filled: usize, } diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 8c4aa28c9f..8809b2a4a2 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -6,7 +6,7 @@ mod _collections { use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::{ builtins::{ - IterStatus::{self, Active, Exhausted}, + IterStatus::{Active, Exhausted}, PyInt, PyTypeRef, }, function::{FuncArgs, KwArgs, OptionalArg}, @@ -358,13 +358,10 @@ mod _collections { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyResult { - // let length = zelf.len(); + let position = zelf.len().saturating_sub(1); Ok(PyReverseDequeIterator { state: zelf.state.load(), - internal: PyRwLock::new(PositionIterInternal::new( - zelf.into_object(), - zelf.len() - 1, - )), + internal: PyRwLock::new(PositionIterInternal::new(zelf, position)), }) } @@ -575,12 +572,8 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug, PyValue)] struct PyDequeIterator { - // position: AtomicCell, - // status: AtomicCell, - // length: usize, // To track length immutability. - // deque: PyDequeRef, state: usize, - internal: PyRwLock, + internal: PyRwLock>, } #[derive(FromArgs)] @@ -600,16 +593,10 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - // let len = deque.len(); let iter = PyDequeIterator::new(deque); if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; iter.internal.write().position = index; - // iter.position.store(min(index, len)); - - // if len.le(&index) { - // iter.status.store(Exhausted); - // } } iter.into_pyresult_with_type(vm, cls) } @@ -619,35 +606,25 @@ mod _collections { impl PyDequeIterator { pub(crate) fn new(deque: PyDequeRef) -> Self { PyDequeIterator { - // position: AtomicCell::new(0), - // status: AtomicCell::new(IterStatus::Active), - // length: deque.len(), - // deque, state: deque.state.load(), - internal: PyRwLock::new(PositionIterInternal::new(deque.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(deque, 0)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => self.deque.len().saturating_sub(self.position.load()), - // Exhausted => 0, - // } + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> (PyTypeRef, (PyObjectRef, PyObjectRef)) { + ) -> (PyTypeRef, (PyDequeRef, PyObjectRef)) { let internal = zelf.internal.read(); let deque = match &internal.status { Active(obj) => obj.clone(), - Exhausted => PyDeque::default().into_object(vm), + Exhausted => PyDeque::default().into_ref(vm), }; ( zelf.clone_class(), @@ -660,8 +637,7 @@ mod _collections { impl SlotIterator for PyDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let deque = obj.payload::().unwrap(); + |deque, pos| { if zelf.state != deque.state.load() { return Err( vm.new_runtime_error("Deque mutated during iteration".to_owned()) @@ -675,26 +651,6 @@ mod _collections { }, vm, ) - // match zelf.status.load() { - // Exhausted => Err(vm.new_stop_iteration()), - // Active => { - // if zelf.length != zelf.deque.len() { - // // Deque was changed while we iterated. - // zelf.status.store(Exhausted); - // Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - // } else { - // let pos = zelf.position.fetch_add(1); - // let deque = zelf.deque.borrow_deque(); - // if pos < deque.len() { - // let ret = deque[pos].clone(); - // Ok(ret) - // } else { - // zelf.status.store(Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } @@ -702,12 +658,8 @@ mod _collections { #[pyclass(name = "_deque_reverse_iterator")] #[derive(Debug, PyValue)] struct PyReverseDequeIterator { - // position: AtomicCell, - // status: AtomicCell, - // length: usize, // To track length immutability. - // deque: PyDequeRef, state: usize, - internal: PyRwLock, + internal: PyRwLock>, } impl SlotConstructor for PyReverseDequeIterator { @@ -719,15 +671,10 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - // let len = deque.len(); let iter = PyDeque::reversed(deque)?; if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; iter.internal.write().position = index; - // if len.le(&index) { - // iter.status.store(Exhausted); - // } - // iter.position.store(len.saturating_sub(index)); } iter.into_pyresult_with_type(vm, cls) } @@ -739,22 +686,18 @@ mod _collections { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .rev_length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => self.position.load(), - // Exhausted => 0, - // } + .rev_length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> PyResult<(PyTypeRef, (PyObjectRef, PyObjectRef))> { + ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { let internal = zelf.internal.read(); let deque = match &internal.status { Active(obj) => obj.clone(), - Exhausted => PyDeque::default().into_object(vm), + Exhausted => PyDeque::default().into_ref(vm), }; Ok(( zelf.clone_class(), @@ -767,8 +710,7 @@ mod _collections { impl SlotIterator for PyReverseDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().rev_next( - |obj, pos| { - let deque = obj.payload::().unwrap(); + |deque, pos| { if deque.state.load() != zelf.state { return Err( vm.new_runtime_error("Deque mutated during iteration".to_owned()) @@ -782,32 +724,6 @@ mod _collections { }, vm, ) - // match zelf.status.load() { - // Exhausted => Err(vm.new_stop_iteration()), - // Active => { - // // If length changes while we iterate, set to Exhausted and bail. - // if zelf.length != zelf.deque.len() { - // zelf.status.store(Exhausted); - // Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - // } else { - // let pos = zelf.position.fetch_sub(1) - 1; - // let deque = zelf.deque.borrow_deque(); - // if pos > 0 { - // if let Some(obj) = deque.get(pos) { - // return Ok(obj.clone()); - // } - // } - // // We either are == 0 or deque.get returned None. Either way, set status - // // to exhausted and return last item if pos == 0. - // zelf.status.store(Exhausted); - // if pos == 0 { - // // Can safely index directly. - // return Ok(deque[pos].clone()); - // } - // Err(vm.new_stop_iteration()) - // } - // } - // } } } } From 4676dd902eea32839944c80ff0adfb8056cb47cf Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 23 Sep 2021 10:05:05 +0200 Subject: [PATCH 07/19] clear up dictdatatype and fix arithmetic overflow --- vm/src/builtins/iter.rs | 2 +- vm/src/dictdatatype.rs | 29 +---------------------------- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 45e1fa1a81..caaed7dcec 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -218,7 +218,7 @@ impl SlotIterator for PySequenceIterator { #[derive(Debug)] pub struct PyCallableIterator { sentinel: PyObjectRef, - status: PyRwLock>, + status: PyRwLock>, } impl PyValue for PyCallableIterator { diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 6bb2cbd8ab..c542eb439d 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -9,7 +9,6 @@ use crate::common::{ }; use crate::vm::VirtualMachine; use crate::{IdProtocol, IntoPyObject, PyObjectRef, PyRefExact, PyResult, TypeProtocol}; -use crossbeam_utils::atomic::AtomicCell; use std::fmt; use std::mem::size_of; @@ -501,34 +500,8 @@ impl Dict { } } - pub fn next_entry_atomic(&self, position: &AtomicCell) -> Option<(PyObjectRef, T)> { - let inner = self.read(); - loop { - let position_usize = position.fetch_add(1); - let entry = inner.entries.get(position_usize)?; - if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); - } - } - } - - pub fn next_entry_atomic_reversed( - &self, - position: &AtomicCell, - ) -> Option<(PyObjectRef, T)> { - let inner = self.read(); - loop { - let position_usize = position.fetch_add(1); - let position_index = inner.entries.len().checked_sub(position_usize + 1)?; - let entry = inner.entries.get(position_index)?; - if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); - } - } - } - pub fn len_from_entry_index(&self, position: EntryIndex) -> usize { - self.read().entries.len() - position + self.read().entries.len().saturating_sub(position) } pub fn has_changed_size(&self, old: &DictSize) -> bool { From 1cd3d0395149e1da64e3c95082cc438b16cc0a5f Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 23 Sep 2021 11:27:03 +0200 Subject: [PATCH 08/19] fix deadlock and clear up --- vm/src/builtins/dict.rs | 42 ++++++++++++++++++++++++----------------- vm/src/builtins/iter.rs | 7 +++++-- vm/src/builtins/set.rs | 41 ++++++++++------------------------------ vm/src/dictdatatype.rs | 16 ++++++++-------- 4 files changed, 48 insertions(+), 58 deletions(-) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 39a6120660..d4ff3bd86a 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -587,7 +587,9 @@ impl Iterator for DictIter { type Item = (PyObjectRef, PyObjectRef); fn next(&mut self) -> Option { - self.dict.entries.next_entry(&mut self.position) + let (position, key, value) = self.dict.entries.next_entry(self.position)?; + self.position = position; + Some((key, value)) } fn size_hint(&self) -> (usize, Option) { @@ -735,20 +737,21 @@ macro_rules! dict_iterator { impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); - let mut position = - PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); - if let IterStatus::Active(dict) = &*status { + let mut internal = zelf.internal.write(); + if let IterStatus::Active(dict) = &internal.status { if dict.entries.has_changed_size(&zelf.size) { - *status = IterStatus::Exhausted; + internal.status = IterStatus::Exhausted; return Err(vm.new_runtime_error( "dictionary changed size during iteration".to_owned(), )); } - match dict.entries.next_entry(&mut *position) { - Some((key, value)) => Ok(($result_fn)(vm, key, value)), + match dict.entries.next_entry(internal.position) { + Some((position, key, value)) => { + internal.position = position; + Ok(($result_fn)(vm, key, value)) + } None => { - *status = IterStatus::Exhausted; + internal.status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } @@ -794,20 +797,25 @@ macro_rules! dict_iterator { impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); - let mut position = - PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); - if let IterStatus::Active(dict) = &*status { + let mut internal = zelf.internal.write(); + if let IterStatus::Active(dict) = &internal.status { if dict.entries.has_changed_size(&zelf.size) { - *status = IterStatus::Exhausted; + internal.status = IterStatus::Exhausted; return Err(vm.new_runtime_error( "dictionary changed size during iteration".to_owned(), )); } - match dict.entries.prev_entry(&mut *position) { - Some((key, value)) => Ok(($result_fn)(vm, key, value)), + match dict.entries.prev_entry(internal.position) { + Some((position, key, value)) => { + if internal.position == position { + internal.status = IterStatus::Exhausted; + } else { + internal.position = position; + } + Ok(($result_fn)(vm, key, value)) + } None => { - *status = IterStatus::Exhausted; + internal.status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index caaed7dcec..b9e53f62c2 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -240,10 +240,13 @@ impl PyCallableIterator { impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let IterStatus::Active(callable) = &*zelf.status.read() { + // let mut status = zelf.status.write(); + let status = zelf.status.upgradable_read(); + if let IterStatus::Active(callable) = &*status { let ret = callable.invoke((), vm)?; if vm.bool_eq(&ret, &zelf.sentinel)? { - *zelf.status.write() = IterStatus::Exhausted; + // *status = IterStatus::Exhausted; + *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } else { Ok(ret) diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index 7bcaa7988f..e572691656 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -16,7 +16,7 @@ use crate::{ IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use rustpython_common::lock::{PyRwLock, PyRwLockWriteGuard}; +use rustpython_common::lock::PyRwLock; use std::fmt; pub type SetContentType = dictdatatype::Dict<()>; @@ -842,11 +842,6 @@ impl PySetIterator { self.internal .read() .length_hint(|_| Some(self.size.entries_size), vm) - // if let IterStatus::Exhausted = self.status.load() { - // 0 - // } else { - // self.dict.len_from_entry_index(self.position.load()) - // } } #[pymethod(magic)] @@ -867,41 +862,25 @@ impl PySetIterator { impl IteratorIterable for PySetIterator {} impl SlotIterator for PySetIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); - let mut position = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); - if let IterStatus::Active(dict) = &*status { + let mut internal = zelf.internal.write(); + if let IterStatus::Active(dict) = &internal.status { if dict.has_changed_size(&zelf.size) { - *status = IterStatus::Exhausted; + internal.status = IterStatus::Exhausted; return Err(vm.new_runtime_error("set changed size during iteration".to_owned())); } - match dict.next_entry(&mut *position) { - Some((key, _)) => Ok(key), + match dict.next_entry(internal.position) { + Some((position, key, _)) => { + internal.position = position; + Ok(key) + } None => { - *status = IterStatus::Exhausted; + internal.status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } } else { Err(vm.new_stop_iteration()) } - // match zelf.status.load() { - // IterStatus::Exhausted => Err(vm.new_stop_iteration()), - // IterStatus::Active => { - // if zelf.dict.has_changed_size(&zelf.size) { - // zelf.status.store(IterStatus::Exhausted); - // return Err( - // vm.new_runtime_error("set changed size during iteration".to_owned()) - // ); - // } - // match zelf.dict.next_entry_atomic(&zelf.position) { - // Some((key, _)) => Ok(key), - // None => { - // zelf.status.store(IterStatus::Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index c542eb439d..cfc6426b3c 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -478,24 +478,24 @@ impl Dict { self.read().size() } - pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(PyObjectRef, T)> { + pub fn next_entry(&self, mut position: EntryIndex) -> Option<(usize, PyObjectRef, T)> { let inner = self.read(); loop { - let entry = inner.entries.get(*position)?; - *position += 1; + let entry = inner.entries.get(position)?; + position += 1; if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); + break Some((position, entry.key.clone(), entry.value.clone())); } } } - pub fn prev_entry(&self, position: &mut EntryIndex) -> Option<(PyObjectRef, T)> { + pub fn prev_entry(&self, mut position: EntryIndex) -> Option<(usize, PyObjectRef, T)> { let inner = self.read(); loop { - let entry = inner.entries.get(*position)?; - *position = position.checked_sub(1)?; + let entry = inner.entries.get(position)?; + position = position.saturating_sub(1); if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); + break Some((position, entry.key.clone(), entry.value.clone())); } } } From eab3609bde7867e9f31d870f9e97b62e36b21717 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 25 Sep 2021 11:04:35 +0200 Subject: [PATCH 09/19] fix __length_hint__ --- vm/src/builtins/bytearray.rs | 8 +++-- vm/src/builtins/bytes.rs | 6 ++-- vm/src/builtins/dict.rs | 13 +++----- vm/src/builtins/enumerate.rs | 21 ++++++------ vm/src/builtins/iter.rs | 65 +++++++++++++++++++----------------- vm/src/builtins/list.rs | 17 +++------- vm/src/builtins/pystr.rs | 10 +++--- vm/src/builtins/set.rs | 6 ++-- vm/src/builtins/tuple.rs | 12 +++---- vm/src/stdlib/collections.rs | 10 +++--- 10 files changed, 77 insertions(+), 91 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 8b1a8011cf..ecf5183527 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -1,5 +1,7 @@ //! Implementation of the python bytearray object. -use super::{PyBytes, PyBytesRef, PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef}; +use super::{ + PositionIterInternal, PyBytes, PyBytesRef, PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef, +}; use crate::common::{ borrow::{BorrowedValue, BorrowedValueMut}, lock::{ @@ -742,8 +744,8 @@ impl PyValue for PyByteArrayIterator { #[pyimpl(with(SlotIterator))] impl PyByteArrayIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 200937a94c..32ac819f25 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -1,4 +1,4 @@ -use super::{PyDictRef, PyInt, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef, int}; +use super::{PositionIterInternal, PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef}; use crate::{ anystr::{self, AnyStr}, bytesinner::{ @@ -595,8 +595,8 @@ impl PyValue for PyBytesIterator { #[pyimpl(with(SlotIterator))] impl PyBytesIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index d4ff3bd86a..dd688496cc 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -1,4 +1,4 @@ -use super::{IterStatus, PySet, PyStrRef, PyTypeRef}; +use super::{IterStatus, PositionIterInternal, PySet, PyStrRef, PyTypeRef}; use crate::{ builtins::PyBaseExceptionRef, common::ascii, @@ -14,7 +14,6 @@ use crate::{ PyAttributes, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; use rustpython_common::lock::PyRwLock; use std::fmt; use std::mem::size_of; @@ -726,10 +725,8 @@ macro_rules! dict_iterator { } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| Some(obj.entries.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|_| self.size.entries_size) } } @@ -786,10 +783,10 @@ macro_rules! dict_iterator { } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + fn length_hint(&self) -> usize { self.internal .read() - .rev_length_hint(|_| Some(self.size.entries_size), vm) + .rev_length_hint(|_| self.size.entries_size) } } diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 7ad93f55f8..cf6a522d4b 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -1,17 +1,12 @@ -use super::{ - int, - IterStatus::{self, Active, Exhausted}, - PyInt, PyIntRef, PyTypeRef, -}; +use super::{IterStatus, PositionIterInternal, PyIntRef, PyTypeRef}; use crate::common::lock::PyRwLock; use crate::{ function::OptionalArg, protocol::{PyIter, PyIterReturn}, slots::{IteratorIterable, SlotConstructor, SlotIterator}, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, - TypeProtocol, VirtualMachine, + VirtualMachine, }; -use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::Zero; @@ -91,10 +86,14 @@ impl PyReverseSequenceIterator { } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .rev_length_hint(|obj| vm.obj_len(obj).ok(), vm) + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + let internal = self.internal.read(); + if let IterStatus::Active(obj) = &internal.status { + if internal.position <= vm.obj_len(obj)? { + return Ok(internal.position + 1); + } + } + Ok(0) } #[pymethod(magic)] diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index b9e53f62c2..32ae0b06cf 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -10,7 +10,7 @@ use crate::{ ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; -use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::{PyRwLock, PyRwLockUpgradableReadGuard}; /// Marks status of iterator. #[derive(Debug, Clone)] @@ -95,6 +95,10 @@ impl PositionIterInternal { self.status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } + Err(e) if e.isinstance(&vm.ctx.exceptions.runtime_error) => { + self.status = IterStatus::Exhausted; + Err(e) + } Err(e) => Err(e), Ok(ret) => { self.position += 1; @@ -120,6 +124,10 @@ impl PositionIterInternal { self.status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } + Err(e) if e.isinstance(&vm.ctx.exceptions.runtime_error) => { + self.status = IterStatus::Exhausted; + Err(e) + } Err(e) => Err(e), Ok(ret) => { if self.position == 0 { @@ -135,35 +143,27 @@ impl PositionIterInternal { } } - pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + pub fn length_hint(&self, f: F) -> usize where - F: FnOnce(&T) -> Option, - { - let len = if let IterStatus::Active(obj) = &self.status { - if let Some(obj_len) = f(obj) { - obj_len.saturating_sub(self.position) - } else { - return vm.ctx.not_implemented(); - } - } else { - 0 - }; - PyInt::from(len).into_object(vm) - } - - pub fn rev_length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef - where - F: FnOnce(&T) -> Option, + F: FnOnce(&T) -> usize, { if let IterStatus::Active(obj) = &self.status { - if let Some(obj_len) = f(obj) { - if self.position <= obj_len { - return PyInt::from(self.position + 1).into_object(vm); - } - } - // FIXME: return NotImplemented? + f(obj).saturating_sub(self.position) + } else { + 0 } - PyInt::from(0).into_object(vm) + } + + pub fn rev_length_hint(&self, f: F) -> usize + where + F: FnOnce(&T) -> usize, + { + if let IterStatus::Active(obj) = &self.status { + if self.position <= f(obj) { + return self.position + 1; + } + } + 0 } } @@ -189,9 +189,14 @@ impl PySequenceIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| vm.obj_len(obj).ok(), vm) + let internal = self.internal.read(); + if let IterStatus::Active(obj) = &internal.status { + vm.obj_len(obj) + .map(|x| PyInt::from(x).into_object(vm)) + .unwrap_or_else(|_| vm.ctx.not_implemented()) + } else { + PyInt::from(0).into_object(vm) + } } #[pymethod(magic)] @@ -240,12 +245,10 @@ impl PyCallableIterator { impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - // let mut status = zelf.status.write(); let status = zelf.status.upgradable_read(); if let IterStatus::Active(callable) = &*status { let ret = callable.invoke((), vm)?; if vm.bool_eq(&ret, &zelf.sentinel)? { - // *status = IterStatus::Exhausted; *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } else { diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 665658e052..8f9e0a8b60 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -1,8 +1,4 @@ -use super::{ - int, - iter::IterStatus::{self, Active, Exhausted}, - PyGenericAlias, PyInt, PySliceRef, PyTypeRef, -}; +use super::{PositionIterInternal, PyGenericAlias, PySliceRef, PyTypeRef}; use crate::common::lock::{ PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; @@ -19,7 +15,6 @@ use crate::{ PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; use std::fmt; use std::iter::FromIterator; use std::mem::size_of; @@ -479,8 +474,8 @@ impl PyValue for PyListIterator { #[pyimpl(with(SlotIterator))] impl PyListIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] @@ -526,10 +521,8 @@ impl PyValue for PyListReverseIterator { #[pyimpl(with(SlotIterator))] impl PyListReverseIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .rev_length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().rev_length_hint(|obj| obj.len()) } #[pymethod(magic)] diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 71eb569b24..a8672f25b7 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -1,7 +1,7 @@ use super::{ - int::{try_to_primitive, PyInt, PyIntRef}, - iter::IterStatus::{self, Active, Exhausted}, - PyBytesRef, PyDict, PyTypeRef, + int::{PyInt, PyIntRef}, + iter::IterStatus::{self, Exhausted}, + PositionIterInternal, PyBytesRef, PyDict, PyTypeRef, }; use crate::{ anystr::{self, adjust_indices, AnyStr, AnyStrContainer, AnyStrWrapper}, @@ -181,8 +181,8 @@ impl PyValue for PyStrIterator { #[pyimpl(with(SlotIterator))] impl PyStrIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index e572691656..c6047c1496 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -838,10 +838,8 @@ impl PyValue for PySetIterator { #[pyimpl(with(SlotIterator))] impl PySetIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|_| Some(self.size.entries_size), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|_| self.size.entries_size) } #[pymethod(magic)] diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index c8aa6b95ea..e23d2cc629 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -1,8 +1,4 @@ -use super::{ - int, - iter::IterStatus::{self, Active, Exhausted}, - PyInt, PyTypeRef, -}; +use super::{PositionIterInternal, PyTypeRef}; use crate::common::hash::PyHash; use crate::{ function::OptionalArg, @@ -19,7 +15,7 @@ use crate::{ PyContext, PyObjectRef, PyRef, PyResult, PyValue, TransmuteFromObject, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::PyRwLock; use std::fmt; use std::marker::PhantomData; @@ -331,8 +327,8 @@ impl PyValue for PyTupleIterator { #[pyimpl(with(SlotIterator))] impl PyTupleIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 8809b2a4a2..9ae8bac31b 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -612,8 +612,8 @@ mod _collections { } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] @@ -683,10 +683,8 @@ mod _collections { #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyReverseDequeIterator { #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .rev_length_hint(|obj| Some(obj.len()), vm) + fn length_hint(&self) -> usize { + self.internal.read().rev_length_hint(|obj| obj.len()) } #[pymethod(magic)] From 1ade56fa8465162032726a43180011955422e559 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 26 Sep 2021 09:45:52 +0200 Subject: [PATCH 10/19] fix deque reversed iterator use the back counting --- vm/src/stdlib/collections.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 9ae8bac31b..dea1084589 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -358,10 +358,9 @@ mod _collections { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyResult { - let position = zelf.len().saturating_sub(1); Ok(PyReverseDequeIterator { state: zelf.state.load(), - internal: PyRwLock::new(PositionIterInternal::new(zelf, position)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), }) } @@ -659,6 +658,7 @@ mod _collections { #[derive(Debug, PyValue)] struct PyReverseDequeIterator { state: usize, + // position is counting from the tail internal: PyRwLock>, } @@ -684,7 +684,7 @@ mod _collections { impl PyReverseDequeIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().rev_length_hint(|obj| obj.len()) + self.internal.read().length_hint(|obj| obj.len()) } #[pymethod(magic)] @@ -707,7 +707,7 @@ mod _collections { impl IteratorIterable for PyReverseDequeIterator {} impl SlotIterator for PyReverseDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().rev_next( + zelf.internal.write().next( |deque, pos| { if deque.state.load() != zelf.state { return Err( @@ -715,6 +715,10 @@ mod _collections { ); } let deque = deque.borrow_deque(); + let pos = deque + .len() + .checked_sub(pos + 1) + .ok_or_else(|| vm.new_stop_iteration())?; deque .get(pos) .ok_or_else(|| vm.new_stop_iteration()) From 8b9c33c12a7e82d182c5ea5714675baf91c14b93 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 28 Sep 2021 07:31:35 +0200 Subject: [PATCH 11/19] Refactor rwlock -> mutex because multiple read is rare --- vm/src/builtins/bytearray.rs | 14 +++++++------- vm/src/builtins/bytes.rs | 18 ++++++++++-------- vm/src/builtins/dict.rs | 18 +++++++++--------- vm/src/builtins/enumerate.rs | 14 +++++++------- vm/src/builtins/iter.rs | 14 +++++++------- vm/src/builtins/list.rs | 26 +++++++++++++------------- vm/src/builtins/pystr.rs | 12 ++++++------ vm/src/builtins/set.rs | 16 ++++++---------- vm/src/builtins/tuple.rs | 14 +++++++------- vm/src/stdlib/collections.rs | 26 +++++++++++++------------- 10 files changed, 85 insertions(+), 87 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index ecf5183527..21a905069d 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -5,7 +5,7 @@ use super::{ use crate::common::{ borrow::{BorrowedValue, BorrowedValueMut}, lock::{ - PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyRwLock, PyRwLockReadGuard, + PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }, }; @@ -719,7 +719,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -732,7 +732,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyByteArrayIterator { @@ -745,24 +745,24 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } } impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().next( + zelf.internal.lock().next( |bytearray, pos| { let buf = bytearray.borrow_buf(); buf.get(pos) diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 32ac819f25..27ccce1dca 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -17,8 +17,10 @@ use crate::{ PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut}; -use rustpython_common::lock::PyRwLock; +use rustpython_common::{ + borrow::{BorrowedValue, BorrowedValueMut}, + lock::PyMutex, +}; use std::mem::size_of; use std::ops::Deref; @@ -574,7 +576,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -583,7 +585,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyBytesIterator { @@ -596,25 +598,25 @@ impl PyValue for PyBytesIterator { impl PyBytesIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } } impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().next( + zelf.internal.lock().next( |bytes, pos| { bytes .as_bytes() diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index dd688496cc..26d56867fd 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -14,7 +14,7 @@ use crate::{ PyAttributes, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use rustpython_common::lock::PyRwLock; +use rustpython_common::lock::PyMutex; use std::fmt; use std::mem::size_of; @@ -706,7 +706,7 @@ macro_rules! dict_iterator { #[derive(Debug)] pub(crate) struct $iter_name { pub size: dictdatatype::DictSize, - pub internal: PyRwLock>, + pub internal: PyMutex>, } impl PyValue for $iter_name { @@ -720,13 +720,13 @@ macro_rules! dict_iterator { fn new(dict: PyDictRef) -> Self { $iter_name { size: dict.size(), - internal: PyRwLock::new(PositionIterInternal::new(dict, 0)), + internal: PyMutex::new(PositionIterInternal::new(dict, 0)), } } #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|_| self.size.entries_size) + self.internal.lock().length_hint(|_| self.size.entries_size) } } @@ -734,7 +734,7 @@ macro_rules! dict_iterator { impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut internal = zelf.internal.write(); + let mut internal = zelf.internal.lock(); if let IterStatus::Active(dict) = &internal.status { if dict.entries.has_changed_size(&zelf.size) { internal.status = IterStatus::Exhausted; @@ -762,7 +762,7 @@ macro_rules! dict_iterator { #[derive(Debug)] pub(crate) struct $reverse_iter_name { pub size: dictdatatype::DictSize, - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for $reverse_iter_name { @@ -778,14 +778,14 @@ macro_rules! dict_iterator { let position = size.entries_size.saturating_sub(1); $reverse_iter_name { size, - internal: PyRwLock::new(PositionIterInternal::new(dict, position)), + internal: PyMutex::new(PositionIterInternal::new(dict, position)), } } #[pymethod(magic)] fn length_hint(&self) -> usize { self.internal - .read() + .lock() .rev_length_hint(|_| self.size.entries_size) } } @@ -794,7 +794,7 @@ macro_rules! dict_iterator { impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut internal = zelf.internal.write(); + let mut internal = zelf.internal.lock(); if let IterStatus::Active(dict) = &internal.status { if dict.entries.has_changed_size(&zelf.size) { internal.status = IterStatus::Exhausted; diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index cf6a522d4b..2e5d6b35ee 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -1,5 +1,5 @@ use super::{IterStatus, PositionIterInternal, PyIntRef, PyTypeRef}; -use crate::common::lock::PyRwLock; +use crate::common::lock::{PyMutex, PyRwLock}; use crate::{ function::OptionalArg, protocol::{PyIter, PyIterReturn}, @@ -67,7 +67,7 @@ impl SlotIterator for PyEnumerate { #[pyclass(module = false, name = "reversed")] #[derive(Debug)] pub struct PyReverseSequenceIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyReverseSequenceIterator { @@ -81,13 +81,13 @@ impl PyReverseSequenceIterator { pub fn new(obj: PyObjectRef, len: usize) -> Self { let position = len.saturating_sub(1); Self { - internal: PyRwLock::new(PositionIterInternal::new(obj, position)), + internal: PyMutex::new(PositionIterInternal::new(obj, position)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyResult { - let internal = self.internal.read(); + let internal = self.internal.lock(); if let IterStatus::Active(obj) = &internal.status { if internal.position <= vm.obj_len(obj)? { return Ok(internal.position + 1); @@ -98,13 +98,13 @@ impl PyReverseSequenceIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_reversed_reduce(|x| x.clone(), vm) } } @@ -113,7 +113,7 @@ impl IteratorIterable for PyReverseSequenceIterator {} impl SlotIterator for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal - .write() + .lock() .rev_next(|obj, pos| obj.get_item(pos, vm), vm) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 32ae0b06cf..6b5e2da555 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -10,7 +10,7 @@ use crate::{ ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; -use rustpython_common::lock::{PyRwLock, PyRwLockUpgradableReadGuard}; +use rustpython_common::lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}; /// Marks status of iterator. #[derive(Debug, Clone)] @@ -170,7 +170,7 @@ impl PositionIterInternal { #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PySequenceIterator { @@ -183,13 +183,13 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - internal: PyRwLock::new(PositionIterInternal::new(obj, 0)), + internal: PyMutex::new(PositionIterInternal::new(obj, 0)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - let internal = self.internal.read(); + let internal = self.internal.lock(); if let IterStatus::Active(obj) = &internal.status { vm.obj_len(obj) .map(|x| PyInt::from(x).into_object(vm)) @@ -201,12 +201,12 @@ impl PySequenceIterator { #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.read().builtin_iter_reduce(|x| x.clone(), vm) + self.internal.lock().builtin_iter_reduce(|x| x.clone(), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } } @@ -214,7 +214,7 @@ impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal - .write() + .lock() .next(|obj, pos| obj.get_item(pos, vm), vm) } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 8f9e0a8b60..a00c72a59e 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -1,6 +1,6 @@ use super::{PositionIterInternal, PyGenericAlias, PySliceRef, PyTypeRef}; use crate::common::lock::{ - PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, + PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; use crate::{ function::{ArgIterable, FuncArgs, OptionalArg}, @@ -158,7 +158,7 @@ impl PyList { fn reversed(zelf: PyRef) -> PyListReverseIterator { let position = zelf.len().saturating_sub(1); PyListReverseIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf, position)), + internal: PyMutex::new(PositionIterInternal::new(zelf, position)), } } @@ -407,7 +407,7 @@ impl PyList { impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -462,7 +462,7 @@ fn do_sort( #[pyclass(module = false, name = "list_iterator")] #[derive(Debug)] pub struct PyListIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyListIterator { @@ -475,18 +475,18 @@ impl PyValue for PyListIterator { impl PyListIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -494,7 +494,7 @@ impl PyListIterator { impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().next( + zelf.internal.lock().next( |list, pos| { let vec = list.borrow_vec(); vec.get(pos) @@ -509,7 +509,7 @@ impl SlotIterator for PyListIterator { #[pyclass(module = false, name = "list_reverseiterator")] #[derive(Debug)] pub struct PyListReverseIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyListReverseIterator { @@ -522,18 +522,18 @@ impl PyValue for PyListReverseIterator { impl PyListReverseIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().rev_length_hint(|obj| obj.len()) + self.internal.lock().rev_length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_reversed_reduce(|x| x.clone().into_object(), vm) } } @@ -541,7 +541,7 @@ impl PyListReverseIterator { impl IteratorIterable for PyListReverseIterator {} impl SlotIterator for PyListReverseIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().rev_next( + zelf.internal.lock().rev_next( |list, pos| { let vec = list.borrow_vec(); vec.get(pos) diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index a8672f25b7..dc154b11b9 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -169,7 +169,7 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyStrIterator { @@ -182,18 +182,18 @@ impl PyValue for PyStrIterator { impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -201,7 +201,7 @@ impl PyStrIterator { impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut internal = zelf.internal.write(); + let mut internal = zelf.internal.lock(); if let IterStatus::Active(s) = &internal.status { let value = s.as_str(); if internal.position >= value.len() { @@ -1253,7 +1253,7 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index c6047c1496..7bb5109680 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -16,7 +16,7 @@ use crate::{ IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use rustpython_common::lock::PyRwLock; +use rustpython_common::lock::PyMutex; use std::fmt; pub type SetContentType = dictdatatype::Dict<()>; @@ -196,11 +196,7 @@ impl PySetInner { fn iter(&self) -> PySetIterator { PySetIterator { size: self.content.size(), - internal: PyRwLock::new(PositionIterInternal::new(self.content.clone(), 0)) - // dict: PyRc::clone(&self.content), - // size: self.content.size(), - // position: AtomicCell::new(0), - // status: AtomicCell::new(IterStatus::Active), + internal: PyMutex::new(PositionIterInternal::new(self.content.clone(), 0)), } } @@ -819,7 +815,7 @@ impl TryFromObject for SetIterable { #[pyclass(module = false, name = "set_iterator")] pub(crate) struct PySetIterator { size: DictSize, - internal: PyRwLock>>, + internal: PyMutex>>, } impl fmt::Debug for PySetIterator { @@ -839,12 +835,12 @@ impl PyValue for PySetIterator { impl PySetIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|_| self.size.entries_size) + self.internal.lock().length_hint(|_| self.size.entries_size) } #[pymethod(magic)] fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> { - let internal = zelf.internal.read(); + let internal = zelf.internal.lock(); Ok(( vm.get_attribute(vm.builtins.clone(), "iter")?, (vm.ctx.new_list(match &internal.status { @@ -860,7 +856,7 @@ impl PySetIterator { impl IteratorIterable for PySetIterator {} impl SlotIterator for PySetIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let mut internal = zelf.internal.write(); + let mut internal = zelf.internal.lock(); if let IterStatus::Active(dict) = &internal.status { if dict.has_changed_size(&zelf.size) { internal.status = IterStatus::Exhausted; diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index e23d2cc629..de3f427d3f 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -15,7 +15,7 @@ use crate::{ PyContext, PyObjectRef, PyRef, PyResult, PyValue, TransmuteFromObject, TryFromObject, TypeProtocol, }; -use rustpython_common::lock::PyRwLock; +use rustpython_common::lock::PyMutex; use std::fmt; use std::marker::PhantomData; @@ -306,7 +306,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -315,7 +315,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - internal: PyRwLock>, + internal: PyMutex>, } impl PyValue for PyTupleIterator { @@ -328,18 +328,18 @@ impl PyValue for PyTupleIterator { impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.write().set_state(state, vm) + self.internal.lock().set_state(state, vm) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal - .read() + .lock() .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -347,7 +347,7 @@ impl PyTupleIterator { impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().next( + zelf.internal.lock().next( |tuple, pos| { tuple .as_slice() diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index dea1084589..6444e8198b 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -3,7 +3,7 @@ pub(crate) use _collections::make_module; #[pymodule] mod _collections { use crate::builtins::PositionIterInternal; - use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; + use crate::common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::{ builtins::{ IterStatus::{Active, Exhausted}, @@ -360,7 +360,7 @@ mod _collections { fn reversed(zelf: PyRef) -> PyResult { Ok(PyReverseDequeIterator { state: zelf.state.load(), - internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), }) } @@ -572,7 +572,7 @@ mod _collections { #[derive(Debug, PyValue)] struct PyDequeIterator { state: usize, - internal: PyRwLock>, + internal: PyMutex>, } #[derive(FromArgs)] @@ -595,7 +595,7 @@ mod _collections { let iter = PyDequeIterator::new(deque); if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; - iter.internal.write().position = index; + iter.internal.lock().position = index; } iter.into_pyresult_with_type(vm, cls) } @@ -606,13 +606,13 @@ mod _collections { pub(crate) fn new(deque: PyDequeRef) -> Self { PyDequeIterator { state: deque.state.load(), - internal: PyRwLock::new(PositionIterInternal::new(deque, 0)), + internal: PyMutex::new(PositionIterInternal::new(deque, 0)), } } #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] @@ -620,7 +620,7 @@ mod _collections { zelf: PyRef, vm: &VirtualMachine, ) -> (PyTypeRef, (PyDequeRef, PyObjectRef)) { - let internal = zelf.internal.read(); + let internal = zelf.internal.lock(); let deque = match &internal.status { Active(obj) => obj.clone(), Exhausted => PyDeque::default().into_ref(vm), @@ -635,7 +635,7 @@ mod _collections { impl IteratorIterable for PyDequeIterator {} impl SlotIterator for PyDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().next( + zelf.internal.lock().next( |deque, pos| { if zelf.state != deque.state.load() { return Err( @@ -659,7 +659,7 @@ mod _collections { struct PyReverseDequeIterator { state: usize, // position is counting from the tail - internal: PyRwLock>, + internal: PyMutex>, } impl SlotConstructor for PyReverseDequeIterator { @@ -674,7 +674,7 @@ mod _collections { let iter = PyDeque::reversed(deque)?; if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; - iter.internal.write().position = index; + iter.internal.lock().position = index; } iter.into_pyresult_with_type(vm, cls) } @@ -684,7 +684,7 @@ mod _collections { impl PyReverseDequeIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.read().length_hint(|obj| obj.len()) + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] @@ -692,7 +692,7 @@ mod _collections { zelf: PyRef, vm: &VirtualMachine, ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { - let internal = zelf.internal.read(); + let internal = zelf.internal.lock(); let deque = match &internal.status { Active(obj) => obj.clone(), Exhausted => PyDeque::default().into_ref(vm), @@ -707,7 +707,7 @@ mod _collections { impl IteratorIterable for PyReverseDequeIterator {} impl SlotIterator for PyReverseDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.write().next( + zelf.internal.lock().next( |deque, pos| { if deque.state.load() != zelf.state { return Err( From fa8df88b5d7c86f9c2a265df84a839b5a3d71d94 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 28 Sep 2021 07:58:22 +0200 Subject: [PATCH 12/19] optimize avoid lookup for builtin iter and reversed --- vm/src/builtins/iter.rs | 16 +++++++++++++--- vm/src/builtins/pystr.rs | 1 + vm/src/builtins/range.rs | 3 ++- vm/src/builtins/set.rs | 8 +++----- vm/src/dictdatatype.rs | 2 +- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 6b5e2da555..1405f4bf59 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -10,7 +10,7 @@ use crate::{ ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; -use rustpython_common::lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}; +use rustpython_common::lock::{OnceCell, PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}; /// Marks status of iterator. #[derive(Debug, Clone)] @@ -69,7 +69,7 @@ impl PositionIterInternal { where F: FnOnce(&T) -> PyObjectRef, { - let iter = vm.get_attribute(vm.builtins.clone(), "iter").unwrap(); + let iter = get_builtin_attribute_iter(vm).clone(); self._reduce(iter, f, vm) } @@ -77,7 +77,7 @@ impl PositionIterInternal { where F: FnOnce(&T) -> PyObjectRef, { - let reversed = vm.get_attribute(vm.builtins.clone(), "reversed").unwrap(); + let reversed = get_builtin_attribute_reversed(vm).clone(); self._reduce(reversed, f, vm) } @@ -167,6 +167,16 @@ impl PositionIterInternal { } } +pub fn get_builtin_attribute_iter(vm: &VirtualMachine) -> &PyObjectRef { + static INSTANCE: OnceCell = OnceCell::new(); + INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "iter").unwrap()) +} + +pub fn get_builtin_attribute_reversed(vm: &VirtualMachine) -> &PyObjectRef { + static INSTANCE: OnceCell = OnceCell::new(); + INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "reversed").unwrap()) +} + #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index dc154b11b9..e1802c23e9 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -25,6 +25,7 @@ use rustpython_common::{ ascii, atomic::{self, PyAtomic, Radium}, hash, + lock::PyMutex, }; use std::mem::size_of; use std::ops::Range; diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index 4f6db0a594..d7b115ee5b 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -1,4 +1,5 @@ use super::{PyInt, PyIntRef, PySlice, PySliceRef, PyTypeRef}; +use crate::builtins::get_builtin_attribute_iter; use crate::common::hash::PyHash; use crate::{ function::{FuncArgs, OptionalArg}, @@ -608,7 +609,7 @@ fn range_iter_reduce( index: usize, vm: &VirtualMachine, ) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + let iter = get_builtin_attribute_iter(vm).clone(); let stop = start.clone() + length * step.clone(); let range = PyRange { start: PyInt::from(start).into_ref(vm), diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index 7bb5109680..1f1b4fdbfc 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -1,9 +1,8 @@ -use super::PositionIterInternal; /* * Builtin set type with a sequence of unique items. */ -use super::{IterStatus, PyDictRef, PyTypeRef}; -use crate::common::{ascii, hash::PyHash, rc::PyRc}; +use super::{get_builtin_attribute_iter, IterStatus, PositionIterInternal, PyDictRef, PyTypeRef}; +use crate::common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc}; use crate::{ dictdatatype::{self, DictSize}, function::{ArgIterable, FuncArgs, OptionalArg, PosArgs}, @@ -16,7 +15,6 @@ use crate::{ IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use rustpython_common::lock::PyMutex; use std::fmt; pub type SetContentType = dictdatatype::Dict<()>; @@ -842,7 +840,7 @@ impl PySetIterator { fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> { let internal = zelf.internal.lock(); Ok(( - vm.get_attribute(vm.builtins.clone(), "iter")?, + get_builtin_attribute_iter(vm).clone(), (vm.ctx.new_list(match &internal.status { IterStatus::Exhausted => vec![], IterStatus::Active(dict) => { diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index cfc6426b3c..8a5cd5a371 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -1,8 +1,8 @@ -use crate::builtins::{PyStr, PyStrRef}; /// Ordered dictionary implementation. /// Inspired by: https://morepypy.blogspot.com/2015/01/faster-more-memory-efficient-and-more.html /// And: https://www.youtube.com/watch?v=p33CVV29OG8 /// And: http://code.activestate.com/recipes/578375/ +use crate::builtins::{PyStr, PyStrRef}; use crate::common::{ hash, lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}, From a0d9ce030f5d5591898097edfef5340b914b22bf Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 28 Sep 2021 09:33:34 +0200 Subject: [PATCH 13/19] fix setstate saturated --- vm/src/builtins/bytearray.rs | 4 +++- vm/src/builtins/bytes.rs | 4 +++- vm/src/builtins/iter.rs | 35 ++++++++++++++++++++++++++++++++--- vm/src/builtins/list.rs | 8 ++++++-- vm/src/builtins/pystr.rs | 4 +++- vm/src/builtins/tuple.rs | 4 +++- 6 files changed, 50 insertions(+), 9 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 21a905069d..3f84751bd9 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -756,7 +756,9 @@ impl PyByteArrayIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } } impl IteratorIterable for PyByteArrayIterator {} diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 27ccce1dca..4a849cab30 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -610,7 +610,9 @@ impl PyBytesIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } } impl IteratorIterable for PyBytesIterator {} diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 1405f4bf59..476c66938c 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -10,7 +10,10 @@ use crate::{ ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; -use rustpython_common::lock::{OnceCell, PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}; +use rustpython_common::{ + lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}, + static_cell, +}; /// Marks status of iterator. #[derive(Debug, Clone)] @@ -49,6 +52,28 @@ impl PositionIterInternal { } } + pub fn set_state_saturated( + &mut self, + state: PyObjectRef, + f: F, + vm: &VirtualMachine, + ) -> PyResult<()> + where + F: FnOnce(&T) -> usize, + { + if let IterStatus::Active(obj) = &self.status { + if let Some(i) = state.payload::() { + let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); + self.position = i.min(f(obj)); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } else { + Ok(()) + } + } + fn _reduce(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyObjectRef where F: FnOnce(&T) -> PyObjectRef, @@ -168,12 +193,16 @@ impl PositionIterInternal { } pub fn get_builtin_attribute_iter(vm: &VirtualMachine) -> &PyObjectRef { - static INSTANCE: OnceCell = OnceCell::new(); + static_cell! { + static INSTANCE: PyObjectRef; + } INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "iter").unwrap()) } pub fn get_builtin_attribute_reversed(vm: &VirtualMachine) -> &PyObjectRef { - static INSTANCE: OnceCell = OnceCell::new(); + static_cell! { + static INSTANCE: PyObjectRef; + } INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "reversed").unwrap()) } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index a00c72a59e..c3ea794d1e 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -480,7 +480,9 @@ impl PyListIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } #[pymethod(magic)] @@ -527,7 +529,9 @@ impl PyListReverseIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index e1802c23e9..6ecd64636c 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -188,7 +188,9 @@ impl PyStrIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.char_len(), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index de3f427d3f..34b4e58200 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -333,7 +333,9 @@ impl PyTupleIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } #[pymethod(magic)] From b4cbca0e8c11f393335b2642375841f8861b86a2 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 30 Sep 2021 21:43:23 +0200 Subject: [PATCH 14/19] rename functions for cached iter and reversed --- vm/src/builtins/iter.rs | 8 ++++---- vm/src/builtins/range.rs | 4 ++-- vm/src/builtins/set.rs | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 476c66938c..baa55eb413 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -94,7 +94,7 @@ impl PositionIterInternal { where F: FnOnce(&T) -> PyObjectRef, { - let iter = get_builtin_attribute_iter(vm).clone(); + let iter = builtins_iter(vm).clone(); self._reduce(iter, f, vm) } @@ -102,7 +102,7 @@ impl PositionIterInternal { where F: FnOnce(&T) -> PyObjectRef, { - let reversed = get_builtin_attribute_reversed(vm).clone(); + let reversed = builtins_reversed(vm).clone(); self._reduce(reversed, f, vm) } @@ -192,14 +192,14 @@ impl PositionIterInternal { } } -pub fn get_builtin_attribute_iter(vm: &VirtualMachine) -> &PyObjectRef { +pub fn builtins_iter(vm: &VirtualMachine) -> &PyObjectRef { static_cell! { static INSTANCE: PyObjectRef; } INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "iter").unwrap()) } -pub fn get_builtin_attribute_reversed(vm: &VirtualMachine) -> &PyObjectRef { +pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObjectRef { static_cell! { static INSTANCE: PyObjectRef; } diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index d7b115ee5b..05c543d90e 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -1,5 +1,5 @@ use super::{PyInt, PyIntRef, PySlice, PySliceRef, PyTypeRef}; -use crate::builtins::get_builtin_attribute_iter; +use crate::builtins::builtins_iter; use crate::common::hash::PyHash; use crate::{ function::{FuncArgs, OptionalArg}, @@ -609,7 +609,7 @@ fn range_iter_reduce( index: usize, vm: &VirtualMachine, ) -> PyResult { - let iter = get_builtin_attribute_iter(vm).clone(); + let iter = builtins_iter(vm).clone(); let stop = start.clone() + length * step.clone(); let range = PyRange { start: PyInt::from(start).into_ref(vm), diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index 1f1b4fdbfc..d4ffc0c338 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -1,7 +1,7 @@ /* * Builtin set type with a sequence of unique items. */ -use super::{get_builtin_attribute_iter, IterStatus, PositionIterInternal, PyDictRef, PyTypeRef}; +use super::{builtins_iter, IterStatus, PositionIterInternal, PyDictRef, PyTypeRef}; use crate::common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc}; use crate::{ dictdatatype::{self, DictSize}, @@ -840,7 +840,7 @@ impl PySetIterator { fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> { let internal = zelf.internal.lock(); Ok(( - get_builtin_attribute_iter(vm).clone(), + builtins_iter(vm).clone(), (vm.ctx.new_list(match &internal.status { IterStatus::Exhausted => vec![], IterStatus::Active(dict) => { From 70fc9102688d2efc73cab1d3e2d68e202b5baa1b Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Oct 2021 14:26:48 +0200 Subject: [PATCH 15/19] refactor set_state --- vm/src/builtins/bytearray.rs | 2 +- vm/src/builtins/bytes.rs | 2 +- vm/src/builtins/enumerate.rs | 2 +- vm/src/builtins/iter.rs | 27 ++++----------------------- vm/src/builtins/list.rs | 4 ++-- vm/src/builtins/pystr.rs | 2 +- vm/src/builtins/tuple.rs | 2 +- 7 files changed, 11 insertions(+), 30 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 3f84751bd9..2a365a22ce 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -758,7 +758,7 @@ impl PyByteArrayIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } } impl IteratorIterable for PyByteArrayIterator {} diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 4a849cab30..316ce8b086 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -612,7 +612,7 @@ impl PyBytesIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } } impl IteratorIterable for PyBytesIterator {} diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 2e5d6b35ee..596d228d01 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -98,7 +98,7 @@ impl PyReverseSequenceIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal.lock().set_state(state, |_, pos| pos, vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index baa55eb413..b285dbd820 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -38,33 +38,14 @@ impl PositionIterInternal { } } - pub fn set_state(&mut self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let IterStatus::Active(_) = &self.status { - if let Some(i) = state.payload::() { - let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); - self.position = i; - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } - } else { - Ok(()) - } - } - - pub fn set_state_saturated( - &mut self, - state: PyObjectRef, - f: F, - vm: &VirtualMachine, - ) -> PyResult<()> + pub fn set_state(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()> where - F: FnOnce(&T) -> usize, + F: FnOnce(&T, usize) -> usize, { if let IterStatus::Active(obj) = &self.status { if let Some(i) = state.payload::() { let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); - self.position = i.min(f(obj)); + self.position = f(obj, i); Ok(()) } else { Err(vm.new_type_error("an integer is required.".to_owned())) @@ -245,7 +226,7 @@ impl PySequenceIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal.lock().set_state(state, |_, pos| pos, vm) } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index c3ea794d1e..8c7ae63593 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -482,7 +482,7 @@ impl PyListIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] @@ -531,7 +531,7 @@ impl PyListReverseIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 6ecd64636c..81b9152c78 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -190,7 +190,7 @@ impl PyStrIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.char_len(), vm) + .set_state(state, |obj, pos| pos.min(obj.char_len()), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 34b4e58200..e692d636c1 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -335,7 +335,7 @@ impl PyTupleIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] From 87edbfece7ffa5140367ea0859a5dc254bbf6fb7 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Oct 2021 15:48:18 +0200 Subject: [PATCH 16/19] fix str iterator position --- extra_tests/snippets/iterations.py | 5 ++++ vm/src/builtins/pystr.rs | 44 ++++++++++++++++++------------ 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/extra_tests/snippets/iterations.py b/extra_tests/snippets/iterations.py index 98031a935e..b2b30961c0 100644 --- a/extra_tests/snippets/iterations.py +++ b/extra_tests/snippets/iterations.py @@ -9,3 +9,8 @@ assert next(i) == 3 assert next(i, 'w00t') == 'w00t' +s = '你好' +i = iter(s) +i.__setstate__(1) +assert i.__next__() == '好' +assert i.__reduce__()[2] == 2 diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 81b9152c78..249af8eafc 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -170,7 +170,7 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - internal: PyMutex>, + internal: PyMutex<(PositionIterInternal, usize)>, } impl PyValue for PyStrIterator { @@ -183,13 +183,15 @@ impl PyValue for PyStrIterator { impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.lock().length_hint(|obj| obj.len()) + self.internal.lock().0.length_hint(|obj| obj.char_len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal - .lock() + let mut internal = self.internal.lock(); + internal.1 = usize::MAX; + internal + .0 .set_state(state, |obj, pos| pos.min(obj.char_len()), vm) } @@ -197,6 +199,7 @@ impl PyStrIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() + .0 .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -205,20 +208,27 @@ impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); - if let IterStatus::Active(s) = &internal.status { + + if let IterStatus::Active(s) = &internal.0.status { let value = s.as_str(); - if internal.position >= value.len() { - internal.status = Exhausted; - return Err(vm.new_stop_iteration()); + + if internal.1 == usize::MAX { + if let Some((offset, ch)) = value.char_indices().nth(internal.0.position) { + internal.0.position += 1; + internal.1 = offset + ch.len_utf8(); + return Ok(ch.into_pyobject(vm)); + } + } else { + if let Some(value) = value.get(internal.1..) { + if let Some(ch) = value.chars().next() { + internal.0.position += 1; + internal.1 += ch.len_utf8(); + return Ok(ch.into_pyobject(vm)); + } + } } - - let ch = value[internal.position..].chars().next().ok_or_else(|| { - internal.status = Exhausted; - vm.new_stop_iteration() - })?; - - internal.position += ch.len_utf8(); - Ok(ch.into_pyobject(vm)) + internal.0.status = Exhausted; + Err(vm.new_stop_iteration()) } else { Err(vm.new_stop_iteration()) } @@ -1256,7 +1266,7 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new((PositionIterInternal::new(zelf, 0), 0)), } .into_object(vm)) } From bff00893e22f24f7a40172342a4b9701c646411d Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Oct 2021 16:49:12 +0200 Subject: [PATCH 17/19] fixup clippy --- vm/src/builtins/pystr.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 249af8eafc..a6bf250bc8 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -218,13 +218,11 @@ impl SlotIterator for PyStrIterator { internal.1 = offset + ch.len_utf8(); return Ok(ch.into_pyobject(vm)); } - } else { - if let Some(value) = value.get(internal.1..) { - if let Some(ch) = value.chars().next() { - internal.0.position += 1; - internal.1 += ch.len_utf8(); - return Ok(ch.into_pyobject(vm)); - } + } else if let Some(value) = value.get(internal.1..) { + if let Some(ch) = value.chars().next() { + internal.0.position += 1; + internal.1 += ch.len_utf8(); + return Ok(ch.into_pyobject(vm)); } } internal.0.status = Exhausted; From 350c4de4e586c94c82379f95b514fb441b46aa26 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Oct 2021 20:46:17 +0200 Subject: [PATCH 18/19] rebase to use PyIterReturn --- vm/src/builtins/bytearray.rs | 20 ++++--- vm/src/builtins/bytes.rs | 20 +++---- vm/src/builtins/dict.rs | 16 +++--- vm/src/builtins/enumerate.rs | 17 ++++-- vm/src/builtins/iter.rs | 102 +++++++++++++++-------------------- vm/src/builtins/list.rs | 40 +++++++------- vm/src/builtins/pystr.rs | 12 ++--- vm/src/builtins/set.rs | 8 +-- vm/src/builtins/tuple.rs | 21 ++++---- vm/src/stdlib/collections.rs | 60 +++++++++------------ 10 files changed, 145 insertions(+), 171 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 2a365a22ce..e79fb91a58 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -751,7 +751,7 @@ impl PyByteArrayIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() - .builtin_iter_reduce(|x| x.clone().into_object(), vm) + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] @@ -763,15 +763,13 @@ impl PyByteArrayIterator { } impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().next( - |bytearray, pos| { - let buf = bytearray.borrow_buf(); - buf.get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|&x| vm.ctx.new_int(x)) - }, - vm, - ) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().next(|bytearray, pos| { + let buf = bytearray.borrow_buf(); + Ok(match buf.get(pos) { + Some(&x) => PyIterReturn::Return(vm.ctx.new_int(x)), + None => PyIterReturn::StopIteration(None), + }) + }) } } diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 316ce8b086..38b7185780 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -605,7 +605,7 @@ impl PyBytesIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() - .builtin_iter_reduce(|x| x.clone().into_object(), vm) + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] @@ -617,17 +617,13 @@ impl PyBytesIterator { } impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().next( - |bytes, pos| { - bytes - .as_bytes() - .get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|&x| vm.ctx.new_int(x)) - }, - vm, - ) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().next(|bytes, pos| { + Ok(match bytes.as_bytes().get(pos) { + Some(&x) => PyIterReturn::Return(vm.ctx.new_int(x)), + None => PyIterReturn::StopIteration(None), + }) + }) } } diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 26d56867fd..30087c0fbd 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -733,7 +733,7 @@ macro_rules! dict_iterator { impl IteratorIterable for $iter_name {} impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); if let IterStatus::Active(dict) = &internal.status { if dict.entries.has_changed_size(&zelf.size) { @@ -745,15 +745,15 @@ macro_rules! dict_iterator { match dict.entries.next_entry(internal.position) { Some((position, key, value)) => { internal.position = position; - Ok(($result_fn)(vm, key, value)) + Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) } None => { internal.status = IterStatus::Exhausted; - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } else { - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } @@ -793,7 +793,7 @@ macro_rules! dict_iterator { impl IteratorIterable for $reverse_iter_name {} impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); if let IterStatus::Active(dict) = &internal.status { if dict.entries.has_changed_size(&zelf.size) { @@ -809,15 +809,15 @@ macro_rules! dict_iterator { } else { internal.position = position; } - Ok(($result_fn)(vm, key, value)) + Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) } None => { internal.status = IterStatus::Exhausted; - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } else { - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 596d228d01..4922e491d9 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -1,5 +1,6 @@ use super::{IterStatus, PositionIterInternal, PyIntRef, PyTypeRef}; use crate::common::lock::{PyMutex, PyRwLock}; +use crate::TypeProtocol; use crate::{ function::OptionalArg, protocol::{PyIter, PyIterReturn}, @@ -105,16 +106,26 @@ impl PyReverseSequenceIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() - .builtin_reversed_reduce(|x| x.clone(), vm) + .builtins_reversed_reduce(|x| x.clone(), vm) } } impl IteratorIterable for PyReverseSequenceIterator {} impl SlotIterator for PyReverseSequenceIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal .lock() - .rev_next(|obj, pos| obj.get_item(pos, vm), vm) + .rev_next(|obj, pos| match obj.get_item(pos, vm) { + Ok(ret) => Ok(PyIterReturn::Return(ret)), + Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + Ok(PyIterReturn::StopIteration(None)) + } + Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { + let args = e.get_arg(0); + Ok(PyIterReturn::StopIteration(args)) + } + Err(e) => Err(e), + }) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index b285dbd820..de39b23e35 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -71,7 +71,7 @@ impl PositionIterInternal { } } - pub fn builtin_iter_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + pub fn builtins_iter_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where F: FnOnce(&T) -> PyObjectRef, { @@ -79,7 +79,7 @@ impl PositionIterInternal { self._reduce(iter, f, vm) } - pub fn builtin_reversed_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + pub fn builtins_reversed_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where F: FnOnce(&T) -> PyObjectRef, { @@ -87,66 +87,42 @@ impl PositionIterInternal { self._reduce(reversed, f, vm) } - pub fn next(&mut self, f: F, vm: &VirtualMachine) -> PyResult + fn _next(&mut self, f: F, op: OP) -> PyResult where - F: FnOnce(&T, usize) -> PyResult, + F: FnOnce(&T, usize) -> PyResult, + OP: FnOnce(&mut Self), { if let IterStatus::Active(obj) = &self.status { - match f(obj, self.position) { - Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { - self.status = IterStatus::Exhausted; - Err(e) - } - Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - self.status = IterStatus::Exhausted; - Err(vm.new_stop_iteration()) - } - Err(e) if e.isinstance(&vm.ctx.exceptions.runtime_error) => { - self.status = IterStatus::Exhausted; - Err(e) - } - Err(e) => Err(e), - Ok(ret) => { - self.position += 1; - Ok(ret) - } + let ret = f(obj, self.position); + if let Ok(PyIterReturn::Return(_)) = ret { + op(self); + } else { + self.status = IterStatus::Exhausted; } + ret } else { - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } - pub fn rev_next(&mut self, f: F, vm: &VirtualMachine) -> PyResult + pub fn next(&mut self, f: F) -> PyResult where - F: FnOnce(&T, usize) -> PyResult, + F: FnOnce(&T, usize) -> PyResult, { - if let IterStatus::Active(obj) = &self.status { - match f(obj, self.position) { - Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { - self.status = IterStatus::Exhausted; - Err(e) - } - Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - self.status = IterStatus::Exhausted; - Err(vm.new_stop_iteration()) - } - Err(e) if e.isinstance(&vm.ctx.exceptions.runtime_error) => { - self.status = IterStatus::Exhausted; - Err(e) - } - Err(e) => Err(e), - Ok(ret) => { - if self.position == 0 { - self.status = IterStatus::Exhausted; - } else { - self.position -= 1; - } - Ok(ret) - } + self._next(f, |zelf| zelf.position += 1) + } + + pub fn rev_next(&mut self, f: F) -> PyResult + where + F: FnOnce(&T, usize) -> PyResult, + { + self._next(f, |zelf| { + if zelf.position == 0 { + zelf.status = IterStatus::Exhausted; + } else { + zelf.position -= 1; } - } else { - Err(vm.new_stop_iteration()) - } + }) } pub fn length_hint(&self, f: F) -> usize @@ -221,7 +197,7 @@ impl PySequenceIterator { #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.lock().builtin_iter_reduce(|x| x.clone(), vm) + self.internal.lock().builtins_iter_reduce(|x| x.clone(), vm) } #[pymethod(magic)] @@ -232,10 +208,20 @@ impl PySequenceIterator { impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal .lock() - .next(|obj, pos| obj.get_item(pos, vm), vm) + .next(|obj, pos| match obj.get_item(pos, vm) { + Ok(ret) => Ok(PyIterReturn::Return(ret)), + Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + Ok(PyIterReturn::StopIteration(None)) + } + Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { + let args = e.get_arg(0); + Ok(PyIterReturn::StopIteration(args)) + } + Err(e) => Err(e), + }) } } @@ -264,18 +250,18 @@ impl PyCallableIterator { impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let status = zelf.status.upgradable_read(); if let IterStatus::Active(callable) = &*status { let ret = callable.invoke((), vm)?; if vm.bool_eq(&ret, &zelf.sentinel)? { *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted; - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } else { - Ok(ret) + Ok(PyIterReturn::Return(ret)) } } else { - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 8c7ae63593..05a99fdd60 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -489,22 +489,20 @@ impl PyListIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() - .builtin_iter_reduce(|x| x.clone().into_object(), vm) + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().next( - |list, pos| { - let vec = list.borrow_vec(); - vec.get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|x| x.clone()) - }, - vm, - ) + fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().next(|list, pos| { + let vec = list.borrow_vec(); + Ok(match vec.get(pos) { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }) + }) } } @@ -538,22 +536,20 @@ impl PyListReverseIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() - .builtin_reversed_reduce(|x| x.clone().into_object(), vm) + .builtins_reversed_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyListReverseIterator {} impl SlotIterator for PyListReverseIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().rev_next( - |list, pos| { - let vec = list.borrow_vec(); - vec.get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|x| x.clone()) - }, - vm, - ) + fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().rev_next(|list, pos| { + let vec = list.borrow_vec(); + Ok(match vec.get(pos) { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }) + }) } } diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index a6bf250bc8..26c1d45f3c 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -200,13 +200,13 @@ impl PyStrIterator { self.internal .lock() .0 - .builtin_iter_reduce(|x| x.clone().into_object(), vm) + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); if let IterStatus::Active(s) = &internal.0.status { @@ -216,20 +216,18 @@ impl SlotIterator for PyStrIterator { if let Some((offset, ch)) = value.char_indices().nth(internal.0.position) { internal.0.position += 1; internal.1 = offset + ch.len_utf8(); - return Ok(ch.into_pyobject(vm)); + return Ok(PyIterReturn::Return(ch.into_pyobject(vm))); } } else if let Some(value) = value.get(internal.1..) { if let Some(ch) = value.chars().next() { internal.0.position += 1; internal.1 += ch.len_utf8(); - return Ok(ch.into_pyobject(vm)); + return Ok(PyIterReturn::Return(ch.into_pyobject(vm))); } } internal.0.status = Exhausted; - Err(vm.new_stop_iteration()) - } else { - Err(vm.new_stop_iteration()) } + Ok(PyIterReturn::StopIteration(None)) } } diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index d4ffc0c338..e57dac8fab 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -853,7 +853,7 @@ impl PySetIterator { impl IteratorIterable for PySetIterator {} impl SlotIterator for PySetIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); if let IterStatus::Active(dict) = &internal.status { if dict.has_changed_size(&zelf.size) { @@ -863,15 +863,15 @@ impl SlotIterator for PySetIterator { match dict.next_entry(internal.position) { Some((position, key, _)) => { internal.position = position; - Ok(key) + Ok(PyIterReturn::Return(key)) } None => { internal.status = IterStatus::Exhausted; - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } else { - Err(vm.new_stop_iteration()) + Ok(PyIterReturn::StopIteration(None)) } } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index e692d636c1..c99aabf7e0 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -342,23 +342,20 @@ impl PyTupleIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() - .builtin_iter_reduce(|x| x.clone().into_object(), vm) + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().next( - |tuple, pos| { - tuple - .as_slice() - .get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|x| x.clone()) - }, - vm, - ) + fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().next(|tuple, pos| { + Ok(if let Some(ret) = tuple.as_slice().get(pos) { + PyIterReturn::Return(ret.clone()) + } else { + PyIterReturn::StopIteration(None) + }) + }) } } diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 6444e8198b..d003c9c901 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -634,22 +634,17 @@ mod _collections { impl IteratorIterable for PyDequeIterator {} impl SlotIterator for PyDequeIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().next( - |deque, pos| { - if zelf.state != deque.state.load() { - return Err( - vm.new_runtime_error("Deque mutated during iteration".to_owned()) - ); - } - let deque = deque.borrow_deque(); - deque - .get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|x| x.clone()) - }, - vm, - ) + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().next(|deque, pos| { + if zelf.state != deque.state.load() { + return Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())); + } + let deque = deque.borrow_deque(); + Ok(match deque.get(pos) { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }) + }) } } @@ -706,26 +701,23 @@ mod _collections { impl IteratorIterable for PyReverseDequeIterator {} impl SlotIterator for PyReverseDequeIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.lock().next( - |deque, pos| { - if deque.state.load() != zelf.state { - return Err( - vm.new_runtime_error("Deque mutated during iteration".to_owned()) - ); - } - let deque = deque.borrow_deque(); - let pos = deque + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.internal.lock().next(|deque, pos| { + if deque.state.load() != zelf.state { + return Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())); + } + let deque = deque.borrow_deque(); + Ok( + match deque .len() .checked_sub(pos + 1) - .ok_or_else(|| vm.new_stop_iteration())?; - deque - .get(pos) - .ok_or_else(|| vm.new_stop_iteration()) - .map(|x| x.clone()) - }, - vm, - ) + .and_then(|pos| deque.get(pos)) + { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }, + ) + }) } } } From e1f689c5651749af68eeb214d3d9e1085034824c Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Oct 2021 21:16:43 +0200 Subject: [PATCH 19/19] introduce PyIterReturn::from_getitem_result --- vm/src/builtins/enumerate.rs | 13 +------------ vm/src/builtins/iter.rs | 15 ++------------- vm/src/protocol/iter.rs | 13 +++++++++++++ 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 4922e491d9..115631753d 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -1,6 +1,5 @@ use super::{IterStatus, PositionIterInternal, PyIntRef, PyTypeRef}; use crate::common::lock::{PyMutex, PyRwLock}; -use crate::TypeProtocol; use crate::{ function::OptionalArg, protocol::{PyIter, PyIterReturn}, @@ -115,17 +114,7 @@ impl SlotIterator for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal .lock() - .rev_next(|obj, pos| match obj.get_item(pos, vm) { - Ok(ret) => Ok(PyIterReturn::Return(ret)), - Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - Ok(PyIterReturn::StopIteration(None)) - } - Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { - let args = e.get_arg(0); - Ok(PyIterReturn::StopIteration(args)) - } - Err(e) => Err(e), - }) + .rev_next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(pos, vm), vm)) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index de39b23e35..6e9d64e49c 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -7,8 +7,7 @@ use crate::{ function::ArgCallable, protocol::PyIterReturn, slots::{IteratorIterable, SlotIterator}, - ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, - VirtualMachine, + ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, }; use rustpython_common::{ lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}, @@ -211,17 +210,7 @@ impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal .lock() - .next(|obj, pos| match obj.get_item(pos, vm) { - Ok(ret) => Ok(PyIterReturn::Return(ret)), - Err(e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - Ok(PyIterReturn::StopIteration(None)) - } - Err(e) if e.isinstance(&vm.ctx.exceptions.stop_iteration) => { - let args = e.get_arg(0); - Ok(PyIterReturn::StopIteration(args)) - } - Err(e) => Err(e), - }) + .next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(pos, vm), vm)) } } diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs index 967333bffb..9d081249fb 100644 --- a/vm/src/protocol/iter.rs +++ b/vm/src/protocol/iter.rs @@ -136,6 +136,19 @@ impl PyIterReturn { Err(err) => Err(err), } } + pub fn from_getitem_result(result: PyResult, vm: &VirtualMachine) -> PyResult { + match result { + Ok(obj) => Ok(Self::Return(obj)), + Err(err) if err.isinstance(&vm.ctx.exceptions.index_error) => { + Ok(Self::StopIteration(None)) + } + Err(err) if err.isinstance(&vm.ctx.exceptions.stop_iteration) => { + let args = err.get_arg(0); + Ok(Self::StopIteration(args)) + } + Err(err) => Err(err), + } + } } impl IntoPyResult for PyIterReturn {