diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index c3be54b06..fe9266555 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -1,14 +1,15 @@ use crate::common::lock::PyRwLock; +use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::Zero; use super::int::PyIntRef; use super::pytype::PyTypeRef; use crate::function::OptionalArg; -use crate::iterator; use crate::slots::PyIter; use crate::vm::VirtualMachine; +use crate::{iterator, ItemProtocol, TypeProtocol}; use crate::{IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; #[pyclass(module = false, name = "enumerate")] @@ -60,6 +61,52 @@ impl PyIter for PyEnumerate { } } +#[pyclass(module = false, name = "reversed")] +#[derive(Debug)] +pub struct PyReverseSequenceIterator { + pub position: AtomicCell, + pub obj: PyObjectRef, +} + +impl PyValue for PyReverseSequenceIterator { + fn class(vm: &VirtualMachine) -> &PyTypeRef { + &vm.ctx.types.reverse_iter_type + } +} + +#[pyimpl(with(PyIter))] +impl PyReverseSequenceIterator { + pub fn new(obj: PyObjectRef, len: isize) -> Self { + Self { + position: AtomicCell::new(len - 1), + obj, + } + } + + #[pymethod(magic)] + fn length_hint(&self) -> PyResult { + Ok(self.position.load() + 1) + } +} + +impl PyIter for PyReverseSequenceIterator { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let pos = zelf.position.fetch_sub(1); + if pos >= 0 { + match zelf.obj.get_item(pos, vm) { + Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + Err(vm.new_stop_iteration()) + } + // also catches stop_iteration => stop_iteration + ret => ret, + } + } else { + Err(vm.new_stop_iteration()) + } + } +} + pub fn init(context: &PyContext) { PyEnumerate::extend_class(context, &context.types.enumerate_type); + PyReverseSequenceIterator::extend_class(context, &context.types.reverse_iter_type); } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 1e7954349..cad0eae48 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -21,12 +21,11 @@ pub enum IterStatus { Exhausted, } -#[pyclass(module = false, name = "iter")] +#[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { pub position: AtomicCell, pub obj: PyObjectRef, - pub reversed: bool, } impl PyValue for PySequenceIterator { @@ -41,45 +40,26 @@ impl PySequenceIterator { Self { position: AtomicCell::new(0), obj, - reversed: false, - } - } - - pub fn new_reversed(obj: PyObjectRef, len: isize) -> Self { - Self { - position: AtomicCell::new(len - 1), - obj, - reversed: true, } } #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyResult { let pos = self.position.load(); - let hint = if self.reversed { - pos + 1 - } else { - let len = vm.obj_len(&self.obj)?; - len as isize - pos - }; - Ok(hint) + let len = vm.obj_len(&self.obj)?; + Ok(len as isize - pos) } } impl PyIter for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let step: isize = if zelf.reversed { -1 } else { 1 }; - let pos = zelf.position.fetch_add(step); - if pos >= 0 { - match zelf.obj.get_item(pos, vm) { - Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - Err(vm.new_stop_iteration()) - } - // also catches stop_iteration => stop_iteration - ret => ret, + let pos = zelf.position.fetch_add(1); + match zelf.obj.get_item(pos, vm) { + Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + Err(vm.new_stop_iteration()) } - } else { - Err(vm.new_stop_iteration()) + // also catches stop_iteration => stop_iteration + ret => ret, } } } diff --git a/vm/src/builtins/make_module.rs b/vm/src/builtins/make_module.rs index 0fcdfc47c..e2434f52d 100644 --- a/vm/src/builtins/make_module.rs +++ b/vm/src/builtins/make_module.rs @@ -12,9 +12,10 @@ mod decl { use crate::builtins::bytes::PyBytesRef; use crate::builtins::code::PyCodeRef; use crate::builtins::dict::PyDictRef; + use crate::builtins::enumerate::PyReverseSequenceIterator; use crate::builtins::function::{PyCellRef, PyFunctionRef}; use crate::builtins::int::{self, PyIntRef}; - use crate::builtins::iter::{PyCallableIterator, PySequenceIterator}; + use crate::builtins::iter::PyCallableIterator; use crate::builtins::list::{PyList, SortOptions}; use crate::builtins::pybool::IntoPyBool; use crate::builtins::pystr::{PyStr, PyStrRef}; @@ -720,7 +721,7 @@ mod decl { "argument to reversed() must be a sequence".to_owned() })?; let len = vm.obj_len(&obj)? as isize; - let obj_iterator = PySequenceIterator::new_reversed(obj, len); + let obj_iterator = PyReverseSequenceIterator::new(obj, len); Ok(obj_iterator.into_object(vm)) } } diff --git a/vm/src/types.rs b/vm/src/types.rs index a5443334e..68e50b66c 100644 --- a/vm/src/types.rs +++ b/vm/src/types.rs @@ -68,6 +68,7 @@ pub struct TypeZoo { pub generator_type: PyTypeRef, pub int_type: PyTypeRef, pub iter_type: PyTypeRef, + pub reverse_iter_type: PyTypeRef, pub complex_type: PyTypeRef, pub list_type: PyTypeRef, pub list_iterator_type: PyTypeRef, @@ -182,6 +183,7 @@ impl TypeZoo { generator_type: generator::PyGenerator::init_bare_type().clone(), getset_type: getset::PyGetSet::init_bare_type().clone(), iter_type: iter::PySequenceIterator::init_bare_type().clone(), + reverse_iter_type: enumerate::PyReverseSequenceIterator::init_bare_type().clone(), list_iterator_type: list::PyListIterator::init_bare_type().clone(), list_reverseiterator_type: list::PyListReverseIterator::init_bare_type().clone(), mappingproxy_type: mappingproxy::PyMappingProxy::init_bare_type().clone(),