Merge pull request #348 from skinny121/iter_lazy

Made filter and map lazy
This commit is contained in:
Windel Bouwman
2019-02-06 07:43:07 +01:00
committed by GitHub
9 changed files with 291 additions and 75 deletions

View File

@@ -0,0 +1,32 @@
assert list(filter(lambda x: ((x % 2) == 0), [0, 1, 2])) == [0, 2]
# None implies identity
assert list(filter(None, [0, 1, 2])) == [1, 2]
assert type(filter(None, [])) == filter
# test infinite iterator
class Counter(object):
counter = 0
def __next__(self):
self.counter += 1
return self.counter
def __iter__(self):
return self
it = filter(lambda x: ((x % 2) == 0), Counter())
assert next(it) == 2
assert next(it) == 4
def predicate(x):
if x == 0:
raise StopIteration()
return True
assert list(filter(predicate, [1, 2, 0, 4, 5])) == [1, 2]

View File

@@ -0,0 +1,34 @@
a = list(map(str, [1, 2, 3]))
assert a == ['1', '2', '3']
b = list(map(lambda x, y: x + y, [1, 2, 4], [3, 5]))
assert b == [4, 7]
assert type(map(lambda x: x, [])) == map
# test infinite iterator
class Counter(object):
counter = 0
def __next__(self):
self.counter += 1
return self.counter
def __iter__(self):
return self
it = map(lambda x: x+1, Counter())
assert next(it) == 2
assert next(it) == 3
def mapping(x):
if x == 0:
raise StopIteration()
return x
assert list(map(mapping, [1, 2, 0, 4, 5])) == [1, 2]

View File

@@ -1,8 +1,4 @@
a = list(map(str, [1, 2, 3]))
assert a == ['1', '2', '3']
x = sum(map(int, a))
x = sum(map(int, ['1', '2', '3']))
assert x == 6
assert callable(type)
@@ -15,8 +11,6 @@ assert type(frozenset) is type
assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
assert list(filter(lambda x: ((x % 2) == 0), [0, 1, 2])) == [0, 2]
assert 3 == eval('1+2')
code = compile('5+3', 'x.py', 'eval')

View File

@@ -301,29 +301,6 @@ fn builtin_exec(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
vm.run_code_obj(code_obj, scope)
}
fn builtin_filter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(function, None), (iterable, None)]);
// TODO: process one element at a time from iterators.
let iterable = vm.extract_elements(iterable)?;
let mut new_items = vec![];
for element in iterable {
// apply function:
let args = PyFuncArgs {
args: vec![element.clone()],
kwargs: vec![],
};
let result = vm.invoke(function.clone(), args)?;
let result = objbool::boolval(vm, result)?;
if result {
new_items.push(element);
}
}
Ok(vm.ctx.new_list(new_items))
}
fn builtin_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
@@ -428,33 +405,6 @@ fn builtin_locals(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.get_locals())
}
fn builtin_map(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(function, None), (iter_target, None)]);
let iterator = objiter::get_iter(vm, iter_target)?;
let mut elements = vec![];
loop {
match vm.call_method(&iterator, "__next__", vec![]) {
Ok(v) => {
// Now apply function:
let mapped_value = vm.invoke(
function.clone(),
PyFuncArgs {
args: vec![v],
kwargs: vec![],
},
)?;
elements.push(mapped_value);
}
Err(_) => break,
}
}
trace!("Mapped elements: {:?}", elements);
// TODO: when iterators are implemented, we can improve this function.
Ok(vm.ctx.new_list(elements))
}
fn builtin_max(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let candidates = if args.args.len() > 1 {
args.args.clone()
@@ -749,7 +699,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&py_mod, "exec", ctx.new_rustfunc(builtin_exec));
ctx.set_attr(&py_mod, "float", ctx.float_type());
ctx.set_attr(&py_mod, "frozenset", ctx.frozenset_type());
ctx.set_attr(&py_mod, "filter", ctx.new_rustfunc(builtin_filter));
ctx.set_attr(&py_mod, "filter", ctx.filter_type());
ctx.set_attr(&py_mod, "format", ctx.new_rustfunc(builtin_format));
ctx.set_attr(&py_mod, "getattr", ctx.new_rustfunc(builtin_getattr));
ctx.set_attr(&py_mod, "hasattr", ctx.new_rustfunc(builtin_hasattr));
@@ -763,7 +713,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&py_mod, "len", ctx.new_rustfunc(builtin_len));
ctx.set_attr(&py_mod, "list", ctx.list_type());
ctx.set_attr(&py_mod, "locals", ctx.new_rustfunc(builtin_locals));
ctx.set_attr(&py_mod, "map", ctx.new_rustfunc(builtin_map));
ctx.set_attr(&py_mod, "map", ctx.map_type());
ctx.set_attr(&py_mod, "max", ctx.new_rustfunc(builtin_max));
ctx.set_attr(&py_mod, "memoryview", ctx.memoryview_type());
ctx.set_attr(&py_mod, "min", ctx.new_rustfunc(builtin_min));
@@ -819,6 +769,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&py_mod, "ValueError", ctx.exceptions.value_error.clone());
ctx.set_attr(&py_mod, "IndexError", ctx.exceptions.index_error.clone());
ctx.set_attr(&py_mod, "ImportError", ctx.exceptions.import_error.clone());
ctx.set_attr(
&py_mod,
"StopIteration",
ctx.exceptions.stop_iteration.clone(),
);
py_mod
}

