From 06b40b360220b01e96c904a211b1b90aad3a50cf Mon Sep 17 00:00:00 2001 From: Dean Li Date: Sat, 7 Aug 2021 10:19:48 +0800 Subject: [PATCH] socket: impl herror There is one todo that is waiting for hstrerror to land in libc crate. For now it shows the same error as cpython when hstrerror is not available. --- Lib/test/test_exception_hierarchy.py | 2 -- Lib/test/test_socket.py | 4 --- vm/src/stdlib/socket.rs | 54 ++++++++++++++++++++++------ 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/Lib/test/test_exception_hierarchy.py b/Lib/test/test_exception_hierarchy.py index 8041b254f..268d9e82b 100644 --- a/Lib/test/test_exception_hierarchy.py +++ b/Lib/test/test_exception_hierarchy.py @@ -39,8 +39,6 @@ class HierarchyTest(unittest.TestCase): self.assertIs(IOError, OSError) self.assertIs(EnvironmentError, OSError) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_socket_errors(self): self.assertIs(socket.error, IOError) self.assertIs(socket.gaierror.__base__, OSError) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 9c4062825..48c7ed4e0 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -812,8 +812,6 @@ class GeneralModuleTests(unittest.TestCase): else: self.fail('Socket proxy still exists') - # TODO: RUSTPYTHON, socket.herror - @unittest.expectedFailure def testSocketError(self): # Testing socket module exceptions msg = "Error raising socket exception (%s)." @@ -5037,8 +5035,6 @@ class UDPTimeoutTest(SocketUDPTest): class TestExceptions(unittest.TestCase): - # TODO: RUSTPYTHON, socket.herror - @unittest.expectedFailure def testExceptionTree(self): self.assertTrue(issubclass(OSError, Exception)) self.assertTrue(issubclass(socket.herror, OSError)) diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 071204310..7ca06d031 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -1262,7 +1262,7 @@ fn _socket_getaddrinfo(opts: GAIOptions, vm: &VirtualMachine) -> PyResult { let port = port.as_ref().map(|p| p.as_ref()); let addrs = dns_lookup::getaddrinfo(host, port, Some(hints)) - .map_err(|err| convert_gai_error(vm, err))?; + .map_err(|err| convert_socket_error(vm, err, SocketError::GaiError))?; let list = addrs .map(|ai| { @@ -1287,7 +1287,8 @@ fn _socket_gethostbyaddr( ) -> PyResult<(String, PyObjectRef, PyObjectRef)> { // TODO: figure out how to do this properly let addr = get_addr(vm, addr, c::AF_UNSPEC)?; - let (hostname, _) = dns_lookup::getnameinfo(&addr, 0).map_err(|e| convert_gai_error(vm, e))?; + let (hostname, _) = dns_lookup::getnameinfo(&addr, 0) + .map_err(|e| convert_socket_error(vm, e, SocketError::HError))?; Ok(( hostname, vm.ctx.new_list(vec![]), @@ -1375,7 +1376,7 @@ fn _socket_getnameinfo( }; let service = addr.port.to_string(); let mut res = dns_lookup::getaddrinfo(Some(addr.host.as_str()), Some(&service), Some(hints)) - .map_err(|e| convert_gai_error(vm, e))? + .map_err(|e| convert_socket_error(vm, e, SocketError::GaiError))? .filter_map(Result::ok); let mut ainfo = res.next().unwrap(); if res.next().is_some() { @@ -1392,7 +1393,8 @@ fn _socket_getnameinfo( addr.set_scope_id(scopeid); } } - dns_lookup::getnameinfo(&ainfo.sockaddr, flags).map_err(|e| convert_gai_error(vm, e)) + dns_lookup::getnameinfo(&ainfo.sockaddr, flags) + .map_err(|e| convert_socket_error(vm, e, SocketError::GaiError)) } #[cfg(unix)] @@ -1585,7 +1587,7 @@ fn get_addr(vm: &VirtualMachine, pyname: PyStrRef, af: i32) -> PyResult PyResult PyBaseExceptionRef { +fn convert_socket_error( + vm: &VirtualMachine, + err: dns_lookup::LookupError, + err_kind: SocketError, +) -> PyBaseExceptionRef { if let dns_lookup::LookupErrorKind::System = err.kind() { return io::Error::from(err).into_pyexception(vm); } let strerr = { #[cfg(unix)] { - let s = unsafe { ffi::CStr::from_ptr(libc::gai_strerror(err.error_num())) }; - std::str::from_utf8(s.to_bytes()).unwrap() + match err_kind { + SocketError::GaiError => { + let s = unsafe { ffi::CStr::from_ptr(libc::gai_strerror(err.error_num())) }; + std::str::from_utf8(s.to_bytes()).unwrap() + } + SocketError::HError => { + // TODO: wait for libc hstrerror to land (https://github.com/rust-lang/libc/pull/2323) + "host not found" + } + } } #[cfg(windows)] { "getaddrinfo failed" } }; + let exception_cls = match err_kind { + SocketError::GaiError => &GAI_ERROR, + SocketError::HError => &HERROR, + }; vm.new_exception( - GAI_ERROR.get().unwrap().clone(), + exception_cls.get().unwrap().clone(), vec![vm.ctx.new_int(err.error_num()), vm.ctx.new_str(strerr)], ) } @@ -1794,8 +1812,14 @@ fn close_inner(x: RawSocket, vm: &VirtualMachine) -> PyResult<()> { Ok(()) } +enum SocketError { + HError, + GaiError, +} + rustpython_common::static_cell! { static TIMEOUT_ERROR: PyTypeRef; + static HERROR: PyTypeRef; static GAI_ERROR: PyTypeRef; } @@ -1812,6 +1836,15 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { ) }) .clone(); + let socket_herror = HERROR + .get_or_init(|| { + ctx.new_class( + "socket.herror", + &vm.ctx.exceptions.os_error, + Default::default(), + ) + }) + .clone(); let socket_gaierror = GAI_ERROR .get_or_init(|| { ctx.new_class( @@ -1828,6 +1861,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "SocketType" => socket, "error" => ctx.exceptions.os_error.clone(), "timeout" => socket_timeout, + "herror" => socket_herror, "gaierror" => socket_gaierror, "inet_aton" => named_function!(ctx, _socket, inet_aton), "inet_ntoa" => named_function!(ctx, _socket, inet_ntoa),