diff --git a/tests/snippets/generators.py b/tests/snippets/generators.py index f23e5918a..2c0cf32fb 100644 --- a/tests/snippets/generators.py +++ b/tests/snippets/generators.py @@ -88,3 +88,46 @@ l = list(g) # print(l) assert l == [99] assert r == ['a', 66, None] + +def binary(n): + if n <= 1: + return 1 + l = yield from binary(n - 1) + r = yield from binary(n - 1) + return l + 1 + r + +with assert_raises(StopIteration): + try: + next(binary(5)) + except StopIteration as stopiter: + # TODO: StopIteration.value + assert stopiter.args[0] == 31 + raise + +class SpamException(Exception): + pass + +l = [] + +def writer(): + while True: + try: + w = (yield) + except SpamException: + l.append('***') + else: + l.append(f'>> {w}') + +def wrapper(coro): + yield from coro + +w = writer() +wrap = wrapper(w) +wrap.send(None) # "prime" the coroutine +for i in [0, 1, 2, 'spam', 4]: + if i == 'spam': + wrap.throw(SpamException) + else: + wrap.send(i) + +assert l == ['>> 0', '>> 1', '>> 2', '***', '>> 4'] diff --git a/vm/src/frame.rs b/vm/src/frame.rs index b5e3eca16..3ccef016f 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1,4 +1,4 @@ -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use std::fmt; use indexmap::IndexMap; @@ -90,7 +90,7 @@ pub struct Frame { /// Variables pub scope: Scope, /// index of last instruction ran - pub lasti: RefCell, + pub lasti: Cell, } impl PyValue for Frame { @@ -105,6 +105,38 @@ pub enum ExecutionResult { Yield(PyObjectRef), } +impl ExecutionResult { + /// Extract an ExecutionResult from a PyResult returned from e.g. gen.__next__() or gen.send() + pub fn from_result(vm: &VirtualMachine, res: PyResult) -> PyResult { + match res { + Ok(val) => Ok(ExecutionResult::Yield(val)), + Err(err) => { + if objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { + objiter::stop_iter_value(vm, &err).map(ExecutionResult::Return) + } else { + Err(err) + } + } + } + } + + /// Turn an ExecutionResult into a PyResult that would be returned from a generator or coroutine + pub fn into_result(self, vm: &VirtualMachine) -> PyResult { + match self { + ExecutionResult::Yield(value) => Ok(value), + ExecutionResult::Return(value) => { + let stop_iteration = vm.ctx.exceptions.stop_iteration.clone(); + let args = if vm.is_none(&value) { + vec![] + } else { + vec![value] + }; + Err(vm.new_exception_obj(stop_iteration, args).unwrap()) + } + } + } +} + /// A valid execution result, or an exception pub type FrameResult = PyResult>; @@ -128,7 +160,7 @@ impl Frame { // save the callargs as locals // globals: locals.clone(), scope, - lasti: RefCell::new(0), + lasti: Cell::new(0), } } @@ -185,17 +217,40 @@ impl Frame { } } - pub fn throw(&self, vm: &VirtualMachine, exception: PyObjectRef) -> PyResult { - match self.unwind_blocks(vm, UnwindReason::Raising { exception }) { - Ok(None) => self.run(vm), - Ok(Some(result)) => Ok(result), - Err(exception) => Err(exception), + pub(crate) fn gen_throw( + &self, + vm: &VirtualMachine, + exc_type: PyClassRef, + exc_val: PyObjectRef, + exc_tb: PyObjectRef, + ) -> PyResult { + if let bytecode::Instruction::YieldFrom = self.code.instructions[self.lasti.get()] { + let coro = self.last_value(); + vm.call_method( + &coro, + "throw", + vec![exc_type.into_object(), exc_val, exc_tb], + ) + .or_else(|err| { + self.pop_value(); + self.lasti.set(self.lasti.get() + 1); + let val = objiter::stop_iter_value(vm, &err)?; + self._send(coro, val, vm) + }) + } else { + let exception = vm.new_exception_obj(exc_type, vec![exc_val])?; + match self.unwind_blocks(vm, UnwindReason::Raising { exception }) { + Ok(None) => self.run(vm), + Ok(Some(result)) => Ok(result), + Err(exception) => Err(exception), + } + .and_then(|res| res.into_result(vm)) } } pub fn fetch_instruction(&self) -> &bytecode::Instruction { - let ins2 = &self.code.instructions[*self.lasti.borrow()]; - *self.lasti.borrow_mut() += 1; + let ins2 = &self.code.instructions[self.lasti.get()]; + self.lasti.set(self.lasti.get() + 1); ins2 } @@ -954,20 +1009,35 @@ impl Frame { Err(exception) } + fn _send(&self, coro: PyObjectRef, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if vm.is_none(&val) { + objiter::call_next(vm, &coro) + } else { + vm.call_method(&coro, "send", vec![val]) + } + } + fn execute_yield_from(&self, vm: &VirtualMachine) -> FrameResult { // Value send into iterator: - self.pop_value(); + let val = self.pop_value(); - let top_of_stack = self.last_value(); - let next_obj = objiter::get_next_object(vm, &top_of_stack)?; + let coro = self.last_value(); - match next_obj { - Some(value) => { + let result = self._send(coro, val, vm); + + let result = ExecutionResult::from_result(vm, result)?; + + match result { + ExecutionResult::Yield(value) => { // Set back program counter: - *self.lasti.borrow_mut() -= 1; + self.lasti.set(self.lasti.get() - 1); Ok(Some(ExecutionResult::Yield(value))) } - None => Ok(None), + ExecutionResult::Return(value) => { + self.pop_value(); + self.push_value(value); + Ok(None) + } } } @@ -1006,7 +1076,7 @@ impl Frame { let target_pc = self.code.label_map[&label]; #[cfg(feature = "vm-tracing-logging")] trace!("jump from {:?} to {:?}", self.lasti, target_pc); - self.lasti.replace(target_pc); + self.lasti.set(target_pc); } /// The top of stack contains the iterator, lets push it forward @@ -1238,7 +1308,7 @@ impl Frame { } pub fn get_lineno(&self) -> bytecode::Location { - self.code.locations[*self.lasti.borrow()].clone() + self.code.locations[self.lasti.get()].clone() } fn push_block(&self, typ: BlockType) { diff --git a/vm/src/obj/objframe.rs b/vm/src/obj/objframe.rs index 3b447f4f0..4988e634e 100644 --- a/vm/src/obj/objframe.rs +++ b/vm/src/obj/objframe.rs @@ -47,6 +47,6 @@ impl FrameRef { #[pyproperty] fn f_lasti(self, vm: &VirtualMachine) -> PyObjectRef { - vm.ctx.new_int(*self.lasti.borrow()) + vm.ctx.new_int(self.lasti.get()) } } diff --git a/vm/src/obj/objgenerator.rs b/vm/src/obj/objgenerator.rs index 8da98c4ef..11159e055 100644 --- a/vm/src/obj/objgenerator.rs +++ b/vm/src/obj/objgenerator.rs @@ -2,8 +2,9 @@ * The mythical generator. */ -use super::objtype::{isinstance, PyClassRef}; -use crate::frame::{ExecutionResult, FrameRef}; +use super::objtype::{issubclass, PyClassRef}; +use crate::frame::FrameRef; +use crate::function::OptionalArg; use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; @@ -41,36 +42,31 @@ impl PyGenerator { fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.frame.push_value(value.clone()); - let result = vm.run_frame(self.frame.clone())?; - handle_execution_result(result, vm) + vm.run_frame(self.frame.clone())?.into_result(vm) } #[pymethod] fn throw( &self, - _exc_type: PyObjectRef, - exc_val: PyObjectRef, - _exc_tb: PyObjectRef, + exc_type: PyClassRef, + exc_val: OptionalArg, + exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { // TODO what should we do with the other parameters? CPython normalises them with // PyErr_NormalizeException, do we want to do the same. - if !isinstance(&exc_val, &vm.ctx.exceptions.base_exception_type) { + if !issubclass(&exc_type, &vm.ctx.exceptions.base_exception_type) { return Err(vm.new_type_error("Can't throw non exception".to_string())); } - let result = vm.frame_throw(self.frame.clone(), exc_val)?; - handle_execution_result(result, vm) - } -} - -fn handle_execution_result(result: ExecutionResult, vm: &VirtualMachine) -> PyResult { - match result { - ExecutionResult::Yield(value) => Ok(value), - ExecutionResult::Return(_value) => { - // Stop iteration! - let stop_iteration = vm.ctx.exceptions.stop_iteration.clone(); - Err(vm.new_exception(stop_iteration, "End of generator".to_string())) - } + vm.frames.borrow_mut().push(self.frame.clone()); + let result = self.frame.gen_throw( + vm, + exc_type, + exc_val.unwrap_or(vm.get_none()), + exc_tb.unwrap_or(vm.get_none()), + ); + vm.frames.borrow_mut().pop(); + result } } diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 702bcc067..61267968d 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -4,6 +4,7 @@ use std::cell::Cell; +use super::objtuple::PyTuple; use super::objtype::{self, PyClassRef}; use crate::pyobject::{ PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, @@ -76,6 +77,17 @@ pub fn new_stop_iteration(vm: &VirtualMachine) -> PyObjectRef { vm.new_exception(stop_iteration_type, "End of iterator".to_string()) } +pub fn stop_iter_value(vm: &VirtualMachine, exc: &PyObjectRef) -> PyResult { + let args = vm.get_attribute(exc.clone(), "args")?; + let args: &PyTuple = args.payload().unwrap(); + let val = args + .elements + .first() + .cloned() + .unwrap_or_else(|| vm.get_none()); + Ok(val) +} + #[pyclass] #[derive(Debug)] pub struct PySequenceIterator { diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 93949bc05..06cceb143 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -234,17 +234,6 @@ impl VirtualMachine { } } - pub fn frame_throw( - &self, - frame: FrameRef, - exception: PyObjectRef, - ) -> PyResult { - self.frames.borrow_mut().push(frame.clone()); - let result = frame.throw(self, exception); - self.frames.borrow_mut().pop(); - result - } - pub fn current_frame(&self) -> Option> { let frames = self.frames.borrow(); if frames.is_empty() {