diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 67664ebc2..eff1813a5 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -209,4 +209,29 @@ assert 6 == next(it) assert 4 == next(it) assert 1 == next(it) with assertRaises(StopIteration): - next(it) \ No newline at end of file + next(it) + + +# itertools.accumulate +it = itertools.accumulate([6, 3, 7, 1, 0, 9, 8, 8]) +assert 6 == next(it) +assert 9 == next(it) +assert 16 == next(it) +assert 17 == next(it) +assert 17 == next(it) +assert 26 == next(it) +assert 34 == next(it) +assert 42 == next(it) +with assertRaises(StopIteration): + next(it) + +it = itertools.accumulate([3, 2, 4, 1, 0, 5, 8], lambda a, v: a*v) +assert 3 == next(it) +assert 6 == next(it) +assert 24 == next(it) +assert 24 == next(it) +assert 0 == next(it) +assert 0 == next(it) +assert 0 == next(it) +with assertRaises(StopIteration): + next(it) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 1ba74f6d2..dabbb220c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -537,6 +537,66 @@ impl PyItertoolsFilterFalse { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsAccumulate { + iterable: PyObjectRef, + binop: PyObjectRef, + acc_value: RefCell>, +} + +impl PyValue for PyItertoolsAccumulate { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "accumulate") + } +} + +#[pyimpl] +impl PyItertoolsAccumulate { + #[pymethod(name = "__new__")] + #[allow(clippy::new_ret_no_self)] + fn new( + cls: PyClassRef, + iterable: PyObjectRef, + binop: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsAccumulate { + iterable: iter, + binop: binop.unwrap_or_else(|| vm.get_none()), + acc_value: RefCell::from(Option::None), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let iterable = &self.iterable; + let obj = call_next(vm, iterable)?; + + let next_acc_value = match &*self.acc_value.borrow() { + Option::None => obj.clone(), + Option::Some(value) => { + if self.binop.is(&vm.get_none()) { + vm._add(value.clone(), obj.clone())? + } else { + vm.invoke(&self.binop, vec![value.clone(), obj.clone()])? + } + } + }; + self.acc_value.replace(Option::from(next_acc_value.clone())); + + Ok(next_acc_value) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -561,6 +621,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let filterfalse = ctx.new_class("filterfalse", ctx.object()); PyItertoolsFilterFalse::extend_class(ctx, &filterfalse); + let accumulate = ctx.new_class("accumulate", ctx.object()); + PyItertoolsAccumulate::extend_class(ctx, &accumulate); + py_module!(vm, "itertools", { "chain" => chain, "count" => count, @@ -570,5 +633,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "takewhile" => takewhile, "islice" => islice, "filterfalse" => filterfalse, + "accumulate" => accumulate, }) }