diff --git a/tests/snippets/builtin_filter.py b/tests/snippets/builtin_filter.py new file mode 100644 index 000000000..d0b5ccd5c --- /dev/null +++ b/tests/snippets/builtin_filter.py @@ -0,0 +1,32 @@ +assert list(filter(lambda x: ((x % 2) == 0), [0, 1, 2])) == [0, 2] + +# None implies identity +assert list(filter(None, [0, 1, 2])) == [1, 2] + +assert type(filter(None, [])) == filter + + +# test infinite iterator +class Counter(object): + counter = 0 + + def __next__(self): + self.counter += 1 + return self.counter + + def __iter__(self): + return self + + +it = filter(lambda x: ((x % 2) == 0), Counter()) +assert next(it) == 2 +assert next(it) == 4 + + +def predicate(x): + if x == 0: + raise StopIteration() + return True + + +assert list(filter(predicate, [1, 2, 0, 4, 5])) == [1, 2] diff --git a/tests/snippets/builtin_map.py b/tests/snippets/builtin_map.py new file mode 100644 index 000000000..0de8d2c59 --- /dev/null +++ b/tests/snippets/builtin_map.py @@ -0,0 +1,34 @@ +a = list(map(str, [1, 2, 3])) +assert a == ['1', '2', '3'] + + +b = list(map(lambda x, y: x + y, [1, 2, 4], [3, 5])) +assert b == [4, 7] + +assert type(map(lambda x: x, [])) == map + + +# test infinite iterator +class Counter(object): + counter = 0 + + def __next__(self): + self.counter += 1 + return self.counter + + def __iter__(self): + return self + + +it = map(lambda x: x+1, Counter()) +assert next(it) == 2 +assert next(it) == 3 + + +def mapping(x): + if x == 0: + raise StopIteration() + return x + + +assert list(map(mapping, [1, 2, 0, 4, 5])) == [1, 2] diff --git a/tests/snippets/builtins.py b/tests/snippets/builtins.py index bbf116abc..539b49ef7 100644 --- a/tests/snippets/builtins.py +++ b/tests/snippets/builtins.py @@ -1,8 +1,4 @@ - -a = list(map(str, [1, 2, 3])) -assert a == ['1', '2', '3'] - -x = sum(map(int, a)) +x = sum(map(int, ['1', '2', '3'])) assert x == 6 assert callable(type) @@ -15,8 +11,6 @@ assert type(frozenset) is type assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)] -assert list(filter(lambda x: ((x % 2) == 0), [0, 1, 2])) == [0, 2] - assert 3 == eval('1+2') code = compile('5+3', 'x.py', 'eval') diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 83893ee83..afd7efc43 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -301,29 +301,6 @@ fn builtin_exec(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { vm.run_code_obj(code_obj, scope) } -fn builtin_filter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(function, None), (iterable, None)]); - - // TODO: process one element at a time from iterators. - let iterable = vm.extract_elements(iterable)?; - - let mut new_items = vec![]; - for element in iterable { - // apply function: - let args = PyFuncArgs { - args: vec![element.clone()], - kwargs: vec![], - }; - let result = vm.invoke(function.clone(), args)?; - let result = objbool::boolval(vm, result)?; - if result { - new_items.push(element); - } - } - - Ok(vm.ctx.new_list(new_items)) -} - fn builtin_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -428,33 +405,6 @@ fn builtin_locals(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_locals()) } -fn builtin_map(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(function, None), (iter_target, None)]); - let iterator = objiter::get_iter(vm, iter_target)?; - let mut elements = vec![]; - loop { - match vm.call_method(&iterator, "__next__", vec![]) { - Ok(v) => { - // Now apply function: - let mapped_value = vm.invoke( - function.clone(), - PyFuncArgs { - args: vec![v], - kwargs: vec![], - }, - )?; - elements.push(mapped_value); - } - Err(_) => break, - } - } - - trace!("Mapped elements: {:?}", elements); - - // TODO: when iterators are implemented, we can improve this function. - Ok(vm.ctx.new_list(elements)) -} - fn builtin_max(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let candidates = if args.args.len() > 1 { args.args.clone() @@ -749,7 +699,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "exec", ctx.new_rustfunc(builtin_exec)); ctx.set_attr(&py_mod, "float", ctx.float_type()); ctx.set_attr(&py_mod, "frozenset", ctx.frozenset_type()); - ctx.set_attr(&py_mod, "filter", ctx.new_rustfunc(builtin_filter)); + ctx.set_attr(&py_mod, "filter", ctx.filter_type()); ctx.set_attr(&py_mod, "format", ctx.new_rustfunc(builtin_format)); ctx.set_attr(&py_mod, "getattr", ctx.new_rustfunc(builtin_getattr)); ctx.set_attr(&py_mod, "hasattr", ctx.new_rustfunc(builtin_hasattr)); @@ -763,7 +713,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "len", ctx.new_rustfunc(builtin_len)); ctx.set_attr(&py_mod, "list", ctx.list_type()); ctx.set_attr(&py_mod, "locals", ctx.new_rustfunc(builtin_locals)); - ctx.set_attr(&py_mod, "map", ctx.new_rustfunc(builtin_map)); + ctx.set_attr(&py_mod, "map", ctx.map_type()); ctx.set_attr(&py_mod, "max", ctx.new_rustfunc(builtin_max)); ctx.set_attr(&py_mod, "memoryview", ctx.memoryview_type()); ctx.set_attr(&py_mod, "min", ctx.new_rustfunc(builtin_min)); @@ -819,6 +769,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "ValueError", ctx.exceptions.value_error.clone()); ctx.set_attr(&py_mod, "IndexError", ctx.exceptions.index_error.clone()); ctx.set_attr(&py_mod, "ImportError", ctx.exceptions.import_error.clone()); + ctx.set_attr( + &py_mod, + "StopIteration", + ctx.exceptions.stop_iteration.clone(), + ); py_mod } diff --git a/vm/src/obj/mod.rs b/vm/src/obj/mod.rs index 03017a30b..a60ce79ff 100644 --- a/vm/src/obj/mod.rs +++ b/vm/src/obj/mod.rs @@ -6,6 +6,7 @@ pub mod objbytes; pub mod objcode; pub mod objcomplex; pub mod objdict; +pub mod objfilter; pub mod objfloat; pub mod objframe; pub mod objfunction; @@ -13,6 +14,7 @@ pub mod objgenerator; pub mod objint; pub mod objiter; pub mod objlist; +pub mod objmap; pub mod objmemory; pub mod objobject; pub mod objproperty; diff --git a/vm/src/obj/objfilter.rs b/vm/src/obj/objfilter.rs new file mode 100644 index 000000000..009d2ad2e --- /dev/null +++ b/vm/src/obj/objfilter.rs @@ -0,0 +1,83 @@ +use super::super::pyobject::{ + IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, + TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objbool; +use super::objiter; +use super::objtype; // Required for arg_check! to use isinstance + +pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(cls, None), (function, None), (iterable, None)] + ); + let iterator = objiter::get_iter(vm, iterable)?; + Ok(PyObject::new( + PyObjectPayload::FilterIterator { + predicate: function.clone(), + iterator, + }, + cls.clone(), + )) +} + +fn filter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]); + // Return self: + Ok(filter.clone()) +} + +fn filter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(filter, Some(vm.ctx.filter_type())), (needle, None)] + ); + objiter::contains(vm, filter, needle) +} + +fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]); + + if let PyObjectPayload::FilterIterator { + ref mut predicate, + ref mut iterator, + } = filter.borrow_mut().payload + { + loop { + let next_obj = objiter::call_next(vm, iterator)?; + let predicate_value = if predicate.is(&vm.get_none()) { + next_obj.clone() + } else { + // the predicate itself can raise StopIteration which does stop the filter + // iteration + vm.invoke( + predicate.clone(), + PyFuncArgs { + args: vec![next_obj.clone()], + kwargs: vec![], + }, + )? + }; + if objbool::boolval(vm, predicate_value)? { + return Ok(next_obj); + } + } + } else { + panic!("filter doesn't have correct payload"); + } +} + +pub fn init(context: &PyContext) { + let filter_type = &context.filter_type; + context.set_attr( + &filter_type, + "__contains__", + context.new_rustfunc(filter_contains), + ); + context.set_attr(&filter_type, "__iter__", context.new_rustfunc(filter_iter)); + context.set_attr(&filter_type, "__new__", context.new_rustfunc(filter_new)); + context.set_attr(&filter_type, "__next__", context.new_rustfunc(filter_next)); +} diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 53ed76804..19382bc28 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -23,6 +23,10 @@ pub fn get_iter(vm: &mut VirtualMachine, iter_target: &PyObjectRef) -> PyResult // return Err(type_error); } +pub fn call_next(vm: &mut VirtualMachine, iter_obj: &PyObjectRef) -> PyResult { + vm.call_method(iter_obj, "__next__", vec![]) +} + /* * Helper function to retrieve the next object (or none) from an iterator. */ @@ -30,7 +34,7 @@ pub fn get_next_object( vm: &mut VirtualMachine, iter_obj: &PyObjectRef, ) -> Result, PyObjectRef> { - let next_obj: PyResult = vm.call_method(iter_obj, "__next__", vec![]); + let next_obj: PyResult = call_next(vm, iter_obj); match next_obj { Ok(value) => Ok(Some(value)), @@ -61,6 +65,21 @@ pub fn get_all( Ok(elements) } +pub fn contains(vm: &mut VirtualMachine, iter: &PyObjectRef, needle: &PyObjectRef) -> PyResult { + loop { + if let Some(element) = get_next_object(vm, iter)? { + let equal = vm.call_method(needle, "__eq__", vec![element.clone()])?; + if objbool::get_value(&equal) { + return Ok(vm.new_bool(true)); + } else { + continue; + } + } else { + return Ok(vm.new_bool(false)); + } + } +} + // Sequence iterator: fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(iter_target, None)]); @@ -80,21 +99,7 @@ fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { args, required = [(iter, Some(vm.ctx.iter_type())), (needle, None)] ); - loop { - match vm.call_method(&iter, "__next__", vec![]) { - Ok(element) => match vm.call_method(needle, "__eq__", vec![element.clone()]) { - Ok(value) => { - if objbool::get_value(&value) { - return Ok(vm.new_bool(true)); - } else { - continue; - } - } - Err(_) => return Err(vm.new_type_error("".to_string())), - }, - Err(_) => return Ok(vm.new_bool(false)), - } - } + contains(vm, iter, needle) } fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objmap.rs b/vm/src/obj/objmap.rs new file mode 100644 index 000000000..ed6130643 --- /dev/null +++ b/vm/src/obj/objmap.rs @@ -0,0 +1,81 @@ +use super::super::pyobject::{ + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objiter; +use super::objtype; // Required for arg_check! to use isinstance + +pub fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + no_kwargs!(vm, args); + let cls = &args.args[0]; + if args.args.len() < 3 { + Err(vm.new_type_error("map() must have at least two arguments.".to_owned())) + } else { + let function = &args.args[1]; + let iterables = &args.args[2..]; + let iterators = iterables + .into_iter() + .map(|iterable| objiter::get_iter(vm, iterable)) + .collect::, _>>()?; + Ok(PyObject::new( + PyObjectPayload::MapIterator { + mapper: function.clone(), + iterators, + }, + cls.clone(), + )) + } +} + +fn map_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]); + // Return self: + Ok(map.clone()) +} + +fn map_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(map, Some(vm.ctx.map_type())), (needle, None)] + ); + objiter::contains(vm, map, needle) +} + +fn map_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]); + + if let PyObjectPayload::MapIterator { + ref mut mapper, + ref mut iterators, + } = map.borrow_mut().payload + { + let next_objs = iterators + .iter() + .map(|iterator| objiter::call_next(vm, iterator)) + .collect::, _>>()?; + + // the mapper itself can raise StopIteration which does stop the map iteration + vm.invoke( + mapper.clone(), + PyFuncArgs { + args: next_objs, + kwargs: vec![], + }, + ) + } else { + panic!("map doesn't have correct payload"); + } +} + +pub fn init(context: &PyContext) { + let map_type = &context.map_type; + context.set_attr( + &map_type, + "__contains__", + context.new_rustfunc(map_contains), + ); + context.set_attr(&map_type, "__iter__", context.new_rustfunc(map_iter)); + context.set_attr(&map_type, "__new__", context.new_rustfunc(map_new)); + context.set_attr(&map_type, "__next__", context.new_rustfunc(map_next)); +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index fec9578be..245314093 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -7,6 +7,7 @@ use super::obj::objbytes; use super::obj::objcode; use super::obj::objcomplex; use super::obj::objdict; +use super::obj::objfilter; use super::obj::objfloat; use super::obj::objframe; use super::obj::objfunction; @@ -14,6 +15,7 @@ use super::obj::objgenerator; use super::obj::objint; use super::obj::objiter; use super::obj::objlist; +use super::obj::objmap; use super::obj::objmemory; use super::obj::objobject; use super::obj::objproperty; @@ -106,6 +108,7 @@ pub struct PyContext { pub classmethod_type: PyObjectRef, pub code_type: PyObjectRef, pub dict_type: PyObjectRef, + pub filter_type: PyObjectRef, pub float_type: PyObjectRef, pub frame_type: PyObjectRef, pub frozenset_type: PyObjectRef, @@ -116,6 +119,7 @@ pub struct PyContext { pub true_value: PyObjectRef, pub false_value: PyObjectRef, pub list_type: PyObjectRef, + pub map_type: PyObjectRef, pub memoryview_type: PyObjectRef, pub none: PyObjectRef, pub tuple_type: PyObjectRef, @@ -200,6 +204,8 @@ impl PyContext { let bytearray_type = create_type("bytearray", &type_type, &object_type, &dict_type); let tuple_type = create_type("tuple", &type_type, &object_type, &dict_type); let iter_type = create_type("iter", &type_type, &object_type, &dict_type); + let filter_type = create_type("filter", &type_type, &object_type, &dict_type); + let map_type = create_type("map", &type_type, &object_type, &dict_type); let bool_type = create_type("bool", &type_type, &int_type, &dict_type); let memoryview_type = create_type("memoryview", &type_type, &object_type, &dict_type); let code_type = create_type("code", &type_type, &int_type, &dict_type); @@ -240,6 +246,8 @@ impl PyContext { false_value, tuple_type, iter_type, + filter_type, + map_type, dict_type, none: none, str_type: str_type, @@ -275,6 +283,8 @@ impl PyContext { objsuper::init(&context); objtuple::init(&context); objiter::init(&context); + objfilter::init(&context); + objmap::init(&context); objbool::init(&context); objcode::init(&context); objframe::init(&context); @@ -346,6 +356,14 @@ impl PyContext { self.iter_type.clone() } + pub fn filter_type(&self) -> PyObjectRef { + self.filter_type.clone() + } + + pub fn map_type(&self) -> PyObjectRef { + self.map_type.clone() + } + pub fn str_type(&self) -> PyObjectRef { self.str_type.clone() } @@ -866,6 +884,14 @@ pub enum PyObjectPayload { position: usize, iterated_obj: PyObjectRef, }, + FilterIterator { + predicate: PyObjectRef, + iterator: PyObjectRef, + }, + MapIterator { + mapper: PyObjectRef, + iterators: Vec, + }, Slice { start: Option, stop: Option, @@ -934,6 +960,8 @@ impl fmt::Debug for PyObjectPayload { PyObjectPayload::WeakRef { .. } => write!(f, "weakref"), PyObjectPayload::Range { .. } => write!(f, "range"), PyObjectPayload::Iterator { .. } => write!(f, "iterator"), + PyObjectPayload::FilterIterator { .. } => write!(f, "filter"), + PyObjectPayload::MapIterator { .. } => write!(f, "map"), PyObjectPayload::Slice { .. } => write!(f, "slice"), PyObjectPayload::Code { ref code } => write!(f, "code: {:?}", code), PyObjectPayload::Function { .. } => write!(f, "function"), @@ -1030,6 +1058,8 @@ impl PyObject { position, iterated_obj.borrow_mut().str() ), + PyObjectPayload::FilterIterator { .. } => format!(""), + PyObjectPayload::MapIterator { .. } => format!(""), } }