From 708fd9bcfb7877b6e283e645b181a806f52f8ac0 Mon Sep 17 00:00:00 2001 From: j30ng Date: Sat, 14 Sep 2019 00:26:43 +0900 Subject: [PATCH 1/3] Implement itertools.accumulate --- tests/snippets/stdlib_itertools.py | 27 ++++++++++++- vm/src/stdlib/itertools.rs | 64 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 67664ebc2..eff1813a5 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -209,4 +209,29 @@ assert 6 == next(it) assert 4 == next(it) assert 1 == next(it) with assertRaises(StopIteration): - next(it) \ No newline at end of file + next(it) + + +# itertools.accumulate +it = itertools.accumulate([6, 3, 7, 1, 0, 9, 8, 8]) +assert 6 == next(it) +assert 9 == next(it) +assert 16 == next(it) +assert 17 == next(it) +assert 17 == next(it) +assert 26 == next(it) +assert 34 == next(it) +assert 42 == next(it) +with assertRaises(StopIteration): + next(it) + +it = itertools.accumulate([3, 2, 4, 1, 0, 5, 8], lambda a, v: a*v) +assert 3 == next(it) +assert 6 == next(it) +assert 24 == next(it) +assert 24 == next(it) +assert 0 == next(it) +assert 0 == next(it) +assert 0 == next(it) +with assertRaises(StopIteration): + next(it) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 1ba74f6d2..9c4b3ab60 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -537,6 +537,66 @@ impl PyItertoolsFilterFalse { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsAccumulate { + iterable: PyObjectRef, + binop: PyObjectRef, + acc_value: RefCell, +} + +impl PyValue for PyItertoolsAccumulate { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "accumulate") + } +} + +#[pyimpl] +impl PyItertoolsAccumulate { + #[pymethod(name = "__new__")] + #[allow(clippy::new_ret_no_self)] + fn new( + _cls: PyClassRef, + iterable: PyObjectRef, + binop: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let iter = get_iter(vm, &iterable)?; + + Ok(PyItertoolsAccumulate { + iterable: iter, + binop: binop.unwrap_or_else(|| vm.get_none()), + acc_value: RefCell::from(vm.get_none()), + } + .into_ref(vm) + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let acc_value = self.acc_value.borrow().clone(); + let iterable = &self.iterable; + + let obj = call_next(vm, iterable)?; + + let next_acc_value = if acc_value.is(&vm.get_none()) { + obj.clone() + } else if self.binop.is(&vm.get_none()) { + vm._add(acc_value, obj.clone())? + } else { + vm.invoke(&self.binop, vec![acc_value, obj.clone()])? + }; + + self.acc_value.replace(next_acc_value.clone()); + Ok(next_acc_value) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -561,6 +621,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let filterfalse = ctx.new_class("filterfalse", ctx.object()); PyItertoolsFilterFalse::extend_class(ctx, &filterfalse); + let accumulate = ctx.new_class("accumulate", ctx.object()); + PyItertoolsAccumulate::extend_class(ctx, &accumulate); + py_module!(vm, "itertools", { "chain" => chain, "count" => count, @@ -570,5 +633,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "takewhile" => takewhile, "islice" => islice, "filterfalse" => filterfalse, + "accumulate" => accumulate, }) } From f920982755c0b66a2125114640932eb65103dd53 Mon Sep 17 00:00:00 2001 From: j30ng Date: Sat, 14 Sep 2019 11:52:56 +0900 Subject: [PATCH 2/3] Python 'None' Value -> Option --- vm/src/stdlib/itertools.rs | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 9c4b3ab60..38d6166dd 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -542,7 +542,7 @@ impl PyItertoolsFilterFalse { struct PyItertoolsAccumulate { iterable: PyObjectRef, binop: PyObjectRef, - acc_value: RefCell, + acc_value: RefCell>, } impl PyValue for PyItertoolsAccumulate { @@ -566,7 +566,7 @@ impl PyItertoolsAccumulate { Ok(PyItertoolsAccumulate { iterable: iter, binop: binop.unwrap_or_else(|| vm.get_none()), - acc_value: RefCell::from(vm.get_none()), + acc_value: RefCell::from(Option::None), } .into_ref(vm) .into_object()) @@ -574,20 +574,21 @@ impl PyItertoolsAccumulate { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let acc_value = self.acc_value.borrow().clone(); let iterable = &self.iterable; - let obj = call_next(vm, iterable)?; - let next_acc_value = if acc_value.is(&vm.get_none()) { - obj.clone() - } else if self.binop.is(&vm.get_none()) { - vm._add(acc_value, obj.clone())? - } else { - vm.invoke(&self.binop, vec![acc_value, obj.clone()])? + let next_acc_value = match &*self.acc_value.borrow() { + Option::None => obj.clone(), + Option::Some(value) => { + if self.binop.is(&vm.get_none()) { + vm._add(value.clone(), obj.clone())? + } else { + vm.invoke(&self.binop, vec![value.clone(), obj.clone()])? + } + } }; + self.acc_value.replace(Option::from(next_acc_value.clone())); - self.acc_value.replace(next_acc_value.clone()); Ok(next_acc_value) } From d4e5d7644188a415b9cff43d70023690a7652a48 Mon Sep 17 00:00:00 2001 From: j30ng Date: Sat, 14 Sep 2019 12:02:05 +0900 Subject: [PATCH 3/3] Use into_ref_with_type() --- vm/src/stdlib/itertools.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 38d6166dd..dabbb220c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -556,20 +556,19 @@ impl PyItertoolsAccumulate { #[pymethod(name = "__new__")] #[allow(clippy::new_ret_no_self)] fn new( - _cls: PyClassRef, + cls: PyClassRef, iterable: PyObjectRef, binop: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let iter = get_iter(vm, &iterable)?; - Ok(PyItertoolsAccumulate { + PyItertoolsAccumulate { iterable: iter, binop: binop.unwrap_or_else(|| vm.get_none()), acc_value: RefCell::from(Option::None), } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")]