Implement itertools.zip_longest

This commit is contained in:
Space0726
2019-11-23 02:37:16 +09:00
parent 389db55a3f
commit 994e86fa25
2 changed files with 127 additions and 0 deletions

View File

@@ -344,3 +344,46 @@ with assert_raises(StopIteration):
with assert_raises(ValueError):
itertools.combinations([1, 2, 3, 4], -2)
# itertools.zip_longest tests
zl = itertools.zip_longest
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \
== [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) \
== [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), (None, None, 99)]
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99], fillvalue='d')) \
== [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), ('d', 'd', 99)]
assert list(zl(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)]
assert list(zl()) == []
assert list(zl(*zl(['a', 'b', 'c'], range(1, 4)))) \
== [('a', 'b', 'c'), (1, 2, 3)]
assert list(zl(*zl(['a', 'b', 'c'], range(1, 5)))) \
== [('a', 'b', 'c', None), (1, 2, 3, 4)]
assert list(zl(*zl(['a', 'b', 'c'], range(1, 5), fillvalue=100))) \
== [('a', 'b', 'c', 100), (1, 2, 3, 4)]
# test infinite iterator
class Counter(object):
def __init__(self, counter=0):
self.counter = counter
def __next__(self):
self.counter += 1
return self.counter
def __iter__(self):
return self
it = zl(Counter(), Counter(3))
assert next(it) == (1, 4)
assert next(it) == (2, 5)
it = zl([1,2], [3])
assert next(it) == (1, 3)
assert next(it) == (2, None)
with assert_raises(StopIteration):
next(it)

View File

@@ -953,6 +953,86 @@ impl PyItertoolsCombinations {
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsZiplongest {
iterators: Vec<PyObjectRef>,
fillvalue: PyObjectRef,
numactive: RefCell<usize>,
}
impl PyValue for PyItertoolsZiplongest {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "zip_longest")
}
}
#[derive(FromArgs)]
struct ZiplongestArgs {
#[pyarg(keyword_only, optional = true)]
fillvalue: OptionalArg<PyObjectRef>,
}
#[pyimpl]
impl PyItertoolsZiplongest {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
iterables: Args,
args: ZiplongestArgs,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let fillvalue = match args.fillvalue.into_option() {
Some(i) => i,
None => vm.get_none(),
};
let iterators = iterables
.into_iter()
.map(|iterable| get_iter(vm, &iterable))
.collect::<Result<Vec<_>, _>>()?;
let numactive = RefCell::new(iterators.len());
PyItertoolsZiplongest {
iterators,
fillvalue,
numactive,
}.into_ref_with_type(vm, cls)
}
#[pymethod(name = "__next__")]
fn next(&self, vm:&VirtualMachine) -> PyResult {
if self.iterators.is_empty() {
Err(new_stop_iteration(vm))
} else {
let mut next_obj: PyObjectRef;
let mut result: Vec<PyObjectRef> = Vec::new();
let mut numactive = self.numactive.clone().into_inner();
for idx in 0..self.iterators.len(){
next_obj = match call_next(vm, &self.iterators[idx]) {
Ok(obj) => obj,
Err(_) => {
numactive -= 1;
if numactive == 0 {
return Err(new_stop_iteration(vm));
}
self.fillvalue.clone()
}
};
result.push(next_obj);
}
Ok(vm.ctx.new_tuple(result))
}
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
}
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let ctx = &vm.ctx;
@@ -991,6 +1071,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let tee = ctx.new_class("tee", ctx.object());
PyItertoolsTee::extend_class(ctx, &tee);
let zip_longest = ctx.new_class("zip_longest", ctx.object());
PyItertoolsZiplongest::extend_class(ctx, &zip_longest);
py_module!(vm, "itertools", {
"accumulate" => accumulate,
"chain" => chain,
@@ -1005,5 +1088,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"takewhile" => takewhile,
"tee" => tee,
"product" => product,
"zip_longest" => zip_longest,
})
}