diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 853d7cd84..607c6e7a4 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -253,6 +253,40 @@ mod tests { ) } + #[test] + fn test_parse_tuples() { + let source = String::from("a, b = 4, 5\n"); + + assert_eq!( + parse_statement(&source), + Ok(ast::LocatedStatement { + location: ast::Location::new(1, 1), + node: ast::Statement::Assign { + targets: vec![ast::Expression::Tuple { + elements: vec![ + ast::Expression::Identifier { + name: "a".to_string() + }, + ast::Expression::Identifier { + name: "b".to_string() + } + ] + }], + value: ast::Expression::Tuple { + elements: vec![ + ast::Expression::Number { + value: ast::Number::Integer { value: 4 } + }, + ast::Expression::Number { + value: ast::Number::Integer { value: 5 } + } + ] + } + } + }) + ) + } + #[test] fn test_parse_class() { let source = String::from("class Foo(A, B):\n def __init__(self):\n pass\n def method_with_default(self, arg='default'):\n pass\n"); diff --git a/parser/src/python.lalrpop b/parser/src/python.lalrpop index 3135d7d15..11ee510bc 100644 --- a/parser/src/python.lalrpop +++ b/parser/src/python.lalrpop @@ -46,32 +46,38 @@ SmallStatement: ast::LocatedStatement = { }; ExpressionStatement: ast::LocatedStatement = { - => { - //match e2 { - // None => ast::Statement::Expression { expression: e }, - // Some(e3) => ast::Statement::Expression { expression: e }, - //} - if e2.len() > 0 { - // Dealing with assignment here - // TODO: for rhs in e2 { - let rhs = e2.into_iter().next().unwrap(); - // ast::Expression::Tuple { elements: e2.into_iter().next().unwrap() - let v = rhs.into_iter().next().unwrap(); - let lhs = ast::LocatedStatement { - location: loc.clone(), - node: ast::Statement::Assign { targets: e, value: v }, - }; - lhs - } else { - if e.len() > 1 { - panic!("Not good?"); - // ast::Statement::Expression { expression: e[0] } + => { + // Just an expression, no assignment: + if suffix.is_empty() { + if expr.len() > 1 { + ast::LocatedStatement { + location: loc.clone(), + node: ast::Statement::Expression { expression: ast::Expression::Tuple { elements: expr } } + } } else { ast::LocatedStatement { location: loc.clone(), - node: ast::Statement::Expression { expression: e.into_iter().next().unwrap() }, + node: ast::Statement::Expression { expression: expr[0].clone() }, } } + } else { + let mut targets = vec![if expr.len() > 1 { + ast::Expression::Tuple { elements: expr } + } else { + expr[0].clone() + }]; + let mut values : Vec = suffix.into_iter().map(|test_list| if test_list.len() > 1 { ast::Expression::Tuple { elements: test_list }} else { test_list[0].clone() }).collect(); + + while values.len() > 1 { + targets.push(values.remove(0)); + } + + let value = values[0].clone(); + + ast::LocatedStatement { + location: loc.clone(), + node: ast::Statement::Assign { targets, value }, + } } }, => { @@ -120,7 +126,7 @@ FlowStatement: ast::LocatedStatement = { "return" => { ast::LocatedStatement { location: loc, - node: ast::Statement::Return { value: t}, + node: ast::Statement::Return { value: t }, } }, "raise" => { diff --git a/tests/snippets/assignment.py b/tests/snippets/assignment.py new file mode 100644 index 000000000..457da9200 --- /dev/null +++ b/tests/snippets/assignment.py @@ -0,0 +1,28 @@ +x = 1 +assert x == 1 + +x = 1, 2, 3 +assert x == (1, 2, 3) + +x, y = 1, 2 +assert x == 1 +assert y == 2 + +x, y = (y, x) + +assert x == 2 +assert y == 1 + +((x, y), z) = ((1, 2), 3) + +assert (x, y, z) == (1, 2, 3) + +q = (1, 2, 3) +(x, y, z) = q +assert y == q[1] + +x = (a, b, c) = y = q + +assert (a, b, c) == q +assert x == q +assert y == q diff --git a/tests/snippets/ast_snippet.py b/tests/snippets/ast_snippet.py index 033d59696..43bf74756 100644 --- a/tests/snippets/ast_snippet.py +++ b/tests/snippets/ast_snippet.py @@ -5,13 +5,19 @@ print(ast) source = """ def foo(): print('bar') + pass """ n = ast.parse(source) print(n) print(n.body) print(n.body[0].name) assert n.body[0].name == 'foo' -print(n.body[0].body) -print(n.body[0].body[0]) -print(n.body[0].body[0].value.func.id) -assert n.body[0].body[0].value.func.id == 'print' +foo = n.body[0] +assert foo.lineno == 2 +print(foo.body) +assert len(foo.body) == 2 +print(foo.body[0]) +print(foo.body[0].value.func.id) +assert foo.body[0].value.func.id == 'print' +assert foo.body[0].lineno == 3 +assert foo.body[1].lineno == 4 diff --git a/tests/snippets/iterations.py b/tests/snippets/iterations.py new file mode 100644 index 000000000..98031a935 --- /dev/null +++ b/tests/snippets/iterations.py @@ -0,0 +1,11 @@ + + +ls = [1, 2, 3] + +i = iter(ls) +assert i.__next__() == 1 +assert i.__next__() == 2 +assert next(i) == 3 + +assert next(i, 'w00t') == 'w00t' + diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 4ecdc2b1b..297187d2e 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::io::{self, Write}; use super::compile; +use super::obj::objiter; use super::obj::objstr; use super::obj::objtype; use super::objbool; @@ -221,7 +222,10 @@ fn builtin_issubclass(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.context().new_bool(objtype::issubclass(cls1, cls2))) } -// builtin_iter +fn builtin_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter_target, None)]); + objiter::get_iter(vm, iter_target) +} fn builtin_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); @@ -254,7 +258,30 @@ fn builtin_locals(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { // builtin_max // builtin_memoryview // builtin_min -// builtin_next + +fn builtin_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(iterator, None)], + optional = [(default_value, None)] + ); + + match vm.call_method(iterator.clone(), "__next__", vec![]) { + Ok(value) => Ok(value), + Err(value) => { + if objtype::isinstance(&value, vm.ctx.exceptions.stop_iteration.clone()) { + match default_value { + None => Err(value), + Some(value) => Ok(value.clone()), + } + } else { + Err(value) + } + } + } +} + // builtin_object // builtin_oct // builtin_open @@ -378,9 +405,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { String::from("issubclass"), ctx.new_rustfunc(builtin_issubclass), ); + dict.insert(String::from("iter"), ctx.new_rustfunc(builtin_iter)); dict.insert(String::from("len"), ctx.new_rustfunc(builtin_len)); dict.insert(String::from("list"), ctx.list_type()); dict.insert(String::from("locals"), ctx.new_rustfunc(builtin_locals)); + dict.insert(String::from("next"), ctx.new_rustfunc(builtin_next)); dict.insert(String::from("pow"), ctx.new_rustfunc(builtin_pow)); dict.insert(String::from("print"), ctx.new_rustfunc(builtin_print)); dict.insert(String::from("range"), ctx.new_rustfunc(builtin_range)); diff --git a/vm/src/bytecode.rs b/vm/src/bytecode.rs index f01f6ae4a..255d245ee 100644 --- a/vm/src/bytecode.rs +++ b/vm/src/bytecode.rs @@ -140,6 +140,9 @@ pub enum Instruction { PrintExpr, LoadBuildClass, StoreLocals, + UnpackSequence { + size: usize, + }, } #[derive(Debug, Clone, PartialEq)] diff --git a/vm/src/compile.rs b/vm/src/compile.rs index c9494a0e7..d902e9380 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -521,7 +521,10 @@ impl Compiler { ast::Statement::Assign { targets, value } => { self.compile_expression(value); - for target in targets { + for (i, target) in targets.into_iter().enumerate() { + if i + 1 != targets.len() { + self.emit(Instruction::Duplicate); + } self.compile_store(target); } } @@ -561,6 +564,14 @@ impl Compiler { name: name.to_string(), }); } + ast::Expression::Tuple { elements } => { + self.emit(Instruction::UnpackSequence { + size: elements.len(), + }); + for element in elements { + self.compile_store(element); + } + } _ => { panic!("WTF: {:?}", target); } diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index f962515d4..6de078f0e 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -89,6 +89,7 @@ pub struct ExceptionZoo { pub name_error: PyObjectRef, pub runtime_error: PyObjectRef, pub not_implemented_error: PyObjectRef, + pub stop_iteration: PyObjectRef, pub type_error: PyObjectRef, pub value_error: PyObjectRef, } @@ -138,6 +139,12 @@ impl ExceptionZoo { &runtime_error, &dict_type, ); + let stop_iteration = create_type( + &String::from("StopIteration"), + &type_type, + &exception_type, + &dict_type, + ); let type_error = create_type( &String::from("TypeError"), &type_type, @@ -159,6 +166,7 @@ impl ExceptionZoo { name_error: name_error, runtime_error: runtime_error, not_implemented_error: not_implemented_error, + stop_iteration: stop_iteration, type_error: type_error, value_error: value_error, } diff --git a/vm/src/obj/mod.rs b/vm/src/obj/mod.rs index 8dfa798ec..f49b4ac35 100644 --- a/vm/src/obj/mod.rs +++ b/vm/src/obj/mod.rs @@ -3,6 +3,7 @@ pub mod objdict; pub mod objfloat; pub mod objfunction; pub mod objint; +pub mod objiter; pub mod objlist; pub mod objobject; pub mod objsequence; diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs new file mode 100644 index 000000000..09afd8ce3 --- /dev/null +++ b/vm/src/obj/objiter.rs @@ -0,0 +1,92 @@ +/* + * Various types to support iteration. + */ + +use super::super::pyobject::{ + AttributeProtocol, PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult, + TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objstr; +use super::objtype; // Required for arg_check! to use isinstance + +/* + * This helper function is called at multiple places. First, it is called + * in the vm when a for loop is entered. Next, it is used when the builtin + * function 'iter' is called. + */ +pub fn get_iter(vm: &mut VirtualMachine, iter_target: &PyObjectRef) -> PyResult { + // Check what we are going to iterate over: + let iterated_obj = if objtype::isinstance(iter_target, vm.ctx.iter_type()) { + // If object is already an iterator, return that one. + return Ok(iter_target.clone()); + } else if objtype::isinstance(iter_target, vm.ctx.list_type()) { + iter_target.clone() + } else { + let type_str = objstr::get_value(&vm.to_str(iter_target.typ()).unwrap()); + let type_error = vm.new_type_error(format!("Cannot iterate over {}", type_str)); + return Err(type_error); + }; + + let iter_obj = PyObject::new( + PyObjectKind::Iterator { + position: 0, + iterated_obj: iterated_obj, + }, + vm.ctx.iter_type(), + ); + + // We are all good here: + Ok(iter_obj) +} + +// Sequence iterator: +fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter_target, None)]); + + get_iter(vm, iter_target) +} + +fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); + // Return self: + Ok(iter.clone()) +} + +fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); + + if let PyObjectKind::Iterator { + ref mut position, + iterated_obj: ref iterated_obj_ref, + } = iter.borrow_mut().kind + { + let iterated_obj = &*iterated_obj_ref.borrow_mut(); + match iterated_obj.kind { + PyObjectKind::List { ref elements } => { + if *position < elements.len() { + let obj_ref = elements[*position].clone(); + *position += 1; + Ok(obj_ref) + } else { + let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone(); + let stop_iteration = + vm.new_exception(stop_iteration_type, "End of iterator".to_string()); + Err(stop_iteration) + } + } + _ => { + panic!("NOT IMPL"); + } + } + } else { + panic!("NOT IMPL"); + } +} + +pub fn init(context: &PyContext) { + let ref iter_type = context.iter_type; + iter_type.set_attr("__new__", context.new_rustfunc(iter_new)); + iter_type.set_attr("__iter__", context.new_rustfunc(iter_iter)); + iter_type.set_attr("__next__", context.new_rustfunc(iter_next)); +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index cff8cca82..5a66135ce 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -5,6 +5,7 @@ use super::obj::objdict; use super::obj::objfloat; use super::obj::objfunction; use super::obj::objint; +use super::obj::objiter; use super::obj::objlist; use super::obj::objobject; use super::obj::objstr; @@ -60,6 +61,7 @@ pub struct PyContext { pub false_value: PyObjectRef, pub list_type: PyObjectRef, pub tuple_type: PyObjectRef, + pub iter_type: PyObjectRef, pub str_type: PyObjectRef, pub function_type: PyObjectRef, pub module_type: PyObjectRef, @@ -123,6 +125,7 @@ impl PyContext { let float_type = create_type("float", &type_type, &object_type, &dict_type); let bytes_type = create_type("bytes", &type_type, &object_type, &dict_type); let tuple_type = create_type("tuple", &type_type, &object_type, &dict_type); + let iter_type = create_type("iter", &type_type, &object_type, &dict_type); let bool_type = create_type("bool", &type_type, &int_type, &dict_type); let exceptions = exceptions::ExceptionZoo::new(&type_type, &object_type, &dict_type); @@ -142,6 +145,7 @@ impl PyContext { true_value: true_value, false_value: false_value, tuple_type: tuple_type, + iter_type: iter_type, dict_type: dict_type, none: none, str_type: str_type, @@ -164,6 +168,7 @@ impl PyContext { objbytes::init(&context); objstr::init(&context); objtuple::init(&context); + objiter::init(&context); objbool::init(&context); exceptions::init(&context); context @@ -190,6 +195,9 @@ impl PyContext { pub fn tuple_type(&self) -> PyObjectRef { self.tuple_type.clone() } + pub fn iter_type(&self) -> PyObjectRef { + self.iter_type.clone() + } pub fn dict_type(&self) -> PyObjectRef { self.dict_type.clone() } @@ -750,35 +758,6 @@ impl PyObject { } } - // Implement iterator protocol: - pub fn nxt(&mut self) -> Option { - match self.kind { - PyObjectKind::Iterator { - ref mut position, - iterated_obj: ref iterated_obj_ref, - } => { - let iterated_obj = &*iterated_obj_ref.borrow_mut(); - match iterated_obj.kind { - PyObjectKind::List { ref elements } => { - if *position < elements.len() { - let obj_ref = elements[*position].clone(); - *position += 1; - Some(obj_ref) - } else { - None - } - } - _ => { - panic!("NOT IMPL"); - } - } - } - _ => { - panic!("NOT IMPL"); - } - } - } - // Move this object into a reference object, transferring ownership. pub fn into_ref(self) -> PyObjectRef { Rc::new(RefCell::new(self)) diff --git a/vm/src/vm.rs b/vm/src/vm.rs index b165678a1..5c98144fc 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -7,17 +7,17 @@ extern crate rustpython_parser; use self::rustpython_parser::ast; -use std::cell::RefMut; use std::collections::hash_map::HashMap; -use std::ops::Deref; use super::builtins; use super::bytecode; use super::frame::{copy_code, Block, Frame}; use super::import::import; +use super::obj::objiter; use super::obj::objlist; use super::obj::objobject; use super::obj::objstr; +use super::obj::objtuple; use super::obj::objtype; use super::objbool; use super::pyobject::{ @@ -846,46 +846,50 @@ impl VirtualMachine { } bytecode::Instruction::GetIter => { let iterated_obj = self.pop_value(); - let iter_obj = PyObject::new( - PyObjectKind::Iterator { - position: 0, - iterated_obj: iterated_obj, - }, - self.ctx.type_type(), - ); - self.push_value(iter_obj); - None + match objiter::get_iter(self, &iterated_obj) { + Ok(iter_obj) => { + self.push_value(iter_obj); + None + } + Err(err) => Some(Err(err)), + } } bytecode::Instruction::ForIter => { // The top of stack contains the iterator, lets push it forward: - let next_obj: Option = { + let next_obj: PyResult = { let top_of_stack = self.last_value(); - let mut ref_mut: RefMut = top_of_stack.deref().borrow_mut(); - // We require a mutable pyobject here to update the iterator: - let mut iterator = ref_mut; // &mut PyObject = ref_mut.; - // let () = iterator; - iterator.nxt() + self.call_method(top_of_stack, "__next__", vec![]) }; // Check the next object: match next_obj { - Some(value) => { + Ok(value) => { self.push_value(value); + None } - None => { - // Pop iterator from stack: - self.pop_value(); + Err(next_error) => { + // Check if we have stopiteration, or something else: + if objtype::isinstance( + &next_error, + self.ctx.exceptions.stop_iteration.clone(), + ) { + // Pop iterator from stack: + self.pop_value(); - // End of for loop - let end_label = if let Block::Loop { start: _, end } = self.last_block() { - *end + // End of for loop + let end_label = if let Block::Loop { start: _, end } = self.last_block() + { + *end + } else { + panic!("Wrong block type") + }; + self.jump(&end_label); + None } else { - panic!("Wrong block type") - }; - self.jump(&end_label); + Some(Err(next_error)) + } } - }; - None + } } bytecode::Instruction::MakeFunction { flags } => { let _qualified_name = self.pop_value(); @@ -1062,6 +1066,19 @@ impl VirtualMachine { } None } + bytecode::Instruction::UnpackSequence { size } => { + let value = self.pop_value(); + + let elements = objtuple::get_elements(&value); + if elements.len() != *size { + panic!("Wrong number of values to unpack"); + } + + for element in elements.into_iter().rev() { + self.push_value(element); + } + None + } } }