Merge pull request #1645 from dralley/itertools

Add itertools.cycle() and itertools.chain.from_iterable()
This commit is contained in:
Noah
2019-12-25 13:05:45 -05:00
committed by GitHub
2 changed files with 145 additions and 2 deletions

View File

@@ -21,6 +21,36 @@ assert next(x) == 'b'
with assert_raises(TypeError):
next(x)
# empty
with assert_raises(TypeError):
chain.from_iterable()
with assert_raises(TypeError):
chain.from_iterable("abc", "def")
with assert_raises(TypeError):
# iterables are lazily evaluated -- can be constructed but will fail to execute
list(chain.from_iterable([1, 2, 3]))
with assert_raises(TypeError):
list(chain(1))
args = ["abc", "def"]
assert list(chain.from_iterable(args)) == ['a', 'b', 'c', 'd', 'e', 'f']
args = [[], "", b"", ()]
assert list(chain.from_iterable(args)) == []
args = ["ab", "cd", (), 'e']
assert list(chain.from_iterable(args)) == ['a', 'b', 'c', 'd', 'e']
x = chain.from_iterable(["ab", 1])
assert next(x) == 'a'
assert next(x) == 'b'
with assert_raises(TypeError):
next(x)
# itertools.count tests
# default arguments
@@ -76,6 +106,32 @@ assert next(c) == 5
# assert next(c) == 1.5
# itertools.cycle tests
r = itertools.cycle([1, 2, 3])
assert next(r) == 1
assert next(r) == 2
assert next(r) == 3
assert next(r) == 1
assert next(r) == 2
assert next(r) == 3
assert next(r) == 1
r = itertools.cycle([1])
assert next(r) == 1
assert next(r) == 1
assert next(r) == 1
r = itertools.cycle([])
with assert_raises(StopIteration):
next(r)
with assert_raises(TypeError):
itertools.cycle(None)
with assert_raises(TypeError):
itertools.cycle(10)
# itertools.repeat tests
# no times
@@ -91,7 +147,7 @@ assert next(r) == 1
with assert_raises(StopIteration):
next(r)
# timees = 0
# times = 0
r = itertools.repeat(1, 0)
with assert_raises(StopIteration):
next(r)

View File

@@ -11,7 +11,7 @@ use num_traits::ToPrimitive;
use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs};
use crate::obj::objbool;
use crate::obj::objint::{self, PyInt, PyIntRef};
use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration};
use crate::obj::objiter::{call_next, get_all, get_iter, get_next_object, new_stop_iteration};
use crate::obj::objtuple::PyTuple;
use crate::obj::objtype::{self, PyClassRef};
use crate::pyobject::{
@@ -73,6 +73,22 @@ impl PyItertoolsChain {
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
#[pyclassmethod(name = "from_iterable")]
fn from_iterable(
cls: PyClassRef,
iterable: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let it = get_iter(vm, &iterable)?;
let iterables = get_all(vm, &it)?;
PyItertoolsChain {
iterables,
cur: RefCell::new((0, None)),
}
.into_ref_with_type(vm, cls)
}
}
#[pyclass(name = "compress")]
@@ -177,6 +193,73 @@ impl PyItertoolsCount {
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsCycle {
iter: RefCell<PyObjectRef>,
saved: RefCell<Vec<PyObjectRef>>,
index: Cell<usize>,
first_pass: Cell<bool>,
}
impl PyValue for PyItertoolsCycle {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "cycle")
}
}
#[pyimpl]
impl PyItertoolsCycle {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
iterable: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let iter = get_iter(vm, &iterable)?;
PyItertoolsCycle {
iter: RefCell::new(iter.clone()),
saved: RefCell::new(Vec::new()),
index: Cell::new(0),
first_pass: Cell::new(false),
}
.into_ref_with_type(vm, cls)
}
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? {
if self.first_pass.get() {
return Ok(item);
}
self.saved.borrow_mut().push(item.clone());
item
} else {
if self.saved.borrow().len() == 0 {
return Err(new_stop_iteration(vm));
}
let last_index = self.index.get();
self.index.set(self.index.get() + 1);
if self.index.get() >= self.saved.borrow().len() {
self.index.set(0);
}
self.saved.borrow()[last_index].clone()
};
Ok(item)
}
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}
}
#[pyclass]
#[derive(Debug)]
struct PyItertoolsRepeat {
@@ -1177,6 +1260,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let count = ctx.new_class("count", ctx.object());
PyItertoolsCount::extend_class(ctx, &count);
let cycle = ctx.new_class("cycle", ctx.object());
PyItertoolsCycle::extend_class(ctx, &cycle);
let dropwhile = ctx.new_class("dropwhile", ctx.object());
PyItertoolsDropwhile::extend_class(ctx, &dropwhile);
@@ -1211,6 +1297,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"compress" => compress,
"combinations" => combinations,
"count" => count,
"cycle" => cycle,
"dropwhile" => dropwhile,
"islice" => islice,
"filterfalse" => filterfalse,