Add itertools.combinations()

re: #1361
This commit is contained in:
Daniel Alley
2019-11-21 10:51:38 -05:00
parent 53b391177b
commit 16b2b425b8
2 changed files with 132 additions and 0 deletions

View File

@@ -301,6 +301,7 @@ assert list(t[0]) == [1,2,3]
assert list(t[0]) == []
# itertools.product
it = itertools.product([1, 2], [3, 4])
assert (1, 3) == next(it)
assert (1, 4) == next(it)
@@ -321,3 +322,25 @@ with assert_raises(TypeError):
itertools.product(None)
with assert_raises(TypeError):
itertools.product([1, 2], repeat=None)
# itertools.combinations
it = itertools.combinations([1, 2, 3, 4], 2)
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
it = itertools.combinations([1, 2, 3], 1)
assert list(it) == [(1,), (2,), (3,)]
it = itertools.combinations([1, 2, 3], 2)
assert next(it) == (1, 2)
assert next(it) == (1, 3)
assert next(it) == (2, 3)
with assert_raises(StopIteration):
next(it)
it = itertools.combinations([1, 2, 3], 4)
with assert_raises(StopIteration):
next(it)
with assert_raises(ValueError):
itertools.combinations([1, 2, 3, 4], -2)

View File

@@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
use std::rc::Rc;
use num_bigint::BigInt;
use num_traits::sign::Signed;
use num_traits::ToPrimitive;
use crate::function::{Args, OptionalArg, PyFuncArgs};
@@ -848,6 +849,110 @@ impl PyItertoolsProduct {
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsCombinations {
pool: Vec<PyObjectRef>,
indices: RefCell<Vec<usize>>,
r: Cell<usize>,
exhausted: Cell<bool>,
}
impl PyValue for PyItertoolsCombinations {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "combinations")
}
}
#[pyimpl]
impl PyItertoolsCombinations {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
iterable: PyObjectRef,
r: PyIntRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let iter = get_iter(vm, &iterable)?;
let pool = get_all(vm, &iter)?;
let r = r.as_bigint();
if r.is_negative() {
return Err(vm.new_value_error("r must be non-negative".to_string()));
}
let r = r.to_usize().unwrap();
let n = pool.len();
PyItertoolsCombinations {
pool,
indices: RefCell::new((0..r).collect()),
r: Cell::new(r),
exhausted: Cell::new(r > n),
}
.into_ref_with_type(vm, cls)
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
// stop signal
if self.exhausted.get() {
return Err(new_stop_iteration(vm));
}
let n = self.pool.len();
let r = self.r.get();
let res = PyTuple::from(
self.pool
.iter()
.enumerate()
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
.map(|(_, num)| num.clone())
.collect::<Vec<PyObjectRef>>(),
);
let mut indices = self.indices.borrow_mut();
let mut sentinel = false;
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
let mut idx = r - 1;
loop {
if indices[idx] != idx + n - r {
sentinel = true;
break;
}
if idx != 0 {
idx -= 1;
} else {
break;
}
}
// If no suitable index is found, then the indices are all at
// their maximum value and we're done.
if !sentinel {
self.exhausted.set(true);
}
// Increment the current index which we know is not at its
// maximum. Then move back to the right setting each index
// to its lowest possible value (one higher than the index
// to its left -- this maintains the sort order invariant).
indices[idx] += 1;
for j in idx + 1..r {
indices[j] = indices[j - 1] + 1;
}
Ok(res.into_ref(vm).into_object())
}
}
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let ctx = &vm.ctx;
@@ -858,6 +963,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let compress = PyItertoolsCompress::make_class(ctx);
let combinations = ctx.new_class("combinations", ctx.object());
PyItertoolsCombinations::extend_class(ctx, &combinations);
let count = ctx.new_class("count", ctx.object());
PyItertoolsCount::extend_class(ctx, &count);
@@ -887,6 +995,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"accumulate" => accumulate,
"chain" => chain,
"compress" => compress,
"combinations" => combinations,
"count" => count,
"dropwhile" => dropwhile,
"islice" => islice,