View File

@@ -6,6 +6,7 @@ pub mod objbytes;
pub mod objcode;
pub mod objcomplex;
pub mod objdict;
pub mod objfilter;
pub mod objfloat;
pub mod objframe;
pub mod objfunction;
@@ -13,6 +14,7 @@ pub mod objgenerator;
pub mod objint;
pub mod objiter;
pub mod objlist;
pub mod objmap;
pub mod objmemory;
pub mod objobject;
pub mod objproperty;

83
vm/src/obj/objfilter.rs Normal file
View File

@@ -0,0 +1,83 @@
use super::super::pyobject::{
IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult,
TypeProtocol,
};
use super::super::vm::VirtualMachine;
use super::objbool;
use super::objiter;
use super::objtype; // Required for arg_check! to use isinstance
pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(cls, None), (function, None), (iterable, None)]
);
let iterator = objiter::get_iter(vm, iterable)?;
Ok(PyObject::new(
PyObjectPayload::FilterIterator {
predicate: function.clone(),
iterator,
},
cls.clone(),
))
}
fn filter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]);
// Return self:
Ok(filter.clone())
}
fn filter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(filter, Some(vm.ctx.filter_type())), (needle, None)]
);
objiter::contains(vm, filter, needle)
}
fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]);
if let PyObjectPayload::FilterIterator {
ref mut predicate,
ref mut iterator,
} = filter.borrow_mut().payload
{
loop {
let next_obj = objiter::call_next(vm, iterator)?;
let predicate_value = if predicate.is(&vm.get_none()) {
next_obj.clone()
} else {
// the predicate itself can raise StopIteration which does stop the filter
// iteration
vm.invoke(
predicate.clone(),
PyFuncArgs {
args: vec![next_obj.clone()],
kwargs: vec![],
},
)?
};
if objbool::boolval(vm, predicate_value)? {
return Ok(next_obj);
}
}
} else {
panic!("filter doesn't have correct payload");
}
}
pub fn init(context: &PyContext) {
let filter_type = &context.filter_type;
context.set_attr(
&filter_type,
"__contains__",
context.new_rustfunc(filter_contains),
);
context.set_attr(&filter_type, "__iter__", context.new_rustfunc(filter_iter));
context.set_attr(&filter_type, "__new__", context.new_rustfunc(filter_new));
context.set_attr(&filter_type, "__next__", context.new_rustfunc(filter_next));
}

View File

