mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Fix asyncio related compiler/library issues (#6837)
* Fix socket bytes support * fix unwind_fblock * fix posix.sendfile * fix ssl_write * Fix SSL ZeroReturn * fix context * fix generator * Enable unittest test_async_case again
This commit is contained in:
6
Lib/test/test_context.py
vendored
6
Lib/test/test_context.py
vendored
@@ -217,8 +217,6 @@ class ContextTest(unittest.TestCase):
|
||||
|
||||
ctx.run(fun)
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@isolated_context
|
||||
def test_context_getset_1(self):
|
||||
c = contextvars.ContextVar('c')
|
||||
@@ -317,8 +315,6 @@ class ContextTest(unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'different Context'):
|
||||
c.reset(tok)
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
@isolated_context
|
||||
def test_context_getset_5(self):
|
||||
c = contextvars.ContextVar('c', default=42)
|
||||
@@ -332,8 +328,6 @@ class ContextTest(unittest.TestCase):
|
||||
contextvars.copy_context().run(fun)
|
||||
self.assertEqual(c.get(), [])
|
||||
|
||||
# TODO: RUSTPYTHON
|
||||
@unittest.expectedFailure
|
||||
def test_context_copy_1(self):
|
||||
ctx1 = contextvars.Context()
|
||||
c = contextvars.ContextVar('c', default=42)
|
||||
|
||||
2
Lib/test/test_inspect/test_inspect.py
vendored
2
Lib/test/test_inspect/test_inspect.py
vendored
@@ -2797,7 +2797,6 @@ class TestGetGeneratorState(unittest.TestCase):
|
||||
self.assertIn(name, repr(state))
|
||||
self.assertIn(name, str(state))
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_getgeneratorlocals(self):
|
||||
def each(lst, a=None):
|
||||
b=(1, 2, 3)
|
||||
@@ -2985,7 +2984,6 @@ class TestGetAsyncGenState(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertIn(name, repr(state))
|
||||
self.assertIn(name, str(state))
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
async def test_getasyncgenlocals(self):
|
||||
async def each(lst, a=None):
|
||||
b=(1, 2, 3)
|
||||
|
||||
1
Lib/test/test_ssl.py
vendored
1
Lib/test/test_ssl.py
vendored
@@ -3525,7 +3525,6 @@ class ThreadedTests(unittest.TestCase):
|
||||
else:
|
||||
s.close()
|
||||
|
||||
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
|
||||
def test_socketserver(self):
|
||||
"""Using socketserver to create and manage SSL connections."""
|
||||
server = make_https_server(self, certfile=SIGNED_CERTFILE)
|
||||
|
||||
5
Lib/test/test_unittest/test_async_case.py
vendored
5
Lib/test/test_unittest/test_async_case.py
vendored
@@ -13,9 +13,7 @@ class MyException(Exception):
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
# XXX: RUSTPYTHON; asyncio.events._set_event_loop_policy is not implemented
|
||||
# asyncio.events._set_event_loop_policy(None)
|
||||
pass
|
||||
asyncio.events._set_event_loop_policy(None)
|
||||
|
||||
|
||||
class TestCM:
|
||||
@@ -52,7 +50,6 @@ class TestAsyncCase(unittest.TestCase):
|
||||
# starting a new event loop
|
||||
self.addCleanup(support.gc_collect)
|
||||
|
||||
@unittest.expectedFailure # TODO: RUSTPYTHON
|
||||
def test_full_cycle(self):
|
||||
expected = ['setUp',
|
||||
'asyncSetUp',
|
||||
|
||||
@@ -1528,27 +1528,30 @@ impl Compiler {
|
||||
// Otherwise, if an exception occurs during the finally body, the stack
|
||||
// will be unwound to the wrong depth and the return value will be lost.
|
||||
if preserve_tos {
|
||||
// Get the handler info from the saved fblock (or current handler)
|
||||
// and create a new handler with stack_depth + 1
|
||||
let (handler, stack_depth, preserve_lasti) =
|
||||
if let Some(handler) = saved_fblock.fb_handler {
|
||||
(
|
||||
Some(handler),
|
||||
saved_fblock.fb_stack_depth + 1, // +1 for return value
|
||||
saved_fblock.fb_preserve_lasti,
|
||||
)
|
||||
} else {
|
||||
// No handler in saved_fblock, check current handler
|
||||
if let Some(current_handler) = self.current_except_handler() {
|
||||
(
|
||||
Some(current_handler.handler_block),
|
||||
current_handler.stack_depth + 1, // +1 for return value
|
||||
current_handler.preserve_lasti,
|
||||
)
|
||||
} else {
|
||||
(None, 1, false) // No handler, but still track the return value
|
||||
// Find the outer handler for exceptions during finally body execution.
|
||||
// CRITICAL: Only search fblocks with index < fblock_idx (= outer fblocks).
|
||||
// Inner FinallyTry blocks may have been restored after their unwind
|
||||
// processing, and we must NOT use their handlers - that would cause
|
||||
// the inner finally body to execute again on exception.
|
||||
let (handler, stack_depth, preserve_lasti) = {
|
||||
let code = self.code_stack.last().unwrap();
|
||||
let mut found = None;
|
||||
// Only search fblocks at indices 0..fblock_idx (outer fblocks)
|
||||
// After removal, fblock_idx now points to where saved_fblock was,
|
||||
// so indices 0..fblock_idx are the outer fblocks
|
||||
for i in (0..fblock_idx).rev() {
|
||||
let fblock = &code.fblock[i];
|
||||
if let Some(handler) = fblock.fb_handler {
|
||||
found = Some((
|
||||
Some(handler),
|
||||
fblock.fb_stack_depth + 1, // +1 for return value
|
||||
fblock.fb_preserve_lasti,
|
||||
));
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
found.unwrap_or((None, 1, false))
|
||||
};
|
||||
|
||||
self.push_fblock_with_handler(
|
||||
FBlockType::PopValue,
|
||||
|
||||
@@ -168,11 +168,15 @@ mod _contextvars {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn copy(&self) -> Self {
|
||||
fn copy(&self, vm: &VirtualMachine) -> Self {
|
||||
// Deep copy the vars - clone the underlying Hamt data, not just the PyRef
|
||||
let vars_copy = HamtObject {
|
||||
hamt: RefCell::new(self.inner.vars.hamt.borrow().clone()),
|
||||
};
|
||||
Self {
|
||||
inner: ContextInner {
|
||||
idx: Cell::new(usize::MAX),
|
||||
vars: self.inner.vars.clone(),
|
||||
vars: vars_copy.into_ref(&vm.ctx),
|
||||
entered: Cell::new(false),
|
||||
},
|
||||
}
|
||||
@@ -630,7 +634,7 @@ mod _contextvars {
|
||||
|
||||
#[pyfunction]
|
||||
fn copy_context(vm: &VirtualMachine) -> PyContext {
|
||||
PyContext::current(vm).copy()
|
||||
PyContext::current(vm).copy(vm)
|
||||
}
|
||||
|
||||
// Set Token.MISSING attribute
|
||||
|
||||
@@ -15,7 +15,10 @@ mod _socket {
|
||||
},
|
||||
common::os::ErrorExt,
|
||||
convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject},
|
||||
function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption},
|
||||
function::{
|
||||
ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, OptionalArg,
|
||||
OptionalOption,
|
||||
},
|
||||
types::{Constructor, DefaultConstructor, Initializer, Representable},
|
||||
utils::ToCString,
|
||||
};
|
||||
@@ -2783,9 +2786,9 @@ mod _socket {
|
||||
#[derive(FromArgs)]
|
||||
struct GAIOptions {
|
||||
#[pyarg(positional)]
|
||||
host: Option<PyStrRef>,
|
||||
host: Option<ArgStrOrBytesLike>,
|
||||
#[pyarg(positional)]
|
||||
port: Option<Either<PyStrRef, i32>>,
|
||||
port: Option<Either<ArgStrOrBytesLike, i32>>,
|
||||
|
||||
#[pyarg(positional, default = c::AF_UNSPEC)]
|
||||
family: i32,
|
||||
@@ -2809,9 +2812,9 @@ mod _socket {
|
||||
flags: opts.flags,
|
||||
};
|
||||
|
||||
// Encode host using IDNA encoding
|
||||
// Encode host: str uses IDNA encoding, bytes must be valid UTF-8
|
||||
let host_encoded: Option<String> = match opts.host.as_ref() {
|
||||
Some(s) => {
|
||||
Some(ArgStrOrBytesLike::Str(s)) => {
|
||||
let encoded =
|
||||
vm.state
|
||||
.codec_registry
|
||||
@@ -2820,19 +2823,43 @@ mod _socket {
|
||||
.map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?;
|
||||
Some(host_str.to_owned())
|
||||
}
|
||||
Some(ArgStrOrBytesLike::Buf(b)) => {
|
||||
let bytes = b.borrow_buf();
|
||||
let host_str = core::str::from_utf8(&bytes).map_err(|_| {
|
||||
vm.new_unicode_decode_error("host bytes is not utf8".to_owned())
|
||||
})?;
|
||||
Some(host_str.to_owned())
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let host = host_encoded.as_deref();
|
||||
|
||||
// Encode port using UTF-8
|
||||
let port: Option<alloc::borrow::Cow<'_, str>> = match opts.port.as_ref() {
|
||||
Some(Either::A(s)) => Some(alloc::borrow::Cow::Borrowed(s.to_str().ok_or_else(
|
||||
|| vm.new_unicode_encode_error("surrogates not allowed".to_owned()),
|
||||
)?)),
|
||||
Some(Either::B(i)) => Some(alloc::borrow::Cow::Owned(i.to_string())),
|
||||
// Encode port: str/bytes as service name, int as port number
|
||||
let port_encoded: Option<String> = match opts.port.as_ref() {
|
||||
Some(Either::A(sb)) => {
|
||||
let port_str = match sb {
|
||||
ArgStrOrBytesLike::Str(s) => {
|
||||
// For str, check for surrogates and raise UnicodeEncodeError if found
|
||||
s.to_str()
|
||||
.ok_or_else(|| vm.new_unicode_encode_error("surrogates not allowed"))?
|
||||
.to_owned()
|
||||
}
|
||||
ArgStrOrBytesLike::Buf(b) => {
|
||||
// For bytes, check if it's valid UTF-8
|
||||
let bytes = b.borrow_buf();
|
||||
core::str::from_utf8(&bytes)
|
||||
.map_err(|_| {
|
||||
vm.new_unicode_decode_error("port is not utf8".to_owned())
|
||||
})?
|
||||
.to_owned()
|
||||
}
|
||||
};
|
||||
Some(port_str)
|
||||
}
|
||||
Some(Either::B(i)) => Some(i.to_string()),
|
||||
None => None,
|
||||
};
|
||||
let port = port.as_ref().map(|p| p.as_ref());
|
||||
let port = port_encoded.as_deref();
|
||||
|
||||
let addrs = dns_lookup::getaddrinfo(host, port, Some(hints))
|
||||
.map_err(|err| convert_socket_error(vm, err, SocketError::GaiError))?;
|
||||
|
||||
@@ -53,6 +53,7 @@ mod _ssl {
|
||||
// Import error types used in this module (others are exposed via pymodule(with(...)))
|
||||
use super::error::{
|
||||
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
|
||||
create_ssl_zero_return_error,
|
||||
};
|
||||
use alloc::sync::Arc;
|
||||
use core::{
|
||||
@@ -3593,7 +3594,7 @@ mod _ssl {
|
||||
let mut conn_guard = self.connection.lock();
|
||||
let conn = match conn_guard.as_mut() {
|
||||
Some(conn) => conn,
|
||||
None => return return_data(vec![], &buffer, vm),
|
||||
None => return Err(create_ssl_zero_return_error(vm).upcast()),
|
||||
};
|
||||
use std::io::BufRead;
|
||||
let mut reader = conn.reader();
|
||||
@@ -3613,8 +3614,20 @@ mod _ssl {
|
||||
return return_data(buf, &buffer, vm);
|
||||
}
|
||||
}
|
||||
// Clean closure with close_notify - return empty data
|
||||
return_data(vec![], &buffer, vm)
|
||||
// Clean closure with close_notify
|
||||
// CPython behavior depends on whether we've sent our close_notify:
|
||||
// - If we've already sent close_notify (unwrap was called): raise SSLZeroReturnError
|
||||
// - If we haven't sent close_notify yet: return empty bytes
|
||||
let our_shutdown_state = *self.shutdown_state.lock();
|
||||
if our_shutdown_state == ShutdownState::SentCloseNotify
|
||||
|| our_shutdown_state == ShutdownState::Completed
|
||||
{
|
||||
// We already sent close_notify, now receiving peer's → SSLZeroReturnError
|
||||
Err(create_ssl_zero_return_error(vm).upcast())
|
||||
} else {
|
||||
// We haven't sent close_notify yet → return empty bytes
|
||||
return_data(vec![], &buffer, vm)
|
||||
}
|
||||
}
|
||||
Err(crate::ssl::compat::SslError::WantRead) => {
|
||||
// Non-blocking mode: would block
|
||||
|
||||
@@ -1552,6 +1552,11 @@ pub(super) fn ssl_read(
|
||||
|
||||
// Try to read plaintext from rustls buffer
|
||||
if let Some(n) = try_read_plaintext(conn, buf)? {
|
||||
if n == 0 {
|
||||
// EOF from TLS - close_notify received
|
||||
// Return ZeroReturn so Python raises SSLZeroReturnError
|
||||
return Err(SslError::ZeroReturn);
|
||||
}
|
||||
return Ok(n);
|
||||
}
|
||||
|
||||
@@ -1740,17 +1745,40 @@ pub(super) fn ssl_write(
|
||||
let already_buffered = *socket.write_buffered_len.lock();
|
||||
|
||||
// Only write plaintext if not already buffered
|
||||
// Track how much we wrote for partial write handling
|
||||
let mut bytes_written_to_rustls = 0usize;
|
||||
|
||||
if already_buffered == 0 {
|
||||
// Write plaintext to rustls (= SSL_write_ex internal buffer write)
|
||||
{
|
||||
bytes_written_to_rustls = {
|
||||
let mut writer = conn.writer();
|
||||
use std::io::Write;
|
||||
writer
|
||||
.write_all(data)
|
||||
.map_err(|e| SslError::Syscall(format!("Write failed: {e}")))?;
|
||||
}
|
||||
// Mark data as buffered
|
||||
*socket.write_buffered_len.lock() = data.len();
|
||||
// Use write() instead of write_all() to support partial writes.
|
||||
// In BIO mode (asyncio), when the internal buffer is full,
|
||||
// we want to write as much as possible and return that count,
|
||||
// rather than failing completely.
|
||||
match writer.write(data) {
|
||||
Ok(0) if !data.is_empty() => {
|
||||
// Buffer is full and nothing could be written.
|
||||
// In BIO mode, return WantWrite so the caller can
|
||||
// drain the outgoing BIO and retry.
|
||||
if is_bio {
|
||||
return Err(SslError::WantWrite);
|
||||
}
|
||||
return Err(SslError::Syscall("Write failed: buffer full".to_string()));
|
||||
}
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
if is_bio {
|
||||
// In BIO mode, treat write errors as WantWrite
|
||||
return Err(SslError::WantWrite);
|
||||
}
|
||||
return Err(SslError::Syscall(format!("Write failed: {e}")));
|
||||
}
|
||||
}
|
||||
};
|
||||
// Mark data as buffered (only the portion we actually wrote)
|
||||
*socket.write_buffered_len.lock() = bytes_written_to_rustls;
|
||||
} else if already_buffered != data.len() {
|
||||
// Caller is retrying with different data - this is a protocol error
|
||||
// Clear the buffer state and return an SSL error (bad write retry)
|
||||
@@ -1790,13 +1818,23 @@ pub(super) fn ssl_write(
|
||||
}
|
||||
Err(SslError::WantWrite) => {
|
||||
// Non-blocking socket would block - return WANT_WRITE
|
||||
// If we had a partial write to rustls, return partial success
|
||||
// instead of error to match OpenSSL partial-write semantics
|
||||
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
|
||||
*socket.write_buffered_len.lock() = 0;
|
||||
return Ok(bytes_written_to_rustls);
|
||||
}
|
||||
// Keep write_buffered_len set so we don't re-buffer on retry
|
||||
return Err(SslError::WantWrite);
|
||||
}
|
||||
Err(SslError::WantRead) => {
|
||||
// Need to read before write can complete (e.g., renegotiation)
|
||||
// This matches CPython's handling of SSL_ERROR_WANT_READ in write
|
||||
if is_bio {
|
||||
// If we had a partial write to rustls, return partial success
|
||||
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
|
||||
*socket.write_buffered_len.lock() = 0;
|
||||
return Ok(bytes_written_to_rustls);
|
||||
}
|
||||
// Keep write_buffered_len set so we don't re-buffer on retry
|
||||
return Err(SslError::WantRead);
|
||||
}
|
||||
@@ -1807,6 +1845,11 @@ pub(super) fn ssl_write(
|
||||
// Continue loop
|
||||
}
|
||||
Err(e @ SslError::Timeout(_)) => {
|
||||
// If we had a partial write to rustls, return partial success
|
||||
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
|
||||
*socket.write_buffered_len.lock() = 0;
|
||||
return Ok(bytes_written_to_rustls);
|
||||
}
|
||||
// Preserve buffered state so retry doesn't duplicate data
|
||||
// (send_all_bytes saved unsent TLS bytes to pending_tls_output)
|
||||
return Err(e);
|
||||
@@ -1826,10 +1869,21 @@ pub(super) fn ssl_write(
|
||||
.map_err(SslError::Py)?;
|
||||
}
|
||||
|
||||
// Determine how many bytes we actually wrote
|
||||
let actual_written = if bytes_written_to_rustls > 0 {
|
||||
// Fresh write: return what we wrote to rustls
|
||||
bytes_written_to_rustls
|
||||
} else if already_buffered > 0 {
|
||||
// Retry of previous write: return the full buffered amount
|
||||
already_buffered
|
||||
} else {
|
||||
data.len()
|
||||
};
|
||||
|
||||
// Write completed successfully - clear buffer state
|
||||
*socket.write_buffered_len.lock() = 0;
|
||||
|
||||
Ok(data.len())
|
||||
Ok(actual_written)
|
||||
}
|
||||
|
||||
// Helper functions (private-ish, used by public SSL functions)
|
||||
|
||||
@@ -123,8 +123,12 @@ impl PyAsyncGen {
|
||||
self.inner.frame().yield_from_target()
|
||||
}
|
||||
#[pygetset]
|
||||
fn ag_frame(&self, _vm: &VirtualMachine) -> FrameRef {
|
||||
self.inner.frame()
|
||||
fn ag_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
|
||||
if self.inner.closed() {
|
||||
None
|
||||
} else {
|
||||
Some(self.inner.frame())
|
||||
}
|
||||
}
|
||||
#[pygetset]
|
||||
fn ag_running(&self, _vm: &VirtualMachine) -> bool {
|
||||
|
||||
@@ -76,8 +76,12 @@ impl PyCoroutine {
|
||||
self.inner.frame().yield_from_target()
|
||||
}
|
||||
#[pygetset]
|
||||
fn cr_frame(&self, _vm: &VirtualMachine) -> FrameRef {
|
||||
self.inner.frame()
|
||||
fn cr_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
|
||||
if self.inner.closed() {
|
||||
None
|
||||
} else {
|
||||
Some(self.inner.frame())
|
||||
}
|
||||
}
|
||||
#[pygetset]
|
||||
fn cr_running(&self, _vm: &VirtualMachine) -> bool {
|
||||
|
||||
@@ -66,8 +66,12 @@ impl PyGenerator {
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
fn gi_frame(&self, _vm: &VirtualMachine) -> FrameRef {
|
||||
self.inner.frame()
|
||||
fn gi_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
|
||||
if self.inner.closed() {
|
||||
None
|
||||
} else {
|
||||
Some(self.inner.frame())
|
||||
}
|
||||
}
|
||||
|
||||
#[pygetset]
|
||||
|
||||
@@ -2574,7 +2574,14 @@ pub mod module {
|
||||
headers,
|
||||
trailers,
|
||||
);
|
||||
res.map_err(|err| err.into_pyexception(vm))?;
|
||||
// On macOS, sendfile can return EAGAIN even when some bytes were written.
|
||||
// In that case, we should return the number of bytes written rather than
|
||||
// raising an exception. Only raise an error if no bytes were written.
|
||||
if let Err(err) = res
|
||||
&& written == 0
|
||||
{
|
||||
return Err(err.into_pyexception(vm));
|
||||
}
|
||||
Ok(vm.ctx.new_int(written as u64).into())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user