forked from Rust-related/RustPython
Make PyAsyncGen, PyAsyncGenASend, PyAsyncGenAThrow ThreadSafe
This commit is contained in:
@@ -4,18 +4,19 @@ use super::objtype::{self, PyClassRef};
|
||||
use crate::exceptions::PyBaseExceptionRef;
|
||||
use crate::frame::FrameRef;
|
||||
use crate::function::OptionalArg;
|
||||
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
|
||||
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe};
|
||||
use crate::vm::VirtualMachine;
|
||||
|
||||
use std::cell::Cell;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
|
||||
#[pyclass(name = "async_generator")]
|
||||
#[derive(Debug)]
|
||||
pub struct PyAsyncGen {
|
||||
inner: Coro,
|
||||
running_async: Cell<bool>,
|
||||
running_async: AtomicCell<bool>,
|
||||
}
|
||||
pub type PyAsyncGenRef = PyRef<PyAsyncGen>;
|
||||
impl ThreadSafe for PyAsyncGen {}
|
||||
|
||||
impl PyValue for PyAsyncGen {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
@@ -32,7 +33,7 @@ impl PyAsyncGen {
|
||||
pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyAsyncGenRef {
|
||||
PyAsyncGen {
|
||||
inner: Coro::new(frame, Variant::AsyncGen),
|
||||
running_async: Cell::new(false),
|
||||
running_async: AtomicCell::new(false),
|
||||
}
|
||||
.into_ref(vm)
|
||||
}
|
||||
@@ -57,7 +58,7 @@ impl PyAsyncGen {
|
||||
fn asend(zelf: PyRef<Self>, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
|
||||
PyAsyncGenASend {
|
||||
ag: zelf,
|
||||
state: Cell::new(AwaitableState::Init),
|
||||
state: AtomicCell::new(AwaitableState::Init),
|
||||
value,
|
||||
}
|
||||
}
|
||||
@@ -73,7 +74,7 @@ impl PyAsyncGen {
|
||||
PyAsyncGenAThrow {
|
||||
ag: zelf,
|
||||
aclose: false,
|
||||
state: Cell::new(AwaitableState::Init),
|
||||
state: AtomicCell::new(AwaitableState::Init),
|
||||
value: (
|
||||
exc_type,
|
||||
exc_val.unwrap_or_else(|| vm.get_none()),
|
||||
@@ -87,7 +88,7 @@ impl PyAsyncGen {
|
||||
PyAsyncGenAThrow {
|
||||
ag: zelf,
|
||||
aclose: true,
|
||||
state: Cell::new(AwaitableState::Init),
|
||||
state: AtomicCell::new(AwaitableState::Init),
|
||||
value: (
|
||||
vm.ctx.exceptions.generator_exit.clone().into_object(),
|
||||
vm.get_none(),
|
||||
@@ -131,13 +132,13 @@ impl PyAsyncGenWrappedValue {
|
||||
{
|
||||
ag.inner.closed.store(true);
|
||||
}
|
||||
ag.running_async.set(false);
|
||||
ag.running_async.store(false);
|
||||
}
|
||||
let val = val?;
|
||||
|
||||
match_class!(match val {
|
||||
val @ Self => {
|
||||
ag.running_async.set(false);
|
||||
ag.running_async.store(false);
|
||||
Err(vm.new_exception(
|
||||
vm.ctx.exceptions.stop_iteration.clone(),
|
||||
vec![val.0.clone()],
|
||||
@@ -159,10 +160,12 @@ enum AwaitableState {
|
||||
#[derive(Debug)]
|
||||
struct PyAsyncGenASend {
|
||||
ag: PyAsyncGenRef,
|
||||
state: Cell<AwaitableState>,
|
||||
state: AtomicCell<AwaitableState>,
|
||||
value: PyObjectRef,
|
||||
}
|
||||
|
||||
impl ThreadSafe for PyAsyncGenASend {}
|
||||
|
||||
impl PyValue for PyAsyncGenASend {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
vm.ctx.types.async_generator_asend.clone()
|
||||
@@ -187,7 +190,7 @@ impl PyAsyncGenASend {
|
||||
|
||||
#[pymethod]
|
||||
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
let val = match self.state.get() {
|
||||
let val = match self.state.load() {
|
||||
AwaitableState::Closed => {
|
||||
return Err(vm.new_runtime_error(
|
||||
"cannot reuse already awaited __anext__()/asend()".to_owned(),
|
||||
@@ -195,13 +198,13 @@ impl PyAsyncGenASend {
|
||||
}
|
||||
AwaitableState::Iter => val, // already running, all good
|
||||
AwaitableState::Init => {
|
||||
if self.ag.running_async.get() {
|
||||
if self.ag.running_async.load() {
|
||||
return Err(vm.new_runtime_error(
|
||||
"anext(): asynchronous generator is already running".to_owned(),
|
||||
));
|
||||
}
|
||||
self.ag.running_async.set(true);
|
||||
self.state.set(AwaitableState::Iter);
|
||||
self.ag.running_async.store(true);
|
||||
self.state.store(AwaitableState::Iter);
|
||||
if vm.is_none(&val) {
|
||||
self.value.clone()
|
||||
} else {
|
||||
@@ -225,7 +228,7 @@ impl PyAsyncGenASend {
|
||||
exc_tb: OptionalArg,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
if let AwaitableState::Closed = self.state.get() {
|
||||
if let AwaitableState::Closed = self.state.load() {
|
||||
return Err(
|
||||
vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned())
|
||||
);
|
||||
@@ -246,7 +249,7 @@ impl PyAsyncGenASend {
|
||||
|
||||
#[pymethod]
|
||||
fn close(&self) {
|
||||
self.state.set(AwaitableState::Closed);
|
||||
self.state.store(AwaitableState::Closed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,10 +258,12 @@ impl PyAsyncGenASend {
|
||||
struct PyAsyncGenAThrow {
|
||||
ag: PyAsyncGenRef,
|
||||
aclose: bool,
|
||||
state: Cell<AwaitableState>,
|
||||
state: AtomicCell<AwaitableState>,
|
||||
value: (PyObjectRef, PyObjectRef, PyObjectRef),
|
||||
}
|
||||
|
||||
impl ThreadSafe for PyAsyncGenAThrow {}
|
||||
|
||||
impl PyValue for PyAsyncGenAThrow {
|
||||
fn class(vm: &VirtualMachine) -> PyClassRef {
|
||||
vm.ctx.types.async_generator_athrow.clone()
|
||||
@@ -283,14 +288,14 @@ impl PyAsyncGenAThrow {
|
||||
|
||||
#[pymethod]
|
||||
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
match self.state.get() {
|
||||
match self.state.load() {
|
||||
AwaitableState::Closed => {
|
||||
Err(vm
|
||||
.new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned()))
|
||||
}
|
||||
AwaitableState::Init => {
|
||||
if self.ag.running_async.get() {
|
||||
self.state.set(AwaitableState::Closed);
|
||||
if self.ag.running_async.load() {
|
||||
self.state.store(AwaitableState::Closed);
|
||||
let msg = if self.aclose {
|
||||
"aclose(): asynchronous generator is already running"
|
||||
} else {
|
||||
@@ -299,7 +304,7 @@ impl PyAsyncGenAThrow {
|
||||
return Err(vm.new_runtime_error(msg.to_owned()));
|
||||
}
|
||||
if self.ag.inner.closed() {
|
||||
self.state.set(AwaitableState::Closed);
|
||||
self.state.store(AwaitableState::Closed);
|
||||
return Err(vm.new_exception_empty(vm.ctx.exceptions.stop_iteration.clone()));
|
||||
}
|
||||
if !vm.is_none(&val) {
|
||||
@@ -307,8 +312,8 @@ impl PyAsyncGenAThrow {
|
||||
"can't send non-None value to a just-started async generator".to_owned(),
|
||||
));
|
||||
}
|
||||
self.state.set(AwaitableState::Iter);
|
||||
self.ag.running_async.set(true);
|
||||
self.state.store(AwaitableState::Iter);
|
||||
self.ag.running_async.store(true);
|
||||
|
||||
let (ty, val, tb) = self.value.clone();
|
||||
let ret = self.ag.inner.throw(ty, val, tb, vm);
|
||||
@@ -368,7 +373,7 @@ impl PyAsyncGenAThrow {
|
||||
|
||||
#[pymethod]
|
||||
fn close(&self) {
|
||||
self.state.set(AwaitableState::Closed);
|
||||
self.state.store(AwaitableState::Closed);
|
||||
}
|
||||
|
||||
fn ignored_close(&self, res: &PyResult) -> bool {
|
||||
@@ -376,13 +381,13 @@ impl PyAsyncGenAThrow {
|
||||
.map_or(false, |v| v.payload_is::<PyAsyncGenWrappedValue>())
|
||||
}
|
||||
fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
self.ag.running_async.set(false);
|
||||
self.state.set(AwaitableState::Closed);
|
||||
self.ag.running_async.store(false);
|
||||
self.state.store(AwaitableState::Closed);
|
||||
vm.new_runtime_error("async generator ignored GeneratorExit".to_owned())
|
||||
}
|
||||
fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
self.ag.running_async.set(false);
|
||||
self.state.set(AwaitableState::Closed);
|
||||
self.ag.running_async.store(false);
|
||||
self.state.store(AwaitableState::Closed);
|
||||
if self.aclose
|
||||
&& (objtype::isinstance(&exc, &vm.ctx.exceptions.stop_async_iteration)
|
||||
|| objtype::isinstance(&exc, &vm.ctx.exceptions.generator_exit))
|
||||
|
||||
Reference in New Issue
Block a user