Merge pull request #384 from skinny121/iter_lazy_2

Add enumerate and zip types
This commit is contained in:
Windel Bouwman
2019-02-08 20:05:20 +01:00
committed by GitHub
11 changed files with 243 additions and 135 deletions

View File

@@ -0,0 +1,23 @@
assert list(enumerate(['a', 'b', 'c'])) == [(0, 'a'), (1, 'b'), (2, 'c')]
assert type(enumerate([])) == enumerate
assert list(enumerate(['a', 'b', 'c'], -100)) == [(-100, 'a'), (-99, 'b'), (-98, 'c')]
assert list(enumerate(['a', 'b', 'c'], 2**200)) == [(2**200, 'a'), (2**200 + 1, 'b'), (2**200 + 2, 'c')]
# test infinite iterator
class Counter(object):
counter = 0
def __next__(self):
self.counter += 1
return self.counter
def __iter__(self):
return self
it = enumerate(Counter())
assert next(it) == (0, 1)
assert next(it) == (1, 2)

View File

@@ -0,0 +1,24 @@
assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
assert list(zip(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)]
assert list(zip()) == []
assert list(zip(*zip(['a', 'b', 'c'], range(1, 4)))) == [('a', 'b', 'c'), (1, 2, 3)]
# test infinite iterator
class Counter(object):
def __init__(self, counter=0):
self.counter = counter
def __next__(self):
self.counter += 1
return self.counter
def __iter__(self):
return self
it = zip(Counter(), Counter(3))
assert next(it) == (1, 4)
assert next(it) == (2, 5)

View File

@@ -5,12 +5,8 @@ assert callable(type)
# TODO:
# assert callable(callable)
assert list(enumerate(['a', 'b', 'c'])) == [(0, 'a'), (1, 'b'), (2, 'c')]
assert type(frozenset) is type
assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
assert 3 == eval('1+2')
code = compile('5+3', 'x.py', 'eval')

View File

@@ -21,8 +21,7 @@ use super::pyobject::{
use super::stdlib::io::io_open;
use super::vm::VirtualMachine;
use num_bigint::ToBigInt;
use num_traits::{Signed, ToPrimitive, Zero};
use num_traits::{Signed, ToPrimitive};
fn get_locals(vm: &mut VirtualMachine) -> PyObjectRef {
let d = vm.new_dict();
@@ -180,29 +179,6 @@ fn builtin_divmod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
}
fn builtin_enumerate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(iterable, None)],
optional = [(start, None)]
);
let items = vm.extract_elements(iterable)?;
let start = if let Some(start) = start {
objint::get_value(start)
} else {
Zero::zero()
};
let mut new_items = vec![];
for (i, item) in items.into_iter().enumerate() {
let element = vm
.ctx
.new_tuple(vec![vm.ctx.new_int(i.to_bigint().unwrap() + &start), item]);
new_items.push(element);
}
Ok(vm.ctx.new_list(new_items))
}
/// Implements `eval`.
/// See also: https://docs.python.org/3/library/functions.html#eval
fn builtin_eval(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -641,32 +617,6 @@ fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
// builtin_vars
fn builtin_zip(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
no_kwargs!(vm, args);
// TODO: process one element at a time from iterators.
let mut iterables = vec![];
for iterable in args.args.iter() {
let iterable = vm.extract_elements(iterable)?;
iterables.push(iterable);
}
let minsize: usize = iterables.iter().map(|i| i.len()).min().unwrap_or(0);
let mut new_items = vec![];
for i in 0..minsize {
let items = iterables
.iter()
.map(|iterable| iterable[i].clone())
.collect();
let element = vm.ctx.new_tuple(items);
new_items.push(element);
}
Ok(vm.ctx.new_list(new_items))
}
// builtin___import__
pub fn make_module(ctx: &PyContext) -> PyObjectRef {
@@ -692,7 +642,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&py_mod, "dict", ctx.dict_type());
ctx.set_attr(&py_mod, "divmod", ctx.new_rustfunc(builtin_divmod));
ctx.set_attr(&py_mod, "dir", ctx.new_rustfunc(builtin_dir));
ctx.set_attr(&py_mod, "enumerate", ctx.new_rustfunc(builtin_enumerate));
ctx.set_attr(&py_mod, "enumerate", ctx.enumerate_type());
ctx.set_attr(&py_mod, "eval", ctx.new_rustfunc(builtin_eval));
ctx.set_attr(&py_mod, "exec", ctx.new_rustfunc(builtin_exec));
ctx.set_attr(&py_mod, "float", ctx.float_type());
@@ -733,7 +683,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&py_mod, "super", ctx.super_type());
ctx.set_attr(&py_mod, "tuple", ctx.tuple_type());
ctx.set_attr(&py_mod, "type", ctx.type_type());
ctx.set_attr(&py_mod, "zip", ctx.new_rustfunc(builtin_zip));
ctx.set_attr(&py_mod, "zip", ctx.zip_type());
// Exceptions:
ctx.set_attr(

View File

@@ -6,6 +6,7 @@ pub mod objbytes;
pub mod objcode;
pub mod objcomplex;
pub mod objdict;
pub mod objenumerate;
pub mod objfilter;
pub mod objfloat;
pub mod objframe;
@@ -25,3 +26,4 @@ pub mod objstr;
pub mod objsuper;
pub mod objtuple;
pub mod objtype;
pub mod objzip;

View File

@@ -0,0 +1,69 @@
use super::super::pyobject::{
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
};
use super::super::vm::VirtualMachine;
use super::objint;
use super::objiter;
use super::objtype; // Required for arg_check! to use isinstance
use num_bigint::BigInt;
use num_traits::Zero;
use std::ops::AddAssign;
fn enumerate_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(cls, Some(vm.ctx.type_type())), (iterable, None)],
optional = [(start, Some(vm.ctx.int_type()))]
);
let counter = if let Some(x) = start {
objint::get_value(x)
} else {
BigInt::zero()
};
let iterator = objiter::get_iter(vm, iterable)?;
Ok(PyObject::new(
PyObjectPayload::EnumerateIterator { counter, iterator },
cls.clone(),
))
}
fn enumerate_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(enumerate, Some(vm.ctx.enumerate_type()))]
);
if let PyObjectPayload::EnumerateIterator {
ref mut counter,
ref mut iterator,
} = enumerate.borrow_mut().payload
{
let next_obj = objiter::call_next(vm, iterator)?;
let result = vm
.ctx
.new_tuple(vec![vm.ctx.new_int(counter.clone()), next_obj]);
AddAssign::add_assign(counter, 1);
Ok(result)
} else {
panic!("enumerate doesn't have correct payload");
}
}
pub fn init(context: &PyContext) {
let enumerate_type = &context.enumerate_type;
objiter::iter_type_init(context, enumerate_type);
context.set_attr(
enumerate_type,
"__new__",
context.new_rustfunc(enumerate_new),
);
context.set_attr(
enumerate_type,
"__next__",
context.new_rustfunc(enumerate_next),
);
}

