From 0c6375c0312c9151f24f5ed4a0cfcbef38e2e013 Mon Sep 17 00:00:00 2001 From: jfh Date: Thu, 12 Aug 2021 18:38:27 +0300 Subject: [PATCH] Fix pickling, length hint, iteration for reverse list iterators. Refactor common functionality. --- Lib/test/test_iterlen.py | 2 - Lib/test/test_list.py | 2 - vm/src/builtins/list.rs | 129 ++++++++++++++++++++++++++++++--------- 3 files changed, 101 insertions(+), 32 deletions(-) diff --git a/Lib/test/test_iterlen.py b/Lib/test/test_iterlen.py index f05b716e0..db01b5487 100644 --- a/Lib/test/test_iterlen.py +++ b/Lib/test/test_iterlen.py @@ -221,8 +221,6 @@ class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase): def test_invariant(self): super().test_invariant() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_mutation(self): d = list(range(n)) it = reversed(d) diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py index 87c1d2946..1a6eee45f 100644 --- a/Lib/test/test_list.py +++ b/Lib/test/test_list.py @@ -116,8 +116,6 @@ class ListTest(list_tests.CommonTest): a[:] = data self.assertEqual(list(it), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_reversed_pickle(self): orig = self.type2test([4, 5, 6, 7]) data = [10, 11, 12, 13, 14, 15] diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index e374c0386..4fa79e04b 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -164,8 +164,14 @@ impl PyList { #[pymethod(magic)] fn reversed(zelf: PyRef) -> PyListReverseIterator { let final_position = zelf.borrow_vec().len(); + // Mark iterator as exhausted immediately if its empty. PyListReverseIterator { - position: AtomicCell::new(final_position as isize), + position: AtomicCell::new(final_position.saturating_sub(1)), + status: AtomicCell::new(if final_position == 0 { + Exhausted + } else { + Active + }), list: zelf, } } @@ -482,7 +488,7 @@ impl PyValue for PyListIterator { #[pyimpl(with(PyIter))] impl PyListIterator { - #[pymethod(name = "__length_hint__")] + #[pymethod(magic)] fn length_hint(&self) -> usize { match self.status.load() { Active => { @@ -500,31 +506,19 @@ impl PyListIterator { if let Exhausted = self.status.load() { return Ok(()); } - if let Some(i) = state.payload::() { - let position = std::cmp::min( - int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0), - self.list.len(), - ); - self.position.store(position); - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } + let position = list_state(self.list.len(), state, vm)?; + self.position.store(position); + Ok(()) } #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> PyResult { - let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; - Ok(match self.status.load() { - Exhausted => vm - .ctx - .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), - Active => vm.ctx.new_tuple(vec![ - iter, - vm.ctx.new_tuple(vec![self.list.clone().into_object()]), - vm.ctx.new_int(self.position.load()), - ]), - }) + let pos = if let Exhausted = self.status.load() { + None + } else { + Some(self.position.load()) + }; + list_reduce(self.list.clone(), pos, false, vm) } } @@ -547,7 +541,8 @@ impl PyIter for PyListIterator { #[pyclass(module = false, name = "list_reverseiterator")] #[derive(Debug)] pub struct PyListReverseIterator { - pub position: AtomicCell, + pub position: AtomicCell, + pub status: AtomicCell, pub list: PyListRef, } @@ -559,25 +554,103 @@ impl PyValue for PyListReverseIterator { #[pyimpl(with(PyIter))] impl PyListReverseIterator { - #[pymethod(name = "__length_hint__")] + #[pymethod(magic)] fn length_hint(&self) -> usize { - std::cmp::max(self.position.load(), 0) as usize + match self.status.load() { + Active => { + let position = self.position.load(); + if position > self.list.len() { + // List was mutated. Report zero, next call to `__next__` will + // fail and set iterator to Exhausted. + 0 + } else { + position + 1 + } + } + Exhausted => 0, + } + } + + #[pymethod(magic)] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + // When we're exhausted, just return. + if let Exhausted = self.status.load() { + return Ok(()); + } + + // Max for position is list.len() - 1. + let position = list_state(self.list.len().saturating_sub(1), state, vm)?; + self.position.store(position); + Ok(()) + } + + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyResult { + let pos = if let Exhausted = self.status.load() { + None + } else { + Some(self.position.load()) + }; + list_reduce(self.list.clone(), pos, true, vm) } } impl PyIter for PyListReverseIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + if let Exhausted = zelf.status.load() { + return Err(vm.new_stop_iteration()); + } let list = zelf.list.borrow_vec(); let pos = zelf.position.fetch_sub(1); if pos > 0 { - if let Some(ret) = list.get(pos as usize - 1) { - return Ok(ret.clone()); + if let Some(obj) = list.get(pos) { + return Ok(obj.clone()); + } + } + // We either are == 0 or list.get returned None. Either way, set status + // to exhausted and return last item if pos == 0. + zelf.status.store(Exhausted); + if pos == 0 { + if let Some(obj) = list.get(pos) { + return Ok(obj.clone()); } } Err(vm.new_stop_iteration()) } } +// Common reducer for forward and reverse list iterators. +fn list_reduce( + list: PyRef, + position: Option, + reverse: bool, + vm: &VirtualMachine, +) -> PyResult { + let attr = if reverse { "reversed" } else { "iter" }; + let iter = vm.get_attribute(vm.builtins.clone(), attr)?; + let elems = match position { + None => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])], + Some(position) => vec![ + iter, + vm.ctx.new_tuple(vec![list.into_object()]), + vm.ctx.new_int(position), + ], + }; + Ok(vm.ctx.new_tuple(elems)) +} + +// Common function to extract state. Clamps it in range [0, length]. +fn list_state(length: usize, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let position = state + .payload::() + .ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?; + let position = std::cmp::min( + int::try_to_primitive(position.as_bigint(), vm).unwrap_or(0), + length, + ); + Ok(position) +} + pub fn init(context: &PyContext) { let list_type = &context.types.list_type; PyList::extend_class(context, list_type);