diff --git a/vm/src/builtins/asyncgenerator.rs b/vm/src/builtins/asyncgenerator.rs index ae6a3ad41b..b25178092a 100644 --- a/vm/src/builtins/asyncgenerator.rs +++ b/vm/src/builtins/asyncgenerator.rs @@ -1,7 +1,7 @@ use super::{PyCode, PyStrRef, PyTypeRef}; use crate::{ builtins::PyBaseExceptionRef, - coroutine::{Coro, Variant}, + coroutine::Coro, frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, @@ -34,7 +34,7 @@ impl PyAsyncGen { pub fn new(frame: FrameRef, name: PyStrRef) -> Self { PyAsyncGen { - inner: Coro::new(frame, Variant::AsyncGen, name), + inner: Coro::new(frame, name), running_async: AtomicCell::new(false), } } @@ -50,8 +50,8 @@ impl PyAsyncGen { } #[pymethod(magic)] - fn repr(zelf: PyRef) -> String { - zelf.inner.repr(zelf.get_id()) + fn repr(zelf: PyRef, vm: &VirtualMachine) -> String { + zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm) } #[pymethod(magic)] @@ -138,17 +138,20 @@ impl PyValue for PyAsyncGenWrappedValue { impl PyAsyncGenWrappedValue {} impl PyAsyncGenWrappedValue { - fn unbox(ag: &PyAsyncGen, val: PyResult, vm: &VirtualMachine) -> PyResult { - if let Err(ref e) = val { - if e.isinstance(&vm.ctx.exceptions.stop_async_iteration) - || e.isinstance(&vm.ctx.exceptions.generator_exit) - { - ag.inner.closed.store(true); - } + fn unbox(ag: &PyAsyncGen, val: PyResult, vm: &VirtualMachine) -> PyResult { + let (closed, async_done) = match &val { + Ok(PyIterReturn::StopIteration(_)) => (true, true), + Err(e) if e.isinstance(&vm.ctx.exceptions.generator_exit) => (true, true), + Err(_) => (false, true), + _ => (false, false), + }; + if closed { + ag.inner.closed.store(true); + } + if async_done { ag.running_async.store(false); } - let val = val?; - + let val = val?.into_async_pyresult(vm)?; match_class!(match val { val @ Self => { ag.running_async.store(false); @@ -214,7 +217,7 @@ impl PyAsyncGenASend { } } }; - let res = self.ag.inner.send(val, vm); + let res = self.ag.inner.send(self.ag.as_object(), val, vm); let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm); if res.is_err() { self.close(); @@ -237,6 +240,7 @@ impl PyAsyncGenASend { } let res = self.ag.inner.throw( + self.ag.as_object(), exc_type, exc_val.unwrap_or_none(vm), exc_tb.unwrap_or_none(vm), @@ -258,8 +262,7 @@ impl PyAsyncGenASend { impl IteratorIterable for PyAsyncGenASend {} impl SlotIterator for PyAsyncGenASend { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - // TODO: Fix zelf.send to return PyIterReturn - PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm) + PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm) } } @@ -315,12 +318,12 @@ impl PyAsyncGenAThrow { self.ag.running_async.store(true); let (ty, val, tb) = self.value.clone(); - let ret = self.ag.inner.throw(ty, val, tb, vm); + let ret = self.ag.inner.throw(self.ag.as_object(), ty, val, tb, vm); let ret = if self.aclose { if self.ignored_close(&ret) { Err(self.yield_close(vm)) } else { - ret + ret.and_then(|o| o.into_async_pyresult(vm)) } } else { PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm) @@ -328,14 +331,15 @@ impl PyAsyncGenAThrow { ret.map_err(|e| self.check_error(e, vm)) } AwaitableState::Iter => { - let ret = self.ag.inner.send(val, vm); + let ret = self.ag.inner.send(self.ag.as_object(), val, vm); if self.aclose { match ret { - Ok(v) if v.payload_is::() => { + Ok(PyIterReturn::Return(v)) if v.payload_is::() => { Err(self.yield_close(vm)) } - Ok(v) => Ok(v), - Err(e) => Err(self.check_error(e, vm)), + other => other + .and_then(|o| o.into_async_pyresult(vm)) + .map_err(|e| self.check_error(e, vm)), } } else { PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm) @@ -353,6 +357,7 @@ impl PyAsyncGenAThrow { vm: &VirtualMachine, ) -> PyResult { let ret = self.ag.inner.throw( + self.ag.as_object(), exc_type, exc_val.unwrap_or_none(vm), exc_tb.unwrap_or_none(vm), @@ -362,7 +367,7 @@ impl PyAsyncGenAThrow { if self.ignored_close(&ret) { Err(self.yield_close(vm)) } else { - ret + ret.and_then(|o| o.into_async_pyresult(vm)) } } else { PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm) @@ -375,9 +380,11 @@ impl PyAsyncGenAThrow { self.state.store(AwaitableState::Closed); } - fn ignored_close(&self, res: &PyResult) -> bool { - res.as_ref() - .map_or(false, |v| v.payload_is::()) + fn ignored_close(&self, res: &PyResult) -> bool { + res.as_ref().map_or(false, |v| match v { + PyIterReturn::Return(obj) => obj.payload_is::(), + PyIterReturn::StopIteration(_) => false, + }) } fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { self.ag.running_async.store(false); @@ -401,8 +408,7 @@ impl PyAsyncGenAThrow { impl IteratorIterable for PyAsyncGenAThrow {} impl SlotIterator for PyAsyncGenAThrow { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - // TODO: Fix zelf.send to return PyIterReturn - PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm) + PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm) } } diff --git a/vm/src/builtins/coroutine.rs b/vm/src/builtins/coroutine.rs index 45e6ec6b01..6bff0cd7bd 100644 --- a/vm/src/builtins/coroutine.rs +++ b/vm/src/builtins/coroutine.rs @@ -1,6 +1,6 @@ use super::{PyCode, PyStrRef, PyTypeRef}; use crate::{ - coroutine::{Coro, Variant}, + coroutine::Coro, frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, @@ -29,7 +29,7 @@ impl PyCoroutine { pub fn new(frame: FrameRef, name: PyStrRef) -> Self { PyCoroutine { - inner: Coro::new(frame, Variant::Coroutine, name), + inner: Coro::new(frame, name), } } @@ -44,24 +44,25 @@ impl PyCoroutine { } #[pymethod(magic)] - fn repr(zelf: PyRef) -> String { - zelf.inner.repr(zelf.get_id()) + fn repr(zelf: PyRef, vm: &VirtualMachine) -> String { + zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm) } #[pymethod] - fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.send(value, vm) + fn send(zelf: PyRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + zelf.inner.send(zelf.as_object(), value, vm) } #[pymethod] fn throw( - &self, + zelf: PyRef, exc_type: PyObjectRef, exc_val: OptionalArg, exc_tb: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { - self.inner.throw( + ) -> PyResult { + zelf.inner.throw( + zelf.as_object(), exc_type, exc_val.unwrap_or_none(vm), exc_tb.unwrap_or_none(vm), @@ -70,8 +71,8 @@ impl PyCoroutine { } #[pymethod] - fn close(&self, vm: &VirtualMachine) -> PyResult<()> { - self.inner.close(vm) + fn close(zelf: PyRef, vm: &VirtualMachine) -> PyResult<()> { + zelf.inner.close(zelf.as_object(), vm) } #[pymethod(name = "__await__")] @@ -106,8 +107,7 @@ impl PyCoroutine { impl IteratorIterable for PyCoroutine {} impl SlotIterator for PyCoroutine { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - // TODO: Fix zelf.send to return PyIterReturn - PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm) + Self::send(zelf.clone(), vm.ctx.none(), vm) } } @@ -127,8 +127,8 @@ impl PyValue for PyCoroutineWrapper { #[pyimpl(with(SlotIterator))] impl PyCoroutineWrapper { #[pymethod] - fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.coro.send(val, vm) + fn send(zelf: PyRef, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { + PyCoroutine::send(zelf.coro.clone(), val, vm) } #[pymethod] @@ -138,16 +138,15 @@ impl PyCoroutineWrapper { exc_val: OptionalArg, exc_tb: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { - self.coro.throw(exc_type, exc_val, exc_tb, vm) + ) -> PyResult { + PyCoroutine::throw(self.coro.clone(), exc_type, exc_val, exc_tb, vm) } } impl IteratorIterable for PyCoroutineWrapper {} impl SlotIterator for PyCoroutineWrapper { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - // TODO: Fix zelf.send to return PyIterReturn - PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm) + Self::send(zelf.clone(), vm.ctx.none(), vm) } } diff --git a/vm/src/builtins/filter.rs b/vm/src/builtins/filter.rs index 88e8bf9f28..06103360ae 100644 --- a/vm/src/builtins/filter.rs +++ b/vm/src/builtins/filter.rs @@ -51,7 +51,8 @@ impl SlotIterator for PyFilter { } else { // the predicate itself can raise StopIteration which does stop the filter // iteration - match PyIterReturn::from_result(vm.invoke(predicate, vec![next_obj.clone()]), vm)? { + match PyIterReturn::from_pyresult(vm.invoke(predicate, vec![next_obj.clone()]), vm)? + { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), } diff --git a/vm/src/builtins/generator.rs b/vm/src/builtins/generator.rs index deb17e5188..a8d0124ad7 100644 --- a/vm/src/builtins/generator.rs +++ b/vm/src/builtins/generator.rs @@ -4,7 +4,7 @@ use super::{PyCode, PyStrRef, PyTypeRef}; use crate::{ - coroutine::{Coro, Variant}, + coroutine::Coro, frame::FrameRef, function::OptionalArg, protocol::PyIterReturn, @@ -32,7 +32,7 @@ impl PyGenerator { pub fn new(frame: FrameRef, name: PyStrRef) -> Self { PyGenerator { - inner: Coro::new(frame, Variant::Gen, name), + inner: Coro::new(frame, name), } } @@ -47,24 +47,25 @@ impl PyGenerator { } #[pymethod(magic)] - fn repr(zelf: PyRef) -> String { - zelf.inner.repr(zelf.get_id()) + fn repr(zelf: PyRef, vm: &VirtualMachine) -> String { + zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm) } #[pymethod] - fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.inner.send(value, vm) + fn send(zelf: PyRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + zelf.inner.send(zelf.as_object(), value, vm) } #[pymethod] fn throw( - &self, + zelf: PyRef, exc_type: PyObjectRef, exc_val: OptionalArg, exc_tb: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { - self.inner.throw( + ) -> PyResult { + zelf.inner.throw( + zelf.as_object(), exc_type, exc_val.unwrap_or_none(vm), exc_tb.unwrap_or_none(vm), @@ -73,8 +74,8 @@ impl PyGenerator { } #[pymethod] - fn close(&self, vm: &VirtualMachine) -> PyResult<()> { - self.inner.close(vm) + fn close(zelf: PyRef, vm: &VirtualMachine) -> PyResult<()> { + zelf.inner.close(zelf.as_object(), vm) } #[pyproperty] @@ -98,8 +99,7 @@ impl PyGenerator { impl IteratorIterable for PyGenerator {} impl SlotIterator for PyGenerator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - // TODO: Fix zelf.send to return PyIterReturn - PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm) + Self::send(zelf.clone(), vm.ctx.none(), vm) } } diff --git a/vm/src/builtins/map.rs b/vm/src/builtins/map.rs index 870cb90d3a..c3f2a44d6d 100644 --- a/vm/src/builtins/map.rs +++ b/vm/src/builtins/map.rs @@ -58,7 +58,7 @@ impl SlotIterator for PyMap { } // the mapper itself can raise StopIteration which does stop the map iteration - PyIterReturn::from_result(vm.invoke(&zelf.mapper, next_objs), vm) + PyIterReturn::from_pyresult(vm.invoke(&zelf.mapper, next_objs), vm) } } diff --git a/vm/src/builtins/pytype.rs b/vm/src/builtins/pytype.rs index 9db1530c8a..72f5fc85d7 100644 --- a/vm/src/builtins/pytype.rs +++ b/vm/src/builtins/pytype.rs @@ -264,7 +264,7 @@ impl PyType { } "__next__" => { let func: slots::IterNextFunc = |zelf, vm| { - PyIterReturn::from_result( + PyIterReturn::from_pyresult( vm.call_special_method(zelf.clone(), "__next__", ()), vm, ) diff --git a/vm/src/coroutine.rs b/vm/src/coroutine.rs index ffc4e67cb9..440ebff27c 100644 --- a/vm/src/coroutine.rs +++ b/vm/src/coroutine.rs @@ -1,33 +1,26 @@ use crate::{ - builtins::{PyBaseExceptionRef, PyStrRef, PyTypeRef}, + builtins::{PyBaseExceptionRef, PyStrRef}, common::lock::PyMutex, exceptions, frame::{ExecutionResult, FrameRef}, - PyObjectRef, PyResult, TypeProtocol, VirtualMachine, + protocol::PyIterReturn, + IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; -#[derive(Debug, PartialEq, Clone, Copy)] -pub enum Variant { - Gen, - Coroutine, - AsyncGen, -} -impl Variant { - fn exec_result(self, res: ExecutionResult, vm: &VirtualMachine) -> PyResult { - res.into_result(self == Self::AsyncGen, vm) - } - fn name(self) -> &'static str { +impl ExecutionResult { + /// Turn an ExecutionResult into a PyResult that would be returned from a generator or coroutine + fn into_iter_return(self, vm: &VirtualMachine) -> PyIterReturn { match self { - Self::Gen => "generator", - Self::Coroutine => "coroutine", - Self::AsyncGen => "async generator", - } - } - fn stop_iteration(self, vm: &VirtualMachine) -> PyTypeRef { - match self { - Self::AsyncGen => vm.ctx.exceptions.stop_async_iteration.clone(), - _ => vm.ctx.exceptions.stop_iteration.clone(), + ExecutionResult::Yield(value) => PyIterReturn::Return(value), + ExecutionResult::Return(value) => { + let arg = if vm.is_none(&value) { + None + } else { + Some(value) + }; + PyIterReturn::StopIteration(arg) + } } } } @@ -35,9 +28,8 @@ impl Variant { #[derive(Debug)] pub struct Coro { frame: FrameRef, - pub closed: AtomicCell, // redudant? + pub closed: AtomicCell, // TODO: https://github.com/RustPython/RustPython/pull/3183#discussion_r720560652 running: AtomicCell, - variant: Variant, // code // _weakreflist name: PyMutex, @@ -45,14 +37,24 @@ pub struct Coro { exception: PyMutex>, // exc_state } +fn gen_name(gen: &PyObjectRef, vm: &VirtualMachine) -> &'static str { + let typ = gen.class(); + if typ.is(&vm.ctx.types.coroutine_type) { + "coroutine" + } else if typ.is(&vm.ctx.types.async_generator) { + "async generator" + } else { + "generator" + } +} + impl Coro { - pub fn new(frame: FrameRef, variant: Variant, name: PyStrRef) -> Self { + pub fn new(frame: FrameRef, name: PyStrRef) -> Self { Coro { frame, closed: AtomicCell::new(false), running: AtomicCell::new(false), exception: PyMutex::default(), - variant, name: PyMutex::new(name), } } @@ -64,12 +66,17 @@ impl Coro { } } - fn run_with_context(&self, vm: &VirtualMachine, func: F) -> PyResult + fn run_with_context( + &self, + gen: &PyObjectRef, + vm: &VirtualMachine, + func: F, + ) -> PyResult where F: FnOnce(FrameRef) -> PyResult, { if self.running.compare_exchange(false, true).is_err() { - return Err(vm.new_value_error(format!("{} already executing", self.variant.name()))); + return Err(vm.new_value_error(format!("{} already executing", gen_name(gen, vm)))); } vm.push_exception(self.exception.lock().take()); @@ -82,31 +89,36 @@ impl Coro { result } - pub fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + pub fn send( + &self, + gen: &PyObjectRef, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { if self.closed.load() { - return Err(vm.new_exception_empty(self.variant.stop_iteration(vm))); + return Ok(PyIterReturn::StopIteration(None)); } let value = if self.frame.lasti() > 0 { Some(value) } else if !vm.is_none(&value) { return Err(vm.new_type_error(format!( "can't send non-None value to a just-started {}", - self.variant.name() + gen_name(gen, vm), ))); } else { None }; - let result = self.run_with_context(vm, |f| f.resume(value, vm)); + let result = self.run_with_context(gen, vm, |f| f.resume(value, vm)); self.maybe_close(&result); match result { - Ok(exec_res) => self.variant.exec_result(exec_res, vm), + Ok(exec_res) => Ok(exec_res.into_iter_return(vm)), Err(e) => { if e.isinstance(&vm.ctx.exceptions.stop_iteration) { - let err = vm - .new_runtime_error(format!("{} raised StopIteration", self.variant.name())); + let err = + vm.new_runtime_error(format!("{} raised StopIteration", gen_name(gen, vm))); err.set_cause(Some(e)); Err(err) - } else if self.variant == Variant::AsyncGen + } else if gen.class().is(&vm.ctx.types.async_generator) && e.isinstance(&vm.ctx.exceptions.stop_async_iteration) { let err = vm @@ -121,24 +133,25 @@ impl Coro { } pub fn throw( &self, + gen: &PyObjectRef, exc_type: PyObjectRef, exc_val: PyObjectRef, exc_tb: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult { if self.closed.load() { return Err(exceptions::normalize(exc_type, exc_val, exc_tb, vm)?); } - let result = self.run_with_context(vm, |f| f.gen_throw(vm, exc_type, exc_val, exc_tb)); + let result = self.run_with_context(gen, vm, |f| f.gen_throw(vm, exc_type, exc_val, exc_tb)); self.maybe_close(&result); - self.variant.exec_result(result?, vm) + Ok(result?.into_iter_return(vm)) } - pub fn close(&self, vm: &VirtualMachine) -> PyResult<()> { + pub fn close(&self, gen: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { if self.closed.load() { return Ok(()); } - let result = self.run_with_context(vm, |f| { + let result = self.run_with_context(gen, vm, |f| { f.gen_throw( vm, vm.ctx.exceptions.generator_exit.clone().into_object(), @@ -149,7 +162,7 @@ impl Coro { self.closed.store(true); match result { Ok(ExecutionResult::Yield(_)) => { - Err(vm.new_runtime_error(format!("{} ignored GeneratorExit", self.variant.name()))) + Err(vm.new_runtime_error(format!("{} ignored GeneratorExit", gen_name(gen, vm)))) } Err(e) if !is_gen_exit(&e, vm) => Err(e), _ => Ok(()), @@ -171,10 +184,10 @@ impl Coro { pub fn set_name(&self, name: PyStrRef) { *self.name.lock() = name; } - pub fn repr(&self, id: usize) -> String { + pub fn repr(&self, gen: &PyObjectRef, id: usize, vm: &VirtualMachine) -> String { format!( "<{} object {} at {:#x}>", - self.variant.name(), + gen_name(gen, vm), self.name.lock(), id ) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index ee9898fc1c..e277ae7d20 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -128,42 +128,6 @@ pub enum ExecutionResult { Yield(PyObjectRef), } -impl ExecutionResult { - /// Extract an ExecutionResult from a PyResult returned from e.g. gen.__next__() or gen.send() - pub fn from_result(vm: &VirtualMachine, res: PyResult) -> PyResult { - match res { - Ok(val) => Ok(ExecutionResult::Yield(val)), - Err(err) => { - if err.isinstance(&vm.ctx.exceptions.stop_iteration) { - Ok(ExecutionResult::Return(iterator::stop_iter_value(vm, &err))) - } else { - Err(err) - } - } - } - } - - /// Turn an ExecutionResult into a PyResult that would be returned from a generator or coroutine - pub fn into_result(self, async_stopiter: bool, vm: &VirtualMachine) -> PyResult { - match self { - ExecutionResult::Yield(value) => Ok(value), - ExecutionResult::Return(value) => { - let stop_iteration = if async_stopiter { - vm.ctx.exceptions.stop_async_iteration.clone() - } else { - vm.ctx.exceptions.stop_iteration.clone() - }; - let args = if vm.is_none(&value) { - vec![] - } else { - vec![value] - }; - Err(vm.new_exception(stop_iteration, args)) - } - } - } -} - /// A valid execution result, or an exception pub type FrameResult = PyResult>; @@ -428,18 +392,20 @@ impl ExecutingFrame<'_> { exc_val: PyObjectRef, exc_tb: PyObjectRef, ) -> PyResult { - if let Some(coro) = self.yield_from_target() { + if let Some(gen) = self.yield_from_target() { use crate::utils::Either; // borrow checker shenanigans - we only need to use exc_type/val/tb if the following // variable is Some - let thrower = if let Some(coro) = self.builtin_coro(coro) { + let thrower = if let Some(coro) = self.builtin_coro(gen) { Some(Either::A(coro)) } else { - vm.get_attribute_opt(coro.clone(), "throw")?.map(Either::B) + vm.get_attribute_opt(gen.clone(), "throw")?.map(Either::B) }; if let Some(thrower) = thrower { let ret = match thrower { - Either::A(coro) => coro.throw(exc_type, exc_val, exc_tb, vm), + Either::A(coro) => coro + .throw(gen, exc_type, exc_val, exc_tb, vm) + .into_pyresult(vm), // FIXME: Either::B(meth) => vm.invoke(&meth, vec![exc_type, exc_val, exc_tb]), }; return ret.map(ExecutionResult::Yield).or_else(|err| { @@ -1439,14 +1405,19 @@ impl ExecutingFrame<'_> { }) } - fn _send(&self, coro: &PyObjectRef, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match self.builtin_coro(coro) { - Some(coro) => coro.send(val, vm), + fn _send( + &self, + gen: &PyObjectRef, + val: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + match self.builtin_coro(gen) { + Some(coro) => coro.send(gen, val, vm), // FIXME: turn return type to PyResult then ExecutionResult will be simplified - None if vm.is_none(&val) => PyIter::new(coro).next(vm).into_pyresult(vm), + None if vm.is_none(&val) => PyIter::new(gen).next(vm), None => { - let meth = vm.get_attribute(coro.clone(), "send")?; - vm.invoke(&meth, (val,)) + let meth = vm.get_attribute(gen.clone(), "send")?; + PyIterReturn::from_pyresult(vm.invoke(&meth, (val,)), vm) } } } @@ -1454,20 +1425,18 @@ impl ExecutingFrame<'_> { fn execute_yield_from(&mut self, vm: &VirtualMachine) -> FrameResult { // Value send into iterator: let val = self.pop_value(); - let coro = self.last_value_ref(); + let result = self._send(coro, val, vm)?; - let result = self._send(coro, val, vm); - - let result = ExecutionResult::from_result(vm, result)?; - + // PyIterReturn returned from e.g. gen.__next__() or gen.send() match result { - ExecutionResult::Yield(value) => { + PyIterReturn::Return(value) => { // Set back program counter: self.update_lasti(|i| *i -= 1); Ok(Some(ExecutionResult::Yield(value))) } - ExecutionResult::Return(value) => { + PyIterReturn::StopIteration(value) => { + let value = vm.unwrap_or_none(value); self.pop_value(); self.push_value(value); Ok(None) diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs index 9d081249fb..6a5cc5b0ec 100644 --- a/vm/src/protocol/iter.rs +++ b/vm/src/protocol/iter.rs @@ -126,7 +126,7 @@ pub enum PyIterReturn { } impl PyIterReturn { - pub fn from_result(result: PyResult, vm: &VirtualMachine) -> PyResult { + pub fn from_pyresult(result: PyResult, vm: &VirtualMachine) -> PyResult { match result { Ok(obj) => Ok(Self::Return(obj)), Err(err) if err.isinstance(&vm.ctx.exceptions.stop_iteration) => { @@ -136,6 +136,7 @@ impl PyIterReturn { Err(err) => Err(err), } } + pub fn from_getitem_result(result: PyResult, vm: &VirtualMachine) -> PyResult { match result { Ok(obj) => Ok(Self::Return(obj)), @@ -149,6 +150,16 @@ impl PyIterReturn { Err(err) => Err(err), } } + + pub fn into_async_pyresult(self, vm: &VirtualMachine) -> PyResult { + match self { + Self::Return(obj) => Ok(obj), + Self::StopIteration(v) => Err({ + let args = if let Some(v) = v { vec![v] } else { Vec::new() }; + vm.new_exception(vm.ctx.exceptions.stop_async_iteration.clone(), args) + }), + } + } } impl IntoPyResult for PyIterReturn { diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index d81227a0aa..ecd32b4220 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -371,7 +371,7 @@ mod decl { let function = &zelf.function; match obj { PyIterReturn::Return(obj) => { - PyIterReturn::from_result(vm.invoke(function, vm.extract_elements(&obj)?), vm) + PyIterReturn::from_pyresult(vm.invoke(function, vm.extract_elements(&obj)?), vm) } PyIterReturn::StopIteration(v) => Ok(PyIterReturn::StopIteration(v)), } diff --git a/wasm/lib/src/js_module.rs b/wasm/lib/src/js_module.rs index 5a5014bc19..028e9c8753 100644 --- a/wasm/lib/src/js_module.rs +++ b/wasm/lib/src/js_module.rs @@ -590,7 +590,7 @@ impl AwaitPromise { impl IteratorIterable for AwaitPromise {} impl SlotIterator for AwaitPromise { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - PyIterReturn::from_result(zelf.send(None, vm), vm) + PyIterReturn::from_pyresult(zelf.send(None, vm), vm) } }