diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 0e6f94eb4..292ddb993 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -235,6 +235,16 @@ impl PySocket { .map_err(|e| e.into_pyexception(vm)) } + /// returns Err(blocking) + pub fn get_timeout(&self) -> Result { + let timeout = self.timeout.load(); + if timeout > 0.0 { + Ok(Duration::from_secs_f64(timeout)) + } else { + Err(timeout != 0.0) + } + } + fn sock_op_err( &self, vm: &VirtualMachine, @@ -244,13 +254,7 @@ impl PySocket { where F: FnMut() -> io::Result, { - let timeout = self.timeout.load(); - let timeout = if timeout > 0.0 { - Some(Duration::from_secs_f64(timeout)) - } else { - None - }; - self.sock_op_timeout_err(vm, select, timeout, f) + self.sock_op_timeout_err(vm, select, self.get_timeout().ok(), f) } fn sock_op_timeout_err( @@ -558,12 +562,7 @@ impl PySocket { ) -> PyResult<()> { let flags = flags.unwrap_or(0); - let timeout = self.timeout.load(); - let timeout = if timeout > 0.0 { - Some(Duration::from_secs_f64(timeout)) - } else { - None - }; + let timeout = self.get_timeout().ok(); let deadline = timeout.map(Deadline::new); @@ -1007,14 +1006,18 @@ impl IntoPyException for IoOrPyException { } #[derive(Copy, Clone)] -enum SelectKind { +pub(super) enum SelectKind { Read, Write, Connect, } /// returns true if timed out -fn sock_select(sock: &Socket, kind: SelectKind, interval: Option) -> io::Result { +pub(super) fn sock_select( + sock: &Socket, + kind: SelectKind, + interval: Option, +) -> io::Result { let fd = sock_fileno(sock); #[cfg(unix)] { @@ -1343,7 +1346,7 @@ unsafe fn sock_from_raw_unchecked(fileno: RawSocket) -> Socket { Socket::from_raw_socket(fileno) } } -fn sock_fileno(sock: &Socket) -> RawSocket { +pub(super) fn sock_fileno(sock: &Socket) -> RawSocket { #[cfg(unix)] { use std::os::unix::io::AsRawFd; @@ -1368,7 +1371,7 @@ fn into_sock_fileno(sock: Socket) -> RawSocket { } } -const INVALID_SOCKET: RawSocket = { +pub(super) const INVALID_SOCKET: RawSocket = { #[cfg(unix)] { -1 @@ -1405,7 +1408,10 @@ fn convert_gai_error(vm: &VirtualMachine, err: dns_lookup::LookupError) -> PyBas } fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), "timed out".to_owned()) + timeout_error_msg(vm, "timed out".to_owned()) +} +pub(super) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { + vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), msg) } fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String { diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index 6d54fefa2..836a44c88 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -1,5 +1,5 @@ use super::os::PyPathLike; -use super::socket::PySocketRef; +use super::socket::{self, PySocketRef}; use crate::builtins::{pytype, weakref::PyWeak, PyStrRef, PyTypeRef}; use crate::byteslike::{PyBytesLike, PyRwBytesLike}; use crate::common::lock::{PyRwLock, PyRwLockWriteGuard}; @@ -607,7 +607,6 @@ impl PySslContext { } #[derive(FromArgs)] -// #[allow(dead_code)] struct WrapSocketArgs { #[pyarg(any)] sock: PySocketRef, @@ -641,6 +640,76 @@ struct LoadCertChainArgs { password: Option>, } +struct SocketTimeout { + // Err is true if the socket is blocking + deadline: Result, +} +impl SocketTimeout { + fn get(s: &socket::PySocket) -> Self { + let deadline = s.get_timeout().map(|d| Instant::now() + d); + Self { deadline } + } +} +enum SelectRet { + Nonblocking, + TimedOut, + IsBlocking, + Closed, + Ok, +} +fn ssl_select(sock: &socket::PySocket, needs: SslNeeds, timeout: &SocketTimeout) -> SelectRet { + let sock = sock.sock(); + let timeout = match &timeout.deadline { + Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { + Some(timeout) => timeout, + None => return SelectRet::TimedOut, + }, + Err(true) => return SelectRet::IsBlocking, + Err(false) => return SelectRet::Nonblocking, + }; + if socket::sock_fileno(&sock) == socket::INVALID_SOCKET { + return SelectRet::Closed; + } + let res = socket::sock_select( + &sock, + match needs { + SslNeeds::Read => socket::SelectKind::Read, + SslNeeds::Write => socket::SelectKind::Write, + }, + Some(timeout), + ); + match res { + Ok(true) => SelectRet::TimedOut, + _ => SelectRet::Ok, + } +} +#[derive(Clone, Copy)] +enum SslNeeds { + Read, + Write, +} + +fn socket_needs( + err: &ssl::Error, + sock: &socket::PySocket, + timeout: &SocketTimeout, +) -> (Option, SelectRet) { + let needs = match err.code() { + ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read), + ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write), + _ => None, + }; + let state = needs.map_or(SelectRet::Ok, |needs| ssl_select(sock, needs, &timeout)); + (needs, state) +} + +fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg( + ssl_error(vm), + "Underlying socket has been closed.".to_owned(), + ) +} + #[pyclass(module = "ssl", name = "_SSLSocket")] struct PySslSocket { ctx: PyRef, @@ -747,74 +816,127 @@ impl PySslSocket { #[pymethod] fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { let mut stream = self.stream.write(); - let timeout = stream.get_ref().timeout.load(); - let deadline = if timeout > 0.0 { - Some((std::time::Duration::from_secs_f64(timeout), Instant::now())) - } else { - None - }; - let err = loop { + let timeout = SocketTimeout::get(stream.get_ref()); + loop { let err = match stream.do_handshake() { Ok(()) => return Ok(()), Err(e) => e, }; - match err.code() { - ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => { - if let Some((timeout, ref start)) = deadline { - if start.elapsed() >= timeout { - let socket_timeout = vm.class("_socket", "timeout"); - return Err(vm.new_exception_msg( - socket_timeout, - "The handshake operation timed out".to_owned(), - )); - } - } else if timeout == 0.0 { - // socket's non-blocking, we tried once and now it needs more to read/write - break err; - } - continue; // keep blocking + let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The handshake operation timed out".to_owned(), + )) + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + SelectRet::Nonblocking => {} + _ => { + if needs.is_some() { + continue; + } } - _ => break err, } - }; - Err(convert_ssl_error(vm, err)) + return Err(convert_ssl_error(vm, err)); + } } #[pymethod] fn write(&self, data: PyBytesLike, vm: &VirtualMachine) -> PyResult { let mut stream = self.stream.write(); - data.with_ref(|b| stream.ssl_write(b)) - .map_err(|e| convert_ssl_error(vm, e)) + let data = data.borrow_value(); + let data = &*data; + let timeout = SocketTimeout::get(stream.get_ref()); + let state = ssl_select(stream.get_ref(), SslNeeds::Write, &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The write operation timed out".to_owned(), + )) + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + _ => {} + } + loop { + let err = match stream.ssl_write(data) { + Ok(len) => return Ok(len), + Err(e) => e, + }; + let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The write operation timed out".to_owned(), + )) + } + SelectRet::Closed => return Err(socket_closed_error(vm)), + SelectRet::Nonblocking => {} + _ => { + if needs.is_some() { + continue; + } + } + } + return Err(convert_ssl_error(vm, err)); + } } #[pymethod] fn read(&self, n: usize, buffer: OptionalArg, vm: &VirtualMachine) -> PyResult { let mut stream = self.stream.write(); - let ret_nread = buffer.is_present(); - let ssl_res = if let OptionalArg::Present(buffer) = buffer { - buffer.with_ref(|buf| stream.ssl_read(buf).map(|n| vm.ctx.new_int(n))) + let mut inner_buffer = if let OptionalArg::Present(buffer) = &buffer { + Either::A(buffer.borrow_value()) } else { - let mut buf = vec![0u8; n]; - stream.ssl_read(&mut buf).map(|n| { - buf.truncate(n); - vm.ctx.new_bytes(buf) - }) + Either::B(vec![0u8; n]) }; - ssl_res.or_else(|e| { - if e.code() == ssl::ErrorCode::ZERO_RETURN + let buf = match &mut inner_buffer { + Either::A(b) => &mut **b, + Either::B(b) => b.as_mut_slice(), + }; + let buf = match buf.get_mut(..n) { + Some(b) => b, + None => buf, + }; + let timeout = SocketTimeout::get(stream.get_ref()); + let count = loop { + let err = match stream.ssl_read(buf) { + Ok(count) => break count, + Err(e) => e, + }; + if err.code() == ssl::ErrorCode::ZERO_RETURN && stream.get_shutdown() == ssl::ShutdownState::RECEIVED { - Ok(if ret_nread { - vm.ctx.new_int(0) - } else { - vm.ctx.new_bytes(vec![]) - }) - } else { - Err(convert_ssl_error(vm, e)) + break 0; } - }) - - // .map_err(|e| convert_ssl_error(vm, e))?; + let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout); + match state { + SelectRet::TimedOut => { + return Err(socket::timeout_error_msg( + vm, + "The read operation timed out".to_owned(), + )) + } + SelectRet::Nonblocking => {} + _ => { + if needs.is_some() { + continue; + } + } + } + return Err(convert_ssl_error(vm, err)); + }; + let ret = match inner_buffer { + Either::A(_buf) => vm.ctx.new_int(count), + Either::B(mut buf) => { + buf.truncate(n); + buf.shrink_to_fit(); + vm.ctx.new_bytes(buf) + } + }; + Ok(ret) } }