Update _ssl to use select()/poll() to handle sockets w/ timeouts

This commit is contained in:
Noah
2021-05-13 22:53:32 -05:00
parent 0f646a8f86
commit 6bc27257c4
2 changed files with 195 additions and 67 deletions

View File

@@ -235,6 +235,16 @@ impl PySocket {
.map_err(|e| e.into_pyexception(vm))
}
/// returns Err(blocking)
pub fn get_timeout(&self) -> Result<Duration, bool> {
let timeout = self.timeout.load();
if timeout > 0.0 {
Ok(Duration::from_secs_f64(timeout))
} else {
Err(timeout != 0.0)
}
}
fn sock_op_err<F, R>(
&self,
vm: &VirtualMachine,
@@ -244,13 +254,7 @@ impl PySocket {
where
F: FnMut() -> io::Result<R>,
{
let timeout = self.timeout.load();
let timeout = if timeout > 0.0 {
Some(Duration::from_secs_f64(timeout))
} else {
None
};
self.sock_op_timeout_err(vm, select, timeout, f)
self.sock_op_timeout_err(vm, select, self.get_timeout().ok(), f)
}
fn sock_op_timeout_err<F, R>(
@@ -558,12 +562,7 @@ impl PySocket {
) -> PyResult<()> {
let flags = flags.unwrap_or(0);
let timeout = self.timeout.load();
let timeout = if timeout > 0.0 {
Some(Duration::from_secs_f64(timeout))
} else {
None
};
let timeout = self.get_timeout().ok();
let deadline = timeout.map(Deadline::new);
@@ -1007,14 +1006,18 @@ impl IntoPyException for IoOrPyException {
}
#[derive(Copy, Clone)]
enum SelectKind {
pub(super) enum SelectKind {
Read,
Write,
Connect,
}
/// returns true if timed out
fn sock_select(sock: &Socket, kind: SelectKind, interval: Option<Duration>) -> io::Result<bool> {
pub(super) fn sock_select(
sock: &Socket,
kind: SelectKind,
interval: Option<Duration>,
) -> io::Result<bool> {
let fd = sock_fileno(sock);
#[cfg(unix)]
{
@@ -1343,7 +1346,7 @@ unsafe fn sock_from_raw_unchecked(fileno: RawSocket) -> Socket {
Socket::from_raw_socket(fileno)
}
}
fn sock_fileno(sock: &Socket) -> RawSocket {
pub(super) fn sock_fileno(sock: &Socket) -> RawSocket {
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
@@ -1368,7 +1371,7 @@ fn into_sock_fileno(sock: Socket) -> RawSocket {
}
}
const INVALID_SOCKET: RawSocket = {
pub(super) const INVALID_SOCKET: RawSocket = {
#[cfg(unix)]
{
-1
@@ -1405,7 +1408,10 @@ fn convert_gai_error(vm: &VirtualMachine, err: dns_lookup::LookupError) -> PyBas
}
fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), "timed out".to_owned())
timeout_error_msg(vm, "timed out".to_owned())
}
pub(super) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), msg)
}
fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String {

View File

@@ -1,5 +1,5 @@
use super::os::PyPathLike;
use super::socket::PySocketRef;
use super::socket::{self, PySocketRef};
use crate::builtins::{pytype, weakref::PyWeak, PyStrRef, PyTypeRef};
use crate::byteslike::{PyBytesLike, PyRwBytesLike};
use crate::common::lock::{PyRwLock, PyRwLockWriteGuard};
@@ -607,7 +607,6 @@ impl PySslContext {
}
#[derive(FromArgs)]
// #[allow(dead_code)]
struct WrapSocketArgs {
#[pyarg(any)]
sock: PySocketRef,
@@ -641,6 +640,76 @@ struct LoadCertChainArgs {
password: Option<Either<PyStrRef, PyCallable>>,
}
struct SocketTimeout {
// Err is true if the socket is blocking
deadline: Result<Instant, bool>,
}
impl SocketTimeout {
fn get(s: &socket::PySocket) -> Self {
let deadline = s.get_timeout().map(|d| Instant::now() + d);
Self { deadline }
}
}
enum SelectRet {
Nonblocking,
TimedOut,
IsBlocking,
Closed,
Ok,
}
fn ssl_select(sock: &socket::PySocket, needs: SslNeeds, timeout: &SocketTimeout) -> SelectRet {
let sock = sock.sock();
let timeout = match &timeout.deadline {
Ok(deadline) => match deadline.checked_duration_since(Instant::now()) {
Some(timeout) => timeout,
None => return SelectRet::TimedOut,
},
Err(true) => return SelectRet::IsBlocking,
Err(false) => return SelectRet::Nonblocking,
};
if socket::sock_fileno(&sock) == socket::INVALID_SOCKET {
return SelectRet::Closed;
}
let res = socket::sock_select(
&sock,
match needs {
SslNeeds::Read => socket::SelectKind::Read,
SslNeeds::Write => socket::SelectKind::Write,
},
Some(timeout),
);
match res {
Ok(true) => SelectRet::TimedOut,
_ => SelectRet::Ok,
}
}
#[derive(Clone, Copy)]
enum SslNeeds {
Read,
Write,
}
fn socket_needs(
err: &ssl::Error,
sock: &socket::PySocket,
timeout: &SocketTimeout,
) -> (Option<SslNeeds>, SelectRet) {
let needs = match err.code() {
ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read),
ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write),
_ => None,
};
let state = needs.map_or(SelectRet::Ok, |needs| ssl_select(sock, needs, &timeout));
(needs, state)
}
fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(
ssl_error(vm),
"Underlying socket has been closed.".to_owned(),
)
}
#[pyclass(module = "ssl", name = "_SSLSocket")]
struct PySslSocket {
ctx: PyRef<PySslContext>,
@@ -747,74 +816,127 @@ impl PySslSocket {
#[pymethod]
fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> {
let mut stream = self.stream.write();
let timeout = stream.get_ref().timeout.load();
let deadline = if timeout > 0.0 {
Some((std::time::Duration::from_secs_f64(timeout), Instant::now()))
} else {
None
};
let err = loop {
let timeout = SocketTimeout::get(stream.get_ref());
loop {
let err = match stream.do_handshake() {
Ok(()) => return Ok(()),
Err(e) => e,
};
match err.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
if let Some((timeout, ref start)) = deadline {
if start.elapsed() >= timeout {
let socket_timeout = vm.class("_socket", "timeout");
return Err(vm.new_exception_msg(
socket_timeout,
"The handshake operation timed out".to_owned(),
));
}
} else if timeout == 0.0 {
// socket's non-blocking, we tried once and now it needs more to read/write
break err;
}
continue; // keep blocking
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
vm,
"The handshake operation timed out".to_owned(),
))
}
SelectRet::Closed => return Err(socket_closed_error(vm)),
SelectRet::Nonblocking => {}
_ => {
if needs.is_some() {
continue;
}
}
_ => break err,
}
};
Err(convert_ssl_error(vm, err))
return Err(convert_ssl_error(vm, err));
}
}
#[pymethod]
fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult<usize> {
let mut stream = self.stream.write();
data.with_ref(|b| stream.ssl_write(b))
.map_err(|e| convert_ssl_error(vm, e))
let data = data.borrow_value();
let data = &*data;
let timeout = SocketTimeout::get(stream.get_ref());
let state = ssl_select(stream.get_ref(), SslNeeds::Write, &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
vm,
"The write operation timed out".to_owned(),
))
}
SelectRet::Closed => return Err(socket_closed_error(vm)),
_ => {}
}
loop {
let err = match stream.ssl_write(data) {
Ok(len) => return Ok(len),
Err(e) => e,
};
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
vm,
"The write operation timed out".to_owned(),
))
}
SelectRet::Closed => return Err(socket_closed_error(vm)),
SelectRet::Nonblocking => {}
_ => {
if needs.is_some() {
continue;
}
}
}
return Err(convert_ssl_error(vm, err));
}
}
#[pymethod]
fn read(&self, n: usize, buffer: OptionalArg<PyRwBytesLike>, vm: &VirtualMachine) -> PyResult {
let mut stream = self.stream.write();
let ret_nread = buffer.is_present();
let ssl_res = if let OptionalArg::Present(buffer) = buffer {
buffer.with_ref(|buf| stream.ssl_read(buf).map(|n| vm.ctx.new_int(n)))
let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer {
Either::A(buffer.borrow_value())
} else {
let mut buf = vec![0u8; n];
stream.ssl_read(&mut buf).map(|n| {
buf.truncate(n);
vm.ctx.new_bytes(buf)
})
Either::B(vec![0u8; n])
};
ssl_res.or_else(|e| {
if e.code() == ssl::ErrorCode::ZERO_RETURN
let buf = match &mut inner_buffer {
Either::A(b) => &mut **b,
Either::B(b) => b.as_mut_slice(),
};
let buf = match buf.get_mut(..n) {
Some(b) => b,
None => buf,
};
let timeout = SocketTimeout::get(stream.get_ref());
let count = loop {
let err = match stream.ssl_read(buf) {
Ok(count) => break count,
Err(e) => e,
};
if err.code() == ssl::ErrorCode::ZERO_RETURN
&& stream.get_shutdown() == ssl::ShutdownState::RECEIVED
{
Ok(if ret_nread {
vm.ctx.new_int(0)
} else {
vm.ctx.new_bytes(vec![])
})
} else {
Err(convert_ssl_error(vm, e))
break 0;
}
})
// .map_err(|e| convert_ssl_error(vm, e))?;
let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout);
match state {
SelectRet::TimedOut => {
return Err(socket::timeout_error_msg(
vm,
"The read operation timed out".to_owned(),
))
}
SelectRet::Nonblocking => {}
_ => {
if needs.is_some() {
continue;
}
}
}
return Err(convert_ssl_error(vm, err));
};
let ret = match inner_buffer {
Either::A(_buf) => vm.ctx.new_int(count),
Either::B(mut buf) => {
buf.truncate(n);
buf.shrink_to_fit();
vm.ctx.new_bytes(buf)
}
};
Ok(ret)
}
}