diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index e8285fac95..66342f8715 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -183,3 +183,21 @@ assert_matches_seq(it, [1, 2]) it = i([1, 2, 3], None, None, 3) assert_matches_seq(it, [1]) + +# itertools.filterfalse +it = itertools.filterfalse(lambda x: x%2, range(10)) +assert 0 == next(it) +assert 2 == next(it) +assert 4 == next(it) +assert 6 == next(it) +assert 8 == next(it) +with assertRaises(StopIteration): + next(it) + +l = [0, 1, None, False, True, [], {}] +it = itertools.filterfalse(None, l) +assert 0 == next(it) +assert None == next(it) +assert False == next(it) +assert [] == next(it) +assert {} == next(it) \ No newline at end of file diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 6039c178a0..e5ed280a33 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -418,6 +418,64 @@ impl PyItertoolsIslice { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsFilterFalse { + predicate: PyObjectRef, + iterable: PyObjectRef, +} + +impl PyValue for PyItertoolsFilterFalse { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "filterfalse") + } +} + +#[pyimpl] +impl PyItertoolsFilterFalse { + #[pymethod(name = "__new__")] + #[allow(clippy::new_ret_no_self)] + fn new( + _cls: PyClassRef, + predicate: PyObjectRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + let iter = get_iter(vm, &iterable)?; + + Ok(PyItertoolsFilterFalse { + predicate, + iterable: iter, + } + .into_ref(vm) + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let predicate = &self.predicate; + let iterable = &self.iterable; + + loop { + let obj = call_next(vm, iterable)?; + let pred_value = if predicate.is(&vm.get_none()) { + obj.clone() + } else { + vm.invoke(predicate, vec![obj.clone()])? + }; + + if !objbool::boolval(vm, pred_value)? { + return Ok(obj); + } + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -436,6 +494,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let islice = PyItertoolsIslice::make_class(ctx); + let filterfalse = ctx.new_class("filterfalse", ctx.object()); + PyItertoolsFilterFalse::extend_class(ctx, &filterfalse); + py_module!(vm, "itertools", { "chain" => chain, "count" => count, @@ -443,5 +504,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "starmap" => starmap, "takewhile" => takewhile, "islice" => islice, + "filterfalse" => filterfalse, }) }