diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 3f84751bd9..2a365a22ce 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -758,7 +758,7 @@ impl PyByteArrayIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } } impl IteratorIterable for PyByteArrayIterator {} diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 4a849cab30..316ce8b086 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -612,7 +612,7 @@ impl PyBytesIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } } impl IteratorIterable for PyBytesIterator {} diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 2e5d6b35ee..596d228d01 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -98,7 +98,7 @@ impl PyReverseSequenceIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal.lock().set_state(state, |_, pos| pos, vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index baa55eb413..b285dbd820 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -38,33 +38,14 @@ impl PositionIterInternal { } } - pub fn set_state(&mut self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - if let IterStatus::Active(_) = &self.status { - if let Some(i) = state.payload::() { - let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); - self.position = i; - Ok(()) - } else { - Err(vm.new_type_error("an integer is required.".to_owned())) - } - } else { - Ok(()) - } - } - - pub fn set_state_saturated( - &mut self, - state: PyObjectRef, - f: F, - vm: &VirtualMachine, - ) -> PyResult<()> + pub fn set_state(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()> where - F: FnOnce(&T) -> usize, + F: FnOnce(&T, usize) -> usize, { if let IterStatus::Active(obj) = &self.status { if let Some(i) = state.payload::() { let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); - self.position = i.min(f(obj)); + self.position = f(obj, i); Ok(()) } else { Err(vm.new_type_error("an integer is required.".to_owned())) @@ -245,7 +226,7 @@ impl PySequenceIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal.lock().set_state(state, |_, pos| pos, vm) } } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index c3ea794d1e..8c7ae63593 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -482,7 +482,7 @@ impl PyListIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] @@ -531,7 +531,7 @@ impl PyListReverseIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 6ecd64636c..81b9152c78 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -190,7 +190,7 @@ impl PyStrIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.char_len(), vm) + .set_state(state, |obj, pos| pos.min(obj.char_len()), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 34b4e58200..e692d636c1 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -335,7 +335,7 @@ impl PyTupleIterator { fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { self.internal .lock() - .set_state_saturated(state, |obj| obj.len(), vm) + .set_state(state, |obj, pos| pos.min(obj.len()), vm) } #[pymethod(magic)]