Add itertools.combinations_with_replacement()

This commit is contained in:
Daniel Alley
2019-12-25 22:07:02 -05:00
parent 920ef52592
commit 4bbca2bed2
2 changed files with 118 additions and 0 deletions

View File

@@ -405,6 +405,25 @@ with assert_raises(ValueError):
with assert_raises(TypeError):
itertools.combinations([1, 2, 3, 4], None)
# itertools.combinations
it = itertools.combinations_with_replacement([1, 2, 3], 0)
assert list(it) == [()]
it = itertools.combinations_with_replacement([1, 2, 3], 1)
assert list(it) == [(1,), (2,), (3,)]
it = itertools.combinations_with_replacement([1, 2, 3], 2)
assert list(it) == [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
it = itertools.combinations_with_replacement([1, 2], 3)
assert list(it) == [(1, 1, 1), (1, 1, 2), (1, 2, 2), (2, 2, 2)]
with assert_raises(ValueError):
itertools.combinations_with_replacement([1, 2, 3, 4], -2)
with assert_raises(TypeError):
itertools.combinations_with_replacement([1, 2, 3, 4], None)
# itertools.permutations
it = itertools.permutations([1, 2, 3])
assert list(it) == [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]

View File

@@ -1031,6 +1031,100 @@ impl PyItertoolsCombinations {
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsCombinationsWithReplacement {
pool: Vec<PyObjectRef>,
indices: RefCell<Vec<usize>>,
r: Cell<usize>,
exhausted: Cell<bool>,
}
impl PyValue for PyItertoolsCombinationsWithReplacement {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "combinations_with_replacement")
}
}
#[pyimpl]
impl PyItertoolsCombinationsWithReplacement {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
iterable: PyObjectRef,
r: PyIntRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let iter = get_iter(vm, &iterable)?;
let pool = get_all(vm, &iter)?;
let r = r.as_bigint();
if r.is_negative() {
return Err(vm.new_value_error("r must be non-negative".to_string()));
}
let r = r.to_usize().unwrap();
let n = pool.len();
PyItertoolsCombinationsWithReplacement {
pool,
indices: RefCell::new(vec![0; r]),
r: Cell::new(r),
exhausted: Cell::new(n == 0 && r > 0),
}
.into_ref_with_type(vm, cls)
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
// stop signal
if self.exhausted.get() {
return Err(new_stop_iteration(vm));
}
let n = self.pool.len();
let r = self.r.get();
if r == 0 {
self.exhausted.set(true);
return Ok(vm.ctx.new_tuple(vec![]));
}
let mut indices = self.indices.borrow_mut();
let res = vm
.ctx
.new_tuple(indices.iter().map(|&i| self.pool[i].clone()).collect());
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
let mut idx = r as isize - 1;
while idx >= 0 && indices[idx as usize] == n - 1 {
idx -= 1;
}
// If no suitable index is found, then the indices are all at
// their maximum value and we're done.
if idx < 0 {
self.exhausted.set(true);
} else {
let index = indices[idx as usize] + 1;
// Increment the current index which we know is not at its
// maximum. Then set all to the right to the same value.
for j in idx as usize..r {
indices[j as usize] = index as usize;
}
}
Ok(res)
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsPermutations {
@@ -1257,6 +1351,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let combinations = ctx.new_class("combinations", ctx.object());
PyItertoolsCombinations::extend_class(ctx, &combinations);
let combinations_with_replacement =
ctx.new_class("combinations_with_replacement", ctx.object());
PyItertoolsCombinationsWithReplacement::extend_class(ctx, &combinations_with_replacement);
let count = ctx.new_class("count", ctx.object());
PyItertoolsCount::extend_class(ctx, &count);
@@ -1296,6 +1394,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"chain" => chain,
"compress" => compress,
"combinations" => combinations,
"combinations_with_replacement" => combinations_with_replacement,
"count" => count,
"cycle" => cycle,
"dropwhile" => dropwhile,