@@ -23,6 +23,10 @@ pub fn get_iter(vm: &mut VirtualMachine, iter_target: &PyObjectRef) -> PyResult
// return Err(type_error);
}
pub fn call_next(vm: &mut VirtualMachine, iter_obj: &PyObjectRef) -> PyResult {
vm.call_method(iter_obj, "__next__", vec![])
}
/*
* Helper function to retrieve the next object (or none) from an iterator.
*/
@@ -30,7 +34,7 @@ pub fn get_next_object(
vm: &mut VirtualMachine,
iter_obj: &PyObjectRef,
) -> Result<Option<PyObjectRef>, PyObjectRef> {
let next_obj: PyResult = vm.call_method(iter_obj, "__next__", vec![]);
let next_obj: PyResult = call_next(vm, iter_obj);
match next_obj {
Ok(value) => Ok(Some(value)),
@@ -61,6 +65,21 @@ pub fn get_all(
Ok(elements)
}
pub fn contains(vm: &mut VirtualMachine, iter: &PyObjectRef, needle: &PyObjectRef) -> PyResult {
loop {
if let Some(element) = get_next_object(vm, iter)? {
let equal = vm.call_method(needle, "__eq__", vec![element.clone()])?;
if objbool::get_value(&equal) {
return Ok(vm.new_bool(true));
} else {
continue;
}
} else {
return Ok(vm.new_bool(false));
}
}
}
// Sequence iterator:
fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(iter_target, None)]);
@@ -80,21 +99,7 @@ fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
args,
required = [(iter, Some(vm.ctx.iter_type())), (needle, None)]
);
loop {
match vm.call_method(&iter, "__next__", vec![]) {
Ok(element) => match vm.call_method(needle, "__eq__", vec![element.clone()]) {
Ok(value) => {
if objbool::get_value(&value) {
return Ok(vm.new_bool(true));
} else {
continue;
}
}
Err(_) => return Err(vm.new_type_error("".to_string())),
},
Err(_) => return Ok(vm.new_bool(false)),
}
}
contains(vm, iter, needle)
}
fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

81
vm/src/obj/objmap.rs Normal file
View File

@@ -0,0 +1,81 @@
use super::super::pyobject::{
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
};
use super::super::vm::VirtualMachine;
use super::objiter;
use super::objtype; // Required for arg_check! to use isinstance
pub fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
no_kwargs!(vm, args);
let cls = &args.args[0];
if args.args.len() < 3 {
Err(vm.new_type_error("map() must have at least two arguments.".to_owned()))
} else {
let function = &args.args[1];
let iterables = &args.args[2..];
let iterators = iterables
.into_iter()
.map(|iterable| objiter::get_iter(vm, iterable))
.collect::<Result<Vec<_>, _>>()?;
Ok(PyObject::new(
PyObjectPayload::MapIterator {
mapper: function.clone(),
iterators,
},
cls.clone(),
))
}
}
fn map_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]);
// Return self:
Ok(map.clone())
}
fn map_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(map, Some(vm.ctx.map_type())), (needle, None)]
);
objiter::contains(vm, map, needle)
}
fn map_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]);
if let PyObjectPayload::MapIterator {
ref mut mapper,
ref mut iterators,
} = map.borrow_mut().payload
{
let next_objs = iterators
.iter()
.map(|iterator| objiter::call_next(vm, iterator))
.collect::<Result<Vec<_>, _>>()?;
// the mapper itself can raise StopIteration which does stop the map iteration
vm.invoke(
mapper.clone(),
PyFuncArgs {
args: next_objs,
kwargs: vec![],
},
)
} else {
panic!("map doesn't have correct payload");
}
}
pub fn init(context: &PyContext) {
let map_type = &context.map_type;
context.set_attr(
&map_type,
"__contains__",
context.new_rustfunc(map_contains),
);
context.set_attr(&map_type, "__iter__", context.new_rustfunc(map_iter));
context.set_attr(&map_type, "__new__", context.new_rustfunc(map_new));
context.set_attr(&map_type, "__next__", context.new_rustfunc(map_next));
}

View File

