Fix asyncio related compiler/library issues (#6837)

* Fix socket bytes support

* fix unwind_fblock

* fix posix.sendfile

* fix ssl_write

* Fix SSL ZeroReturn

* fix context

* fix generator

* Enable unittest test_async_case again
This commit is contained in:
Jeong, YunWon
2026-01-23 19:59:29 +09:00
committed by GitHub
parent 9b56aa5b60
commit efce325cbf
13 changed files with 175 additions and 67 deletions

View File

@@ -217,8 +217,6 @@ class ContextTest(unittest.TestCase):
ctx.run(fun)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@isolated_context
def test_context_getset_1(self):
c = contextvars.ContextVar('c')
@@ -317,8 +315,6 @@ class ContextTest(unittest.TestCase):
with self.assertRaisesRegex(ValueError, 'different Context'):
c.reset(tok)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@isolated_context
def test_context_getset_5(self):
c = contextvars.ContextVar('c', default=42)
@@ -332,8 +328,6 @@ class ContextTest(unittest.TestCase):
contextvars.copy_context().run(fun)
self.assertEqual(c.get(), [])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_context_copy_1(self):
ctx1 = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

View File

@@ -2797,7 +2797,6 @@ class TestGetGeneratorState(unittest.TestCase):
self.assertIn(name, repr(state))
self.assertIn(name, str(state))
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_getgeneratorlocals(self):
def each(lst, a=None):
b=(1, 2, 3)
@@ -2985,7 +2984,6 @@ class TestGetAsyncGenState(unittest.IsolatedAsyncioTestCase):
self.assertIn(name, repr(state))
self.assertIn(name, str(state))
@unittest.expectedFailure # TODO: RUSTPYTHON
async def test_getasyncgenlocals(self):
async def each(lst, a=None):
b=(1, 2, 3)

View File

@@ -3525,7 +3525,6 @@ class ThreadedTests(unittest.TestCase):
else:
s.close()
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
def test_socketserver(self):
"""Using socketserver to create and manage SSL connections."""
server = make_https_server(self, certfile=SIGNED_CERTFILE)

View File

@@ -13,9 +13,7 @@ class MyException(Exception):
def tearDownModule():
# XXX: RUSTPYTHON; asyncio.events._set_event_loop_policy is not implemented
# asyncio.events._set_event_loop_policy(None)
pass
asyncio.events._set_event_loop_policy(None)
class TestCM:
@@ -52,7 +50,6 @@ class TestAsyncCase(unittest.TestCase):
# starting a new event loop
self.addCleanup(support.gc_collect)
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_full_cycle(self):
expected = ['setUp',
'asyncSetUp',

View File

@@ -1528,27 +1528,30 @@ impl Compiler {
// Otherwise, if an exception occurs during the finally body, the stack
// will be unwound to the wrong depth and the return value will be lost.
if preserve_tos {
// Get the handler info from the saved fblock (or current handler)
// and create a new handler with stack_depth + 1
let (handler, stack_depth, preserve_lasti) =
if let Some(handler) = saved_fblock.fb_handler {
(
Some(handler),
saved_fblock.fb_stack_depth + 1, // +1 for return value
saved_fblock.fb_preserve_lasti,
)
} else {
// No handler in saved_fblock, check current handler
if let Some(current_handler) = self.current_except_handler() {
(
Some(current_handler.handler_block),
current_handler.stack_depth + 1, // +1 for return value
current_handler.preserve_lasti,
)
} else {
(None, 1, false) // No handler, but still track the return value
// Find the outer handler for exceptions during finally body execution.
// CRITICAL: Only search fblocks with index < fblock_idx (= outer fblocks).
// Inner FinallyTry blocks may have been restored after their unwind
// processing, and we must NOT use their handlers - that would cause
// the inner finally body to execute again on exception.
let (handler, stack_depth, preserve_lasti) = {
let code = self.code_stack.last().unwrap();
let mut found = None;
// Only search fblocks at indices 0..fblock_idx (outer fblocks)
// After removal, fblock_idx now points to where saved_fblock was,
// so indices 0..fblock_idx are the outer fblocks
for i in (0..fblock_idx).rev() {
let fblock = &code.fblock[i];
if let Some(handler) = fblock.fb_handler {
found = Some((
Some(handler),
fblock.fb_stack_depth + 1, // +1 for return value
fblock.fb_preserve_lasti,
));
break;
}
};
}
found.unwrap_or((None, 1, false))
};
self.push_fblock_with_handler(
FBlockType::PopValue,

View File

@@ -168,11 +168,15 @@ mod _contextvars {
}
#[pymethod]
fn copy(&self) -> Self {
fn copy(&self, vm: &VirtualMachine) -> Self {
// Deep copy the vars - clone the underlying Hamt data, not just the PyRef
let vars_copy = HamtObject {
hamt: RefCell::new(self.inner.vars.hamt.borrow().clone()),
};
Self {
inner: ContextInner {
idx: Cell::new(usize::MAX),
vars: self.inner.vars.clone(),
vars: vars_copy.into_ref(&vm.ctx),
entered: Cell::new(false),
},
}
@@ -630,7 +634,7 @@ mod _contextvars {
#[pyfunction]
fn copy_context(vm: &VirtualMachine) -> PyContext {
PyContext::current(vm).copy()
PyContext::current(vm).copy(vm)
}
// Set Token.MISSING attribute

View File

@@ -15,7 +15,10 @@ mod _socket {
},
common::os::ErrorExt,
convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject},
function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption},
function::{
ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, OptionalArg,
OptionalOption,
},
types::{Constructor, DefaultConstructor, Initializer, Representable},
utils::ToCString,
};
@@ -2783,9 +2786,9 @@ mod _socket {
#[derive(FromArgs)]
struct GAIOptions {
#[pyarg(positional)]
host: Option<PyStrRef>,
host: Option<ArgStrOrBytesLike>,
#[pyarg(positional)]
port: Option<Either<PyStrRef, i32>>,
port: Option<Either<ArgStrOrBytesLike, i32>>,
#[pyarg(positional, default = c::AF_UNSPEC)]
family: i32,
@@ -2809,9 +2812,9 @@ mod _socket {
flags: opts.flags,
};
// Encode host using IDNA encoding
// Encode host: str uses IDNA encoding, bytes must be valid UTF-8
let host_encoded: Option<String> = match opts.host.as_ref() {
Some(s) => {
Some(ArgStrOrBytesLike::Str(s)) => {
let encoded =
vm.state
.codec_registry
@@ -2820,19 +2823,43 @@ mod _socket {
.map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?;
Some(host_str.to_owned())
}
Some(ArgStrOrBytesLike::Buf(b)) => {
let bytes = b.borrow_buf();
let host_str = core::str::from_utf8(&bytes).map_err(|_| {
vm.new_unicode_decode_error("host bytes is not utf8".to_owned())
})?;
Some(host_str.to_owned())
}
None => None,
};
let host = host_encoded.as_deref();
// Encode port using UTF-8
let port: Option<alloc::borrow::Cow<'_, str>> = match opts.port.as_ref() {
Some(Either::A(s)) => Some(alloc::borrow::Cow::Borrowed(s.to_str().ok_or_else(
|| vm.new_unicode_encode_error("surrogates not allowed".to_owned()),
)?)),
Some(Either::B(i)) => Some(alloc::borrow::Cow::Owned(i.to_string())),
// Encode port: str/bytes as service name, int as port number
let port_encoded: Option<String> = match opts.port.as_ref() {
Some(Either::A(sb)) => {
let port_str = match sb {
ArgStrOrBytesLike::Str(s) => {
// For str, check for surrogates and raise UnicodeEncodeError if found
s.to_str()
.ok_or_else(|| vm.new_unicode_encode_error("surrogates not allowed"))?
.to_owned()
}
ArgStrOrBytesLike::Buf(b) => {
// For bytes, check if it's valid UTF-8
let bytes = b.borrow_buf();
core::str::from_utf8(&bytes)
.map_err(|_| {
vm.new_unicode_decode_error("port is not utf8".to_owned())
})?
.to_owned()
}
};
Some(port_str)
}
Some(Either::B(i)) => Some(i.to_string()),
None => None,
};
let port = port.as_ref().map(|p| p.as_ref());
let port = port_encoded.as_deref();
let addrs = dns_lookup::getaddrinfo(host, port, Some(hints))
.map_err(|err| convert_socket_error(vm, err, SocketError::GaiError))?;

View File

@@ -53,6 +53,7 @@ mod _ssl {
// Import error types used in this module (others are exposed via pymodule(with(...)))
use super::error::{
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
create_ssl_zero_return_error,
};
use alloc::sync::Arc;
use core::{
@@ -3593,7 +3594,7 @@ mod _ssl {
let mut conn_guard = self.connection.lock();
let conn = match conn_guard.as_mut() {
Some(conn) => conn,
None => return return_data(vec![], &buffer, vm),
None => return Err(create_ssl_zero_return_error(vm).upcast()),
};
use std::io::BufRead;
let mut reader = conn.reader();
@@ -3613,8 +3614,20 @@ mod _ssl {
return return_data(buf, &buffer, vm);
}
}
// Clean closure with close_notify - return empty data
return_data(vec![], &buffer, vm)
// Clean closure with close_notify
// CPython behavior depends on whether we've sent our close_notify:
// - If we've already sent close_notify (unwrap was called): raise SSLZeroReturnError
// - If we haven't sent close_notify yet: return empty bytes
let our_shutdown_state = *self.shutdown_state.lock();
if our_shutdown_state == ShutdownState::SentCloseNotify
|| our_shutdown_state == ShutdownState::Completed
{
// We already sent close_notify, now receiving peer's → SSLZeroReturnError
Err(create_ssl_zero_return_error(vm).upcast())
} else {
// We haven't sent close_notify yet → return empty bytes
return_data(vec![], &buffer, vm)
}
}
Err(crate::ssl::compat::SslError::WantRead) => {
// Non-blocking mode: would block

View File

@@ -1552,6 +1552,11 @@ pub(super) fn ssl_read(
// Try to read plaintext from rustls buffer
if let Some(n) = try_read_plaintext(conn, buf)? {
if n == 0 {
// EOF from TLS - close_notify received
// Return ZeroReturn so Python raises SSLZeroReturnError
return Err(SslError::ZeroReturn);
}
return Ok(n);
}
@@ -1740,17 +1745,40 @@ pub(super) fn ssl_write(
let already_buffered = *socket.write_buffered_len.lock();
// Only write plaintext if not already buffered
// Track how much we wrote for partial write handling
let mut bytes_written_to_rustls = 0usize;
if already_buffered == 0 {
// Write plaintext to rustls (= SSL_write_ex internal buffer write)
{
bytes_written_to_rustls = {
let mut writer = conn.writer();
use std::io::Write;
writer
.write_all(data)
.map_err(|e| SslError::Syscall(format!("Write failed: {e}")))?;
}
// Mark data as buffered
*socket.write_buffered_len.lock() = data.len();
// Use write() instead of write_all() to support partial writes.
// In BIO mode (asyncio), when the internal buffer is full,
// we want to write as much as possible and return that count,
// rather than failing completely.
match writer.write(data) {
Ok(0) if !data.is_empty() => {
// Buffer is full and nothing could be written.
// In BIO mode, return WantWrite so the caller can
// drain the outgoing BIO and retry.
if is_bio {
return Err(SslError::WantWrite);
}
return Err(SslError::Syscall("Write failed: buffer full".to_string()));
}
Ok(n) => n,
Err(e) => {
if is_bio {
// In BIO mode, treat write errors as WantWrite
return Err(SslError::WantWrite);
}
return Err(SslError::Syscall(format!("Write failed: {e}")));
}
}
};
// Mark data as buffered (only the portion we actually wrote)
*socket.write_buffered_len.lock() = bytes_written_to_rustls;
} else if already_buffered != data.len() {
// Caller is retrying with different data - this is a protocol error
// Clear the buffer state and return an SSL error (bad write retry)
@@ -1790,13 +1818,23 @@ pub(super) fn ssl_write(
}
Err(SslError::WantWrite) => {
// Non-blocking socket would block - return WANT_WRITE
// If we had a partial write to rustls, return partial success
// instead of error to match OpenSSL partial-write semantics
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
*socket.write_buffered_len.lock() = 0;
return Ok(bytes_written_to_rustls);
}
// Keep write_buffered_len set so we don't re-buffer on retry
return Err(SslError::WantWrite);
}
Err(SslError::WantRead) => {
// Need to read before write can complete (e.g., renegotiation)
// This matches CPython's handling of SSL_ERROR_WANT_READ in write
if is_bio {
// If we had a partial write to rustls, return partial success
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
*socket.write_buffered_len.lock() = 0;
return Ok(bytes_written_to_rustls);
}
// Keep write_buffered_len set so we don't re-buffer on retry
return Err(SslError::WantRead);
}
@@ -1807,6 +1845,11 @@ pub(super) fn ssl_write(
// Continue loop
}
Err(e @ SslError::Timeout(_)) => {
// If we had a partial write to rustls, return partial success
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
*socket.write_buffered_len.lock() = 0;
return Ok(bytes_written_to_rustls);
}
// Preserve buffered state so retry doesn't duplicate data
// (send_all_bytes saved unsent TLS bytes to pending_tls_output)
return Err(e);
@@ -1826,10 +1869,21 @@ pub(super) fn ssl_write(
.map_err(SslError::Py)?;
}
// Determine how many bytes we actually wrote
let actual_written = if bytes_written_to_rustls > 0 {
// Fresh write: return what we wrote to rustls
bytes_written_to_rustls
} else if already_buffered > 0 {
// Retry of previous write: return the full buffered amount
already_buffered
} else {
data.len()
};
// Write completed successfully - clear buffer state
*socket.write_buffered_len.lock() = 0;
Ok(data.len())
Ok(actual_written)
}
// Helper functions (private-ish, used by public SSL functions)

View File

@@ -123,8 +123,12 @@ impl PyAsyncGen {
self.inner.frame().yield_from_target()
}
#[pygetset]
fn ag_frame(&self, _vm: &VirtualMachine) -> FrameRef {
self.inner.frame()
fn ag_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
if self.inner.closed() {
None
} else {
Some(self.inner.frame())
}
}
#[pygetset]
fn ag_running(&self, _vm: &VirtualMachine) -> bool {

View File

@@ -76,8 +76,12 @@ impl PyCoroutine {
self.inner.frame().yield_from_target()
}
#[pygetset]
fn cr_frame(&self, _vm: &VirtualMachine) -> FrameRef {
self.inner.frame()
fn cr_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
if self.inner.closed() {
None
} else {
Some(self.inner.frame())
}
}
#[pygetset]
fn cr_running(&self, _vm: &VirtualMachine) -> bool {

View File

@@ -66,8 +66,12 @@ impl PyGenerator {
}
#[pygetset]
fn gi_frame(&self, _vm: &VirtualMachine) -> FrameRef {
self.inner.frame()
fn gi_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
if self.inner.closed() {
None
} else {
Some(self.inner.frame())
}
}
#[pygetset]

View File

@@ -2574,7 +2574,14 @@ pub mod module {
headers,
trailers,
);
res.map_err(|err| err.into_pyexception(vm))?;
// On macOS, sendfile can return EAGAIN even when some bytes were written.
// In that case, we should return the number of bytes written rather than
// raising an exception. Only raise an error if no bytes were written.
if let Err(err) = res
&& written == 0
{
return Err(err.into_pyexception(vm));
}
Ok(vm.ctx.new_int(written as u64).into())
}