From 52149d0e708261eb69456e81a2e62e79e1f657d1 Mon Sep 17 00:00:00 2001 From: jfh Date: Thu, 19 Aug 2021 13:46:14 +0300 Subject: [PATCH] Add remaining methods to sequence iterator. --- Lib/test/test_iter.py | 12 --------- vm/src/builtins/iter.rs | 57 ++++++++++++++++++++++++++++++++++++----- vm/src/iterator.rs | 2 +- vm/src/pyobject.rs | 2 +- 4 files changed, 53 insertions(+), 20 deletions(-) diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index a2e440a61..748ad5e93 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -155,19 +155,13 @@ class TestCase(unittest.TestCase): self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10))) # Test for loop on a sequence class without __iter__ - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_seq_class_for(self): self.check_for_loop(SequenceClass(10), list(range(10))) # Test iter() on a sequence class without __iter__ - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_seq_class_iter(self): self.check_iterator(iter(SequenceClass(10)), list(range(10))) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_mutating_seq_class_iter_pickle(self): orig = SequenceClass(5) for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -204,8 +198,6 @@ class TestCase(unittest.TestCase): self.assertTrue(isinstance(it, collections.abc.Iterator)) self.assertEqual(list(it), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_mutating_seq_class_exhausted_iter(self): a = SequenceClass(5) exhit = iter(a) @@ -908,8 +900,6 @@ class TestCase(unittest.TestCase): self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e']) self.assertEqual(list(b), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_sinkstate_sequence(self): # This used to fail a = SequenceClass(5) @@ -1004,8 +994,6 @@ class TestCase(unittest.TestCase): with self.assertRaises(OverflowError): next(it) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iter_neg_setstate(self): it = iter(UnlimitedSequenceClass()) it.__setstate__(-42) diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index cad0eae48..d3264b417 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -5,6 +5,7 @@ use crossbeam_utils::atomic::AtomicCell; use super::pytype::PyTypeRef; +use super::{int, PyInt}; use crate::slots::PyIter; use crate::vm::VirtualMachine; use crate::{ @@ -24,8 +25,9 @@ pub enum IterStatus { #[pyclass(module = false, name = "iterator")] #[derive(Debug)] pub struct PySequenceIterator { - pub position: AtomicCell, + pub position: AtomicCell, pub obj: PyObjectRef, + pub status: AtomicCell, } impl PyValue for PySequenceIterator { @@ -36,26 +38,69 @@ impl PyValue for PySequenceIterator { #[pyimpl(with(PyIter))] impl PySequenceIterator { - pub fn new_forward(obj: PyObjectRef) -> Self { + pub fn new(obj: PyObjectRef) -> Self { Self { position: AtomicCell::new(0), obj, + status: AtomicCell::new(IterStatus::Active), } } #[pymethod(magic)] - fn length_hint(&self, vm: &VirtualMachine) -> PyResult { - let pos = self.position.load(); - let len = vm.obj_len(&self.obj)?; - Ok(len as isize - pos) + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + match self.status.load() { + IterStatus::Active => { + let pos = self.position.load(); + // return NotImplemented if no length is around. + vm.obj_len(&self.obj) + .map_or(vm.ctx.not_implemented(), |len| { + PyInt::from(len.saturating_sub(pos)).into_object(vm) + }) + } + IterStatus::Exhausted => PyInt::from(0).into_object(vm), + } + } + + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyResult { + let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + Ok(match self.status.load() { + IterStatus::Exhausted => vm + .ctx + .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), + IterStatus::Active => vm.ctx.new_tuple(vec![ + iter, + vm.ctx.new_tuple(vec![self.obj.clone()]), + vm.ctx.new_int(self.position.load()), + ]), + }) + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // When we're exhausted, just return. + if let IterStatus::Exhausted = self.status.load() { + return Ok(()); + } + if let Some(i) = state.payload::() { + self.position + .store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0)); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } } } impl PyIter for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + if let IterStatus::Exhausted = zelf.status.load() { + return Err(vm.new_stop_iteration()); + } let pos = zelf.position.fetch_add(1); match zelf.obj.get_item(pos, vm) { Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => { + zelf.status.store(IterStatus::Exhausted); Err(vm.new_stop_iteration()) } // also catches stop_iteration => stop_iteration diff --git a/vm/src/iterator.rs b/vm/src/iterator.rs index f8708fa02..2bf41144e 100644 --- a/vm/src/iterator.rs +++ b/vm/src/iterator.rs @@ -36,7 +36,7 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult { vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || { format!("'{}' object is not iterable", iter_target.class().name) })?; - Ok(PySequenceIterator::new_forward(iter_target) + Ok(PySequenceIterator::new(iter_target) .into_ref(vm) .into_object()) } diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index a1598fb5f..ffcb70831 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -741,7 +741,7 @@ impl PyIterable { pub fn iter<'a>(&self, vm: &'a VirtualMachine) -> PyResult> { let iter_obj = match self.iterfn { Some(f) => f(self.iterable.clone(), vm)?, - None => PySequenceIterator::new_forward(self.iterable.clone()).into_object(vm), + None => PySequenceIterator::new(self.iterable.clone()).into_object(vm), }; let length_hint = iterator::length_hint(vm, iter_obj.clone())?;