mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Merge pull request #1372 from j30ng/itertools-accumulate
Implement itertools.accumulate
This commit is contained in:
@@ -209,4 +209,29 @@ assert 6 == next(it)
|
||||
assert 4 == next(it)
|
||||
assert 1 == next(it)
|
||||
with assertRaises(StopIteration):
|
||||
next(it)
|
||||
next(it)
|
||||
|
||||
|
||||
# itertools.accumulate
|
||||
it = itertools.accumulate([6, 3, 7, 1, 0, 9, 8, 8])
|
||||
assert 6 == next(it)
|
||||
assert 9 == next(it)
|
||||
assert 16 == next(it)
|
||||
assert 17 == next(it)
|
||||
assert 17 == next(it)
|
||||
assert 26 == next(it)
|
||||
assert 34 == next(it)
|
||||
assert 42 == next(it)
|
||||
with assertRaises(StopIteration):
|
||||
next(it)
|
||||
|
||||
it = itertools.accumulate([3, 2, 4, 1, 0, 5, 8], lambda a, v: a*v)
|
||||
assert 3 == next(it)
|
||||
assert 6 == next(it)
|
||||
assert 24 == next(it)
|
||||
assert 24 == next(it)
|
||||
assert 0 == next(it)
|
||||
assert 0 == next(it)
|
||||
assert 0 == next(it)
|
||||
with assertRaises(StopIteration):
|
||||
next(it)
|
||||
|
||||
@@ -537,6 +537,66 @@ impl PyItertoolsFilterFalse {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug)]
|
||||
struct PyItertoolsAccumulate {
|
||||
iterable: PyObjectRef,
|
||||
binop: PyObjectRef,
|
||||
acc_value: RefCell<Option<PyObjectRef>>,
|
||||
}
|
||||
|
||||
impl PyValue for PyItertoolsAccumulate {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
vm.class("itertools", "accumulate")
|
||||
}
|
||||
}
|
||||
|
||||
#[pyimpl]
|
||||
impl PyItertoolsAccumulate {
|
||||
#[pymethod(name = "__new__")]
|
||||
#[allow(clippy::new_ret_no_self)]
|
||||
fn new(
|
||||
cls: PyClassRef,
|
||||
iterable: PyObjectRef,
|
||||
binop: OptionalArg<PyObjectRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyRef<PyItertoolsAccumulate>> {
|
||||
let iter = get_iter(vm, &iterable)?;
|
||||
|
||||
PyItertoolsAccumulate {
|
||||
iterable: iter,
|
||||
binop: binop.unwrap_or_else(|| vm.get_none()),
|
||||
acc_value: RefCell::from(Option::None),
|
||||
}
|
||||
.into_ref_with_type(vm, cls)
|
||||
}
|
||||
|
||||
#[pymethod(name = "__next__")]
|
||||
fn next(&self, vm: &VirtualMachine) -> PyResult {
|
||||
let iterable = &self.iterable;
|
||||
let obj = call_next(vm, iterable)?;
|
||||
|
||||
let next_acc_value = match &*self.acc_value.borrow() {
|
||||
Option::None => obj.clone(),
|
||||
Option::Some(value) => {
|
||||
if self.binop.is(&vm.get_none()) {
|
||||
vm._add(value.clone(), obj.clone())?
|
||||
} else {
|
||||
vm.invoke(&self.binop, vec![value.clone(), obj.clone()])?
|
||||
}
|
||||
}
|
||||
};
|
||||
self.acc_value.replace(Option::from(next_acc_value.clone()));
|
||||
|
||||
Ok(next_acc_value)
|
||||
}
|
||||
|
||||
#[pymethod(name = "__iter__")]
|
||||
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
|
||||
zelf
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let ctx = &vm.ctx;
|
||||
|
||||
@@ -561,6 +621,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let filterfalse = ctx.new_class("filterfalse", ctx.object());
|
||||
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);
|
||||
|
||||
let accumulate = ctx.new_class("accumulate", ctx.object());
|
||||
PyItertoolsAccumulate::extend_class(ctx, &accumulate);
|
||||
|
||||
py_module!(vm, "itertools", {
|
||||
"chain" => chain,
|
||||
"count" => count,
|
||||
@@ -570,5 +633,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
"takewhile" => takewhile,
|
||||
"islice" => islice,
|
||||
"filterfalse" => filterfalse,
|
||||
"accumulate" => accumulate,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user