diff --git a/tests/snippets/generators.py b/tests/snippets/generators.py index 2c0cf32fb..dccbff3ec 100644 --- a/tests/snippets/generators.py +++ b/tests/snippets/generators.py @@ -131,3 +131,13 @@ for i in [0, 1, 2, 'spam', 4]: wrap.send(i) assert l == ['>> 0', '>> 1', '>> 2', '***', '>> 4'] + +def a(): + yield + +g = a() + +next(g) +assert_raises(TypeError, g.throw, TypeError) +assert_raises(StopIteration, next, g) +assert_raises(TypeError, g.throw, TypeError) diff --git a/vm/src/obj/objcoroutine.rs b/vm/src/obj/objcoroutine.rs index e052b401b..6d36f81a7 100644 --- a/vm/src/obj/objcoroutine.rs +++ b/vm/src/obj/objcoroutine.rs @@ -1,15 +1,19 @@ +use super::objiter::new_stop_iteration; use super::objtype::{isinstance, issubclass, PyClassRef}; use crate::frame::{ExecutionResult, FrameRef}; use crate::function::OptionalArg; use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; +use std::cell::Cell; + pub type PyCoroutineRef = PyRef; #[pyclass(name = "coroutine")] #[derive(Debug)] pub struct PyCoroutine { frame: FrameRef, + closed: Cell, } impl PyValue for PyCoroutine { @@ -21,14 +25,32 @@ impl PyValue for PyCoroutine { #[pyimpl] impl PyCoroutine { pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyCoroutineRef { - PyCoroutine { frame }.into_ref(vm) + PyCoroutine { + frame, + closed: Cell::new(false), + } + .into_ref(vm) + } + + // TODO: deduplicate this code with objgenerator + fn maybe_close(&self, res: &PyResult) { + match res { + Ok(ExecutionResult::Return(_)) | Err(_) => self.closed.set(true), + Ok(ExecutionResult::Yield(_)) => {} + } } #[pymethod] pub(crate) fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if self.closed.get() { + return Err(new_stop_iteration(vm)); + } + self.frame.push_value(value.clone()); - vm.run_frame(self.frame.clone())?.into_result(vm) + let result = vm.run_frame(self.frame.clone()); + self.maybe_close(&result); + result?.into_result(vm) } #[pymethod] @@ -39,27 +61,31 @@ impl PyCoroutine { exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { + if self.closed.get() { + return Err(vm.invoke(exc_type.as_object(), vec![])?); + } // TODO what should we do with the other parameters? CPython normalises them with // PyErr_NormalizeException, do we want to do the same. if !issubclass(&exc_type, &vm.ctx.exceptions.base_exception_type) { return Err(vm.new_type_error("Can't throw non exception".to_string())); } vm.frames.borrow_mut().push(self.frame.clone()); - let result = self - .frame - .gen_throw( - vm, - exc_type, - exc_val.unwrap_or(vm.get_none()), - exc_tb.unwrap_or(vm.get_none()), - ) - .and_then(|res| res.into_result(vm)); + let result = self.frame.gen_throw( + vm, + exc_type, + exc_val.unwrap_or(vm.get_none()), + exc_tb.unwrap_or(vm.get_none()), + ); + self.maybe_close(&result); vm.frames.borrow_mut().pop(); - result + result?.into_result(vm) } #[pymethod] fn close(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.closed.get() { + return Ok(()); + } vm.frames.borrow_mut().push(self.frame.clone()); let result = self.frame.gen_throw( vm, @@ -68,6 +94,7 @@ impl PyCoroutine { vm.get_none(), ); vm.frames.borrow_mut().pop(); + self.closed.set(true); match result { Ok(ExecutionResult::Yield(_)) => Err(vm.new_exception( vm.ctx.exceptions.runtime_error.clone(), diff --git a/vm/src/obj/objgenerator.rs b/vm/src/obj/objgenerator.rs index 44342bfed..e98500e20 100644 --- a/vm/src/obj/objgenerator.rs +++ b/vm/src/obj/objgenerator.rs @@ -2,18 +2,22 @@ * The mythical generator. */ +use super::objiter::new_stop_iteration; use super::objtype::{isinstance, issubclass, PyClassRef}; use crate::frame::{ExecutionResult, FrameRef}; use crate::function::OptionalArg; use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; +use std::cell::Cell; + pub type PyGeneratorRef = PyRef; #[pyclass(name = "generator")] #[derive(Debug)] pub struct PyGenerator { frame: FrameRef, + closed: Cell, } impl PyValue for PyGenerator { @@ -25,7 +29,18 @@ impl PyValue for PyGenerator { #[pyimpl] impl PyGenerator { pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyGeneratorRef { - PyGenerator { frame }.into_ref(vm) + PyGenerator { + frame, + closed: Cell::new(false), + } + .into_ref(vm) + } + + fn maybe_close(&self, res: &PyResult) { + match res { + Ok(ExecutionResult::Return(_)) | Err(_) => self.closed.set(true), + Ok(ExecutionResult::Yield(_)) => {} + } } #[pymethod(name = "__iter__")] @@ -40,9 +55,15 @@ impl PyGenerator { #[pymethod] pub(crate) fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if self.closed.get() { + return Err(new_stop_iteration(vm)); + } + self.frame.push_value(value.clone()); - vm.run_frame(self.frame.clone())?.into_result(vm) + let result = vm.run_frame(self.frame.clone()); + self.maybe_close(&result); + result?.into_result(vm) } #[pymethod] @@ -53,27 +74,31 @@ impl PyGenerator { exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { + if self.closed.get() { + return Err(vm.invoke(exc_type.as_object(), vec![])?); + } // TODO what should we do with the other parameters? CPython normalises them with // PyErr_NormalizeException, do we want to do the same. if !issubclass(&exc_type, &vm.ctx.exceptions.base_exception_type) { return Err(vm.new_type_error("Can't throw non exception".to_string())); } vm.frames.borrow_mut().push(self.frame.clone()); - let result = self - .frame - .gen_throw( - vm, - exc_type, - exc_val.unwrap_or(vm.get_none()), - exc_tb.unwrap_or(vm.get_none()), - ) - .and_then(|res| res.into_result(vm)); + let result = self.frame.gen_throw( + vm, + exc_type, + exc_val.unwrap_or(vm.get_none()), + exc_tb.unwrap_or(vm.get_none()), + ); + self.maybe_close(&result); vm.frames.borrow_mut().pop(); - result + result?.into_result(vm) } #[pymethod] fn close(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.closed.get() { + return Ok(()); + } vm.frames.borrow_mut().push(self.frame.clone()); let result = self.frame.gen_throw( vm, @@ -82,6 +107,7 @@ impl PyGenerator { vm.get_none(), ); vm.frames.borrow_mut().pop(); + self.closed.set(true); match result { Ok(ExecutionResult::Yield(_)) => Err(vm.new_exception( vm.ctx.exceptions.runtime_error.clone(),