fix str iterator position

This commit is contained in:
Kangzhi Shi
2021-10-01 15:48:18 +02:00
parent 70fc910268
commit 87edbfece7
2 changed files with 32 additions and 17 deletions

View File

@@ -9,3 +9,8 @@ assert next(i) == 3
assert next(i, 'w00t') == 'w00t'
s = '你好'
i = iter(s)
i.__setstate__(1)
assert i.__next__() == ''
assert i.__reduce__()[2] == 2

View File

@@ -170,7 +170,7 @@ impl TryIntoRef<PyStr> for &str {
#[pyclass(module = false, name = "str_iterator")]
#[derive(Debug)]
pub struct PyStrIterator {
internal: PyMutex<PositionIterInternal<PyStrRef>>,
internal: PyMutex<(PositionIterInternal<PyStrRef>, usize)>,
}
impl PyValue for PyStrIterator {
@@ -183,13 +183,15 @@ impl PyValue for PyStrIterator {
impl PyStrIterator {
#[pymethod(magic)]
fn length_hint(&self) -> usize {
self.internal.lock().length_hint(|obj| obj.len())
self.internal.lock().0.length_hint(|obj| obj.char_len())
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal
.lock()
let mut internal = self.internal.lock();
internal.1 = usize::MAX;
internal
.0
.set_state(state, |obj, pos| pos.min(obj.char_len()), vm)
}
@@ -197,6 +199,7 @@ impl PyStrIterator {
fn reduce(&self, vm: &VirtualMachine) -> PyObjectRef {
self.internal
.lock()
.0
.builtin_iter_reduce(|x| x.clone().into_object(), vm)
}
}
@@ -205,20 +208,27 @@ impl IteratorIterable for PyStrIterator {}
impl SlotIterator for PyStrIterator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult {
let mut internal = zelf.internal.lock();
if let IterStatus::Active(s) = &internal.status {
if let IterStatus::Active(s) = &internal.0.status {
let value = s.as_str();
if internal.position >= value.len() {
internal.status = Exhausted;
return Err(vm.new_stop_iteration());
if internal.1 == usize::MAX {
if let Some((offset, ch)) = value.char_indices().nth(internal.0.position) {
internal.0.position += 1;
internal.1 = offset + ch.len_utf8();
return Ok(ch.into_pyobject(vm));
}
} else {
if let Some(value) = value.get(internal.1..) {
if let Some(ch) = value.chars().next() {
internal.0.position += 1;
internal.1 += ch.len_utf8();
return Ok(ch.into_pyobject(vm));
}
}
}
let ch = value[internal.position..].chars().next().ok_or_else(|| {
internal.status = Exhausted;
vm.new_stop_iteration()
})?;
internal.position += ch.len_utf8();
Ok(ch.into_pyobject(vm))
internal.0.status = Exhausted;
Err(vm.new_stop_iteration())
} else {
Err(vm.new_stop_iteration())
}
@@ -1256,7 +1266,7 @@ impl Comparable for PyStr {
impl Iterable for PyStr {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
Ok(PyStrIterator {
internal: PyMutex::new(PositionIterInternal::new(zelf, 0)),
internal: PyMutex::new((PositionIterInternal::new(zelf, 0), 0)),
}
.into_object(vm))
}