diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 0a79468cc..dc94e752b 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -5,7 +5,7 @@ use crate::{ exceptions::{IntoPyException, PyBaseExceptionRef}, function::{FuncArgs, OptionalArg, OptionalOption}, utils::{Either, ToCString}, - IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromBorrowedObject, + IntoPyObject, PyClassImpl, PyObjectRef, PyResult, PyValue, TryFromBorrowedObject, TryFromObject, TypeProtocol, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; @@ -18,7 +18,10 @@ use std::convert::TryFrom; use std::mem::MaybeUninit; use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs}; use std::time::{Duration, Instant}; -use std::{ffi, io}; +use std::{ + ffi, + io::{self, Read, Write}, +}; #[cfg(unix)] type RawSocket = std::os::unix::io::RawFd; @@ -158,13 +161,26 @@ impl Default for PySocket { } } -pub type PySocketRef = PyRef; - #[cfg(windows)] const CLOSED_ERR: i32 = c::WSAENOTSOCK; #[cfg(unix)] const CLOSED_ERR: i32 = c::EBADF; -#[pyimpl(flags(BASETYPE))] + +impl Read for &PySocket { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + (&mut &*self.sock_io()?).read(buf) + } +} +impl Write for &PySocket { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + (&mut &*self.sock_io()?).write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + (&mut &*self.sock_io()?).flush() + } +} + impl PySocket { pub fn sock_opt(&self) -> Option> { PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.get()).ok() @@ -179,94 +195,6 @@ impl PySocket { self.sock_io().map_err(|e| e.into_pyexception(vm)) } - #[pyslot] - fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { - Self::default().into_pyresult_with_type(vm, cls) - } - - #[pymethod(magic)] - fn init( - &self, - family: OptionalArg, - socket_kind: OptionalArg, - proto: OptionalArg, - fileno: OptionalOption, - vm: &VirtualMachine, - ) -> PyResult<()> { - let mut family = family.unwrap_or(-1); - let mut socket_kind = socket_kind.unwrap_or(-1); - let mut proto = proto.unwrap_or(-1); - let fileno = fileno - .flatten() - .map(|obj| get_raw_sock(obj, vm)) - .transpose()?; - let sock; - if let Some(fileno) = fileno { - sock = sock_from_raw(fileno, vm)?; - match sock.local_addr() { - Ok(addr) if family == -1 => family = addr.family() as i32, - Err(e) - if family == -1 - || matches!( - e.raw_os_error(), - Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF)) - ) => - { - std::mem::forget(sock); - return Err(e.into_pyexception(vm)); - } - _ => {} - } - if socket_kind == -1 { - // TODO: when socket2 cuts a new release, type will be available on all os - // socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into(); - let res = unsafe { - c::getsockopt( - sock_fileno(&sock) as _, - c::SOL_SOCKET, - c::SO_TYPE, - &mut socket_kind as *mut libc::c_int as *mut _, - &mut (std::mem::size_of::() as _), - ) - }; - if res < 0 { - return Err(super::os::errno_err(vm)); - } - } - cfg_if::cfg_if! { - if #[cfg(any( - target_os = "android", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "linux", - ))] { - if proto == -1 { - proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into); - } - } else { - proto = 0; - } - } - } else { - if family == -1 { - family = c::AF_INET as i32 - } - if socket_kind == -1 { - socket_kind = c::SOCK_STREAM - } - if proto == -1 { - proto = 0 - } - sock = Socket::new( - Domain::from(family), - SocketType::from(socket_kind), - Some(Protocol::from(proto)), - ) - .map_err(|err| err.into_pyexception(vm))?; - }; - self.init_inner(family, socket_kind, proto, sock, vm) - } - fn init_inner( &self, family: i32, @@ -523,6 +451,97 @@ impl PySocket { Err(err.into()) } } +} + +#[pyimpl(flags(BASETYPE))] +impl PySocket { + #[pyslot] + fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { + Self::default().into_pyresult_with_type(vm, cls) + } + + #[pymethod(magic)] + fn init( + &self, + family: OptionalArg, + socket_kind: OptionalArg, + proto: OptionalArg, + fileno: OptionalOption, + vm: &VirtualMachine, + ) -> PyResult<()> { + let mut family = family.unwrap_or(-1); + let mut socket_kind = socket_kind.unwrap_or(-1); + let mut proto = proto.unwrap_or(-1); + let fileno = fileno + .flatten() + .map(|obj| get_raw_sock(obj, vm)) + .transpose()?; + let sock; + if let Some(fileno) = fileno { + sock = sock_from_raw(fileno, vm)?; + match sock.local_addr() { + Ok(addr) if family == -1 => family = addr.family() as i32, + Err(e) + if family == -1 + || matches!( + e.raw_os_error(), + Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF)) + ) => + { + std::mem::forget(sock); + return Err(e.into_pyexception(vm)); + } + _ => {} + } + if socket_kind == -1 { + // TODO: when socket2 cuts a new release, type will be available on all os + // socket_kind = sock.r#type().map_err(|e| e.into_pyexception(vm))?.into(); + let res = unsafe { + c::getsockopt( + sock_fileno(&sock) as _, + c::SOL_SOCKET, + c::SO_TYPE, + &mut socket_kind as *mut libc::c_int as *mut _, + &mut (std::mem::size_of::() as _), + ) + }; + if res < 0 { + return Err(super::os::errno_err(vm)); + } + } + cfg_if::cfg_if! { + if #[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "linux", + ))] { + if proto == -1 { + proto = sock.protocol().map_err(|e| e.into_pyexception(vm))?.map_or(0, Into::into); + } + } else { + proto = 0; + } + } + } else { + if family == -1 { + family = c::AF_INET as i32 + } + if socket_kind == -1 { + socket_kind = c::SOCK_STREAM + } + if proto == -1 { + proto = 0 + } + sock = Socket::new( + Domain::from(family), + SocketType::from(socket_kind), + Some(Protocol::from(proto)), + ) + .map_err(|err| err.into_pyexception(vm))?; + }; + self.init_inner(family, socket_kind, proto, sock, vm) + } #[pymethod] fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { @@ -919,20 +938,6 @@ impl PySocket { } } -impl io::Read for PySocketRef { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - <&Socket as io::Read>::read(&mut &*self.sock_io()?, buf) - } -} -impl io::Write for PySocketRef { - fn write(&mut self, buf: &[u8]) -> io::Result { - <&Socket as io::Write>::write(&mut &*self.sock_io()?, buf) - } - fn flush(&mut self) -> io::Result<()> { - <&Socket as io::Write>::flush(&mut &*self.sock_io()?) - } -} - struct Address { host: PyStrRef, port: u16, diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index c59fc02f0..7df3d7b8f 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -1,4 +1,4 @@ -use super::socket::{self, PySocketRef}; +use super::socket::{self, PySocket}; use crate::common::lock::{PyRwLock, PyRwLockWriteGuard}; use crate::{ builtins::{pytype, weakref::PyWeak, PyStrRef, PyTypeRef}, @@ -22,6 +22,7 @@ use openssl::{ use std::convert::TryFrom; use std::ffi::CStr; use std::fmt; +use std::io::{Read, Write}; use std::time::Instant; mod sys { @@ -105,9 +106,9 @@ fn nid2obj(nid: Nid) -> Option { unsafe { ptr2obj(sys::OBJ_nid2obj(nid.as_raw())) } } fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option { - unsafe { - let no_name = if no_name { 1 } else { 0 }; - let ptr = obj.as_ptr(); + let no_name = if no_name { 1 } else { 0 }; + let ptr = obj.as_ptr(); + let s = unsafe { let buflen = sys::OBJ_obj2txt(std::ptr::null_mut(), 0, ptr, no_name); assert!(buflen >= 0); if buflen == 0 { @@ -116,10 +117,10 @@ fn obj2txt(obj: &Asn1ObjectRef, no_name: bool) -> Option { let mut buf = vec![0u8; buflen as usize]; let ret = sys::OBJ_obj2txt(buf.as_mut_ptr() as *mut libc::c_char, buflen, ptr, no_name); assert!(ret >= 0); - let s = String::from_utf8(buf) - .unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()); - Some(s) - } + String::from_utf8(buf) + .unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned()) + }; + Some(s) } type PyNid = (libc::c_int, String, String, Option); @@ -232,9 +233,8 @@ fn _ssl_rand_bytes(n: i32, vm: &VirtualMachine) -> PyResult> { return Err(vm.new_value_error("num must be positive".to_owned())); } let mut buf = vec![0; n as usize]; - openssl::rand::rand_bytes(&mut buf) - .map(|()| buf) - .map_err(|e| convert_openssl_error(vm, e)) + openssl::rand::rand_bytes(&mut buf).map_err(|e| convert_openssl_error(vm, e))?; + Ok(buf) } fn _ssl_rand_pseudo_bytes(n: i32, vm: &VirtualMachine) -> PyResult<(Vec, bool)> { @@ -592,7 +592,7 @@ impl PySslContext { } } - let stream = ssl::SslStream::new(ssl, args.sock.clone()) + let stream = ssl::SslStream::new(ssl, SocketStream(args.sock.clone())) .map_err(|e| convert_openssl_error(vm, e))?; // TODO: use this @@ -611,7 +611,7 @@ impl PySslContext { #[derive(FromArgs)] struct WrapSocketArgs { #[pyarg(any)] - sock: PySocketRef, + sock: PyRef, #[pyarg(any)] server_side: bool, #[pyarg(any, default)] @@ -642,16 +642,9 @@ struct LoadCertChainArgs { password: Option>, } -struct SocketTimeout { - // Err is true if the socket is blocking - deadline: Result, -} -impl SocketTimeout { - fn get(s: &socket::PySocket) -> Self { - let deadline = s.get_timeout().map(|d| Instant::now() + d); - Self { deadline } - } -} +// Err is true if the socket is blocking +type SocketDeadline = Result; + enum SelectRet { Nonblocking, TimedOut, @@ -659,50 +652,60 @@ enum SelectRet { Closed, Ok, } -fn ssl_select(sock: &socket::PySocket, needs: SslNeeds, timeout: &SocketTimeout) -> SelectRet { - let sock = match sock.sock_opt() { - Some(s) => s, - None => return SelectRet::Closed, - }; - let timeout = match &timeout.deadline { - Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { - Some(timeout) => timeout, - None => return SelectRet::TimedOut, - }, - Err(true) => return SelectRet::IsBlocking, - Err(false) => return SelectRet::Nonblocking, - }; - let res = socket::sock_select( - &sock, - match needs { - SslNeeds::Read => socket::SelectKind::Read, - SslNeeds::Write => socket::SelectKind::Write, - }, - Some(timeout), - ); - match res { - Ok(true) => SelectRet::TimedOut, - _ => SelectRet::Ok, - } -} + #[derive(Clone, Copy)] enum SslNeeds { Read, Write, } -fn socket_needs( - err: &ssl::Error, - sock: &socket::PySocket, - timeout: &SocketTimeout, -) -> (Option, SelectRet) { - let needs = match err.code() { - ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read), - ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write), - _ => None, - }; - let state = needs.map_or(SelectRet::Ok, |needs| ssl_select(sock, needs, timeout)); - (needs, state) +struct SocketStream(PyRef); + +impl SocketStream { + fn timeout_deadline(&self) -> SocketDeadline { + self.0.get_timeout().map(|d| Instant::now() + d) + } + + fn select(&self, needs: SslNeeds, deadline: &SocketDeadline) -> SelectRet { + let sock = match self.0.sock_opt() { + Some(s) => s, + None => return SelectRet::Closed, + }; + let deadline = match &deadline { + Ok(deadline) => match deadline.checked_duration_since(Instant::now()) { + Some(deadline) => deadline, + None => return SelectRet::TimedOut, + }, + Err(true) => return SelectRet::IsBlocking, + Err(false) => return SelectRet::Nonblocking, + }; + let res = socket::sock_select( + &sock, + match needs { + SslNeeds::Read => socket::SelectKind::Read, + SslNeeds::Write => socket::SelectKind::Write, + }, + Some(deadline), + ); + match res { + Ok(true) => SelectRet::TimedOut, + _ => SelectRet::Ok, + } + } + + fn socket_needs( + &self, + err: &ssl::Error, + deadline: &SocketDeadline, + ) -> (Option, SelectRet) { + let needs = match err.code() { + ssl::ErrorCode::WANT_READ => Some(SslNeeds::Read), + ssl::ErrorCode::WANT_WRITE => Some(SslNeeds::Write), + _ => None, + }; + let state = needs.map_or(SelectRet::Ok, |needs| self.select(needs, deadline)); + (needs, state) + } } fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { @@ -716,7 +719,7 @@ fn socket_closed_error(vm: &VirtualMachine) -> PyBaseExceptionRef { #[derive(PyValue)] struct PySslSocket { ctx: PyRef, - stream: PyRwLock>, + stream: PyRwLock>, socket_type: SslServerOrClient, server_hostname: Option, owner: PyRwLock>, @@ -788,38 +791,37 @@ impl PySslSocket { .map(cipher_to_tuple) } + #[cfg(osslconf = "OPENSSL_NO_COMP")] #[pymethod] fn compression(&self) -> Option<&'static str> { - #[cfg(osslconf = "OPENSSL_NO_COMP")] - { - None + None + } + #[cfg(not(osslconf = "OPENSSL_NO_COMP"))] + #[pymethod] + fn compression(&self) -> Option<&'static str> { + let stream = self.stream.read(); + let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) }; + if comp_method.is_null() { + return None; } - #[cfg(not(osslconf = "OPENSSL_NO_COMP"))] - { - let stream = self.stream.read(); - let comp_method = unsafe { sys::SSL_get_current_compression(stream.ssl().as_ptr()) }; - if comp_method.is_null() { - return None; - } - let typ = unsafe { sys::COMP_get_type(comp_method) }; - let nid = Nid::from_raw(typ); - if nid == Nid::UNDEF { - return None; - } - nid.short_name().ok() + let typ = unsafe { sys::COMP_get_type(comp_method) }; + let nid = Nid::from_raw(typ); + if nid == Nid::UNDEF { + return None; } + nid.short_name().ok() } #[pymethod] fn do_handshake(&self, vm: &VirtualMachine) -> PyResult<()> { let mut stream = self.stream.write(); - let timeout = SocketTimeout::get(stream.get_ref()); + let timeout = stream.get_ref().timeout_deadline(); loop { let err = match stream.do_handshake() { Ok(()) => return Ok(()), Err(e) => e, }; - let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout); + let (needs, state) = stream.get_ref().socket_needs(&err, &timeout); match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -844,8 +846,8 @@ impl PySslSocket { let mut stream = self.stream.write(); let data = data.borrow_buf(); let data = &*data; - let timeout = SocketTimeout::get(stream.get_ref()); - let state = ssl_select(stream.get_ref(), SslNeeds::Write, &timeout); + let timeout = stream.get_ref().timeout_deadline(); + let state = stream.get_ref().select(SslNeeds::Write, &timeout); match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -861,7 +863,7 @@ impl PySslSocket { Ok(len) => return Ok(len), Err(e) => e, }; - let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout); + let (needs, state) = stream.get_ref().socket_needs(&err, &timeout); match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -902,7 +904,7 @@ impl PySslSocket { Some(b) => b, None => buf, }; - let timeout = SocketTimeout::get(stream.get_ref()); + let timeout = stream.get_ref().timeout_deadline(); let count = loop { let err = match stream.ssl_read(buf) { Ok(count) => break count, @@ -913,7 +915,7 @@ impl PySslSocket { { break 0; } - let (needs, state) = socket_needs(&err, stream.get_ref(), &timeout); + let (needs, state) = stream.get_ref().socket_needs(&err, &timeout); match state { SelectRet::TimedOut => { return Err(socket::timeout_error_msg( @@ -996,10 +998,9 @@ fn cipher_to_tuple(cipher: &ssl::SslCipherRef) -> CipherTuple { } fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult { - if binary { - cert.to_der() - .map(|b| vm.ctx.new_bytes(b)) - .map_err(|e| convert_openssl_error(vm, e)) + let r = if binary { + let b = cert.to_der().map_err(|e| convert_openssl_error(vm, e))?; + vm.ctx.new_bytes(b) } else { let dict = vm.ctx.new_dict(); @@ -1073,8 +1074,9 @@ fn cert_to_py(vm: &VirtualMachine, cert: &X509Ref, binary: bool) -> PyResult { dict.set_item("subjectAltName", vm.ctx.new_tuple(san), vm)?; }; - Ok(dict.into_object()) - } + dict.into_object() + }; + Ok(r) } #[allow(non_snake_case)] @@ -1237,3 +1239,21 @@ fn extend_module_platform_specific(module: &PyObjectRef, vm: &VirtualMachine) { #[cfg(not(windows))] fn extend_module_platform_specific(_module: &PyObjectRef, _vm: &VirtualMachine) {} + +impl Read for SocketStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let mut socket: &PySocket = &self.0; + socket.read(buf) + } +} + +impl Write for SocketStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut socket: &PySocket = &self.0; + socket.write(buf) + } + fn flush(&mut self) -> std::io::Result<()> { + let mut socket: &PySocket = &self.0; + socket.flush() + } +}