From 43457da22e1e8505544e69c680e7644c9f80c07e Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Mon, 3 May 2021 12:27:48 -0500 Subject: [PATCH] Add sock.connext_ex --- vm/src/stdlib/socket.rs | 107 +++++++++++++++++++++++++++++++++------- 1 file changed, 90 insertions(+), 17 deletions(-) diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index a20fde292..2dbd7d452 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -229,7 +229,21 @@ impl PySocket { Ok(()) } + #[inline] fn sock_op(&self, vm: &VirtualMachine, select: SelectKind, f: F) -> PyResult + where + F: FnMut() -> io::Result, + { + self.sock_op_err(vm, select, f) + .map_err(|e| e.into_pyexception(vm)) + } + + fn sock_op_err( + &self, + vm: &VirtualMachine, + select: SelectKind, + f: F, + ) -> Result where F: FnMut() -> io::Result, { @@ -239,16 +253,16 @@ impl PySocket { } else { None }; - self.sock_op_timeout(vm, select, timeout, f) + self.sock_op_timeout_err(vm, select, timeout, f) } - fn sock_op_timeout( + fn sock_op_timeout_err( &self, vm: &VirtualMachine, select: SelectKind, timeout: Option, mut f: F, - ) -> PyResult + ) -> Result where F: FnMut() -> io::Result, { @@ -256,15 +270,15 @@ impl PySocket { loop { if deadline.is_some() || matches!(select, SelectKind::Connect) { - let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?; + let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; let res = sock_select(&self.sock(), select, interval); match res { - Ok(true) => return Err(timeout_error(vm)), + Ok(true) => return Err(IoOrPyException::Timeout), Err(e) if e.kind() == io::ErrorKind::Interrupted => { vm.check_signals()?; continue; } - Err(e) => return Err(e.into_pyexception(vm)), + Err(e) => return Err(e.into()), Ok(false) => {} // no timeout, continue as normal } } @@ -280,7 +294,7 @@ impl PySocket { if timeout.is_some() && err.kind() == io::ErrorKind::WouldBlock { continue; } - return Err(err.into_pyexception(vm)); + return Err(err.into()); } } @@ -359,9 +373,13 @@ impl PySocket { } } - #[pymethod] - fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let sock_addr = self.extract_address(address, "connect", vm)?; + fn connect_inner( + &self, + address: PyObjectRef, + caller: &str, + vm: &VirtualMachine, + ) -> Result<(), IoOrPyException> { + let sock_addr = self.extract_address(address, caller, vm)?; let err = match self.sock().connect(&sock_addr) { Ok(()) => return Ok(()), @@ -384,7 +402,7 @@ impl PySocket { // basically, connect() is async, and it registers an "error" on the socket when it's // done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up // from poll and the error is EISCONN then we know that the connect is done - self.sock_op(vm, SelectKind::Connect, || { + self.sock_op_err(vm, SelectKind::Connect, || { let sock = self.sock(); let err = sock.take_error()?; match err { @@ -395,7 +413,21 @@ impl PySocket { } }) } else { - Err(err.into_pyexception(vm)) + Err(err.into()) + } + } + + #[pymethod] + fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + self.connect_inner(address, "connect", vm) + .map_err(|e| e.into_pyexception(vm)) + } + + #[pymethod] + fn connect_ex(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult { + match self.connect_inner(address, "connect_ex", vm) { + Ok(()) => Ok(0), + Err(err) => err.errno(), } } @@ -506,14 +538,18 @@ impl PySocket { let mut buf_offset = 0; // now we have like 3 layers of interrupt loop :) while buf_offset < buf_len { - let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?; - self.sock_op_timeout(vm, SelectKind::Write, interval, || { + let interval = deadline + .as_ref() + .map(|d| d.time_until().map_err(|e| e.into_pyexception(vm))) + .transpose()?; + self.sock_op_timeout_err(vm, SelectKind::Write, interval, || { bytes.with_ref(|b| { let subbuf = &b[buf_offset..]; buf_offset += self.sock().send_with_flags(subbuf, flags)?; Ok(()) }) - })?; + }) + .map_err(|e| e.into_pyexception(vm))?; vm.check_signals()?; } Ok(()) @@ -875,6 +911,43 @@ fn slice_as_uninit(v: &mut [T]) -> &mut [MaybeUninit] { unsafe { &mut *(v as *mut [T] as *mut [MaybeUninit]) } } +enum IoOrPyException { + Timeout, + Py(PyBaseExceptionRef), + Io(io::Error), +} +impl From for IoOrPyException { + fn from(exc: PyBaseExceptionRef) -> Self { + Self::Py(exc) + } +} +impl From for IoOrPyException { + fn from(err: io::Error) -> Self { + Self::Io(err) + } +} +impl IoOrPyException { + fn errno(self) -> PyResult { + match self { + Self::Timeout => Ok(errcode!(EWOULDBLOCK)), + Self::Io(err) => { + // TODO: just unwrap()? + Ok(err.raw_os_error().unwrap_or(1)) + } + Self::Py(exc) => Err(exc), + } + } +} +impl IntoPyException for IoOrPyException { + fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { + match self { + Self::Timeout => timeout_error(vm), + Self::Py(exc) => exc, + Self::Io(err) => err.into_pyexception(vm), + } + } +} + #[derive(Copy, Clone)] enum SelectKind { Read, @@ -1292,11 +1365,11 @@ impl Deadline { deadline: Instant::now() + timeout, } } - fn time_until(&self, vm: &VirtualMachine) -> PyResult { + fn time_until(&self) -> Result { self.deadline .checked_duration_since(Instant::now()) // past the deadline already - .ok_or_else(|| timeout_error(vm)) + .ok_or(IoOrPyException::Timeout) } }