forked from Rust-related/RustPython
Implement socket.socket.sendmsg (#5205)
* Implement socket.socket.sendmsg * debugger-friendly newlines * Fix control_buf error on macOS --------- Co-authored-by: Jeong YunWon <jeong@youknowone.org>
This commit is contained in:
6
Cargo.lock
generated
6
Cargo.lock
generated
@@ -2570,12 +2570,12 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.5.5"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
|
||||
checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -94,7 +94,7 @@ termios = "0.3.3"
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
gethostname = "0.2.3"
|
||||
socket2 = { version = "0.5.4", features = ["all"] }
|
||||
socket2 = { version = "0.5.6", features = ["all"] }
|
||||
dns-lookup = "2"
|
||||
openssl = { version = "0.10.62", optional = true }
|
||||
openssl-sys = { version = "0.9.80", optional = true }
|
||||
|
||||
@@ -822,7 +822,8 @@ mod _socket {
|
||||
|
||||
impl PySocket {
|
||||
pub fn sock_opt(&self) -> Option<PyMappedRwLockReadGuard<'_, Socket>> {
|
||||
PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.as_ref()).ok()
|
||||
let sock = PyRwLockReadGuard::try_map(self.sock.read(), |sock| sock.as_ref());
|
||||
sock.ok()
|
||||
}
|
||||
|
||||
pub fn sock(&self) -> io::Result<PyMappedRwLockReadGuard<'_, Socket>> {
|
||||
@@ -869,7 +870,8 @@ mod _socket {
|
||||
where
|
||||
F: FnMut() -> io::Result<R>,
|
||||
{
|
||||
self.sock_op_timeout_err(vm, select, self.get_timeout().ok(), f)
|
||||
let timeout = self.get_timeout().ok();
|
||||
self.sock_op_timeout_err(vm, select, timeout, f)
|
||||
}
|
||||
|
||||
fn sock_op_timeout_err<F, R>(
|
||||
@@ -1319,6 +1321,121 @@ mod _socket {
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
#[pymethod]
|
||||
fn sendmsg(
|
||||
&self,
|
||||
buffers: Vec<ArgBytesLike>,
|
||||
ancdata: OptionalArg,
|
||||
flags: OptionalArg<i32>,
|
||||
addr: OptionalOption,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<usize> {
|
||||
let flags = flags.unwrap_or(0);
|
||||
let mut msg = socket2::MsgHdr::new();
|
||||
|
||||
let sockaddr;
|
||||
if let Some(addr) = addr.flatten() {
|
||||
sockaddr = self
|
||||
.extract_address(addr, "sendmsg", vm)
|
||||
.map_err(|e| e.into_pyexception(vm))?;
|
||||
msg = msg.with_addr(&sockaddr);
|
||||
}
|
||||
|
||||
let buffers = buffers
|
||||
.iter()
|
||||
.map(|buf| buf.borrow_buf())
|
||||
.collect::<Vec<_>>();
|
||||
let buffers = buffers
|
||||
.iter()
|
||||
.map(|buf| io::IoSlice::new(buf))
|
||||
.collect::<Vec<_>>();
|
||||
msg = msg.with_buffers(&buffers);
|
||||
|
||||
let control_buf;
|
||||
if let OptionalArg::Present(ancdata) = ancdata {
|
||||
let cmsgs = vm.extract_elements_with(
|
||||
&ancdata,
|
||||
|obj| -> PyResult<(i32, i32, ArgBytesLike)> {
|
||||
let seq: Vec<PyObjectRef> = obj.try_into_value(vm)?;
|
||||
let [lvl, typ, data]: [PyObjectRef; 3] = seq.try_into().map_err(|_| {
|
||||
vm.new_type_error("expected a sequence of length 3".to_owned())
|
||||
})?;
|
||||
Ok((
|
||||
lvl.try_into_value(vm)?,
|
||||
typ.try_into_value(vm)?,
|
||||
data.try_into_value(vm)?,
|
||||
))
|
||||
},
|
||||
)?;
|
||||
control_buf = Self::pack_cmsgs_to_send(&cmsgs, vm)?;
|
||||
if !control_buf.is_empty() {
|
||||
msg = msg.with_control(&control_buf);
|
||||
}
|
||||
}
|
||||
|
||||
self.sock_op(vm, SelectKind::Write, || {
|
||||
let sock = self.sock()?;
|
||||
sock.sendmsg(&msg, flags)
|
||||
})
|
||||
.map_err(|e| e.into_pyexception(vm))
|
||||
}
|
||||
|
||||
// based on nix's implementation
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
fn pack_cmsgs_to_send(
|
||||
cmsgs: &[(i32, i32, ArgBytesLike)],
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Vec<u8>> {
|
||||
use std::{mem, ptr};
|
||||
|
||||
if cmsgs.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let capacity = cmsgs
|
||||
.iter()
|
||||
.map(|(_, _, buf)| buf.len())
|
||||
.try_fold(0, |sum, len| {
|
||||
let space = checked_cmsg_space(len).ok_or_else(|| {
|
||||
vm.new_os_error("ancillary data item too large".to_owned())
|
||||
})?;
|
||||
usize::checked_add(sum, space)
|
||||
.ok_or_else(|| vm.new_os_error("too much ancillary data".to_owned()))
|
||||
})?;
|
||||
|
||||
let mut cmsg_buffer = vec![0u8; capacity];
|
||||
|
||||
// make a dummy msghdr so we can use the CMSG_* apis
|
||||
let mut mhdr = unsafe { mem::zeroed::<libc::msghdr>() };
|
||||
mhdr.msg_control = cmsg_buffer.as_mut_ptr().cast();
|
||||
mhdr.msg_controllen = capacity as _;
|
||||
|
||||
let mut pmhdr: *mut libc::cmsghdr = unsafe { libc::CMSG_FIRSTHDR(&mhdr) };
|
||||
for (lvl, typ, buf) in cmsgs {
|
||||
if pmhdr.is_null() {
|
||||
return Err(vm.new_runtime_error(
|
||||
"unexpected NULL result from CMSG_FIRSTHDR/CMSG_NXTHDR".to_owned(),
|
||||
));
|
||||
}
|
||||
let data = &*buf.borrow_buf();
|
||||
assert_eq!(data.len(), buf.len());
|
||||
// Safe because we know that pmhdr is valid, and we initialized it with
|
||||
// sufficient space
|
||||
unsafe {
|
||||
(*pmhdr).cmsg_level = *lvl;
|
||||
(*pmhdr).cmsg_type = *typ;
|
||||
(*pmhdr).cmsg_len = data.len() as _;
|
||||
ptr::copy_nonoverlapping(data.as_ptr(), libc::CMSG_DATA(pmhdr), data.len());
|
||||
}
|
||||
|
||||
// Safe because mhdr is valid
|
||||
pmhdr = unsafe { libc::CMSG_NXTHDR(&mhdr, pmhdr) };
|
||||
}
|
||||
|
||||
Ok(cmsg_buffer)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn close(&self) -> io::Result<()> {
|
||||
let sock = self.detach();
|
||||
@@ -2364,4 +2481,46 @@ mod _socket {
|
||||
HError,
|
||||
GaiError,
|
||||
}
|
||||
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
fn checked_cmsg_len(len: usize) -> Option<usize> {
|
||||
// SAFETY: CMSG_LEN is always safe
|
||||
let cmsg_len = |length| unsafe { libc::CMSG_LEN(length) };
|
||||
if len as u64 > (i32::MAX as u64 - cmsg_len(0) as u64) {
|
||||
return None;
|
||||
}
|
||||
let res = cmsg_len(len as _) as usize;
|
||||
if res > i32::MAX as usize || res < len {
|
||||
return None;
|
||||
}
|
||||
Some(res)
|
||||
}
|
||||
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
fn checked_cmsg_space(len: usize) -> Option<usize> {
|
||||
// SAFETY: CMSG_SPACE is always safe
|
||||
let cmsg_space = |length| unsafe { libc::CMSG_SPACE(length) };
|
||||
if len as u64 > (i32::MAX as u64 - cmsg_space(1) as u64) {
|
||||
return None;
|
||||
}
|
||||
let res = cmsg_space(len as _) as usize;
|
||||
if res > i32::MAX as usize || res < len {
|
||||
return None;
|
||||
}
|
||||
Some(res)
|
||||
}
|
||||
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
#[pyfunction(name = "CMSG_LEN")]
|
||||
fn cmsg_len(length: usize, vm: &VirtualMachine) -> PyResult<usize> {
|
||||
checked_cmsg_len(length)
|
||||
.ok_or_else(|| vm.new_overflow_error("CMSG_LEN() argument out of range".to_owned()))
|
||||
}
|
||||
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
#[pyfunction(name = "CMSG_SPACE")]
|
||||
fn cmsg_space(length: usize, vm: &VirtualMachine) -> PyResult<usize> {
|
||||
checked_cmsg_space(length)
|
||||
.ok_or_else(|| vm.new_overflow_error("CMSG_SPACE() argument out of range".to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user