@@ -7,6 +7,7 @@ use super::obj::objbytes;
use super::obj::objcode;
use super::obj::objcomplex;
use super::obj::objdict;
use super::obj::objfilter;
use super::obj::objfloat;
use super::obj::objframe;
use super::obj::objfunction;
@@ -14,6 +15,7 @@ use super::obj::objgenerator;
use super::obj::objint;
use super::obj::objiter;
use super::obj::objlist;
use super::obj::objmap;
use super::obj::objmemory;
use super::obj::objobject;
use super::obj::objproperty;
@@ -106,6 +108,7 @@ pub struct PyContext {
pub classmethod_type: PyObjectRef,
pub code_type: PyObjectRef,
pub dict_type: PyObjectRef,
pub filter_type: PyObjectRef,
pub float_type: PyObjectRef,
pub frame_type: PyObjectRef,
pub frozenset_type: PyObjectRef,
@@ -116,6 +119,7 @@ pub struct PyContext {
pub true_value: PyObjectRef,
pub false_value: PyObjectRef,
pub list_type: PyObjectRef,
pub map_type: PyObjectRef,
pub memoryview_type: PyObjectRef,
pub none: PyObjectRef,
pub tuple_type: PyObjectRef,
@@ -200,6 +204,8 @@ impl PyContext {
let bytearray_type = create_type("bytearray", &type_type, &object_type, &dict_type);
let tuple_type = create_type("tuple", &type_type, &object_type, &dict_type);
let iter_type = create_type("iter", &type_type, &object_type, &dict_type);
let filter_type = create_type("filter", &type_type, &object_type, &dict_type);
let map_type = create_type("map", &type_type, &object_type, &dict_type);
let bool_type = create_type("bool", &type_type, &int_type, &dict_type);
let memoryview_type = create_type("memoryview", &type_type, &object_type, &dict_type);
let code_type = create_type("code", &type_type, &int_type, &dict_type);
@@ -240,6 +246,8 @@ impl PyContext {
false_value,
tuple_type,
iter_type,
filter_type,
map_type,
dict_type,
none: none,
str_type: str_type,
@@ -275,6 +283,8 @@ impl PyContext {
objsuper::init(&context);
objtuple::init(&context);
objiter::init(&context);
objfilter::init(&context);
objmap::init(&context);
objbool::init(&context);
objcode::init(&context);
objframe::init(&context);
@@ -346,6 +356,14 @@ impl PyContext {
self.iter_type.clone()
}
pub fn filter_type(&self) -> PyObjectRef {
self.filter_type.clone()
}
pub fn map_type(&self) -> PyObjectRef {
self.map_type.clone()
}
pub fn str_type(&self) -> PyObjectRef {
self.str_type.clone()
}
@@ -866,6 +884,14 @@ pub enum PyObjectPayload {
position: usize,
iterated_obj: PyObjectRef,
},
FilterIterator {
predicate: PyObjectRef,
iterator: PyObjectRef,
},
MapIterator {
mapper: PyObjectRef,
iterators: Vec<PyObjectRef>,
},
Slice {
start: Option<i32>,
stop: Option<i32>,
@@ -934,6 +960,8 @@ impl fmt::Debug for PyObjectPayload {
PyObjectPayload::WeakRef { .. } => write!(f, "weakref"),
PyObjectPayload::Range { .. } => write!(f, "range"),
PyObjectPayload::Iterator { .. } => write!(f, "iterator"),
PyObjectPayload::FilterIterator { .. } => write!(f, "filter"),
PyObjectPayload::MapIterator { .. } => write!(f, "map"),
PyObjectPayload::Slice { .. } => write!(f, "slice"),
PyObjectPayload::Code { ref code } => write!(f, "code: {:?}", code),
PyObjectPayload::Function { .. } => write!(f, "function"),
@@ -1030,6 +1058,8 @@ impl PyObject {
position,
iterated_obj.borrow_mut().str()
),
PyObjectPayload::FilterIterator { .. } => format!("<filter>"),
PyObjectPayload::MapIterator { .. } => format!("<map>"),
}
}