forked from Rust-related/RustPython
@@ -301,6 +301,7 @@ assert list(t[0]) == [1,2,3]
|
||||
assert list(t[0]) == []
|
||||
|
||||
# itertools.product
|
||||
|
||||
it = itertools.product([1, 2], [3, 4])
|
||||
assert (1, 3) == next(it)
|
||||
assert (1, 4) == next(it)
|
||||
@@ -321,3 +322,25 @@ with assert_raises(TypeError):
|
||||
itertools.product(None)
|
||||
with assert_raises(TypeError):
|
||||
itertools.product([1, 2], repeat=None)
|
||||
|
||||
# itertools.combinations
|
||||
|
||||
it = itertools.combinations([1, 2, 3, 4], 2)
|
||||
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
|
||||
|
||||
it = itertools.combinations([1, 2, 3], 1)
|
||||
assert list(it) == [(1,), (2,), (3,)]
|
||||
|
||||
it = itertools.combinations([1, 2, 3], 2)
|
||||
assert next(it) == (1, 2)
|
||||
assert next(it) == (1, 3)
|
||||
assert next(it) == (2, 3)
|
||||
with assert_raises(StopIteration):
|
||||
next(it)
|
||||
|
||||
it = itertools.combinations([1, 2, 3], 4)
|
||||
with assert_raises(StopIteration):
|
||||
next(it)
|
||||
|
||||
with assert_raises(ValueError):
|
||||
itertools.combinations([1, 2, 3, 4], -2)
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
|
||||
use std::rc::Rc;
|
||||
|
||||
use num_bigint::BigInt;
|
||||
use num_traits::sign::Signed;
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
use crate::function::{Args, OptionalArg, PyFuncArgs};
|
||||
@@ -848,6 +849,110 @@ impl PyItertoolsProduct {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug)]
|
||||
struct PyItertoolsCombinations {
|
||||
pool: Vec<PyObjectRef>,
|
||||
indices: RefCell<Vec<usize>>,
|
||||
r: Cell<usize>,
|
||||
exhausted: Cell<bool>,
|
||||
}
|
||||
|
||||
impl PyValue for PyItertoolsCombinations {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
vm.class("itertools", "combinations")
|
||||
}
|
||||
}
|
||||
|
||||
#[pyimpl]
|
||||
impl PyItertoolsCombinations {
|
||||
#[pyslot(new)]
|
||||
fn tp_new(
|
||||
cls: PyClassRef,
|
||||
iterable: PyObjectRef,
|
||||
r: PyIntRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyRef<Self>> {
|
||||
let iter = get_iter(vm, &iterable)?;
|
||||
let pool = get_all(vm, &iter)?;
|
||||
|
||||
let r = r.as_bigint();
|
||||
if r.is_negative() {
|
||||
return Err(vm.new_value_error("r must be non-negative".to_string()));
|
||||
}
|
||||
let r = r.to_usize().unwrap();
|
||||
|
||||
let n = pool.len();
|
||||
|
||||
PyItertoolsCombinations {
|
||||
pool,
|
||||
indices: RefCell::new((0..r).collect()),
|
||||
r: Cell::new(r),
|
||||
exhausted: Cell::new(r > n),
|
||||
}
|
||||
.into_ref_with_type(vm, cls)
|
||||
}
|
||||
|
||||
#[pymethod(name = "__iter__")]
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, vm: &VirtualMachine) -> PyResult {
|
||||
// stop signal
|
||||
if self.exhausted.get() {
|
||||
return Err(new_stop_iteration(vm));
|
||||
}
|
||||
|
||||
let n = self.pool.len();
|
||||
let r = self.r.get();
|
||||
|
||||
let res = PyTuple::from(
|
||||
self.pool
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
|
||||
.map(|(_, num)| num.clone())
|
||||
.collect::<Vec<PyObjectRef>>(),
|
||||
);
|
||||
|
||||
let mut indices = self.indices.borrow_mut();
|
||||
let mut sentinel = false;
|
||||
|
||||
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
|
||||
let mut idx = r - 1;
|
||||
loop {
|
||||
if indices[idx] != idx + n - r {
|
||||
sentinel = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if idx != 0 {
|
||||
idx -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If no suitable index is found, then the indices are all at
|
||||
// their maximum value and we're done.
|
||||
if !sentinel {
|
||||
self.exhausted.set(true);
|
||||
}
|
||||
|
||||
// Increment the current index which we know is not at its
|
||||
// maximum. Then move back to the right setting each index
|
||||
// to its lowest possible value (one higher than the index
|
||||
// to its left -- this maintains the sort order invariant).
|
||||
indices[idx] += 1;
|
||||
for j in idx + 1..r {
|
||||
indices[j] = indices[j - 1] + 1;
|
||||
}
|
||||
|
||||
Ok(res.into_ref(vm).into_object())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let ctx = &vm.ctx;
|
||||
|
||||
@@ -858,6 +963,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
|
||||
let compress = PyItertoolsCompress::make_class(ctx);
|
||||
|
||||
let combinations = ctx.new_class("combinations", ctx.object());
|
||||
PyItertoolsCombinations::extend_class(ctx, &combinations);
|
||||
|
||||
let count = ctx.new_class("count", ctx.object());
|
||||
PyItertoolsCount::extend_class(ctx, &count);
|
||||
|
||||
@@ -887,6 +995,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
"accumulate" => accumulate,
|
||||
"chain" => chain,
|
||||
"compress" => compress,
|
||||
"combinations" => combinations,
|
||||
"count" => count,
|
||||
"dropwhile" => dropwhile,
|
||||
"islice" => islice,
|
||||
|
||||
Reference in New Issue
Block a user