forked from Rust-related/RustPython
Implement itertools.product
This implements `itertools.product` of standard library. Related with #1361
This commit is contained in:
@@ -299,3 +299,25 @@ assert list(t3) == [1,2,3]
|
||||
t = itertools.tee([1,2,3])
|
||||
assert list(t[0]) == [1,2,3]
|
||||
assert list(t[0]) == []
|
||||
|
||||
# itertools.product
|
||||
it = itertools.product([1, 2], [3, 4])
|
||||
assert (1, 3) == next(it)
|
||||
assert (1, 4) == next(it)
|
||||
assert (2, 3) == next(it)
|
||||
assert (2, 4) == next(it)
|
||||
with assert_raises(StopIteration):
|
||||
next(it)
|
||||
|
||||
it = itertools.product([1, 2], repeat=2)
|
||||
assert (1, 1) == next(it)
|
||||
assert (1, 2) == next(it)
|
||||
assert (2, 1) == next(it)
|
||||
assert (2, 2) == next(it)
|
||||
with assert_raises(StopIteration):
|
||||
next(it)
|
||||
|
||||
with assert_raises(TypeError):
|
||||
itertools.product(None)
|
||||
with assert_raises(TypeError):
|
||||
itertools.product([1, 2], repeat=None)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::cmp::Ordering;
|
||||
use std::iter;
|
||||
use std::ops::{AddAssign, SubAssign};
|
||||
use std::rc::Rc;
|
||||
|
||||
use num_bigint::BigInt;
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
use crate::function::{OptionalArg, PyFuncArgs};
|
||||
use crate::function::{Args, OptionalArg, PyFuncArgs};
|
||||
use crate::obj::objbool;
|
||||
use crate::obj::objint::{self, PyInt, PyIntRef};
|
||||
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
|
||||
use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration};
|
||||
use crate::obj::objtuple::PyTuple;
|
||||
use crate::obj::objtype::{self, PyClassRef};
|
||||
use crate::pyobject::{
|
||||
@@ -736,6 +737,123 @@ impl PyItertoolsTee {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug)]
|
||||
struct PyIterToolsProduct {
|
||||
pools: Vec<Vec<PyObjectRef>>,
|
||||
idxs: RefCell<Vec<usize>>,
|
||||
cur: Cell<usize>,
|
||||
stop: Cell<bool>,
|
||||
}
|
||||
|
||||
impl PyValue for PyIterToolsProduct {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
vm.class("itertools", "product")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
struct ProductArgs {
|
||||
#[pyarg(keyword_only, optional = true)]
|
||||
repeat: OptionalArg<usize>,
|
||||
}
|
||||
|
||||
#[pyimpl]
|
||||
impl PyIterToolsProduct {
|
||||
#[pyslot(new)]
|
||||
fn tp_new(
|
||||
cls: PyClassRef,
|
||||
iterables: Args<PyObjectRef>,
|
||||
args: ProductArgs,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyRef<Self>> {
|
||||
let repeat = match args.repeat.into_option() {
|
||||
Some(i) => i,
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let mut pools = Vec::new();
|
||||
for arg in iterables.into_iter() {
|
||||
let it = get_iter(vm, &arg)?;
|
||||
let pool = get_all(vm, &it)?;
|
||||
|
||||
pools.push(pool);
|
||||
}
|
||||
let pools = iter::repeat(pools)
|
||||
.take(repeat)
|
||||
.flatten()
|
||||
.collect::<Vec<Vec<PyObjectRef>>>();
|
||||
|
||||
let l = pools.len();
|
||||
|
||||
PyIterToolsProduct {
|
||||
pools,
|
||||
idxs: RefCell::new(vec![0; l]),
|
||||
cur: Cell::new(l - 1),
|
||||
stop: Cell::new(false),
|
||||
}
|
||||
.into_ref_with_type(vm, cls)
|
||||
}
|
||||
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, vm: &VirtualMachine) -> PyResult {
|
||||
// stop signal
|
||||
if self.stop.get() {
|
||||
return Err(new_stop_iteration(vm));
|
||||
}
|
||||
|
||||
let pools = &self.pools;
|
||||
|
||||
for p in pools {
|
||||
if p.is_empty() {
|
||||
return Err(new_stop_iteration(vm));
|
||||
}
|
||||
}
|
||||
|
||||
let res = PyTuple::from(
|
||||
pools
|
||||
.iter()
|
||||
.zip(self.idxs.borrow().iter())
|
||||
.map(|(pool, idx)| pool[*idx].clone())
|
||||
.collect::<Vec<PyObjectRef>>(),
|
||||
);
|
||||
|
||||
self.update_idxs();
|
||||
|
||||
if self.is_end() {
|
||||
self.stop.set(true);
|
||||
}
|
||||
|
||||
Ok(res.into_ref(vm).into_object())
|
||||
}
|
||||
|
||||
fn is_end(&self) -> bool {
|
||||
(self.idxs.borrow()[self.cur.get()] == &self.pools[self.cur.get()].len() - 1
|
||||
&& self.cur.get() == 0)
|
||||
}
|
||||
|
||||
fn update_idxs(&self) {
|
||||
let lst_idx = &self.pools[self.cur.get()].len() - 1;
|
||||
|
||||
if self.idxs.borrow()[self.cur.get()] == lst_idx {
|
||||
if self.is_end() {
|
||||
return;
|
||||
}
|
||||
self.idxs.borrow_mut()[self.cur.get()] = 0;
|
||||
self.cur.set(self.cur.get() - 1);
|
||||
self.update_idxs();
|
||||
} else {
|
||||
self.idxs.borrow_mut()[self.cur.get()] += 1;
|
||||
self.cur.set(self.idxs.borrow().len() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(name = "__iter__")]
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let ctx = &vm.ctx;
|
||||
|
||||
@@ -767,6 +885,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
|
||||
let tee = ctx.new_class("tee", ctx.object());
|
||||
PyItertoolsTee::extend_class(ctx, &tee);
|
||||
let product = ctx.new_class("product", ctx.object());
|
||||
PyIterToolsProduct::extend_class(ctx, &product);
|
||||
|
||||
py_module!(vm, "itertools", {
|
||||
"chain" => chain,
|
||||
@@ -780,5 +900,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
"filterfalse" => filterfalse,
|
||||
"accumulate" => accumulate,
|
||||
"tee" => tee,
|
||||
"product" => product,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user