Merge pull request #1372 from j30ng/itertools-accumulate

Implement itertools.accumulate
This commit is contained in:
Noah
2019-09-14 12:09:04 -05:00
committed by GitHub
2 changed files with 90 additions and 1 deletions

View File

@@ -209,4 +209,29 @@ assert 6 == next(it)
assert 4 == next(it)
assert 1 == next(it)
with assertRaises(StopIteration):
next(it)
next(it)
# itertools.accumulate
it = itertools.accumulate([6, 3, 7, 1, 0, 9, 8, 8])
assert 6 == next(it)
assert 9 == next(it)
assert 16 == next(it)
assert 17 == next(it)
assert 17 == next(it)
assert 26 == next(it)
assert 34 == next(it)
assert 42 == next(it)
with assertRaises(StopIteration):
next(it)
it = itertools.accumulate([3, 2, 4, 1, 0, 5, 8], lambda a, v: a*v)
assert 3 == next(it)
assert 6 == next(it)
assert 24 == next(it)
assert 24 == next(it)
assert 0 == next(it)
assert 0 == next(it)
assert 0 == next(it)
with assertRaises(StopIteration):
next(it)

View File

@@ -537,6 +537,66 @@ impl PyItertoolsFilterFalse {
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsAccumulate {
iterable: PyObjectRef,
binop: PyObjectRef,
acc_value: RefCell<Option<PyObjectRef>>,
}
impl PyValue for PyItertoolsAccumulate {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "accumulate")
}
}
#[pyimpl]
impl PyItertoolsAccumulate {
#[pymethod(name = "__new__")]
#[allow(clippy::new_ret_no_self)]
fn new(
cls: PyClassRef,
iterable: PyObjectRef,
binop: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyRef<PyItertoolsAccumulate>> {
let iter = get_iter(vm, &iterable)?;
PyItertoolsAccumulate {
iterable: iter,
binop: binop.unwrap_or_else(|| vm.get_none()),
acc_value: RefCell::from(Option::None),
}
.into_ref_with_type(vm, cls)
}
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
let iterable = &self.iterable;
let obj = call_next(vm, iterable)?;
let next_acc_value = match &*self.acc_value.borrow() {
Option::None => obj.clone(),
Option::Some(value) => {
if self.binop.is(&vm.get_none()) {
vm._add(value.clone(), obj.clone())?
} else {
vm.invoke(&self.binop, vec![value.clone(), obj.clone()])?
}
}
};
self.acc_value.replace(Option::from(next_acc_value.clone()));
Ok(next_acc_value)
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
}
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let ctx = &vm.ctx;
@@ -561,6 +621,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let filterfalse = ctx.new_class("filterfalse", ctx.object());
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);
let accumulate = ctx.new_class("accumulate", ctx.object());
PyItertoolsAccumulate::extend_class(ctx, &accumulate);
py_module!(vm, "itertools", {
"chain" => chain,
"count" => count,
@@ -570,5 +633,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"takewhile" => takewhile,
"islice" => islice,
"filterfalse" => filterfalse,
"accumulate" => accumulate,
})
}