Add sock.recvfrom_into + other

This commit is contained in:
Noah
2021-05-03 14:23:08 -05:00
parent 43457da22e
commit 2dbe4a9c9a
2 changed files with 125 additions and 45 deletions

View File

@@ -1204,15 +1204,20 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
impl TryFromObject for std::time::Duration {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
use std::time::Duration;
u64::try_from_object(vm, obj.clone())
.map(Duration::from_secs)
.or_else(|_| f64::try_from_object(vm, obj.clone()).map(Duration::from_secs_f64))
.map_err(|_| {
vm.new_type_error(format!(
"expected an int or float for duration, got {}",
obj.class()
))
})
if let Some(float) = obj.payload::<PyFloat>() {
Ok(Duration::from_secs_f64(float.to_f64()))
} else if let Some(int) = vm.to_index_opt(obj.clone()) {
let sec = int?
.borrow_value()
.to_u64()
.ok_or_else(|| vm.new_value_error("value out of range".to_owned()))?;
Ok(Duration::from_secs(sec))
} else {
Err(vm.new_type_error(format!(
"expected an int or float for duration, got {}",
obj.class()
)))
}
}
}

View File

@@ -2,6 +2,7 @@ use crossbeam_utils::atomic::AtomicCell;
use gethostname::gethostname;
#[cfg(all(unix, not(target_os = "redox")))]
use nix::unistd::sethostname;
use num_traits::ToPrimitive;
use socket2::{Domain, Protocol, Socket, Type as SocketType};
use std::convert::TryFrom;
use std::io;
@@ -9,7 +10,6 @@ use std::mem::MaybeUninit;
use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs};
use std::time::{Duration, Instant};
use crate::builtins::bytes::PyBytesRef;
use crate::builtins::int;
use crate::builtins::pystr::PyStrRef;
use crate::builtins::pytype::PyTypeRef;
@@ -17,7 +17,7 @@ use crate::builtins::tuple::PyTupleRef;
use crate::byteslike::{PyBytesLike, PyRwBytesLike};
use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard};
use crate::exceptions::{IntoPyException, PyBaseExceptionRef};
use crate::function::{FuncArgs, OptionalArg};
use crate::function::{FuncArgs, OptionalArg, OptionalOption};
use crate::pyobject::{
BorrowValue, Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue,
StaticType, TryFromObject, TypeProtocol,
@@ -122,14 +122,17 @@ impl PySocket {
#[pymethod(name = "__init__")]
fn init(
&self,
mut family: i32,
mut socket_kind: i32,
mut proto: i32,
fileno: Option<PyObjectRef>,
family: OptionalArg<i32>,
socket_kind: OptionalArg<i32>,
proto: OptionalArg<i32>,
fileno: OptionalOption<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
let mut family = family.unwrap_or(-1);
let mut socket_kind = socket_kind.unwrap_or(-1);
let mut proto = proto.unwrap_or(-1);
// should really just be to_index() but test_socket tests the error messages explicitly
let fileno = match fileno {
let fileno = match fileno.flatten() {
Some(o) if o.isinstance(&vm.ctx.types.float_type) => {
return Err(vm.new_type_error("integer argument expected, got float".to_owned()))
}
@@ -153,7 +156,8 @@ impl PySocket {
Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF))
) =>
{
return Err(e.into_pyexception(vm))
std::mem::forget(sock);
return Err(e.into_pyexception(vm));
}
_ => {}
}
@@ -481,19 +485,24 @@ impl PySocket {
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
let sock = self.sock();
let mut buf = buf.borrow_value();
let buf = &mut *buf;
self.sock_op(vm, SelectKind::Read, || {
buf.with_ref(|buf| sock.recv_with_flags(slice_as_uninit(buf), flags))
sock.recv_with_flags(slice_as_uninit(buf), flags)
})
}
#[pymethod]
fn recvfrom(
&self,
bufsize: usize,
bufsize: isize,
flags: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<(Vec<u8>, PyObjectRef)> {
let flags = flags.unwrap_or(0);
let bufsize = bufsize
.to_usize()
.ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom".to_owned()))?;
let mut buffer = Vec::with_capacity(bufsize);
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
self.sock()
@@ -503,6 +512,35 @@ impl PySocket {
Ok((buffer, get_addr_tuple(&addr, vm)))
}
#[pymethod]
fn recvfrom_into(
&self,
buf: PyRwBytesLike,
nbytes: OptionalArg<isize>,
flags: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<(usize, PyObjectRef)> {
let mut buf = buf.borrow_value();
let buf = &mut *buf;
let buf = match nbytes {
OptionalArg::Present(i) => {
let i = i.to_usize().ok_or_else(|| {
vm.new_value_error("negative buffersize in recvfrom_into".to_owned())
})?;
buf.get_mut(..i).ok_or_else(|| {
vm.new_value_error("nbytes is greater than the length of the buffer".to_owned())
})?
}
OptionalArg::Missing => buf,
};
let flags = flags.unwrap_or(0);
let sock = self.sock();
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
sock.recv_from_with_flags(slice_as_uninit(buf), flags)
})?;
Ok((n, get_addr_tuple(&addr, vm)))
}
#[pymethod]
fn send(
&self,
@@ -511,8 +549,10 @@ impl PySocket {
vm: &VirtualMachine,
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
let buf = bytes.borrow_value();
let buf = &*buf;
self.sock_op(vm, SelectKind::Write, || {
bytes.with_ref(|b| self.sock().send_with_flags(b, flags))
self.sock().send_with_flags(buf, flags)
})
}
@@ -534,20 +574,19 @@ impl PySocket {
let deadline = timeout.map(Deadline::new);
let buf_len = bytes.len();
let buf = bytes.borrow_value();
let buf = &*buf;
let mut buf_offset = 0;
// now we have like 3 layers of interrupt loop :)
while buf_offset < buf_len {
while buf_offset < buf.len() {
let interval = deadline
.as_ref()
.map(|d| d.time_until().map_err(|e| e.into_pyexception(vm)))
.transpose()?;
self.sock_op_timeout_err(vm, SelectKind::Write, interval, || {
bytes.with_ref(|b| {
let subbuf = &b[buf_offset..];
buf_offset += self.sock().send_with_flags(subbuf, flags)?;
Ok(())
})
let subbuf = &buf[buf_offset..];
buf_offset += self.sock().send_with_flags(subbuf, flags)?;
Ok(())
})
.map_err(|e| e.into_pyexception(vm))?;
vm.check_signals()?;
@@ -565,12 +604,21 @@ impl PySocket {
) -> PyResult<usize> {
// signature is bytes[, flags], address
let (flags, address) = match arg3 {
OptionalArg::Present(arg3) => (i32::try_from_object(vm, arg2)?, arg3),
OptionalArg::Present(arg3) => {
// should just be i32::try_from_obj but tests check for error message
let int = vm.to_index_opt(arg2).unwrap_or_else(|| {
Err(vm.new_type_error("an integer is required".to_owned()))
})?;
let flags = int::try_to_primitive::<i32>(int.borrow_value(), vm)?;
(flags, arg3)
}
OptionalArg::Missing => (0, arg2),
};
let addr = self.extract_address(address, "sendto", vm)?;
let buf = bytes.borrow_value();
let buf = &*buf;
self.sock_op(vm, SelectKind::Write, || {
bytes.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags))
self.sock().send_to_with_flags(buf, &addr, flags)
})
}
@@ -758,6 +806,18 @@ impl PySocket {
fn proto(&self) -> i32 {
self.proto.load()
}
#[pymethod(magic)]
fn repr(&self) -> String {
format!(
"<socket object, fd={}, family={}, type={}, proto={}>",
// cast because INVALID_SOCKET is unsigned, so would show usize::MAX instead of -1
sock_fileno(&self.sock()) as i64,
self.family.load(),
self.kind.load(),
self.proto.load(),
)
}
}
impl io::Read for PySocketRef {
@@ -799,7 +859,6 @@ impl TryFromObject for Address {
impl Address {
fn from_tuple(tuple: &[PyObjectRef], vm: &VirtualMachine) -> PyResult<Self> {
use num_traits::ToPrimitive;
let host = PyStrRef::try_from_object(vm, tuple[0].clone())?;
let port = i32::try_from_object(vm, tuple[1].clone())?;
let port = port
@@ -875,8 +934,9 @@ fn _socket_inet_aton(ip_string: PyStrRef, vm: &VirtualMachine) -> PyResult<Vec<u
.map_err(|_| vm.new_os_error("illegal IP address string passed to inet_aton".to_owned()))
}
fn _socket_inet_ntoa(packed_ip: PyBytesRef, vm: &VirtualMachine) -> PyResult {
let packed_ip = <&[u8; 4]>::try_from(&**packed_ip)
fn _socket_inet_ntoa(packed_ip: PyBytesLike, vm: &VirtualMachine) -> PyResult {
let packed_ip = packed_ip.borrow_value();
let packed_ip = <&[u8; 4]>::try_from(&*packed_ip)
.map_err(|_| vm.new_os_error("packed IP wrong length for inet_ntoa".to_owned()))?;
Ok(vm.ctx.new_str(Ipv4Addr::from(*packed_ip).to_string()))
}
@@ -889,10 +949,15 @@ fn _socket_getservbyname(
use std::ffi::CString;
let cstr_name = CString::new(servicename.borrow_value())
.map_err(|_| vm.new_value_error("embedded null character".to_owned()))?;
let protocolname = protocolname.as_ref().map_or("", |s| s.borrow_value());
let cstr_proto = CString::new(protocolname)
let cstr_proto = protocolname
.as_ref()
.map(|s| CString::new(s.borrow_value()))
.transpose()
.map_err(|_| vm.new_value_error("embedded null character".to_owned()))?;
let serv = unsafe { c::getservbyname(cstr_name.as_ptr(), cstr_proto.as_ptr()) };
let cstr_proto = cstr_proto
.as_ref()
.map_or_else(std::ptr::null, |s| s.as_ptr());
let serv = unsafe { c::getservbyname(cstr_name.as_ptr(), cstr_proto) };
if serv.is_null() {
return Err(vm.new_os_error("service/proto not found".to_owned()));
}
@@ -1114,16 +1179,21 @@ fn _socket_inet_pton(af_inet: i32, ip_string: PyStrRef, vm: &VirtualMachine) ->
}
}
fn _socket_inet_ntop(af_inet: i32, packed_ip: PyBytesRef, vm: &VirtualMachine) -> PyResult<String> {
fn _socket_inet_ntop(
af_inet: i32,
packed_ip: PyBytesLike,
vm: &VirtualMachine,
) -> PyResult<String> {
let packed_ip = packed_ip.borrow_value();
match af_inet {
c::AF_INET => {
let packed_ip = <&[u8; 4]>::try_from(&**packed_ip).map_err(|_| {
let packed_ip = <&[u8; 4]>::try_from(&*packed_ip).map_err(|_| {
vm.new_value_error("invalid length of packed IP address string".to_owned())
})?;
Ok(Ipv4Addr::from(*packed_ip).to_string())
}
c::AF_INET6 => {
let packed_ip = <&[u8; 16]>::try_from(&**packed_ip).map_err(|_| {
let packed_ip = <&[u8; 16]>::try_from(&*packed_ip).map_err(|_| {
vm.new_value_error("invalid length of packed IP address string".to_owned())
})?;
Ok(get_ipv6_addr_str(Ipv6Addr::from(*packed_ip)))
@@ -1350,7 +1420,10 @@ fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String {
match ipv6.to_ipv4() {
Some(v4) if matches!(v4.octets(), [0, 0, _, _]) => format!("::{:x}", u32::from(v4)),
// instead of "::0.0.ddd.ddd" it's "::xxxx"
Some(v4) if !ipv6.is_unspecified() && matches!(v4.octets(), [0, 0, _, _]) => {
format!("::{:x}", u32::from(v4))
}
_ => ipv6.to_string(),
}
}
@@ -1438,8 +1511,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
})
.clone();
let socket = PySocket::make_class(ctx);
let module = py_module!(vm, "_socket", {
"socket" => PySocket::make_class(ctx),
"socket" => socket.clone(),
"SocketType" => socket,
"error" => ctx.exceptions.os_error.clone(),
"timeout" => socket_timeout,
"gaierror" => socket_gaierror,
@@ -1477,6 +1552,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"IPPROTO_IPIP" => ctx.new_int(c::IPPROTO_IP),
"IPPROTO_IPV6" => ctx.new_int(c::IPPROTO_IPV6),
"SOL_SOCKET" => ctx.new_int(c::SOL_SOCKET),
"SOL_TCP" => ctx.new_int(6),
"SO_REUSEADDR" => ctx.new_int(c::SO_REUSEADDR),
"SO_TYPE" => ctx.new_int(c::SO_TYPE),
"SO_BROADCAST" => ctx.new_int(c::SO_BROADCAST),
@@ -1498,17 +1574,14 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"AI_NUMERICSERV" => ctx.new_int(c::AI_NUMERICSERV),
});
#[cfg(not(windows))]
extend_module!(vm, module, {
"SO_REUSEPORT" => ctx.new_int(c::SO_REUSEPORT),
});
#[cfg(not(target_os = "redox"))]
extend_module!(vm, module, {
"getaddrinfo" => named_function!(ctx, _socket, getaddrinfo),
"gethostbyaddr" => named_function!(ctx, _socket, gethostbyaddr),
"gethostbyname" => named_function!(ctx, _socket, gethostbyname),
"getnameinfo" => named_function!(ctx, _socket, getnameinfo),
"SOCK_RAW" => ctx.new_int(c::SOCK_RAW),
"SOCK_RDM" => ctx.new_int(c::SOCK_RDM),
});
extend_module_platform_specific(vm, &module);
@@ -1526,11 +1599,13 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: &PyObjectRef) {
extend_module!(vm, module, {
"socketpair" => named_function!(ctx, _socket, socketpair),
"AF_UNIX" => ctx.new_int(c::AF_UNIX),
"SO_REUSEPORT" => ctx.new_int(c::SO_REUSEPORT),
});
#[cfg(not(target_os = "redox"))]
extend_module!(vm, module, {
"sethostname" => named_function!(ctx, _socket, sethostname),
"SOCK_SEQPACKET" => ctx.new_int(c::SOCK_SEQPACKET),
});
}