diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index e37e5084a0..c0127cabe9 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -795,8 +795,6 @@ class BaseBytesTest: q = pickle.loads(ps) self.assertEqual(b, q) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iterator_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): for b in b"", b"a", b"abc", b"\xffab\x80", b"\0\0\377\0\0": diff --git a/extra_tests/snippets/iterations.py b/extra_tests/snippets/iterations.py index 98031a935e..b2b30961c0 100644 --- a/extra_tests/snippets/iterations.py +++ b/extra_tests/snippets/iterations.py @@ -9,3 +9,8 @@ assert next(i) == 3 assert next(i, 'w00t') == 'w00t' +s = '你好' +i = iter(s) +i.__setstate__(1) +assert i.__next__() == '好' +assert i.__reduce__()[2] == 2 diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 2ab6091dcd..e79fb91a58 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -1,9 +1,11 @@ //! Implementation of the python bytearray object. -use super::{PyBytes, PyBytesRef, PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef}; +use super::{ + PositionIterInternal, PyBytes, PyBytesRef, PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef, +}; use crate::common::{ borrow::{BorrowedValue, BorrowedValueMut}, lock::{ - PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyRwLock, PyRwLockReadGuard, + PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }, }; @@ -717,8 +719,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - position: AtomicCell::new(0), - bytearray: zelf, + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -731,8 +732,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - position: AtomicCell, - bytearray: PyByteArrayRef, + internal: PyMutex>, } impl PyValue for PyByteArrayIterator { @@ -742,16 +742,34 @@ impl PyValue for PyByteArrayIterator { } #[pyimpl(with(SlotIterator))] -impl PyByteArrayIterator {} +impl PyByteArrayIterator { + #[pymethod(magic)] + fn length_hint(&self) -> usize { + self.internal.lock().length_hint(|obj| obj.len()) + } + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .builtins_iter_reduce(|x| x.clone().into_object(), vm) + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + self.internal + .lock() + .set_state(state, |obj, pos| pos.min(obj.len()), vm) + } +} impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let pos = zelf.position.fetch_add(1); - let r = if let Some(&ret) = zelf.bytearray.borrow_buf().get(pos) { - PyIterReturn::Return(ret.into_pyobject(vm)) - } else { - PyIterReturn::StopIteration(None) - }; - Ok(r) + zelf.internal.lock().next(|bytearray, pos| { + let buf = bytearray.borrow_buf(); + Ok(match buf.get(pos) { + Some(&x) => PyIterReturn::Return(vm.ctx.new_int(x)), + None => PyIterReturn::StopIteration(None), + }) + }) } } diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 2df6000179..38b7185780 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -1,4 +1,4 @@ -use super::{PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef}; +use super::{PositionIterInternal, PyDictRef, PyIntRef, PyStrRef, PyTupleRef, PyTypeRef}; use crate::{ anystr::{self, AnyStr}, bytesinner::{ @@ -17,8 +17,10 @@ use crate::{ PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use crossbeam_utils::atomic::AtomicCell; -use rustpython_common::borrow::{BorrowedValue, BorrowedValueMut}; +use rustpython_common::{ + borrow::{BorrowedValue, BorrowedValueMut}, + lock::PyMutex, +}; use std::mem::size_of; use std::ops::Deref; @@ -574,8 +576,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - position: AtomicCell::new(0), - bytes: zelf, + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -584,8 +585,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - position: AtomicCell, - bytes: PyBytesRef, + internal: PyMutex>, } impl PyValue for PyBytesIterator { @@ -595,17 +595,35 @@ impl PyValue for PyBytesIterator { } #[pyimpl(with(SlotIterator))] -impl PyBytesIterator {} +impl PyBytesIterator { + #[pymethod(magic)] + fn length_hint(&self) -> usize { + self.internal.lock().length_hint(|obj| obj.len()) + } + + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .builtins_iter_reduce(|x| x.clone().into_object(), vm) + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + self.internal + .lock() + .set_state(state, |obj, pos| pos.min(obj.len()), vm) + } +} impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let pos = zelf.position.fetch_add(1); - let r = if let Some(&ret) = zelf.bytes.as_bytes().get(pos) { - PyIterReturn::Return(vm.ctx.new_int(ret)) - } else { - PyIterReturn::StopIteration(None) - }; - Ok(r) + zelf.internal.lock().next(|bytes, pos| { + Ok(match bytes.as_bytes().get(pos) { + Some(&x) => PyIterReturn::Return(vm.ctx.new_int(x)), + None => PyIterReturn::StopIteration(None), + }) + }) } } diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index c35b100c76..d77484858e 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -1,4 +1,4 @@ -use super::{IterStatus, PySet, PyStrRef, PyTypeRef}; +use super::{IterStatus, PositionIterInternal, PySet, PyStrRef, PyTypeRef}; use crate::{ builtins::PyBaseExceptionRef, common::ascii, @@ -14,7 +14,7 @@ use crate::{ PyAttributes, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::PyMutex; use std::fmt; use std::mem::size_of; @@ -595,7 +595,9 @@ impl Iterator for DictIter { type Item = (PyObjectRef, PyObjectRef); fn next(&mut self) -> Option { - self.dict.entries.next_entry(&mut self.position) + let (position, key, value) = self.dict.entries.next_entry(self.position)?; + self.position = position; + Some((key, value)) } fn size_hint(&self) -> (usize, Option) { @@ -712,10 +714,8 @@ macro_rules! dict_iterator { #[pyclass(module = false, name = $iter_class_name)] #[derive(Debug)] pub(crate) struct $iter_name { - pub dict: PyDictRef, pub size: dictdatatype::DictSize, - pub position: AtomicCell, - pub status: AtomicCell, + pub internal: PyMutex>, } impl PyValue for $iter_name { @@ -728,20 +728,14 @@ macro_rules! dict_iterator { impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { - position: AtomicCell::new(0), size: dict.size(), - dict, - status: AtomicCell::new(IterStatus::Active), + internal: PyMutex::new(PositionIterInternal::new(dict, 0)), } } #[pymethod(magic)] fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.entries.len_from_entry_index(self.position.load()) - } + self.internal.lock().length_hint(|_| self.size.entries_size) } } @@ -749,25 +743,26 @@ macro_rules! dict_iterator { impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.entries.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err(vm.new_runtime_error( - "dictionary changed size during iteration".to_owned(), - )); + let mut internal = zelf.internal.lock(); + if let IterStatus::Active(dict) = &internal.status { + if dict.entries.has_changed_size(&zelf.size) { + internal.status = IterStatus::Exhausted; + return Err(vm.new_runtime_error( + "dictionary changed size during iteration".to_owned(), + )); + } + match dict.entries.next_entry(internal.position) { + Some((position, key, value)) => { + internal.position = position; + Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) } - match zelf.dict.entries.next_entry_atomic(&zelf.position) { - Some((key, value)) => { - Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) - } - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + None => { + internal.status = IterStatus::Exhausted; + Ok(PyIterReturn::StopIteration(None)) } } + } else { + Ok(PyIterReturn::StopIteration(None)) } } } @@ -775,10 +770,8 @@ macro_rules! dict_iterator { #[pyclass(module = false, name = $reverse_iter_class_name)] #[derive(Debug)] pub(crate) struct $reverse_iter_name { - pub dict: PyDictRef, pub size: dictdatatype::DictSize, - pub position: AtomicCell, - pub status: AtomicCell, + internal: PyMutex>, } impl PyValue for $reverse_iter_name { @@ -790,21 +783,19 @@ macro_rules! dict_iterator { #[pyimpl(with(SlotIterator))] impl $reverse_iter_name { fn new(dict: PyDictRef) -> Self { + let size = dict.size(); + let position = size.entries_size.saturating_sub(1); $reverse_iter_name { - position: AtomicCell::new(0), - size: dict.size(), - dict, - status: AtomicCell::new(IterStatus::Active), + size, + internal: PyMutex::new(PositionIterInternal::new(dict, position)), } } #[pymethod(magic)] fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.entries.len_from_entry_index(self.position.load()) - } + self.internal + .lock() + .rev_length_hint(|_| self.size.entries_size) } } @@ -812,25 +803,30 @@ macro_rules! dict_iterator { impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.entries.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err(vm.new_runtime_error( - "dictionary changed size during iteration".to_owned(), - )); + let mut internal = zelf.internal.lock(); + if let IterStatus::Active(dict) = &internal.status { + if dict.entries.has_changed_size(&zelf.size) { + internal.status = IterStatus::Exhausted; + return Err(vm.new_runtime_error( + "dictionary changed size during iteration".to_owned(), + )); + } + match dict.entries.prev_entry(internal.position) { + Some((position, key, value)) => { + if internal.position == position { + internal.status = IterStatus::Exhausted; + } else { + internal.position = position; + } + Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) } - match zelf.dict.entries.next_entry_atomic_reversed(&zelf.position) { - Some((key, value)) => { - Ok(PyIterReturn::Return(($result_fn)(vm, key, value))) - } - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + None => { + internal.status = IterStatus::Exhausted; + Ok(PyIterReturn::StopIteration(None)) } } + } else { + Ok(PyIterReturn::StopIteration(None)) } } } diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index e2cfc1bf41..115631753d 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -1,17 +1,12 @@ -use super::{ - int, - IterStatus::{self, Active, Exhausted}, - PyInt, PyIntRef, PyTypeRef, -}; -use crate::common::lock::PyRwLock; +use super::{IterStatus, PositionIterInternal, PyIntRef, PyTypeRef}; +use crate::common::lock::{PyMutex, PyRwLock}; use crate::{ function::OptionalArg, protocol::{PyIter, PyIterReturn}, slots::{IteratorIterable, SlotConstructor, SlotIterator}, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, - TypeProtocol, VirtualMachine, + VirtualMachine, }; -use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::Zero; @@ -72,9 +67,7 @@ impl SlotIterator for PyEnumerate { #[pyclass(module = false, name = "reversed")] #[derive(Debug)] pub struct PyReverseSequenceIterator { - pub position: AtomicCell, - pub status: AtomicCell, - pub obj: PyObjectRef, + internal: PyMutex>, } impl PyValue for PyReverseSequenceIterator { @@ -86,80 +79,42 @@ impl PyValue for PyReverseSequenceIterator { #[pyimpl(with(SlotIterator))] impl PyReverseSequenceIterator { pub fn new(obj: PyObjectRef, len: usize) -> Self { + let position = len.saturating_sub(1); Self { - position: AtomicCell::new(len.saturating_sub(1)), - status: AtomicCell::new(if len == 0 { Exhausted } else { Active }), - obj, + internal: PyMutex::new(PositionIterInternal::new(obj, position)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyResult { - Ok(match self.status.load() { - Active => { - let position = self.position.load(); - if position > vm.obj_len(&self.obj)? { - 0 - } else { - position + 1 - } + let internal = self.internal.lock(); + if let IterStatus::Active(obj) = &internal.status { + if internal.position <= vm.obj_len(obj)? { + return Ok(internal.position + 1); } - Exhausted => 0, - }) + } + Ok(0) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - let len = vm.obj_len(&self.obj)?; - let pos = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let pos = std::cmp::min( - int::try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - len.saturating_sub(1), - ); - self.position.store(pos); - Ok(()) + self.internal.lock().set_state(state, |_, pos| pos, vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; - Ok(vm.ctx.new_tuple(match self.status.load() { - Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![])])], - Active => vec![ - iter, - vm.ctx.new_tuple(vec![self.obj.clone()]), - vm.ctx.new_int(self.position.load()), - ], - })) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .builtins_reversed_reduce(|x| x.clone(), vm) } } impl IteratorIterable for PyReverseSequenceIterator {} impl SlotIterator for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_sub(1); - if pos == 0 { - zelf.status.store(Exhausted); - } - match zelf.obj.get_item(pos, vm) { - Err(ref e) - if e.isinstance(&vm.ctx.exceptions.index_error) - || e.isinstance(&vm.ctx.exceptions.stop_iteration) => - { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } - other => other.map(PyIterReturn::Return), - } + zelf.internal + .lock() + .rev_next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(pos, vm), vm)) } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index efb0756b24..6e9d64e49c 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -7,26 +7,165 @@ use crate::{ function::ArgCallable, protocol::PyIterReturn, slots::{IteratorIterable, SlotIterator}, - ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, - VirtualMachine, + ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, +}; +use rustpython_common::{ + lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}, + static_cell, }; -use crossbeam_utils::atomic::AtomicCell; /// Marks status of iterator. -#[derive(Debug, Clone, Copy)] -pub enum IterStatus { +#[derive(Debug, Clone)] +pub enum IterStatus { /// Iterator hasn't raised StopIteration. - Active, + Active(T), /// Iterator has raised StopIteration. Exhausted, } +#[derive(Debug)] +pub struct PositionIterInternal { + pub status: IterStatus, + pub position: usize, +} + +impl PositionIterInternal { + pub fn new(obj: T, position: usize) -> Self { + Self { + status: IterStatus::Active(obj), + position, + } + } + + pub fn set_state(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()> + where + F: FnOnce(&T, usize) -> usize, + { + if let IterStatus::Active(obj) = &self.status { + if let Some(i) = state.payload::() { + let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); + self.position = f(obj, i); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } else { + Ok(()) + } + } + + fn _reduce(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + if let IterStatus::Active(obj) = &self.status { + vm.ctx.new_tuple(vec![ + func, + vm.ctx.new_tuple(vec![f(obj)]), + vm.ctx.new_int(self.position), + ]) + } else { + vm.ctx + .new_tuple(vec![func, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]) + } + } + + pub fn builtins_iter_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + let iter = builtins_iter(vm).clone(); + self._reduce(iter, f, vm) + } + + pub fn builtins_reversed_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + let reversed = builtins_reversed(vm).clone(); + self._reduce(reversed, f, vm) + } + + fn _next(&mut self, f: F, op: OP) -> PyResult + where + F: FnOnce(&T, usize) -> PyResult, + OP: FnOnce(&mut Self), + { + if let IterStatus::Active(obj) = &self.status { + let ret = f(obj, self.position); + if let Ok(PyIterReturn::Return(_)) = ret { + op(self); + } else { + self.status = IterStatus::Exhausted; + } + ret + } else { + Ok(PyIterReturn::StopIteration(None)) + } + } + + pub fn next(&mut self, f: F) -> PyResult + where + F: FnOnce(&T, usize) -> PyResult, + { + self._next(f, |zelf| zelf.position += 1) + } + + pub fn rev_next(&mut self, f: F) -> PyResult + where + F: FnOnce(&T, usize) -> PyResult, + { + self._next(f, |zelf| { + if zelf.position == 0 { + zelf.status = IterStatus::Exhausted; + } else { + zelf.position -= 1; + } + }) + } + + pub fn length_hint(&self, f: F) -> usize + where + F: FnOnce(&T) -> usize, + { + if let IterStatus::Active(obj) = &self.status { + f(obj).saturating_sub(self.position) + } else { + 0 + } + } + + pub fn rev_length_hint(&self, f: F) -> usize + where + F: FnOnce(&T) -> usize, + { + if let IterStatus::Active(obj) = &self.status { + if self.position <= f(obj) { + return self.position + 1; + } + } + 0 + } +} + +pub fn builtins_iter(vm: &VirtualMachine) -> &PyObjectRef { + static_cell! { + static INSTANCE: PyObjectRef; + } + INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "iter").unwrap()) +} + +pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObjectRef { + static_cell! { + static INSTANCE: PyObjectRef; + } + INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "reversed").unwrap()) +} + #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - pub position: AtomicCell, - pub obj: PyObjectRef, - pub status: AtomicCell, + internal: PyMutex>, } impl PyValue for PySequenceIterator { @@ -39,84 +178,47 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - position: AtomicCell::new(0), - obj, - status: AtomicCell::new(IterStatus::Active), + internal: PyMutex::new(PositionIterInternal::new(obj, 0)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - match self.status.load() { - IterStatus::Active => { - let pos = self.position.load(); - // return NotImplemented if no length is around. - vm.obj_len(&self.obj) - .map_or(vm.ctx.not_implemented(), |len| { - PyInt::from(len.saturating_sub(pos)).into_object(vm) - }) - } - IterStatus::Exhausted => PyInt::from(0).into_object(vm), + let internal = self.internal.lock(); + if let IterStatus::Active(obj) = &internal.status { + vm.obj_len(obj) + .map(|x| PyInt::from(x).into_object(vm)) + .unwrap_or_else(|_| vm.ctx.not_implemented()) + } else { + PyInt::from(0).into_object(vm) } } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - IterStatus::Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - IterStatus::Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.obj.clone()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal.lock().builtins_iter_reduce(|x| x.clone(), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let IterStatus::Exhausted = self.status.load() { - return Ok(()); - } - if let Some(i) = state.payload::() { - self.position - .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal.lock().set_state(state, |_, pos| pos, vm) } } impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let IterStatus::Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_add(1); - match zelf.obj.get_item(pos, vm) { - Err(ref e) - if e.isinstance(&vm.ctx.exceptions.index_error) - || e.isinstance(&vm.ctx.exceptions.stop_iteration) => - { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } - ret => ret.map(PyIterReturn::Return), - } + zelf.internal + .lock() + .next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(pos, vm), vm)) } } #[pyclass(module = false, name = "callable_iterator")] #[derive(Debug)] pub struct PyCallableIterator { - callable: ArgCallable, sentinel: PyObjectRef, - status: AtomicCell, + status: PyRwLock>, } impl PyValue for PyCallableIterator { @@ -129,9 +231,8 @@ impl PyValue for PyCallableIterator { impl PyCallableIterator { pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self { Self { - callable, sentinel, - status: AtomicCell::new(IterStatus::Active), + status: PyRwLock::new(IterStatus::Active(callable)), } } } @@ -139,15 +240,17 @@ impl PyCallableIterator { impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let IterStatus::Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let ret = zelf.callable.invoke((), vm)?; - if vm.bool_eq(&ret, &zelf.sentinel)? { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) + let status = zelf.status.upgradable_read(); + if let IterStatus::Active(callable) = &*status { + let ret = callable.invoke((), vm)?; + if vm.bool_eq(&ret, &zelf.sentinel)? { + *PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted; + Ok(PyIterReturn::StopIteration(None)) + } else { + Ok(PyIterReturn::Return(ret)) + } } else { - Ok(PyIterReturn::Return(ret)) + Ok(PyIterReturn::StopIteration(None)) } } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 1f0fdfad58..05a99fdd60 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -1,10 +1,6 @@ -use super::{ - int, - iter::IterStatus::{self, Active, Exhausted}, - PyGenericAlias, PyInt, PySliceRef, PyTypeRef, -}; +use super::{PositionIterInternal, PyGenericAlias, PySliceRef, PyTypeRef}; use crate::common::lock::{ - PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, + PyMappedRwLockReadGuard, PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; use crate::{ function::{ArgIterable, FuncArgs, OptionalArg}, @@ -19,7 +15,6 @@ use crate::{ PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; use std::fmt; use std::iter::FromIterator; use std::mem::size_of; @@ -161,16 +156,9 @@ impl PyList { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyListReverseIterator { - let final_position = zelf.borrow_vec().len(); - // Mark iterator as exhausted immediately if its empty. + let position = zelf.len().saturating_sub(1); PyListReverseIterator { - position: AtomicCell::new(final_position.saturating_sub(1)), - status: AtomicCell::new(if final_position == 0 { - Exhausted - } else { - Active - }), - list: zelf, + internal: PyMutex::new(PositionIterInternal::new(zelf, position)), } } @@ -419,9 +407,7 @@ impl PyList { impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(Active), - list: zelf, + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -476,9 +462,7 @@ fn do_sort( #[pyclass(module = false, name = "list_iterator")] #[derive(Debug)] pub struct PyListIterator { - pub position: AtomicCell, - status: AtomicCell, - pub list: PyListRef, + internal: PyMutex>, } impl PyValue for PyListIterator { @@ -491,61 +475,41 @@ impl PyValue for PyListIterator { impl PyListIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - match self.status.load() { - Active => { - let list = self.list.borrow_vec(); - let pos = self.position.load(); - list.len().saturating_sub(pos) - } - Exhausted => 0, - } + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - let position = list_state(self.list.len(), state, vm)?; - self.position.store(position); - Ok(()) + self.internal + .lock() + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let pos = if let Exhausted = self.status.load() { - None - } else { - Some(self.position.load()) - }; - list_reduce(self.list.clone(), pos, false, vm) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let list = zelf.list.borrow_vec(); - let pos = zelf.position.fetch_add(1); - if let Some(obj) = list.get(pos) { - Ok(PyIterReturn::Return(obj.clone())) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + zelf.internal.lock().next(|list, pos| { + let vec = list.borrow_vec(); + Ok(match vec.get(pos) { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }) + }) } } #[pyclass(module = false, name = "list_reverseiterator")] #[derive(Debug)] pub struct PyListReverseIterator { - pub position: AtomicCell, - pub status: AtomicCell, - pub list: PyListRef, + internal: PyMutex>, } impl PyValue for PyListReverseIterator { @@ -558,102 +522,37 @@ impl PyValue for PyListReverseIterator { impl PyListReverseIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - match self.status.load() { - Active => { - let position = self.position.load(); - if position > self.list.len() { - // List was mutated. Report zero, next call to `__next__` will - // fail and set iterator to Exhausted. - 0 - } else { - position + 1 - } - } - Exhausted => 0, - } + self.internal.lock().rev_length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - - // Max for position is list.len() - 1. - let position = list_state(self.list.len().saturating_sub(1), state, vm)?; - self.position.store(position); - Ok(()) + self.internal + .lock() + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let pos = if let Exhausted = self.status.load() { - None - } else { - Some(self.position.load()) - }; - list_reduce(self.list.clone(), pos, true, vm) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .builtins_reversed_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyListReverseIterator {} impl SlotIterator for PyListReverseIterator { fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let list = zelf.list.borrow_vec(); - let pos = zelf.position.fetch_sub(1); - if pos > 0 { - if let Some(obj) = list.get(pos) { - return Ok(PyIterReturn::Return(obj.clone())); - } - } - // We either are == 0 or list.get returned None. Either way, set status - // to exhausted and return last item if pos == 0. - zelf.status.store(Exhausted); - if pos == 0 { - if let Some(obj) = list.get(pos) { - return Ok(PyIterReturn::Return(obj.clone())); - } - } - Ok(PyIterReturn::StopIteration(None)) + zelf.internal.lock().rev_next(|list, pos| { + let vec = list.borrow_vec(); + Ok(match vec.get(pos) { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }) + }) } } -// Common reducer for forward and reverse list iterators. -fn list_reduce( - list: PyRef, - position: Option, - reverse: bool, - vm: &VirtualMachine, -) -> PyResult { - let attr = if reverse { "reversed" } else { "iter" }; - let iter = vm.get_attribute(vm.builtins.clone(), attr)?; - let elems = match position { - None => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])], - Some(position) => vec![ - iter, - vm.ctx.new_tuple(vec![list.into_object()]), - vm.ctx.new_int(position), - ], - }; - Ok(vm.ctx.new_tuple(elems)) -} - -// Common function to extract state. Clamps it in range [0, length]. -fn list_state(length: usize, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let position = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let position = std::cmp::min( - int::try_to_primitive(position.as_bigint(), vm).unwrap_or(0), - length, - ); - Ok(position) -} - pub fn init(context: &PyContext) { let list_type = &context.types.list_type; PyList::extend_class(context, list_type); diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index a3ff1d0bd8..26c1d45f3c 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -1,7 +1,7 @@ use super::{ - int::{try_to_primitive, PyInt, PyIntRef}, - iter::IterStatus::{self, Active, Exhausted}, - PyBytesRef, PyDict, PyTypeRef, + int::{PyInt, PyIntRef}, + iter::IterStatus::{self, Exhausted}, + PositionIterInternal, PyBytesRef, PyDict, PyTypeRef, }; use crate::{ anystr::{self, adjust_indices, AnyStr, AnyStrContainer, AnyStrWrapper}, @@ -19,13 +19,13 @@ use crate::{ PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_common::{ ascii, atomic::{self, PyAtomic, Radium}, hash, + lock::PyMutex, }; use std::mem::size_of; use std::ops::Range; @@ -170,9 +170,7 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - string: PyStrRef, - position: PyAtomic, - status: AtomicCell, + internal: PyMutex<(PositionIterInternal, usize)>, } impl PyValue for PyStrIterator { @@ -185,81 +183,51 @@ impl PyValue for PyStrIterator { impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - match self.status.load() { - Active => { - let pos = self.position.load(atomic::Ordering::SeqCst); - self.string.len().saturating_sub(pos) - } - Exhausted => 0, - } + self.internal.lock().0.length_hint(|obj| obj.char_len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - let pos = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let pos = std::cmp::min( - try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - self.string.len(), - ); - self.position.store(pos, atomic::Ordering::SeqCst); - Ok(()) + let mut internal = self.internal.lock(); + internal.1 = usize::MAX; + internal + .0 + .set_state(state, |obj, pos| pos.min(obj.char_len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(vm.ctx.new_tuple(match self.status.load() { - Exhausted => vec![ - iter, - vm.ctx.new_tuple(vec![vm.ctx.new_ascii_literal(ascii!(""))]), - ], - Active => vec![ - iter, - vm.ctx.new_tuple(vec![self.string.clone().into_object()]), - vm.ctx - .new_int(self.position.load(atomic::Ordering::Relaxed)), - ], - })) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .0 + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let value = &*zelf.string.as_str(); - let mut start = zelf.position.load(atomic::Ordering::SeqCst); - loop { - if start == value.len() { - zelf.status.store(Exhausted); - return Ok(PyIterReturn::StopIteration(None)); - } - let ch = match value[start..].chars().next() { - Some(ch) => ch, - None => { - zelf.status.store(Exhausted); - return Ok(PyIterReturn::StopIteration(None)); - } - }; + let mut internal = zelf.internal.lock(); - match zelf.position.compare_exchange_weak( - start, - start + ch.len_utf8(), - atomic::Ordering::Release, - atomic::Ordering::Relaxed, - ) { - Ok(_) => break Ok(PyIterReturn::Return(ch.into_pyobject(vm))), - Err(cur) => start = cur, + if let IterStatus::Active(s) = &internal.0.status { + let value = s.as_str(); + + if internal.1 == usize::MAX { + if let Some((offset, ch)) = value.char_indices().nth(internal.0.position) { + internal.0.position += 1; + internal.1 = offset + ch.len_utf8(); + return Ok(PyIterReturn::Return(ch.into_pyobject(vm))); + } + } else if let Some(value) = value.get(internal.1..) { + if let Some(ch) = value.chars().next() { + internal.0.position += 1; + internal.1 += ch.len_utf8(); + return Ok(PyIterReturn::Return(ch.into_pyobject(vm))); + } } + internal.0.status = Exhausted; } + Ok(PyIterReturn::StopIteration(None)) } } @@ -1294,9 +1262,7 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - position: Radium::new(0), - string: zelf, - status: AtomicCell::new(Active), + internal: PyMutex::new((PositionIterInternal::new(zelf, 0), 0)), } .into_object(vm)) } diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index 4f6db0a594..05c543d90e 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -1,4 +1,5 @@ use super::{PyInt, PyIntRef, PySlice, PySliceRef, PyTypeRef}; +use crate::builtins::builtins_iter; use crate::common::hash::PyHash; use crate::{ function::{FuncArgs, OptionalArg}, @@ -608,7 +609,7 @@ fn range_iter_reduce( index: usize, vm: &VirtualMachine, ) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + let iter = builtins_iter(vm).clone(); let stop = start.clone() + length * step.clone(); let range = PyRange { start: PyInt::from(start).into_ref(vm), diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index f162f58910..e57dac8fab 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -1,8 +1,8 @@ /* * Builtin set type with a sequence of unique items. */ -use super::{IterStatus, PyDictRef, PyTypeRef}; -use crate::common::{ascii, hash::PyHash, rc::PyRc}; +use super::{builtins_iter, IterStatus, PositionIterInternal, PyDictRef, PyTypeRef}; +use crate::common::{ascii, hash::PyHash, lock::PyMutex, rc::PyRc}; use crate::{ dictdatatype::{self, DictSize}, function::{ArgIterable, FuncArgs, OptionalArg, PosArgs}, @@ -15,7 +15,6 @@ use crate::{ IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; use std::fmt; pub type SetContentType = dictdatatype::Dict<()>; @@ -194,10 +193,8 @@ impl PySetInner { fn iter(&self) -> PySetIterator { PySetIterator { - dict: PyRc::clone(&self.content), size: self.content.size(), - position: AtomicCell::new(0), - status: AtomicCell::new(IterStatus::Active), + internal: PyMutex::new(PositionIterInternal::new(self.content.clone(), 0)), } } @@ -815,10 +812,8 @@ impl TryFromObject for SetIterable { #[pyclass(module = false, name = "set_iterator")] pub(crate) struct PySetIterator { - dict: PyRc, size: DictSize, - position: AtomicCell, - status: AtomicCell, + internal: PyMutex>>, } impl fmt::Debug for PySetIterator { @@ -838,25 +833,19 @@ impl PyValue for PySetIterator { impl PySetIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.len_from_entry_index(self.position.load()) - } + self.internal.lock().length_hint(|_| self.size.entries_size) } #[pymethod(magic)] fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> { + let internal = zelf.internal.lock(); Ok(( - vm.get_attribute(vm.builtins.clone(), "iter")?, - (vm.ctx.new_list(match zelf.status.load() { + builtins_iter(vm).clone(), + (vm.ctx.new_list(match &internal.status { IterStatus::Exhausted => vec![], - IterStatus::Active => zelf - .dict - .keys() - .into_iter() - .skip(zelf.position.load()) - .collect(), + IterStatus::Active(dict) => { + dict.keys().into_iter().skip(internal.position).collect() + } }),), )) } @@ -865,23 +854,24 @@ impl PySetIterator { impl IteratorIterable for PySetIterator {} impl SlotIterator for PySetIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err( - vm.new_runtime_error("set changed size during iteration".to_owned()) - ); + let mut internal = zelf.internal.lock(); + if let IterStatus::Active(dict) = &internal.status { + if dict.has_changed_size(&zelf.size) { + internal.status = IterStatus::Exhausted; + return Err(vm.new_runtime_error("set changed size during iteration".to_owned())); + } + match dict.next_entry(internal.position) { + Some((position, key, _)) => { + internal.position = position; + Ok(PyIterReturn::Return(key)) } - match zelf.dict.next_entry_atomic(&zelf.position) { - Some((key, _)) => Ok(PyIterReturn::Return(key)), - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + None => { + internal.status = IterStatus::Exhausted; + Ok(PyIterReturn::StopIteration(None)) } } + } else { + Ok(PyIterReturn::StopIteration(None)) } } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 03c11d4e50..466eade62d 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -1,8 +1,4 @@ -use super::{ - int, - iter::IterStatus::{self, Active, Exhausted}, - PyInt, PyTypeRef, -}; +use super::{PositionIterInternal, PyTypeRef}; use crate::common::hash::PyHash; use crate::{ function::OptionalArg, @@ -19,7 +15,7 @@ use crate::{ PyContext, PyObjectRef, PyRef, PyResult, PyValue, TransmuteFromObject, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::PyMutex; use std::fmt; use std::marker::PhantomData; @@ -310,9 +306,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(Active), - tuple: zelf, + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -321,9 +315,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - position: AtomicCell, - status: AtomicCell, - tuple: PyTupleRef, + internal: PyMutex>, } impl PyValue for PyTupleIterator { @@ -336,60 +328,34 @@ impl PyValue for PyTupleIterator { impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.tuple.len().saturating_sub(self.position.load()), - Exhausted => 0, - } + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - if let Exhausted = self.status.load() { - return Ok(()); - } - // Else, set to min of (pos, tuple_size). - if let Some(i) = state.payload::() { - let position = std::cmp::min( - int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0), - self.tuple.len(), - ); - self.position.store(position); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + self.internal + .lock() + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.tuple.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .lock() + .builtins_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { fn next(zelf: &PyRef, _vm: &VirtualMachine) -> PyResult { - if let Exhausted = zelf.status.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - let pos = zelf.position.fetch_add(1); - if let Some(obj) = zelf.tuple.as_slice().get(pos) { - Ok(PyIterReturn::Return(obj.clone())) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + zelf.internal.lock().next(|tuple, pos| { + Ok(if let Some(ret) = tuple.as_slice().get(pos) { + PyIterReturn::Return(ret.clone()) + } else { + PyIterReturn::StopIteration(None) + }) + }) } } diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 04ed619e00..8a5cd5a371 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -1,15 +1,14 @@ -use crate::builtins::{PyStr, PyStrRef}; /// Ordered dictionary implementation. /// Inspired by: https://morepypy.blogspot.com/2015/01/faster-more-memory-efficient-and-more.html /// And: https://www.youtube.com/watch?v=p33CVV29OG8 /// And: http://code.activestate.com/recipes/578375/ +use crate::builtins::{PyStr, PyStrRef}; use crate::common::{ hash, lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}, }; use crate::vm::VirtualMachine; use crate::{IdProtocol, IntoPyObject, PyObjectRef, PyRefExact, PyResult, TypeProtocol}; -use crossbeam_utils::atomic::AtomicCell; use std::fmt; use std::mem::size_of; @@ -110,8 +109,8 @@ struct DictEntry { #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, - entries_size: usize, - used: usize, + pub entries_size: usize, + pub used: usize, filled: usize, } @@ -479,45 +478,30 @@ impl Dict { self.read().size() } - pub fn next_entry(&self, position: &mut EntryIndex) -> Option<(PyObjectRef, T)> { + pub fn next_entry(&self, mut position: EntryIndex) -> Option<(usize, PyObjectRef, T)> { let inner = self.read(); loop { - let entry = inner.entries.get(*position)?; - *position += 1; + let entry = inner.entries.get(position)?; + position += 1; if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); + break Some((position, entry.key.clone(), entry.value.clone())); } } } - pub fn next_entry_atomic(&self, position: &AtomicCell) -> Option<(PyObjectRef, T)> { + pub fn prev_entry(&self, mut position: EntryIndex) -> Option<(usize, PyObjectRef, T)> { let inner = self.read(); loop { - let position_usize = position.fetch_add(1); - let entry = inner.entries.get(position_usize)?; + let entry = inner.entries.get(position)?; + position = position.saturating_sub(1); if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); - } - } - } - - pub fn next_entry_atomic_reversed( - &self, - position: &AtomicCell, - ) -> Option<(PyObjectRef, T)> { - let inner = self.read(); - loop { - let position_usize = position.fetch_add(1); - let position_index = inner.entries.len().checked_sub(position_usize + 1)?; - let entry = inner.entries.get(position_index)?; - if let Some(entry) = entry { - break Some((entry.key.clone(), entry.value.clone())); + break Some((position, entry.key.clone(), entry.value.clone())); } } } pub fn len_from_entry_index(&self, position: EntryIndex) -> usize { - self.read().entries.len() - position + self.read().entries.len().saturating_sub(position) } pub fn has_changed_size(&self, old: &DictSize) -> bool { diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs index 967333bffb..9d081249fb 100644 --- a/vm/src/protocol/iter.rs +++ b/vm/src/protocol/iter.rs @@ -136,6 +136,19 @@ impl PyIterReturn { Err(err) => Err(err), } } + pub fn from_getitem_result(result: PyResult, vm: &VirtualMachine) -> PyResult { + match result { + Ok(obj) => Ok(Self::Return(obj)), + Err(err) if err.isinstance(&vm.ctx.exceptions.index_error) => { + Ok(Self::StopIteration(None)) + } + Err(err) if err.isinstance(&vm.ctx.exceptions.stop_iteration) => { + let args = err.get_arg(0); + Ok(Self::StopIteration(args)) + } + Err(err) => Err(err), + } + } } impl IntoPyResult for PyIterReturn { diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 9158ea67a7..d003c9c901 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -2,10 +2,11 @@ pub(crate) use _collections::make_module; #[pymodule] mod _collections { - use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; + use crate::builtins::PositionIterInternal; + use crate::common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::{ builtins::{ - IterStatus::{self, Active, Exhausted}, + IterStatus::{Active, Exhausted}, PyInt, PyTypeRef, }, function::{FuncArgs, KwArgs, OptionalArg}, @@ -357,12 +358,9 @@ mod _collections { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyResult { - let length = zelf.len(); Ok(PyReverseDequeIterator { - position: AtomicCell::new(length), - status: AtomicCell::new(if length > 0 { Active } else { Exhausted }), - length, - deque: zelf, + state: zelf.state.load(), + internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), }) } @@ -573,10 +571,8 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug, PyValue)] struct PyDequeIterator { - position: AtomicCell, - status: AtomicCell, - length: usize, // To track length immutability. - deque: PyDequeRef, + state: usize, + internal: PyMutex>, } #[derive(FromArgs)] @@ -596,15 +592,10 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - let len = deque.len(); let iter = PyDequeIterator::new(deque); if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; - iter.position.store(min(index, len)); - - if len.le(&index) { - iter.status.store(Exhausted); - } + iter.internal.lock().position = index; } iter.into_pyresult_with_type(vm, cls) } @@ -614,56 +605,46 @@ mod _collections { impl PyDequeIterator { pub(crate) fn new(deque: PyDequeRef) -> Self { PyDequeIterator { - position: AtomicCell::new(0), - status: AtomicCell::new(IterStatus::Active), - length: deque.len(), - deque, + state: deque.state.load(), + internal: PyMutex::new(PositionIterInternal::new(deque, 0)), } } #[pymethod(magic)] fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.deque.len().saturating_sub(self.position.load()), - Exhausted => 0, - } + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { - Ok(( + ) -> (PyTypeRef, (PyDequeRef, PyObjectRef)) { + let internal = zelf.internal.lock(); + let deque = match &internal.status { + Active(obj) => obj.clone(), + Exhausted => PyDeque::default().into_ref(vm), + }; + ( zelf.clone_class(), - (zelf.deque.clone(), vm.ctx.new_int(zelf.position.load())), - )) + (deque, vm.ctx.new_int(internal.position)), + ) } } impl IteratorIterable for PyDequeIterator {} impl SlotIterator for PyDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - Exhausted => Ok(PyIterReturn::StopIteration(None)), - Active => { - if zelf.length != zelf.deque.len() { - // Deque was changed while we iterated. - zelf.status.store(Exhausted); - Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - } else { - let pos = zelf.position.fetch_add(1); - let deque = zelf.deque.borrow_deque(); - if pos < deque.len() { - let ret = deque[pos].clone(); - Ok(PyIterReturn::Return(ret)) - } else { - zelf.status.store(Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } - } + zelf.internal.lock().next(|deque, pos| { + if zelf.state != deque.state.load() { + return Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())); } - } + let deque = deque.borrow_deque(); + Ok(match deque.get(pos) { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }) + }) } } @@ -671,10 +652,9 @@ mod _collections { #[pyclass(name = "_deque_reverse_iterator")] #[derive(Debug, PyValue)] struct PyReverseDequeIterator { - position: AtomicCell, - status: AtomicCell, - length: usize, // To track length immutability. - deque: PyDequeRef, + state: usize, + // position is counting from the tail + internal: PyMutex>, } impl SlotConstructor for PyReverseDequeIterator { @@ -686,14 +666,10 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - let len = deque.len(); let iter = PyDeque::reversed(deque)?; if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; - if len.le(&index) { - iter.status.store(Exhausted); - } - iter.position.store(len.saturating_sub(index)); + iter.internal.lock().position = index; } iter.into_pyresult_with_type(vm, cls) } @@ -703,10 +679,7 @@ mod _collections { impl PyReverseDequeIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - match self.status.load() { - Active => self.position.load(), - Exhausted => 0, - } + self.internal.lock().length_hint(|obj| obj.len()) } #[pymethod(magic)] @@ -714,9 +687,14 @@ mod _collections { zelf: PyRef, vm: &VirtualMachine, ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { + let internal = zelf.internal.lock(); + let deque = match &internal.status { + Active(obj) => obj.clone(), + Exhausted => PyDeque::default().into_ref(vm), + }; Ok(( zelf.clone_class(), - (zelf.deque.clone(), vm.ctx.new_int(zelf.position.load())), + (deque, vm.ctx.new_int(internal.position)), )) } } @@ -724,32 +702,22 @@ mod _collections { impl IteratorIterable for PyReverseDequeIterator {} impl SlotIterator for PyReverseDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - Exhausted => Ok(PyIterReturn::StopIteration(None)), - Active => { - // If length changes while we iterate, set to Exhausted and bail. - if zelf.length != zelf.deque.len() { - zelf.status.store(Exhausted); - Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - } else { - let pos = zelf.position.fetch_sub(1) - 1; - let deque = zelf.deque.borrow_deque(); - if pos > 0 { - if let Some(obj) = deque.get(pos) { - return Ok(PyIterReturn::Return(obj.clone())); - } - } - // We either are == 0 or deque.get returned None. Either way, set status - // to exhausted and return last item if pos == 0. - zelf.status.store(Exhausted); - if pos == 0 { - // Can safely index directly. - return Ok(PyIterReturn::Return(deque[pos].clone())); - } - Ok(PyIterReturn::StopIteration(None)) - } + zelf.internal.lock().next(|deque, pos| { + if deque.state.load() != zelf.state { + return Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())); } - } + let deque = deque.borrow_deque(); + Ok( + match deque + .len() + .checked_sub(pos + 1) + .and_then(|pos| deque.get(pos)) + { + Some(x) => PyIterReturn::Return(x.clone()), + None => PyIterReturn::StopIteration(None), + }, + ) + }) } } }