Merge pull request #2798 from deantvv/impl-socket-herror

socket: impl herror
This commit is contained in:
Jim Fasarakis-Hilliard
2021-08-07 12:50:44 +03:00
committed by GitHub
3 changed files with 44 additions and 16 deletions

View File

@@ -39,8 +39,6 @@ class HierarchyTest(unittest.TestCase):
self.assertIs(IOError, OSError)
self.assertIs(EnvironmentError, OSError)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_socket_errors(self):
self.assertIs(socket.error, IOError)
self.assertIs(socket.gaierror.__base__, OSError)

View File

@@ -812,8 +812,6 @@ class GeneralModuleTests(unittest.TestCase):
else:
self.fail('Socket proxy still exists')
# TODO: RUSTPYTHON, socket.herror
@unittest.expectedFailure
def testSocketError(self):
# Testing socket module exceptions
msg = "Error raising socket exception (%s)."
@@ -5037,8 +5035,6 @@ class UDPTimeoutTest(SocketUDPTest):
class TestExceptions(unittest.TestCase):
# TODO: RUSTPYTHON, socket.herror
@unittest.expectedFailure
def testExceptionTree(self):
self.assertTrue(issubclass(OSError, Exception))
self.assertTrue(issubclass(socket.herror, OSError))

View File

@@ -1262,7 +1262,7 @@ fn _socket_getaddrinfo(opts: GAIOptions, vm: &VirtualMachine) -> PyResult {
let port = port.as_ref().map(|p| p.as_ref());
let addrs = dns_lookup::getaddrinfo(host, port, Some(hints))
.map_err(|err| convert_gai_error(vm, err))?;
.map_err(|err| convert_socket_error(vm, err, SocketError::GaiError))?;
let list = addrs
.map(|ai| {
@@ -1287,7 +1287,8 @@ fn _socket_gethostbyaddr(
) -> PyResult<(String, PyObjectRef, PyObjectRef)> {
// TODO: figure out how to do this properly
let addr = get_addr(vm, addr, c::AF_UNSPEC)?;
let (hostname, _) = dns_lookup::getnameinfo(&addr, 0).map_err(|e| convert_gai_error(vm, e))?;
let (hostname, _) = dns_lookup::getnameinfo(&addr, 0)
.map_err(|e| convert_socket_error(vm, e, SocketError::HError))?;
Ok((
hostname,
vm.ctx.new_list(vec![]),
@@ -1375,7 +1376,7 @@ fn _socket_getnameinfo(
};
let service = addr.port.to_string();
let mut res = dns_lookup::getaddrinfo(Some(addr.host.as_str()), Some(&service), Some(hints))
.map_err(|e| convert_gai_error(vm, e))?
.map_err(|e| convert_socket_error(vm, e, SocketError::GaiError))?
.filter_map(Result::ok);
let mut ainfo = res.next().unwrap();
if res.next().is_some() {
@@ -1392,7 +1393,8 @@ fn _socket_getnameinfo(
addr.set_scope_id(scopeid);
}
}
dns_lookup::getnameinfo(&ainfo.sockaddr, flags).map_err(|e| convert_gai_error(vm, e))
dns_lookup::getnameinfo(&ainfo.sockaddr, flags)
.map_err(|e| convert_socket_error(vm, e, SocketError::GaiError))
}
#[cfg(unix)]
@@ -1585,7 +1587,7 @@ fn get_addr(vm: &VirtualMachine, pyname: PyStrRef, af: i32) -> PyResult<SocketAd
protocol: 0,
};
let mut res = dns_lookup::getaddrinfo(None, Some("0"), Some(hints))
.map_err(|e| convert_gai_error(vm, e))?;
.map_err(|e| convert_socket_error(vm, e, SocketError::GaiError))?;
let ainfo = res.next().unwrap().map_err(|e| e.into_pyexception(vm))?;
if res.next().is_some() {
return Err(vm.new_os_error("wildcard resolved to multiple address".to_owned()));
@@ -1623,7 +1625,7 @@ fn get_addr(vm: &VirtualMachine, pyname: PyStrRef, af: i32) -> PyResult<SocketAd
let name = std::str::from_utf8(name.as_bytes())
.map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?;
let mut res = dns_lookup::getaddrinfo(Some(name), None, Some(hints))
.map_err(|e| convert_gai_error(vm, e))?;
.map_err(|e| convert_socket_error(vm, e, SocketError::GaiError))?;
res.next()
.unwrap()
.map(|ainfo| ainfo.sockaddr)
@@ -1694,23 +1696,39 @@ pub(super) const INVALID_SOCKET: RawSocket = {
}
};
fn convert_gai_error(vm: &VirtualMachine, err: dns_lookup::LookupError) -> PyBaseExceptionRef {
fn convert_socket_error(
vm: &VirtualMachine,
err: dns_lookup::LookupError,
err_kind: SocketError,
) -> PyBaseExceptionRef {
if let dns_lookup::LookupErrorKind::System = err.kind() {
return io::Error::from(err).into_pyexception(vm);
}
let strerr = {
#[cfg(unix)]
{
let s = unsafe { ffi::CStr::from_ptr(libc::gai_strerror(err.error_num())) };
std::str::from_utf8(s.to_bytes()).unwrap()
match err_kind {
SocketError::GaiError => {
let s = unsafe { ffi::CStr::from_ptr(libc::gai_strerror(err.error_num())) };
std::str::from_utf8(s.to_bytes()).unwrap()
}
SocketError::HError => {
// TODO: wait for libc hstrerror to land (https://github.com/rust-lang/libc/pull/2323)
"host not found"
}
}
}
#[cfg(windows)]
{
"getaddrinfo failed"
}
};
let exception_cls = match err_kind {
SocketError::GaiError => &GAI_ERROR,
SocketError::HError => &HERROR,
};
vm.new_exception(
GAI_ERROR.get().unwrap().clone(),
exception_cls.get().unwrap().clone(),
vec![vm.ctx.new_int(err.error_num()), vm.ctx.new_str(strerr)],
)
}
@@ -1794,8 +1812,14 @@ fn close_inner(x: RawSocket, vm: &VirtualMachine) -> PyResult<()> {
Ok(())
}
enum SocketError {
HError,
GaiError,
}
rustpython_common::static_cell! {
static TIMEOUT_ERROR: PyTypeRef;
static HERROR: PyTypeRef;
static GAI_ERROR: PyTypeRef;
}
@@ -1812,6 +1836,15 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
)
})
.clone();
let socket_herror = HERROR
.get_or_init(|| {
ctx.new_class(
"socket.herror",
&vm.ctx.exceptions.os_error,
Default::default(),
)
})
.clone();
let socket_gaierror = GAI_ERROR
.get_or_init(|| {
ctx.new_class(
@@ -1828,6 +1861,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
"SocketType" => socket,
"error" => ctx.exceptions.os_error.clone(),
"timeout" => socket_timeout,
"herror" => socket_herror,
"gaierror" => socket_gaierror,
"inet_aton" => named_function!(ctx, _socket, inet_aton),
"inet_ntoa" => named_function!(ctx, _socket, inet_ntoa),