Merge pull request #2709 from DimitrisJim/iter_exhaustion

Fix iteration issues, pickling by marking iterators as exhausted.
This commit is contained in:
Jeong YunWon
2021-07-12 16:08:55 +09:00
committed by GitHub
5 changed files with 66 additions and 18 deletions

View File

@@ -307,9 +307,6 @@ class TestCase(unittest.TestCase):
def test_iter_big_range(self):
self.check_for_loop(iter(range(10000)), list(range(10000)))
# Test an empty list
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_iter_empty(self):
self.check_for_loop(iter([]), [])
@@ -903,8 +900,6 @@ class TestCase(unittest.TestCase):
# This tests various things that weren't sink states in Python 2.2.1,
# plus various things that always were fine.
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_sinkstate_list(self):
# This used to fail
a = list(range(5))

View File

@@ -212,8 +212,6 @@ class TestList(TestInvariantWithoutMutations, unittest.TestCase):
def test_invariant(self):
super().test_invariant()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_mutation(self):
d = list(range(n))
it = iter(d)

View File

@@ -80,8 +80,6 @@ class ListTest(list_tests.CommonTest):
check(10) # check our checking code
check(1000000)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_iterator_pickle(self):
orig = self.type2test([4, 5, 6, 7])
data = [10, 11, 12, 13, 14, 15]

View File

@@ -12,6 +12,15 @@ use crate::{
TypeProtocol,
};
/// Marks status of iterator.
#[derive(Debug, Clone, Copy)]
pub enum IterStatus {
/// Iterator hasn't raised StopIteration.
Active,
/// Iterator has raised StopIteration.
Exhausted,
}
#[pyclass(module = false, name = "iter")]
#[derive(Debug)]
pub struct PySequenceIterator {
@@ -80,7 +89,7 @@ impl PyIter for PySequenceIterator {
pub struct PyCallableIterator {
callable: PyCallable,
sentinel: PyObjectRef,
done: AtomicCell<bool>,
status: AtomicCell<IterStatus>,
}
impl PyValue for PyCallableIterator {
@@ -95,21 +104,19 @@ impl PyCallableIterator {
Self {
callable,
sentinel,
done: AtomicCell::new(false),
status: AtomicCell::new(IterStatus::Active),
}
}
}
impl PyIter for PyCallableIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
if zelf.done.load() {
if let IterStatus::Exhausted = zelf.status.load() {
return Err(vm.new_stop_iteration());
}
let ret = zelf.callable.invoke((), vm)?;
if vm.bool_eq(&ret, &zelf.sentinel)? {
zelf.done.store(true);
zelf.status.store(IterStatus::Exhausted);
Err(vm.new_stop_iteration())
} else {
Ok(ret)

View File

@@ -5,8 +5,14 @@ use std::ops::DerefMut;
use crossbeam_utils::atomic::AtomicCell;
use super::int;
use super::iter::{
IterStatus,
IterStatus::{Active, Exhausted},
};
use super::pytype::PyTypeRef;
use super::slice::PySliceRef;
use super::PyInt;
use crate::common::lock::{
PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard,
};
@@ -402,6 +408,7 @@ impl Iterable for PyList {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(PyListIterator {
position: AtomicCell::new(0),
status: AtomicCell::new(Active),
list: zelf,
}
.into_object(vm))
@@ -458,6 +465,7 @@ fn do_sort(
#[derive(Debug)]
pub struct PyListIterator {
pub position: AtomicCell<usize>,
status: AtomicCell<IterStatus>,
pub list: PyListRef,
}
@@ -471,19 +479,61 @@ impl PyValue for PyListIterator {
impl PyListIterator {
#[pymethod(name = "__length_hint__")]
fn length_hint(&self) -> usize {
let list = self.list.borrow_vec();
let pos = self.position.load();
list.len().saturating_sub(pos)
match self.status.load() {
Active => {
let list = self.list.borrow_vec();
let pos = self.position.load();
list.len().saturating_sub(pos)
}
Exhausted => 0,
}
}
#[pymethod(name = "__setstate__")]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// When we're exhausted, just return.
if let Exhausted = self.status.load() {
return Ok(());
}
if let Some(i) = state.payload::<PyInt>() {
let position = std::cmp::min(
int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0),
self.list.len(),
);
self.position.store(position);
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(match self.status.load() {
Exhausted => vm
.ctx
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
Active => vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.list.clone().into_object()]),
vm.ctx.new_int(self.position.load()),
]),
})
}
}
impl PyIter for PyListIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
if let Exhausted = zelf.status.load() {
return Err(vm.new_stop_iteration());
}
let list = zelf.list.borrow_vec();
let pos = zelf.position.fetch_add(1);
if let Some(obj) = list.get(pos) {
Ok(obj.clone())
} else {
zelf.status.store(Exhausted);
Err(vm.new_stop_iteration())
}
}