fix setstate saturated

This commit is contained in:
Kangzhi Shi
2021-09-28 09:33:34 +02:00
parent fa8df88b5d
commit a0d9ce030f
6 changed files with 50 additions and 9 deletions

View File

@@ -756,7 +756,9 @@ impl PyByteArrayIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
}
}
impl IteratorIterable for PyByteArrayIterator {}

View File

@@ -610,7 +610,9 @@ impl PyBytesIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
}
}
impl IteratorIterable for PyBytesIterator {}

View File

@@ -10,7 +10,10 @@ use crate::{
ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol,
VirtualMachine,
};
use rustpython_common::lock::{OnceCell, PyMutex, PyRwLock, PyRwLockUpgradableReadGuard};
use rustpython_common::{
lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard},
static_cell,
};
/// Marks status of iterator.
#[derive(Debug, Clone)]
@@ -49,6 +52,28 @@ impl<T> PositionIterInternal<T> {
}
}
pub fn set_state_saturated<F>(
&mut self,
state: PyObjectRef,
f: F,
vm: &VirtualMachine,
) -> PyResult<()>
where
F: FnOnce(&T) -> usize,
{
if let IterStatus::Active(obj) = &self.status {
if let Some(i) = state.payload::<PyInt>() {
let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0);
self.position = i.min(f(obj));
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
} else {
Ok(())
}
}
fn _reduce<F>(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyObjectRef
where
F: FnOnce(&T) -> PyObjectRef,
@@ -168,12 +193,16 @@ impl<T> PositionIterInternal<T> {
}
pub fn get_builtin_attribute_iter(vm: &VirtualMachine) -> &PyObjectRef {
static INSTANCE: OnceCell<PyObjectRef> = OnceCell::new();
static_cell! {
static INSTANCE: PyObjectRef;
}
INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "iter").unwrap())
}
pub fn get_builtin_attribute_reversed(vm: &VirtualMachine) -> &PyObjectRef {
static INSTANCE: OnceCell<PyObjectRef> = OnceCell::new();
static_cell! {
static INSTANCE: PyObjectRef;
}
INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "reversed").unwrap())
}

View File

@@ -480,7 +480,9 @@ impl PyListIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
}
#[pymethod(magic)]
@@ -527,7 +529,9 @@ impl PyListReverseIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
}
#[pymethod(magic)]

View File

@@ -188,7 +188,9 @@ impl PyStrIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal
.lock()
.set_state_saturated(state, |obj| obj.char_len(), vm)
}
#[pymethod(magic)]

View File

@@ -333,7 +333,9 @@ impl PyTupleIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
}
#[pymethod(magic)]