diff --git a/benchmarks/benchmarks/strings.py b/benchmarks/benchmarks/strings.py new file mode 100644 index 000000000..e2e46e90b --- /dev/null +++ b/benchmarks/benchmarks/strings.py @@ -0,0 +1,4 @@ +long_string = "a" * 50000 + +for char in long_string: + pass diff --git a/benchmarks/test_benchmarks.py b/benchmarks/test_benchmarks.py index c1b80582d..2626461aa 100644 --- a/benchmarks/test_benchmarks.py +++ b/benchmarks/test_benchmarks.py @@ -18,6 +18,7 @@ pythons = [ benchmarks = [ ['benchmarks/nbody.py'], ['benchmarks/mandelbrot.py'], + ['benchmarks/strings.py'], ] exe_ids = ['cpython', 'rustpython'] diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index e7b0da07f..1659d3b35 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -277,7 +277,7 @@ assert "\u9487" == "钇" assert "\U0001F609" == "😉" # test str iter -iterable_str = "123456789" +iterable_str = "12345678😉" str_iter = iter(iterable_str) assert next(str_iter) == "1" @@ -288,13 +288,13 @@ assert next(str_iter) == "5" assert next(str_iter) == "6" assert next(str_iter) == "7" assert next(str_iter) == "8" -assert next(str_iter) == "9" +assert next(str_iter) == "😉" assert next(str_iter, None) == None assert_raises(StopIteration, next, str_iter) str_iter_reversed = reversed(iterable_str) -assert next(str_iter_reversed) == "9" +assert next(str_iter_reversed) == "😉" assert next(str_iter_reversed) == "8" assert next(str_iter_reversed) == "7" assert next(str_iter_reversed) == "6" diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 1f299867c..219ce280e 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -98,7 +98,7 @@ impl TryIntoRef for &str { #[derive(Debug)] pub struct PyStringIterator { pub string: PyStringRef, - position: Cell, + byte_position: Cell, } impl PyValue for PyStringIterator { @@ -111,15 +111,16 @@ impl PyValue for PyStringIterator { impl PyStringIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let pos = self.position.get(); + let pos = self.byte_position.get(); - if pos < self.string.value.chars().count() { - self.position.set(self.position.get() + 1); + if pos < self.string.value.len() { + // We can be sure that chars() has a value, because of the pos check above. + let char_ = self.string.value[pos..].chars().next().unwrap(); - #[allow(clippy::range_plus_one)] - let value = self.string.value.do_slice(pos..pos + 1); + self.byte_position + .set(self.byte_position.get() + char_.len_utf8()); - value.into_pyobject(vm) + char_.to_string().into_pyobject(vm) } else { Err(objiter::new_stop_iteration(vm)) } @@ -1198,7 +1199,7 @@ impl PyString { #[pymethod(name = "__iter__")] fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyStringIterator { PyStringIterator { - position: Cell::new(0), + byte_position: Cell::new(0), string: zelf, } }