Merge pull request #2913 from DimitrisJim/iter_meths

Add remaining methods to sequence iterator.
This commit is contained in:
Jeong YunWon
2021-08-19 20:39:20 +09:00
committed by GitHub
4 changed files with 53 additions and 20 deletions

View File

@@ -155,19 +155,13 @@ class TestCase(unittest.TestCase):
self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
# Test for loop on a sequence class without __iter__
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_seq_class_for(self):
self.check_for_loop(SequenceClass(10), list(range(10)))
# Test iter() on a sequence class without __iter__
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_seq_class_iter(self):
self.check_iterator(iter(SequenceClass(10)), list(range(10)))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_mutating_seq_class_iter_pickle(self):
orig = SequenceClass(5)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -204,8 +198,6 @@ class TestCase(unittest.TestCase):
self.assertTrue(isinstance(it, collections.abc.Iterator))
self.assertEqual(list(it), [])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_mutating_seq_class_exhausted_iter(self):
a = SequenceClass(5)
exhit = iter(a)
@@ -908,8 +900,6 @@ class TestCase(unittest.TestCase):
self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
self.assertEqual(list(b), [])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_sinkstate_sequence(self):
# This used to fail
a = SequenceClass(5)
@@ -1004,8 +994,6 @@ class TestCase(unittest.TestCase):
with self.assertRaises(OverflowError):
next(it)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_iter_neg_setstate(self):
it = iter(UnlimitedSequenceClass())
it.__setstate__(-42)

View File

@@ -5,6 +5,7 @@
use crossbeam_utils::atomic::AtomicCell;
use super::pytype::PyTypeRef;
use super::{int, PyInt};
use crate::slots::PyIter;
use crate::vm::VirtualMachine;
use crate::{
@@ -24,8 +25,9 @@ pub enum IterStatus {
#[pyclass(module = false, name = "iterator")]
#[derive(Debug)]
pub struct PySequenceIterator {
pub position: AtomicCell<isize>,
pub position: AtomicCell<usize>,
pub obj: PyObjectRef,
pub status: AtomicCell<IterStatus>,
}
impl PyValue for PySequenceIterator {
@@ -36,26 +38,69 @@ impl PyValue for PySequenceIterator {
#[pyimpl(with(PyIter))]
impl PySequenceIterator {
pub fn new_forward(obj: PyObjectRef) -> Self {
pub fn new(obj: PyObjectRef) -> Self {
Self {
position: AtomicCell::new(0),
obj,
status: AtomicCell::new(IterStatus::Active),
}
}
#[pymethod(magic)]
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<isize> {
let pos = self.position.load();
let len = vm.obj_len(&self.obj)?;
Ok(len as isize - pos)
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
match self.status.load() {
IterStatus::Active => {
let pos = self.position.load();
// return NotImplemented if no length is around.
vm.obj_len(&self.obj)
.map_or(vm.ctx.not_implemented(), |len| {
PyInt::from(len.saturating_sub(pos)).into_object(vm)
})
}
IterStatus::Exhausted => PyInt::from(0).into_object(vm),
}
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(match self.status.load() {
IterStatus::Exhausted => vm
.ctx
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
IterStatus::Active => vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.obj.clone()]),
vm.ctx.new_int(self.position.load()),
]),
})
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// When we're exhausted, just return.
if let IterStatus::Exhausted = self.status.load() {
return Ok(());
}
if let Some(i) = state.payload::<PyInt>() {
self.position
.store(int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0));
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
}
}
impl PyIter for PySequenceIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
if let IterStatus::Exhausted = zelf.status.load() {
return Err(vm.new_stop_iteration());
}
let pos = zelf.position.fetch_add(1);
match zelf.obj.get_item(pos, vm) {
Err(ref e) if e.isinstance(&vm.ctx.exceptions.index_error) => {
zelf.status.store(IterStatus::Exhausted);
Err(vm.new_stop_iteration())
}
// also catches stop_iteration => stop_iteration

View File

@@ -36,7 +36,7 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult {
vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || {
format!("'{}' object is not iterable", iter_target.class().name)
})?;
Ok(PySequenceIterator::new_forward(iter_target)
Ok(PySequenceIterator::new(iter_target)
.into_ref(vm)
.into_object())
}

View File

@@ -741,7 +741,7 @@ impl<T> PyIterable<T> {
pub fn iter<'a>(&self, vm: &'a VirtualMachine) -> PyResult<PyIterator<'a, T>> {
let iter_obj = match self.iterfn {
Some(f) => f(self.iterable.clone(), vm)?,
None => PySequenceIterator::new_forward(self.iterable.clone()).into_object(vm),
None => PySequenceIterator::new(self.iterable.clone()).into_object(vm),
};
let length_hint = iterator::length_hint(vm, iter_obj.clone())?;