diff --git a/vm/src/obj/objasyncgenerator.rs b/vm/src/obj/objasyncgenerator.rs index 367e4b9864..c26e2a7c0e 100644 --- a/vm/src/obj/objasyncgenerator.rs +++ b/vm/src/obj/objasyncgenerator.rs @@ -4,18 +4,19 @@ use super::objtype::{self, PyClassRef}; use crate::exceptions::PyBaseExceptionRef; use crate::frame::FrameRef; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; use crate::vm::VirtualMachine; -use std::cell::Cell; +use crossbeam_utils::atomic::AtomicCell; #[pyclass(name = "async_generator")] #[derive(Debug)] pub struct PyAsyncGen { inner: Coro, - running_async: Cell, + running_async: AtomicCell, } pub type PyAsyncGenRef = PyRef; +impl ThreadSafe for PyAsyncGen {} impl PyValue for PyAsyncGen { fn class(vm: &VirtualMachine) -> PyClassRef { @@ -32,7 +33,7 @@ impl PyAsyncGen { pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyAsyncGenRef { PyAsyncGen { inner: Coro::new(frame, Variant::AsyncGen), - running_async: Cell::new(false), + running_async: AtomicCell::new(false), } .into_ref(vm) } @@ -57,7 +58,7 @@ impl PyAsyncGen { fn asend(zelf: PyRef, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend { PyAsyncGenASend { ag: zelf, - state: Cell::new(AwaitableState::Init), + state: AtomicCell::new(AwaitableState::Init), value, } } @@ -73,7 +74,7 @@ impl PyAsyncGen { PyAsyncGenAThrow { ag: zelf, aclose: false, - state: Cell::new(AwaitableState::Init), + state: AtomicCell::new(AwaitableState::Init), value: ( exc_type, exc_val.unwrap_or_else(|| vm.get_none()), @@ -87,7 +88,7 @@ impl PyAsyncGen { PyAsyncGenAThrow { ag: zelf, aclose: true, - state: Cell::new(AwaitableState::Init), + state: AtomicCell::new(AwaitableState::Init), value: ( vm.ctx.exceptions.generator_exit.clone().into_object(), vm.get_none(), @@ -131,13 +132,13 @@ impl PyAsyncGenWrappedValue { { ag.inner.closed.store(true); } - ag.running_async.set(false); + ag.running_async.store(false); } let val = val?; match_class!(match val { val @ Self => { - ag.running_async.set(false); + ag.running_async.store(false); Err(vm.new_exception( vm.ctx.exceptions.stop_iteration.clone(), vec![val.0.clone()], @@ -159,10 +160,12 @@ enum AwaitableState { #[derive(Debug)] struct PyAsyncGenASend { ag: PyAsyncGenRef, - state: Cell, + state: AtomicCell, value: PyObjectRef, } +impl ThreadSafe for PyAsyncGenASend {} + impl PyValue for PyAsyncGenASend { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.types.async_generator_asend.clone() @@ -187,7 +190,7 @@ impl PyAsyncGenASend { #[pymethod] fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let val = match self.state.get() { + let val = match self.state.load() { AwaitableState::Closed => { return Err(vm.new_runtime_error( "cannot reuse already awaited __anext__()/asend()".to_owned(), @@ -195,13 +198,13 @@ impl PyAsyncGenASend { } AwaitableState::Iter => val, // already running, all good AwaitableState::Init => { - if self.ag.running_async.get() { + if self.ag.running_async.load() { return Err(vm.new_runtime_error( "anext(): asynchronous generator is already running".to_owned(), )); } - self.ag.running_async.set(true); - self.state.set(AwaitableState::Iter); + self.ag.running_async.store(true); + self.state.store(AwaitableState::Iter); if vm.is_none(&val) { self.value.clone() } else { @@ -225,7 +228,7 @@ impl PyAsyncGenASend { exc_tb: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - if let AwaitableState::Closed = self.state.get() { + if let AwaitableState::Closed = self.state.load() { return Err( vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned()) ); @@ -246,7 +249,7 @@ impl PyAsyncGenASend { #[pymethod] fn close(&self) { - self.state.set(AwaitableState::Closed); + self.state.store(AwaitableState::Closed); } } @@ -255,10 +258,12 @@ impl PyAsyncGenASend { struct PyAsyncGenAThrow { ag: PyAsyncGenRef, aclose: bool, - state: Cell, + state: AtomicCell, value: (PyObjectRef, PyObjectRef, PyObjectRef), } +impl ThreadSafe for PyAsyncGenAThrow {} + impl PyValue for PyAsyncGenAThrow { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.types.async_generator_athrow.clone() @@ -283,14 +288,14 @@ impl PyAsyncGenAThrow { #[pymethod] fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match self.state.get() { + match self.state.load() { AwaitableState::Closed => { Err(vm .new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned())) } AwaitableState::Init => { - if self.ag.running_async.get() { - self.state.set(AwaitableState::Closed); + if self.ag.running_async.load() { + self.state.store(AwaitableState::Closed); let msg = if self.aclose { "aclose(): asynchronous generator is already running" } else { @@ -299,7 +304,7 @@ impl PyAsyncGenAThrow { return Err(vm.new_runtime_error(msg.to_owned())); } if self.ag.inner.closed() { - self.state.set(AwaitableState::Closed); + self.state.store(AwaitableState::Closed); return Err(vm.new_exception_empty(vm.ctx.exceptions.stop_iteration.clone())); } if !vm.is_none(&val) { @@ -307,8 +312,8 @@ impl PyAsyncGenAThrow { "can't send non-None value to a just-started async generator".to_owned(), )); } - self.state.set(AwaitableState::Iter); - self.ag.running_async.set(true); + self.state.store(AwaitableState::Iter); + self.ag.running_async.store(true); let (ty, val, tb) = self.value.clone(); let ret = self.ag.inner.throw(ty, val, tb, vm); @@ -368,7 +373,7 @@ impl PyAsyncGenAThrow { #[pymethod] fn close(&self) { - self.state.set(AwaitableState::Closed); + self.state.store(AwaitableState::Closed); } fn ignored_close(&self, res: &PyResult) -> bool { @@ -376,13 +381,13 @@ impl PyAsyncGenAThrow { .map_or(false, |v| v.payload_is::()) } fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { - self.ag.running_async.set(false); - self.state.set(AwaitableState::Closed); + self.ag.running_async.store(false); + self.state.store(AwaitableState::Closed); vm.new_runtime_error("async generator ignored GeneratorExit".to_owned()) } fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef { - self.ag.running_async.set(false); - self.state.set(AwaitableState::Closed); + self.ag.running_async.store(false); + self.state.store(AwaitableState::Closed); if self.aclose && (objtype::isinstance(&exc, &vm.ctx.exceptions.stop_async_iteration) || objtype::isinstance(&exc, &vm.ctx.exceptions.generator_exit))