Implement itertools.product

This implements `itertools.product` of standard library.

Related with #1361
This commit is contained in:
joshua1b
2019-10-11 01:41:36 +09:00
parent 40783d1bc5
commit 8deb936c22
2 changed files with 145 additions and 2 deletions

View File

@@ -299,3 +299,25 @@ assert list(t3) == [1,2,3]
t = itertools.tee([1,2,3])
assert list(t[0]) == [1,2,3]
assert list(t[0]) == []
# itertools.product
it = itertools.product([1, 2], [3, 4])
assert (1, 3) == next(it)
assert (1, 4) == next(it)
assert (2, 3) == next(it)
assert (2, 4) == next(it)
with assert_raises(StopIteration):
next(it)
it = itertools.product([1, 2], repeat=2)
assert (1, 1) == next(it)
assert (1, 2) == next(it)
assert (2, 1) == next(it)
assert (2, 2) == next(it)
with assert_raises(StopIteration):
next(it)
with assert_raises(TypeError):
itertools.product(None)
with assert_raises(TypeError):
itertools.product([1, 2], repeat=None)

View File

@@ -1,15 +1,16 @@
use std::cell::{Cell, RefCell};
use std::cmp::Ordering;
use std::iter;
use std::ops::{AddAssign, SubAssign};
use std::rc::Rc;
use num_bigint::BigInt;
use num_traits::ToPrimitive;
use crate::function::{OptionalArg, PyFuncArgs};
use crate::function::{Args, OptionalArg, PyFuncArgs};
use crate::obj::objbool;
use crate::obj::objint::{self, PyInt, PyIntRef};
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration};
use crate::obj::objtuple::PyTuple;
use crate::obj::objtype::{self, PyClassRef};
use crate::pyobject::{
@@ -736,6 +737,123 @@ impl PyItertoolsTee {
}
}
#[pyclass]
#[derive(Debug)]
struct PyIterToolsProduct {
pools: Vec<Vec<PyObjectRef>>,
idxs: RefCell<Vec<usize>>,
cur: Cell<usize>,
stop: Cell<bool>,
}
impl PyValue for PyIterToolsProduct {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "product")
}
}
#[derive(FromArgs)]
struct ProductArgs {
#[pyarg(keyword_only, optional = true)]
repeat: OptionalArg<usize>,
}
#[pyimpl]
impl PyIterToolsProduct {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
iterables: Args<PyObjectRef>,
args: ProductArgs,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let repeat = match args.repeat.into_option() {
Some(i) => i,
None => 1,
};
let mut pools = Vec::new();
for arg in iterables.into_iter() {
let it = get_iter(vm, &arg)?;
let pool = get_all(vm, &it)?;
pools.push(pool);
}
let pools = iter::repeat(pools)
.take(repeat)
.flatten()
.collect::<Vec<Vec<PyObjectRef>>>();
let l = pools.len();
PyIterToolsProduct {
pools,
idxs: RefCell::new(vec![0; l]),
cur: Cell::new(l - 1),
stop: Cell::new(false),
}
.into_ref_with_type(vm, cls)
}
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
// stop signal
if self.stop.get() {
return Err(new_stop_iteration(vm));
}
let pools = &self.pools;
for p in pools {
if p.is_empty() {
return Err(new_stop_iteration(vm));
}
}
let res = PyTuple::from(
pools
.iter()
.zip(self.idxs.borrow().iter())
.map(|(pool, idx)| pool[*idx].clone())
.collect::<Vec<PyObjectRef>>(),
);
self.update_idxs();
if self.is_end() {
self.stop.set(true);
}
Ok(res.into_ref(vm).into_object())
}
fn is_end(&self) -> bool {
(self.idxs.borrow()[self.cur.get()] == &self.pools[self.cur.get()].len() - 1
&& self.cur.get() == 0)
}
fn update_idxs(&self) {
let lst_idx = &self.pools[self.cur.get()].len() - 1;
if self.idxs.borrow()[self.cur.get()] == lst_idx {
if self.is_end() {
return;
}
self.idxs.borrow_mut()[self.cur.get()] = 0;
self.cur.set(self.cur.get() - 1);
self.update_idxs();
} else {
self.idxs.borrow_mut()[self.cur.get()] += 1;
self.cur.set(self.idxs.borrow().len() - 1);
}
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
}
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let ctx = &vm.ctx;
@@ -767,6 +885,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let tee = ctx.new_class("tee", ctx.object());
PyItertoolsTee::extend_class(ctx, &tee);
let product = ctx.new_class("product", ctx.object());
PyIterToolsProduct::extend_class(ctx, &product);
py_module!(vm, "itertools", {
"chain" => chain,
@@ -780,5 +900,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"filterfalse" => filterfalse,
"accumulate" => accumulate,
"tee" => tee,
"product" => product,
})
}