Add sock.connext_ex

This commit is contained in:
Noah
2021-05-03 12:27:48 -05:00
parent 8ac4a8fa7a
commit 43457da22e

View File

@@ -229,7 +229,21 @@ impl PySocket {
Ok(())
}
#[inline]
fn sock_op<F, R>(&self, vm: &VirtualMachine, select: SelectKind, f: F) -> PyResult<R>
where
F: FnMut() -> io::Result<R>,
{
self.sock_op_err(vm, select, f)
.map_err(|e| e.into_pyexception(vm))
}
fn sock_op_err<F, R>(
&self,
vm: &VirtualMachine,
select: SelectKind,
f: F,
) -> Result<R, IoOrPyException>
where
F: FnMut() -> io::Result<R>,
{
@@ -239,16 +253,16 @@ impl PySocket {
} else {
None
};
self.sock_op_timeout(vm, select, timeout, f)
self.sock_op_timeout_err(vm, select, timeout, f)
}
fn sock_op_timeout<F, R>(
fn sock_op_timeout_err<F, R>(
&self,
vm: &VirtualMachine,
select: SelectKind,
timeout: Option<Duration>,
mut f: F,
) -> PyResult<R>
) -> Result<R, IoOrPyException>
where
F: FnMut() -> io::Result<R>,
{
@@ -256,15 +270,15 @@ impl PySocket {
loop {
if deadline.is_some() || matches!(select, SelectKind::Connect) {
let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?;
let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?;
let res = sock_select(&self.sock(), select, interval);
match res {
Ok(true) => return Err(timeout_error(vm)),
Ok(true) => return Err(IoOrPyException::Timeout),
Err(e) if e.kind() == io::ErrorKind::Interrupted => {
vm.check_signals()?;
continue;
}
Err(e) => return Err(e.into_pyexception(vm)),
Err(e) => return Err(e.into()),
Ok(false) => {} // no timeout, continue as normal
}
}
@@ -280,7 +294,7 @@ impl PySocket {
if timeout.is_some() && err.kind() == io::ErrorKind::WouldBlock {
continue;
}
return Err(err.into_pyexception(vm));
return Err(err.into());
}
}
@@ -359,9 +373,13 @@ impl PySocket {
}
}
#[pymethod]
fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let sock_addr = self.extract_address(address, "connect", vm)?;
fn connect_inner(
&self,
address: PyObjectRef,
caller: &str,
vm: &VirtualMachine,
) -> Result<(), IoOrPyException> {
let sock_addr = self.extract_address(address, caller, vm)?;
let err = match self.sock().connect(&sock_addr) {
Ok(()) => return Ok(()),
@@ -384,7 +402,7 @@ impl PySocket {
// basically, connect() is async, and it registers an "error" on the socket when it's
// done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up
// from poll and the error is EISCONN then we know that the connect is done
self.sock_op(vm, SelectKind::Connect, || {
self.sock_op_err(vm, SelectKind::Connect, || {
let sock = self.sock();
let err = sock.take_error()?;
match err {
@@ -395,7 +413,21 @@ impl PySocket {
}
})
} else {
Err(err.into_pyexception(vm))
Err(err.into())
}
}
#[pymethod]
fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.connect_inner(address, "connect", vm)
.map_err(|e| e.into_pyexception(vm))
}
#[pymethod]
fn connect_ex(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<i32> {
match self.connect_inner(address, "connect_ex", vm) {
Ok(()) => Ok(0),
Err(err) => err.errno(),
}
}
@@ -506,14 +538,18 @@ impl PySocket {
let mut buf_offset = 0;
// now we have like 3 layers of interrupt loop :)
while buf_offset < buf_len {
let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?;
self.sock_op_timeout(vm, SelectKind::Write, interval, || {
let interval = deadline
.as_ref()
.map(|d| d.time_until().map_err(|e| e.into_pyexception(vm)))
.transpose()?;
self.sock_op_timeout_err(vm, SelectKind::Write, interval, || {
bytes.with_ref(|b| {
let subbuf = &b[buf_offset..];
buf_offset += self.sock().send_with_flags(subbuf, flags)?;
Ok(())
})
})?;
})
.map_err(|e| e.into_pyexception(vm))?;
vm.check_signals()?;
}
Ok(())
@@ -875,6 +911,43 @@ fn slice_as_uninit<T>(v: &mut [T]) -> &mut [MaybeUninit<T>] {
unsafe { &mut *(v as *mut [T] as *mut [MaybeUninit<T>]) }
}
enum IoOrPyException {
Timeout,
Py(PyBaseExceptionRef),
Io(io::Error),
}
impl From<PyBaseExceptionRef> for IoOrPyException {
fn from(exc: PyBaseExceptionRef) -> Self {
Self::Py(exc)
}
}
impl From<io::Error> for IoOrPyException {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
impl IoOrPyException {
fn errno(self) -> PyResult<i32> {
match self {
Self::Timeout => Ok(errcode!(EWOULDBLOCK)),
Self::Io(err) => {
// TODO: just unwrap()?
Ok(err.raw_os_error().unwrap_or(1))
}
Self::Py(exc) => Err(exc),
}
}
}
impl IntoPyException for IoOrPyException {
fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef {
match self {
Self::Timeout => timeout_error(vm),
Self::Py(exc) => exc,
Self::Io(err) => err.into_pyexception(vm),
}
}
}
#[derive(Copy, Clone)]
enum SelectKind {
Read,
@@ -1292,11 +1365,11 @@ impl Deadline {
deadline: Instant::now() + timeout,
}
}
fn time_until(&self, vm: &VirtualMachine) -> PyResult<Duration> {
fn time_until(&self) -> Result<Duration, IoOrPyException> {
self.deadline
.checked_duration_since(Instant::now())
// past the deadline already
.ok_or_else(|| timeout_error(vm))
.ok_or(IoOrPyException::Timeout)
}
}