refactor set_state

This commit is contained in:
Kangzhi Shi
2021-10-01 14:26:48 +02:00
parent b4cbca0e8c
commit 70fc910268
7 changed files with 11 additions and 30 deletions

View File

@@ -758,7 +758,7 @@ impl PyByteArrayIterator {
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
.set_state(state, |obj, pos| pos.min(obj.len()), vm)
}
}
impl IteratorIterable for PyByteArrayIterator {}

View File

@@ -612,7 +612,7 @@ impl PyBytesIterator {
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
.set_state(state, |obj, pos| pos.min(obj.len()), vm)
}
}
impl IteratorIterable for PyBytesIterator {}

View File

@@ -98,7 +98,7 @@ impl PyReverseSequenceIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal.lock().set_state(state, |_, pos| pos, vm)
}
#[pymethod(magic)]

View File

@@ -38,33 +38,14 @@ impl<T> PositionIterInternal<T> {
}
}
pub fn set_state(&mut self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
if let IterStatus::Active(_) = &self.status {
if let Some(i) = state.payload::<PyInt>() {
let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0);
self.position = i;
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
} else {
Ok(())
}
}
pub fn set_state_saturated<F>(
&mut self,
state: PyObjectRef,
f: F,
vm: &VirtualMachine,
) -> PyResult<()>
pub fn set_state<F>(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()>
where
F: FnOnce(&T) -> usize,
F: FnOnce(&T, usize) -> usize,
{
if let IterStatus::Active(obj) = &self.status {
if let Some(i) = state.payload::<PyInt>() {
let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0);
self.position = i.min(f(obj));
self.position = f(obj, i);
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
@@ -245,7 +226,7 @@ impl PySequenceIterator {
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, vm)
self.internal.lock().set_state(state, |_, pos| pos, vm)
}
}

View File

@@ -482,7 +482,7 @@ impl PyListIterator {
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
.set_state(state, |obj, pos| pos.min(obj.len()), vm)
}
#[pymethod(magic)]
@@ -531,7 +531,7 @@ impl PyListReverseIterator {
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
.set_state(state, |obj, pos| pos.min(obj.len()), vm)
}
#[pymethod(magic)]

View File

@@ -190,7 +190,7 @@ impl PyStrIterator {
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
.set_state_saturated(state, |obj| obj.char_len(), vm)
.set_state(state, |obj, pos| pos.min(obj.char_len()), vm)
}
#[pymethod(magic)]

View File

@@ -335,7 +335,7 @@ impl PyTupleIterator {
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
.set_state_saturated(state, |obj| obj.len(), vm)
.set_state(state, |obj, pos| pos.min(obj.len()), vm)
}
#[pymethod(magic)]