Implement socket.socket.sendmsg (#5205)

* Implement socket.socket.sendmsg

* debugger-friendly newlines

* Fix control_buf error on macOS

---------

Co-authored-by: Jeong YunWon <jeong@youknowone.org>
This commit is contained in:
Noa
2024-04-21 21:21:10 -05:00
committed by GitHub
parent 3286e683e6
commit 84099514e6
3 changed files with 165 additions and 6 deletions

6
Cargo.lock generated
View File

@@ -2570,12 +2570,12 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
[[package]]
name = "socket2"
version = "0.5.5"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871"
dependencies = [
"libc",
"windows-sys 0.48.0",
"windows-sys 0.52.0",
]
[[package]]

View File

@@ -94,7 +94,7 @@ termios = "0.3.3"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
gethostname = "0.2.3"
socket2 = { version = "0.5.4", features = ["all"] }
socket2 = { version = "0.5.6", features = ["all"] }
dns-lookup = "2"
openssl = { version = "0.10.62", optional = true }
openssl-sys = { version = "0.9.80", optional = true }

View File

@@ -822,7 +822,8 @@ mod _socket {
impl PySocket {
pub fn sock_opt(&self) -> Option<PyMappedRwLockReadGuard<'_, Socket>> {
PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.as_ref()).ok()
let sock = PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.as_ref());
sock.ok()
}
pub fn sock(&self) -> io::Result<PyMappedRwLockReadGuard<'_, Socket>> {
@@ -869,7 +870,8 @@ mod _socket {
where
F: FnMut() -> io::Result<R>,
{
self.sock_op_timeout_err(vm, select, self.get_timeout().ok(), f)
let timeout = self.get_timeout().ok();
self.sock_op_timeout_err(vm, select, timeout, f)
}
fn sock_op_timeout_err<F, R>(
@@ -1319,6 +1321,121 @@ mod _socket {
})
}
#[cfg(all(unix, not(target_os = "redox")))]
#[pymethod]
fn sendmsg(
&self,
buffers: Vec<ArgBytesLike>,
ancdata: OptionalArg,
flags: OptionalArg<i32>,
addr: OptionalOption,
vm: &VirtualMachine,
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
let mut msg = socket2::MsgHdr::new();
let sockaddr;
if let Some(addr) = addr.flatten() {
sockaddr = self
.extract_address(addr, "sendmsg", vm)
.map_err(|e| e.into_pyexception(vm))?;
msg = msg.with_addr(&sockaddr);
}
let buffers = buffers
.iter()
.map(|buf| buf.borrow_buf())
.collect::<Vec<_>>();
let buffers = buffers
.iter()
.map(|buf| io::IoSlice::new(buf))
.collect::<Vec<_>>();
msg = msg.with_buffers(&buffers);
let control_buf;
if let OptionalArg::Present(ancdata) = ancdata {
let cmsgs = vm.extract_elements_with(
&ancdata,
|obj| -> PyResult<(i32, i32, ArgBytesLike)> {
let seq: Vec<PyObjectRef> = obj.try_into_value(vm)?;
let [lvl, typ, data]: [PyObjectRef; 3] = seq.try_into().map_err(|_| {
vm.new_type_error("expected a sequence of length 3".to_owned())
})?;
Ok((
lvl.try_into_value(vm)?,
typ.try_into_value(vm)?,
data.try_into_value(vm)?,
))
},
)?;
control_buf = Self::pack_cmsgs_to_send(&cmsgs, vm)?;
if !control_buf.is_empty() {
msg = msg.with_control(&control_buf);
}
}
self.sock_op(vm, SelectKind::Write, || {
let sock = self.sock()?;
sock.sendmsg(&msg, flags)
})
.map_err(|e| e.into_pyexception(vm))
}
// based on nix's implementation
#[cfg(all(unix, not(target_os = "redox")))]
fn pack_cmsgs_to_send(
cmsgs: &[(i32, i32, ArgBytesLike)],
vm: &VirtualMachine,
) -> PyResult<Vec<u8>> {
use std::{mem, ptr};
if cmsgs.is_empty() {
return Ok(vec![]);
}
let capacity = cmsgs
.iter()
.map(|(_, _, buf)| buf.len())
.try_fold(0, |sum, len| {
let space = checked_cmsg_space(len).ok_or_else(|| {
vm.new_os_error("ancillary data item too large".to_owned())
})?;
usize::checked_add(sum, space)
.ok_or_else(|| vm.new_os_error("too much ancillary data".to_owned()))
})?;
let mut cmsg_buffer = vec![0u8; capacity];
// make a dummy msghdr so we can use the CMSG_* apis
let mut mhdr = unsafe { mem::zeroed::<libc::msghdr>() };
mhdr.msg_control = cmsg_buffer.as_mut_ptr().cast();
mhdr.msg_controllen = capacity as _;
let mut pmhdr: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(&mhdr) };
for (lvl, typ, buf) in cmsgs {
if pmhdr.is_null() {
return Err(vm.new_runtime_error(
"unexpected NULL result from CMSG_FIRSTHDR/CMSG_NXTHDR".to_owned(),
));
}
let data = &*buf.borrow_buf();
assert_eq!(data.len(), buf.len());
// Safe because we know that pmhdr is valid, and we initialized it with
// sufficient space
unsafe {
(*pmhdr).cmsg_level = *lvl;
(*pmhdr).cmsg_type = *typ;
(*pmhdr).cmsg_len = data.len() as _;
ptr::copy_nonoverlapping(data.as_ptr(), libc::CMSG_DATA(pmhdr), data.len());
}
// Safe because mhdr is valid
pmhdr = unsafe { libc::CMSG_NXTHDR(&mhdr, pmhdr) };
}
Ok(cmsg_buffer)
}
#[pymethod]
fn close(&self) -> io::Result<()> {
let sock = self.detach();
@@ -2364,4 +2481,46 @@ mod _socket {
HError,
GaiError,
}
#[cfg(all(unix, not(target_os = "redox")))]
fn checked_cmsg_len(len: usize) -> Option<usize> {
// SAFETY: CMSG_LEN is always safe
let cmsg_len = |length| unsafe { libc::CMSG_LEN(length) };
if len as u64 > (i32::MAX as u64 - cmsg_len(0) as u64) {
return None;
}
let res = cmsg_len(len as _) as usize;
if res > i32::MAX as usize || res < len {
return None;
}
Some(res)
}
#[cfg(all(unix, not(target_os = "redox")))]
fn checked_cmsg_space(len: usize) -> Option<usize> {
// SAFETY: CMSG_SPACE is always safe
let cmsg_space = |length| unsafe { libc::CMSG_SPACE(length) };
if len as u64 > (i32::MAX as u64 - cmsg_space(1) as u64) {
return None;
}
let res = cmsg_space(len as _) as usize;
if res > i32::MAX as usize || res < len {
return None;
}
Some(res)
}
#[cfg(all(unix, not(target_os = "redox")))]
#[pyfunction(name = "CMSG_LEN")]
fn cmsg_len(length: usize, vm: &VirtualMachine) -> PyResult<usize> {
checked_cmsg_len(length)
.ok_or_else(|| vm.new_overflow_error("CMSG_LEN() argument out of range".to_owned()))
}
#[cfg(all(unix, not(target_os = "redox")))]
#[pyfunction(name = "CMSG_SPACE")]
fn cmsg_space(length: usize, vm: &VirtualMachine) -> PyResult<usize> {
checked_cmsg_space(length)
.ok_or_else(|| vm.new_overflow_error("CMSG_SPACE() argument out of range".to_owned()))
}
}