diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index f133cfc04e..5a400db714 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -301,6 +301,7 @@ assert list(t[0]) == [1,2,3] assert list(t[0]) == [] # itertools.product + it = itertools.product([1, 2], [3, 4]) assert (1, 3) == next(it) assert (1, 4) == next(it) @@ -321,3 +322,25 @@ with assert_raises(TypeError): itertools.product(None) with assert_raises(TypeError): itertools.product([1, 2], repeat=None) + +# itertools.combinations + +it = itertools.combinations([1, 2, 3, 4], 2) +assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] + +it = itertools.combinations([1, 2, 3], 1) +assert list(it) == [(1,), (2,), (3,)] + +it = itertools.combinations([1, 2, 3], 2) +assert next(it) == (1, 2) +assert next(it) == (1, 3) +assert next(it) == (2, 3) +with assert_raises(StopIteration): + next(it) + +it = itertools.combinations([1, 2, 3], 4) +with assert_raises(StopIteration): + next(it) + +with assert_raises(ValueError): + itertools.combinations([1, 2, 3, 4], -2) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 750eb9a9a6..0d176029dc 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign}; use std::rc::Rc; use num_bigint::BigInt; +use num_traits::sign::Signed; use num_traits::ToPrimitive; use crate::function::{Args, OptionalArg, PyFuncArgs}; @@ -848,6 +849,110 @@ impl PyItertoolsProduct { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsCombinations { + pool: Vec, + indices: RefCell>, + r: Cell, + exhausted: Cell, +} + +impl PyValue for PyItertoolsCombinations { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "combinations") + } +} + +#[pyimpl] +impl PyItertoolsCombinations { + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + r: PyIntRef, + vm: &VirtualMachine, + ) -> PyResult> { + let iter = get_iter(vm, &iterable)?; + let pool = get_all(vm, &iter)?; + + let r = r.as_bigint(); + if r.is_negative() { + return Err(vm.new_value_error("r must be non-negative".to_string())); + } + let r = r.to_usize().unwrap(); + + let n = pool.len(); + + PyItertoolsCombinations { + pool, + indices: RefCell::new((0..r).collect()), + r: Cell::new(r), + exhausted: Cell::new(r > n), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.exhausted.get() { + return Err(new_stop_iteration(vm)); + } + + let n = self.pool.len(); + let r = self.r.get(); + + let res = PyTuple::from( + self.pool + .iter() + .enumerate() + .filter(|(idx, _)| self.indices.borrow().contains(&idx)) + .map(|(_, num)| num.clone()) + .collect::>(), + ); + + let mut indices = self.indices.borrow_mut(); + let mut sentinel = false; + + // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). + let mut idx = r - 1; + loop { + if indices[idx] != idx + n - r { + sentinel = true; + break; + } + + if idx != 0 { + idx -= 1; + } else { + break; + } + } + // If no suitable index is found, then the indices are all at + // their maximum value and we're done. + if !sentinel { + self.exhausted.set(true); + } + + // Increment the current index which we know is not at its + // maximum. Then move back to the right setting each index + // to its lowest possible value (one higher than the index + // to its left -- this maintains the sort order invariant). + indices[idx] += 1; + for j in idx + 1..r { + indices[j] = indices[j - 1] + 1; + } + + Ok(res.into_ref(vm).into_object()) + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -858,6 +963,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let compress = PyItertoolsCompress::make_class(ctx); + let combinations = ctx.new_class("combinations", ctx.object()); + PyItertoolsCombinations::extend_class(ctx, &combinations); + let count = ctx.new_class("count", ctx.object()); PyItertoolsCount::extend_class(ctx, &count); @@ -887,6 +995,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "accumulate" => accumulate, "chain" => chain, "compress" => compress, + "combinations" => combinations, "count" => count, "dropwhile" => dropwhile, "islice" => islice,