diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index e7b00da71..8709948b9 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1182,8 +1182,6 @@ class TestBasicOps(unittest.TestCase): with self.assertRaises(TypeError): pairwise(None) # non-iterable argument - # TODO: RUSTPYTHON - @unittest.skip("TODO: RUSTPYTHON, hangs") def test_pairwise_reenter(self): def check(reenter_at, expected): class I: @@ -1234,8 +1232,6 @@ class TestBasicOps(unittest.TestCase): ([5], [6]), ]) - # TODO: RUSTPYTHON - @unittest.skip("TODO: RUSTPYTHON, hangs") def test_pairwise_reenter2(self): def check(maxcount, expected): class I: diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index efe336812..df6d9487c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1943,7 +1943,7 @@ mod decl { type Args = PyIter; fn py_new(cls: PyTypeRef, iterator: Self::Args, vm: &VirtualMachine) -> PyResult { - PyItertoolsPairwise { + Self { iterator, old: PyRwLock::new(None), } @@ -1959,18 +1959,29 @@ mod decl { impl IterNext for PyItertoolsPairwise { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { - let old = match zelf.old.read().clone() { + let old_clone = { + let guard = zelf.old.read(); + guard.clone() + }; + let old = match old_clone { None => match zelf.iterator.next(vm)? { - PyIterReturn::Return(obj) => obj, + PyIterReturn::Return(obj) => { + // Needed for when we reenter + *zelf.old.write() = Some(obj.clone()); + obj + } PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), }, Some(obj) => obj, }; + let new = match zelf.iterator.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), }; + *zelf.old.write() = Some(new.clone()); + Ok(PyIterReturn::Return(vm.new_tuple((old, new)).into())) } }