View File

@@ -7,7 +7,7 @@ use super::objbool;
use super::objiter;
use super::objtype; // Required for arg_check! to use isinstance
pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
@@ -23,21 +23,6 @@ pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
))
}
fn filter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]);
// Return self:
Ok(filter.clone())
}
fn filter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(filter, Some(vm.ctx.filter_type())), (needle, None)]
);
objiter::contains(vm, filter, needle)
}
fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]);
@@ -72,12 +57,7 @@ fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
pub fn init(context: &PyContext) {
let filter_type = &context.filter_type;
context.set_attr(
&filter_type,
"__contains__",
context.new_rustfunc(filter_contains),
);
context.set_attr(&filter_type, "__iter__", context.new_rustfunc(filter_iter));
objiter::iter_type_init(context, filter_type);
context.set_attr(&filter_type, "__new__", context.new_rustfunc(filter_new));
context.set_attr(&filter_type, "__next__", context.new_rustfunc(filter_next));
}

View File

@@ -65,7 +65,17 @@ pub fn get_all(
Ok(elements)
}
pub fn contains(vm: &mut VirtualMachine, iter: &PyObjectRef, needle: &PyObjectRef) -> PyResult {
pub fn new_stop_iteration(vm: &mut VirtualMachine) -> PyObjectRef {
let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone();
vm.new_exception(stop_iteration_type, "End of iterator".to_string())
}
fn contains(vm: &mut VirtualMachine, args: PyFuncArgs, iter_type: PyObjectRef) -> PyResult {
arg_check!(
vm,
args,
required = [(iter, Some(iter_type)), (needle, None)]
);
loop {
if let Some(element) = get_next_object(vm, iter)? {
let equal = vm.call_method(needle, "__eq__", vec![element.clone()])?;
@@ -80,6 +90,34 @@ pub fn contains(vm: &mut VirtualMachine, iter: &PyObjectRef, needle: &PyObjectRe
}
}
/// Common setup for iter types, adds __iter__ and __contains__ methods
pub fn iter_type_init(context: &PyContext, iter_type: &PyObjectRef) {
let contains_func = {
let cloned_iter_type = iter_type.clone();
move |vm: &mut VirtualMachine, args: PyFuncArgs| {
contains(vm, args, cloned_iter_type.clone())
}
};
context.set_attr(
&iter_type,
"__contains__",
context.new_rustfunc(contains_func),
);
let iter_func = {
let cloned_iter_type = iter_type.clone();
move |vm: &mut VirtualMachine, args: PyFuncArgs| {
arg_check!(
vm,
args,
required = [(iter, Some(cloned_iter_type.clone()))]
);
// Return self:
Ok(iter.clone())
}
};
context.set_attr(&iter_type, "__iter__", context.new_rustfunc(iter_func));
}
// Sequence iterator:
fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(iter_target, None)]);
@@ -87,21 +125,6 @@ fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
get_iter(vm, iter_target)
}
fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]);
// Return self:
Ok(iter.clone())
}
fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(iter, Some(vm.ctx.iter_type())), (needle, None)]
);
contains(vm, iter, needle)
}
fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]);
@@ -118,10 +141,7 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
*position += 1;
Ok(obj_ref)
} else {
let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone();
let stop_iteration =
vm.new_exception(stop_iteration_type, "End of iterator".to_string());
Err(stop_iteration)
Err(new_stop_iteration(vm))
}
}
@@ -130,10 +150,7 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
*position += 1;
Ok(vm.ctx.new_int(int))
} else {
let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone();
let stop_iteration =
vm.new_exception(stop_iteration_type, "End of iterator".to_string());
Err(stop_iteration)
Err(new_stop_iteration(vm))
}
}
@@ -143,10 +160,7 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
*position += 1;
Ok(obj_ref)
} else {
let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone();
let stop_iteration =
vm.new_exception(stop_iteration_type, "End of iterator".to_string());
Err(stop_iteration)
Err(new_stop_iteration(vm))
}
}
@@ -161,12 +175,7 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
pub fn init(context: &PyContext) {
let iter_type = &context.iter_type;
context.set_attr(
&iter_type,
"__contains__",
context.new_rustfunc(iter_contains),
);
context.set_attr(&iter_type, "__iter__", context.new_rustfunc(iter_iter));
iter_type_init(context, iter_type);
context.set_attr(&iter_type, "__new__", context.new_rustfunc(iter_new));
context.set_attr(&iter_type, "__next__", context.new_rustfunc(iter_next));
}

