From 84099514e683da49511d44c05bb86de8eb83eed6 Mon Sep 17 00:00:00 2001 From: Noa Date: Sun, 21 Apr 2024 21:21:10 -0500 Subject: [PATCH] Implement socket.socket.sendmsg (#5205) * Implement socket.socket.sendmsg * debugger-friendly newlines * Fix control_buf error on macOS --------- Co-authored-by: Jeong YunWon --- Cargo.lock | 6 +- stdlib/Cargo.toml | 2 +- stdlib/src/socket.rs | 163 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 165 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 20fa9a531..b45f6b2e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2570,12 +2570,12 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 052a583eb..d4386b24a 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -94,7 +94,7 @@ termios = "0.3.3" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] gethostname = "0.2.3" -socket2 = { version = "0.5.4", features = ["all"] } +socket2 = { version = "0.5.6", features = ["all"] } dns-lookup = "2" openssl = { version = "0.10.62", optional = true } openssl-sys = { version = "0.9.80", optional = true } diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index 5aea7be11..f3a86141e 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -822,7 +822,8 @@ mod _socket { impl PySocket { pub fn sock_opt(&self) -> Option> { - PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.as_ref()).ok() + let sock = PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.as_ref()); + sock.ok() } pub fn sock(&self) -> io::Result> { @@ -869,7 +870,8 @@ mod _socket { where F: FnMut() -> io::Result, { - self.sock_op_timeout_err(vm, select, self.get_timeout().ok(), f) + let timeout = self.get_timeout().ok(); + self.sock_op_timeout_err(vm, select, timeout, f) } fn sock_op_timeout_err( @@ -1319,6 +1321,121 @@ mod _socket { }) } + #[cfg(all(unix, not(target_os = "redox")))] + #[pymethod] + fn sendmsg( + &self, + buffers: Vec, + ancdata: OptionalArg, + flags: OptionalArg, + addr: OptionalOption, + vm: &VirtualMachine, + ) -> PyResult { + let flags = flags.unwrap_or(0); + let mut msg = socket2::MsgHdr::new(); + + let sockaddr; + if let Some(addr) = addr.flatten() { + sockaddr = self + .extract_address(addr, "sendmsg", vm) + .map_err(|e| e.into_pyexception(vm))?; + msg = msg.with_addr(&sockaddr); + } + + let buffers = buffers + .iter() + .map(|buf| buf.borrow_buf()) + .collect::>(); + let buffers = buffers + .iter() + .map(|buf| io::IoSlice::new(buf)) + .collect::>(); + msg = msg.with_buffers(&buffers); + + let control_buf; + if let OptionalArg::Present(ancdata) = ancdata { + let cmsgs = vm.extract_elements_with( + &ancdata, + |obj| -> PyResult<(i32, i32, ArgBytesLike)> { + let seq: Vec = obj.try_into_value(vm)?; + let [lvl, typ, data]: [PyObjectRef; 3] = seq.try_into().map_err(|_| { + vm.new_type_error("expected a sequence of length 3".to_owned()) + })?; + Ok(( + lvl.try_into_value(vm)?, + typ.try_into_value(vm)?, + data.try_into_value(vm)?, + )) + }, + )?; + control_buf = Self::pack_cmsgs_to_send(&cmsgs, vm)?; + if !control_buf.is_empty() { + msg = msg.with_control(&control_buf); + } + } + + self.sock_op(vm, SelectKind::Write, || { + let sock = self.sock()?; + sock.sendmsg(&msg, flags) + }) + .map_err(|e| e.into_pyexception(vm)) + } + + // based on nix's implementation + #[cfg(all(unix, not(target_os = "redox")))] + fn pack_cmsgs_to_send( + cmsgs: &[(i32, i32, ArgBytesLike)], + vm: &VirtualMachine, + ) -> PyResult> { + use std::{mem, ptr}; + + if cmsgs.is_empty() { + return Ok(vec![]); + } + + let capacity = cmsgs + .iter() + .map(|(_, _, buf)| buf.len()) + .try_fold(0, |sum, len| { + let space = checked_cmsg_space(len).ok_or_else(|| { + vm.new_os_error("ancillary data item too large".to_owned()) + })?; + usize::checked_add(sum, space) + .ok_or_else(|| vm.new_os_error("too much ancillary data".to_owned())) + })?; + + let mut cmsg_buffer = vec![0u8; capacity]; + + // make a dummy msghdr so we can use the CMSG_* apis + let mut mhdr = unsafe { mem::zeroed::() }; + mhdr.msg_control = cmsg_buffer.as_mut_ptr().cast(); + mhdr.msg_controllen = capacity as _; + + let mut pmhdr: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(&mhdr) }; + for (lvl, typ, buf) in cmsgs { + if pmhdr.is_null() { + return Err(vm.new_runtime_error( + "unexpected NULL result from CMSG_FIRSTHDR/CMSG_NXTHDR".to_owned(), + )); + } + let data = &*buf.borrow_buf(); + assert_eq!(data.len(), buf.len()); + // Safe because we know that pmhdr is valid, and we initialized it with + // sufficient space + unsafe { + (*pmhdr).cmsg_level = *lvl; + (*pmhdr).cmsg_type = *typ; + (*pmhdr).cmsg_len = data.len() as _; + ptr::copy_nonoverlapping(data.as_ptr(), libc::CMSG_DATA(pmhdr), data.len()); + } + + // Safe because mhdr is valid + pmhdr = unsafe { libc::CMSG_NXTHDR(&mhdr, pmhdr) }; + } + + Ok(cmsg_buffer) + } + #[pymethod] fn close(&self) -> io::Result<()> { let sock = self.detach(); @@ -2364,4 +2481,46 @@ mod _socket { HError, GaiError, } + + #[cfg(all(unix, not(target_os = "redox")))] + fn checked_cmsg_len(len: usize) -> Option { + // SAFETY: CMSG_LEN is always safe + let cmsg_len = |length| unsafe { libc::CMSG_LEN(length) }; + if len as u64 > (i32::MAX as u64 - cmsg_len(0) as u64) { + return None; + } + let res = cmsg_len(len as _) as usize; + if res > i32::MAX as usize || res < len { + return None; + } + Some(res) + } + + #[cfg(all(unix, not(target_os = "redox")))] + fn checked_cmsg_space(len: usize) -> Option { + // SAFETY: CMSG_SPACE is always safe + let cmsg_space = |length| unsafe { libc::CMSG_SPACE(length) }; + if len as u64 > (i32::MAX as u64 - cmsg_space(1) as u64) { + return None; + } + let res = cmsg_space(len as _) as usize; + if res > i32::MAX as usize || res < len { + return None; + } + Some(res) + } + + #[cfg(all(unix, not(target_os = "redox")))] + #[pyfunction(name = "CMSG_LEN")] + fn cmsg_len(length: usize, vm: &VirtualMachine) -> PyResult { + checked_cmsg_len(length) + .ok_or_else(|| vm.new_overflow_error("CMSG_LEN() argument out of range".to_owned())) + } + + #[cfg(all(unix, not(target_os = "redox")))] + #[pyfunction(name = "CMSG_SPACE")] + fn cmsg_space(length: usize, vm: &VirtualMachine) -> PyResult { + checked_cmsg_space(length) + .ok_or_else(|| vm.new_overflow_error("CMSG_SPACE() argument out of range".to_owned())) + } }