Make PyAsyncGen, PyAsyncGenASend, PyAsyncGenAThrow ThreadSafe

This commit is contained in:
Aviv Palivoda
2020-05-08 10:27:29 +03:00
parent ab628a3265
commit 73d9bac4f3

View File

@@ -4,18 +4,19 @@ use super::objtype::{self, PyClassRef};
use crate::exceptions::PyBaseExceptionRef;
use crate::frame::FrameRef;
use crate::function::OptionalArg;
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe};
use crate::vm::VirtualMachine;
use std::cell::Cell;
use crossbeam_utils::atomic::AtomicCell;
#[pyclass(name = "async_generator")]
#[derive(Debug)]
pub struct PyAsyncGen {
inner: Coro,
running_async: Cell<bool>,
running_async: AtomicCell<bool>,
}
pub type PyAsyncGenRef = PyRef<PyAsyncGen>;
impl ThreadSafe for PyAsyncGen {}
impl PyValue for PyAsyncGen {
fn class(vm: &VirtualMachine) -> PyClassRef {
@@ -32,7 +33,7 @@ impl PyAsyncGen {
pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyAsyncGenRef {
PyAsyncGen {
inner: Coro::new(frame, Variant::AsyncGen),
running_async: Cell::new(false),
running_async: AtomicCell::new(false),
}
.into_ref(vm)
}
@@ -57,7 +58,7 @@ impl PyAsyncGen {
fn asend(zelf: PyRef<Self>, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
PyAsyncGenASend {
ag: zelf,
state: Cell::new(AwaitableState::Init),
state: AtomicCell::new(AwaitableState::Init),
value,
}
}
@@ -73,7 +74,7 @@ impl PyAsyncGen {
PyAsyncGenAThrow {
ag: zelf,
aclose: false,
state: Cell::new(AwaitableState::Init),
state: AtomicCell::new(AwaitableState::Init),
value: (
exc_type,
exc_val.unwrap_or_else(|| vm.get_none()),
@@ -87,7 +88,7 @@ impl PyAsyncGen {
PyAsyncGenAThrow {
ag: zelf,
aclose: true,
state: Cell::new(AwaitableState::Init),
state: AtomicCell::new(AwaitableState::Init),
value: (
vm.ctx.exceptions.generator_exit.clone().into_object(),
vm.get_none(),
@@ -131,13 +132,13 @@ impl PyAsyncGenWrappedValue {
{
ag.inner.closed.store(true);
}
ag.running_async.set(false);
ag.running_async.store(false);
}
let val = val?;
match_class!(match val {
val @ Self => {
ag.running_async.set(false);
ag.running_async.store(false);
Err(vm.new_exception(
vm.ctx.exceptions.stop_iteration.clone(),
vec![val.0.clone()],
@@ -159,10 +160,12 @@ enum AwaitableState {
#[derive(Debug)]
struct PyAsyncGenASend {
ag: PyAsyncGenRef,
state: Cell<AwaitableState>,
state: AtomicCell<AwaitableState>,
value: PyObjectRef,
}
impl ThreadSafe for PyAsyncGenASend {}
impl PyValue for PyAsyncGenASend {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.ctx.types.async_generator_asend.clone()
@@ -187,7 +190,7 @@ impl PyAsyncGenASend {
#[pymethod]
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let val = match self.state.get() {
let val = match self.state.load() {
AwaitableState::Closed => {
return Err(vm.new_runtime_error(
"cannot reuse already awaited __anext__()/asend()".to_owned(),
@@ -195,13 +198,13 @@ impl PyAsyncGenASend {
}
AwaitableState::Iter => val, // already running, all good
AwaitableState::Init => {
if self.ag.running_async.get() {
if self.ag.running_async.load() {
return Err(vm.new_runtime_error(
"anext(): asynchronous generator is already running".to_owned(),
));
}
self.ag.running_async.set(true);
self.state.set(AwaitableState::Iter);
self.ag.running_async.store(true);
self.state.store(AwaitableState::Iter);
if vm.is_none(&val) {
self.value.clone()
} else {
@@ -225,7 +228,7 @@ impl PyAsyncGenASend {
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult {
if let AwaitableState::Closed = self.state.get() {
if let AwaitableState::Closed = self.state.load() {
return Err(
vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned())
);
@@ -246,7 +249,7 @@ impl PyAsyncGenASend {
#[pymethod]
fn close(&self) {
self.state.set(AwaitableState::Closed);
self.state.store(AwaitableState::Closed);
}
}
@@ -255,10 +258,12 @@ impl PyAsyncGenASend {
struct PyAsyncGenAThrow {
ag: PyAsyncGenRef,
aclose: bool,
state: Cell<AwaitableState>,
state: AtomicCell<AwaitableState>,
value: (PyObjectRef, PyObjectRef, PyObjectRef),
}
impl ThreadSafe for PyAsyncGenAThrow {}
impl PyValue for PyAsyncGenAThrow {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.ctx.types.async_generator_athrow.clone()
@@ -283,14 +288,14 @@ impl PyAsyncGenAThrow {
#[pymethod]
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match self.state.get() {
match self.state.load() {
AwaitableState::Closed => {
Err(vm
.new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned()))
}
AwaitableState::Init => {
if self.ag.running_async.get() {
self.state.set(AwaitableState::Closed);
if self.ag.running_async.load() {
self.state.store(AwaitableState::Closed);
let msg = if self.aclose {
"aclose(): asynchronous generator is already running"
} else {
@@ -299,7 +304,7 @@ impl PyAsyncGenAThrow {
return Err(vm.new_runtime_error(msg.to_owned()));
}
if self.ag.inner.closed() {
self.state.set(AwaitableState::Closed);
self.state.store(AwaitableState::Closed);
return Err(vm.new_exception_empty(vm.ctx.exceptions.stop_iteration.clone()));
}
if !vm.is_none(&val) {
@@ -307,8 +312,8 @@ impl PyAsyncGenAThrow {
"can't send non-None value to a just-started async generator".to_owned(),
));
}
self.state.set(AwaitableState::Iter);
self.ag.running_async.set(true);
self.state.store(AwaitableState::Iter);
self.ag.running_async.store(true);
let (ty, val, tb) = self.value.clone();
let ret = self.ag.inner.throw(ty, val, tb, vm);
@@ -368,7 +373,7 @@ impl PyAsyncGenAThrow {
#[pymethod]
fn close(&self) {
self.state.set(AwaitableState::Closed);
self.state.store(AwaitableState::Closed);
}
fn ignored_close(&self, res: &PyResult) -> bool {
@@ -376,13 +381,13 @@ impl PyAsyncGenAThrow {
.map_or(false, |v| v.payload_is::<PyAsyncGenWrappedValue>())
}
fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
self.ag.running_async.set(false);
self.state.set(AwaitableState::Closed);
self.ag.running_async.store(false);
self.state.store(AwaitableState::Closed);
vm.new_runtime_error("async generator ignored GeneratorExit".to_owned())
}
fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
self.ag.running_async.set(false);
self.state.set(AwaitableState::Closed);
self.ag.running_async.store(false);
self.state.store(AwaitableState::Closed);
if self.aclose
&& (objtype::isinstance(&exc, &vm.ctx.exceptions.stop_async_iteration)
|| objtype::isinstance(&exc, &vm.ctx.exceptions.generator_exit))