diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index f049e4803..8b1a8011c 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -717,7 +717,7 @@ impl Unhashable for PyByteArray {} impl Iterable for PyByteArray { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteArrayIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -730,7 +730,7 @@ impl Iterable for PyByteArray { #[pyclass(module = false, name = "bytearray_iterator")] #[derive(Debug)] pub struct PyByteArrayIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyByteArrayIterator { @@ -743,14 +743,13 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] @@ -762,8 +761,7 @@ impl IteratorIterable for PyByteArrayIterator {} impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let bytearray = obj.payload::().unwrap(); + |bytearray, pos| { let buf = bytearray.borrow_buf(); buf.get(pos) .ok_or_else(|| vm.new_stop_iteration()) diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 5d710e6d5..200937a94 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -574,7 +574,7 @@ impl Comparable for PyBytes { impl Iterable for PyBytes { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyBytesIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -583,7 +583,7 @@ impl Iterable for PyBytes { #[pyclass(module = false, name = "bytes_iterator")] #[derive(Debug)] pub struct PyBytesIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyBytesIterator { @@ -596,15 +596,14 @@ impl PyValue for PyBytesIterator { impl PyBytesIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } #[pymethod(magic)] @@ -616,8 +615,7 @@ impl IteratorIterable for PyBytesIterator {} impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let bytes = obj.payload::().unwrap(); + |bytes, pos| { bytes .as_bytes() .get(pos) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 50f8498c5..39a612066 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -704,11 +704,8 @@ macro_rules! dict_iterator { #[pyclass(module = false, name = $iter_class_name)] #[derive(Debug)] pub(crate) struct $iter_name { - // pub dict: PyDictRef, pub size: dictdatatype::DictSize, - // pub position: AtomicCell, - // pub status: AtomicCell, - pub internal: PyRwLock, + pub internal: PyRwLock>, } impl PyValue for $iter_name { @@ -721,10 +718,8 @@ macro_rules! dict_iterator { impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { - // position: AtomicCell::new(0), size: dict.size(), - // dict, - internal: PyRwLock::new(PositionIterInternal::new(dict.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(dict, 0)), } } @@ -732,12 +727,7 @@ macro_rules! dict_iterator { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .length_hint(|obj| obj.payload::().map(|x| x.entries.len()), vm) - // if let IterStatus::Exhausted = self.status.load() { - // 0 - // } else { - // self.dict.entries.len_from_entry_index(self.position.load()) - // } + .length_hint(|obj| Some(obj.entries.len()), vm) } } @@ -745,54 +735,34 @@ macro_rules! dict_iterator { impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let internal = zelf.internal.write(); - if let IterStatus::Active(obj) = &internal.status { - let dict = obj.payload::().unwrap(); + let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); + let mut position = + PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); + if let IterStatus::Active(dict) = &*status { if dict.entries.has_changed_size(&zelf.size) { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; return Err(vm.new_runtime_error( "dictionary changed size during iteration".to_owned(), )); } - match dict.entries.next_entry(&mut internal.position) { + match dict.entries.next_entry(&mut *position) { Some((key, value)) => Ok(($result_fn)(vm, key, value)), None => { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } } else { Err(vm.new_stop_iteration()) } - // match zelf.status.load() { - // IterStatus::Exhausted => Err(vm.new_stop_iteration()), - // IterStatus::Active => { - // if zelf.dict.entries.has_changed_size(&zelf.size) { - // zelf.status.store(IterStatus::Exhausted); - // return Err(vm.new_runtime_error( - // "dictionary changed size during iteration".to_owned(), - // )); - // } - // match zelf.dict.entries.next_entry_atomic(&zelf.position) { - // Some((key, value)) => Ok(($result_fn)(vm, key, value)), - // None => { - // zelf.status.store(IterStatus::Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } #[pyclass(module = false, name = $reverse_iter_class_name)] #[derive(Debug)] pub(crate) struct $reverse_iter_name { - // pub dict: PyDictRef, pub size: dictdatatype::DictSize, - // pub position: AtomicCell, - // pub status: AtomicCell, - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for $reverse_iter_name { @@ -804,16 +774,11 @@ macro_rules! dict_iterator { #[pyimpl(with(SlotIterator))] impl $reverse_iter_name { fn new(dict: PyDictRef) -> Self { - let position = dict.entries.len().saturating_sub(1); + let size = dict.size(); + let position = size.entries_size.saturating_sub(1); $reverse_iter_name { - // position: AtomicCell::new(0), - size: dict.size(), - // dict, - // status: AtomicCell::new(IterStatus::Active), - internal: PyRwLock::new(PositionIterInternal::new( - dict.into_object(), - position, - )), + size, + internal: PyRwLock::new(PositionIterInternal::new(dict, position)), } } @@ -821,12 +786,7 @@ macro_rules! dict_iterator { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .rev_length_hint(|obj| obj.payload::().map(|x| x.entries.len()), vm) - // if let IterStatus::Exhausted = self.status.load() { - // 0 - // } else { - // self.dict.entries.len_from_entry_index(self.position.load()) - // } + .rev_length_hint(|_| Some(self.size.entries_size), vm) } } @@ -834,43 +794,26 @@ macro_rules! dict_iterator { impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let internal = zelf.internal.write(); - if let IterStatus::Active(obj) = &internal.status { - let dict = obj.payload::().unwrap(); + let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); + let mut position = + PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); + if let IterStatus::Active(dict) = &*status { if dict.entries.has_changed_size(&zelf.size) { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; return Err(vm.new_runtime_error( "dictionary changed size during iteration".to_owned(), )); } - match dict.entries.prev_entry(&mut internal.position) { + match dict.entries.prev_entry(&mut *position) { Some((key, value)) => Ok(($result_fn)(vm, key, value)), None => { - internal.status = IterStatus::Exhausted; + *status = IterStatus::Exhausted; Err(vm.new_stop_iteration()) } } } else { Err(vm.new_stop_iteration()) } - // match zelf.status.load() { - // IterStatus::Exhausted => Err(vm.new_stop_iteration()), - // IterStatus::Active => { - // if zelf.dict.entries.has_changed_size(&zelf.size) { - // zelf.status.store(IterStatus::Exhausted); - // return Err(vm.new_runtime_error( - // "dictionary changed size during iteration".to_owned(), - // )); - // } - // match zelf.dict.entries.next_entry_atomic_reversed(&zelf.position) { - // Some((key, value)) => Ok(($result_fn)(vm, key, value)), - // None => { - // zelf.status.store(IterStatus::Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } }; diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 3dcc589c3..7ad93f55f 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -72,10 +72,7 @@ impl SlotIterator for PyEnumerate { #[pyclass(module = false, name = "reversed")] #[derive(Debug)] pub struct PyReverseSequenceIterator { - internal: PyRwLock, - // pub position: AtomicCell, - // pub status: AtomicCell, - // pub obj: PyObjectRef, + internal: PyRwLock>, } impl PyValue for PyReverseSequenceIterator { @@ -89,10 +86,7 @@ impl PyReverseSequenceIterator { pub fn new(obj: PyObjectRef, len: usize) -> Self { let position = len.saturating_sub(1); Self { - internal: PyRwLock::new(PositionIterInternal::new(obj, position)) - // position: AtomicCell::new(len.saturating_sub(1)), - // status: AtomicCell::new(if len == 0 { Exhausted } else { Active }), - // obj, + internal: PyRwLock::new(PositionIterInternal::new(obj, position)), } } @@ -101,50 +95,18 @@ impl PyReverseSequenceIterator { self.internal .read() .rev_length_hint(|obj| vm.obj_len(obj).ok(), vm) - // Ok(match self.status.load() { - // Active => { - // let position = self.position.load(); - // if position > vm.obj_len(&self.obj)? { - // 0 - // } else { - // position + 1 - // } - // } - // Exhausted => 0, - // }) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.read().set_state(state, vm) - // // When we're exhausted, just return. - // if let Exhausted = self.status.load() { - // return Ok(()); - // } - // let len = vm.obj_len(&self.obj)?; - // let pos = state - // .payload::() - // .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - // let pos = std::cmp::min( - // try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - // len.saturating_sub(1), - // ); - // self.position.store(pos); - // Ok(()) + self.internal.write().set_state(state, vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; - Ok(self.internal.read().reduce(iter, vm)) - // Ok(vm.ctx.new_tuple(match self.status.load() { - // Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![])])], - // Active => vec![ - // iter, - // vm.ctx.new_tuple(vec![self.obj.clone()]), - // vm.ctx.new_int(self.position.load()), - // ], - // })) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_reversed_reduce(|x| x.clone(), vm) } } @@ -152,23 +114,8 @@ impl IteratorIterable for PyReverseSequenceIterator {} impl SlotIterator for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal - .read() + .write() .rev_next(|obj, pos| obj.get_item(pos, vm), vm) - // if let Exhausted = zelf.status.load() { - // return Err(vm.new_stop_iteration()); - // } - // let pos = zelf.position.fetch_sub(1); - // if pos == 0 { - // zelf.status.store(Exhausted); - // } - // match zelf.obj.get_item(pos, vm) { - // Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - // zelf.status.store(Exhausted); - // Err(vm.new_stop_iteration()) - // } - // // also catches stop_iteration => stop_iteration - // ret => ret, - // } } } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 0acae123c..45e1fa1a8 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -14,21 +14,21 @@ use crossbeam_utils::atomic::AtomicCell; /// Marks status of iterator. #[derive(Debug, Clone)] -pub enum IterStatus { +pub enum IterStatus { /// Iterator hasn't raised StopIteration. - Active(PyObjectRef), + Active(T), /// Iterator has raised StopIteration. Exhausted, } #[derive(Debug)] -pub struct PositionIterInternal { - pub status: IterStatus, +pub struct PositionIterInternal { + pub status: IterStatus, pub position: usize, } -impl PositionIterInternal { - pub fn new(obj: PyObjectRef, position: usize) -> Self { +impl PositionIterInternal { + pub fn new(obj: T, position: usize) -> Self { Self { status: IterStatus::Active(obj), position, @@ -49,11 +49,14 @@ impl PositionIterInternal { } } - pub fn reduce(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + fn _reduce(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { if let IterStatus::Active(obj) = &self.status { vm.ctx.new_tuple(vec![ func, - vm.ctx.new_tuple(vec![obj.clone()]), + vm.ctx.new_tuple(vec![f(obj)]), vm.ctx.new_int(self.position), ]) } else { @@ -62,9 +65,25 @@ impl PositionIterInternal { } } + pub fn builtin_iter_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + let iter = vm.get_attribute(vm.builtins.clone(), "iter").unwrap(); + self._reduce(iter, f, vm) + } + + pub fn builtin_reversed_reduce(&self, f: F, vm: &VirtualMachine) -> PyObjectRef + where + F: FnOnce(&T) -> PyObjectRef, + { + let reversed = vm.get_attribute(vm.builtins.clone(), "reversed").unwrap(); + self._reduce(reversed, f, vm) + } + pub fn next(&mut self, f: F, vm: &VirtualMachine) -> PyResult where - F: FnOnce(&PyObjectRef, usize) -> PyResult, + F: FnOnce(&T, usize) -> PyResult, { if let IterStatus::Active(obj) = &self.status { match f(obj, self.position) { @@ -89,7 +108,7 @@ impl PositionIterInternal { pub fn rev_next(&mut self, f: F, vm: &VirtualMachine) -> PyResult where - F: FnOnce(&PyObjectRef, usize) -> PyResult, + F: FnOnce(&T, usize) -> PyResult, { if let IterStatus::Active(obj) = &self.status { match f(obj, self.position) { @@ -118,7 +137,7 @@ impl PositionIterInternal { pub fn length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where - F: FnOnce(&PyObjectRef) -> Option, + F: FnOnce(&T) -> Option, { let len = if let IterStatus::Active(obj) = &self.status { if let Some(obj_len) = f(obj) { @@ -134,7 +153,7 @@ impl PositionIterInternal { pub fn rev_length_hint(&self, f: F, vm: &VirtualMachine) -> PyObjectRef where - F: FnOnce(&PyObjectRef) -> Option, + F: FnOnce(&T) -> Option, { if let IterStatus::Active(obj) = &self.status { if let Some(obj_len) = f(obj) { @@ -142,6 +161,7 @@ impl PositionIterInternal { return PyInt::from(self.position + 1).into_object(vm); } } + // FIXME: return NotImplemented? } PyInt::from(0).into_object(vm) } @@ -150,7 +170,7 @@ impl PositionIterInternal { #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - internal: PositionIterInternal, + internal: PyRwLock>, } impl PyValue for PySequenceIterator { @@ -163,31 +183,34 @@ impl PyValue for PySequenceIterator { impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { - internal: PositionIterInternal::new(obj, 0), + internal: PyRwLock::new(PositionIterInternal::new(obj, 0)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal.length_hint(|obj| vm.obj_len(obj).ok(), vm) + self.internal + .read() + .length_hint(|obj| vm.obj_len(obj).ok(), vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal.read().builtin_iter_reduce(|x| x.clone(), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.set_state(state, vm) + self.internal.write().set_state(state, vm) } } impl IteratorIterable for PySequenceIterator {} impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - zelf.internal.next(|obj, pos| obj.get_item(pos, vm), vm) + zelf.internal + .write() + .next(|obj, pos| obj.get_item(pos, vm), vm) } } @@ -195,7 +218,7 @@ impl SlotIterator for PySequenceIterator { #[derive(Debug)] pub struct PyCallableIterator { sentinel: PyObjectRef, - status: PyRwLock, + status: PyRwLock>, } impl PyValue for PyCallableIterator { @@ -209,7 +232,7 @@ impl PyCallableIterator { pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self { Self { sentinel, - status: PyRwLock::new(IterStatus::Active(callable.into_object())), + status: PyRwLock::new(IterStatus::Active(callable)), } } } @@ -218,7 +241,7 @@ impl IteratorIterable for PyCallableIterator {} impl SlotIterator for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let IterStatus::Active(callable) = &*zelf.status.read() { - let ret = vm.invoke(callable, ())?; + let ret = callable.invoke((), vm)?; if vm.bool_eq(&ret, &zelf.sentinel)? { *zelf.status.write() = IterStatus::Exhausted; Err(vm.new_stop_iteration()) diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 4c2e0b241..665658e05 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -162,16 +162,8 @@ impl PyList { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyListReverseIterator { let position = zelf.len().saturating_sub(1); - // Mark iterator as exhausted immediately if its empty. PyListReverseIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), position)) - // position: AtomicCell::new(final_position.saturating_sub(1)), - // status: AtomicCell::new(if final_position == 0 { - // Exhausted - // } else { - // Active - // }), - // list: zelf, + internal: PyRwLock::new(PositionIterInternal::new(zelf, position)), } } @@ -420,7 +412,7 @@ impl PyList { impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -475,7 +467,7 @@ fn do_sort( #[pyclass(module = false, name = "list_iterator")] #[derive(Debug)] pub struct PyListIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyListIterator { @@ -488,9 +480,7 @@ impl PyValue for PyListIterator { impl PyListIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] @@ -499,9 +489,10 @@ impl PyListIterator { } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -509,8 +500,7 @@ impl IteratorIterable for PyListIterator {} impl SlotIterator for PyListIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let list = obj.payload::().unwrap(); + |list, pos| { let vec = list.borrow_vec(); vec.get(pos) .ok_or_else(|| vm.new_stop_iteration()) @@ -524,9 +514,7 @@ impl SlotIterator for PyListIterator { #[pyclass(module = false, name = "list_reverseiterator")] #[derive(Debug)] pub struct PyListReverseIterator { - internal: PyRwLock, // pub position: AtomicCell, - // pub status: AtomicCell, - // pub list: PyListRef + internal: PyRwLock>, } impl PyValue for PyListReverseIterator { @@ -541,46 +529,19 @@ impl PyListReverseIterator { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .rev_length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => { - // let position = self.position.load(); - // if position > self.list.len() { - // // List was mutated. Report zero, next call to `__next__` will - // // fail and set iterator to Exhausted. - // 0 - // } else { - // position + 1 - // } - // } - // Exhausted => 0, - // } + .rev_length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - // When we're exhausted, just return. - // if let Exhausted = self.status.load() { - // return Ok(()); - // } - - // // Max for position is list.len() - 1. - // let position = list_state(self.list.len().saturating_sub(1), state, vm)?; - // self.position.store(position); - // Ok(()) self.internal.write().set_state(state, vm) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; - Ok(self.internal.read().reduce(iter, vm)) - // let pos = if let Exhausted = self.status.load() { - // None - // } else { - // Some(self.position.load()) - // }; - // list_reduce(self.list.clone(), pos, true, vm) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_reversed_reduce(|x| x.clone().into_object(), vm) } } @@ -588,8 +549,7 @@ impl IteratorIterable for PyListReverseIterator {} impl SlotIterator for PyListReverseIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().rev_next( - |obj, pos| { - let list = obj.payload::().unwrap(); + |list, pos| { let vec = list.borrow_vec(); vec.get(pos) .ok_or_else(|| vm.new_stop_iteration()) @@ -597,60 +557,9 @@ impl SlotIterator for PyListReverseIterator { }, vm, ) - // if let Exhausted = zelf.status.load() { - // return Err(vm.new_stop_iteration()); - // } - // let list = zelf.list.borrow_vec(); - // let pos = zelf.position.fetch_sub(1); - // if pos > 0 { - // if let Some(obj) = list.get(pos) { - // return Ok(obj.clone()); - // } - // } - // // We either are == 0 or list.get returned None. Either way, set status - // // to exhausted and return last item if pos == 0. - // zelf.status.store(Exhausted); - // if pos == 0 { - // if let Some(obj) = list.get(pos) { - // return Ok(obj.clone()); - // } - // } - // Err(vm.new_stop_iteration()) } } -// Common reducer for forward and reverse list iterators. -fn list_reduce( - list: PyRef, - position: Option, - reverse: bool, - vm: &VirtualMachine, -) -> PyResult { - let attr = if reverse { "reversed" } else { "iter" }; - let iter = vm.get_attribute(vm.builtins.clone(), attr)?; - let elems = match position { - None => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])], - Some(position) => vec![ - iter, - vm.ctx.new_tuple(vec![list.into_object()]), - vm.ctx.new_int(position), - ], - }; - Ok(vm.ctx.new_tuple(elems)) -} - -// Common function to extract state. Clamps it in range [0, length]. -fn list_state(length: usize, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let position = state - .payload::() - .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - let position = std::cmp::min( - int::try_to_primitive(position.as_bigint(), vm).unwrap_or(0), - length, - ); - Ok(position) -} - pub fn init(context: &PyContext) { let list_type = &context.types.list_type; PyList::extend_class(context, list_type); diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 70e5f3383..71eb569b2 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -19,7 +19,6 @@ use crate::{ PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef, TypeProtocol, VirtualMachine, }; use bstr::ByteSlice; -use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; use num_traits::ToPrimitive; use rustpython_common::{ @@ -170,10 +169,7 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - // string: PyStrRef, - // position: PyAtomic, - // status: AtomicCell, - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyStrIterator { @@ -186,49 +182,27 @@ impl PyValue for PyStrIterator { impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => { - // let pos = self.position.load(atomic::Ordering::SeqCst); - // self.string.len().saturating_sub(pos) - // } - // Exhausted => 0, - // } + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal.write().set_state(state, vm) - // // When we're exhausted, just return. - // if let Exhausted = self.status.load() { - // return Ok(()); - // } - // let pos = state - // .payload::() - // .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; - // let pos = std::cmp::min( - // try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), - // self.string.len(), - // ); - // self.position.store(pos, atomic::Ordering::SeqCst); - // Ok(()) } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let internal = zelf.internal.write(); - if let IterStatus::Active(obj) = &internal.status { - let s = obj.payload::().unwrap(); + let mut internal = zelf.internal.write(); + if let IterStatus::Active(s) = &internal.status { let value = s.as_str(); if internal.position >= value.len() { internal.status = Exhausted; @@ -245,31 +219,6 @@ impl SlotIterator for PyStrIterator { } else { Err(vm.new_stop_iteration()) } - // if let Exhausted = zelf.status.load() { - // return Err(vm.new_stop_iteration()); - // } - // let value = &*zelf.string.as_str(); - // let mut start = zelf.position.load(atomic::Ordering::SeqCst); - // loop { - // if start == value.len() { - // zelf.status.store(Exhausted); - // return Err(vm.new_stop_iteration()); - // } - // let ch = value[start..].chars().next().ok_or_else(|| { - // zelf.status.store(Exhausted); - // vm.new_stop_iteration() - // })?; - - // match zelf.position.compare_exchange_weak( - // start, - // start + ch.len_utf8(), - // atomic::Ordering::Release, - // atomic::Ordering::Relaxed, - // ) { - // Ok(_) => break Ok(ch.into_pyobject(vm)), - // Err(cur) => start = cur, - // } - // } } } @@ -1304,10 +1253,7 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - // position: Radium::new(0), - // string: zelf, - // status: AtomicCell::new(Active), - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index f162f5891..7bcaa7988 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -1,3 +1,4 @@ +use super::PositionIterInternal; /* * Builtin set type with a sequence of unique items. */ @@ -15,7 +16,7 @@ use crate::{ IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; -use crossbeam_utils::atomic::AtomicCell; +use rustpython_common::lock::{PyRwLock, PyRwLockWriteGuard}; use std::fmt; pub type SetContentType = dictdatatype::Dict<()>; @@ -194,10 +195,12 @@ impl PySetInner { fn iter(&self) -> PySetIterator { PySetIterator { - dict: PyRc::clone(&self.content), size: self.content.size(), - position: AtomicCell::new(0), - status: AtomicCell::new(IterStatus::Active), + internal: PyRwLock::new(PositionIterInternal::new(self.content.clone(), 0)) + // dict: PyRc::clone(&self.content), + // size: self.content.size(), + // position: AtomicCell::new(0), + // status: AtomicCell::new(IterStatus::Active), } } @@ -815,10 +818,8 @@ impl TryFromObject for SetIterable { #[pyclass(module = false, name = "set_iterator")] pub(crate) struct PySetIterator { - dict: PyRc, size: DictSize, - position: AtomicCell, - status: AtomicCell, + internal: PyRwLock>>, } impl fmt::Debug for PySetIterator { @@ -837,26 +838,27 @@ impl PyValue for PySetIterator { #[pyimpl(with(SlotIterator))] impl PySetIterator { #[pymethod(magic)] - fn length_hint(&self) -> usize { - if let IterStatus::Exhausted = self.status.load() { - 0 - } else { - self.dict.len_from_entry_index(self.position.load()) - } + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .length_hint(|_| Some(self.size.entries_size), vm) + // if let IterStatus::Exhausted = self.status.load() { + // 0 + // } else { + // self.dict.len_from_entry_index(self.position.load()) + // } } #[pymethod(magic)] fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, (PyObjectRef,))> { + let internal = zelf.internal.read(); Ok(( vm.get_attribute(vm.builtins.clone(), "iter")?, - (vm.ctx.new_list(match zelf.status.load() { + (vm.ctx.new_list(match &internal.status { IterStatus::Exhausted => vec![], - IterStatus::Active => zelf - .dict - .keys() - .into_iter() - .skip(zelf.position.load()) - .collect(), + IterStatus::Active(dict) => { + dict.keys().into_iter().skip(internal.position).collect() + } }),), )) } @@ -864,25 +866,42 @@ impl PySetIterator { impl IteratorIterable for PySetIterator {} impl SlotIterator for PySetIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - match zelf.status.load() { - IterStatus::Exhausted => Ok(PyIterReturn::StopIteration(None)), - IterStatus::Active => { - if zelf.dict.has_changed_size(&zelf.size) { - zelf.status.store(IterStatus::Exhausted); - return Err( - vm.new_runtime_error("set changed size during iteration".to_owned()) - ); - } - match zelf.dict.next_entry_atomic(&zelf.position) { - Some((key, _)) => Ok(PyIterReturn::Return(key)), - None => { - zelf.status.store(IterStatus::Exhausted); - Ok(PyIterReturn::StopIteration(None)) - } + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let mut status = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.status); + let mut position = PyRwLockWriteGuard::map(zelf.internal.write(), |x| &mut x.position); + if let IterStatus::Active(dict) = &*status { + if dict.has_changed_size(&zelf.size) { + *status = IterStatus::Exhausted; + return Err(vm.new_runtime_error("set changed size during iteration".to_owned())); + } + match dict.next_entry(&mut *position) { + Some((key, _)) => Ok(key), + None => { + *status = IterStatus::Exhausted; + Err(vm.new_stop_iteration()) } } + } else { + Err(vm.new_stop_iteration()) } + // match zelf.status.load() { + // IterStatus::Exhausted => Err(vm.new_stop_iteration()), + // IterStatus::Active => { + // if zelf.dict.has_changed_size(&zelf.size) { + // zelf.status.store(IterStatus::Exhausted); + // return Err( + // vm.new_runtime_error("set changed size during iteration".to_owned()) + // ); + // } + // match zelf.dict.next_entry_atomic(&zelf.position) { + // Some((key, _)) => Ok(key), + // None => { + // zelf.status.store(IterStatus::Exhausted); + // Err(vm.new_stop_iteration()) + // } + // } + // } + // } } } diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 833a7552d..c8aa6b95e 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -310,7 +310,7 @@ impl Comparable for PyTuple { impl Iterable for PyTuple { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyTupleIterator { - internal: PyRwLock::new(PositionIterInternal::new(zelf.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(zelf, 0)), } .into_object(vm)) } @@ -319,7 +319,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] pub(crate) struct PyTupleIterator { - internal: PyRwLock, + internal: PyRwLock>, } impl PyValue for PyTupleIterator { @@ -332,9 +332,7 @@ impl PyValue for PyTupleIterator { impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] @@ -343,9 +341,10 @@ impl PyTupleIterator { } #[pymethod(magic)] - fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(self.internal.read().reduce(iter, vm)) + fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { + self.internal + .read() + .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -353,8 +352,7 @@ impl IteratorIterable for PyTupleIterator {} impl SlotIterator for PyTupleIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let tuple = obj.payload::().unwrap(); + |tuple, pos| { tuple .as_slice() .get(pos) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 996b75778..6bb2cbd8a 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -110,8 +110,8 @@ struct DictEntry { #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, - entries_size: usize, - used: usize, + pub entries_size: usize, + pub used: usize, filled: usize, } diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 8c4aa28c9..8809b2a4a 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -6,7 +6,7 @@ mod _collections { use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::{ builtins::{ - IterStatus::{self, Active, Exhausted}, + IterStatus::{Active, Exhausted}, PyInt, PyTypeRef, }, function::{FuncArgs, KwArgs, OptionalArg}, @@ -358,13 +358,10 @@ mod _collections { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyResult { - // let length = zelf.len(); + let position = zelf.len().saturating_sub(1); Ok(PyReverseDequeIterator { state: zelf.state.load(), - internal: PyRwLock::new(PositionIterInternal::new( - zelf.into_object(), - zelf.len() - 1, - )), + internal: PyRwLock::new(PositionIterInternal::new(zelf, position)), }) } @@ -575,12 +572,8 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug, PyValue)] struct PyDequeIterator { - // position: AtomicCell, - // status: AtomicCell, - // length: usize, // To track length immutability. - // deque: PyDequeRef, state: usize, - internal: PyRwLock, + internal: PyRwLock>, } #[derive(FromArgs)] @@ -600,16 +593,10 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - // let len = deque.len(); let iter = PyDequeIterator::new(deque); if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; iter.internal.write().position = index; - // iter.position.store(min(index, len)); - - // if len.le(&index) { - // iter.status.store(Exhausted); - // } } iter.into_pyresult_with_type(vm, cls) } @@ -619,35 +606,25 @@ mod _collections { impl PyDequeIterator { pub(crate) fn new(deque: PyDequeRef) -> Self { PyDequeIterator { - // position: AtomicCell::new(0), - // status: AtomicCell::new(IterStatus::Active), - // length: deque.len(), - // deque, state: deque.state.load(), - internal: PyRwLock::new(PositionIterInternal::new(deque.into_object(), 0)), + internal: PyRwLock::new(PositionIterInternal::new(deque, 0)), } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - self.internal - .read() - .length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => self.deque.len().saturating_sub(self.position.load()), - // Exhausted => 0, - // } + self.internal.read().length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> (PyTypeRef, (PyObjectRef, PyObjectRef)) { + ) -> (PyTypeRef, (PyDequeRef, PyObjectRef)) { let internal = zelf.internal.read(); let deque = match &internal.status { Active(obj) => obj.clone(), - Exhausted => PyDeque::default().into_object(vm), + Exhausted => PyDeque::default().into_ref(vm), }; ( zelf.clone_class(), @@ -660,8 +637,7 @@ mod _collections { impl SlotIterator for PyDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().next( - |obj, pos| { - let deque = obj.payload::().unwrap(); + |deque, pos| { if zelf.state != deque.state.load() { return Err( vm.new_runtime_error("Deque mutated during iteration".to_owned()) @@ -675,26 +651,6 @@ mod _collections { }, vm, ) - // match zelf.status.load() { - // Exhausted => Err(vm.new_stop_iteration()), - // Active => { - // if zelf.length != zelf.deque.len() { - // // Deque was changed while we iterated. - // zelf.status.store(Exhausted); - // Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - // } else { - // let pos = zelf.position.fetch_add(1); - // let deque = zelf.deque.borrow_deque(); - // if pos < deque.len() { - // let ret = deque[pos].clone(); - // Ok(ret) - // } else { - // zelf.status.store(Exhausted); - // Err(vm.new_stop_iteration()) - // } - // } - // } - // } } } @@ -702,12 +658,8 @@ mod _collections { #[pyclass(name = "_deque_reverse_iterator")] #[derive(Debug, PyValue)] struct PyReverseDequeIterator { - // position: AtomicCell, - // status: AtomicCell, - // length: usize, // To track length immutability. - // deque: PyDequeRef, state: usize, - internal: PyRwLock, + internal: PyRwLock>, } impl SlotConstructor for PyReverseDequeIterator { @@ -719,15 +671,10 @@ mod _collections { (DequeIterArgs { deque, index }, _kwargs): Self::Args, vm: &VirtualMachine, ) -> PyResult { - // let len = deque.len(); let iter = PyDeque::reversed(deque)?; if let OptionalArg::Present(index) = index { let index = max(index, 0) as usize; iter.internal.write().position = index; - // if len.le(&index) { - // iter.status.store(Exhausted); - // } - // iter.position.store(len.saturating_sub(index)); } iter.into_pyresult_with_type(vm, cls) } @@ -739,22 +686,18 @@ mod _collections { fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .read() - .rev_length_hint(|obj| obj.payload::().map(|x| x.len()), vm) - // match self.status.load() { - // Active => self.position.load(), - // Exhausted => 0, - // } + .rev_length_hint(|obj| Some(obj.len()), vm) } #[pymethod(magic)] fn reduce( zelf: PyRef, vm: &VirtualMachine, - ) -> PyResult<(PyTypeRef, (PyObjectRef, PyObjectRef))> { + ) -> PyResult<(PyTypeRef, (PyDequeRef, PyObjectRef))> { let internal = zelf.internal.read(); let deque = match &internal.status { Active(obj) => obj.clone(), - Exhausted => PyDeque::default().into_object(vm), + Exhausted => PyDeque::default().into_ref(vm), }; Ok(( zelf.clone_class(), @@ -767,8 +710,7 @@ mod _collections { impl SlotIterator for PyReverseDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.internal.write().rev_next( - |obj, pos| { - let deque = obj.payload::().unwrap(); + |deque, pos| { if deque.state.load() != zelf.state { return Err( vm.new_runtime_error("Deque mutated during iteration".to_owned()) @@ -782,32 +724,6 @@ mod _collections { }, vm, ) - // match zelf.status.load() { - // Exhausted => Err(vm.new_stop_iteration()), - // Active => { - // // If length changes while we iterate, set to Exhausted and bail. - // if zelf.length != zelf.deque.len() { - // zelf.status.store(Exhausted); - // Err(vm.new_runtime_error("Deque mutated during iteration".to_owned())) - // } else { - // let pos = zelf.position.fetch_sub(1) - 1; - // let deque = zelf.deque.borrow_deque(); - // if pos > 0 { - // if let Some(obj) = deque.get(pos) { - // return Ok(obj.clone()); - // } - // } - // // We either are == 0 or deque.get returned None. Either way, set status - // // to exhausted and return last item if pos == 0. - // zelf.status.store(Exhausted); - // if pos == 0 { - // // Can safely index directly. - // return Ok(deque[pos].clone()); - // } - // Err(vm.new_stop_iteration()) - // } - // } - // } } } }