View File

@@ -5,7 +5,7 @@ use super::super::vm::VirtualMachine;
use super::objiter;
use super::objtype; // Required for arg_check! to use isinstance
pub fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
no_kwargs!(vm, args);
let cls = &args.args[0];
if args.args.len() < 3 {
@@ -27,21 +27,6 @@ pub fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
}
fn map_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]);
// Return self:
Ok(map.clone())
}
fn map_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(map, Some(vm.ctx.map_type())), (needle, None)]
);
objiter::contains(vm, map, needle)
}
fn map_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]);
@@ -70,12 +55,7 @@ fn map_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
pub fn init(context: &PyContext) {
let map_type = &context.map_type;
context.set_attr(
&map_type,
"__contains__",
context.new_rustfunc(map_contains),
);
context.set_attr(&map_type, "__iter__", context.new_rustfunc(map_iter));
objiter::iter_type_init(context, map_type);
context.set_attr(&map_type, "__new__", context.new_rustfunc(map_new));
context.set_attr(&map_type, "__next__", context.new_rustfunc(map_next));
}

46
vm/src/obj/objzip.rs Normal file
View File

@@ -0,0 +1,46 @@
use super::super::pyobject::{
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
};
use super::super::vm::VirtualMachine;
use super::objiter;
use super::objtype; // Required for arg_check! to use isinstance
fn zip_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
no_kwargs!(vm, args);
let cls = &args.args[0];
let iterables = &args.args[1..];
let iterators = iterables
.into_iter()
.map(|iterable| objiter::get_iter(vm, iterable))
.collect::<Result<Vec<_>, _>>()?;
Ok(PyObject::new(
PyObjectPayload::ZipIterator { iterators },
cls.clone(),
))
}
fn zip_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zip, Some(vm.ctx.zip_type()))]);
if let PyObjectPayload::ZipIterator { ref mut iterators } = zip.borrow_mut().payload {
if iterators.is_empty() {
Err(objiter::new_stop_iteration(vm))
} else {
let next_objs = iterators
.iter()
.map(|iterator| objiter::call_next(vm, iterator))
.collect::<Result<Vec<_>, _>>()?;
Ok(vm.ctx.new_tuple(next_objs))
}
} else {
panic!("zip doesn't have correct payload");
}
}
pub fn init(context: &PyContext) {
let zip_type = &context.zip_type;
objiter::iter_type_init(context, zip_type);
context.set_attr(zip_type, "__new__", context.new_rustfunc(zip_new));
context.set_attr(zip_type, "__next__", context.new_rustfunc(zip_next));
}

View File

