From 98d90c830b0c8b3e2545d0dc7895bc47c6f37172 Mon Sep 17 00:00:00 2001 From: Yonatan Goldschmidt Date: Sun, 12 May 2019 23:16:23 +0300 Subject: [PATCH] Add `itertools.takewhile` --- tests/snippets/stdlib_itertools.py | 44 ++++++++++++++++++++ vm/src/stdlib/itertools.rs | 66 +++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 61c031dd7..92fbe4a29 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -82,3 +82,47 @@ with assertRaises(StopIteration): r = itertools.repeat(1, -1) with assertRaises(StopIteration): next(r) + + +# itertools.takewhile tests + +from itertools import takewhile as tw + +t = tw(lambda n: n < 5, [1, 2, 5, 1, 3]) +assert next(t) == 1 +assert next(t) == 2 +with assertRaises(StopIteration): + next(t) + +# not iterable +with assertRaises(TypeError): + tw(lambda n: n < 1, 1) + +# not callable +t = tw(5, [1, 2]) +with assertRaises(TypeError): + next(t) + +# non-bool predicate +t = tw(lambda n: n, [1, 2, 0]) +assert next(t) == 1 +assert next(t) == 2 +with assertRaises(StopIteration): + next(t) + +# bad predicate prototype +t = tw(lambda: True, [1]) +with assertRaises(TypeError): + next(t) + +# StopIteration before attempting to call (bad) predicate +t = tw(lambda: True, []) +with assertRaises(StopIteration): + next(t) + +# doesn't try again after the first predicate failure +t = tw(lambda n: n < 1, [1, 0]) +with assertRaises(StopIteration): + next(t) +with assertRaises(StopIteration): + next(t) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 3856fe92f..e97c53bfa 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -5,8 +5,9 @@ use std::ops::{AddAssign, SubAssign}; use num_bigint::BigInt; use crate::function::OptionalArg; +use crate::obj::objbool; use crate::obj::objint::{PyInt, PyIntRef}; -use crate::obj::objiter::new_stop_iteration; +use crate::obj::objiter::{call_next, get_iter, new_stop_iteration}; use crate::obj::objtype::PyClassRef; use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; @@ -121,6 +122,65 @@ impl PyItertoolsRepeat { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsTakewhile { + predicate: PyObjectRef, + iterable: PyObjectRef, + stop_flag: RefCell, +} + +impl PyValue for PyItertoolsTakewhile { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "takewhile") + } +} + +#[pyimpl] +impl PyItertoolsTakewhile { + #[pymethod(name = "__new__")] + fn new( + _cls: PyClassRef, + predicate: PyObjectRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + let iter = get_iter(vm, &iterable)?; + + Ok(PyItertoolsTakewhile { + predicate: predicate, + iterable: iter, + stop_flag: RefCell::new(false), + } + .into_ref(vm) + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if *self.stop_flag.borrow() { + return Err(new_stop_iteration(vm)); + } + + // might be StopIteration or anything else, which is propaged upwwards + let obj = call_next(vm, &self.iterable)?; + + let verdict = vm.invoke(self.predicate.clone(), vec![obj.clone()])?; + let verdict = objbool::boolval(vm, verdict)?; + if verdict { + Ok(obj) + } else { + *self.stop_flag.borrow_mut() = true; + Err(new_stop_iteration(vm)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -130,8 +190,12 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let repeat = ctx.new_class("repeat", ctx.object()); PyItertoolsRepeat::extend_class(ctx, &repeat); + let takewhile = ctx.new_class("takewhile", ctx.object()); + PyItertoolsTakewhile::extend_class(ctx, &takewhile); + py_module!(vm, "itertools", { "count" => count, "repeat" => repeat, + "takewhile" => takewhile, }) }