From 56d4654961db83215d37cd5883e5142b018bea33 Mon Sep 17 00:00:00 2001 From: Padraic Fanning Date: Sat, 17 Apr 2021 23:03:55 -0400 Subject: [PATCH 1/2] Add test_socketserver from CPython 3.8 --- Lib/test/test_socketserver.py | 518 ++++++++++++++++++++++++++++++++++ 1 file changed, 518 insertions(+) create mode 100644 Lib/test/test_socketserver.py diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py new file mode 100644 index 000000000..1e5224df0 --- /dev/null +++ b/Lib/test/test_socketserver.py @@ -0,0 +1,518 @@ +""" +Test suite for socketserver. +""" + +import contextlib +import io +import os +import select +import signal +import socket +import tempfile +import threading +import unittest +import socketserver + +import test.support +from test.support import reap_children, reap_threads, verbose + + +test.support.requires("network") + +TEST_STR = b"hello world\n" +HOST = test.support.HOST + +HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") +requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, + 'requires Unix sockets') +HAVE_FORKING = hasattr(os, "fork") +requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') + +def signal_alarm(n): + """Call signal.alarm when it exists (i.e. not on Windows).""" + if hasattr(signal, 'alarm'): + signal.alarm(n) + +# Remember real select() to avoid interferences with mocking +_real_select = select.select + +def receive(sock, n, timeout=20): + r, w, x = _real_select([sock], [], [], timeout) + if sock in r: + return sock.recv(n) + else: + raise RuntimeError("timed out on %r" % (sock,)) + +if HAVE_UNIX_SOCKETS and HAVE_FORKING: + class ForkingUnixStreamServer(socketserver.ForkingMixIn, + socketserver.UnixStreamServer): + pass + + class ForkingUnixDatagramServer(socketserver.ForkingMixIn, + socketserver.UnixDatagramServer): + pass + + +@contextlib.contextmanager +def simple_subprocess(testcase): + """Tests that a custom child process is not waited on (Issue 1540386)""" + pid = os.fork() + if pid == 0: + # Don't raise an exception; it would be caught by the test harness. + os._exit(72) + try: + yield None + except: + raise + finally: + pid2, status = os.waitpid(pid, 0) + testcase.assertEqual(pid2, pid) + testcase.assertEqual(72 << 8, status) + + +class SocketServerTest(unittest.TestCase): + """Test all socket servers.""" + + def setUp(self): + signal_alarm(60) # Kill deadlocks after 60 seconds. + self.port_seed = 0 + self.test_files = [] + + def tearDown(self): + signal_alarm(0) # Didn't deadlock. + reap_children() + + for fn in self.test_files: + try: + os.remove(fn) + except OSError: + pass + self.test_files[:] = [] + + def pickaddr(self, proto): + if proto == socket.AF_INET: + return (HOST, 0) + else: + # XXX: We need a way to tell AF_UNIX to pick its own name + # like AF_INET provides port==0. + dir = None + fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) + self.test_files.append(fn) + return fn + + def make_server(self, addr, svrcls, hdlrbase): + class MyServer(svrcls): + def handle_error(self, request, client_address): + self.close_request(request) + raise + + class MyHandler(hdlrbase): + def handle(self): + line = self.rfile.readline() + self.wfile.write(line) + + if verbose: print("creating server") + try: + server = MyServer(addr, MyHandler) + except PermissionError as e: + # Issue 29184: cannot bind() a Unix socket on Android. + self.skipTest('Cannot create server (%s, %s): %s' % + (svrcls, addr, e)) + self.assertEqual(server.server_address, server.socket.getsockname()) + return server + + @reap_threads + def run_server(self, svrcls, hdlrbase, testfunc): + server = self.make_server(self.pickaddr(svrcls.address_family), + svrcls, hdlrbase) + # We had the OS pick a port, so pull the real address out of + # the server. + addr = server.server_address + if verbose: + print("ADDR =", addr) + print("CLASS =", svrcls) + + t = threading.Thread( + name='%s serving' % svrcls, + target=server.serve_forever, + # Short poll interval to make the test finish quickly. + # Time between requests is short enough that we won't wake + # up spuriously too many times. + kwargs={'poll_interval':0.01}) + t.daemon = True # In case this function raises. + t.start() + if verbose: print("server running") + for i in range(3): + if verbose: print("test client", i) + testfunc(svrcls.address_family, addr) + if verbose: print("waiting for server") + server.shutdown() + t.join() + server.server_close() + self.assertEqual(-1, server.socket.fileno()) + if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): + # bpo-31151: Check that ForkingMixIn.server_close() waits until + # all children completed + self.assertFalse(server.active_children) + if verbose: print("done") + + def stream_examine(self, proto, addr): + with socket.socket(proto, socket.SOCK_STREAM) as s: + s.connect(addr) + s.sendall(TEST_STR) + buf = data = receive(s, 100) + while data and b'\n' not in buf: + data = receive(s, 100) + buf += data + self.assertEqual(buf, TEST_STR) + + def dgram_examine(self, proto, addr): + with socket.socket(proto, socket.SOCK_DGRAM) as s: + if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: + s.bind(self.pickaddr(proto)) + s.sendto(TEST_STR, addr) + buf = data = receive(s, 100) + while data and b'\n' not in buf: + data = receive(s, 100) + buf += data + self.assertEqual(buf, TEST_STR) + + def test_TCPServer(self): + self.run_server(socketserver.TCPServer, + socketserver.StreamRequestHandler, + self.stream_examine) + + def test_ThreadingTCPServer(self): + self.run_server(socketserver.ThreadingTCPServer, + socketserver.StreamRequestHandler, + self.stream_examine) + + @requires_forking + def test_ForkingTCPServer(self): + with simple_subprocess(self): + self.run_server(socketserver.ForkingTCPServer, + socketserver.StreamRequestHandler, + self.stream_examine) + + @requires_unix_sockets + def test_UnixStreamServer(self): + self.run_server(socketserver.UnixStreamServer, + socketserver.StreamRequestHandler, + self.stream_examine) + + @requires_unix_sockets + def test_ThreadingUnixStreamServer(self): + self.run_server(socketserver.ThreadingUnixStreamServer, + socketserver.StreamRequestHandler, + self.stream_examine) + + @requires_unix_sockets + @requires_forking + def test_ForkingUnixStreamServer(self): + with simple_subprocess(self): + self.run_server(ForkingUnixStreamServer, + socketserver.StreamRequestHandler, + self.stream_examine) + + def test_UDPServer(self): + self.run_server(socketserver.UDPServer, + socketserver.DatagramRequestHandler, + self.dgram_examine) + + def test_ThreadingUDPServer(self): + self.run_server(socketserver.ThreadingUDPServer, + socketserver.DatagramRequestHandler, + self.dgram_examine) + + @requires_forking + def test_ForkingUDPServer(self): + with simple_subprocess(self): + self.run_server(socketserver.ForkingUDPServer, + socketserver.DatagramRequestHandler, + self.dgram_examine) + + @requires_unix_sockets + def test_UnixDatagramServer(self): + self.run_server(socketserver.UnixDatagramServer, + socketserver.DatagramRequestHandler, + self.dgram_examine) + + @requires_unix_sockets + def test_ThreadingUnixDatagramServer(self): + self.run_server(socketserver.ThreadingUnixDatagramServer, + socketserver.DatagramRequestHandler, + self.dgram_examine) + + @requires_unix_sockets + @requires_forking + def test_ForkingUnixDatagramServer(self): + self.run_server(ForkingUnixDatagramServer, + socketserver.DatagramRequestHandler, + self.dgram_examine) + + @reap_threads + def test_shutdown(self): + # Issue #2302: shutdown() should always succeed in making an + # other thread leave serve_forever(). + class MyServer(socketserver.TCPServer): + pass + + class MyHandler(socketserver.StreamRequestHandler): + pass + + threads = [] + for i in range(20): + s = MyServer((HOST, 0), MyHandler) + t = threading.Thread( + name='MyServer serving', + target=s.serve_forever, + kwargs={'poll_interval':0.01}) + t.daemon = True # In case this function raises. + threads.append((t, s)) + for t, s in threads: + t.start() + s.shutdown() + for t, s in threads: + t.join() + s.server_close() + + def test_close_immediately(self): + class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): + pass + + server = MyServer((HOST, 0), lambda: None) + server.server_close() + + def test_tcpserver_bind_leak(self): + # Issue #22435: the server socket wouldn't be closed if bind()/listen() + # failed. + # Create many servers for which bind() will fail, to see if this result + # in FD exhaustion. + for i in range(1024): + with self.assertRaises(OverflowError): + socketserver.TCPServer((HOST, -1), + socketserver.StreamRequestHandler) + + def test_context_manager(self): + with socketserver.TCPServer((HOST, 0), + socketserver.StreamRequestHandler) as server: + pass + self.assertEqual(-1, server.socket.fileno()) + + +class ErrorHandlerTest(unittest.TestCase): + """Test that the servers pass normal exceptions from the handler to + handle_error(), and that exiting exceptions like SystemExit and + KeyboardInterrupt are not passed.""" + + def tearDown(self): + test.support.unlink(test.support.TESTFN) + + def test_sync_handled(self): + BaseErrorTestServer(ValueError) + self.check_result(handled=True) + + def test_sync_not_handled(self): + with self.assertRaises(SystemExit): + BaseErrorTestServer(SystemExit) + self.check_result(handled=False) + + def test_threading_handled(self): + ThreadingErrorTestServer(ValueError) + self.check_result(handled=True) + + def test_threading_not_handled(self): + ThreadingErrorTestServer(SystemExit) + self.check_result(handled=False) + + @requires_forking + def test_forking_handled(self): + ForkingErrorTestServer(ValueError) + self.check_result(handled=True) + + @requires_forking + def test_forking_not_handled(self): + ForkingErrorTestServer(SystemExit) + self.check_result(handled=False) + + def check_result(self, handled): + with open(test.support.TESTFN) as log: + expected = 'Handler called\n' + 'Error handled\n' * handled + self.assertEqual(log.read(), expected) + + +class BaseErrorTestServer(socketserver.TCPServer): + def __init__(self, exception): + self.exception = exception + super().__init__((HOST, 0), BadHandler) + with socket.create_connection(self.server_address): + pass + try: + self.handle_request() + finally: + self.server_close() + self.wait_done() + + def handle_error(self, request, client_address): + with open(test.support.TESTFN, 'a') as log: + log.write('Error handled\n') + + def wait_done(self): + pass + + +class BadHandler(socketserver.BaseRequestHandler): + def handle(self): + with open(test.support.TESTFN, 'a') as log: + log.write('Handler called\n') + raise self.server.exception('Test error') + + +class ThreadingErrorTestServer(socketserver.ThreadingMixIn, + BaseErrorTestServer): + def __init__(self, *pos, **kw): + self.done = threading.Event() + super().__init__(*pos, **kw) + + def shutdown_request(self, *pos, **kw): + super().shutdown_request(*pos, **kw) + self.done.set() + + def wait_done(self): + self.done.wait() + + +if HAVE_FORKING: + class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): + pass + + +class SocketWriterTest(unittest.TestCase): + def test_basics(self): + class Handler(socketserver.StreamRequestHandler): + def handle(self): + self.server.wfile = self.wfile + self.server.wfile_fileno = self.wfile.fileno() + self.server.request_fileno = self.request.fileno() + + server = socketserver.TCPServer((HOST, 0), Handler) + self.addCleanup(server.server_close) + s = socket.socket( + server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + with s: + s.connect(server.server_address) + server.handle_request() + self.assertIsInstance(server.wfile, io.BufferedIOBase) + self.assertEqual(server.wfile_fileno, server.request_fileno) + + def test_write(self): + # Test that wfile.write() sends data immediately, and that it does + # not truncate sends when interrupted by a Unix signal + pthread_kill = test.support.get_attribute(signal, 'pthread_kill') + + class Handler(socketserver.StreamRequestHandler): + def handle(self): + self.server.sent1 = self.wfile.write(b'write data\n') + # Should be sent immediately, without requiring flush() + self.server.received = self.rfile.readline() + big_chunk = b'\0' * test.support.SOCK_MAX_SIZE + self.server.sent2 = self.wfile.write(big_chunk) + + server = socketserver.TCPServer((HOST, 0), Handler) + self.addCleanup(server.server_close) + interrupted = threading.Event() + + def signal_handler(signum, frame): + interrupted.set() + + original = signal.signal(signal.SIGUSR1, signal_handler) + self.addCleanup(signal.signal, signal.SIGUSR1, original) + response1 = None + received2 = None + main_thread = threading.get_ident() + + def run_client(): + s = socket.socket(server.address_family, socket.SOCK_STREAM, + socket.IPPROTO_TCP) + with s, s.makefile('rb') as reader: + s.connect(server.server_address) + nonlocal response1 + response1 = reader.readline() + s.sendall(b'client response\n') + + reader.read(100) + # The main thread should now be blocking in a send() syscall. + # But in theory, it could get interrupted by other signals, + # and then retried. So keep sending the signal in a loop, in + # case an earlier signal happens to be delivered at an + # inconvenient moment. + while True: + pthread_kill(main_thread, signal.SIGUSR1) + if interrupted.wait(timeout=float(1)): + break + nonlocal received2 + received2 = len(reader.read()) + + background = threading.Thread(target=run_client) + background.start() + server.handle_request() + background.join() + self.assertEqual(server.sent1, len(response1)) + self.assertEqual(response1, b'write data\n') + self.assertEqual(server.received, b'client response\n') + self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) + self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) + + +class MiscTestCase(unittest.TestCase): + + def test_all(self): + # objects defined in the module should be in __all__ + expected = [] + for name in dir(socketserver): + if not name.startswith('_'): + mod_object = getattr(socketserver, name) + if getattr(mod_object, '__module__', None) == 'socketserver': + expected.append(name) + self.assertCountEqual(socketserver.__all__, expected) + + def test_shutdown_request_called_if_verify_request_false(self): + # Issue #26309: BaseServer should call shutdown_request even if + # verify_request is False + + class MyServer(socketserver.TCPServer): + def verify_request(self, request, client_address): + return False + + shutdown_called = 0 + def shutdown_request(self, request): + self.shutdown_called += 1 + socketserver.TCPServer.shutdown_request(self, request) + + server = MyServer((HOST, 0), socketserver.StreamRequestHandler) + s = socket.socket(server.address_family, socket.SOCK_STREAM) + s.connect(server.server_address) + s.close() + server.handle_request() + self.assertEqual(server.shutdown_called, 1) + server.server_close() + + def test_threads_reaped(self): + """ + In #37193, users reported a memory leak + due to the saving of every request thread. Ensure that + not all threads are kept forever. + """ + class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): + pass + + server = MyServer((HOST, 0), socketserver.StreamRequestHandler) + for n in range(10): + with socket.create_connection(server.server_address): + server.handle_request() + self.assertLess(len(server._threads), 10) + server.server_close() + + +if __name__ == "__main__": + unittest.main() From 7b3f1a7e8017a7112e1956f3c4ce5a295af0c59b Mon Sep 17 00:00:00 2001 From: Padraic Fanning Date: Sat, 17 Apr 2021 23:05:16 -0400 Subject: [PATCH 2/2] Mark erroring test --- Lib/test/test_socketserver.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 1e5224df0..63fc7c597 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -497,6 +497,8 @@ class MiscTestCase(unittest.TestCase): self.assertEqual(server.shutdown_called, 1) server.server_close() + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_threads_reaped(self): """ In #37193, users reported a memory leak