From da53eeea5743fbc608e475d651421dfc69a24d81 Mon Sep 17 00:00:00 2001 From: Hunter Harloff Date: Wed, 27 Mar 2024 23:19:36 -0400 Subject: [PATCH] Add itertools.batched Support --- vm/src/stdlib/itertools.rs | 71 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 3bc26c0a8..6b946887c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1952,4 +1952,75 @@ mod decl { Ok(PyIterReturn::Return(vm.new_tuple((old, new)).into())) } } + + #[pyattr] + #[pyclass(name = "batched")] + #[derive(Debug, PyPayload)] + struct PyItertoolsBatched { + exhausted: AtomicCell, + iterable: PyIter, + n: AtomicCell, + } + + #[derive(FromArgs)] + struct BatchedNewArgs { + #[pyarg(positional)] + iterable: PyIter, + #[pyarg(positional)] + n: PyIntRef, + } + + impl Constructor for PyItertoolsBatched { + type Args = BatchedNewArgs; + + fn py_new( + cls: PyTypeRef, + Self::Args { iterable, n }: Self::Args, + vm: &VirtualMachine, + ) -> PyResult { + let n = n.as_bigint(); + if n.is_negative() { + return Err(vm.new_value_error("n must be at least one".to_owned())); + } + let n = n.to_usize().unwrap(); + + Self { + iterable, + n: AtomicCell::new(n), + exhausted: AtomicCell::new(false), + } + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + + #[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE, HAS_DICT))] + impl PyItertoolsBatched {} + + impl SelfIter for PyItertoolsBatched {} + + impl IterNext for PyItertoolsBatched { + fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { + if zelf.exhausted.load() { + return Ok(PyIterReturn::StopIteration(None)); + } + let mut result: Vec = Vec::new(); + let n = zelf.n.load(); + for _ in 0..n { + match zelf.iterable.next(vm)? { + PyIterReturn::Return(obj) => { + result.push(obj); + } + PyIterReturn::StopIteration(_) => { + zelf.exhausted.store(true); + break; + } + } + } + match result.len() { + 0 => Ok(PyIterReturn::StopIteration(None)), + _ => Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into())), + } + } + } }