From 2dbe4a9c9aab508280665e8fc3436ef5bb707ac0 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Mon, 3 May 2021 14:23:08 -0500 Subject: [PATCH] Add sock.recvfrom_into + other --- vm/src/pyobject.rs | 23 ++++--- vm/src/stdlib/socket.rs | 147 ++++++++++++++++++++++++++++++---------- 2 files changed, 125 insertions(+), 45 deletions(-) diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index c8049880e..eefc56c73 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1204,15 +1204,20 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { impl TryFromObject for std::time::Duration { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { use std::time::Duration; - u64::try_from_object(vm, obj.clone()) - .map(Duration::from_secs) - .or_else(|_| f64::try_from_object(vm, obj.clone()).map(Duration::from_secs_f64)) - .map_err(|_| { - vm.new_type_error(format!( - "expected an int or float for duration, got {}", - obj.class() - )) - }) + if let Some(float) = obj.payload::() { + Ok(Duration::from_secs_f64(float.to_f64())) + } else if let Some(int) = vm.to_index_opt(obj.clone()) { + let sec = int? + .borrow_value() + .to_u64() + .ok_or_else(|| vm.new_value_error("value out of range".to_owned()))?; + Ok(Duration::from_secs(sec)) + } else { + Err(vm.new_type_error(format!( + "expected an int or float for duration, got {}", + obj.class() + ))) + } } } diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 2dbd7d452..26e006c9a 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -2,6 +2,7 @@ use crossbeam_utils::atomic::AtomicCell; use gethostname::gethostname; #[cfg(all(unix, not(target_os = "redox")))] use nix::unistd::sethostname; +use num_traits::ToPrimitive; use socket2::{Domain, Protocol, Socket, Type as SocketType}; use std::convert::TryFrom; use std::io; @@ -9,7 +10,6 @@ use std::mem::MaybeUninit; use std::net::{self, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs}; use std::time::{Duration, Instant}; -use crate::builtins::bytes::PyBytesRef; use crate::builtins::int; use crate::builtins::pystr::PyStrRef; use crate::builtins::pytype::PyTypeRef; @@ -17,7 +17,7 @@ use crate::builtins::tuple::PyTupleRef; use crate::byteslike::{PyBytesLike, PyRwBytesLike}; use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; use crate::exceptions::{IntoPyException, PyBaseExceptionRef}; -use crate::function::{FuncArgs, OptionalArg}; +use crate::function::{FuncArgs, OptionalArg, OptionalOption}; use crate::pyobject::{ BorrowValue, Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, StaticType, TryFromObject, TypeProtocol, @@ -122,14 +122,17 @@ impl PySocket { #[pymethod(name = "__init__")] fn init( &self, - mut family: i32, - mut socket_kind: i32, - mut proto: i32, - fileno: Option, + family: OptionalArg, + socket_kind: OptionalArg, + proto: OptionalArg, + fileno: OptionalOption, vm: &VirtualMachine, ) -> PyResult<()> { + let mut family = family.unwrap_or(-1); + let mut socket_kind = socket_kind.unwrap_or(-1); + let mut proto = proto.unwrap_or(-1); // should really just be to_index() but test_socket tests the error messages explicitly - let fileno = match fileno { + let fileno = match fileno.flatten() { Some(o) if o.isinstance(&vm.ctx.types.float_type) => { return Err(vm.new_type_error("integer argument expected, got float".to_owned())) } @@ -153,7 +156,8 @@ impl PySocket { Some(errcode!(ENOTSOCK)) | Some(errcode!(EBADF)) ) => { - return Err(e.into_pyexception(vm)) + std::mem::forget(sock); + return Err(e.into_pyexception(vm)); } _ => {} } @@ -481,19 +485,24 @@ impl PySocket { ) -> PyResult { let flags = flags.unwrap_or(0); let sock = self.sock(); + let mut buf = buf.borrow_value(); + let buf = &mut *buf; self.sock_op(vm, SelectKind::Read, || { - buf.with_ref(|buf| sock.recv_with_flags(slice_as_uninit(buf), flags)) + sock.recv_with_flags(slice_as_uninit(buf), flags) }) } #[pymethod] fn recvfrom( &self, - bufsize: usize, + bufsize: isize, flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult<(Vec, PyObjectRef)> { let flags = flags.unwrap_or(0); + let bufsize = bufsize + .to_usize() + .ok_or_else(|| vm.new_value_error("negative buffersize in recvfrom".to_owned()))?; let mut buffer = Vec::with_capacity(bufsize); let (n, addr) = self.sock_op(vm, SelectKind::Read, || { self.sock() @@ -503,6 +512,35 @@ impl PySocket { Ok((buffer, get_addr_tuple(&addr, vm))) } + #[pymethod] + fn recvfrom_into( + &self, + buf: PyRwBytesLike, + nbytes: OptionalArg, + flags: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<(usize, PyObjectRef)> { + let mut buf = buf.borrow_value(); + let buf = &mut *buf; + let buf = match nbytes { + OptionalArg::Present(i) => { + let i = i.to_usize().ok_or_else(|| { + vm.new_value_error("negative buffersize in recvfrom_into".to_owned()) + })?; + buf.get_mut(..i).ok_or_else(|| { + vm.new_value_error("nbytes is greater than the length of the buffer".to_owned()) + })? + } + OptionalArg::Missing => buf, + }; + let flags = flags.unwrap_or(0); + let sock = self.sock(); + let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + sock.recv_from_with_flags(slice_as_uninit(buf), flags) + })?; + Ok((n, get_addr_tuple(&addr, vm))) + } + #[pymethod] fn send( &self, @@ -511,8 +549,10 @@ impl PySocket { vm: &VirtualMachine, ) -> PyResult { let flags = flags.unwrap_or(0); + let buf = bytes.borrow_value(); + let buf = &*buf; self.sock_op(vm, SelectKind::Write, || { - bytes.with_ref(|b| self.sock().send_with_flags(b, flags)) + self.sock().send_with_flags(buf, flags) }) } @@ -534,20 +574,19 @@ impl PySocket { let deadline = timeout.map(Deadline::new); - let buf_len = bytes.len(); + let buf = bytes.borrow_value(); + let buf = &*buf; let mut buf_offset = 0; // now we have like 3 layers of interrupt loop :) - while buf_offset < buf_len { + while buf_offset < buf.len() { let interval = deadline .as_ref() .map(|d| d.time_until().map_err(|e| e.into_pyexception(vm))) .transpose()?; self.sock_op_timeout_err(vm, SelectKind::Write, interval, || { - bytes.with_ref(|b| { - let subbuf = &b[buf_offset..]; - buf_offset += self.sock().send_with_flags(subbuf, flags)?; - Ok(()) - }) + let subbuf = &buf[buf_offset..]; + buf_offset += self.sock().send_with_flags(subbuf, flags)?; + Ok(()) }) .map_err(|e| e.into_pyexception(vm))?; vm.check_signals()?; @@ -565,12 +604,21 @@ impl PySocket { ) -> PyResult { // signature is bytes[, flags], address let (flags, address) = match arg3 { - OptionalArg::Present(arg3) => (i32::try_from_object(vm, arg2)?, arg3), + OptionalArg::Present(arg3) => { + // should just be i32::try_from_obj but tests check for error message + let int = vm.to_index_opt(arg2).unwrap_or_else(|| { + Err(vm.new_type_error("an integer is required".to_owned())) + })?; + let flags = int::try_to_primitive::(int.borrow_value(), vm)?; + (flags, arg3) + } OptionalArg::Missing => (0, arg2), }; let addr = self.extract_address(address, "sendto", vm)?; + let buf = bytes.borrow_value(); + let buf = &*buf; self.sock_op(vm, SelectKind::Write, || { - bytes.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags)) + self.sock().send_to_with_flags(buf, &addr, flags) }) } @@ -758,6 +806,18 @@ impl PySocket { fn proto(&self) -> i32 { self.proto.load() } + + #[pymethod(magic)] + fn repr(&self) -> String { + format!( + "", + // cast because INVALID_SOCKET is unsigned, so would show usize::MAX instead of -1 + sock_fileno(&self.sock()) as i64, + self.family.load(), + self.kind.load(), + self.proto.load(), + ) + } } impl io::Read for PySocketRef { @@ -799,7 +859,6 @@ impl TryFromObject for Address { impl Address { fn from_tuple(tuple: &[PyObjectRef], vm: &VirtualMachine) -> PyResult { - use num_traits::ToPrimitive; let host = PyStrRef::try_from_object(vm, tuple[0].clone())?; let port = i32::try_from_object(vm, tuple[1].clone())?; let port = port @@ -875,8 +934,9 @@ fn _socket_inet_aton(ip_string: PyStrRef, vm: &VirtualMachine) -> PyResult PyResult { - let packed_ip = <&[u8; 4]>::try_from(&**packed_ip) +fn _socket_inet_ntoa(packed_ip: PyBytesLike, vm: &VirtualMachine) -> PyResult { + let packed_ip = packed_ip.borrow_value(); + let packed_ip = <&[u8; 4]>::try_from(&*packed_ip) .map_err(|_| vm.new_os_error("packed IP wrong length for inet_ntoa".to_owned()))?; Ok(vm.ctx.new_str(Ipv4Addr::from(*packed_ip).to_string())) } @@ -889,10 +949,15 @@ fn _socket_getservbyname( use std::ffi::CString; let cstr_name = CString::new(servicename.borrow_value()) .map_err(|_| vm.new_value_error("embedded null character".to_owned()))?; - let protocolname = protocolname.as_ref().map_or("", |s| s.borrow_value()); - let cstr_proto = CString::new(protocolname) + let cstr_proto = protocolname + .as_ref() + .map(|s| CString::new(s.borrow_value())) + .transpose() .map_err(|_| vm.new_value_error("embedded null character".to_owned()))?; - let serv = unsafe { c::getservbyname(cstr_name.as_ptr(), cstr_proto.as_ptr()) }; + let cstr_proto = cstr_proto + .as_ref() + .map_or_else(std::ptr::null, |s| s.as_ptr()); + let serv = unsafe { c::getservbyname(cstr_name.as_ptr(), cstr_proto) }; if serv.is_null() { return Err(vm.new_os_error("service/proto not found".to_owned())); } @@ -1114,16 +1179,21 @@ fn _socket_inet_pton(af_inet: i32, ip_string: PyStrRef, vm: &VirtualMachine) -> } } -fn _socket_inet_ntop(af_inet: i32, packed_ip: PyBytesRef, vm: &VirtualMachine) -> PyResult { +fn _socket_inet_ntop( + af_inet: i32, + packed_ip: PyBytesLike, + vm: &VirtualMachine, +) -> PyResult { + let packed_ip = packed_ip.borrow_value(); match af_inet { c::AF_INET => { - let packed_ip = <&[u8; 4]>::try_from(&**packed_ip).map_err(|_| { + let packed_ip = <&[u8; 4]>::try_from(&*packed_ip).map_err(|_| { vm.new_value_error("invalid length of packed IP address string".to_owned()) })?; Ok(Ipv4Addr::from(*packed_ip).to_string()) } c::AF_INET6 => { - let packed_ip = <&[u8; 16]>::try_from(&**packed_ip).map_err(|_| { + let packed_ip = <&[u8; 16]>::try_from(&*packed_ip).map_err(|_| { vm.new_value_error("invalid length of packed IP address string".to_owned()) })?; Ok(get_ipv6_addr_str(Ipv6Addr::from(*packed_ip))) @@ -1350,7 +1420,10 @@ fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef { fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String { match ipv6.to_ipv4() { - Some(v4) if matches!(v4.octets(), [0, 0, _, _]) => format!("::{:x}", u32::from(v4)), + // instead of "::0.0.ddd.ddd" it's "::xxxx" + Some(v4) if !ipv6.is_unspecified() && matches!(v4.octets(), [0, 0, _, _]) => { + format!("::{:x}", u32::from(v4)) + } _ => ipv6.to_string(), } } @@ -1438,8 +1511,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { }) .clone(); + let socket = PySocket::make_class(ctx); let module = py_module!(vm, "_socket", { - "socket" => PySocket::make_class(ctx), + "socket" => socket.clone(), + "SocketType" => socket, "error" => ctx.exceptions.os_error.clone(), "timeout" => socket_timeout, "gaierror" => socket_gaierror, @@ -1477,6 +1552,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "IPPROTO_IPIP" => ctx.new_int(c::IPPROTO_IP), "IPPROTO_IPV6" => ctx.new_int(c::IPPROTO_IPV6), "SOL_SOCKET" => ctx.new_int(c::SOL_SOCKET), + "SOL_TCP" => ctx.new_int(6), "SO_REUSEADDR" => ctx.new_int(c::SO_REUSEADDR), "SO_TYPE" => ctx.new_int(c::SO_TYPE), "SO_BROADCAST" => ctx.new_int(c::SO_BROADCAST), @@ -1498,17 +1574,14 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "AI_NUMERICSERV" => ctx.new_int(c::AI_NUMERICSERV), }); - #[cfg(not(windows))] - extend_module!(vm, module, { - "SO_REUSEPORT" => ctx.new_int(c::SO_REUSEPORT), - }); - #[cfg(not(target_os = "redox"))] extend_module!(vm, module, { "getaddrinfo" => named_function!(ctx, _socket, getaddrinfo), "gethostbyaddr" => named_function!(ctx, _socket, gethostbyaddr), "gethostbyname" => named_function!(ctx, _socket, gethostbyname), "getnameinfo" => named_function!(ctx, _socket, getnameinfo), + "SOCK_RAW" => ctx.new_int(c::SOCK_RAW), + "SOCK_RDM" => ctx.new_int(c::SOCK_RDM), }); extend_module_platform_specific(vm, &module); @@ -1526,11 +1599,13 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: &PyObjectRef) { extend_module!(vm, module, { "socketpair" => named_function!(ctx, _socket, socketpair), "AF_UNIX" => ctx.new_int(c::AF_UNIX), + "SO_REUSEPORT" => ctx.new_int(c::SO_REUSEPORT), }); #[cfg(not(target_os = "redox"))] extend_module!(vm, module, { "sethostname" => named_function!(ctx, _socket, sethostname), + "SOCK_SEQPACKET" => ctx.new_int(c::SOCK_SEQPACKET), }); }