diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index de4d1ddad..f10565112 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1399,7 +1399,7 @@ mod decl { } #[derive(FromArgs)] - struct ZiplongestArgs { + struct ZipLongestArgs { #[pyarg(named, optional)] fillvalue: OptionalArg, } @@ -1410,7 +1410,7 @@ mod decl { fn tp_new( cls: PyTypeRef, iterables: Args, - args: ZiplongestArgs, + args: ZipLongestArgs, vm: &VirtualMachine, ) -> PyResult> { let fillvalue = args.fillvalue.unwrap_or_none(vm); @@ -1454,4 +1454,47 @@ mod decl { } } } + + #[pyattr] + #[pyclass(name = "pairwise")] + #[derive(Debug)] + struct PyItertoolsPairwise { + iterator: PyObjectRef, + old: PyRwLock>, + } + + impl PyValue for PyItertoolsPairwise { + fn class(_vm: &VirtualMachine) -> &PyTypeRef { + Self::static_type() + } + } + + #[pyimpl(with(PyIter))] + impl PyItertoolsPairwise { + #[pyslot] + fn tp_new( + cls: PyTypeRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + let iterator = get_iter(vm, iterable)?; + + PyItertoolsPairwise { + iterator, + old: PyRwLock::new(None), + } + .into_ref_with_type(vm, cls) + } + } + impl PyIter for PyItertoolsPairwise { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + let old = match zelf.old.read().clone() { + None => call_next(vm, &zelf.iterator)?, + Some(obj) => obj, + }; + let new = call_next(vm, &zelf.iterator)?; + *zelf.old.write() = Some(new.clone()); + Ok(vm.ctx.new_tuple(vec![old, new])) + } + } }