From 87edbfece7ffa5140367ea0859a5dc254bbf6fb7 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Oct 2021 15:48:18 +0200 Subject: [PATCH] fix str iterator position --- extra_tests/snippets/iterations.py | 5 ++++ vm/src/builtins/pystr.rs | 44 ++++++++++++++++++------------ 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/extra_tests/snippets/iterations.py b/extra_tests/snippets/iterations.py index 98031a935..b2b30961c 100644 --- a/extra_tests/snippets/iterations.py +++ b/extra_tests/snippets/iterations.py @@ -9,3 +9,8 @@ assert next(i) == 3 assert next(i, 'w00t') == 'w00t' +s = '你好' +i = iter(s) +i.__setstate__(1) +assert i.__next__() == '好' +assert i.__reduce__()[2] == 2 diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 81b9152c7..249af8eaf 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -170,7 +170,7 @@ impl TryIntoRef for &str { #[pyclass(module = false, name = "str_iterator")] #[derive(Debug)] pub struct PyStrIterator { - internal: PyMutex>, + internal: PyMutex<(PositionIterInternal, usize)>, } impl PyValue for PyStrIterator { @@ -183,13 +183,15 @@ impl PyValue for PyStrIterator { impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { - self.internal.lock().length_hint(|obj| obj.len()) + self.internal.lock().0.length_hint(|obj| obj.char_len()) } #[pymethod(magic)] fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.internal - .lock() + let mut internal = self.internal.lock(); + internal.1 = usize::MAX; + internal + .0 .set_state(state, |obj, pos| pos.min(obj.char_len()), vm) } @@ -197,6 +199,7 @@ impl PyStrIterator { fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef { self.internal .lock() + .0 .builtin_iter_reduce(|x| x.clone().into_object(), vm) } } @@ -205,20 +208,27 @@ impl IteratorIterable for PyStrIterator {} impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut internal = zelf.internal.lock(); - if let IterStatus::Active(s) = &internal.status { + + if let IterStatus::Active(s) = &internal.0.status { let value = s.as_str(); - if internal.position >= value.len() { - internal.status = Exhausted; - return Err(vm.new_stop_iteration()); + + if internal.1 == usize::MAX { + if let Some((offset, ch)) = value.char_indices().nth(internal.0.position) { + internal.0.position += 1; + internal.1 = offset + ch.len_utf8(); + return Ok(ch.into_pyobject(vm)); + } + } else { + if let Some(value) = value.get(internal.1..) { + if let Some(ch) = value.chars().next() { + internal.0.position += 1; + internal.1 += ch.len_utf8(); + return Ok(ch.into_pyobject(vm)); + } + } } - - let ch = value[internal.position..].chars().next().ok_or_else(|| { - internal.status = Exhausted; - vm.new_stop_iteration() - })?; - - internal.position += ch.len_utf8(); - Ok(ch.into_pyobject(vm)) + internal.0.status = Exhausted; + Err(vm.new_stop_iteration()) } else { Err(vm.new_stop_iteration()) } @@ -1256,7 +1266,7 @@ impl Comparable for PyStr { impl Iterable for PyStr { fn iter(zelf: PyRef, vm: &VirtualMachine) -> PyResult { Ok(PyStrIterator { - internal: PyMutex::new(PositionIterInternal::new(zelf, 0)), + internal: PyMutex::new((PositionIterInternal::new(zelf, 0), 0)), } .into_object(vm)) }