From 95ba82bd5424fccf2e276fe340b08774bc3ad2e0 Mon Sep 17 00:00:00 2001 From: jfh Date: Mon, 14 Jun 2021 19:32:22 +0300 Subject: [PATCH 1/2] Fix iteration issues, pickling by marking iterators as exhausted. --- Lib/test/test_iter.py | 5 ---- Lib/test/test_iterlen.py | 2 -- Lib/test/test_list.py | 2 -- vm/src/builtins/iter.rs | 19 +++++++++----- vm/src/builtins/list.rs | 56 +++++++++++++++++++++++++++++++++++++--- 5 files changed, 66 insertions(+), 18 deletions(-) diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index 015e8ae19..3f4ff24ea 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -307,9 +307,6 @@ class TestCase(unittest.TestCase): def test_iter_big_range(self): self.check_for_loop(iter(range(10000)), list(range(10000))) - # Test an empty list - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iter_empty(self): self.check_for_loop(iter([]), []) @@ -903,8 +900,6 @@ class TestCase(unittest.TestCase): # This tests various things that weren't sink states in Python 2.2.1, # plus various things that always were fine. - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_sinkstate_list(self): # This used to fail a = list(range(5)) diff --git a/Lib/test/test_iterlen.py b/Lib/test/test_iterlen.py index 46b470f67..3bedfdddb 100644 --- a/Lib/test/test_iterlen.py +++ b/Lib/test/test_iterlen.py @@ -212,8 +212,6 @@ class TestList(TestInvariantWithoutMutations, unittest.TestCase): def test_invariant(self): super().test_invariant() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_mutation(self): d = list(range(n)) it = iter(d) diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py index 1c96ddffd..87c1d2946 100644 --- a/Lib/test/test_list.py +++ b/Lib/test/test_list.py @@ -80,8 +80,6 @@ class ListTest(list_tests.CommonTest): check(10) # check our checking code check(1000000) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iterator_pickle(self): orig = self.type2test([4, 5, 6, 7]) data = [10, 11, 12, 13, 14, 15] diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index c4877c6b5..c10f6bd81 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -12,6 +12,15 @@ use crate::{ TypeProtocol, }; +/// Marks status of iterator. +#[derive(Debug, Clone, Copy)] +pub enum IterStatus { + /// Iterator hasn't raised StopIteration. + Active, + /// Iterator has raised StopIteration. + Exhausted, +} + #[pyclass(module = false, name = "iter")] #[derive(Debug)] pub struct PySequenceIterator { @@ -80,7 +89,7 @@ impl PyIter for PySequenceIterator { pub struct PyCallableIterator { callable: PyCallable, sentinel: PyObjectRef, - done: AtomicCell, + status: AtomicCell, } impl PyValue for PyCallableIterator { @@ -95,21 +104,19 @@ impl PyCallableIterator { Self { callable, sentinel, - done: AtomicCell::new(false), + status: AtomicCell::new(IterStatus::Active), } } } impl PyIter for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - if zelf.done.load() { + if let IterStatus::Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); } - let ret = zelf.callable.invoke((), vm)?; - if vm.bool_eq(&ret, &zelf.sentinel)? { - zelf.done.store(true); + zelf.status.store(IterStatus::Exhausted); Err(vm.new_stop_iteration()) } else { Ok(ret) diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 91f57befa..d4cfa323a 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -5,8 +5,14 @@ use std::ops::DerefMut; use crossbeam_utils::atomic::AtomicCell; +use super::int; +use super::iter::{ + IterStatus, + IterStatus::{Active, Exhausted}, +}; use super::pytype::PyTypeRef; use super::slice::PySliceRef; +use super::PyInt; use crate::common::lock::{ PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; @@ -402,6 +408,7 @@ impl Iterable for PyList { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyListIterator { position: AtomicCell::new(0), + status: AtomicCell::new(Active), list: zelf, } .into_object(vm)) @@ -458,6 +465,7 @@ fn do_sort( #[derive(Debug)] pub struct PyListIterator { pub position: AtomicCell, + status: AtomicCell, pub list: PyListRef, } @@ -471,19 +479,61 @@ impl PyValue for PyListIterator { impl PyListIterator { #[pymethod(name = "__length_hint__")] fn length_hint(&self) -> usize { - let list = self.list.borrow_vec(); - let pos = self.position.load(); - list.len().saturating_sub(pos) + match self.status.load() { + Active => { + let list = self.list.borrow_vec(); + let pos = self.position.load(); + list.len().saturating_sub(pos) + } + Exhausted => 0, + } + } + + #[pymethod(name = "__setstate__")] + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // When we're exhausted, just return. + if let Exhausted = self.status.load() { + return Ok(vm.ctx.none()); + } + if let Some(i) = state.payload::() { + let position = std::cmp::min( + int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0), + self.list.len(), + ); + self.position.store(position); + Ok(vm.ctx.none()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } + + #[pymethod(magic)] + fn reduce(&self, vm: &VirtualMachine) -> PyResult { + let iter = vm.get_attribute(vm.builtins.clone(), "iter")?; + Ok(match self.status.load() { + Exhausted => vm + .ctx + .new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]), + Active => vm.ctx.new_tuple(vec![ + iter, + vm.ctx.new_tuple(vec![self.list.clone().into_object()]), + vm.ctx.new_int(self.position.load()), + ]), + }) } } impl PyIter for PyListIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + if let Exhausted = zelf.status.load() { + return Err(vm.new_stop_iteration()); + } let list = zelf.list.borrow_vec(); let pos = zelf.position.fetch_add(1); if let Some(obj) = list.get(pos) { Ok(obj.clone()) } else { + zelf.status.store(Exhausted); Err(vm.new_stop_iteration()) } } From 86cef6738d632eb8cfd6a464b86f643e1f098dc9 Mon Sep 17 00:00:00 2001 From: jfh Date: Mon, 12 Jul 2021 00:22:03 +0300 Subject: [PATCH 2/2] Address review comments. --- vm/src/builtins/list.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index d4cfa323a..71df39b76 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -490,10 +490,10 @@ impl PyListIterator { } #[pymethod(name = "__setstate__")] - fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { // When we're exhausted, just return. if let Exhausted = self.status.load() { - return Ok(vm.ctx.none()); + return Ok(()); } if let Some(i) = state.payload::() { let position = std::cmp::min( @@ -501,7 +501,7 @@ impl PyListIterator { self.list.len(), ); self.position.store(position); - Ok(vm.ctx.none()) + Ok(()) } else { Err(vm.new_type_error("an integer is required.".to_owned())) }