mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Update _ssl to use select()/poll() to handle sockets w/ timeouts
This commit is contained in:
@@ -235,6 +235,16 @@ impl PySocket {
|
||||
.map_err(|e| e.into_pyexception(vm))
|
||||
}
|
||||
|
||||
/// returns Err(blocking)
|
||||
pub fn get_timeout(&self) -> Result<Duration, bool> {
|
||||
let timeout = self.timeout.load();
|
||||
if timeout > 0.0 {
|
||||
Ok(Duration::from_secs_f64(timeout))
|
||||
} else {
|
||||
Err(timeout != 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn sock_op_err<F, R>(
|
||||
&self,
|
||||
vm: &VirtualMachine,
|
||||
@@ -244,13 +254,7 @@ impl PySocket {
|
||||
where
|
||||
F: FnMut() -> io::Result<R>,
|
||||
{
|
||||
let timeout = self.timeout.load();
|
||||
let timeout = if timeout > 0.0 {
|
||||
Some(Duration::from_secs_f64(timeout))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.sock_op_timeout_err(vm, select, timeout, f)
|
||||
self.sock_op_timeout_err(vm, select, self.get_timeout().ok(), f)
|
||||
}
|
||||
|
||||
fn sock_op_timeout_err<F, R>(
|
||||
@@ -558,12 +562,7 @@ impl PySocket {
|
||||
) -> PyResult<()> {
|
||||
let flags = flags.unwrap_or(0);
|
||||
|
||||
let timeout = self.timeout.load();
|
||||
let timeout = if timeout > 0.0 {
|
||||
Some(Duration::from_secs_f64(timeout))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let timeout = self.get_timeout().ok();
|
||||
|
||||
let deadline = timeout.map(Deadline::new);
|
||||
|
||||
@@ -1007,14 +1006,18 @@ impl IntoPyException for IoOrPyException {
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum SelectKind {
|
||||
pub(super) enum SelectKind {
|
||||
Read,
|
||||
Write,
|
||||
Connect,
|
||||
}
|
||||
|
||||
/// returns true if timed out
|
||||
fn sock_select(sock: &Socket, kind: SelectKind, interval: Option<Duration>) -> io::Result<bool> {
|
||||
pub(super) fn sock_select(
|
||||
sock: &Socket,
|
||||
kind: SelectKind,
|
||||
interval: Option<Duration>,
|
||||
) -> io::Result<bool> {
|
||||
let fd = sock_fileno(sock);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
@@ -1343,7 +1346,7 @@ unsafe fn sock_from_raw_unchecked(fileno: RawSocket) -> Socket {
|
||||
Socket::from_raw_socket(fileno)
|
||||
}
|
||||
}
|
||||
fn sock_fileno(sock: &Socket) -> RawSocket {
|
||||
pub(super) fn sock_fileno(sock: &Socket) -> RawSocket {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::io::AsRawFd;
|
||||
@@ -1368,7 +1371,7 @@ fn into_sock_fileno(sock: Socket) -> RawSocket {
|
||||
}
|
||||
}
|
||||
|
||||
const INVALID_SOCKET: RawSocket = {
|
||||
pub(super) const INVALID_SOCKET: RawSocket = {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
-1
|
||||
@@ -1405,7 +1408,10 @@ fn convert_gai_error(vm: &VirtualMachine, err: dns_lookup::LookupError) -> PyBas
|
||||
}
|
||||
|
||||
fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), "timed out".to_owned())
|
||||
timeout_error_msg(vm, "timed out".to_owned())
|
||||
}
|
||||
pub(super) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
|
||||
vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), msg)
|
||||
}
|
||||
|
||||
fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::os::PyPathLike;
|
||||
use super::socket::PySocketRef;
|
||||
use super::socket::{self, PySocketRef};
|
||||
use crate::builtins::{pytype, weakref::PyWeak, PyStrRef, PyTypeRef};
|
||||
use crate::byteslike::{PyBytesLike, PyRwBytesLike};
|
||||
use crate::common::lock::{PyRwLock, PyRwLockWriteGuard};
|
||||
@@ -607,7 +607,6 @@ impl PySslContext {
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
// #[allow(dead_code)]
|
||||
struct WrapSocketArgs {
|
||||
#[pyarg(any)]
|
||||
sock: PySocketRef,
|
||||
@@ -641,6 +640,76 @@ struct LoadCertChainArgs {
|
||||
password: Option<Either<PyStrRef, PyCallable>>,
|
||||
}
|
||||
|
||||
struct SocketTimeout {
|
||||
// Err is true if the socket is blocking
|
||||
deadline: Result<Instant, bool>,
|
||||
}
|
||||
impl SocketTimeout {
|
||||
fn get(s: &socket::PySocket) -> Self {
|
||||
let deadline = s.get_timeout().map(|d| Instant::now() + d);
|
||||
Self { deadline }
|
||||
}
|
||||
}
|
||||
enum SelectRet {
|
||||
Nonblocking,
|
||||
TimedOut,
|
||||
IsBlocking,
|
||||
Closed,
|
||||
Ok,
|
||||
}
|
||||
fn ssl_select(sock: &socket::PySocket, needs: SslNeeds, timeout: &SocketTimeout) -> SelectRet {
|
||||
let sock = sock.sock();
|
||||
let timeout = match &timeout.deadline {
|
||||
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
|
||||
Some(timeout) => timeout,
|
||||
None => return SelectRet::TimedOut,
|
||||
},
|
||||
Err(true) => return SelectRet::IsBlocking,
|
||||
Err(false) => return SelectRet::Nonblocking,
|
||||
};
|
||||
if socket::sock_fileno(&sock) == socket::INVALID_SOCKET {
|
||||
return SelectRet::Closed;
|
||||
}
|
||||
let res = socket::sock_select(
|
||||
&sock,
|
||||
match needs {
|
||||
SslNeeds::Read => socket::SelectKind::Read,
|
||||
SslNeeds::Write => socket::SelectKind::Write,
|
||||
},
|
||||
Some(timeout),
|
||||
);
|
||||
match res {
|
||||
Ok(true) => SelectRet::TimedOut,
|
||||
_ => SelectRet::Ok,
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Copy)]
|
||||
enum SslNeeds {
|
||||
Read,
|
||||
Write,
|
||||
}
|
||||
|
||||
fn socket_needs(
|
||||
err: &ssl::Error,
|
||||
sock: &socket::PySocket,
|
||||
timeout: &SocketTimeout,
|
||||
) -> (Option<SslNeeds>, SelectRet) {
|
||||
let needs = match err.code() {
|
||||
ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read),
|
||||
ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write),
|
||||
_ => None,
|
||||
};
|
||||
let state = needs.map_or(SelectRet::Ok, |needs| ssl_select(sock, needs, &timeout));
|
||||
(needs, state)
|
||||
}
|
||||
|
||||
fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
vm.new_exception_msg(
|
||||
ssl_error(vm),
|
||||
"Underlying socket has been closed.".to_owned(),
|
||||
)
|
||||
}
|
||||
|
||||
#[pyclass(module = "ssl", name = "_SSLSocket")]
|
||||
struct PySslSocket {
|
||||
ctx: PyRef<PySslContext>,
|
||||
@@ -747,74 +816,127 @@ impl PySslSocket {
|
||||
#[pymethod]
|
||||
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let mut stream = self.stream.write();
|
||||
let timeout = stream.get_ref().timeout.load();
|
||||
let deadline = if timeout > 0.0 {
|
||||
Some((std::time::Duration::from_secs_f64(timeout), Instant::now()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let err = loop {
|
||||
let timeout = SocketTimeout::get(stream.get_ref());
|
||||
loop {
|
||||
let err = match stream.do_handshake() {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => e,
|
||||
};
|
||||
match err.code() {
|
||||
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
|
||||
if let Some((timeout, ref start)) = deadline {
|
||||
if start.elapsed() >= timeout {
|
||||
let socket_timeout = vm.class("_socket", "timeout");
|
||||
return Err(vm.new_exception_msg(
|
||||
socket_timeout,
|
||||
"The handshake operation timed out".to_owned(),
|
||||
));
|
||||
}
|
||||
} else if timeout == 0.0 {
|
||||
// socket's non-blocking, we tried once and now it needs more to read/write
|
||||
break err;
|
||||
}
|
||||
continue; // keep blocking
|
||||
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
|
||||
match state {
|
||||
SelectRet::TimedOut => {
|
||||
return Err(socket::timeout_error_msg(
|
||||
vm,
|
||||
"The handshake operation timed out".to_owned(),
|
||||
))
|
||||
}
|
||||
SelectRet::Closed => return Err(socket_closed_error(vm)),
|
||||
SelectRet::Nonblocking => {}
|
||||
_ => {
|
||||
if needs.is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
_ => break err,
|
||||
}
|
||||
};
|
||||
Err(convert_ssl_error(vm, err))
|
||||
return Err(convert_ssl_error(vm, err));
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
|
||||
let mut stream = self.stream.write();
|
||||
data.with_ref(|b| stream.ssl_write(b))
|
||||
.map_err(|e| convert_ssl_error(vm, e))
|
||||
let data = data.borrow_value();
|
||||
let data = &*data;
|
||||
let timeout = SocketTimeout::get(stream.get_ref());
|
||||
let state = ssl_select(stream.get_ref(), SslNeeds::Write, &timeout);
|
||||
match state {
|
||||
SelectRet::TimedOut => {
|
||||
return Err(socket::timeout_error_msg(
|
||||
vm,
|
||||
"The write operation timed out".to_owned(),
|
||||
))
|
||||
}
|
||||
SelectRet::Closed => return Err(socket_closed_error(vm)),
|
||||
_ => {}
|
||||
}
|
||||
loop {
|
||||
let err = match stream.ssl_write(data) {
|
||||
Ok(len) => return Ok(len),
|
||||
Err(e) => e,
|
||||
};
|
||||
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
|
||||
match state {
|
||||
SelectRet::TimedOut => {
|
||||
return Err(socket::timeout_error_msg(
|
||||
vm,
|
||||
"The write operation timed out".to_owned(),
|
||||
))
|
||||
}
|
||||
SelectRet::Closed => return Err(socket_closed_error(vm)),
|
||||
SelectRet::Nonblocking => {}
|
||||
_ => {
|
||||
if needs.is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(convert_ssl_error(vm, err));
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn read(&self, n: usize, buffer: OptionalArg<PyRwBytesLike>, vm: &VirtualMachine) -> PyResult {
|
||||
let mut stream = self.stream.write();
|
||||
let ret_nread = buffer.is_present();
|
||||
let ssl_res = if let OptionalArg::Present(buffer) = buffer {
|
||||
buffer.with_ref(|buf| stream.ssl_read(buf).map(|n| vm.ctx.new_int(n)))
|
||||
let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer {
|
||||
Either::A(buffer.borrow_value())
|
||||
} else {
|
||||
let mut buf = vec![0u8; n];
|
||||
stream.ssl_read(&mut buf).map(|n| {
|
||||
buf.truncate(n);
|
||||
vm.ctx.new_bytes(buf)
|
||||
})
|
||||
Either::B(vec![0u8; n])
|
||||
};
|
||||
ssl_res.or_else(|e| {
|
||||
if e.code() == ssl::ErrorCode::ZERO_RETURN
|
||||
let buf = match &mut inner_buffer {
|
||||
Either::A(b) => &mut **b,
|
||||
Either::B(b) => b.as_mut_slice(),
|
||||
};
|
||||
let buf = match buf.get_mut(..n) {
|
||||
Some(b) => b,
|
||||
None => buf,
|
||||
};
|
||||
let timeout = SocketTimeout::get(stream.get_ref());
|
||||
let count = loop {
|
||||
let err = match stream.ssl_read(buf) {
|
||||
Ok(count) => break count,
|
||||
Err(e) => e,
|
||||
};
|
||||
if err.code() == ssl::ErrorCode::ZERO_RETURN
|
||||
&& stream.get_shutdown() == ssl::ShutdownState::RECEIVED
|
||||
{
|
||||
Ok(if ret_nread {
|
||||
vm.ctx.new_int(0)
|
||||
} else {
|
||||
vm.ctx.new_bytes(vec![])
|
||||
})
|
||||
} else {
|
||||
Err(convert_ssl_error(vm, e))
|
||||
break 0;
|
||||
}
|
||||
})
|
||||
|
||||
// .map_err(|e| convert_ssl_error(vm, e))?;
|
||||
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
|
||||
match state {
|
||||
SelectRet::TimedOut => {
|
||||
return Err(socket::timeout_error_msg(
|
||||
vm,
|
||||
"The read operation timed out".to_owned(),
|
||||
))
|
||||
}
|
||||
SelectRet::Nonblocking => {}
|
||||
_ => {
|
||||
if needs.is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(convert_ssl_error(vm, err));
|
||||
};
|
||||
let ret = match inner_buffer {
|
||||
Either::A(_buf) => vm.ctx.new_int(count),
|
||||
Either::B(mut buf) => {
|
||||
buf.truncate(n);
|
||||
buf.shrink_to_fit();
|
||||
vm.ctx.new_bytes(buf)
|
||||
}
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user