@@ -7,6 +7,7 @@ use super::obj::objbytes;
use super::obj::objcode;
use super::obj::objcomplex;
use super::obj::objdict;
use super::obj::objenumerate;
use super::obj::objfilter;
use super::obj::objfloat;
use super::obj::objframe;
@@ -25,6 +26,7 @@ use super::obj::objstr;
use super::obj::objsuper;
use super::obj::objtuple;
use super::obj::objtype;
use super::obj::objzip;
use super::vm::VirtualMachine;
use num_bigint::BigInt;
use num_bigint::ToBigInt;
@@ -113,6 +115,7 @@ pub struct PyContext {
pub classmethod_type: PyObjectRef,
pub code_type: PyObjectRef,
pub dict_type: PyObjectRef,
pub enumerate_type: PyObjectRef,
pub filter_type: PyObjectRef,
pub float_type: PyObjectRef,
pub frame_type: PyObjectRef,
@@ -134,6 +137,7 @@ pub struct PyContext {
pub str_type: PyObjectRef,
pub range_type: PyObjectRef,
pub type_type: PyObjectRef,
pub zip_type: PyObjectRef,
pub function_type: PyObjectRef,
pub property_type: PyObjectRef,
pub module_type: PyObjectRef,
@@ -204,8 +208,10 @@ impl PyContext {
let bytearray_type = create_type("bytearray", &type_type, &object_type, &dict_type);
let tuple_type = create_type("tuple", &type_type, &object_type, &dict_type);
let iter_type = create_type("iter", &type_type, &object_type, &dict_type);
let enumerate_type = create_type("enumerate", &type_type, &object_type, &dict_type);
let filter_type = create_type("filter", &type_type, &object_type, &dict_type);
let map_type = create_type("map", &type_type, &object_type, &dict_type);
let zip_type = create_type("zip", &type_type, &object_type, &dict_type);
let bool_type = create_type("bool", &type_type, &int_type, &dict_type);
let memoryview_type = create_type("memoryview", &type_type, &object_type, &dict_type);
let code_type = create_type("code", &type_type, &int_type, &dict_type);
@@ -246,8 +252,10 @@ impl PyContext {
false_value,
tuple_type,
iter_type,
enumerate_type,
filter_type,
map_type,
zip_type,
dict_type,
none,
str_type,
@@ -283,8 +291,10 @@ impl PyContext {
objsuper::init(&context);
objtuple::init(&context);
objiter::init(&context);
objenumerate::init(&context);
objfilter::init(&context);
objmap::init(&context);
objzip::init(&context);
objbool::init(&context);
objcode::init(&context);
objframe::init(&context);
@@ -356,6 +366,10 @@ impl PyContext {
self.iter_type.clone()
}
pub fn enumerate_type(&self) -> PyObjectRef {
self.enumerate_type.clone()
}
pub fn filter_type(&self) -> PyObjectRef {
self.filter_type.clone()
}
@@ -364,6 +378,10 @@ impl PyContext {
self.map_type.clone()
}
pub fn zip_type(&self) -> PyObjectRef {
self.zip_type.clone()
}
pub fn str_type(&self) -> PyObjectRef {
self.str_type.clone()
}
@@ -886,6 +904,10 @@ pub enum PyObjectPayload {
position: usize,
iterated_obj: PyObjectRef,
},
EnumerateIterator {
counter: BigInt,
iterator: PyObjectRef,
},
FilterIterator {
predicate: PyObjectRef,
iterator: PyObjectRef,
@@ -894,6 +916,9 @@ pub enum PyObjectPayload {
mapper: PyObjectRef,
iterators: Vec<PyObjectRef>,
},
ZipIterator {
iterators: Vec<PyObjectRef>,
},
Slice {
start: Option<i32>,
stop: Option<i32>,
@@ -962,8 +987,10 @@ impl fmt::Debug for PyObjectPayload {
PyObjectPayload::WeakRef { .. } => write!(f, "weakref"),
PyObjectPayload::Range { .. } => write!(f, "range"),
PyObjectPayload::Iterator { .. } => write!(f, "iterator"),
PyObjectPayload::EnumerateIterator { .. } => write!(f, "enumerate"),
PyObjectPayload::FilterIterator { .. } => write!(f, "filter"),
PyObjectPayload::MapIterator { .. } => write!(f, "map"),
PyObjectPayload::ZipIterator { .. } => write!(f, "zip"),
PyObjectPayload::Slice { .. } => write!(f, "slice"),
PyObjectPayload::Code { ref code } => write!(f, "code: {:?}", code),
PyObjectPayload::Function { .. } => write!(f, "function"),
@@ -1059,8 +1086,10 @@ impl PyObject {
position,
iterated_obj.borrow_mut().str()
),
PyObjectPayload::EnumerateIterator { .. } => format!("<enumerate>"),
PyObjectPayload::FilterIterator { .. } => format!("<filter>"),
PyObjectPayload::MapIterator { .. } => format!("<map>"),
PyObjectPayload::ZipIterator { .. } => format!("<zip>"),
}
}