Merge pull request #2868 from DimitrisJim/reverse_list_it

Fix pickling, length hint, iteration for reverse list iterators.
This commit is contained in:
Jim Fasarakis-Hilliard
2021-08-12 21:02:55 +03:00
committed by GitHub
3 changed files with 101 additions and 32 deletions

View File

@@ -221,8 +221,6 @@ class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase):
def test_invariant(self):
super().test_invariant()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_mutation(self):
d = list(range(n))
it = reversed(d)

View File

@@ -116,8 +116,6 @@ class ListTest(list_tests.CommonTest):
a[:] = data
self.assertEqual(list(it), [])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_reversed_pickle(self):
orig = self.type2test([4, 5, 6, 7])
data = [10, 11, 12, 13, 14, 15]

View File

@@ -164,8 +164,14 @@ impl PyList {
#[pymethod(magic)]
fn reversed(zelf: PyRef<Self>) -> PyListReverseIterator {
let final_position = zelf.borrow_vec().len();
// Mark iterator as exhausted immediately if its empty.
PyListReverseIterator {
position: AtomicCell::new(final_position as isize),
position: AtomicCell::new(final_position.saturating_sub(1)),
status: AtomicCell::new(if final_position == 0 {
Exhausted
} else {
Active
}),
list: zelf,
}
}
@@ -482,7 +488,7 @@ impl PyValue for PyListIterator {
#[pyimpl(with(PyIter))]
impl PyListIterator {
#[pymethod(name = "__length_hint__")]
#[pymethod(magic)]
fn length_hint(&self) -> usize {
match self.status.load() {
Active => {
@@ -500,31 +506,19 @@ impl PyListIterator {
if let Exhausted = self.status.load() {
return Ok(());
}
if let Some(i) = state.payload::<PyInt>() {
let position = std::cmp::min(
int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0),
self.list.len(),
);
self.position.store(position);
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
let position = list_state(self.list.len(), state, vm)?;
self.position.store(position);
Ok(())
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
Ok(match self.status.load() {
Exhausted => vm
.ctx
.new_tuple(vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])]),
Active => vm.ctx.new_tuple(vec![
iter,
vm.ctx.new_tuple(vec![self.list.clone().into_object()]),
vm.ctx.new_int(self.position.load()),
]),
})
let pos = if let Exhausted = self.status.load() {
None
} else {
Some(self.position.load())
};
list_reduce(self.list.clone(), pos, false, vm)
}
}
@@ -547,7 +541,8 @@ impl PyIter for PyListIterator {
#[pyclass(module = false, name = "list_reverseiterator")]
#[derive(Debug)]
pub struct PyListReverseIterator {
pub position: AtomicCell<isize>,
pub position: AtomicCell<usize>,
pub status: AtomicCell<IterStatus>,
pub list: PyListRef,
}
@@ -559,25 +554,103 @@ impl PyValue for PyListReverseIterator {
#[pyimpl(with(PyIter))]
impl PyListReverseIterator {
#[pymethod(name = "__length_hint__")]
#[pymethod(magic)]
fn length_hint(&self) -> usize {
std::cmp::max(self.position.load(), 0) as usize
match self.status.load() {
Active => {
let position = self.position.load();
if position > self.list.len() {
// List was mutated. Report zero, next call to `__next__` will
// fail and set iterator to Exhausted.
0
} else {
position + 1
}
}
Exhausted => 0,
}
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
// When we're exhausted, just return.
if let Exhausted = self.status.load() {
return Ok(());
}
// Max for position is list.len() - 1.
let position = list_state(self.list.len().saturating_sub(1), state, vm)?;
self.position.store(position);
Ok(())
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
let pos = if let Exhausted = self.status.load() {
None
} else {
Some(self.position.load())
};
list_reduce(self.list.clone(), pos, true, vm)
}
}
impl PyIter for PyListReverseIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
if let Exhausted = zelf.status.load() {
return Err(vm.new_stop_iteration());
}
let list = zelf.list.borrow_vec();
let pos = zelf.position.fetch_sub(1);
if pos > 0 {
if let Some(ret) = list.get(pos as usize - 1) {
return Ok(ret.clone());
if let Some(obj) = list.get(pos) {
return Ok(obj.clone());
}
}
// We either are == 0 or list.get returned None. Either way, set status
// to exhausted and return last item if pos == 0.
zelf.status.store(Exhausted);
if pos == 0 {
if let Some(obj) = list.get(pos) {
return Ok(obj.clone());
}
}
Err(vm.new_stop_iteration())
}
}
// Common reducer for forward and reverse list iterators.
fn list_reduce(
list: PyRef<PyList>,
position: Option<usize>,
reverse: bool,
vm: &VirtualMachine,
) -> PyResult {
let attr = if reverse { "reversed" } else { "iter" };
let iter = vm.get_attribute(vm.builtins.clone(), attr)?;
let elems = match position {
None => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_list(vec![])])],
Some(position) => vec![
iter,
vm.ctx.new_tuple(vec![list.into_object()]),
vm.ctx.new_int(position),
],
};
Ok(vm.ctx.new_tuple(elems))
}
// Common function to extract state. Clamps it in range [0, length].
fn list_state(length: usize, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
let position = state
.payload::<PyInt>()
.ok_or_else(|| vm.new_type_error("an integer is required.".to_owned()))?;
let position = std::cmp::min(
int::try_to_primitive(position.as_bigint(), vm).unwrap_or(0),
length,
);
Ok(position)
}
pub fn init(context: &PyContext) {
let list_type = &context.types.list_type;
PyList::extend_class(context, list_type);