From ca912fcf2720b329d6af2b2e36e5ef9ba7b6a612 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 23 Jan 2021 21:08:43 -0600 Subject: [PATCH 1/6] Fix BufferedReader over a SocketIO to properly read the right amount --- vm/src/builtins/memory.rs | 6 +++ vm/src/stdlib/io.rs | 22 +++++------ vm/src/stdlib/socket.rs | 83 ++++++++++++++++++++++++++++----------- 3 files changed, 77 insertions(+), 34 deletions(-) diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index 7087be9e8..626f663b8 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -95,7 +95,13 @@ impl Buffer for RcBuffer { pub trait Buffer: Debug + PyThreadingConstraint { fn get_options(&self) -> &BufferOptions; + /// Get the full inner buffer of this memory. You probably want [`as_contiguous()`], as + /// `obj_bytes` doesn't take into account the range a memoryview might operate on, among other + /// footguns. fn obj_bytes(&self) -> BorrowedValue<[u8]>; + /// Get the full inner buffer of this memory, mutably. You probably want + /// [`as_contiguous_mut()`], as `obj_bytes` doesn't take into account the range a memoryview + /// might operate on, among other footguns. fn obj_bytes_mut(&self) -> BorrowedValueMut<[u8]>; fn release(&self); diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index cbf844fc3..29f04350b 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -903,7 +903,7 @@ mod _io { // raw file is non-blocking if remaining > self.buffer.len() { // can't buffer everything, buffer what we can and error - let buf = rcbuf.obj_bytes(); + let buf = rcbuf.as_contiguous().unwrap(); let buffer_len = self.buffer.len(); self.buffer.copy_from_slice(&buf[written..][..buffer_len]); self.raw_pos = 0; @@ -926,7 +926,7 @@ mod _io { self.reset_read(); } if remaining > 0 { - let buf = rcbuf.obj_bytes(); + let buf = rcbuf.as_contiguous().unwrap(); self.buffer[..remaining].copy_from_slice(&buf[written..][..remaining]); written += remaining; } @@ -1025,13 +1025,13 @@ mod _io { 0 }; let buf_end = self.buffer.len(); - let res = self.raw_read(Either::A(None), start..buf_end, vm); - if let Ok(Some(n)) = &res { - let new_start = (start + *n) as Offset; + let res = self.raw_read(Either::A(None), start..buf_end, vm)?; + if let Some(n) = res.filter(|n| *n > 0) { + let new_start = (start + n) as Offset; self.read_end = new_start; self.raw_pos = new_start; } - res + Ok(res) } fn raw_read( @@ -1094,7 +1094,7 @@ mod _io { n, len ))); } - if self.abs_pos != -1 { + if n > 0 && self.abs_pos != -1 { self.abs_pos += n as Offset } Ok(Some(n as usize)) @@ -1194,10 +1194,10 @@ mod _io { let n = self.readahead(); let buf_len; { - let mut b = buf.obj_bytes_mut(); + let mut b = buf.as_contiguous_mut().unwrap(); buf_len = b.len(); if n > 0 { - if n as usize > b.len() { + if n as usize >= b.len() { b.copy_from_slice(&self.buffer[self.pos as usize..][..buf_len]); self.pos += buf_len as Offset; return Ok(Some(buf_len)); @@ -1224,7 +1224,7 @@ mod _io { let n = self.fill_buffer(vm)?; if let Some(n) = n.filter(|&n| n > 0) { let n = std::cmp::min(n, remaining); - rcbuf.obj_bytes_mut()[written..][..n] + rcbuf.as_contiguous_mut().unwrap()[written..][..n] .copy_from_slice(&self.buffer[self.pos as usize..][..n]); self.pos += n as Offset; written += n; @@ -1233,7 +1233,7 @@ mod _io { } n } else { - Some(0) + break; }; let n = match n { Some(0) => break, diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 23e876afe..32e112263 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -59,6 +59,7 @@ pub struct PySocket { kind: AtomicCell, family: AtomicCell, proto: AtomicCell, + timeout: AtomicCell, sock: PyRwLock, } @@ -86,6 +87,7 @@ impl PySocket { kind: AtomicCell::default(), family: AtomicCell::default(), proto: AtomicCell::default(), + timeout: AtomicCell::new(-1.0), sock: PyRwLock::new(invalid_sock()), } .into_ref_with_type(vm, cls) @@ -128,11 +130,33 @@ impl PySocket { Ok(()) } + fn sock_op(&self, vm: &VirtualMachine, mut f: F) -> PyResult + where + F: FnMut() -> io::Result, + { + loop { + let err = loop { + // loop on interrupt + match f() { + Ok(x) => return Ok(x), + Err(e) if e.kind() == io::ErrorKind::Interrupted => vm.check_signals()?, + Err(e) => break e, + } + }; + if self.timeout.load() > 0.0 && err.kind() == io::ErrorKind::WouldBlock { + continue; + } + return Err(convert_sock_error(vm, err)); + } + } + #[pymethod] fn connect(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> { let sock_addr = get_addr(vm, address, Some(self.family.load()))?; - let res = if let Some(duration) = self.sock().read_timeout().unwrap() { - self.sock().connect_timeout(&sock_addr, duration) + let timeout = self.timeout.load(); + let res = if timeout > 0.0 { + let timeout = Duration::from_secs_f64(timeout); + self.sock().connect_timeout(&sock_addr, timeout) } else { self.sock().connect(&sock_addr) }; @@ -176,10 +200,8 @@ impl PySocket { ) -> PyResult> { let flags = flags.unwrap_or(0); let mut buffer = vec![0u8; bufsize]; - let n = self - .sock() - .recv_with_flags(&mut buffer, flags) - .map_err(|err| convert_sock_error(vm, err))?; + let sock = self.sock(); + let n = self.sock_op(vm, || sock.recv_with_flags(&mut buffer, flags))?; buffer.truncate(n); Ok(buffer) } @@ -192,8 +214,8 @@ impl PySocket { vm: &VirtualMachine, ) -> PyResult { let flags = flags.unwrap_or(0); - buf.with_ref(|buf| self.sock().recv_with_flags(buf, flags)) - .map_err(|err| convert_sock_error(vm, err)) + let sock = self.sock(); + buf.with_ref(|buf| self.sock_op(vm, || sock.recv_with_flags(buf, flags))) } #[pymethod] @@ -290,24 +312,29 @@ impl PySocket { } #[pymethod] - fn gettimeout(&self, vm: &VirtualMachine) -> PyResult> { - let dur = self - .sock() - .read_timeout() - .map_err(|err| convert_sock_error(vm, err))?; - Ok(dur.map(|d| d.as_secs_f64())) + fn gettimeout(&self) -> Option { + let timeout = self.timeout.load(); + if timeout >= 0.0 { + Some(timeout) + } else { + None + } } #[pymethod] fn setblocking(&self, block: bool, vm: &VirtualMachine) -> PyResult<()> { - self.sock() - .set_nonblocking(!block) - .map_err(|err| convert_sock_error(vm, err)) + self.timeout.store(if block { -1.0 } else { 0.0 }); + let timeout = if block { + Some(Duration::from_secs(0)) + } else { + None + }; + self.settimeout_sock(Some(block), timeout, vm) } #[pymethod] - fn getblocking(&self, vm: &VirtualMachine) -> PyResult { - Ok(self.gettimeout(vm)?.map_or(false, |t| t == 0.0)) + fn getblocking(&self) -> bool { + self.timeout.load() != 0.0 } #[pymethod] @@ -315,11 +342,21 @@ impl PySocket { // timeout is None: blocking, no timeout // timeout is 0: non-blocking, no timeout // otherwise: timeout is timeout, don't change blocking - let (block, timeout) = match timeout { - None => (Some(true), None), - Some(d) if d == Duration::from_secs(0) => (Some(false), None), - Some(d) => (None, Some(d)), + let (block, timeout, f64_timeout) = match timeout { + None => (Some(true), None, -1.0), + Some(d) if d == Duration::from_secs(0) => (Some(false), None, 0.0), + Some(d) => (None, Some(d), d.as_secs_f64()), }; + self.timeout.store(f64_timeout); + self.settimeout_sock(block, timeout, vm) + } + + fn settimeout_sock( + &self, + block: Option, + timeout: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { self.sock() .set_read_timeout(timeout) .map_err(|err| convert_sock_error(vm, err))?; From 9e5716124a0d4910635d444a6f41f9b847fd8419 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Sat, 23 Jan 2021 22:41:02 -0600 Subject: [PATCH 2/6] Don't panic on invalid wbits --- vm/src/stdlib/zlib.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vm/src/stdlib/zlib.rs b/vm/src/stdlib/zlib.rs index 2f2f1eacd..9d9e64806 100644 --- a/vm/src/stdlib/zlib.rs +++ b/vm/src/stdlib/zlib.rs @@ -89,10 +89,14 @@ mod decl { Ok(vm.ctx.new_bytes(encoded_bytes)) } - // TODO: validate wbits value here - fn header_from_wbits(wbits: OptionalArg) -> (bool, u8) { + fn header_from_wbits(wbits: OptionalArg, vm: &VirtualMachine) -> PyResult<(bool, u8)> { let wbits = wbits.unwrap_or(MAX_WBITS as i8); - (wbits > 0, wbits.abs() as u8) + let header = wbits > 0; + let wbits = wbits.abs() as u8; + match wbits { + 9..=15 => Ok((header, wbits)), + _ => Err(vm.new_value_error("Invalid initialization option".to_owned())), + } } fn _decompress( @@ -172,7 +176,7 @@ mod decl { vm: &VirtualMachine, ) -> PyResult> { data.with_ref(|data| { - let (header, wbits) = header_from_wbits(wbits); + let (header, wbits) = header_from_wbits(wbits, vm)?; let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE); let mut d = Decompress::new_with_window_bits(header, wbits); @@ -187,18 +191,18 @@ mod decl { } #[pyfunction] - fn decompressobj(args: DecopmressobjArgs, vm: &VirtualMachine) -> PyDecompress { - let (header, wbits) = header_from_wbits(args.wbits); + fn decompressobj(args: DecopmressobjArgs, vm: &VirtualMachine) -> PyResult { + let (header, wbits) = header_from_wbits(args.wbits, vm)?; let mut decompress = Decompress::new_with_window_bits(header, wbits); if let OptionalArg::Present(dict) = args.zdict { dict.with_ref(|d| decompress.set_dictionary(d).unwrap()); } - PyDecompress { + Ok(PyDecompress { decompress: PyMutex::new(decompress), eof: AtomicCell::new(false), unused_data: PyMutex::new(PyBytes::from(vec![]).into_ref(vm)), unconsumed_tail: PyMutex::new(PyBytes::from(vec![]).into_ref(vm)), - } + }) } #[pyattr] #[pyclass(name = "Decompress")] @@ -350,7 +354,7 @@ mod decl { _zdict: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let (header, wbits) = header_from_wbits(wbits); + let (header, wbits) = header_from_wbits(wbits, vm)?; let level = level.unwrap_or(-1); let level = match level { From f6139951980125e2c46dc7d25994d5ea5f814c7c Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Mon, 25 Jan 2021 12:40:24 -0600 Subject: [PATCH 3/6] zlib gzip wbits --- Cargo.lock | 4 ++-- vm/Cargo.toml | 2 +- vm/src/stdlib/zlib.rs | 24 +++++++++++++++++++----- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7eb8f7683..49b3f8023 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -783,9 +783,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7411863d55df97a419aa64cb4d2f167103ea9d767e2c54a1868b7ac3f6b47129" +checksum = "cd3aec53de10fe96d7d8c565eb17f2c687bb5518a2ec453b5b1252964526abe0" dependencies = [ "cfg-if 1.0.0", "crc32fast", diff --git a/vm/Cargo.toml b/vm/Cargo.toml index ad193f749..ba3285716 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -124,7 +124,7 @@ num_cpus = "1" [target.'cfg(not(any(target_arch = "wasm32", target_os = "redox")))'.dependencies] dns-lookup = "1.0" -flate2 = { version = "1.0", features = ["zlib"], default-features = false } +flate2 = { version = "1.0.20", features = ["zlib"], default-features = false } libz-sys = "1.0" [target.'cfg(windows)'.dependencies] diff --git a/vm/src/stdlib/zlib.rs b/vm/src/stdlib/zlib.rs index 9d9e64806..b80e40a01 100644 --- a/vm/src/stdlib/zlib.rs +++ b/vm/src/stdlib/zlib.rs @@ -89,12 +89,16 @@ mod decl { Ok(vm.ctx.new_bytes(encoded_bytes)) } - fn header_from_wbits(wbits: OptionalArg, vm: &VirtualMachine) -> PyResult<(bool, u8)> { + fn header_from_wbits( + wbits: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<(Option, u8)> { let wbits = wbits.unwrap_or(MAX_WBITS as i8); let header = wbits > 0; let wbits = wbits.abs() as u8; match wbits { - 9..=15 => Ok((header, wbits)), + 9..=15 => Ok((Some(header), wbits)), + 25..=31 => Ok((None, wbits - 16)), _ => Err(vm.new_value_error("Invalid initialization option".to_owned())), } } @@ -179,7 +183,10 @@ mod decl { let (header, wbits) = header_from_wbits(wbits, vm)?; let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE); - let mut d = Decompress::new_with_window_bits(header, wbits); + let mut d = match header { + Some(header) => Decompress::new_with_window_bits(header, wbits), + None => Decompress::new_gzip(wbits), + }; _decompress(data, &mut d, bufsize, None, vm).and_then(|(buf, stream_end)| { if stream_end { Ok(buf) @@ -193,7 +200,10 @@ mod decl { #[pyfunction] fn decompressobj(args: DecopmressobjArgs, vm: &VirtualMachine) -> PyResult { let (header, wbits) = header_from_wbits(args.wbits, vm)?; - let mut decompress = Decompress::new_with_window_bits(header, wbits); + let mut decompress = match header { + Some(header) => Decompress::new_with_window_bits(header, wbits), + None => Decompress::new_gzip(wbits), + }; if let OptionalArg::Present(dict) = args.zdict { dict.with_ref(|d| decompress.set_dictionary(d).unwrap()); } @@ -362,7 +372,11 @@ mod decl { n @ 0..=9 => n as u32, _ => return Err(vm.new_value_error("invalid initialization option".to_owned())), }; - let compress = Compress::new_with_window_bits(Compression::new(level), header, wbits); + let level = Compression::new(level); + let compress = match header { + Some(header) => Compress::new_with_window_bits(level, header, wbits), + None => Compress::new_gzip(level, wbits), + }; Ok(PyCompress { inner: PyMutex::new(CompressInner { compress, From d6eeed9770a89e56fb86a0737c984ab16a6a0404 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Fri, 29 Jan 2021 00:37:57 -0600 Subject: [PATCH 4/6] Yeah sure lets just reimplement _all_ the timeout code --- vm/src/stdlib/select.rs | 55 +++++--- vm/src/stdlib/socket.rs | 275 ++++++++++++++++++++++++++++++---------- 2 files changed, 247 insertions(+), 83 deletions(-) diff --git a/vm/src/stdlib/select.rs b/vm/src/stdlib/select.rs index 0069eb31a..bce2d82ab 100644 --- a/vm/src/stdlib/select.rs +++ b/vm/src/stdlib/select.rs @@ -1,5 +1,6 @@ use crate::pyobject::{PyObjectRef, PyResult, TryFromObject}; use crate::vm::VirtualMachine; +use std::io; pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef { #[cfg(windows)] @@ -68,8 +69,8 @@ impl TryFromObject for Selectable { } } -#[repr(C)] -struct FdSet(platform::fd_set); +#[repr(transparent)] +pub struct FdSet(platform::fd_set); impl FdSet { pub fn new() -> FdSet { @@ -103,6 +104,33 @@ impl FdSet { } } +pub fn select( + nfds: libc::c_int, + readfds: &mut FdSet, + writefds: &mut FdSet, + errfds: &mut FdSet, + timeout: Option<&mut timeval>, +) -> io::Result { + let timeout = match timeout { + Some(tv) => tv as *mut timeval, + None => std::ptr::null_mut(), + }; + let ret = unsafe { + platform::select( + nfds, + &mut readfds.0, + &mut writefds.0, + &mut errfds.0, + timeout, + ) + }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret) + } +} + fn sec_to_timeval(sec: f64) -> timeval { timeval { tv_sec: sec.trunc() as _, @@ -162,19 +190,14 @@ mod decl { .max() .map_or(0, |n| n + 1) as i32; - let (select_res, err) = loop { + loop { let mut tv = timeout.map(sec_to_timeval); - let timeout_ptr = match tv { - Some(ref mut tv) => tv as *mut _, - None => std::ptr::null_mut(), - }; - let res = - unsafe { super::platform::select(nfds, &mut r.0, &mut w.0, &mut x.0, timeout_ptr) }; + let res = super::select(nfds, &mut r, &mut w, &mut x, tv.as_mut()); - let err = std::io::Error::last_os_error(); - - if res >= 0 || err.kind() != std::io::ErrorKind::Interrupted { - break (res, err); + match res { + Ok(_) => break, + Err(err) if err.kind() == io::ErrorKind::Interrupted => {} + Err(err) => return Err(err.into_pyexception(vm)), } vm.check_signals()?; @@ -185,14 +208,10 @@ mod decl { r.clear(); w.clear(); x.clear(); - break (0, err); + break; } // retry select() if we haven't reached the deadline yet } - }; - - if select_res < 0 { - return Err(err.into_pyexception(vm)); } let set2list = |list: Vec, mut set: FdSet| { diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 32e112263..ca62e8190 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -6,7 +6,7 @@ use socket2::{Domain, Protocol, Socket, Type as SocketType}; use std::convert::TryFrom; use std::io; use std::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs}; -use std::time::Duration; +use std::time::{Duration, Instant}; use crate::builtins::bytes::PyBytesRef; use crate::builtins::pystr::{PyStr, PyStrRef}; @@ -20,7 +20,7 @@ use crate::pyobject::{ BorrowValue, Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, StaticType, TryFromObject, }; -use crate::{py_io, VirtualMachine}; +use crate::VirtualMachine; #[cfg(unix)] type RawSocket = std::os::unix::io::RawFd; @@ -130,11 +130,57 @@ impl PySocket { Ok(()) } - fn sock_op(&self, vm: &VirtualMachine, mut f: F) -> PyResult + fn sock_op(&self, vm: &VirtualMachine, select: SelectKind, f: F) -> PyResult where F: FnMut() -> io::Result, { + let timeout = self.timeout.load(); + let timeout = if timeout > 0.0 { + Some(Duration::from_secs_f64(timeout)) + } else { + None + }; + self.sock_op_timeout(vm, select, timeout, f) + } + + fn sock_op_timeout( + &self, + vm: &VirtualMachine, + select: SelectKind, + timeout: Option, + mut f: F, + ) -> PyResult + where + F: FnMut() -> io::Result, + { + let mut deadline: Option = None; + loop { + if timeout.is_some() || matches!(select, SelectKind::Connect) { + let interval = timeout.map(|dur| match deadline { + Some(d) => d + .checked_duration_since(Instant::now()) + // past the deadline already + .ok_or_else(|| timeout_error(vm)), + None => { + let dl = Instant::now() + dur; + deadline = Some(dl); + Ok(dur) + } + }); + let interval = interval.transpose()?; + let res = sock_select(&self.sock(), select, interval); + match res { + Ok(true) => return Err(timeout_error(vm)), + Err(e) if e.kind() == io::ErrorKind::Interrupted => { + vm.check_signals()?; + continue; + } + Err(e) => return Err(convert_sock_error(vm, e)), + Ok(false) => {} // no timeout, continue as normal + } + } + let err = loop { // loop on interrupt match f() { @@ -143,7 +189,7 @@ impl PySocket { Err(e) => break e, } }; - if self.timeout.load() > 0.0 && err.kind() == io::ErrorKind::WouldBlock { + if timeout.is_some() && err.kind() == io::ErrorKind::WouldBlock { continue; } return Err(convert_sock_error(vm, err)); @@ -153,14 +199,35 @@ impl PySocket { #[pymethod] fn connect(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> { let sock_addr = get_addr(vm, address, Some(self.family.load()))?; - let timeout = self.timeout.load(); - let res = if timeout > 0.0 { - let timeout = Duration::from_secs_f64(timeout); - self.sock().connect_timeout(&sock_addr, timeout) - } else { - self.sock().connect(&sock_addr) + + let err = match self.sock().connect(&sock_addr) { + Ok(()) => return Ok(()), + Err(e) => e, }; - res.map_err(|err| convert_sock_error(vm, err)) + + let wait_connect = if err.kind() == io::ErrorKind::Interrupted { + vm.check_signals()?; + self.timeout.load() != 0.0 + } else { + self.timeout.load() > 0.0 && err.raw_os_error() == Some(libc::EINPROGRESS) + }; + + if wait_connect { + // basically, connect() is async, and it registers an "error" on the socket when it's + // done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up + // from poll and the error is EISCONN then we know that the connect is done + self.sock_op(vm, SelectKind::Connect, || { + let sock = self.sock(); + let err = sock.take_error()?; + match err { + Some(e) if e.raw_os_error() == Some(libc::EISCONN) => Ok(()), + Some(e) => Err(e), + None => Ok(()), + } + }) + } else { + Err(convert_sock_error(vm, err)) + } } #[pymethod] @@ -182,11 +249,7 @@ impl PySocket { #[pymethod] fn _accept(&self, vm: &VirtualMachine) -> PyResult<(RawSocket, AddrTuple)> { - let (sock, addr) = self - .sock() - .accept() - .map_err(|err| convert_sock_error(vm, err))?; - + let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock().accept())?; let fd = into_sock_fileno(sock); Ok((fd, get_addr_tuple(addr))) } @@ -201,7 +264,9 @@ impl PySocket { let flags = flags.unwrap_or(0); let mut buffer = vec![0u8; bufsize]; let sock = self.sock(); - let n = self.sock_op(vm, || sock.recv_with_flags(&mut buffer, flags))?; + let n = self.sock_op(vm, SelectKind::Read, || { + sock.recv_with_flags(&mut buffer, flags) + })?; buffer.truncate(n); Ok(buffer) } @@ -215,7 +280,9 @@ impl PySocket { ) -> PyResult { let flags = flags.unwrap_or(0); let sock = self.sock(); - buf.with_ref(|buf| self.sock_op(vm, || sock.recv_with_flags(buf, flags))) + self.sock_op(vm, SelectKind::Read, || { + buf.with_ref(|buf| sock.recv_with_flags(buf, flags)) + }) } #[pymethod] @@ -227,10 +294,9 @@ impl PySocket { ) -> PyResult<(Vec, AddrTuple)> { let flags = flags.unwrap_or(0); let mut buffer = vec![0u8; bufsize]; - let (n, addr) = self - .sock() - .recv_from_with_flags(&mut buffer, flags) - .map_err(|err| convert_sock_error(vm, err))?; + let (n, addr) = self.sock_op(vm, SelectKind::Read, || { + self.sock().recv_from_with_flags(&mut buffer, flags) + })?; buffer.truncate(n); Ok((buffer, get_addr_tuple(addr))) } @@ -243,9 +309,9 @@ impl PySocket { vm: &VirtualMachine, ) -> PyResult { let flags = flags.unwrap_or(0); - bytes - .with_ref(|b| self.sock().send_with_flags(b, flags)) - .map_err(|err| convert_sock_error(vm, err)) + self.sock_op(vm, SelectKind::Write, || { + bytes.with_ref(|b| self.sock().send_with_flags(b, flags)) + }) } #[pymethod] @@ -256,10 +322,42 @@ impl PySocket { vm: &VirtualMachine, ) -> PyResult<()> { let flags = flags.unwrap_or(0); - let sock = self.sock(); - bytes - .with_ref(|buf| py_io::write_all(buf, |b| sock.send_with_flags(b, flags))) - .map_err(|err| convert_sock_error(vm, err)) + + let timeout = self.timeout.load(); + let timeout = if timeout > 0.0 { + Some(Duration::from_secs_f64(timeout)) + } else { + None + }; + + let mut deadline: Option = None; + + let buf_len = bytes.len(); + let mut buf_offset = 0; + // now we have like 3 layers of interrupt loop :) + while buf_offset < buf_len { + let interval = timeout.map(|dur| match deadline { + Some(d) => d + .checked_duration_since(Instant::now()) + // past the deadline already + .ok_or_else(|| timeout_error(vm)), + None => { + let dl = Instant::now() + dur; + deadline = Some(dl); + Ok(dur) + } + }); + let interval = interval.transpose()?; + self.sock_op_timeout(vm, SelectKind::Write, interval, || { + bytes.with_ref(|b| { + let subbuf = &b[buf_offset..]; + buf_offset += self.sock().send_with_flags(subbuf, flags)?; + Ok(()) + }) + })?; + vm.check_signals()?; + } + Ok(()) } #[pymethod] @@ -269,13 +367,12 @@ impl PySocket { address: Address, flags: OptionalArg, vm: &VirtualMachine, - ) -> PyResult<()> { + ) -> PyResult { let flags = flags.unwrap_or(0); let addr = get_addr(vm, address, Some(self.family.load()))?; - bytes - .with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags)) - .map_err(|err| convert_sock_error(vm, err))?; - Ok(()) + self.sock_op(vm, SelectKind::Write, || { + bytes.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags)) + }) } #[pymethod] @@ -324,12 +421,9 @@ impl PySocket { #[pymethod] fn setblocking(&self, block: bool, vm: &VirtualMachine) -> PyResult<()> { self.timeout.store(if block { -1.0 } else { 0.0 }); - let timeout = if block { - Some(Duration::from_secs(0)) - } else { - None - }; - self.settimeout_sock(Some(block), timeout, vm) + self.sock() + .set_nonblocking(!block) + .map_err(|err| convert_sock_error(vm, err)) } #[pymethod] @@ -342,33 +436,15 @@ impl PySocket { // timeout is None: blocking, no timeout // timeout is 0: non-blocking, no timeout // otherwise: timeout is timeout, don't change blocking - let (block, timeout, f64_timeout) = match timeout { - None => (Some(true), None, -1.0), - Some(d) if d == Duration::from_secs(0) => (Some(false), None, 0.0), - Some(d) => (None, Some(d), d.as_secs_f64()), + let (block, timeout) = match timeout { + None => (true, -1.0), + Some(d) if d == Duration::from_secs(0) => (false, 0.0), + Some(d) => (true, d.as_secs_f64()), }; - self.timeout.store(f64_timeout); - self.settimeout_sock(block, timeout, vm) - } - - fn settimeout_sock( - &self, - block: Option, - timeout: Option, - vm: &VirtualMachine, - ) -> PyResult<()> { + self.timeout.store(timeout); self.sock() - .set_read_timeout(timeout) - .map_err(|err| convert_sock_error(vm, err))?; - self.sock() - .set_write_timeout(timeout) - .map_err(|err| convert_sock_error(vm, err))?; - if let Some(blocking) = block { - self.sock() - .set_nonblocking(!blocking) - .map_err(|err| convert_sock_error(vm, err))?; - } - Ok(()) + .set_nonblocking(!block) + .map_err(|err| convert_sock_error(vm, err)) } #[pymethod] @@ -587,6 +663,71 @@ fn _socket_getservbyname( Ok(vm.ctx.new_int(u16::from_be(port as u16))) } +#[derive(Copy, Clone)] +enum SelectKind { + Read, + Write, + Connect, +} + +/// returns true if timed out +fn sock_select(sock: &Socket, kind: SelectKind, interval: Option) -> io::Result { + let fd = sock_fileno(sock); + #[cfg(unix)] + { + let mut pollfd = libc::pollfd { + fd, + events: match kind { + SelectKind::Read => libc::POLLIN, + SelectKind::Write => libc::POLLOUT, + SelectKind::Connect => libc::POLLOUT | libc::POLLERR, + }, + revents: 0, + }; + let timeout = match interval { + Some(d) => d.as_millis() as _, + None => -1, + }; + let ret = unsafe { libc::poll(&mut pollfd, 1, timeout) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(ret == 0) + } + } + #[cfg(windows)] + { + use crate::stdlib::select; + + let mut reads = select::FdSet::new(); + let mut writes = select::FdSet::new(); + let mut errs = select::FdSet::new(); + + match kind { + SelectKind::Read => reads.insert(fd), + SelectKind::Write => writes.insert(fd), + SelectKind::Connect => { + writes.insert(fd); + errs.insert(fd); + } + } + + let mut interval = interval.map(|dur| libc::timeval { + tv_sec: dur.as_secs() as _, + tv_usec: dur.subsec_micros() as _, + }); + + select::select( + fd as i32 + 1, + &mut reads, + &mut writes, + &mut errs, + interval.as_mut(), + ) + .map(|ret| ret == 0) + } +} + #[derive(FromArgs)] struct GAIOptions { #[pyarg(positional)] @@ -854,6 +995,10 @@ fn convert_sock_error(vm: &VirtualMachine, err: io::Error) -> PyBaseExceptionRef } } +fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), "timed out".to_owned()) +} + fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String { match ipv6.to_ipv4() { Some(v4) if matches!(v4.octets(), [0, 0, _, _]) => format!("::{:x}", u32::from(v4)), From 235c384f49b801dc3f8b361709e8cdb8a8badaa9 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Fri, 29 Jan 2021 00:43:18 -0600 Subject: [PATCH 5/6] Remove py_io::write_all, fix windows --- vm/src/py_io.rs | 19 ------------------- vm/src/stdlib/select.rs | 15 ++++++++------- vm/src/stdlib/socket.rs | 3 ++- 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/vm/src/py_io.rs b/vm/src/py_io.rs index 3974dac5b..17c007772 100644 --- a/vm/src/py_io.rs +++ b/vm/src/py_io.rs @@ -30,25 +30,6 @@ impl Write for PyWriter<'_> { } } -pub fn write_all( - mut buf: &[u8], - mut write: impl FnMut(&[u8]) -> io::Result, -) -> io::Result<()> { - while !buf.is_empty() { - match write(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write whole buffer", - )) - } - Ok(n) => buf = &buf[n..], - Err(e) => return Err(e), - } - } - Ok(()) -} - pub fn file_readline(obj: &PyObjectRef, size: Option, vm: &VirtualMachine) -> PyResult { let args = size.map_or_else(Vec::new, |size| vec![vm.ctx.new_int(size)]); let ret = vm.call_method(obj, "readline", args)?; diff --git a/vm/src/stdlib/select.rs b/vm/src/stdlib/select.rs index bce2d82ab..7f697a314 100644 --- a/vm/src/stdlib/select.rs +++ b/vm/src/stdlib/select.rs @@ -13,18 +13,18 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef { #[cfg(unix)] mod platform { - pub(super) use libc::{fd_set, select, timeval, FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO}; - pub(super) use std::os::unix::io::RawFd; + pub use libc::{fd_set, select, timeval, FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO}; + pub use std::os::unix::io::RawFd; } #[allow(non_snake_case)] #[cfg(windows)] mod platform { - pub(super) use winapi::um::winsock2::{fd_set, select, timeval, FD_SETSIZE, SOCKET as RawFd}; + pub use winapi::um::winsock2::{fd_set, select, timeval, FD_SETSIZE, SOCKET as RawFd}; // from winsock2.h: https://gist.github.com/piscisaureus/906386#file-winsock2-h-L128-L141 - pub(super) unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) { + pub unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) { let mut i = 0; for idx in 0..(*set).fd_count as usize { i = idx; @@ -40,17 +40,18 @@ mod platform { } } - pub(super) unsafe fn FD_ZERO(set: *mut fd_set) { + pub unsafe fn FD_ZERO(set: *mut fd_set) { (*set).fd_count = 0; } - pub(super) unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool { + pub unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool { use winapi::um::winsock2::__WSAFDIsSet; __WSAFDIsSet(fd as _, set) != 0 } } -use platform::{timeval, RawFd}; +pub use platform::timeval; +use platform::RawFd; struct Selectable { obj: PyObjectRef, diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index ca62e8190..e4a751eb0 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -703,6 +703,7 @@ fn sock_select(sock: &Socket, kind: SelectKind, interval: Option) -> i let mut writes = select::FdSet::new(); let mut errs = select::FdSet::new(); + let fd = fd as usize; match kind { SelectKind::Read => reads.insert(fd), SelectKind::Write => writes.insert(fd), @@ -712,7 +713,7 @@ fn sock_select(sock: &Socket, kind: SelectKind, interval: Option) -> i } } - let mut interval = interval.map(|dur| libc::timeval { + let mut interval = interval.map(|dur| select::timeval { tv_sec: dur.as_secs() as _, tv_usec: dur.subsec_micros() as _, }); From 160505d0cf769eb19404d86d8b42084077e8296b Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Wed, 3 Feb 2021 10:50:28 -0600 Subject: [PATCH 6/6] Use helper type for deadlines --- vm/src/stdlib/socket.rs | 51 +++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index e4a751eb0..8733edc86 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -153,22 +153,11 @@ impl PySocket { where F: FnMut() -> io::Result, { - let mut deadline: Option = None; + let deadline = timeout.map(Deadline::new); loop { - if timeout.is_some() || matches!(select, SelectKind::Connect) { - let interval = timeout.map(|dur| match deadline { - Some(d) => d - .checked_duration_since(Instant::now()) - // past the deadline already - .ok_or_else(|| timeout_error(vm)), - None => { - let dl = Instant::now() + dur; - deadline = Some(dl); - Ok(dur) - } - }); - let interval = interval.transpose()?; + if deadline.is_some() || matches!(select, SelectKind::Connect) { + let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?; let res = sock_select(&self.sock(), select, interval); match res { Ok(true) => return Err(timeout_error(vm)), @@ -222,6 +211,7 @@ impl PySocket { match err { Some(e) if e.raw_os_error() == Some(libc::EISCONN) => Ok(()), Some(e) => Err(e), + // TODO: is this accurate? None => Ok(()), } }) @@ -330,24 +320,13 @@ impl PySocket { None }; - let mut deadline: Option = None; + let deadline = timeout.map(Deadline::new); let buf_len = bytes.len(); let mut buf_offset = 0; // now we have like 3 layers of interrupt loop :) while buf_offset < buf_len { - let interval = timeout.map(|dur| match deadline { - Some(d) => d - .checked_duration_since(Instant::now()) - // past the deadline already - .ok_or_else(|| timeout_error(vm)), - None => { - let dl = Instant::now() + dur; - deadline = Some(dl); - Ok(dur) - } - }); - let interval = interval.transpose()?; + let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?; self.sock_op_timeout(vm, SelectKind::Write, interval, || { bytes.with_ref(|b| { let subbuf = &b[buf_offset..]; @@ -1007,6 +986,24 @@ fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String { } } +struct Deadline { + deadline: Instant, +} + +impl Deadline { + fn new(timeout: Duration) -> Self { + Self { + deadline: Instant::now() + timeout, + } + } + fn time_until(&self, vm: &VirtualMachine) -> PyResult { + self.deadline + .checked_duration_since(Instant::now()) + // past the deadline already + .ok_or_else(|| timeout_error(vm)) + } +} + rustpython_common::static_cell! { static TIMEOUT_ERROR: PyTypeRef; static GAI_ERROR: PyTypeRef;