From 994e86fa25b99e73e2b63f42cd85896520e61095 Mon Sep 17 00:00:00 2001 From: Space0726 Date: Sat, 23 Nov 2019 02:37:16 +0900 Subject: [PATCH] Implement itertools.zip_longest --- tests/snippets/stdlib_itertools.py | 43 +++++++++++++++ vm/src/stdlib/itertools.rs | 84 ++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 5a400db714..117ccb9380 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -344,3 +344,46 @@ with assert_raises(StopIteration): with assert_raises(ValueError): itertools.combinations([1, 2, 3, 4], -2) + +# itertools.zip_longest tests +zl = itertools.zip_longest +assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \ + == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)] +assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) \ + == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), (None, None, 99)] +assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99], fillvalue='d')) \ + == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), ('d', 'd', 99)] + +assert list(zl(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)] +assert list(zl()) == [] + +assert list(zl(*zl(['a', 'b', 'c'], range(1, 4)))) \ + == [('a', 'b', 'c'), (1, 2, 3)] +assert list(zl(*zl(['a', 'b', 'c'], range(1, 5)))) \ + == [('a', 'b', 'c', None), (1, 2, 3, 4)] +assert list(zl(*zl(['a', 'b', 'c'], range(1, 5), fillvalue=100))) \ + == [('a', 'b', 'c', 100), (1, 2, 3, 4)] + + +# test infinite iterator +class Counter(object): + def __init__(self, counter=0): + self.counter = counter + + def __next__(self): + self.counter += 1 + return self.counter + + def __iter__(self): + return self + + +it = zl(Counter(), Counter(3)) +assert next(it) == (1, 4) +assert next(it) == (2, 5) + +it = zl([1,2], [3]) +assert next(it) == (1, 3) +assert next(it) == (2, None) +with assert_raises(StopIteration): + next(it) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 0d176029dc..19fd94c5c1 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -953,6 +953,86 @@ impl PyItertoolsCombinations { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsZiplongest { + iterators: Vec, + fillvalue: PyObjectRef, + numactive: RefCell, +} + +impl PyValue for PyItertoolsZiplongest { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "zip_longest") + } +} + +#[derive(FromArgs)] +struct ZiplongestArgs { + #[pyarg(keyword_only, optional = true)] + fillvalue: OptionalArg, +} + +#[pyimpl] +impl PyItertoolsZiplongest { + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, + iterables: Args, + args: ZiplongestArgs, + vm: &VirtualMachine, + ) -> PyResult> { + let fillvalue = match args.fillvalue.into_option() { + Some(i) => i, + None => vm.get_none(), + }; + + let iterators = iterables + .into_iter() + .map(|iterable| get_iter(vm, &iterable)) + .collect::, _>>()?; + + let numactive = RefCell::new(iterators.len()); + + PyItertoolsZiplongest { + iterators, + fillvalue, + numactive, + }.into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm:&VirtualMachine) -> PyResult { + if self.iterators.is_empty() { + Err(new_stop_iteration(vm)) + } else { + let mut next_obj: PyObjectRef; + let mut result: Vec = Vec::new(); + let mut numactive = self.numactive.clone().into_inner(); + + for idx in 0..self.iterators.len(){ + next_obj = match call_next(vm, &self.iterators[idx]) { + Ok(obj) => obj, + Err(_) => { + numactive -= 1; + if numactive == 0 { + return Err(new_stop_iteration(vm)); + } + self.fillvalue.clone() + } + }; + result.push(next_obj); + } + Ok(vm.ctx.new_tuple(result)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -991,6 +1071,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let tee = ctx.new_class("tee", ctx.object()); PyItertoolsTee::extend_class(ctx, &tee); + let zip_longest = ctx.new_class("zip_longest", ctx.object()); + PyItertoolsZiplongest::extend_class(ctx, &zip_longest); + py_module!(vm, "itertools", { "accumulate" => accumulate, "chain" => chain, @@ -1005,5 +1088,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "takewhile" => takewhile, "tee" => tee, "product" => product, + "zip_longest" => zip_longest, }) }