diff --git a/Lib/multiprocessing/__init__.py b/Lib/multiprocessing/__init__.py index 86df638370..8336f381de 100644 --- a/Lib/multiprocessing/__init__.py +++ b/Lib/multiprocessing/__init__.py @@ -19,9 +19,8 @@ from . import context # Copy stuff from default context # -globals().update((name, getattr(context._default_context, name)) - for name in context._default_context.__all__) -__all__ = context._default_context.__all__ +__all__ = [x for x in dir(context._default_context) if not x.startswith('_')] +globals().update((name, getattr(context._default_context, name)) for name in __all__) # # XXX These should not really be documented or public. diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py index 3f4ded2301..510e4b5aba 100644 --- a/Lib/multiprocessing/connection.py +++ b/Lib/multiprocessing/connection.py @@ -57,10 +57,10 @@ if sys.platform == 'win32': def _init_timeout(timeout=CONNECTION_TIMEOUT): - return time.time() + timeout + return time.monotonic() + timeout def _check_timeout(t): - return time.time() > t + return time.monotonic() > t # # @@ -73,6 +73,11 @@ def arbitrary_address(family): if family == 'AF_INET': return ('localhost', 0) elif family == 'AF_UNIX': + # Prefer abstract sockets if possible to avoid problems with the address + # size. When coding portable applications, some implementations have + # sun_path as short as 92 bytes in the sockaddr_un struct. + if util.abstract_sockets_supported: + return f"\0listener-{os.getpid()}-{next(_mmap_counter)}" return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir()) elif family == 'AF_PIPE': return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' % @@ -102,7 +107,7 @@ def address_type(address): return 'AF_INET' elif type(address) is str and address.startswith('\\\\'): return 'AF_PIPE' - elif type(address) is str: + elif type(address) is str or util.is_abstract_socket_namespace(address): return 'AF_UNIX' else: raise ValueError('address type of %r unrecognized' % address) @@ -389,23 +394,33 @@ class Connection(_ConnectionBase): def _send_bytes(self, buf): n = len(buf) - # For wire compatibility with 3.2 and lower - header = struct.pack("!i", n) - if n > 16384: - # The payload is large so Nagle's algorithm won't be triggered - # and we'd better avoid the cost of concatenation. + if n > 0x7fffffff: + pre_header = struct.pack("!i", -1) + header = struct.pack("!Q", n) + self._send(pre_header) self._send(header) self._send(buf) else: - # Issue #20540: concatenate before sending, to avoid delays due - # to Nagle's algorithm on a TCP socket. - # Also note we want to avoid sending a 0-length buffer separately, - # to avoid "broken pipe" errors if the other end closed the pipe. - self._send(header + buf) + # For wire compatibility with 3.7 and lower + header = struct.pack("!i", n) + if n > 16384: + # The payload is large so Nagle's algorithm won't be triggered + # and we'd better avoid the cost of concatenation. + self._send(header) + self._send(buf) + else: + # Issue #20540: concatenate before sending, to avoid delays due + # to Nagle's algorithm on a TCP socket. + # Also note we want to avoid sending a 0-length buffer separately, + # to avoid "broken pipe" errors if the other end closed the pipe. + self._send(header + buf) def _recv_bytes(self, maxsize=None): buf = self._recv(4) size, = struct.unpack("!i", buf.getvalue()) + if size == -1: + buf = self._recv(8) + size, = struct.unpack("!Q", buf.getvalue()) if maxsize is not None and size > maxsize: return None return self._recv(size) @@ -465,8 +480,13 @@ class Listener(object): self._listener = None listener.close() - address = property(lambda self: self._listener._address) - last_accepted = property(lambda self: self._listener._last_accepted) + @property + def address(self): + return self._listener._address + + @property + def last_accepted(self): + return self._listener._last_accepted def __enter__(self): return self @@ -582,7 +602,8 @@ class SocketListener(object): self._family = family self._last_accepted = None - if family == 'AF_UNIX': + if family == 'AF_UNIX' and not util.is_abstract_socket_namespace(address): + # Linux abstract socket namespaces do not need to be explicitly unlinked self._unlink = util.Finalize( self, os.unlink, args=(address,), exitpriority=0 ) @@ -715,7 +736,9 @@ FAILURE = b'#FAILURE#' def deliver_challenge(connection, authkey): import hmac - assert isinstance(authkey, bytes) + if not isinstance(authkey, bytes): + raise ValueError( + "Authkey must be bytes, not {0!s}".format(type(authkey))) message = os.urandom(MESSAGE_LENGTH) connection.send_bytes(CHALLENGE + message) digest = hmac.new(authkey, message, 'md5').digest() @@ -728,7 +751,9 @@ def deliver_challenge(connection, authkey): def answer_challenge(connection, authkey): import hmac - assert isinstance(authkey, bytes) + if not isinstance(authkey, bytes): + raise ValueError( + "Authkey must be bytes, not {0!s}".format(type(authkey))) message = connection.recv_bytes(256) # reject large message assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message message = message[len(CHALLENGE):] @@ -780,8 +805,7 @@ def XmlClient(*args, **kwds): # Wait # -# XXX RustPython TODO: implement all the functions in this block -if sys.platform == 'win32' and False: +if sys.platform == 'win32': def _exhaustive_wait(handles, timeout): # Return ALL handles which are currently signalled. (Only @@ -906,7 +930,7 @@ else: selector.register(obj, selectors.EVENT_READ) if timeout is not None: - deadline = time.time() + timeout + deadline = time.monotonic() + timeout while True: ready = selector.select(timeout) @@ -914,7 +938,7 @@ else: return [key.fileobj for (key, events) in ready] else: if timeout is not None: - timeout = deadline - time.time() + timeout = deadline - time.monotonic() if timeout < 0: return ready diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py index 09455e2ec7..8d0525d5d6 100644 --- a/Lib/multiprocessing/context.py +++ b/Lib/multiprocessing/context.py @@ -5,7 +5,7 @@ import threading from . import process from . import reduction -__all__ = [] # things are copied from here to __init__.py +__all__ = () # # Exceptions @@ -24,7 +24,7 @@ class AuthenticationError(ProcessError): pass # -# Base type for contexts +# Base type for contexts. Bound methods of an instance of this type are included in __all__ of __init__.py # class BaseContext(object): @@ -35,6 +35,7 @@ class BaseContext(object): AuthenticationError = AuthenticationError current_process = staticmethod(process.current_process) + parent_process = staticmethod(process.parent_process) active_children = staticmethod(process.active_children) def cpu_count(self): @@ -189,14 +190,14 @@ class BaseContext(object): try: ctx = _concrete_contexts[method] except KeyError: - raise ValueError('cannot find context for %r' % method) + raise ValueError('cannot find context for %r' % method) from None ctx._check_available() return ctx def get_start_method(self, allow_none=False): return self._name - def set_start_method(self, method=None): + def set_start_method(self, method, force=False): raise ValueError('cannot set start method of concrete context') @property @@ -256,12 +257,11 @@ class DefaultContext(BaseContext): if sys.platform == 'win32': return ['spawn'] else: + methods = ['spawn', 'fork'] if sys.platform == 'darwin' else ['fork', 'spawn'] if reduction.HAVE_SEND_HANDLE: - return ['fork', 'spawn', 'forkserver'] - else: - return ['fork', 'spawn'] + methods.append('forkserver') + return methods -DefaultContext.__all__ = list(x for x in dir(DefaultContext) if x[0] != '_') # # Context types for fixed start method @@ -310,7 +310,12 @@ if sys.platform != 'win32': 'spawn': SpawnContext(), 'forkserver': ForkServerContext(), } - _default_context = DefaultContext(_concrete_contexts['fork']) + if sys.platform == 'darwin': + # bpo-33725: running arbitrary code after fork() is no longer reliable + # on macOS since macOS 10.14 (Mojave). Use spawn by default instead. + _default_context = DefaultContext(_concrete_contexts['spawn']) + else: + _default_context = DefaultContext(_concrete_contexts['fork']) else: diff --git a/Lib/multiprocessing/dummy/__init__.py b/Lib/multiprocessing/dummy/__init__.py index 1abea64419..6a1468609e 100644 --- a/Lib/multiprocessing/dummy/__init__.py +++ b/Lib/multiprocessing/dummy/__init__.py @@ -41,7 +41,10 @@ class DummyProcess(threading.Thread): self._parent = current_process() def start(self): - assert self._parent is current_process() + if self._parent is not current_process(): + raise RuntimeError( + "Parent is {0!r} but current_process is {1!r}".format( + self._parent, current_process())) self._start_called = True if hasattr(self._parent, '_children'): self._parent._children[self] = None @@ -77,7 +80,7 @@ def freeze_support(): # class Namespace(object): - def __init__(self, **kwds): + def __init__(self, /, **kwds): self.__dict__.update(kwds) def __repr__(self): items = list(self.__dict__.items()) @@ -98,11 +101,15 @@ class Value(object): def __init__(self, typecode, value, lock=True): self._typecode = typecode self._value = value - def _get(self): + + @property + def value(self): return self._value - def _set(self, value): + + @value.setter + def value(self, value): self._value = value - value = property(_get, _set) + def __repr__(self): return '<%s(%r, %r)>'%(type(self).__name__,self._typecode,self._value) diff --git a/Lib/multiprocessing/dummy/connection.py b/Lib/multiprocessing/dummy/connection.py index 19843751c0..f0ce320fcf 100644 --- a/Lib/multiprocessing/dummy/connection.py +++ b/Lib/multiprocessing/dummy/connection.py @@ -26,7 +26,9 @@ class Listener(object): def close(self): self._backlog_queue = None - address = property(lambda self: self._backlog_queue) + @property + def address(self): + return self._backlog_queue def __enter__(self): return self diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index f2c179e4e0..22a911a7a2 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -6,11 +6,12 @@ import socket import struct import sys import threading +import warnings from . import connection from . import process from .context import reduction -from . import semaphore_tracker +from . import resource_tracker from . import spawn from . import util @@ -22,7 +23,7 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', # MAXFDS_TO_SEND = 256 -UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t +SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t # # Forkserver class @@ -33,10 +34,31 @@ class ForkServer(object): def __init__(self): self._forkserver_address = None self._forkserver_alive_fd = None + self._forkserver_pid = None self._inherited_fds = None self._lock = threading.Lock() self._preload_modules = ['__main__'] + def _stop(self): + # Method used by unit tests to stop the server + with self._lock: + self._stop_unlocked() + + def _stop_unlocked(self): + if self._forkserver_pid is None: + return + + # close the "alive" file descriptor asks the server to stop + os.close(self._forkserver_alive_fd) + self._forkserver_alive_fd = None + + os.waitpid(self._forkserver_pid, 0) + self._forkserver_pid = None + + if not util.is_abstract_socket_namespace(self._forkserver_address): + os.unlink(self._forkserver_address) + self._forkserver_address = None + def set_forkserver_preload(self, modules_names): '''Set list of module names to try to load in forkserver process.''' if not all(type(mod) is str for mod in self._preload_modules): @@ -67,7 +89,7 @@ class ForkServer(object): parent_r, child_w = os.pipe() child_r, parent_w = os.pipe() allfds = [child_r, child_w, self._forkserver_alive_fd, - semaphore_tracker.getfd()] + resource_tracker.getfd()] allfds += fds try: reduction.sendfds(client, allfds) @@ -88,9 +110,18 @@ class ForkServer(object): ensure_running() will do nothing. ''' with self._lock: - semaphore_tracker.ensure_running() - if self._forkserver_alive_fd is not None: - return + resource_tracker.ensure_running() + if self._forkserver_pid is not None: + # forkserver was launched before, is it still running? + pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG) + if not pid: + # still alive + return + # dead, launch it again + os.close(self._forkserver_alive_fd) + self._forkserver_address = None + self._forkserver_alive_fd = None + self._forkserver_pid = None cmd = ('from multiprocessing.forkserver import main; ' + 'main(%d, %d, %r, **%r)') @@ -98,15 +129,15 @@ class ForkServer(object): if self._preload_modules: desired_keys = {'main_path', 'sys_path'} data = spawn.get_preparation_data('ignore') - data = dict((x,y) for (x,y) in data.items() - if x in desired_keys) + data = {x: y for x, y in data.items() if x in desired_keys} else: data = {} with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') listener.bind(address) - os.chmod(address, 0o600) + if not util.is_abstract_socket_namespace(address): + os.chmod(address, 0o600) listener.listen() # all client processes own the write end of the "alive" pipe; @@ -127,6 +158,7 @@ class ForkServer(object): os.close(alive_r) self._forkserver_address = address self._forkserver_alive_fd = alive_w + self._forkserver_pid = pid # # @@ -149,14 +181,36 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): util._close_stdin() - # ignoring SIGCHLD means no need to reap zombie processes - handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN) + sig_r, sig_w = os.pipe() + os.set_blocking(sig_r, False) + os.set_blocking(sig_w, False) + + def sigchld_handler(*_unused): + # Dummy signal handler, doesn't do anything + pass + + handlers = { + # unblocking SIGCHLD allows the wakeup fd to notify our event loop + signal.SIGCHLD: sigchld_handler, + # protect the process from ^C + signal.SIGINT: signal.SIG_IGN, + } + old_handlers = {sig: signal.signal(sig, val) + for (sig, val) in handlers.items()} + + # calling os.write() in the Python signal handler is racy + signal.set_wakeup_fd(sig_w) + + # map child pids to client fds + pid_to_fd = {} + with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \ selectors.DefaultSelector() as selector: _forkserver._forkserver_address = listener.getsockname() selector.register(listener, selectors.EVENT_READ) selector.register(alive_r, selectors.EVENT_READ) + selector.register(sig_r, selectors.EVENT_READ) while True: try: @@ -167,69 +221,116 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): if alive_r in rfds: # EOF because no more client processes left - assert os.read(alive_r, 1) == b'' + assert os.read(alive_r, 1) == b'', "Not at EOF?" raise SystemExit - assert listener in rfds - with listener.accept()[0] as s: - code = 1 - if os.fork() == 0: + if sig_r in rfds: + # Got SIGCHLD + os.read(sig_r, 65536) # exhaust + while True: + # Scan for child processes try: - _serve_one(s, listener, alive_r, handler) - except Exception: - sys.excepthook(*sys.exc_info()) - sys.stderr.flush() - finally: - os._exit(code) + pid, sts = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + break + child_w = pid_to_fd.pop(pid, None) + if child_w is not None: + returncode = os.waitstatus_to_exitcode(sts) + + # Send exit code to client process + try: + write_signed(child_w, returncode) + except BrokenPipeError: + # client vanished + pass + os.close(child_w) + else: + # This shouldn't happen really + warnings.warn('forkserver: waitpid returned ' + 'unexpected pid %d' % pid) + + if listener in rfds: + # Incoming fork request + with listener.accept()[0] as s: + # Receive fds from client + fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) + if len(fds) > MAXFDS_TO_SEND: + raise RuntimeError( + "Too many ({0:n}) fds to send".format( + len(fds))) + child_r, child_w, *fds = fds + s.close() + pid = os.fork() + if pid == 0: + # Child + code = 1 + try: + listener.close() + selector.close() + unused_fds = [alive_r, child_w, sig_r, sig_w] + unused_fds.extend(pid_to_fd.values()) + code = _serve_one(child_r, fds, + unused_fds, + old_handlers) + except Exception: + sys.excepthook(*sys.exc_info()) + sys.stderr.flush() + finally: + os._exit(code) + else: + # Send pid to client process + try: + write_signed(child_w, pid) + except BrokenPipeError: + # client vanished + pass + pid_to_fd[pid] = child_w + os.close(child_r) + for fd in fds: + os.close(fd) except OSError as e: if e.errno != errno.ECONNABORTED: raise -def _serve_one(s, listener, alive_r, handler): - # close unnecessary stuff and reset SIGCHLD handler - listener.close() - os.close(alive_r) - signal.signal(signal.SIGCHLD, handler) - # receive fds from parent process - fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) - s.close() - assert len(fds) <= MAXFDS_TO_SEND - (child_r, child_w, _forkserver._forkserver_alive_fd, - stfd, *_forkserver._inherited_fds) = fds - semaphore_tracker._semaphore_tracker._fd = stfd +def _serve_one(child_r, fds, unused_fds, handlers): + # close unnecessary stuff and reset signal handlers + signal.set_wakeup_fd(-1) + for sig, val in handlers.items(): + signal.signal(sig, val) + for fd in unused_fds: + os.close(fd) - # send pid to client processes - write_unsigned(child_w, os.getpid()) + (_forkserver._forkserver_alive_fd, + resource_tracker._resource_tracker._fd, + *_forkserver._inherited_fds) = fds - # reseed random number generator - if 'random' in sys.modules: - import random - random.seed() + # Run process object received over pipe + parent_sentinel = os.dup(child_r) + code = spawn._main(child_r, parent_sentinel) - # run process object received over pipe - code = spawn._main(child_r) + return code - # write the exit code to the pipe - write_unsigned(child_w, code) # -# Read and write unsigned numbers +# Read and write signed numbers # -def read_unsigned(fd): +def read_signed(fd): data = b'' - length = UNSIGNED_STRUCT.size + length = SIGNED_STRUCT.size while len(data) < length: s = os.read(fd, length - len(data)) if not s: raise EOFError('unexpected EOF') data += s - return UNSIGNED_STRUCT.unpack(data)[0] + return SIGNED_STRUCT.unpack(data)[0] -def write_unsigned(fd, n): - msg = UNSIGNED_STRUCT.pack(n) +def write_signed(fd, n): + msg = SIGNED_STRUCT.pack(n) while msg: nbytes = os.write(fd, msg) if nbytes == 0: diff --git a/Lib/multiprocessing/heap.py b/Lib/multiprocessing/heap.py index 443321535e..6217dfe126 100644 --- a/Lib/multiprocessing/heap.py +++ b/Lib/multiprocessing/heap.py @@ -8,6 +8,7 @@ # import bisect +from collections import defaultdict import mmap import os import sys @@ -28,6 +29,9 @@ if sys.platform == 'win32': import _winapi class Arena(object): + """ + A shared memory area backed by anonymous memory (Windows). + """ _rand = tempfile._RandomNameSequence() @@ -52,6 +56,7 @@ if sys.platform == 'win32': def __setstate__(self, state): self.size, self.name = self._state = state + # Reopen existing mmap self.buffer = mmap.mmap(-1, self.size, tagname=self.name) # XXX Temporarily preventing buildbot failures while determining # XXX the correct long-term fix. See issue 23060 @@ -60,26 +65,38 @@ if sys.platform == 'win32': else: class Arena(object): + """ + A shared memory area backed by a temporary file (POSIX). + """ + + if sys.platform == 'linux': + _dir_candidates = ['/dev/shm'] + else: + _dir_candidates = [] def __init__(self, size, fd=-1): self.size = size self.fd = fd if fd == -1: + # Arena is created anew (if fd != -1, it means we're coming + # from rebuild_arena() below) self.fd, name = tempfile.mkstemp( - prefix='pym-%d-'%os.getpid(), dir=util.get_temp_dir()) + prefix='pym-%d-'%os.getpid(), + dir=self._choose_dir(size)) os.unlink(name) util.Finalize(self, os.close, (self.fd,)) - with open(self.fd, 'wb', closefd=False) as f: - bs = 1024 * 1024 - if size >= bs: - zeros = b'\0' * bs - for _ in range(size // bs): - f.write(zeros) - del zeros - f.write(b'\0' * (size % bs)) - assert f.tell() == size + os.ftruncate(self.fd, size) self.buffer = mmap.mmap(self.fd, self.size) + def _choose_dir(self, size): + # Choose a non-storage backed directory if possible, + # to improve performance + for d in self._dir_candidates: + st = os.statvfs(d) + if st.f_bavail * st.f_frsize >= size: # enough free space? + return d + return util.get_temp_dir() + def reduce_arena(a): if a.fd == -1: raise ValueError('Arena is unpicklable because ' @@ -97,37 +114,82 @@ else: class Heap(object): + # Minimum malloc() alignment _alignment = 8 + _DISCARD_FREE_SPACE_LARGER_THAN = 4 * 1024 ** 2 # 4 MB + _DOUBLE_ARENA_SIZE_UNTIL = 4 * 1024 ** 2 + def __init__(self, size=mmap.PAGESIZE): self._lastpid = os.getpid() self._lock = threading.Lock() + # Current arena allocation size self._size = size + # A sorted list of available block sizes in arenas self._lengths = [] + + # Free block management: + # - map each block size to a list of `(Arena, start, stop)` blocks self._len_to_seq = {} + # - map `(Arena, start)` tuple to the `(Arena, start, stop)` block + # starting at that offset self._start_to_block = {} + # - map `(Arena, stop)` tuple to the `(Arena, start, stop)` block + # ending at that offset self._stop_to_block = {} - self._allocated_blocks = set() + + # Map arenas to their `(Arena, start, stop)` blocks in use + self._allocated_blocks = defaultdict(set) self._arenas = [] - # list of pending blocks to free - see free() comment below + + # List of pending blocks to free - see comment in free() below self._pending_free_blocks = [] + # Statistics + self._n_mallocs = 0 + self._n_frees = 0 + @staticmethod def _roundup(n, alignment): # alignment must be a power of 2 mask = alignment - 1 return (n + mask) & ~mask + def _new_arena(self, size): + # Create a new arena with at least the given *size* + length = self._roundup(max(self._size, size), mmap.PAGESIZE) + # We carve larger and larger arenas, for efficiency, until we + # reach a large-ish size (roughly L3 cache-sized) + if self._size < self._DOUBLE_ARENA_SIZE_UNTIL: + self._size *= 2 + util.info('allocating a new mmap of length %d', length) + arena = Arena(length) + self._arenas.append(arena) + return (arena, 0, length) + + def _discard_arena(self, arena): + # Possibly delete the given (unused) arena + length = arena.size + # Reusing an existing arena is faster than creating a new one, so + # we only reclaim space if it's large enough. + if length < self._DISCARD_FREE_SPACE_LARGER_THAN: + return + blocks = self._allocated_blocks.pop(arena) + assert not blocks + del self._start_to_block[(arena, 0)] + del self._stop_to_block[(arena, length)] + self._arenas.remove(arena) + seq = self._len_to_seq[length] + seq.remove((arena, 0, length)) + if not seq: + del self._len_to_seq[length] + self._lengths.remove(length) + def _malloc(self, size): # returns a large enough block -- it might be much larger i = bisect.bisect_left(self._lengths, size) if i == len(self._lengths): - length = self._roundup(max(self._size, size), mmap.PAGESIZE) - self._size *= 2 - util.info('allocating a new mmap of length %d', length) - arena = Arena(length) - self._arenas.append(arena) - return (arena, 0, length) + return self._new_arena(size) else: length = self._lengths[i] seq = self._len_to_seq[length] @@ -140,8 +202,8 @@ class Heap(object): del self._stop_to_block[(arena, stop)] return block - def _free(self, block): - # free location and try to merge with neighbours + def _add_free_block(self, block): + # make block available and try to merge with its neighbours in the arena (arena, start, stop) = block try: @@ -185,6 +247,14 @@ class Heap(object): return start, stop + def _remove_allocated_block(self, block): + arena, start, stop = block + blocks = self._allocated_blocks[arena] + blocks.remove((start, stop)) + if not blocks: + # Arena is entirely free, discard it from this process + self._discard_arena(arena) + def _free_pending_blocks(self): # Free all the blocks in the pending list - called with the lock held. while True: @@ -192,8 +262,8 @@ class Heap(object): block = self._pending_free_blocks.pop() except IndexError: break - self._allocated_blocks.remove(block) - self._free(block) + self._add_free_block(block) + self._remove_allocated_block(block) def free(self, block): # free a block returned by malloc() @@ -204,8 +274,11 @@ class Heap(object): # immediately, the block is added to a list of blocks to be freed # synchronously sometimes later from malloc() or free(), by calling # _free_pending_blocks() (appending and retrieving from a list is not - # strictly thread-safe but under cPython it's atomic thanks to the GIL). - assert os.getpid() == self._lastpid + # strictly thread-safe but under CPython it's atomic thanks to the GIL). + if os.getpid() != self._lastpid: + raise ValueError( + "My pid ({0:n}) is not last pid {1:n}".format( + os.getpid(),self._lastpid)) if not self._lock.acquire(False): # can't acquire the lock right now, add the block to the list of # pending blocks to free @@ -213,30 +286,37 @@ class Heap(object): else: # we hold the lock try: + self._n_frees += 1 self._free_pending_blocks() - self._allocated_blocks.remove(block) - self._free(block) + self._add_free_block(block) + self._remove_allocated_block(block) finally: self._lock.release() def malloc(self, size): # return a block of right size (possibly rounded up) - assert 0 <= size < sys.maxsize + if size < 0: + raise ValueError("Size {0:n} out of range".format(size)) + if sys.maxsize <= size: + raise OverflowError("Size {0:n} too large".format(size)) if os.getpid() != self._lastpid: self.__init__() # reinitialize after fork with self._lock: + self._n_mallocs += 1 + # allow pending blocks to be marked available self._free_pending_blocks() - size = self._roundup(max(size,1), self._alignment) + size = self._roundup(max(size, 1), self._alignment) (arena, start, stop) = self._malloc(size) - new_stop = start + size - if new_stop < stop: - self._free((arena, new_stop, stop)) - block = (arena, start, new_stop) - self._allocated_blocks.add(block) - return block + real_stop = start + size + if real_stop < stop: + # if the returned block is larger than necessary, mark + # the remainder available + self._add_free_block((arena, real_stop, stop)) + self._allocated_blocks[arena].add((start, real_stop)) + return (arena, start, real_stop) # -# Class representing a chunk of an mmap -- can be inherited by child process +# Class wrapping a block allocated out of a Heap -- can be inherited by child process # class BufferWrapper(object): @@ -244,7 +324,10 @@ class BufferWrapper(object): _heap = Heap() def __init__(self, size): - assert 0 <= size < sys.maxsize + if size < 0: + raise ValueError("Size {0:n} out of range".format(size)) + if sys.maxsize <= size: + raise OverflowError("Size {0:n} too large".format(size)) block = BufferWrapper._heap.malloc(size) self._state = (block, size) util.Finalize(self, BufferWrapper._heap.free, args=(block,)) diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py index e298b2cbaf..22292c78b7 100644 --- a/Lib/multiprocessing/managers.py +++ b/Lib/multiprocessing/managers.py @@ -1,5 +1,5 @@ # -# Module providing the `SyncManager` class for dealing +# Module providing manager classes for dealing # with shared objects # # multiprocessing/managers.py @@ -16,19 +16,29 @@ __all__ = [ 'BaseManager', 'SyncManager', 'BaseProxy', 'Token' ] import sys import threading +import signal import array import queue +import time import types +import os +from os import getpid -from time import time as _time from traceback import format_exc from . import connection -from .context import reduction, get_spawning_popen +from .context import reduction, get_spawning_popen, ProcessError from . import pool from . import process from . import util from . import get_context +try: + from . import shared_memory +except ImportError: + HAS_SHMEM = False +else: + HAS_SHMEM = True + __all__.append('SharedMemoryManager') # # Register some things for pickling @@ -51,7 +61,7 @@ if view_types[0] is not list: # only needed in Py3.0 class Token(object): ''' - Type to uniquely indentify a shared object + Type to uniquely identify a shared object ''' __slots__ = ('typeid', 'address', 'id') @@ -85,14 +95,17 @@ def dispatch(c, id, methodname, args=(), kwds={}): def convert_to_error(kind, result): if kind == '#ERROR': return result - elif kind == '#TRACEBACK': - assert type(result) is str - return RemoteError(result) - elif kind == '#UNSERIALIZABLE': - assert type(result) is str - return RemoteError('Unserializable message: %s\n' % result) + elif kind in ('#TRACEBACK', '#UNSERIALIZABLE'): + if not isinstance(result, str): + raise TypeError( + "Result {0!r} (kind '{1}') type is {2}, not str".format( + result, kind, type(result))) + if kind == '#UNSERIALIZABLE': + return RemoteError('Unserializable message: %s\n' % result) + else: + return RemoteError(result) else: - return ValueError('Unrecognized message type') + return ValueError('Unrecognized message type {!r}'.format(kind)) class RemoteError(Exception): def __str__(self): @@ -131,7 +144,10 @@ class Server(object): 'debug_info', 'number_of_objects', 'dummy', 'incref', 'decref'] def __init__(self, registry, address, authkey, serializer): - assert isinstance(authkey, bytes) + if not isinstance(authkey, bytes): + raise TypeError( + "Authkey {0!r} is type {1!s}, not bytes".format( + authkey, type(authkey))) self.registry = registry self.authkey = process.AuthenticationString(authkey) Listener, Client = listener_client[serializer] @@ -161,7 +177,7 @@ class Server(object): except (KeyboardInterrupt, SystemExit): pass finally: - if sys.stdout != sys.__stdout__: + if sys.stdout != sys.__stdout__: # what about stderr? util.debug('resetting stdout, stderr') sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ @@ -177,11 +193,8 @@ class Server(object): t.daemon = True t.start() - def handle_request(self, c): - ''' - Handle a new connection - ''' - funcname = result = request = None + def _handle_request(self, c): + request = None try: connection.deliver_challenge(c, self.authkey) connection.answer_challenge(c, self.authkey) @@ -198,6 +211,7 @@ class Server(object): msg = ('#TRACEBACK', format_exc()) else: msg = ('#RETURN', result) + try: c.send(msg) except Exception as e: @@ -209,7 +223,17 @@ class Server(object): util.info(' ... request was %r', request) util.info(' ... exception was %r', e) - c.close() + def handle_request(self, conn): + ''' + Handle a new connection + ''' + try: + self._handle_request(conn) + except SystemExit: + # Server.serve_client() calls sys.exit(0) on EOF + pass + finally: + conn.close() def serve_client(self, conn): ''' @@ -234,7 +258,7 @@ class Server(object): try: obj, exposed, gettypeid = \ self.id_to_local_proxy_obj[ident] - except KeyError as second_ke: + except KeyError: raise ke if methodname not in exposed: @@ -282,7 +306,7 @@ class Server(object): try: try: send(msg) - except Exception as e: + except Exception: send(('#UNSERIALIZABLE', format_exc())) except Exception as e: util.info('exception in thread serving %r', @@ -314,6 +338,7 @@ class Server(object): ''' Return some info --- useful to spot problems with refcounting ''' + # Perhaps include debug info about 'c'? with self.mutex: result = [] keys = list(self.id_to_refcount.keys()) @@ -345,7 +370,7 @@ class Server(object): finally: self.stop_event.set() - def create(self, c, typeid, *args, **kwds): + def create(self, c, typeid, /, *args, **kwds): ''' Create a new shared object and return its id ''' @@ -354,7 +379,9 @@ class Server(object): self.registry[typeid] if callable is None: - assert len(args) == 1 and not kwds + if kwds or (len(args) != 1): + raise ValueError( + "Without callable, must have one non-keyword argument") obj = args[0] else: obj = callable(*args, **kwds) @@ -362,7 +389,10 @@ class Server(object): if exposed is None: exposed = public_methods(obj) if method_to_typeid is not None: - assert type(method_to_typeid) is dict + if not isinstance(method_to_typeid, dict): + raise TypeError( + "Method_to_typeid {0!r}: type {1!s}, not dict".format( + method_to_typeid, type(method_to_typeid))) exposed = list(exposed) + list(method_to_typeid) ident = '%x' % id(obj) # convert to string because xmlrpclib @@ -415,7 +445,11 @@ class Server(object): return with self.mutex: - assert self.id_to_refcount[ident] >= 1 + if self.id_to_refcount[ident] <= 0: + raise AssertionError( + "Id {0!s} ({1!r}) has refcount {2:n}, not 1+".format( + ident, self.id_to_obj[ident], + self.id_to_refcount[ident])) self.id_to_refcount[ident] -= 1 if self.id_to_refcount[ident] == 0: del self.id_to_refcount[ident] @@ -478,7 +512,14 @@ class BaseManager(object): ''' Return server object with serve_forever() method and address attribute ''' - assert self._state.value == State.INITIAL + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started server") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("Manager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) return Server(self._registry, self._address, self._authkey, self._serializer) @@ -495,7 +536,14 @@ class BaseManager(object): ''' Spawn a server process for this manager object ''' - assert self._state.value == State.INITIAL + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started server") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("Manager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) if initializer is not None and not callable(initializer): raise TypeError('initializer must be a callable') @@ -533,6 +581,9 @@ class BaseManager(object): ''' Create a server, report its address and run it ''' + # bpo-36368: protect server process from KeyboardInterrupt signals + signal.signal(signal.SIGINT, signal.SIG_IGN) + if initializer is not None: initializer(*initargs) @@ -547,7 +598,7 @@ class BaseManager(object): util.info('manager serving at %r', server.address) server.serve_forever() - def _create(self, typeid, *args, **kwds): + def _create(self, typeid, /, *args, **kwds): ''' Create a new shared object; return the token and exposed tuple ''' @@ -591,7 +642,14 @@ class BaseManager(object): def __enter__(self): if self._state.value == State.INITIAL: self.start() - assert self._state.value == State.STARTED + if self._state.value != State.STARTED: + if self._state.value == State.INITIAL: + raise ProcessError("Unable to start server") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("Manager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -619,7 +677,7 @@ class BaseManager(object): if hasattr(process, 'terminate'): util.info('trying to `terminate()` manager process') process.terminate() - process.join(timeout=0.1) + process.join(timeout=1.0) if process.is_alive(): util.info('manager still alive after terminate') @@ -629,7 +687,9 @@ class BaseManager(object): except KeyError: pass - address = property(lambda self: self._address) + @property + def address(self): + return self._address @classmethod def register(cls, typeid, callable=None, proxytype=None, exposed=None, @@ -649,7 +709,7 @@ class BaseManager(object): getattr(proxytype, '_method_to_typeid_', None) if method_to_typeid: - for key, value in list(method_to_typeid.items()): + for key, value in list(method_to_typeid.items()): # isinstance? assert type(key) is str, '%r is not a string' % key assert type(value) is str, '%r is not a string' % value @@ -658,7 +718,7 @@ class BaseManager(object): ) if create_method: - def temp(self, *args, **kwds): + def temp(self, /, *args, **kwds): util.debug('requesting creation of a shared %r object', typeid) token, exp = self._create(typeid, *args, **kwds) proxy = proxytype( @@ -744,7 +804,7 @@ class BaseProxy(object): def _callmethod(self, methodname, args=(), kwds={}): ''' - Try to call a method of the referrent and return a copy of the result + Try to call a method of the referent and return a copy of the result ''' try: conn = self._tls.connection @@ -898,7 +958,7 @@ def MakeProxyType(name, exposed, _cache={}): dic = {} for meth in exposed: - exec('''def %s(self, *args, **kwds): + exec('''def %s(self, /, *args, **kwds): return self._callmethod(%r, args, kwds)''' % (meth, meth), dic) ProxyType = type(name, (BaseProxy,), dic) @@ -908,7 +968,7 @@ def MakeProxyType(name, exposed, _cache={}): def AutoProxy(token, serializer, manager=None, authkey=None, - exposed=None, incref=True): + exposed=None, incref=True, manager_owned=False): ''' Return an auto-proxy for `token` ''' @@ -928,7 +988,7 @@ def AutoProxy(token, serializer, manager=None, authkey=None, ProxyType = MakeProxyType('AutoProxy[%s]' % token.typeid, exposed) proxy = ProxyType(token, serializer, manager=manager, authkey=authkey, - incref=incref) + incref=incref, manager_owned=manager_owned) proxy._isauto = True return proxy @@ -937,7 +997,7 @@ def AutoProxy(token, serializer, manager=None, authkey=None, # class Namespace(object): - def __init__(self, **kwds): + def __init__(self, /, **kwds): self.__dict__.update(kwds) def __repr__(self): items = list(self.__dict__.items()) @@ -998,8 +1058,8 @@ class ConditionProxy(AcquirerProxy): _exposed_ = ('acquire', 'release', 'wait', 'notify', 'notify_all') def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) - def notify(self): - return self._callmethod('notify') + def notify(self, n=1): + return self._callmethod('notify', (n,)) def notify_all(self): return self._callmethod('notify_all') def wait_for(self, predicate, timeout=None): @@ -1007,13 +1067,13 @@ class ConditionProxy(AcquirerProxy): if result: return result if timeout is not None: - endtime = _time() + timeout + endtime = time.monotonic() + timeout else: endtime = None waittime = None while not result: if endtime is not None: - waittime = endtime - _time() + waittime = endtime - time.monotonic() if waittime <= 0: break self.wait(waittime) @@ -1098,10 +1158,13 @@ class ListProxy(BaseListProxy): DictProxy = MakeProxyType('DictProxy', ( - '__contains__', '__delitem__', '__getitem__', '__len__', - '__setitem__', 'clear', 'copy', 'get', 'has_key', 'items', + '__contains__', '__delitem__', '__getitem__', '__iter__', '__len__', + '__setitem__', 'clear', 'copy', 'get', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values' )) +DictProxy._method_to_typeid_ = { + '__iter__': 'Iterator', + } ArrayProxy = MakeProxyType('ArrayProxy', ( @@ -1161,3 +1224,155 @@ SyncManager.register('Namespace', Namespace, NamespaceProxy) # types returned by methods of PoolProxy SyncManager.register('Iterator', proxytype=IteratorProxy, create_method=False) SyncManager.register('AsyncResult', create_method=False) + +# +# Definition of SharedMemoryManager and SharedMemoryServer +# + +if HAS_SHMEM: + class _SharedMemoryTracker: + "Manages one or more shared memory segments." + + def __init__(self, name, segment_names=[]): + self.shared_memory_context_name = name + self.segment_names = segment_names + + def register_segment(self, segment_name): + "Adds the supplied shared memory block name to tracker." + util.debug(f"Register segment {segment_name!r} in pid {getpid()}") + self.segment_names.append(segment_name) + + def destroy_segment(self, segment_name): + """Calls unlink() on the shared memory block with the supplied name + and removes it from the list of blocks being tracked.""" + util.debug(f"Destroy segment {segment_name!r} in pid {getpid()}") + self.segment_names.remove(segment_name) + segment = shared_memory.SharedMemory(segment_name) + segment.close() + segment.unlink() + + def unlink(self): + "Calls destroy_segment() on all tracked shared memory blocks." + for segment_name in self.segment_names[:]: + self.destroy_segment(segment_name) + + def __del__(self): + util.debug(f"Call {self.__class__.__name__}.__del__ in {getpid()}") + self.unlink() + + def __getstate__(self): + return (self.shared_memory_context_name, self.segment_names) + + def __setstate__(self, state): + self.__init__(*state) + + + class SharedMemoryServer(Server): + + public = Server.public + \ + ['track_segment', 'release_segment', 'list_segments'] + + def __init__(self, *args, **kwargs): + Server.__init__(self, *args, **kwargs) + address = self.address + # The address of Linux abstract namespaces can be bytes + if isinstance(address, bytes): + address = os.fsdecode(address) + self.shared_memory_context = \ + _SharedMemoryTracker(f"shm_{address}_{getpid()}") + util.debug(f"SharedMemoryServer started by pid {getpid()}") + + def create(self, c, typeid, /, *args, **kwargs): + """Create a new distributed-shared object (not backed by a shared + memory block) and return its id to be used in a Proxy Object.""" + # Unless set up as a shared proxy, don't make shared_memory_context + # a standard part of kwargs. This makes things easier for supplying + # simple functions. + if hasattr(self.registry[typeid][-1], "_shared_memory_proxy"): + kwargs['shared_memory_context'] = self.shared_memory_context + return Server.create(self, c, typeid, *args, **kwargs) + + def shutdown(self, c): + "Call unlink() on all tracked shared memory, terminate the Server." + self.shared_memory_context.unlink() + return Server.shutdown(self, c) + + def track_segment(self, c, segment_name): + "Adds the supplied shared memory block name to Server's tracker." + self.shared_memory_context.register_segment(segment_name) + + def release_segment(self, c, segment_name): + """Calls unlink() on the shared memory block with the supplied name + and removes it from the tracker instance inside the Server.""" + self.shared_memory_context.destroy_segment(segment_name) + + def list_segments(self, c): + """Returns a list of names of shared memory blocks that the Server + is currently tracking.""" + return self.shared_memory_context.segment_names + + + class SharedMemoryManager(BaseManager): + """Like SyncManager but uses SharedMemoryServer instead of Server. + + It provides methods for creating and returning SharedMemory instances + and for creating a list-like object (ShareableList) backed by shared + memory. It also provides methods that create and return Proxy Objects + that support synchronization across processes (i.e. multi-process-safe + locks and semaphores). + """ + + _Server = SharedMemoryServer + + def __init__(self, *args, **kwargs): + if os.name == "posix": + # bpo-36867: Ensure the resource_tracker is running before + # launching the manager process, so that concurrent + # shared_memory manipulation both in the manager and in the + # current process does not create two resource_tracker + # processes. + from . import resource_tracker + resource_tracker.ensure_running() + BaseManager.__init__(self, *args, **kwargs) + util.debug(f"{self.__class__.__name__} created by pid {getpid()}") + + def __del__(self): + util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}") + pass + + def get_server(self): + 'Better than monkeypatching for now; merge into Server ultimately' + if self._state.value != State.INITIAL: + if self._state.value == State.STARTED: + raise ProcessError("Already started SharedMemoryServer") + elif self._state.value == State.SHUTDOWN: + raise ProcessError("SharedMemoryManager has shut down") + else: + raise ProcessError( + "Unknown state {!r}".format(self._state.value)) + return self._Server(self._registry, self._address, + self._authkey, self._serializer) + + def SharedMemory(self, size): + """Returns a new SharedMemory instance with the specified size in + bytes, to be tracked by the manager.""" + with self._Client(self._address, authkey=self._authkey) as conn: + sms = shared_memory.SharedMemory(None, create=True, size=size) + try: + dispatch(conn, None, 'track_segment', (sms.name,)) + except BaseException as e: + sms.unlink() + raise e + return sms + + def ShareableList(self, sequence): + """Returns a new ShareableList instance populated with the values + from the input sequence, to be tracked by the manager.""" + with self._Client(self._address, authkey=self._authkey) as conn: + sl = shared_memory.ShareableList(sequence) + try: + dispatch(conn, None, 'track_segment', (sl.shm.name,)) + except BaseException as e: + sl.shm.unlink() + raise e + return sl diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 88f4c397dc..bbe05a550c 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -13,27 +13,30 @@ __all__ = ['Pool', 'ThreadPool'] # Imports # -import threading -import queue -import itertools import collections +import itertools import os +import queue +import threading import time import traceback import types +import warnings # If threading is available then ThreadPool should be provided. Therefore # we avoid top-level imports which are liable to fail on some systems. from . import util from . import get_context, TimeoutError +from .connection import wait # # Constants representing the state of a pool # -RUN = 0 -CLOSE = 1 -TERMINATE = 2 +INIT = "INIT" +RUN = "RUN" +CLOSE = "CLOSE" +TERMINATE = "TERMINATE" # # Miscellaneous @@ -93,7 +96,9 @@ class MaybeEncodingError(Exception): def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, wrap_exception=False): - assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) + if (maxtasks is not None) and not (isinstance(maxtasks, int) + and maxtasks >= 1): + raise AssertionError("Maxtasks {!r} is not valid".format(maxtasks)) put = outqueue.put get = inqueue.get if hasattr(inqueue, '_writer'): @@ -119,7 +124,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, try: result = (True, func(*args, **kwds)) except Exception as e: - if wrap_exception: + if wrap_exception and func is not _helper_reraises_exception: e = ExceptionWithTraceback(e, e.__traceback__) result = (False, e) try: @@ -129,29 +134,67 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, util.debug("Possible encoding error while sending result: %s" % ( wrapped)) put((job, i, (False, wrapped))) + + task = job = result = func = args = kwds = None completed += 1 util.debug('worker exiting after %d tasks' % completed) +def _helper_reraises_exception(ex): + 'Pickle-able helper function for use by _guarded_task_generation.' + raise ex + # # Class representing a process pool # +class _PoolCache(dict): + """ + Class that implements a cache for the Pool class that will notify + the pool management threads every time the cache is emptied. The + notification is done by the use of a queue that is provided when + instantiating the cache. + """ + def __init__(self, /, *args, notifier=None, **kwds): + self.notifier = notifier + super().__init__(*args, **kwds) + + def __delitem__(self, item): + super().__delitem__(item) + + # Notify that the cache is empty. This is important because the + # pool keeps maintaining workers until the cache gets drained. This + # eliminates a race condition in which a task is finished after the + # the pool's _handle_workers method has enter another iteration of the + # loop. In this situation, the only event that can wake up the pool + # is the cache to be emptied (no more tasks available). + if not self: + self.notifier.put(None) + class Pool(object): ''' Class which supports an async version of applying functions to arguments. ''' _wrap_exception = True - def Process(self, *args, **kwds): - return self._ctx.Process(*args, **kwds) + @staticmethod + def Process(ctx, *args, **kwds): + return ctx.Process(*args, **kwds) def __init__(self, processes=None, initializer=None, initargs=(), maxtasksperchild=None, context=None): + # Attributes initialized early to make sure that they exist in + # __del__() if __init__() raises an exception + self._pool = [] + self._state = INIT + self._ctx = context or get_context() self._setup_queues() - self._taskqueue = queue.Queue() - self._cache = {} - self._state = RUN + self._taskqueue = queue.SimpleQueue() + # The _change_notifier queue exist to wake up self._handle_workers() + # when the cache (self._cache) is empty or when there is a change in + # the _state variable of the thread that runs _handle_workers. + self._change_notifier = self._ctx.SimpleQueue() + self._cache = _PoolCache(notifier=self._change_notifier) self._maxtasksperchild = maxtasksperchild self._initializer = initializer self._initargs = initargs @@ -165,12 +208,24 @@ class Pool(object): raise TypeError('initializer must be a callable') self._processes = processes - self._pool = [] - self._repopulate_pool() + try: + self._repopulate_pool() + except Exception: + for p in self._pool: + if p.exitcode is None: + p.terminate() + for p in self._pool: + p.join() + raise + + sentinels = self._get_sentinels() self._worker_handler = threading.Thread( target=Pool._handle_workers, - args=(self, ) + args=(self._cache, self._taskqueue, self._ctx, self.Process, + self._processes, self._pool, self._inqueue, self._outqueue, + self._initializer, self._initargs, self._maxtasksperchild, + self._wrap_exception, sentinels, self._change_notifier) ) self._worker_handler.daemon = True self._worker_handler._state = RUN @@ -197,48 +252,92 @@ class Pool(object): self._terminate = util.Finalize( self, self._terminate_pool, args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, - self._worker_handler, self._task_handler, + self._change_notifier, self._worker_handler, self._task_handler, self._result_handler, self._cache), exitpriority=15 ) + self._state = RUN - def _join_exited_workers(self): + # Copy globals as function locals to make sure that they are available + # during Python shutdown when the Pool is destroyed. + def __del__(self, _warn=warnings.warn, RUN=RUN): + if self._state == RUN: + _warn(f"unclosed running multiprocessing pool {self!r}", + ResourceWarning, source=self) + if getattr(self, '_change_notifier', None) is not None: + self._change_notifier.put(None) + + def __repr__(self): + cls = self.__class__ + return (f'<{cls.__module__}.{cls.__qualname__} ' + f'state={self._state} ' + f'pool_size={len(self._pool)}>') + + def _get_sentinels(self): + task_queue_sentinels = [self._outqueue._reader] + self_notifier_sentinels = [self._change_notifier._reader] + return [*task_queue_sentinels, *self_notifier_sentinels] + + @staticmethod + def _get_worker_sentinels(workers): + return [worker.sentinel for worker in + workers if hasattr(worker, "sentinel")] + + @staticmethod + def _join_exited_workers(pool): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. """ cleaned = False - for i in reversed(range(len(self._pool))): - worker = self._pool[i] + for i in reversed(range(len(pool))): + worker = pool[i] if worker.exitcode is not None: # worker exited util.debug('cleaning up worker %d' % i) worker.join() cleaned = True - del self._pool[i] + del pool[i] return cleaned def _repopulate_pool(self): + return self._repopulate_pool_static(self._ctx, self.Process, + self._processes, + self._pool, self._inqueue, + self._outqueue, self._initializer, + self._initargs, + self._maxtasksperchild, + self._wrap_exception) + + @staticmethod + def _repopulate_pool_static(ctx, Process, processes, pool, inqueue, + outqueue, initializer, initargs, + maxtasksperchild, wrap_exception): """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ - for i in range(self._processes - len(self._pool)): - w = self.Process(target=worker, - args=(self._inqueue, self._outqueue, - self._initializer, - self._initargs, self._maxtasksperchild, - self._wrap_exception) - ) - self._pool.append(w) + for i in range(processes - len(pool)): + w = Process(ctx, target=worker, + args=(inqueue, outqueue, + initializer, + initargs, maxtasksperchild, + wrap_exception)) w.name = w.name.replace('Process', 'PoolWorker') w.daemon = True w.start() + pool.append(w) util.debug('added worker') - def _maintain_pool(self): + @staticmethod + def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue, + initializer, initargs, maxtasksperchild, + wrap_exception): """Clean up any exited workers and start replacements for them. """ - if self._join_exited_workers(): - self._repopulate_pool() + if Pool._join_exited_workers(pool): + Pool._repopulate_pool_static(ctx, Process, processes, pool, + inqueue, outqueue, initializer, + initargs, maxtasksperchild, + wrap_exception) def _setup_queues(self): self._inqueue = self._ctx.SimpleQueue() @@ -246,11 +345,15 @@ class Pool(object): self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv + def _check_running(self): + if self._state != RUN: + raise ValueError("Pool not running") + def apply(self, func, args=(), kwds={}): ''' Equivalent of `func(*args, **kwds)`. + Pool must be running. ''' - assert self._state == RUN return self.apply_async(func, args, kwds).get() def map(self, func, iterable, chunksize=None): @@ -276,42 +379,72 @@ class Pool(object): return self._map_async(func, iterable, starmapstar, chunksize, callback, error_callback) + def _guarded_task_generation(self, result_job, func, iterable): + '''Provides a generator of tasks for imap and imap_unordered with + appropriate handling for iterables which throw exceptions during + iteration.''' + try: + i = -1 + for i, x in enumerate(iterable): + yield (result_job, i, func, (x,), {}) + except Exception as e: + yield (result_job, i+1, _helper_reraises_exception, (e,), {}) + def imap(self, func, iterable, chunksize=1): ''' Equivalent of `map()` -- can be MUCH slower than `Pool.map()`. ''' - if self._state != RUN: - raise ValueError("Pool not running") + self._check_running() if chunksize == 1: - result = IMapIterator(self._cache) - self._taskqueue.put((((result._job, i, func, (x,), {}) - for i, x in enumerate(iterable)), result._set_length)) + result = IMapIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, func, iterable), + result._set_length + )) return result else: - assert chunksize > 1 + if chunksize < 1: + raise ValueError( + "Chunksize must be 1+, not {0:n}".format( + chunksize)) task_batches = Pool._get_tasks(func, iterable, chunksize) - result = IMapIterator(self._cache) - self._taskqueue.put((((result._job, i, mapstar, (x,), {}) - for i, x in enumerate(task_batches)), result._set_length)) + result = IMapIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mapstar, + task_batches), + result._set_length + )) return (item for chunk in result for item in chunk) def imap_unordered(self, func, iterable, chunksize=1): ''' Like `imap()` method but ordering of results is arbitrary. ''' - if self._state != RUN: - raise ValueError("Pool not running") + self._check_running() if chunksize == 1: - result = IMapUnorderedIterator(self._cache) - self._taskqueue.put((((result._job, i, func, (x,), {}) - for i, x in enumerate(iterable)), result._set_length)) + result = IMapUnorderedIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, func, iterable), + result._set_length + )) return result else: - assert chunksize > 1 + if chunksize < 1: + raise ValueError( + "Chunksize must be 1+, not {0!r}".format(chunksize)) task_batches = Pool._get_tasks(func, iterable, chunksize) - result = IMapUnorderedIterator(self._cache) - self._taskqueue.put((((result._job, i, mapstar, (x,), {}) - for i, x in enumerate(task_batches)), result._set_length)) + result = IMapUnorderedIterator(self) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mapstar, + task_batches), + result._set_length + )) return (item for chunk in result for item in chunk) def apply_async(self, func, args=(), kwds={}, callback=None, @@ -319,10 +452,9 @@ class Pool(object): ''' Asynchronous version of `apply()` method. ''' - if self._state != RUN: - raise ValueError("Pool not running") - result = ApplyResult(self._cache, callback, error_callback) - self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) + self._check_running() + result = ApplyResult(self, callback, error_callback) + self._taskqueue.put(([(result._job, 0, func, args, kwds)], None)) return result def map_async(self, func, iterable, chunksize=None, callback=None, @@ -338,8 +470,7 @@ class Pool(object): ''' Helper function to implement map, starmap and their async counterparts. ''' - if self._state != RUN: - raise ValueError("Pool not running") + self._check_running() if not hasattr(iterable, '__len__'): iterable = list(iterable) @@ -351,23 +482,43 @@ class Pool(object): chunksize = 0 task_batches = Pool._get_tasks(func, iterable, chunksize) - result = MapResult(self._cache, chunksize, len(iterable), callback, + result = MapResult(self, chunksize, len(iterable), callback, error_callback=error_callback) - self._taskqueue.put((((result._job, i, mapper, (x,), {}) - for i, x in enumerate(task_batches)), None)) + self._taskqueue.put( + ( + self._guarded_task_generation(result._job, + mapper, + task_batches), + None + ) + ) return result @staticmethod - def _handle_workers(pool): + def _wait_for_updates(sentinels, change_notifier, timeout=None): + wait(sentinels, timeout=timeout) + while not change_notifier.empty(): + change_notifier.get() + + @classmethod + def _handle_workers(cls, cache, taskqueue, ctx, Process, processes, + pool, inqueue, outqueue, initializer, initargs, + maxtasksperchild, wrap_exception, sentinels, + change_notifier): thread = threading.current_thread() # Keep maintaining workers until the cache gets drained, unless the pool # is terminated. - while thread._state == RUN or (pool._cache and thread._state != TERMINATE): - pool._maintain_pool() - time.sleep(0.1) + while thread._state == RUN or (cache and thread._state != TERMINATE): + cls._maintain_pool(ctx, Process, processes, pool, inqueue, + outqueue, initializer, initargs, + maxtasksperchild, wrap_exception) + + current_sentinels = [*cls._get_worker_sentinels(pool), *sentinels] + + cls._wait_for_updates(current_sentinels, change_notifier) # send sentinel to stop workers - pool._taskqueue.put(None) + taskqueue.put(None) util.debug('worker handler exiting') @staticmethod @@ -376,37 +527,32 @@ class Pool(object): for taskseq, set_length in iter(taskqueue.get, None): task = None - i = -1 try: - for i, task in enumerate(taskseq): - if thread._state: + # iterating taskseq cannot fail + for task in taskseq: + if thread._state != RUN: util.debug('task handler found thread._state != RUN') break try: put(task) except Exception as e: - job, ind = task[:2] + job, idx = task[:2] try: - cache[job]._set(ind, (False, e)) + cache[job]._set(idx, (False, e)) except KeyError: pass else: if set_length: util.debug('doing set_length()') - set_length(i+1) + idx = task[1] if task else -1 + set_length(idx + 1) continue break - except Exception as ex: - job, ind = task[:2] if task else (0, 0) - if job in cache: - cache[job]._set(ind + 1, (False, ex)) - if set_length: - util.debug('doing set_length()') - set_length(i+1) + finally: + task = taskseq = job = None else: util.debug('task handler got sentinel') - try: # tell result handler to finish when cache is empty util.debug('task handler sending sentinel to result handler') @@ -432,8 +578,8 @@ class Pool(object): util.debug('result handler got EOFError/OSError -- exiting') return - if thread._state: - assert thread._state == TERMINATE + if thread._state != RUN: + assert thread._state == TERMINATE, "Thread not in TERMINATE" util.debug('result handler found thread._state=TERMINATE') break @@ -446,6 +592,7 @@ class Pool(object): cache[job]._set(i, obj) except KeyError: pass + task = job = obj = None while cache and thread._state != TERMINATE: try: @@ -462,6 +609,7 @@ class Pool(object): cache[job]._set(i, obj) except KeyError: pass + task = job = obj = None if hasattr(outqueue, '_reader'): util.debug('ensuring that outqueue is not full') @@ -498,16 +646,19 @@ class Pool(object): if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE + self._change_notifier.put(None) def terminate(self): util.debug('terminating pool') self._state = TERMINATE - self._worker_handler._state = TERMINATE self._terminate() def join(self): util.debug('joining pool') - assert self._state in (CLOSE, TERMINATE) + if self._state == RUN: + raise ValueError("Pool is still running") + elif self._state not in (CLOSE, TERMINATE): + raise ValueError("In unknown state") self._worker_handler.join() self._task_handler.join() self._result_handler.join() @@ -524,20 +675,28 @@ class Pool(object): time.sleep(0) @classmethod - def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, + def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, change_notifier, worker_handler, task_handler, result_handler, cache): # this is guaranteed to only be called once util.debug('finalizing pool') + # Notify that the worker_handler state has been changed so the + # _handle_workers loop can be unblocked (and exited) in order to + # send the finalization sentinel all the workers. worker_handler._state = TERMINATE + change_notifier.put(None) + task_handler._state = TERMINATE util.debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) - assert result_handler.is_alive() or len(cache) == 0 + if (not result_handler.is_alive()) and (len(cache) != 0): + raise AssertionError( + "Cannot have cache with result_hander not alive") result_handler._state = TERMINATE + change_notifier.put(None) outqueue.put(None) # sentinel # We must wait for the worker handler to exit before terminating @@ -570,6 +729,7 @@ class Pool(object): p.join() def __enter__(self): + self._check_running() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -581,19 +741,21 @@ class Pool(object): class ApplyResult(object): - def __init__(self, cache, callback, error_callback): + def __init__(self, pool, callback, error_callback): + self._pool = pool self._event = threading.Event() self._job = next(job_counter) - self._cache = cache + self._cache = pool._cache self._callback = callback self._error_callback = error_callback - cache[self._job] = self + self._cache[self._job] = self def ready(self): return self._event.is_set() def successful(self): - assert self.ready() + if not self.ready(): + raise ValueError("{0!r} not ready".format(self)) return self._success def wait(self, timeout=None): @@ -616,6 +778,7 @@ class ApplyResult(object): self._error_callback(self._value) self._event.set() del self._cache[self._job] + self._pool = None __class_getitem__ = classmethod(types.GenericAlias) @@ -627,8 +790,8 @@ AsyncResult = ApplyResult # create alias -- see #17805 class MapResult(ApplyResult): - def __init__(self, cache, chunksize, length, callback, error_callback): - ApplyResult.__init__(self, cache, callback, + def __init__(self, pool, chunksize, length, callback, error_callback): + ApplyResult.__init__(self, pool, callback, error_callback=error_callback) self._success = True self._value = [None] * length @@ -636,7 +799,7 @@ class MapResult(ApplyResult): if chunksize <= 0: self._number_left = 0 self._event.set() - del cache[self._job] + del self._cache[self._job] else: self._number_left = length//chunksize + bool(length % chunksize) @@ -650,6 +813,7 @@ class MapResult(ApplyResult): self._callback(self._value) del self._cache[self._job] self._event.set() + self._pool = None else: if not success and self._success: # only store first exception @@ -661,6 +825,7 @@ class MapResult(ApplyResult): self._error_callback(self._value) del self._cache[self._job] self._event.set() + self._pool = None # # Class whose instances are returned by `Pool.imap()` @@ -668,15 +833,16 @@ class MapResult(ApplyResult): class IMapIterator(object): - def __init__(self, cache): + def __init__(self, pool): + self._pool = pool self._cond = threading.Condition(threading.Lock()) self._job = next(job_counter) - self._cache = cache + self._cache = pool._cache self._items = collections.deque() self._index = 0 self._length = None self._unsorted = {} - cache[self._job] = self + self._cache[self._job] = self def __iter__(self): return self @@ -687,14 +853,16 @@ class IMapIterator(object): item = self._items.popleft() except IndexError: if self._index == self._length: - raise StopIteration + self._pool = None + raise StopIteration from None self._cond.wait(timeout) try: item = self._items.popleft() except IndexError: if self._index == self._length: - raise StopIteration - raise TimeoutError + self._pool = None + raise StopIteration from None + raise TimeoutError from None success, value = item if success: @@ -718,6 +886,7 @@ class IMapIterator(object): if self._index == self._length: del self._cache[self._job] + self._pool = None def _set_length(self, length): with self._cond: @@ -725,6 +894,7 @@ class IMapIterator(object): if self._index == self._length: self._cond.notify() del self._cache[self._job] + self._pool = None # # Class whose instances are returned by `Pool.imap_unordered()` @@ -739,6 +909,7 @@ class IMapUnorderedIterator(IMapIterator): self._cond.notify() if self._index == self._length: del self._cache[self._job] + self._pool = None # # @@ -748,7 +919,7 @@ class ThreadPool(Pool): _wrap_exception = False @staticmethod - def Process(*args, **kwds): + def Process(ctx, *args, **kwds): from .dummy import Process return Process(*args, **kwds) @@ -756,15 +927,28 @@ class ThreadPool(Pool): Pool.__init__(self, processes, initializer, initargs) def _setup_queues(self): - self._inqueue = queue.Queue() - self._outqueue = queue.Queue() + self._inqueue = queue.SimpleQueue() + self._outqueue = queue.SimpleQueue() self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get + def _get_sentinels(self): + return [self._change_notifier._reader] + + @staticmethod + def _get_worker_sentinels(workers): + return [] + @staticmethod def _help_stuff_finish(inqueue, task_handler, size): - # put sentinels at head of inqueue to make workers finish - with inqueue.not_empty: - inqueue.queue.clear() - inqueue.queue.extend([None] * size) - inqueue.not_empty.notify_all() + # drain inqueue, and put sentinels at its head to make workers finish + try: + while True: + inqueue.get(block=False) + except queue.Empty: + pass + for i in range(size): + inqueue.put(None) + + def _wait_for_updates(self, sentinels, change_notifier, timeout): + time.sleep(timeout) diff --git a/Lib/multiprocessing/popen_fork.py b/Lib/multiprocessing/popen_fork.py index d2ebd7cfbe..625981cf47 100644 --- a/Lib/multiprocessing/popen_fork.py +++ b/Lib/multiprocessing/popen_fork.py @@ -1,5 +1,4 @@ import os -import sys import signal from . import util @@ -14,9 +13,9 @@ class Popen(object): method = 'fork' def __init__(self, process_obj): - sys.stdout.flush() - sys.stderr.flush() + util._flush_std_streams() self.returncode = None + self.finalizer = None self._launch(process_obj) def duplicate_for_child(self, fd): @@ -24,21 +23,14 @@ class Popen(object): def poll(self, flag=os.WNOHANG): if self.returncode is None: - while True: - try: - pid, sts = os.waitpid(self.pid, flag) - except OSError as e: - # Child process not yet created. See #1731717 - # e.errno == errno.ECHILD == 10 - return None - else: - break + try: + pid, sts = os.waitpid(self.pid, flag) + except OSError: + # Child process not yet created. See #1731717 + # e.errno == errno.ECHILD == 10 + return None if pid == self.pid: - if os.WIFSIGNALED(sts): - self.returncode = -os.WTERMSIG(sts) - else: - assert os.WIFEXITED(sts) - self.returncode = os.WEXITSTATUS(sts) + self.returncode = os.waitstatus_to_exitcode(sts) return self.returncode def wait(self, timeout=None): @@ -51,30 +43,41 @@ class Popen(object): return self.poll(os.WNOHANG if timeout == 0.0 else 0) return self.returncode - def terminate(self): + def _send_signal(self, sig): if self.returncode is None: try: - os.kill(self.pid, signal.SIGTERM) + os.kill(self.pid, sig) except ProcessLookupError: pass except OSError: if self.wait(timeout=0.1) is None: raise + def terminate(self): + self._send_signal(signal.SIGTERM) + + def kill(self): + self._send_signal(signal.SIGKILL) + def _launch(self, process_obj): code = 1 parent_r, child_w = os.pipe() + child_r, parent_w = os.pipe() self.pid = os.fork() if self.pid == 0: try: os.close(parent_r) - if 'random' in sys.modules: - import random - random.seed() - code = process_obj._bootstrap() + os.close(parent_w) + code = process_obj._bootstrap(parent_sentinel=child_r) finally: os._exit(code) else: os.close(child_w) - util.Finalize(self, os.close, (parent_r,)) + os.close(child_r) + self.finalizer = util.Finalize(self, util.close_fds, + (parent_r, parent_w,)) self.sentinel = parent_r + + def close(self): + if self.finalizer is not None: + self.finalizer() diff --git a/Lib/multiprocessing/popen_forkserver.py b/Lib/multiprocessing/popen_forkserver.py index 222db2d90a..a56eb9bf11 100644 --- a/Lib/multiprocessing/popen_forkserver.py +++ b/Lib/multiprocessing/popen_forkserver.py @@ -49,10 +49,14 @@ class Popen(popen_fork.Popen): set_spawning_popen(None) self.sentinel, w = forkserver.connect_to_new_process(self._fds) - util.Finalize(self, os.close, (self.sentinel,)) + # Keep a duplicate of the data pipe's write end as a sentinel of the + # parent process used by the child process. + _parent_w = os.dup(w) + self.finalizer = util.Finalize(self, util.close_fds, + (_parent_w, self.sentinel)) with open(w, 'wb', closefd=True) as f: f.write(buf.getbuffer()) - self.pid = forkserver.read_unsigned(self.sentinel) + self.pid = forkserver.read_signed(self.sentinel) def poll(self, flag=os.WNOHANG): if self.returncode is None: @@ -61,8 +65,10 @@ class Popen(popen_fork.Popen): if not wait([self.sentinel], timeout): return None try: - self.returncode = forkserver.read_unsigned(self.sentinel) + self.returncode = forkserver.read_signed(self.sentinel) except (OSError, EOFError): - # The process ended abnormally perhaps because of a signal + # This should not happen usually, but perhaps the forkserver + # process itself got killed self.returncode = 255 + return self.returncode diff --git a/Lib/multiprocessing/popen_spawn_posix.py b/Lib/multiprocessing/popen_spawn_posix.py index 98f8f0ab33..24b8634523 100644 --- a/Lib/multiprocessing/popen_spawn_posix.py +++ b/Lib/multiprocessing/popen_spawn_posix.py @@ -36,8 +36,8 @@ class Popen(popen_fork.Popen): return fd def _launch(self, process_obj): - from . import semaphore_tracker - tracker_fd = semaphore_tracker.getfd() + from . import resource_tracker + tracker_fd = resource_tracker.getfd() self._fds.append(tracker_fd) prep_data = spawn.get_preparation_data(process_obj._name) fp = io.BytesIO() @@ -61,8 +61,12 @@ class Popen(popen_fork.Popen): with open(parent_w, 'wb', closefd=False) as f: f.write(fp.getbuffer()) finally: - if parent_r is not None: - util.Finalize(self, os.close, (parent_r,)) - for fd in (child_r, child_w, parent_w): + fds_to_close = [] + for fd in (parent_r, parent_w): + if fd is not None: + fds_to_close.append(fd) + self.finalizer = util.Finalize(self, util.close_fds, fds_to_close) + + for fd in (child_r, child_w): if fd is not None: os.close(fd) diff --git a/Lib/multiprocessing/popen_spawn_win32.py b/Lib/multiprocessing/popen_spawn_win32.py index 6fd588f542..9c4098d0fa 100644 --- a/Lib/multiprocessing/popen_spawn_win32.py +++ b/Lib/multiprocessing/popen_spawn_win32.py @@ -18,6 +18,18 @@ TERMINATE = 0x10000 WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") + +def _path_eq(p1, p2): + return p1 == p2 or os.path.normcase(p1) == os.path.normcase(p2) + +WINENV = not _path_eq(sys.executable, sys._base_executable) + + +def _close_handles(*handles): + for handle in handles: + _winapi.CloseHandle(handle) + + # # We define a Popen class similar to the one from subprocess, but # whose constructor takes a process object as its argument. @@ -32,20 +44,35 @@ class Popen(object): def __init__(self, process_obj): prep_data = spawn.get_preparation_data(process_obj._name) - # read end of pipe will be "stolen" by the child process + # read end of pipe will be duplicated by the child process # -- see spawn_main() in spawn.py. + # + # bpo-33929: Previously, the read end of pipe was "stolen" by the child + # process, but it leaked a handle if the child process had been + # terminated before it could steal the handle from the parent process. rhandle, whandle = _winapi.CreatePipe(None, 0) wfd = msvcrt.open_osfhandle(whandle, 0) cmd = spawn.get_command_line(parent_pid=os.getpid(), pipe_handle=rhandle) cmd = ' '.join('"%s"' % x for x in cmd) + python_exe = spawn.get_executable() + + # bpo-35797: When running in a venv, we bypass the redirect + # executor and launch our base Python. + if WINENV and _path_eq(python_exe, sys.executable): + python_exe = sys._base_executable + env = os.environ.copy() + env["__PYVENV_LAUNCHER__"] = sys.executable + else: + env = None + with open(wfd, 'wb', closefd=True) as to_child: # start process try: hp, ht, pid, tid = _winapi.CreateProcess( - spawn.get_executable(), cmd, - None, None, False, 0, None, None, None) + python_exe, cmd, + None, None, False, 0, env, None, None) _winapi.CloseHandle(ht) except: _winapi.CloseHandle(rhandle) @@ -56,7 +83,8 @@ class Popen(object): self.returncode = None self._handle = hp self.sentinel = int(hp) - util.Finalize(self, _winapi.CloseHandle, (self.sentinel,)) + self.finalizer = util.Finalize(self, _close_handles, + (self.sentinel, int(rhandle))) # send information to child set_spawning_popen(self) @@ -96,3 +124,8 @@ class Popen(object): except OSError: if self.wait(timeout=1.0) is None: raise + + kill = terminate + + def close(self): + self.finalizer() diff --git a/Lib/multiprocessing/process.py b/Lib/multiprocessing/process.py index bca8b7a004..0b2e0b45b2 100644 --- a/Lib/multiprocessing/process.py +++ b/Lib/multiprocessing/process.py @@ -7,7 +7,8 @@ # Licensed to PSF under a Contributor Agreement. # -__all__ = ['BaseProcess', 'current_process', 'active_children'] +__all__ = ['BaseProcess', 'current_process', 'active_children', + 'parent_process'] # # Imports @@ -17,6 +18,7 @@ import os import sys import signal import itertools +import threading from _weakrefset import WeakSet # @@ -45,6 +47,13 @@ def active_children(): _cleanup() return list(_children) + +def parent_process(): + ''' + Return process object representing the parent process + ''' + return _parent_process + # # # @@ -75,7 +84,9 @@ class BaseProcess(object): self._identity = _current_process._identity + (count,) self._config = _current_process._config.copy() self._parent_pid = os.getpid() + self._parent_name = _current_process.name self._popen = None + self._closed = False self._target = target self._args = tuple(args) self._kwargs = dict(kwargs) @@ -85,6 +96,10 @@ class BaseProcess(object): self.daemon = daemon _dangling.add(self) + def _check_closed(self): + if self._closed: + raise ValueError("process object is closed") + def run(self): ''' Method to be run in sub-process; can be overridden in sub-class @@ -96,6 +111,7 @@ class BaseProcess(object): ''' Start child process ''' + self._check_closed() assert self._popen is None, 'cannot start a process twice' assert self._parent_pid == os.getpid(), \ 'can only start a process object created by current process' @@ -104,18 +120,30 @@ class BaseProcess(object): _cleanup() self._popen = self._Popen(self) self._sentinel = self._popen.sentinel + # Avoid a refcycle if the target function holds an indirect + # reference to the process object (see bpo-30775) + del self._target, self._args, self._kwargs _children.add(self) def terminate(self): ''' Terminate process; sends SIGTERM signal or uses TerminateProcess() ''' + self._check_closed() self._popen.terminate() + def kill(self): + ''' + Terminate process; sends SIGKILL signal or uses TerminateProcess() + ''' + self._check_closed() + self._popen.kill() + def join(self, timeout=None): ''' Wait until child process terminates ''' + self._check_closed() assert self._parent_pid == os.getpid(), 'can only join a child process' assert self._popen is not None, 'can only join a started process' res = self._popen.wait(timeout) @@ -126,13 +154,37 @@ class BaseProcess(object): ''' Return whether process is alive ''' + self._check_closed() if self is _current_process: return True assert self._parent_pid == os.getpid(), 'can only test a child process' + if self._popen is None: return False - self._popen.poll() - return self._popen.returncode is None + + returncode = self._popen.poll() + if returncode is None: + return True + else: + _children.discard(self) + return False + + def close(self): + ''' + Close the Process object. + + This method releases resources held by the Process object. It is + an error to call this method if the child process is still running. + ''' + if self._popen is not None: + if self._popen.poll() is None: + raise ValueError("Cannot close a process while it is still running. " + "You should first call join() or terminate().") + self._popen.close() + self._popen = None + del self._sentinel + _children.discard(self) + self._closed = True @property def name(self): @@ -174,6 +226,7 @@ class BaseProcess(object): ''' Return exit code of process or `None` if it has yet to stop ''' + self._check_closed() if self._popen is None: return self._popen return self._popen.poll() @@ -183,6 +236,7 @@ class BaseProcess(object): ''' Return identifier (PID) of process or `None` if it has yet to start ''' + self._check_closed() if self is _current_process: return os.getpid() else: @@ -196,38 +250,46 @@ class BaseProcess(object): Return a file descriptor (Unix) or handle (Windows) suitable for waiting for process termination. ''' + self._check_closed() try: return self._sentinel except AttributeError: - raise ValueError("process not started") + raise ValueError("process not started") from None def __repr__(self): + exitcode = None if self is _current_process: status = 'started' + elif self._closed: + status = 'closed' elif self._parent_pid != os.getpid(): status = 'unknown' elif self._popen is None: status = 'initial' else: - if self._popen.poll() is not None: - status = self.exitcode + exitcode = self._popen.poll() + if exitcode is not None: + status = 'stopped' else: status = 'started' - if type(status) is int: - if status == 0: - status = 'stopped' - else: - status = 'stopped[%s]' % _exitcode_to_name.get(status, status) - - return '<%s(%s, %s%s)>' % (type(self).__name__, self._name, - status, self.daemon and ' daemon' or '') + info = [type(self).__name__, 'name=%r' % self._name] + if self._popen is not None: + info.append('pid=%s' % self._popen.pid) + info.append('parent=%s' % self._parent_pid) + info.append(status) + if exitcode is not None: + exitcode = _exitcode_to_name.get(exitcode, exitcode) + info.append('exitcode=%s' % exitcode) + if self.daemon: + info.append('daemon') + return '<%s>' % ' '.join(info) ## - def _bootstrap(self): + def _bootstrap(self, parent_sentinel=None): from . import util, context - global _current_process, _process_counter, _children + global _current_process, _parent_process, _process_counter, _children try: if self._start_method is not None: @@ -237,6 +299,10 @@ class BaseProcess(object): util._close_stdin() old_process = _current_process _current_process = self + _parent_process = _ParentProcess( + self._parent_name, self._parent_pid, parent_sentinel) + if threading._HAVE_THREAD_NATIVE_ID: + threading.main_thread()._set_native_id() try: util._finalizer_registry.clear() util._run_after_forkers() @@ -251,12 +317,12 @@ class BaseProcess(object): finally: util._exit_function() except SystemExit as e: - if not e.args: - exitcode = 1 - elif isinstance(e.args[0], int): - exitcode = e.args[0] + if e.code is None: + exitcode = 0 + elif isinstance(e.code, int): + exitcode = e.code else: - sys.stderr.write(str(e.args[0]) + '\n') + sys.stderr.write(str(e.code) + '\n') exitcode = 1 except: exitcode = 1 @@ -264,9 +330,9 @@ class BaseProcess(object): sys.stderr.write('Process %s:\n' % self.name) traceback.print_exc() finally: + threading._shutdown() util.info('process exiting with exitcode %d' % exitcode) - sys.stdout.flush() - sys.stderr.flush() + util._flush_std_streams() return exitcode @@ -284,6 +350,40 @@ class AuthenticationString(bytes): ) return AuthenticationString, (bytes(self),) + +# +# Create object representing the parent process +# + +class _ParentProcess(BaseProcess): + + def __init__(self, name, pid, sentinel): + self._identity = () + self._name = name + self._pid = pid + self._parent_pid = None + self._popen = None + self._closed = False + self._sentinel = sentinel + self._config = {} + + def is_alive(self): + from multiprocessing.connection import wait + return not wait([self._sentinel], timeout=0) + + @property + def ident(self): + return self._pid + + def join(self, timeout=None): + ''' + Wait until parent process terminates + ''' + from multiprocessing.connection import wait + wait([self._sentinel], timeout=timeout) + + pid = ident + # # Create object representing the main process # @@ -295,6 +395,7 @@ class _MainProcess(BaseProcess): self._name = 'MainProcess' self._parent_pid = None self._popen = None + self._closed = False self._config = {'authkey': AuthenticationString(os.urandom(32)), 'semprefix': '/mp'} # Note that some versions of FreeBSD only allow named @@ -307,7 +408,11 @@ class _MainProcess(BaseProcess): # Everything in self._config will be inherited by descendant # processes. + def close(self): + pass + +_parent_process = None _current_process = _MainProcess() _process_counter = itertools.count(1) _children = set() @@ -321,7 +426,7 @@ _exitcode_to_name = {} for name, signum in list(signal.__dict__.items()): if name[:3]=='SIG' and '_' not in name: - _exitcode_to_name[-signum] = name + _exitcode_to_name[-signum] = f'-{name}' # For debug and leak testing _dangling = WeakSet() diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index 70e2dca062..f37f114a96 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -49,8 +49,7 @@ class Queue(object): self._sem = ctx.BoundedSemaphore(maxsize) # For use by concurrent.futures self._ignore_epipe = False - - self._after_fork() + self._reset() if sys.platform != 'win32': register_after_fork(self, Queue._after_fork) @@ -63,11 +62,17 @@ class Queue(object): def __setstate__(self, state): (self._ignore_epipe, self._maxsize, self._reader, self._writer, self._rlock, self._wlock, self._sem, self._opid) = state - self._after_fork() + self._reset() def _after_fork(self): debug('Queue._after_fork()') - self._notempty = threading.Condition(threading.Lock()) + self._reset(after_fork=True) + + def _reset(self, after_fork=False): + if after_fork: + self._notempty._at_fork_reinit() + else: + self._notempty = threading.Condition(threading.Lock()) self._buffer = collections.deque() self._thread = None self._jointhread = None @@ -79,7 +84,8 @@ class Queue(object): self._poll = self._reader.poll def put(self, obj, block=True, timeout=None): - assert not self._closed + if self._closed: + raise ValueError(f"Queue {self!r} is closed") if not self._sem.acquire(block, timeout): raise Full @@ -90,19 +96,21 @@ class Queue(object): self._notempty.notify() def get(self, block=True, timeout=None): + if self._closed: + raise ValueError(f"Queue {self!r} is closed") if block and timeout is None: with self._rlock: res = self._recv_bytes() self._sem.release() else: if block: - deadline = time.time() + timeout + deadline = time.monotonic() + timeout if not self._rlock.acquire(block, timeout): raise Empty try: if block: - timeout = deadline - time.time() - if timeout < 0 or not self._poll(timeout): + timeout = deadline - time.monotonic() + if not self._poll(timeout): raise Empty elif not self._poll(): raise Empty @@ -131,17 +139,14 @@ class Queue(object): def close(self): self._closed = True - try: - self._reader.close() - finally: - close = self._close - if close: - self._close = None - close() + close = self._close + if close: + self._close = None + close() def join_thread(self): debug('Queue.join_thread()') - assert self._closed + assert self._closed, "Queue {0!r} not closed".format(self) if self._jointhread: self._jointhread() @@ -161,23 +166,18 @@ class Queue(object): self._thread = threading.Thread( target=Queue._feed, args=(self._buffer, self._notempty, self._send_bytes, - self._wlock, self._writer.close, self._ignore_epipe), + self._wlock, self._reader.close, self._writer.close, + self._ignore_epipe, self._on_queue_feeder_error, + self._sem), name='QueueFeederThread' - ) + ) self._thread.daemon = True debug('doing self._thread.start()') self._thread.start() debug('... done self._thread.start()') - # On process exit we will wait for data to be flushed to pipe. - # - # However, if this process created the queue then all - # processes which use the queue will be descendants of this - # process. Therefore waiting for the queue to be flushed - # is pointless once all the child processes have been joined. - created_by_this_process = (self._opid == os.getpid()) - if not self._joincancelled and not created_by_this_process: + if not self._joincancelled: self._jointhread = Finalize( self._thread, Queue._finalize_join, [weakref.ref(self._thread)], @@ -209,7 +209,8 @@ class Queue(object): notempty.notify() @staticmethod - def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe): + def _feed(buffer, notempty, send_bytes, writelock, reader_close, + writer_close, ignore_epipe, onerror, queue_sem): debug('starting thread to feed data to pipe') nacquire = notempty.acquire nrelease = notempty.release @@ -222,8 +223,8 @@ class Queue(object): else: wacquire = None - try: - while 1: + while 1: + try: nacquire() try: if not buffer: @@ -235,7 +236,8 @@ class Queue(object): obj = bpopleft() if obj is sentinel: debug('feeder thread got sentinel -- exiting') - close() + reader_close() + writer_close() return # serialize the data before acquiring the lock @@ -250,21 +252,34 @@ class Queue(object): wrelease() except IndexError: pass - except Exception as e: - if ignore_epipe and getattr(e, 'errno', 0) == errno.EPIPE: - return - # Since this runs in a daemon thread the resources it uses - # may be become unusable while the process is cleaning up. - # We ignore errors which happen after the process has - # started to cleanup. - try: + except Exception as e: + if ignore_epipe and getattr(e, 'errno', 0) == errno.EPIPE: + return + # Since this runs in a daemon thread the resources it uses + # may be become unusable while the process is cleaning up. + # We ignore errors which happen after the process has + # started to cleanup. if is_exiting(): info('error in queue thread: %s', e) + return else: - import traceback - traceback.print_exc() - except Exception: - pass + # Since the object has not been sent in the queue, we need + # to decrease the size of the queue. The error acts as + # if the object had been silently removed from the queue + # and this step is necessary to have a properly working + # queue. + queue_sem.release() + onerror(e, obj) + + @staticmethod + def _on_queue_feeder_error(e, obj): + """ + Private API hook called when feeding data in the background thread + raises an exception. For overriding by concurrent.futures. + """ + import traceback + traceback.print_exc() + _sentinel = object() @@ -291,7 +306,8 @@ class JoinableQueue(Queue): self._cond, self._unfinished_tasks = state[-2:] def put(self, obj, block=True, timeout=None): - assert not self._closed + if self._closed: + raise ValueError(f"Queue {self!r} is closed") if not self._sem.acquire(block, timeout): raise Full @@ -329,6 +345,10 @@ class SimpleQueue(object): else: self._wlock = ctx.Lock() + def close(self): + self._reader.close() + self._writer.close() + def empty(self): return not self._poll() @@ -338,6 +358,7 @@ class SimpleQueue(object): def __setstate__(self, state): (self._reader, self._writer, self._rlock, self._wlock) = state + self._poll = self._reader.poll def get(self): with self._rlock: diff --git a/Lib/multiprocessing/reduction.py b/Lib/multiprocessing/reduction.py index c043c9a0dc..5593f0682f 100644 --- a/Lib/multiprocessing/reduction.py +++ b/Lib/multiprocessing/reduction.py @@ -7,7 +7,7 @@ # Licensed to PSF under a Contributor Agreement. # -from abc import ABCMeta, abstractmethod +from abc import ABCMeta import copyreg import functools import io @@ -68,12 +68,16 @@ if sys.platform == 'win32': __all__ += ['DupHandle', 'duplicate', 'steal_handle'] import _winapi - def duplicate(handle, target_process=None, inheritable=False): + def duplicate(handle, target_process=None, inheritable=False, + *, source_process=None): '''Duplicate a handle. (target_process is a handle not a pid!)''' + current_process = _winapi.GetCurrentProcess() + if source_process is None: + source_process = current_process if target_process is None: - target_process = _winapi.GetCurrentProcess() + target_process = current_process return _winapi.DuplicateHandle( - _winapi.GetCurrentProcess(), handle, target_process, + source_process, handle, target_process, 0, inheritable, _winapi.DUPLICATE_SAME_ACCESS) def steal_handle(source_pid, handle): @@ -150,7 +154,7 @@ else: '''Receive an array of fds over an AF_UNIX socket.''' a = array.array('i') bytes_size = a.itemsize * size - msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_LEN(bytes_size)) + msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_SPACE(bytes_size)) if not msg and not ancdata: raise EOFError try: @@ -165,7 +169,10 @@ else: if len(cmsg_data) % a.itemsize != 0: raise ValueError a.frombytes(cmsg_data) - assert len(a) % 256 == msg[0] + if len(a) % 256 != msg[0]: + raise AssertionError( + "Len is {0:n} but msg[0] is {1!r}".format( + len(a), msg[0])) return list(a) except (ValueError, IndexError): pass diff --git a/Lib/multiprocessing/resource_sharer.py b/Lib/multiprocessing/resource_sharer.py index e44a728fa9..66076509a1 100644 --- a/Lib/multiprocessing/resource_sharer.py +++ b/Lib/multiprocessing/resource_sharer.py @@ -59,11 +59,10 @@ else: class _ResourceSharer(object): - '''Manager for resouces using background thread.''' + '''Manager for resources using background thread.''' def __init__(self): self._key = 0 self._cache = {} - self._old_locks = [] self._lock = threading.Lock() self._listener = None self._address = None @@ -113,10 +112,7 @@ class _ResourceSharer(object): for key, (send, close) in self._cache.items(): close() self._cache.clear() - # If self._lock was locked at the time of the fork, it may be broken - # -- see issue 6721. Replace it without letting it be gc'ed. - self._old_locks.append(self._lock) - self._lock = threading.Lock() + self._lock._at_fork_reinit() if self._listener is not None: self._listener.close() self._listener = None @@ -125,7 +121,7 @@ class _ResourceSharer(object): def _start(self): from .connection import Listener - assert self._listener is None + assert self._listener is None, "Already have Listener" util.debug('starting listener and thread for sending handles') self._listener = Listener(authkey=process.current_process().authkey) self._address = self._listener.address @@ -136,7 +132,7 @@ class _ResourceSharer(object): def _serve(self): if hasattr(signal, 'pthread_sigmask'): - signal.pthread_sigmask(signal.SIG_BLOCK, range(1, signal.NSIG)) + signal.pthread_sigmask(signal.SIG_BLOCK, signal.valid_signals()) while 1: try: with self._listener.accept() as conn: diff --git a/Lib/multiprocessing/resource_tracker.py b/Lib/multiprocessing/resource_tracker.py new file mode 100644 index 0000000000..cc42dbdda0 --- /dev/null +++ b/Lib/multiprocessing/resource_tracker.py @@ -0,0 +1,239 @@ +############################################################################### +# Server process to keep track of unlinked resources (like shared memory +# segments, semaphores etc.) and clean them. +# +# On Unix we run a server process which keeps track of unlinked +# resources. The server ignores SIGINT and SIGTERM and reads from a +# pipe. Every other process of the program has a copy of the writable +# end of the pipe, so we get EOF when all other processes have exited. +# Then the server process unlinks any remaining resource names. +# +# This is important because there may be system limits for such resources: for +# instance, the system only supports a limited number of named semaphores, and +# shared-memory segments live in the RAM. If a python process leaks such a +# resource, this resource will not be removed till the next reboot. Without +# this resource tracker process, "killall python" would probably leave unlinked +# resources. + +import os +import signal +import sys +import threading +import warnings + +from . import spawn +from . import util + +__all__ = ['ensure_running', 'register', 'unregister'] + +_HAVE_SIGMASK = hasattr(signal, 'pthread_sigmask') +_IGNORED_SIGNALS = (signal.SIGINT, signal.SIGTERM) + +_CLEANUP_FUNCS = { + 'noop': lambda: None, +} + +if os.name == 'posix': + import _multiprocessing + import _posixshmem + + # Use sem_unlink() to clean up named semaphores. + # + # sem_unlink() may be missing if the Python build process detected the + # absence of POSIX named semaphores. In that case, no named semaphores were + # ever opened, so no cleanup would be necessary. + if hasattr(_multiprocessing, 'sem_unlink'): + _CLEANUP_FUNCS.update({ + 'semaphore': _multiprocessing.sem_unlink, + }) + _CLEANUP_FUNCS.update({ + 'shared_memory': _posixshmem.shm_unlink, + }) + + +class ResourceTracker(object): + + def __init__(self): + self._lock = threading.Lock() + self._fd = None + self._pid = None + + def _stop(self): + with self._lock: + if self._fd is None: + # not running + return + + # closing the "alive" file descriptor stops main() + os.close(self._fd) + self._fd = None + + os.waitpid(self._pid, 0) + self._pid = None + + def getfd(self): + self.ensure_running() + return self._fd + + def ensure_running(self): + '''Make sure that resource tracker process is running. + + This can be run from any process. Usually a child process will use + the resource created by its parent.''' + with self._lock: + if self._fd is not None: + # resource tracker was launched before, is it still running? + if self._check_alive(): + # => still alive + return + # => dead, launch it again + os.close(self._fd) + + # Clean-up to avoid dangling processes. + try: + # _pid can be None if this process is a child from another + # python process, which has started the resource_tracker. + if self._pid is not None: + os.waitpid(self._pid, 0) + except ChildProcessError: + # The resource_tracker has already been terminated. + pass + self._fd = None + self._pid = None + + warnings.warn('resource_tracker: process died unexpectedly, ' + 'relaunching. Some resources might leak.') + + fds_to_pass = [] + try: + fds_to_pass.append(sys.stderr.fileno()) + except Exception: + pass + cmd = 'from multiprocessing.resource_tracker import main;main(%d)' + r, w = os.pipe() + try: + fds_to_pass.append(r) + # process will out live us, so no need to wait on pid + exe = spawn.get_executable() + args = [exe] + util._args_from_interpreter_flags() + args += ['-c', cmd % r] + # bpo-33613: Register a signal mask that will block the signals. + # This signal mask will be inherited by the child that is going + # to be spawned and will protect the child from a race condition + # that can make the child die before it registers signal handlers + # for SIGINT and SIGTERM. The mask is unregistered after spawning + # the child. + try: + if _HAVE_SIGMASK: + signal.pthread_sigmask(signal.SIG_BLOCK, _IGNORED_SIGNALS) + pid = util.spawnv_passfds(exe, args, fds_to_pass) + finally: + if _HAVE_SIGMASK: + signal.pthread_sigmask(signal.SIG_UNBLOCK, _IGNORED_SIGNALS) + except: + os.close(w) + raise + else: + self._fd = w + self._pid = pid + finally: + os.close(r) + + def _check_alive(self): + '''Check that the pipe has not been closed by sending a probe.''' + try: + # We cannot use send here as it calls ensure_running, creating + # a cycle. + os.write(self._fd, b'PROBE:0:noop\n') + except OSError: + return False + else: + return True + + def register(self, name, rtype): + '''Register name of resource with resource tracker.''' + self._send('REGISTER', name, rtype) + + def unregister(self, name, rtype): + '''Unregister name of resource with resource tracker.''' + self._send('UNREGISTER', name, rtype) + + def _send(self, cmd, name, rtype): + self.ensure_running() + msg = '{0}:{1}:{2}\n'.format(cmd, name, rtype).encode('ascii') + if len(name) > 512: + # posix guarantees that writes to a pipe of less than PIPE_BUF + # bytes are atomic, and that PIPE_BUF >= 512 + raise ValueError('name too long') + nbytes = os.write(self._fd, msg) + assert nbytes == len(msg), "nbytes {0:n} but len(msg) {1:n}".format( + nbytes, len(msg)) + + +_resource_tracker = ResourceTracker() +ensure_running = _resource_tracker.ensure_running +register = _resource_tracker.register +unregister = _resource_tracker.unregister +getfd = _resource_tracker.getfd + +def main(fd): + '''Run resource tracker.''' + # protect the process from ^C and "killall python" etc + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + if _HAVE_SIGMASK: + signal.pthread_sigmask(signal.SIG_UNBLOCK, _IGNORED_SIGNALS) + + for f in (sys.stdin, sys.stdout): + try: + f.close() + except Exception: + pass + + cache = {rtype: set() for rtype in _CLEANUP_FUNCS.keys()} + try: + # keep track of registered/unregistered resources + with open(fd, 'rb') as f: + for line in f: + try: + cmd, name, rtype = line.strip().decode('ascii').split(':') + cleanup_func = _CLEANUP_FUNCS.get(rtype, None) + if cleanup_func is None: + raise ValueError( + f'Cannot register {name} for automatic cleanup: ' + f'unknown resource type {rtype}') + + if cmd == 'REGISTER': + cache[rtype].add(name) + elif cmd == 'UNREGISTER': + cache[rtype].remove(name) + elif cmd == 'PROBE': + pass + else: + raise RuntimeError('unrecognized command %r' % cmd) + except Exception: + try: + sys.excepthook(*sys.exc_info()) + except: + pass + finally: + # all processes have terminated; cleanup any remaining resources + for rtype, rtype_cache in cache.items(): + if rtype_cache: + try: + warnings.warn('resource_tracker: There appear to be %d ' + 'leaked %s objects to clean up at shutdown' % + (len(rtype_cache), rtype)) + except Exception: + pass + for name in rtype_cache: + # For some reason the process which created and registered this + # resource has failed to unregister it. Presumably it has + # died. We therefore unlink it. + try: + try: + _CLEANUP_FUNCS[rtype](name) + except Exception as e: + warnings.warn('resource_tracker: %r: %s' % (name, e)) + finally: + pass diff --git a/Lib/multiprocessing/semaphore_tracker.py b/Lib/multiprocessing/semaphore_tracker.py deleted file mode 100644 index de7738eeee..0000000000 --- a/Lib/multiprocessing/semaphore_tracker.py +++ /dev/null @@ -1,143 +0,0 @@ -# -# On Unix we run a server process which keeps track of unlinked -# semaphores. The server ignores SIGINT and SIGTERM and reads from a -# pipe. Every other process of the program has a copy of the writable -# end of the pipe, so we get EOF when all other processes have exited. -# Then the server process unlinks any remaining semaphore names. -# -# This is important because the system only supports a limited number -# of named semaphores, and they will not be automatically removed till -# the next reboot. Without this semaphore tracker process, "killall -# python" would probably leave unlinked semaphores. -# - -import os -import signal -import sys -import threading -import warnings -import _multiprocessing - -from . import spawn -from . import util - -__all__ = ['ensure_running', 'register', 'unregister'] - - -class SemaphoreTracker(object): - - def __init__(self): - self._lock = threading.Lock() - self._fd = None - - def getfd(self): - self.ensure_running() - return self._fd - - def ensure_running(self): - '''Make sure that semaphore tracker process is running. - - This can be run from any process. Usually a child process will use - the semaphore created by its parent.''' - with self._lock: - if self._fd is not None: - return - fds_to_pass = [] - try: - fds_to_pass.append(sys.stderr.fileno()) - except Exception: - pass - cmd = 'from multiprocessing.semaphore_tracker import main;main(%d)' - r, w = os.pipe() - try: - fds_to_pass.append(r) - # process will out live us, so no need to wait on pid - exe = spawn.get_executable() - args = [exe] + util._args_from_interpreter_flags() - args += ['-c', cmd % r] - util.spawnv_passfds(exe, args, fds_to_pass) - except: - os.close(w) - raise - else: - self._fd = w - finally: - os.close(r) - - def register(self, name): - '''Register name of semaphore with semaphore tracker.''' - self._send('REGISTER', name) - - def unregister(self, name): - '''Unregister name of semaphore with semaphore tracker.''' - self._send('UNREGISTER', name) - - def _send(self, cmd, name): - self.ensure_running() - msg = '{0}:{1}\n'.format(cmd, name).encode('ascii') - if len(name) > 512: - # posix guarantees that writes to a pipe of less than PIPE_BUF - # bytes are atomic, and that PIPE_BUF >= 512 - raise ValueError('name too long') - nbytes = os.write(self._fd, msg) - assert nbytes == len(msg) - - -_semaphore_tracker = SemaphoreTracker() -ensure_running = _semaphore_tracker.ensure_running -register = _semaphore_tracker.register -unregister = _semaphore_tracker.unregister -getfd = _semaphore_tracker.getfd - - -def main(fd): - '''Run semaphore tracker.''' - # protect the process from ^C and "killall python" etc - signal.signal(signal.SIGINT, signal.SIG_IGN) - signal.signal(signal.SIGTERM, signal.SIG_IGN) - - for f in (sys.stdin, sys.stdout): - try: - f.close() - except Exception: - pass - - cache = set() - try: - # keep track of registered/unregistered semaphores - with open(fd, 'rb') as f: - for line in f: - try: - cmd, name = line.strip().split(b':') - if cmd == b'REGISTER': - cache.add(name) - elif cmd == b'UNREGISTER': - cache.remove(name) - else: - raise RuntimeError('unrecognized command %r' % cmd) - except Exception: - try: - sys.excepthook(*sys.exc_info()) - except: - pass - finally: - # all processes have terminated; cleanup any remaining semaphores - if cache: - try: - warnings.warn('semaphore_tracker: There appear to be %d ' - 'leaked semaphores to clean up at shutdown' % - len(cache)) - except Exception: - pass - for name in cache: - # For some reason the process which created and registered this - # semaphore has failed to unregister it. Presumably it has died. - # We therefore unlink it. - try: - name = name.decode('ascii') - try: - _multiprocessing.sem_unlink(name) - except Exception as e: - warnings.warn('semaphore_tracker: %r: %s' % (name, e)) - finally: - pass diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py new file mode 100644 index 0000000000..122b3fcebf --- /dev/null +++ b/Lib/multiprocessing/shared_memory.py @@ -0,0 +1,532 @@ +"""Provides shared memory for direct access across processes. + +The API of this package is currently provisional. Refer to the +documentation for details. +""" + + +__all__ = [ 'SharedMemory', 'ShareableList' ] + + +from functools import partial +import mmap +import os +import errno +import struct +import secrets +import types + +if os.name == "nt": + import _winapi + _USE_POSIX = False +else: + import _posixshmem + _USE_POSIX = True + + +_O_CREX = os.O_CREAT | os.O_EXCL + +# FreeBSD (and perhaps other BSDs) limit names to 14 characters. +_SHM_SAFE_NAME_LENGTH = 14 + +# Shared memory block name prefix +if _USE_POSIX: + _SHM_NAME_PREFIX = '/psm_' +else: + _SHM_NAME_PREFIX = 'wnsm_' + + +def _make_filename(): + "Create a random filename for the shared memory object." + # number of random bytes to use for name + nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2 + assert nbytes >= 2, '_SHM_NAME_PREFIX too long' + name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes) + assert len(name) <= _SHM_SAFE_NAME_LENGTH + return name + + +class SharedMemory: + """Creates a new shared memory block or attaches to an existing + shared memory block. + + Every shared memory block is assigned a unique name. This enables + one process to create a shared memory block with a particular name + so that a different process can attach to that same shared memory + block using that same name. + + As a resource for sharing data across processes, shared memory blocks + may outlive the original process that created them. When one process + no longer needs access to a shared memory block that might still be + needed by other processes, the close() method should be called. + When a shared memory block is no longer needed by any process, the + unlink() method should be called to ensure proper cleanup.""" + + # Defaults; enables close() and unlink() to run without errors. + _name = None + _fd = -1 + _mmap = None + _buf = None + _flags = os.O_RDWR + _mode = 0o600 + _prepend_leading_slash = True if _USE_POSIX else False + + def __init__(self, name=None, create=False, size=0): + if not size >= 0: + raise ValueError("'size' must be a positive integer") + if create: + self._flags = _O_CREX | os.O_RDWR + if size == 0: + raise ValueError("'size' must be a positive number different from zero") + if name is None and not self._flags & os.O_EXCL: + raise ValueError("'name' can only be None if create=True") + + if _USE_POSIX: + + # POSIX Shared Memory + + if name is None: + while True: + name = _make_filename() + try: + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + except FileExistsError: + continue + self._name = name + break + else: + name = "/" + name if self._prepend_leading_slash else name + self._fd = _posixshmem.shm_open( + name, + self._flags, + mode=self._mode + ) + self._name = name + try: + if create and size: + os.ftruncate(self._fd, size) + stats = os.fstat(self._fd) + size = stats.st_size + self._mmap = mmap.mmap(self._fd, size) + except OSError: + self.unlink() + raise + + from .resource_tracker import register + register(self._name, "shared_memory") + + else: + + # Windows Named Shared Memory + + if create: + while True: + temp_name = _make_filename() if name is None else name + # Create and reserve shared memory block with this name + # until it can be attached to by mmap. + h_map = _winapi.CreateFileMapping( + _winapi.INVALID_HANDLE_VALUE, + _winapi.NULL, + _winapi.PAGE_READWRITE, + (size >> 32) & 0xFFFFFFFF, + size & 0xFFFFFFFF, + temp_name + ) + try: + last_error_code = _winapi.GetLastError() + if last_error_code == _winapi.ERROR_ALREADY_EXISTS: + if name is not None: + raise FileExistsError( + errno.EEXIST, + os.strerror(errno.EEXIST), + name, + _winapi.ERROR_ALREADY_EXISTS + ) + else: + continue + self._mmap = mmap.mmap(-1, size, tagname=temp_name) + finally: + _winapi.CloseHandle(h_map) + self._name = temp_name + break + + else: + self._name = name + # Dynamically determine the existing named shared memory + # block's size which is likely a multiple of mmap.PAGESIZE. + h_map = _winapi.OpenFileMapping( + _winapi.FILE_MAP_READ, + False, + name + ) + try: + p_buf = _winapi.MapViewOfFile( + h_map, + _winapi.FILE_MAP_READ, + 0, + 0, + 0 + ) + finally: + _winapi.CloseHandle(h_map) + size = _winapi.VirtualQuerySize(p_buf) + self._mmap = mmap.mmap(-1, size, tagname=name) + + self._size = size + self._buf = memoryview(self._mmap) + + def __del__(self): + try: + self.close() + except OSError: + pass + + def __reduce__(self): + return ( + self.__class__, + ( + self.name, + False, + self.size, + ), + ) + + def __repr__(self): + return f'{self.__class__.__name__}({self.name!r}, size={self.size})' + + @property + def buf(self): + "A memoryview of contents of the shared memory block." + return self._buf + + @property + def name(self): + "Unique name that identifies the shared memory block." + reported_name = self._name + if _USE_POSIX and self._prepend_leading_slash: + if self._name.startswith("/"): + reported_name = self._name[1:] + return reported_name + + @property + def size(self): + "Size in bytes." + return self._size + + def close(self): + """Closes access to the shared memory from this instance but does + not destroy the shared memory block.""" + if self._buf is not None: + self._buf.release() + self._buf = None + if self._mmap is not None: + self._mmap.close() + self._mmap = None + if _USE_POSIX and self._fd >= 0: + os.close(self._fd) + self._fd = -1 + + def unlink(self): + """Requests that the underlying shared memory block be destroyed. + + In order to ensure proper cleanup of resources, unlink should be + called once (and only once) across all processes which have access + to the shared memory block.""" + if _USE_POSIX and self._name: + from .resource_tracker import unregister + _posixshmem.shm_unlink(self._name) + unregister(self._name, "shared_memory") + + +_encoding = "utf8" + +class ShareableList: + """Pattern for a mutable list-like object shareable via a shared + memory block. It differs from the built-in list type in that these + lists can not change their overall length (i.e. no append, insert, + etc.) + + Because values are packed into a memoryview as bytes, the struct + packing format for any storable value must require no more than 8 + characters to describe its format.""" + + # The shared memory area is organized as follows: + # - 8 bytes: number of items (N) as a 64-bit integer + # - (N + 1) * 8 bytes: offsets of each element from the start of the + # data area + # - K bytes: the data area storing item values (with encoding and size + # depending on their respective types) + # - N * 8 bytes: `struct` format string for each element + # - N bytes: index into _back_transforms_mapping for each element + # (for reconstructing the corresponding Python value) + _types_mapping = { + int: "q", + float: "d", + bool: "xxxxxxx?", + str: "%ds", + bytes: "%ds", + None.__class__: "xxxxxx?x", + } + _alignment = 8 + _back_transforms_mapping = { + 0: lambda value: value, # int, float, bool + 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str + 2: lambda value: value.rstrip(b'\x00'), # bytes + 3: lambda _value: None, # None + } + + @staticmethod + def _extract_recreation_code(value): + """Used in concert with _back_transforms_mapping to convert values + into the appropriate Python objects when retrieving them from + the list as well as when storing them.""" + if not isinstance(value, (str, bytes, None.__class__)): + return 0 + elif isinstance(value, str): + return 1 + elif isinstance(value, bytes): + return 2 + else: + return 3 # NoneType + + def __init__(self, sequence=None, *, name=None): + if name is None or sequence is not None: + sequence = sequence or () + _formats = [ + self._types_mapping[type(item)] + if not isinstance(item, (str, bytes)) + else self._types_mapping[type(item)] % ( + self._alignment * (len(item) // self._alignment + 1), + ) + for item in sequence + ] + self._list_len = len(_formats) + assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len + offset = 0 + # The offsets of each list element into the shared memory's + # data area (0 meaning the start of the data area, not the start + # of the shared memory area). + self._allocated_offsets = [0] + for fmt in _formats: + offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1]) + self._allocated_offsets.append(offset) + _recreation_codes = [ + self._extract_recreation_code(item) for item in sequence + ] + requested_size = struct.calcsize( + "q" + self._format_size_metainfo + + "".join(_formats) + + self._format_packing_metainfo + + self._format_back_transform_codes + ) + + self.shm = SharedMemory(name, create=True, size=requested_size) + else: + self.shm = SharedMemory(name) + + if sequence is not None: + _enc = _encoding + struct.pack_into( + "q" + self._format_size_metainfo, + self.shm.buf, + 0, + self._list_len, + *(self._allocated_offsets) + ) + struct.pack_into( + "".join(_formats), + self.shm.buf, + self._offset_data_start, + *(v.encode(_enc) if isinstance(v, str) else v for v in sequence) + ) + struct.pack_into( + self._format_packing_metainfo, + self.shm.buf, + self._offset_packing_formats, + *(v.encode(_enc) for v in _formats) + ) + struct.pack_into( + self._format_back_transform_codes, + self.shm.buf, + self._offset_back_transform_codes, + *(_recreation_codes) + ) + + else: + self._list_len = len(self) # Obtains size from offset 0 in buffer. + self._allocated_offsets = list( + struct.unpack_from( + self._format_size_metainfo, + self.shm.buf, + 1 * 8 + ) + ) + + def _get_packing_format(self, position): + "Gets the packing format for a single value stored in the list." + position = position if position >= 0 else position + self._list_len + if (position >= self._list_len) or (self._list_len < 0): + raise IndexError("Requested position out of range.") + + v = struct.unpack_from( + "8s", + self.shm.buf, + self._offset_packing_formats + position * 8 + )[0] + fmt = v.rstrip(b'\x00') + fmt_as_str = fmt.decode(_encoding) + + return fmt_as_str + + def _get_back_transform(self, position): + "Gets the back transformation function for a single value." + + if (position >= self._list_len) or (self._list_len < 0): + raise IndexError("Requested position out of range.") + + transform_code = struct.unpack_from( + "b", + self.shm.buf, + self._offset_back_transform_codes + position + )[0] + transform_function = self._back_transforms_mapping[transform_code] + + return transform_function + + def _set_packing_format_and_transform(self, position, fmt_as_str, value): + """Sets the packing format and back transformation code for a + single value in the list at the specified position.""" + + if (position >= self._list_len) or (self._list_len < 0): + raise IndexError("Requested position out of range.") + + struct.pack_into( + "8s", + self.shm.buf, + self._offset_packing_formats + position * 8, + fmt_as_str.encode(_encoding) + ) + + transform_code = self._extract_recreation_code(value) + struct.pack_into( + "b", + self.shm.buf, + self._offset_back_transform_codes + position, + transform_code + ) + + def __getitem__(self, position): + position = position if position >= 0 else position + self._list_len + try: + offset = self._offset_data_start + self._allocated_offsets[position] + (v,) = struct.unpack_from( + self._get_packing_format(position), + self.shm.buf, + offset + ) + except IndexError: + raise IndexError("index out of range") + + back_transform = self._get_back_transform(position) + v = back_transform(v) + + return v + + def __setitem__(self, position, value): + position = position if position >= 0 else position + self._list_len + try: + item_offset = self._allocated_offsets[position] + offset = self._offset_data_start + item_offset + current_format = self._get_packing_format(position) + except IndexError: + raise IndexError("assignment index out of range") + + if not isinstance(value, (str, bytes)): + new_format = self._types_mapping[type(value)] + encoded_value = value + else: + allocated_length = self._allocated_offsets[position + 1] - item_offset + + encoded_value = (value.encode(_encoding) + if isinstance(value, str) else value) + if len(encoded_value) > allocated_length: + raise ValueError("bytes/str item exceeds available storage") + if current_format[-1] == "s": + new_format = current_format + else: + new_format = self._types_mapping[str] % ( + allocated_length, + ) + + self._set_packing_format_and_transform( + position, + new_format, + value + ) + struct.pack_into(new_format, self.shm.buf, offset, encoded_value) + + def __reduce__(self): + return partial(self.__class__, name=self.shm.name), () + + def __len__(self): + return struct.unpack_from("q", self.shm.buf, 0)[0] + + def __repr__(self): + return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})' + + @property + def format(self): + "The struct packing format used by all currently stored items." + return "".join( + self._get_packing_format(i) for i in range(self._list_len) + ) + + @property + def _format_size_metainfo(self): + "The struct packing format used for the items' storage offsets." + return "q" * (self._list_len + 1) + + @property + def _format_packing_metainfo(self): + "The struct packing format used for the items' packing formats." + return "8s" * self._list_len + + @property + def _format_back_transform_codes(self): + "The struct packing format used for the items' back transforms." + return "b" * self._list_len + + @property + def _offset_data_start(self): + # - 8 bytes for the list length + # - (N + 1) * 8 bytes for the element offsets + return (self._list_len + 2) * 8 + + @property + def _offset_packing_formats(self): + return self._offset_data_start + self._allocated_offsets[-1] + + @property + def _offset_back_transform_codes(self): + return self._offset_packing_formats + self._list_len * 8 + + def count(self, value): + "L.count(value) -> integer -- return number of occurrences of value." + + return sum(value == entry for entry in self) + + def index(self, value): + """L.index(value) -> integer -- return first index of value. + Raises ValueError if the value is not present.""" + + for position, entry in enumerate(self): + if value == entry: + return position + else: + raise ValueError(f"{value!r} not in this container") + + __class_getitem__ = classmethod(types.GenericAlias) diff --git a/Lib/multiprocessing/sharedctypes.py b/Lib/multiprocessing/sharedctypes.py index 25cbcf2ae4..6071707027 100644 --- a/Lib/multiprocessing/sharedctypes.py +++ b/Lib/multiprocessing/sharedctypes.py @@ -23,12 +23,13 @@ __all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized'] # typecode_to_type = { - 'c': ctypes.c_char, 'u': ctypes.c_wchar, - 'b': ctypes.c_byte, 'B': ctypes.c_ubyte, - 'h': ctypes.c_short, 'H': ctypes.c_ushort, - 'i': ctypes.c_int, 'I': ctypes.c_uint, - 'l': ctypes.c_long, 'L': ctypes.c_ulong, - 'f': ctypes.c_float, 'd': ctypes.c_double + 'c': ctypes.c_char, 'u': ctypes.c_wchar, + 'b': ctypes.c_byte, 'B': ctypes.c_ubyte, + 'h': ctypes.c_short, 'H': ctypes.c_ushort, + 'i': ctypes.c_int, 'I': ctypes.c_uint, + 'l': ctypes.c_long, 'L': ctypes.c_ulong, + 'q': ctypes.c_longlong, 'Q': ctypes.c_ulonglong, + 'f': ctypes.c_float, 'd': ctypes.c_double } # @@ -77,7 +78,7 @@ def Value(typecode_or_type, *args, lock=True, ctx=None): ctx = ctx or get_context() lock = ctx.RLock() if not hasattr(lock, 'acquire'): - raise AttributeError("'%r' has no method 'acquire'" % lock) + raise AttributeError("%r has no method 'acquire'" % lock) return synchronized(obj, lock, ctx=ctx) def Array(typecode_or_type, size_or_initializer, *, lock=True, ctx=None): @@ -91,7 +92,7 @@ def Array(typecode_or_type, size_or_initializer, *, lock=True, ctx=None): ctx = ctx or get_context() lock = ctx.RLock() if not hasattr(lock, 'acquire'): - raise AttributeError("'%r' has no method 'acquire'" % lock) + raise AttributeError("%r has no method 'acquire'" % lock) return synchronized(obj, lock, ctx=ctx) def copy(obj): @@ -115,7 +116,7 @@ def synchronized(obj, lock=None, ctx=None): scls = class_cache[cls] except KeyError: names = [field[0] for field in cls._fields_] - d = dict((name, make_property(name)) for name in names) + d = {name: make_property(name) for name in names} classname = 'Synchronized' + cls.__name__ scls = class_cache[cls] = type(classname, (SynchronizedBase,), d) return scls(obj, lock, ctx) diff --git a/Lib/multiprocessing/spawn.py b/Lib/multiprocessing/spawn.py index dfb9f65270..7cc129e261 100644 --- a/Lib/multiprocessing/spawn.py +++ b/Lib/multiprocessing/spawn.py @@ -30,7 +30,7 @@ if sys.platform != 'win32': WINEXE = False WINSERVICE = False else: - WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) + WINEXE = getattr(sys, 'frozen', False) WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") if WINSERVICE: @@ -93,20 +93,31 @@ def spawn_main(pipe_handle, parent_pid=None, tracker_fd=None): ''' Run code specified by data received over pipe ''' - assert is_forking(sys.argv) + assert is_forking(sys.argv), "Not forking" if sys.platform == 'win32': import msvcrt - new_handle = reduction.steal_handle(parent_pid, pipe_handle) + import _winapi + + if parent_pid is not None: + source_process = _winapi.OpenProcess( + _winapi.SYNCHRONIZE | _winapi.PROCESS_DUP_HANDLE, + False, parent_pid) + else: + source_process = None + new_handle = reduction.duplicate(pipe_handle, + source_process=source_process) fd = msvcrt.open_osfhandle(new_handle, os.O_RDONLY) + parent_sentinel = source_process else: - from . import semaphore_tracker - semaphore_tracker._semaphore_tracker._fd = tracker_fd + from . import resource_tracker + resource_tracker._resource_tracker._fd = tracker_fd fd = pipe_handle - exitcode = _main(fd) + parent_sentinel = os.dup(pipe_handle) + exitcode = _main(fd, parent_sentinel) sys.exit(exitcode) -def _main(fd): +def _main(fd, parent_sentinel): with os.fdopen(fd, 'rb', closefd=True) as from_parent: process.current_process()._inheriting = True try: @@ -115,7 +126,7 @@ def _main(fd): self = reduction.pickle.load(from_parent) finally: del process.current_process()._inheriting - return self._bootstrap() + return self._bootstrap(parent_sentinel) def _check_not_importing_main(): @@ -217,7 +228,7 @@ def prepare(data): process.ORIGINAL_DIR = data['orig_dir'] if 'start_method' in data: - set_start_method(data['start_method']) + set_start_method(data['start_method'], force=True) if 'init_main_from_name' in data: _fixup_main_from_name(data['init_main_from_name']) diff --git a/Lib/multiprocessing/synchronize.py b/Lib/multiprocessing/synchronize.py index d4bdf0e8b1..d0be48f1fd 100644 --- a/Lib/multiprocessing/synchronize.py +++ b/Lib/multiprocessing/synchronize.py @@ -15,8 +15,7 @@ import threading import sys import tempfile import _multiprocessing - -from time import time as _time +import time from . import context from . import process @@ -77,16 +76,16 @@ class SemLock(object): # We only get here if we are on Unix with forking # disabled. When the object is garbage collected or the # process shuts down we unlink the semaphore name - from .semaphore_tracker import register - register(self._semlock.name) + from .resource_tracker import register + register(self._semlock.name, "semaphore") util.Finalize(self, SemLock._cleanup, (self._semlock.name,), exitpriority=0) @staticmethod def _cleanup(name): - from .semaphore_tracker import unregister + from .resource_tracker import unregister sem_unlink(name) - unregister(name) + unregister(name, "semaphore") def _make_methods(self): self.acquire = self._semlock.acquire @@ -268,35 +267,21 @@ class Condition(object): for i in range(count): self._lock.acquire() - def notify(self): + def notify(self, n=1): assert self._lock._semlock._is_mine(), 'lock is not owned' - assert not self._wait_semaphore.acquire(False) - - # to take account of timeouts since last notify() we subtract - # woken_count from sleeping_count and rezero woken_count - while self._woken_count.acquire(False): - res = self._sleeping_count.acquire(False) - assert res - - if self._sleeping_count.acquire(False): # try grabbing a sleeper - self._wait_semaphore.release() # wake up one sleeper - self._woken_count.acquire() # wait for the sleeper to wake - - # rezero _wait_semaphore in case a timeout just happened - self._wait_semaphore.acquire(False) - - def notify_all(self): - assert self._lock._semlock._is_mine(), 'lock is not owned' - assert not self._wait_semaphore.acquire(False) + assert not self._wait_semaphore.acquire( + False), ('notify: Should not have been able to acquire ' + + '_wait_semaphore') # to take account of timeouts since last notify*() we subtract # woken_count from sleeping_count and rezero woken_count while self._woken_count.acquire(False): res = self._sleeping_count.acquire(False) - assert res + assert res, ('notify: Bug in sleeping_count.acquire' + + '- res should not be False') sleepers = 0 - while self._sleeping_count.acquire(False): + while sleepers < n and self._sleeping_count.acquire(False): self._wait_semaphore.release() # wake up one sleeper sleepers += 1 @@ -308,18 +293,21 @@ class Condition(object): while self._wait_semaphore.acquire(False): pass + def notify_all(self): + self.notify(n=sys.maxsize) + def wait_for(self, predicate, timeout=None): result = predicate() if result: return result if timeout is not None: - endtime = _time() + timeout + endtime = time.monotonic() + timeout else: endtime = None waittime = None while not result: if endtime is not None: - waittime = endtime - _time() + waittime = endtime - time.monotonic() if waittime <= 0: break self.wait(waittime) diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py index 1a2c0db40b..9e07a4e93e 100644 --- a/Lib/multiprocessing/util.py +++ b/Lib/multiprocessing/util.py @@ -102,10 +102,42 @@ def log_to_stderr(level=None): _log_to_stderr = True return _logger + +# Abstract socket support + +def _platform_supports_abstract_sockets(): + if sys.platform == "linux": + return True + if hasattr(sys, 'getandroidapilevel'): + return True + return False + + +def is_abstract_socket_namespace(address): + if not address: + return False + if isinstance(address, bytes): + return address[0] == 0 + elif isinstance(address, str): + return address[0] == "\0" + raise TypeError(f'address type of {address!r} unrecognized') + + +abstract_sockets_supported = _platform_supports_abstract_sockets() + # # Function returning a temp directory which will be removed on exit # +def _remove_temp_dir(rmtree, tempdir): + rmtree(tempdir) + + current_process = process.current_process() + # current_process() can be None if the finalizer is called + # late during Python finalization + if current_process is not None: + current_process._config['tempdir'] = None + def get_temp_dir(): # get name of a temp directory which will be automatically cleaned up tempdir = process.current_process()._config.get('tempdir') @@ -113,7 +145,10 @@ def get_temp_dir(): import shutil, tempfile tempdir = tempfile.mkdtemp(prefix='pymp-') info('created temp directory %s', tempdir) - Finalize(None, shutil.rmtree, args=[tempdir], exitpriority=-100) + # keep a strong reference to shutil.rmtree(), since the finalizer + # can be called late during Python shutdown + Finalize(None, _remove_temp_dir, args=(shutil.rmtree, tempdir), + exitpriority=-100) process.current_process()._config['tempdir'] = tempdir return tempdir @@ -149,12 +184,15 @@ class Finalize(object): Class which supports object finalization using weakrefs ''' def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): - assert exitpriority is None or type(exitpriority) is int + if (exitpriority is not None) and not isinstance(exitpriority,int): + raise TypeError( + "Exitpriority ({0!r}) must be None or int, not {1!s}".format( + exitpriority, type(exitpriority))) if obj is not None: self._weakref = weakref.ref(obj, self) - else: - assert exitpriority is not None + elif exitpriority is None: + raise ValueError("Without object, exitpriority cannot be None") self._callback = callback self._args = args @@ -223,7 +261,7 @@ class Finalize(object): if self._kwargs: x += ', kwargs=' + str(self._kwargs) if self._key[0] is not None: - x += ', exitprority=' + str(self._key[0]) + x += ', exitpriority=' + str(self._key[0]) return x + '>' @@ -241,20 +279,28 @@ def _run_finalizers(minpriority=None): return if minpriority is None: - f = lambda p : p[0][0] is not None + f = lambda p : p[0] is not None else: - f = lambda p : p[0][0] is not None and p[0][0] >= minpriority + f = lambda p : p[0] is not None and p[0] >= minpriority - items = [x for x in list(_finalizer_registry.items()) if f(x)] - items.sort(reverse=True) + # Careful: _finalizer_registry may be mutated while this function + # is running (either by a GC run or by another thread). - for key, finalizer in items: - sub_debug('calling %s', finalizer) - try: - finalizer() - except Exception: - import traceback - traceback.print_exc() + # list(_finalizer_registry) should be atomic, while + # list(_finalizer_registry.items()) is not. + keys = [key for key in list(_finalizer_registry) if f(key)] + keys.sort(reverse=True) + + for key in keys: + finalizer = _finalizer_registry.get(key) + # key may have been removed from the registry + if finalizer is not None: + sub_debug('calling %s', finalizer) + try: + finalizer() + except Exception: + import traceback + traceback.print_exc() if minpriority is None: _finalizer_registry.clear() @@ -321,13 +367,13 @@ atexit.register(_exit_function) class ForkAwareThreadLock(object): def __init__(self): - self._reset() - register_after_fork(self, ForkAwareThreadLock._reset) - - def _reset(self): self._lock = threading.Lock() self.acquire = self._lock.acquire self.release = self._lock.release + register_after_fork(self, ForkAwareThreadLock._at_fork_reinit) + + def _at_fork_reinit(self): + self._lock._at_fork_reinit() def __enter__(self): return self._lock.__enter__() @@ -373,26 +419,71 @@ def _close_stdin(): try: fd = os.open(os.devnull, os.O_RDONLY) try: - sys.stdin = open(fd, closefd=False) + sys.stdin = open(fd, encoding="utf-8", closefd=False) except: os.close(fd) raise except (OSError, ValueError): pass +# +# Flush standard streams, if any +# + +def _flush_std_streams(): + try: + sys.stdout.flush() + except (AttributeError, ValueError): + pass + try: + sys.stderr.flush() + except (AttributeError, ValueError): + pass + # # Start a program with only specified fds kept open # def spawnv_passfds(path, args, passfds): import _posixsubprocess - passfds = sorted(passfds) + passfds = tuple(sorted(map(int, passfds))) errpipe_read, errpipe_write = os.pipe() try: return _posixsubprocess.fork_exec( args, [os.fsencode(path)], True, passfds, None, None, -1, -1, -1, -1, -1, -1, errpipe_read, errpipe_write, - False, False, None) + False, False, None, None, None, -1, None) finally: os.close(errpipe_read) os.close(errpipe_write) + + +def close_fds(*fds): + """Close each file descriptor given as an argument""" + for fd in fds: + os.close(fd) + + +def _cleanup_tests(): + """Cleanup multiprocessing resources when multiprocessing tests + completed.""" + + from test import support + + # cleanup multiprocessing + process._cleanup() + + # Stop the ForkServer process if it's running + from multiprocessing import forkserver + forkserver._forkserver._stop() + + # Stop the ResourceTracker process if it's running + from multiprocessing import resource_tracker + resource_tracker._resource_tracker._stop() + + # bpo-37421: Explicitly call _run_finalizers() to remove immediately + # temporary directories created by multiprocessing.util.get_temp_dir(). + _run_finalizers() + support.gc_collect() + + support.reap_children()