mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Add itertools.combinations_with_replacement()
This commit is contained in:
@@ -405,6 +405,25 @@ with assert_raises(ValueError):
|
||||
with assert_raises(TypeError):
|
||||
itertools.combinations([1, 2, 3, 4], None)
|
||||
|
||||
# itertools.combinations
|
||||
it = itertools.combinations_with_replacement([1, 2, 3], 0)
|
||||
assert list(it) == [()]
|
||||
|
||||
it = itertools.combinations_with_replacement([1, 2, 3], 1)
|
||||
assert list(it) == [(1,), (2,), (3,)]
|
||||
|
||||
it = itertools.combinations_with_replacement([1, 2, 3], 2)
|
||||
assert list(it) == [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
|
||||
|
||||
it = itertools.combinations_with_replacement([1, 2], 3)
|
||||
assert list(it) == [(1, 1, 1), (1, 1, 2), (1, 2, 2), (2, 2, 2)]
|
||||
|
||||
with assert_raises(ValueError):
|
||||
itertools.combinations_with_replacement([1, 2, 3, 4], -2)
|
||||
|
||||
with assert_raises(TypeError):
|
||||
itertools.combinations_with_replacement([1, 2, 3, 4], None)
|
||||
|
||||
# itertools.permutations
|
||||
it = itertools.permutations([1, 2, 3])
|
||||
assert list(it) == [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]
|
||||
|
||||
@@ -1031,6 +1031,100 @@ impl PyItertoolsCombinations {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug)]
|
||||
struct PyItertoolsCombinationsWithReplacement {
|
||||
pool: Vec<PyObjectRef>,
|
||||
indices: RefCell<Vec<usize>>,
|
||||
r: Cell<usize>,
|
||||
exhausted: Cell<bool>,
|
||||
}
|
||||
|
||||
impl PyValue for PyItertoolsCombinationsWithReplacement {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
vm.class("itertools", "combinations_with_replacement")
|
||||
}
|
||||
}
|
||||
|
||||
#[pyimpl]
|
||||
impl PyItertoolsCombinationsWithReplacement {
|
||||
#[pyslot(new)]
|
||||
fn tp_new(
|
||||
cls: PyClassRef,
|
||||
iterable: PyObjectRef,
|
||||
r: PyIntRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyRef<Self>> {
|
||||
let iter = get_iter(vm, &iterable)?;
|
||||
let pool = get_all(vm, &iter)?;
|
||||
|
||||
let r = r.as_bigint();
|
||||
if r.is_negative() {
|
||||
return Err(vm.new_value_error("r must be non-negative".to_string()));
|
||||
}
|
||||
let r = r.to_usize().unwrap();
|
||||
|
||||
let n = pool.len();
|
||||
|
||||
PyItertoolsCombinationsWithReplacement {
|
||||
pool,
|
||||
indices: RefCell::new(vec![0; r]),
|
||||
r: Cell::new(r),
|
||||
exhausted: Cell::new(n == 0 && r > 0),
|
||||
}
|
||||
.into_ref_with_type(vm, cls)
|
||||
}
|
||||
|
||||
#[pymethod(name = "__iter__")]
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, vm: &VirtualMachine) -> PyResult {
|
||||
// stop signal
|
||||
if self.exhausted.get() {
|
||||
return Err(new_stop_iteration(vm));
|
||||
}
|
||||
|
||||
let n = self.pool.len();
|
||||
let r = self.r.get();
|
||||
|
||||
if r == 0 {
|
||||
self.exhausted.set(true);
|
||||
return Ok(vm.ctx.new_tuple(vec![]));
|
||||
}
|
||||
|
||||
let mut indices = self.indices.borrow_mut();
|
||||
|
||||
let res = vm
|
||||
.ctx
|
||||
.new_tuple(indices.iter().map(|&i| self.pool[i].clone()).collect());
|
||||
|
||||
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
|
||||
let mut idx = r as isize - 1;
|
||||
while idx >= 0 && indices[idx as usize] == n - 1 {
|
||||
idx -= 1;
|
||||
}
|
||||
|
||||
// If no suitable index is found, then the indices are all at
|
||||
// their maximum value and we're done.
|
||||
if idx < 0 {
|
||||
self.exhausted.set(true);
|
||||
} else {
|
||||
let index = indices[idx as usize] + 1;
|
||||
|
||||
// Increment the current index which we know is not at its
|
||||
// maximum. Then set all to the right to the same value.
|
||||
for j in idx as usize..r {
|
||||
indices[j as usize] = index as usize;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug)]
|
||||
struct PyItertoolsPermutations {
|
||||
@@ -1257,6 +1351,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let combinations = ctx.new_class("combinations", ctx.object());
|
||||
PyItertoolsCombinations::extend_class(ctx, &combinations);
|
||||
|
||||
let combinations_with_replacement =
|
||||
ctx.new_class("combinations_with_replacement", ctx.object());
|
||||
PyItertoolsCombinationsWithReplacement::extend_class(ctx, &combinations_with_replacement);
|
||||
|
||||
let count = ctx.new_class("count", ctx.object());
|
||||
PyItertoolsCount::extend_class(ctx, &count);
|
||||
|
||||
@@ -1296,6 +1394,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
"chain" => chain,
|
||||
"compress" => compress,
|
||||
"combinations" => combinations,
|
||||
"combinations_with_replacement" => combinations_with_replacement,
|
||||
"count" => count,
|
||||
"cycle" => cycle,
|
||||
"dropwhile" => dropwhile,
|
||||
|
||||
Reference in New Issue
Block a user