From ce514d2aa5dcc30f8fa4e744adc8eb537b7776e3 Mon Sep 17 00:00:00 2001 From: Yonatan Goldschmidt Date: Fri, 10 May 2019 17:27:28 +0300 Subject: [PATCH] Add `itertools.count` --- tests/snippets/stdlib_itertools.py | 54 ++++++++++++++++++++++++ vm/src/stdlib/itertools.rs | 68 +++++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 994a79fbb..cfda2fb68 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -1 +1,55 @@ import itertools + +# count + +# default arguments +c = itertools.count() +assert next(c) == 0 +assert next(c) == 1 +assert next(c) == 2 + +# positional +c = itertools.count(2, 3) +assert next(c) == 2 +assert next(c) == 5 +assert next(c) == 8 + +# backwards +c = itertools.count(1, -10) +assert next(c) == 1 +assert next(c) == -9 +assert next(c) == -19 + +# step = 0 +c = itertools.count(5, 0) +assert next(c) == 5 +assert next(c) == 5 + +# itertools.count TODOs: kwargs and floats + +# step kwarg +# c = itertools.count(step=5) +# assert next(c) == 0 +# assert next(c) == 5 + +# start kwarg +# c = itertools.count(start=10) +# assert next(c) == 10 + +# float start +# c = itertools.count(0.5) +# assert next(c) == 0.5 +# assert next(c) == 1.5 +# assert next(c) == 2.5 + +# float step +# c = itertools.count(1, 0.5) +# assert next(c) == 1 +# assert next(c) == 1.5 +# assert next(c) == 2 + +# float start + step +# c = itertools.count(0.5, 0.5) +# assert next(c) == 0.5 +# assert next(c) == 1 +# assert next(c) == 1.5 diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index f37c9c5f8..a7a955fa5 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1,7 +1,73 @@ -use crate::pyobject::PyObjectRef; +use std::cell::RefCell; +use std::ops::AddAssign; + +use num_bigint::BigInt; + +use crate::function::OptionalArg; +use crate::obj::objint::{PyInt, PyIntRef}; +use crate::obj::objtype::PyClassRef; +use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; +#[pyclass] +#[derive(Debug)] +struct PyItertoolsCount { + cur: RefCell, + step: BigInt, +} + +impl PyValue for PyItertoolsCount { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "count") + } +} + +#[pyimpl] +impl PyItertoolsCount { + #[pymethod(name = "__new__")] + fn new( + _cls: PyClassRef, + start: OptionalArg, + step: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let start = match start.into_option() { + Some(int) => int.as_bigint().clone(), + None => BigInt::from(0), + }; + let step = match step.into_option() { + Some(int) => int.as_bigint().clone(), + None => BigInt::from(1), + }; + + Ok(PyItertoolsCount { + cur: RefCell::new(start), + step: step, + } + .into_ref(vm) + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, _vm: &VirtualMachine) -> PyResult { + let result = self.cur.borrow().clone(); + AddAssign::add_assign(&mut self.cur.borrow_mut() as &mut BigInt, &self.step); + Ok(PyInt::new(result)) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { + let ctx = &vm.ctx; + + let count = ctx.new_class("count", ctx.object()); + PyItertoolsCount::extend_class(ctx, &count); + py_module!(vm, "itertools", { + "count" => count, }) }