From a0d9ce030f5d5591898097edfef5340b914b22bf Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 28 Sep 2021 09:33:34 +0200 Subject: [PATCH] fix setstate saturated --- vm/src/builtins/bytearray.rs | 4 +++- vm/src/builtins/bytes.rs | 4 +++- vm/src/builtins/iter.rs | 35 ++++++++++++++++++++++++++++++++--- vm/src/builtins/list.rs | 8 ++++++-- vm/src/builtins/pystr.rs | 4 +++- vm/src/builtins/tuple.rs | 4 +++- 6 files changed, 50 insertions(+), 9 deletions(-) diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 21a905069..3f84751bd 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -756,7 +756,9 @@ impl PyByteArrayIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } } impl IteratorIterable for PyByteArrayIterator {} diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 27ccce1dc..4a849cab3 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -610,7 +610,9 @@ impl PyBytesIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } } impl IteratorIterable for PyBytesIterator {} diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 1405f4bf5..476c66938 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -10,7 +10,10 @@ use crate::{ ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; -use rustpython_common::lock::{OnceCell, PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}; +use rustpython_common::{ + lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard}, + static_cell, +}; /// Marks status of iterator. #[derive(Debug, Clone)] @@ -49,6 +52,28 @@ impl PositionIterInternal { } } + pub fn set_state_saturated( + &mut self, + state: PyObjectRef, + f: F, + vm: &VirtualMachine, + ) -> PyResult<()> + where + F: FnOnce(&T) -> usize, + { + if let IterStatus::Active(obj) = &self.status { + if let Some(i) = state.payload::() { + let i = int::try_to_primitive(i.as_bigint(), vm).unwrap_or(0); + self.position = i.min(f(obj)); + Ok(()) + } else { + Err(vm.new_type_error("an integer is required.".to_owned())) + } + } else { + Ok(()) + } + } + fn _reduce(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyObjectRef where F: FnOnce(&T) -> PyObjectRef, @@ -168,12 +193,16 @@ impl PositionIterInternal { } pub fn get_builtin_attribute_iter(vm: &VirtualMachine) -> &PyObjectRef { - static INSTANCE: OnceCell = OnceCell::new(); + static_cell! { + static INSTANCE: PyObjectRef; + } INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "iter").unwrap()) } pub fn get_builtin_attribute_reversed(vm: &VirtualMachine) -> &PyObjectRef { - static INSTANCE: OnceCell = OnceCell::new(); + static_cell! { + static INSTANCE: PyObjectRef; + } INSTANCE.get_or_init(|| vm.get_attribute(vm.builtins.clone(), "reversed").unwrap()) } diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index a00c72a59..c3ea794d1 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -480,7 +480,9 @@ impl PyListIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } #[pymethod(magic)] @@ -527,7 +529,9 @@ impl PyListReverseIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index e1802c23e..6ecd64636 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -188,7 +188,9 @@ impl PyStrIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.char_len(), vm) } #[pymethod(magic)] diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index de3f427d3..34b4e5820 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -333,7 +333,9 @@ impl PyTupleIterator { #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal.lock().set_state(state, vm) + self.internal + .lock() + .set_state_saturated(state, |obj| obj.len(), vm) } #[pymethod(magic)]