diff --git a/Lib/test/test_enumerate.py b/Lib/test/test_enumerate.py index 6f0807ec6..92d52c3d1 100644 --- a/Lib/test/test_enumerate.py +++ b/Lib/test/test_enumerate.py @@ -172,8 +172,6 @@ class TestReversed(unittest.TestCase, PickleTest): x = range(1) self.assertEqual(type(reversed(x)), type(iter(x))) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_len(self): for s in ('hello', tuple('hello'), list('hello'), range(5)): self.assertEqual(operator.length_hint(reversed(s)), len(s)) @@ -243,8 +241,6 @@ class TestReversed(unittest.TestCase, PickleTest): b = Blocked() self.assertRaises(TypeError, reversed, b) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pickle(self): for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5): self.check_pickle(reversed(data), list(data)[::-1]) diff --git a/Lib/test/test_tuple.py b/Lib/test/test_tuple.py index 77fcf1b5d..38704f463 100644 --- a/Lib/test/test_tuple.py +++ b/Lib/test/test_tuple.py @@ -361,8 +361,6 @@ class TupleTest(seq_tests.CommonTest): d = pickle.dumps(it, proto) self.assertEqual(self.type2test(it), self.type2test(data)[1:]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_reversed_pickle(self): data = self.type2test([4, 5, 6, 7]) for proto in range(pickle.HIGHEST_PROTOCOL + 1): diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index fe9266555..45c4f34bb 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -4,7 +4,11 @@ use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::Zero; -use super::int::PyIntRef; +use super::int::{try_to_primitive, PyInt, PyIntRef}; +use super::iter::{ + IterStatus, + IterStatus::{Active, Exhausted}, +}; use super::pytype::PyTypeRef; use crate::function::OptionalArg; use crate::slots::PyIter; @@ -64,7 +68,8 @@ impl PyIter for PyEnumerate { #[pyclass(module = false, name = "reversed")] #[derive(Debug)] pub struct PyReverseSequenceIterator { - pub position: AtomicCell, + pub position: AtomicCell, + pub status: AtomicCell, pub obj: PyObjectRef, } @@ -76,32 +81,77 @@ impl PyValue for PyReverseSequenceIterator { #[pyimpl(with(PyIter))] impl PyReverseSequenceIterator { - pub fn new(obj: PyObjectRef, len: isize) -> Self { + pub fn new(obj: PyObjectRef, len: usize) -> Self { Self { - position: AtomicCell::new(len - 1), + position: AtomicCell::new(len.saturating_sub(1)), + status: AtomicCell::new(if len == 0 { Exhausted } else { Active }), obj, } } #[pymethod(magic)] - fn length_hint(&self) -> PyResult { - Ok(self.position.load() + 1) + fn length_hint(&self, vm: &VirtualMachine) -> PyResult { + Ok(match self.status.load() { + Active => { + let position = self.position.load(); + if position > vm.obj_len(&self.obj)? { + 0 + } else { + position + 1 + } + } + Exhausted => 0, + }) + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // When we're exhausted, just return. + if let Exhausted = self.status.load() { + return Ok(()); + } + let len = vm.obj_len(&self.obj)?; + let pos = state + .payload::() + .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; + let pos = std::cmp::min( + try_to_primitive(pos.as_bigint(), vm).unwrap_or(0), + len.saturating_sub(1), + ); + self.position.store(pos); + Ok(()) + } + + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyResult { + let iter = vm.get_attribute(vm.builtins.clone(), "reversed")?; + Ok(vm.ctx.new_tuple(match self.status.load() { + Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_tuple(vec![])])], + Active => vec![ + iter, + vm.ctx.new_tuple(vec![self.obj.clone()]), + vm.ctx.new_int(self.position.load()), + ], + })) } } impl PyIter for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + if let Exhausted = zelf.status.load() { + return Err(vm.new_stop_iteration()); + } let pos = zelf.position.fetch_sub(1); - if pos >= 0 { - match zelf.obj.get_item(pos, vm) { - Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { - Err(vm.new_stop_iteration()) - } - // also catches stop_iteration => stop_iteration - ret => ret, + if pos == 0 { + zelf.status.store(Exhausted); + } + match zelf.obj.get_item(pos, vm) { + Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + zelf.status.store(Exhausted); + Err(vm.new_stop_iteration()) } - } else { - Err(vm.new_stop_iteration()) + // also catches stop_iteration => stop_iteration + ret => ret, } } } diff --git a/vm/src/builtins/make_module.rs b/vm/src/builtins/make_module.rs index e2434f52d..f13f3f083 100644 --- a/vm/src/builtins/make_module.rs +++ b/vm/src/builtins/make_module.rs @@ -720,7 +720,7 @@ mod decl { vm.get_method_or_type_error(obj.clone(), "__getitem__", || { "argument to reversed() must be a sequence".to_owned() })?; - let len = vm.obj_len(&obj)? as isize; + let len = vm.obj_len(&obj)?; let obj_iterator = PyReverseSequenceIterator::new(obj, len); Ok(obj_iterator.into_object(vm)) } diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 200be86eb..fafcf222a 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -202,45 +202,6 @@ impl PyIter for PyStrIterator { } } -#[pyclass(module = false, name = "str_reverseiterator")] -#[derive(Debug)] -pub struct PyStrReverseIterator { - string: PyStrRef, - position: PyAtomic, -} - -impl PyValue for PyStrReverseIterator { - fn class(vm: &VirtualMachine) -> &PyTypeRef { - &vm.ctx.types.str_reverseiterator_type - } -} - -#[pyimpl(with(PyIter))] -impl PyStrReverseIterator {} - -impl PyIter for PyStrReverseIterator { - fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let value = &*zelf.string.value; - let mut end = zelf.position.load(atomic::Ordering::Relaxed); - loop { - let ch = value[..end] - .chars() - .next_back() - .ok_or_else(|| vm.new_stop_iteration())?; - - match zelf.position.compare_exchange_weak( - end, - end - ch.len_utf8(), - atomic::Ordering::Release, - atomic::Ordering::Relaxed, - ) { - Ok(_) => break Ok(ch.into_pyobject(vm)), - Err(cur) => end = cur, - } - } - } -} - #[derive(FromArgs)] struct StrArgs { #[pyarg(any, optional)] @@ -1126,14 +1087,6 @@ impl PyStr { fn encode(zelf: PyRef, args: EncodeArgs, vm: &VirtualMachine) -> PyResult { encode_string(zelf, args.encoding, args.errors, vm) } - - #[pymethod(magic)] - fn reversed(zelf: PyRef) -> PyStrReverseIterator { - PyStrReverseIterator { - position: Radium::new(zelf.byte_len()), - string: zelf, - } - } } impl PyStrRef { @@ -1254,7 +1207,6 @@ pub fn init(ctx: &PyContext) { PyStr::extend_class(ctx, &ctx.types.str_type); PyStrIterator::extend_class(ctx, &ctx.types.str_iterator_type); - PyStrReverseIterator::extend_class(ctx, &ctx.types.str_reverseiterator_type); } impl PySliceableSequence for PyStr { diff --git a/vm/src/types.rs b/vm/src/types.rs index 68e50b66c..bc30f2689 100644 --- a/vm/src/types.rs +++ b/vm/src/types.rs @@ -74,7 +74,6 @@ pub struct TypeZoo { pub list_iterator_type: PyTypeRef, pub list_reverseiterator_type: PyTypeRef, pub str_iterator_type: PyTypeRef, - pub str_reverseiterator_type: PyTypeRef, pub dict_keyiterator_type: PyTypeRef, pub dict_reversekeyiterator_type: PyTypeRef, pub dict_valueiterator_type: PyTypeRef, @@ -193,7 +192,6 @@ impl TypeZoo { longrange_iterator_type: range::PyLongRangeIterator::init_bare_type().clone(), set_iterator_type: set::PySetIterator::init_bare_type().clone(), str_iterator_type: pystr::PyStrIterator::init_bare_type().clone(), - str_reverseiterator_type: pystr::PyStrReverseIterator::init_bare_type().clone(), traceback_type: traceback::PyTraceback::init_bare_type().clone(), tuple_iterator_type: tuple::PyTupleIterator::init_bare_type().clone(), weakproxy_type: weakproxy::PyWeakProxy::init_bare_type().clone(),