diff --git a/Lib/getpass.py b/Lib/getpass.py index be5112115..6970d8adf 100644 --- a/Lib/getpass.py +++ b/Lib/getpass.py @@ -7,7 +7,6 @@ GetPassWarning - This UserWarning is issued when getpass() cannot prevent echoing of the password contents while reading. On Windows, the msvcrt module will be used. -On the Mac EasyDialogs.AskPassword is used, if available. """ @@ -53,7 +52,7 @@ def unix_getpass(prompt='Password: ', stream=None): stack.enter_context(input) if not stream: stream = input - except OSError as e: + except OSError: # If that fails, see if stdin can be controlled. stack.close() try: @@ -96,7 +95,7 @@ def unix_getpass(prompt='Password: ', stream=None): def win_getpass(prompt='Password: ', stream=None): - """Prompt for password with echo off, using Windows getch().""" + """Prompt for password with echo off, using Windows getwch().""" if sys.stdin is not sys.__stdin__: return fallback_getpass(prompt, stream) diff --git a/Lib/glob.py b/Lib/glob.py index 002cd9201..9fc08f45d 100644 --- a/Lib/glob.py +++ b/Lib/glob.py @@ -1,12 +1,16 @@ """Filename globbing utility.""" +import contextlib import os import re import fnmatch +import itertools +import stat +import sys __all__ = ["glob", "iglob", "escape"] -def glob(pathname, *, recursive=False): +def glob(pathname, *, root_dir=None, dir_fd=None, recursive=False): """Return a list of paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la @@ -17,9 +21,9 @@ def glob(pathname, *, recursive=False): If recursive is true, the pattern '**' will match any files and zero or more directories and subdirectories. """ - return list(iglob(pathname, recursive=recursive)) + return list(iglob(pathname, root_dir=root_dir, dir_fd=dir_fd, recursive=recursive)) -def iglob(pathname, *, recursive=False): +def iglob(pathname, *, root_dir=None, dir_fd=None, recursive=False): """Return an iterator which yields the paths matching a pathname pattern. The pattern may contain simple shell-style wildcards a la @@ -30,35 +34,45 @@ def iglob(pathname, *, recursive=False): If recursive is true, the pattern '**' will match any files and zero or more directories and subdirectories. """ - it = _iglob(pathname, recursive, False) - if recursive and _isrecursive(pathname): - s = next(it) # skip empty string - assert not s + sys.audit("glob.glob", pathname, recursive) + sys.audit("glob.glob/2", pathname, recursive, root_dir, dir_fd) + if root_dir is not None: + root_dir = os.fspath(root_dir) + else: + root_dir = pathname[:0] + it = _iglob(pathname, root_dir, dir_fd, recursive, False) + if not pathname or recursive and _isrecursive(pathname[:2]): + try: + s = next(it) # skip empty string + if s: + it = itertools.chain((s,), it) + except StopIteration: + pass return it -def _iglob(pathname, recursive, dironly): +def _iglob(pathname, root_dir, dir_fd, recursive, dironly): dirname, basename = os.path.split(pathname) if not has_magic(pathname): assert not dironly if basename: - if os.path.lexists(pathname): + if _lexists(_join(root_dir, pathname), dir_fd): yield pathname else: # Patterns ending with a slash should match only directories - if os.path.isdir(dirname): + if _isdir(_join(root_dir, dirname), dir_fd): yield pathname return if not dirname: if recursive and _isrecursive(basename): - yield from _glob2(dirname, basename, dironly) + yield from _glob2(root_dir, basename, dir_fd, dironly) else: - yield from _glob1(dirname, basename, dironly) + yield from _glob1(root_dir, basename, dir_fd, dironly) return # `os.path.split()` returns the argument itself as a dirname if it is a # drive or UNC path. Prevent an infinite recursion if a drive or UNC path # contains magic characters (i.e. r'\\?\C:'). if dirname != pathname and has_magic(dirname): - dirs = _iglob(dirname, recursive, True) + dirs = _iglob(dirname, root_dir, dir_fd, recursive, True) else: dirs = [dirname] if has_magic(basename): @@ -69,76 +83,125 @@ def _iglob(pathname, recursive, dironly): else: glob_in_dir = _glob0 for dirname in dirs: - for name in glob_in_dir(dirname, basename, dironly): + for name in glob_in_dir(_join(root_dir, dirname), basename, dir_fd, dironly): yield os.path.join(dirname, name) # These 2 helper functions non-recursively glob inside a literal directory. # They return a list of basenames. _glob1 accepts a pattern while _glob0 # takes a literal basename (so it only has to check for its existence). -def _glob1(dirname, pattern, dironly): - names = list(_iterdir(dirname, dironly)) +def _glob1(dirname, pattern, dir_fd, dironly): + names = _listdir(dirname, dir_fd, dironly) if not _ishidden(pattern): names = (x for x in names if not _ishidden(x)) return fnmatch.filter(names, pattern) -def _glob0(dirname, basename, dironly): - if not basename: - # `os.path.split()` returns an empty basename for paths ending with a - # directory separator. 'q*x/' should match only directories. - if os.path.isdir(dirname): +def _glob0(dirname, basename, dir_fd, dironly): + if basename: + if _lexists(_join(dirname, basename), dir_fd): return [basename] else: - if os.path.lexists(os.path.join(dirname, basename)): + # `os.path.split()` returns an empty basename for paths ending with a + # directory separator. 'q*x/' should match only directories. + if _isdir(dirname, dir_fd): return [basename] return [] # Following functions are not public but can be used by third-party code. def glob0(dirname, pattern): - return _glob0(dirname, pattern, False) + return _glob0(dirname, pattern, None, False) def glob1(dirname, pattern): - return _glob1(dirname, pattern, False) + return _glob1(dirname, pattern, None, False) # This helper function recursively yields relative pathnames inside a literal # directory. -def _glob2(dirname, pattern, dironly): +def _glob2(dirname, pattern, dir_fd, dironly): assert _isrecursive(pattern) yield pattern[:0] - yield from _rlistdir(dirname, dironly) + yield from _rlistdir(dirname, dir_fd, dironly) # If dironly is false, yields all file names inside a directory. # If dironly is true, yields only directory names. -def _iterdir(dirname, dironly): - if not dirname: - if isinstance(dirname, bytes): - dirname = bytes(os.curdir, 'ASCII') - else: - dirname = os.curdir +def _iterdir(dirname, dir_fd, dironly): try: - with os.scandir(dirname) as it: - for entry in it: - try: - if not dironly or entry.is_dir(): - yield entry.name - except OSError: - pass + fd = None + fsencode = None + if dir_fd is not None: + if dirname: + fd = arg = os.open(dirname, _dir_open_flags, dir_fd=dir_fd) + else: + arg = dir_fd + if isinstance(dirname, bytes): + fsencode = os.fsencode + elif dirname: + arg = dirname + elif isinstance(dirname, bytes): + arg = bytes(os.curdir, 'ASCII') + else: + arg = os.curdir + try: + with os.scandir(arg) as it: + for entry in it: + try: + if not dironly or entry.is_dir(): + if fsencode is not None: + yield fsencode(entry.name) + else: + yield entry.name + except OSError: + pass + finally: + if fd is not None: + os.close(fd) except OSError: return +def _listdir(dirname, dir_fd, dironly): + with contextlib.closing(_iterdir(dirname, dir_fd, dironly)) as it: + return list(it) + # Recursively yields relative pathnames inside a literal directory. -def _rlistdir(dirname, dironly): - names = list(_iterdir(dirname, dironly)) +def _rlistdir(dirname, dir_fd, dironly): + names = _listdir(dirname, dir_fd, dironly) for x in names: if not _ishidden(x): yield x - path = os.path.join(dirname, x) if dirname else x - for y in _rlistdir(path, dironly): - yield os.path.join(x, y) + path = _join(dirname, x) if dirname else x + for y in _rlistdir(path, dir_fd, dironly): + yield _join(x, y) +def _lexists(pathname, dir_fd): + # Same as os.path.lexists(), but with dir_fd + if dir_fd is None: + return os.path.lexists(pathname) + try: + os.lstat(pathname, dir_fd=dir_fd) + except (OSError, ValueError): + return False + else: + return True + +def _isdir(pathname, dir_fd): + # Same as os.path.isdir(), but with dir_fd + if dir_fd is None: + return os.path.isdir(pathname) + try: + st = os.stat(pathname, dir_fd=dir_fd) + except (OSError, ValueError): + return False + else: + return stat.S_ISDIR(st.st_mode) + +def _join(dirname, basename): + # It is common if dirname or basename is empty + if not dirname or not basename: + return dirname or basename + return os.path.join(dirname, basename) + magic_check = re.compile('([*?[])') magic_check_bytes = re.compile(b'([*?[])') @@ -169,3 +232,6 @@ def escape(pathname): else: pathname = magic_check.sub(r'[\1]', pathname) return drive + pathname + + +_dir_open_flags = os.O_RDONLY | getattr(os, 'O_DIRECTORY', 0) diff --git a/Lib/heapq.py b/Lib/heapq.py index 0e3555cf9..fabefd87f 100644 --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -597,5 +597,5 @@ except ImportError: if __name__ == "__main__": - import doctest - print(doctest.testmod()) + import doctest # pragma: no cover + print(doctest.testmod()) # pragma: no cover diff --git a/Lib/mailbox.py b/Lib/mailbox.py new file mode 100644 index 000000000..70da07ed2 --- /dev/null +++ b/Lib/mailbox.py @@ -0,0 +1,2151 @@ +"""Read/write support for Maildir, mbox, MH, Babyl, and MMDF mailboxes.""" + +# Notes for authors of new mailbox subclasses: +# +# Remember to fsync() changes to disk before closing a modified file +# or returning from a flush() method. See functions _sync_flush() and +# _sync_close(). + +import os +import time +import calendar +import socket +import errno +import copy +import warnings +import email +import email.message +import email.generator +import io +import contextlib +from types import GenericAlias +try: + import fcntl +except ImportError: + fcntl = None + +__all__ = ['Mailbox', 'Maildir', 'mbox', 'MH', 'Babyl', 'MMDF', + 'Message', 'MaildirMessage', 'mboxMessage', 'MHMessage', + 'BabylMessage', 'MMDFMessage', 'Error', 'NoSuchMailboxError', + 'NotEmptyError', 'ExternalClashError', 'FormatError'] + +linesep = os.linesep.encode('ascii') + +class Mailbox: + """A group of messages in a particular place.""" + + def __init__(self, path, factory=None, create=True): + """Initialize a Mailbox instance.""" + self._path = os.path.abspath(os.path.expanduser(path)) + self._factory = factory + + def add(self, message): + """Add message and return assigned key.""" + raise NotImplementedError('Method must be implemented by subclass') + + def remove(self, key): + """Remove the keyed message; raise KeyError if it doesn't exist.""" + raise NotImplementedError('Method must be implemented by subclass') + + def __delitem__(self, key): + self.remove(key) + + def discard(self, key): + """If the keyed message exists, remove it.""" + try: + self.remove(key) + except KeyError: + pass + + def __setitem__(self, key, message): + """Replace the keyed message; raise KeyError if it doesn't exist.""" + raise NotImplementedError('Method must be implemented by subclass') + + def get(self, key, default=None): + """Return the keyed message, or default if it doesn't exist.""" + try: + return self.__getitem__(key) + except KeyError: + return default + + def __getitem__(self, key): + """Return the keyed message; raise KeyError if it doesn't exist.""" + if not self._factory: + return self.get_message(key) + else: + with contextlib.closing(self.get_file(key)) as file: + return self._factory(file) + + def get_message(self, key): + """Return a Message representation or raise a KeyError.""" + raise NotImplementedError('Method must be implemented by subclass') + + def get_string(self, key): + """Return a string representation or raise a KeyError. + + Uses email.message.Message to create a 7bit clean string + representation of the message.""" + return email.message_from_bytes(self.get_bytes(key)).as_string() + + def get_bytes(self, key): + """Return a byte string representation or raise a KeyError.""" + raise NotImplementedError('Method must be implemented by subclass') + + def get_file(self, key): + """Return a file-like representation or raise a KeyError.""" + raise NotImplementedError('Method must be implemented by subclass') + + def iterkeys(self): + """Return an iterator over keys.""" + raise NotImplementedError('Method must be implemented by subclass') + + def keys(self): + """Return a list of keys.""" + return list(self.iterkeys()) + + def itervalues(self): + """Return an iterator over all messages.""" + for key in self.iterkeys(): + try: + value = self[key] + except KeyError: + continue + yield value + + def __iter__(self): + return self.itervalues() + + def values(self): + """Return a list of messages. Memory intensive.""" + return list(self.itervalues()) + + def iteritems(self): + """Return an iterator over (key, message) tuples.""" + for key in self.iterkeys(): + try: + value = self[key] + except KeyError: + continue + yield (key, value) + + def items(self): + """Return a list of (key, message) tuples. Memory intensive.""" + return list(self.iteritems()) + + def __contains__(self, key): + """Return True if the keyed message exists, False otherwise.""" + raise NotImplementedError('Method must be implemented by subclass') + + def __len__(self): + """Return a count of messages in the mailbox.""" + raise NotImplementedError('Method must be implemented by subclass') + + def clear(self): + """Delete all messages.""" + for key in self.keys(): + self.discard(key) + + def pop(self, key, default=None): + """Delete the keyed message and return it, or default.""" + try: + result = self[key] + except KeyError: + return default + self.discard(key) + return result + + def popitem(self): + """Delete an arbitrary (key, message) pair and return it.""" + for key in self.iterkeys(): + return (key, self.pop(key)) # This is only run once. + else: + raise KeyError('No messages in mailbox') + + def update(self, arg=None): + """Change the messages that correspond to certain keys.""" + if hasattr(arg, 'iteritems'): + source = arg.iteritems() + elif hasattr(arg, 'items'): + source = arg.items() + else: + source = arg + bad_key = False + for key, message in source: + try: + self[key] = message + except KeyError: + bad_key = True + if bad_key: + raise KeyError('No message with key(s)') + + def flush(self): + """Write any pending changes to the disk.""" + raise NotImplementedError('Method must be implemented by subclass') + + def lock(self): + """Lock the mailbox.""" + raise NotImplementedError('Method must be implemented by subclass') + + def unlock(self): + """Unlock the mailbox if it is locked.""" + raise NotImplementedError('Method must be implemented by subclass') + + def close(self): + """Flush and close the mailbox.""" + raise NotImplementedError('Method must be implemented by subclass') + + def _string_to_bytes(self, message): + # If a message is not 7bit clean, we refuse to handle it since it + # likely came from reading invalid messages in text mode, and that way + # lies mojibake. + try: + return message.encode('ascii') + except UnicodeError: + raise ValueError("String input must be ASCII-only; " + "use bytes or a Message instead") + + # Whether each message must end in a newline + _append_newline = False + + def _dump_message(self, message, target, mangle_from_=False): + # This assumes the target file is open in binary mode. + """Dump message contents to target file.""" + if isinstance(message, email.message.Message): + buffer = io.BytesIO() + gen = email.generator.BytesGenerator(buffer, mangle_from_, 0) + gen.flatten(message) + buffer.seek(0) + data = buffer.read() + data = data.replace(b'\n', linesep) + target.write(data) + if self._append_newline and not data.endswith(linesep): + # Make sure the message ends with a newline + target.write(linesep) + elif isinstance(message, (str, bytes, io.StringIO)): + if isinstance(message, io.StringIO): + warnings.warn("Use of StringIO input is deprecated, " + "use BytesIO instead", DeprecationWarning, 3) + message = message.getvalue() + if isinstance(message, str): + message = self._string_to_bytes(message) + if mangle_from_: + message = message.replace(b'\nFrom ', b'\n>From ') + message = message.replace(b'\n', linesep) + target.write(message) + if self._append_newline and not message.endswith(linesep): + # Make sure the message ends with a newline + target.write(linesep) + elif hasattr(message, 'read'): + if hasattr(message, 'buffer'): + warnings.warn("Use of text mode files is deprecated, " + "use a binary mode file instead", DeprecationWarning, 3) + message = message.buffer + lastline = None + while True: + line = message.readline() + # Universal newline support. + if line.endswith(b'\r\n'): + line = line[:-2] + b'\n' + elif line.endswith(b'\r'): + line = line[:-1] + b'\n' + if not line: + break + if mangle_from_ and line.startswith(b'From '): + line = b'>From ' + line[5:] + line = line.replace(b'\n', linesep) + target.write(line) + lastline = line + if self._append_newline and lastline and not lastline.endswith(linesep): + # Make sure the message ends with a newline + target.write(linesep) + else: + raise TypeError('Invalid message type: %s' % type(message)) + + __class_getitem__ = classmethod(GenericAlias) + + +class Maildir(Mailbox): + """A qmail-style Maildir mailbox.""" + + colon = ':' + + def __init__(self, dirname, factory=None, create=True): + """Initialize a Maildir instance.""" + Mailbox.__init__(self, dirname, factory, create) + self._paths = { + 'tmp': os.path.join(self._path, 'tmp'), + 'new': os.path.join(self._path, 'new'), + 'cur': os.path.join(self._path, 'cur'), + } + if not os.path.exists(self._path): + if create: + os.mkdir(self._path, 0o700) + for path in self._paths.values(): + os.mkdir(path, 0o700) + else: + raise NoSuchMailboxError(self._path) + self._toc = {} + self._toc_mtimes = {'cur': 0, 'new': 0} + self._last_read = 0 # Records last time we read cur/new + self._skewfactor = 0.1 # Adjust if os/fs clocks are skewing + + def add(self, message): + """Add message and return assigned key.""" + tmp_file = self._create_tmp() + try: + self._dump_message(message, tmp_file) + except BaseException: + tmp_file.close() + os.remove(tmp_file.name) + raise + _sync_close(tmp_file) + if isinstance(message, MaildirMessage): + subdir = message.get_subdir() + suffix = self.colon + message.get_info() + if suffix == self.colon: + suffix = '' + else: + subdir = 'new' + suffix = '' + uniq = os.path.basename(tmp_file.name).split(self.colon)[0] + dest = os.path.join(self._path, subdir, uniq + suffix) + if isinstance(message, MaildirMessage): + os.utime(tmp_file.name, + (os.path.getatime(tmp_file.name), message.get_date())) + # No file modification should be done after the file is moved to its + # final position in order to prevent race conditions with changes + # from other programs + try: + try: + os.link(tmp_file.name, dest) + except (AttributeError, PermissionError): + os.rename(tmp_file.name, dest) + else: + os.remove(tmp_file.name) + except OSError as e: + os.remove(tmp_file.name) + if e.errno == errno.EEXIST: + raise ExternalClashError('Name clash with existing message: %s' + % dest) + else: + raise + return uniq + + def remove(self, key): + """Remove the keyed message; raise KeyError if it doesn't exist.""" + os.remove(os.path.join(self._path, self._lookup(key))) + + def discard(self, key): + """If the keyed message exists, remove it.""" + # This overrides an inapplicable implementation in the superclass. + try: + self.remove(key) + except (KeyError, FileNotFoundError): + pass + + def __setitem__(self, key, message): + """Replace the keyed message; raise KeyError if it doesn't exist.""" + old_subpath = self._lookup(key) + temp_key = self.add(message) + temp_subpath = self._lookup(temp_key) + if isinstance(message, MaildirMessage): + # temp's subdir and suffix were specified by message. + dominant_subpath = temp_subpath + else: + # temp's subdir and suffix were defaults from add(). + dominant_subpath = old_subpath + subdir = os.path.dirname(dominant_subpath) + if self.colon in dominant_subpath: + suffix = self.colon + dominant_subpath.split(self.colon)[-1] + else: + suffix = '' + self.discard(key) + tmp_path = os.path.join(self._path, temp_subpath) + new_path = os.path.join(self._path, subdir, key + suffix) + if isinstance(message, MaildirMessage): + os.utime(tmp_path, + (os.path.getatime(tmp_path), message.get_date())) + # No file modification should be done after the file is moved to its + # final position in order to prevent race conditions with changes + # from other programs + os.rename(tmp_path, new_path) + + def get_message(self, key): + """Return a Message representation or raise a KeyError.""" + subpath = self._lookup(key) + with open(os.path.join(self._path, subpath), 'rb') as f: + if self._factory: + msg = self._factory(f) + else: + msg = MaildirMessage(f) + subdir, name = os.path.split(subpath) + msg.set_subdir(subdir) + if self.colon in name: + msg.set_info(name.split(self.colon)[-1]) + msg.set_date(os.path.getmtime(os.path.join(self._path, subpath))) + return msg + + def get_bytes(self, key): + """Return a bytes representation or raise a KeyError.""" + with open(os.path.join(self._path, self._lookup(key)), 'rb') as f: + return f.read().replace(linesep, b'\n') + + def get_file(self, key): + """Return a file-like representation or raise a KeyError.""" + f = open(os.path.join(self._path, self._lookup(key)), 'rb') + return _ProxyFile(f) + + def iterkeys(self): + """Return an iterator over keys.""" + self._refresh() + for key in self._toc: + try: + self._lookup(key) + except KeyError: + continue + yield key + + def __contains__(self, key): + """Return True if the keyed message exists, False otherwise.""" + self._refresh() + return key in self._toc + + def __len__(self): + """Return a count of messages in the mailbox.""" + self._refresh() + return len(self._toc) + + def flush(self): + """Write any pending changes to disk.""" + # Maildir changes are always written immediately, so there's nothing + # to do. + pass + + def lock(self): + """Lock the mailbox.""" + return + + def unlock(self): + """Unlock the mailbox if it is locked.""" + return + + def close(self): + """Flush and close the mailbox.""" + return + + def list_folders(self): + """Return a list of folder names.""" + result = [] + for entry in os.listdir(self._path): + if len(entry) > 1 and entry[0] == '.' and \ + os.path.isdir(os.path.join(self._path, entry)): + result.append(entry[1:]) + return result + + def get_folder(self, folder): + """Return a Maildir instance for the named folder.""" + return Maildir(os.path.join(self._path, '.' + folder), + factory=self._factory, + create=False) + + def add_folder(self, folder): + """Create a folder and return a Maildir instance representing it.""" + path = os.path.join(self._path, '.' + folder) + result = Maildir(path, factory=self._factory) + maildirfolder_path = os.path.join(path, 'maildirfolder') + if not os.path.exists(maildirfolder_path): + os.close(os.open(maildirfolder_path, os.O_CREAT | os.O_WRONLY, + 0o666)) + return result + + def remove_folder(self, folder): + """Delete the named folder, which must be empty.""" + path = os.path.join(self._path, '.' + folder) + for entry in os.listdir(os.path.join(path, 'new')) + \ + os.listdir(os.path.join(path, 'cur')): + if len(entry) < 1 or entry[0] != '.': + raise NotEmptyError('Folder contains message(s): %s' % folder) + for entry in os.listdir(path): + if entry != 'new' and entry != 'cur' and entry != 'tmp' and \ + os.path.isdir(os.path.join(path, entry)): + raise NotEmptyError("Folder contains subdirectory '%s': %s" % + (folder, entry)) + for root, dirs, files in os.walk(path, topdown=False): + for entry in files: + os.remove(os.path.join(root, entry)) + for entry in dirs: + os.rmdir(os.path.join(root, entry)) + os.rmdir(path) + + def clean(self): + """Delete old files in "tmp".""" + now = time.time() + for entry in os.listdir(os.path.join(self._path, 'tmp')): + path = os.path.join(self._path, 'tmp', entry) + if now - os.path.getatime(path) > 129600: # 60 * 60 * 36 + os.remove(path) + + _count = 1 # This is used to generate unique file names. + + def _create_tmp(self): + """Create a file in the tmp subdirectory and open and return it.""" + now = time.time() + hostname = socket.gethostname() + if '/' in hostname: + hostname = hostname.replace('/', r'\057') + if ':' in hostname: + hostname = hostname.replace(':', r'\072') + uniq = "%s.M%sP%sQ%s.%s" % (int(now), int(now % 1 * 1e6), os.getpid(), + Maildir._count, hostname) + path = os.path.join(self._path, 'tmp', uniq) + try: + os.stat(path) + except FileNotFoundError: + Maildir._count += 1 + try: + return _create_carefully(path) + except FileExistsError: + pass + + # Fall through to here if stat succeeded or open raised EEXIST. + raise ExternalClashError('Name clash prevented file creation: %s' % + path) + + def _refresh(self): + """Update table of contents mapping.""" + # If it has been less than two seconds since the last _refresh() call, + # we have to unconditionally re-read the mailbox just in case it has + # been modified, because os.path.mtime() has a 2 sec resolution in the + # most common worst case (FAT) and a 1 sec resolution typically. This + # results in a few unnecessary re-reads when _refresh() is called + # multiple times in that interval, but once the clock ticks over, we + # will only re-read as needed. Because the filesystem might be being + # served by an independent system with its own clock, we record and + # compare with the mtimes from the filesystem. Because the other + # system's clock might be skewing relative to our clock, we add an + # extra delta to our wait. The default is one tenth second, but is an + # instance variable and so can be adjusted if dealing with a + # particularly skewed or irregular system. + if time.time() - self._last_read > 2 + self._skewfactor: + refresh = False + for subdir in self._toc_mtimes: + mtime = os.path.getmtime(self._paths[subdir]) + if mtime > self._toc_mtimes[subdir]: + refresh = True + self._toc_mtimes[subdir] = mtime + if not refresh: + return + # Refresh toc + self._toc = {} + for subdir in self._toc_mtimes: + path = self._paths[subdir] + for entry in os.listdir(path): + p = os.path.join(path, entry) + if os.path.isdir(p): + continue + uniq = entry.split(self.colon)[0] + self._toc[uniq] = os.path.join(subdir, entry) + self._last_read = time.time() + + def _lookup(self, key): + """Use TOC to return subpath for given key, or raise a KeyError.""" + try: + if os.path.exists(os.path.join(self._path, self._toc[key])): + return self._toc[key] + except KeyError: + pass + self._refresh() + try: + return self._toc[key] + except KeyError: + raise KeyError('No message with key: %s' % key) from None + + # This method is for backward compatibility only. + def next(self): + """Return the next message in a one-time iteration.""" + if not hasattr(self, '_onetime_keys'): + self._onetime_keys = self.iterkeys() + while True: + try: + return self[next(self._onetime_keys)] + except StopIteration: + return None + except KeyError: + continue + + +class _singlefileMailbox(Mailbox): + """A single-file mailbox.""" + + def __init__(self, path, factory=None, create=True): + """Initialize a single-file mailbox.""" + Mailbox.__init__(self, path, factory, create) + try: + f = open(self._path, 'rb+') + except OSError as e: + if e.errno == errno.ENOENT: + if create: + f = open(self._path, 'wb+') + else: + raise NoSuchMailboxError(self._path) + elif e.errno in (errno.EACCES, errno.EROFS): + f = open(self._path, 'rb') + else: + raise + self._file = f + self._toc = None + self._next_key = 0 + self._pending = False # No changes require rewriting the file. + self._pending_sync = False # No need to sync the file + self._locked = False + self._file_length = None # Used to record mailbox size + + def add(self, message): + """Add message and return assigned key.""" + self._lookup() + self._toc[self._next_key] = self._append_message(message) + self._next_key += 1 + # _append_message appends the message to the mailbox file. We + # don't need a full rewrite + rename, sync is enough. + self._pending_sync = True + return self._next_key - 1 + + def remove(self, key): + """Remove the keyed message; raise KeyError if it doesn't exist.""" + self._lookup(key) + del self._toc[key] + self._pending = True + + def __setitem__(self, key, message): + """Replace the keyed message; raise KeyError if it doesn't exist.""" + self._lookup(key) + self._toc[key] = self._append_message(message) + self._pending = True + + def iterkeys(self): + """Return an iterator over keys.""" + self._lookup() + yield from self._toc.keys() + + def __contains__(self, key): + """Return True if the keyed message exists, False otherwise.""" + self._lookup() + return key in self._toc + + def __len__(self): + """Return a count of messages in the mailbox.""" + self._lookup() + return len(self._toc) + + def lock(self): + """Lock the mailbox.""" + if not self._locked: + _lock_file(self._file) + self._locked = True + + def unlock(self): + """Unlock the mailbox if it is locked.""" + if self._locked: + _unlock_file(self._file) + self._locked = False + + def flush(self): + """Write any pending changes to disk.""" + if not self._pending: + if self._pending_sync: + # Messages have only been added, so syncing the file + # is enough. + _sync_flush(self._file) + self._pending_sync = False + return + + # In order to be writing anything out at all, self._toc must + # already have been generated (and presumably has been modified + # by adding or deleting an item). + assert self._toc is not None + + # Check length of self._file; if it's changed, some other process + # has modified the mailbox since we scanned it. + self._file.seek(0, 2) + cur_len = self._file.tell() + if cur_len != self._file_length: + raise ExternalClashError('Size of mailbox file changed ' + '(expected %i, found %i)' % + (self._file_length, cur_len)) + + new_file = _create_temporary(self._path) + try: + new_toc = {} + self._pre_mailbox_hook(new_file) + for key in sorted(self._toc.keys()): + start, stop = self._toc[key] + self._file.seek(start) + self._pre_message_hook(new_file) + new_start = new_file.tell() + while True: + buffer = self._file.read(min(4096, + stop - self._file.tell())) + if not buffer: + break + new_file.write(buffer) + new_toc[key] = (new_start, new_file.tell()) + self._post_message_hook(new_file) + self._file_length = new_file.tell() + except: + new_file.close() + os.remove(new_file.name) + raise + _sync_close(new_file) + # self._file is about to get replaced, so no need to sync. + self._file.close() + # Make sure the new file's mode is the same as the old file's + mode = os.stat(self._path).st_mode + os.chmod(new_file.name, mode) + try: + os.rename(new_file.name, self._path) + except FileExistsError: + os.remove(self._path) + os.rename(new_file.name, self._path) + self._file = open(self._path, 'rb+') + self._toc = new_toc + self._pending = False + self._pending_sync = False + if self._locked: + _lock_file(self._file, dotlock=False) + + def _pre_mailbox_hook(self, f): + """Called before writing the mailbox to file f.""" + return + + def _pre_message_hook(self, f): + """Called before writing each message to file f.""" + return + + def _post_message_hook(self, f): + """Called after writing each message to file f.""" + return + + def close(self): + """Flush and close the mailbox.""" + try: + self.flush() + finally: + try: + if self._locked: + self.unlock() + finally: + self._file.close() # Sync has been done by self.flush() above. + + def _lookup(self, key=None): + """Return (start, stop) or raise KeyError.""" + if self._toc is None: + self._generate_toc() + if key is not None: + try: + return self._toc[key] + except KeyError: + raise KeyError('No message with key: %s' % key) from None + + def _append_message(self, message): + """Append message to mailbox and return (start, stop) offsets.""" + self._file.seek(0, 2) + before = self._file.tell() + if len(self._toc) == 0 and not self._pending: + # This is the first message, and the _pre_mailbox_hook + # hasn't yet been called. If self._pending is True, + # messages have been removed, so _pre_mailbox_hook must + # have been called already. + self._pre_mailbox_hook(self._file) + try: + self._pre_message_hook(self._file) + offsets = self._install_message(message) + self._post_message_hook(self._file) + except BaseException: + self._file.truncate(before) + raise + self._file.flush() + self._file_length = self._file.tell() # Record current length of mailbox + return offsets + + + +class _mboxMMDF(_singlefileMailbox): + """An mbox or MMDF mailbox.""" + + _mangle_from_ = True + + def get_message(self, key): + """Return a Message representation or raise a KeyError.""" + start, stop = self._lookup(key) + self._file.seek(start) + from_line = self._file.readline().replace(linesep, b'') + string = self._file.read(stop - self._file.tell()) + msg = self._message_factory(string.replace(linesep, b'\n')) + msg.set_from(from_line[5:].decode('ascii')) + return msg + + def get_string(self, key, from_=False): + """Return a string representation or raise a KeyError.""" + return email.message_from_bytes( + self.get_bytes(key, from_)).as_string(unixfrom=from_) + + def get_bytes(self, key, from_=False): + """Return a string representation or raise a KeyError.""" + start, stop = self._lookup(key) + self._file.seek(start) + if not from_: + self._file.readline() + string = self._file.read(stop - self._file.tell()) + return string.replace(linesep, b'\n') + + def get_file(self, key, from_=False): + """Return a file-like representation or raise a KeyError.""" + start, stop = self._lookup(key) + self._file.seek(start) + if not from_: + self._file.readline() + return _PartialFile(self._file, self._file.tell(), stop) + + def _install_message(self, message): + """Format a message and blindly write to self._file.""" + from_line = None + if isinstance(message, str): + message = self._string_to_bytes(message) + if isinstance(message, bytes) and message.startswith(b'From '): + newline = message.find(b'\n') + if newline != -1: + from_line = message[:newline] + message = message[newline + 1:] + else: + from_line = message + message = b'' + elif isinstance(message, _mboxMMDFMessage): + author = message.get_from().encode('ascii') + from_line = b'From ' + author + elif isinstance(message, email.message.Message): + from_line = message.get_unixfrom() # May be None. + if from_line is not None: + from_line = from_line.encode('ascii') + if from_line is None: + from_line = b'From MAILER-DAEMON ' + time.asctime(time.gmtime()).encode() + start = self._file.tell() + self._file.write(from_line + linesep) + self._dump_message(message, self._file, self._mangle_from_) + stop = self._file.tell() + return (start, stop) + + +class mbox(_mboxMMDF): + """A classic mbox mailbox.""" + + _mangle_from_ = True + + # All messages must end in a newline character, and + # _post_message_hooks outputs an empty line between messages. + _append_newline = True + + def __init__(self, path, factory=None, create=True): + """Initialize an mbox mailbox.""" + self._message_factory = mboxMessage + _mboxMMDF.__init__(self, path, factory, create) + + def _post_message_hook(self, f): + """Called after writing each message to file f.""" + f.write(linesep) + + def _generate_toc(self): + """Generate key-to-(start, stop) table of contents.""" + starts, stops = [], [] + last_was_empty = False + self._file.seek(0) + while True: + line_pos = self._file.tell() + line = self._file.readline() + if line.startswith(b'From '): + if len(stops) < len(starts): + if last_was_empty: + stops.append(line_pos - len(linesep)) + else: + # The last line before the "From " line wasn't + # blank, but we consider it a start of a + # message anyway. + stops.append(line_pos) + starts.append(line_pos) + last_was_empty = False + elif not line: + if last_was_empty: + stops.append(line_pos - len(linesep)) + else: + stops.append(line_pos) + break + elif line == linesep: + last_was_empty = True + else: + last_was_empty = False + self._toc = dict(enumerate(zip(starts, stops))) + self._next_key = len(self._toc) + self._file_length = self._file.tell() + + +class MMDF(_mboxMMDF): + """An MMDF mailbox.""" + + def __init__(self, path, factory=None, create=True): + """Initialize an MMDF mailbox.""" + self._message_factory = MMDFMessage + _mboxMMDF.__init__(self, path, factory, create) + + def _pre_message_hook(self, f): + """Called before writing each message to file f.""" + f.write(b'\001\001\001\001' + linesep) + + def _post_message_hook(self, f): + """Called after writing each message to file f.""" + f.write(linesep + b'\001\001\001\001' + linesep) + + def _generate_toc(self): + """Generate key-to-(start, stop) table of contents.""" + starts, stops = [], [] + self._file.seek(0) + next_pos = 0 + while True: + line_pos = next_pos + line = self._file.readline() + next_pos = self._file.tell() + if line.startswith(b'\001\001\001\001' + linesep): + starts.append(next_pos) + while True: + line_pos = next_pos + line = self._file.readline() + next_pos = self._file.tell() + if line == b'\001\001\001\001' + linesep: + stops.append(line_pos - len(linesep)) + break + elif not line: + stops.append(line_pos) + break + elif not line: + break + self._toc = dict(enumerate(zip(starts, stops))) + self._next_key = len(self._toc) + self._file.seek(0, 2) + self._file_length = self._file.tell() + + +class MH(Mailbox): + """An MH mailbox.""" + + def __init__(self, path, factory=None, create=True): + """Initialize an MH instance.""" + Mailbox.__init__(self, path, factory, create) + if not os.path.exists(self._path): + if create: + os.mkdir(self._path, 0o700) + os.close(os.open(os.path.join(self._path, '.mh_sequences'), + os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)) + else: + raise NoSuchMailboxError(self._path) + self._locked = False + + def add(self, message): + """Add message and return assigned key.""" + keys = self.keys() + if len(keys) == 0: + new_key = 1 + else: + new_key = max(keys) + 1 + new_path = os.path.join(self._path, str(new_key)) + f = _create_carefully(new_path) + closed = False + try: + if self._locked: + _lock_file(f) + try: + try: + self._dump_message(message, f) + except BaseException: + # Unlock and close so it can be deleted on Windows + if self._locked: + _unlock_file(f) + _sync_close(f) + closed = True + os.remove(new_path) + raise + if isinstance(message, MHMessage): + self._dump_sequences(message, new_key) + finally: + if self._locked: + _unlock_file(f) + finally: + if not closed: + _sync_close(f) + return new_key + + def remove(self, key): + """Remove the keyed message; raise KeyError if it doesn't exist.""" + path = os.path.join(self._path, str(key)) + try: + f = open(path, 'rb+') + except OSError as e: + if e.errno == errno.ENOENT: + raise KeyError('No message with key: %s' % key) + else: + raise + else: + f.close() + os.remove(path) + + def __setitem__(self, key, message): + """Replace the keyed message; raise KeyError if it doesn't exist.""" + path = os.path.join(self._path, str(key)) + try: + f = open(path, 'rb+') + except OSError as e: + if e.errno == errno.ENOENT: + raise KeyError('No message with key: %s' % key) + else: + raise + try: + if self._locked: + _lock_file(f) + try: + os.close(os.open(path, os.O_WRONLY | os.O_TRUNC)) + self._dump_message(message, f) + if isinstance(message, MHMessage): + self._dump_sequences(message, key) + finally: + if self._locked: + _unlock_file(f) + finally: + _sync_close(f) + + def get_message(self, key): + """Return a Message representation or raise a KeyError.""" + try: + if self._locked: + f = open(os.path.join(self._path, str(key)), 'rb+') + else: + f = open(os.path.join(self._path, str(key)), 'rb') + except OSError as e: + if e.errno == errno.ENOENT: + raise KeyError('No message with key: %s' % key) + else: + raise + with f: + if self._locked: + _lock_file(f) + try: + msg = MHMessage(f) + finally: + if self._locked: + _unlock_file(f) + for name, key_list in self.get_sequences().items(): + if key in key_list: + msg.add_sequence(name) + return msg + + def get_bytes(self, key): + """Return a bytes representation or raise a KeyError.""" + try: + if self._locked: + f = open(os.path.join(self._path, str(key)), 'rb+') + else: + f = open(os.path.join(self._path, str(key)), 'rb') + except OSError as e: + if e.errno == errno.ENOENT: + raise KeyError('No message with key: %s' % key) + else: + raise + with f: + if self._locked: + _lock_file(f) + try: + return f.read().replace(linesep, b'\n') + finally: + if self._locked: + _unlock_file(f) + + def get_file(self, key): + """Return a file-like representation or raise a KeyError.""" + try: + f = open(os.path.join(self._path, str(key)), 'rb') + except OSError as e: + if e.errno == errno.ENOENT: + raise KeyError('No message with key: %s' % key) + else: + raise + return _ProxyFile(f) + + def iterkeys(self): + """Return an iterator over keys.""" + return iter(sorted(int(entry) for entry in os.listdir(self._path) + if entry.isdigit())) + + def __contains__(self, key): + """Return True if the keyed message exists, False otherwise.""" + return os.path.exists(os.path.join(self._path, str(key))) + + def __len__(self): + """Return a count of messages in the mailbox.""" + return len(list(self.iterkeys())) + + def lock(self): + """Lock the mailbox.""" + if not self._locked: + self._file = open(os.path.join(self._path, '.mh_sequences'), 'rb+') + _lock_file(self._file) + self._locked = True + + def unlock(self): + """Unlock the mailbox if it is locked.""" + if self._locked: + _unlock_file(self._file) + _sync_close(self._file) + del self._file + self._locked = False + + def flush(self): + """Write any pending changes to the disk.""" + return + + def close(self): + """Flush and close the mailbox.""" + if self._locked: + self.unlock() + + def list_folders(self): + """Return a list of folder names.""" + result = [] + for entry in os.listdir(self._path): + if os.path.isdir(os.path.join(self._path, entry)): + result.append(entry) + return result + + def get_folder(self, folder): + """Return an MH instance for the named folder.""" + return MH(os.path.join(self._path, folder), + factory=self._factory, create=False) + + def add_folder(self, folder): + """Create a folder and return an MH instance representing it.""" + return MH(os.path.join(self._path, folder), + factory=self._factory) + + def remove_folder(self, folder): + """Delete the named folder, which must be empty.""" + path = os.path.join(self._path, folder) + entries = os.listdir(path) + if entries == ['.mh_sequences']: + os.remove(os.path.join(path, '.mh_sequences')) + elif entries == []: + pass + else: + raise NotEmptyError('Folder not empty: %s' % self._path) + os.rmdir(path) + + def get_sequences(self): + """Return a name-to-key-list dictionary to define each sequence.""" + results = {} + with open(os.path.join(self._path, '.mh_sequences'), 'r', encoding='ASCII') as f: + all_keys = set(self.keys()) + for line in f: + try: + name, contents = line.split(':') + keys = set() + for spec in contents.split(): + if spec.isdigit(): + keys.add(int(spec)) + else: + start, stop = (int(x) for x in spec.split('-')) + keys.update(range(start, stop + 1)) + results[name] = [key for key in sorted(keys) \ + if key in all_keys] + if len(results[name]) == 0: + del results[name] + except ValueError: + raise FormatError('Invalid sequence specification: %s' % + line.rstrip()) + return results + + def set_sequences(self, sequences): + """Set sequences using the given name-to-key-list dictionary.""" + f = open(os.path.join(self._path, '.mh_sequences'), 'r+', encoding='ASCII') + try: + os.close(os.open(f.name, os.O_WRONLY | os.O_TRUNC)) + for name, keys in sequences.items(): + if len(keys) == 0: + continue + f.write(name + ':') + prev = None + completing = False + for key in sorted(set(keys)): + if key - 1 == prev: + if not completing: + completing = True + f.write('-') + elif completing: + completing = False + f.write('%s %s' % (prev, key)) + else: + f.write(' %s' % key) + prev = key + if completing: + f.write(str(prev) + '\n') + else: + f.write('\n') + finally: + _sync_close(f) + + def pack(self): + """Re-name messages to eliminate numbering gaps. Invalidates keys.""" + sequences = self.get_sequences() + prev = 0 + changes = [] + for key in self.iterkeys(): + if key - 1 != prev: + changes.append((key, prev + 1)) + try: + os.link(os.path.join(self._path, str(key)), + os.path.join(self._path, str(prev + 1))) + except (AttributeError, PermissionError): + os.rename(os.path.join(self._path, str(key)), + os.path.join(self._path, str(prev + 1))) + else: + os.unlink(os.path.join(self._path, str(key))) + prev += 1 + self._next_key = prev + 1 + if len(changes) == 0: + return + for name, key_list in sequences.items(): + for old, new in changes: + if old in key_list: + key_list[key_list.index(old)] = new + self.set_sequences(sequences) + + def _dump_sequences(self, message, key): + """Inspect a new MHMessage and update sequences appropriately.""" + pending_sequences = message.get_sequences() + all_sequences = self.get_sequences() + for name, key_list in all_sequences.items(): + if name in pending_sequences: + key_list.append(key) + elif key in key_list: + del key_list[key_list.index(key)] + for sequence in pending_sequences: + if sequence not in all_sequences: + all_sequences[sequence] = [key] + self.set_sequences(all_sequences) + + +class Babyl(_singlefileMailbox): + """An Rmail-style Babyl mailbox.""" + + _special_labels = frozenset({'unseen', 'deleted', 'filed', 'answered', + 'forwarded', 'edited', 'resent'}) + + def __init__(self, path, factory=None, create=True): + """Initialize a Babyl mailbox.""" + _singlefileMailbox.__init__(self, path, factory, create) + self._labels = {} + + def add(self, message): + """Add message and return assigned key.""" + key = _singlefileMailbox.add(self, message) + if isinstance(message, BabylMessage): + self._labels[key] = message.get_labels() + return key + + def remove(self, key): + """Remove the keyed message; raise KeyError if it doesn't exist.""" + _singlefileMailbox.remove(self, key) + if key in self._labels: + del self._labels[key] + + def __setitem__(self, key, message): + """Replace the keyed message; raise KeyError if it doesn't exist.""" + _singlefileMailbox.__setitem__(self, key, message) + if isinstance(message, BabylMessage): + self._labels[key] = message.get_labels() + + def get_message(self, key): + """Return a Message representation or raise a KeyError.""" + start, stop = self._lookup(key) + self._file.seek(start) + self._file.readline() # Skip b'1,' line specifying labels. + original_headers = io.BytesIO() + while True: + line = self._file.readline() + if line == b'*** EOOH ***' + linesep or not line: + break + original_headers.write(line.replace(linesep, b'\n')) + visible_headers = io.BytesIO() + while True: + line = self._file.readline() + if line == linesep or not line: + break + visible_headers.write(line.replace(linesep, b'\n')) + # Read up to the stop, or to the end + n = stop - self._file.tell() + assert n >= 0 + body = self._file.read(n) + body = body.replace(linesep, b'\n') + msg = BabylMessage(original_headers.getvalue() + body) + msg.set_visible(visible_headers.getvalue()) + if key in self._labels: + msg.set_labels(self._labels[key]) + return msg + + def get_bytes(self, key): + """Return a string representation or raise a KeyError.""" + start, stop = self._lookup(key) + self._file.seek(start) + self._file.readline() # Skip b'1,' line specifying labels. + original_headers = io.BytesIO() + while True: + line = self._file.readline() + if line == b'*** EOOH ***' + linesep or not line: + break + original_headers.write(line.replace(linesep, b'\n')) + while True: + line = self._file.readline() + if line == linesep or not line: + break + headers = original_headers.getvalue() + n = stop - self._file.tell() + assert n >= 0 + data = self._file.read(n) + data = data.replace(linesep, b'\n') + return headers + data + + def get_file(self, key): + """Return a file-like representation or raise a KeyError.""" + return io.BytesIO(self.get_bytes(key).replace(b'\n', linesep)) + + def get_labels(self): + """Return a list of user-defined labels in the mailbox.""" + self._lookup() + labels = set() + for label_list in self._labels.values(): + labels.update(label_list) + labels.difference_update(self._special_labels) + return list(labels) + + def _generate_toc(self): + """Generate key-to-(start, stop) table of contents.""" + starts, stops = [], [] + self._file.seek(0) + next_pos = 0 + label_lists = [] + while True: + line_pos = next_pos + line = self._file.readline() + next_pos = self._file.tell() + if line == b'\037\014' + linesep: + if len(stops) < len(starts): + stops.append(line_pos - len(linesep)) + starts.append(next_pos) + labels = [label.strip() for label + in self._file.readline()[1:].split(b',') + if label.strip()] + label_lists.append(labels) + elif line == b'\037' or line == b'\037' + linesep: + if len(stops) < len(starts): + stops.append(line_pos - len(linesep)) + elif not line: + stops.append(line_pos - len(linesep)) + break + self._toc = dict(enumerate(zip(starts, stops))) + self._labels = dict(enumerate(label_lists)) + self._next_key = len(self._toc) + self._file.seek(0, 2) + self._file_length = self._file.tell() + + def _pre_mailbox_hook(self, f): + """Called before writing the mailbox to file f.""" + babyl = b'BABYL OPTIONS:' + linesep + babyl += b'Version: 5' + linesep + labels = self.get_labels() + labels = (label.encode() for label in labels) + babyl += b'Labels:' + b','.join(labels) + linesep + babyl += b'\037' + f.write(babyl) + + def _pre_message_hook(self, f): + """Called before writing each message to file f.""" + f.write(b'\014' + linesep) + + def _post_message_hook(self, f): + """Called after writing each message to file f.""" + f.write(linesep + b'\037') + + def _install_message(self, message): + """Write message contents and return (start, stop).""" + start = self._file.tell() + if isinstance(message, BabylMessage): + special_labels = [] + labels = [] + for label in message.get_labels(): + if label in self._special_labels: + special_labels.append(label) + else: + labels.append(label) + self._file.write(b'1') + for label in special_labels: + self._file.write(b', ' + label.encode()) + self._file.write(b',,') + for label in labels: + self._file.write(b' ' + label.encode() + b',') + self._file.write(linesep) + else: + self._file.write(b'1,,' + linesep) + if isinstance(message, email.message.Message): + orig_buffer = io.BytesIO() + orig_generator = email.generator.BytesGenerator(orig_buffer, False, 0) + orig_generator.flatten(message) + orig_buffer.seek(0) + while True: + line = orig_buffer.readline() + self._file.write(line.replace(b'\n', linesep)) + if line == b'\n' or not line: + break + self._file.write(b'*** EOOH ***' + linesep) + if isinstance(message, BabylMessage): + vis_buffer = io.BytesIO() + vis_generator = email.generator.BytesGenerator(vis_buffer, False, 0) + vis_generator.flatten(message.get_visible()) + while True: + line = vis_buffer.readline() + self._file.write(line.replace(b'\n', linesep)) + if line == b'\n' or not line: + break + else: + orig_buffer.seek(0) + while True: + line = orig_buffer.readline() + self._file.write(line.replace(b'\n', linesep)) + if line == b'\n' or not line: + break + while True: + buffer = orig_buffer.read(4096) # Buffer size is arbitrary. + if not buffer: + break + self._file.write(buffer.replace(b'\n', linesep)) + elif isinstance(message, (bytes, str, io.StringIO)): + if isinstance(message, io.StringIO): + warnings.warn("Use of StringIO input is deprecated, " + "use BytesIO instead", DeprecationWarning, 3) + message = message.getvalue() + if isinstance(message, str): + message = self._string_to_bytes(message) + body_start = message.find(b'\n\n') + 2 + if body_start - 2 != -1: + self._file.write(message[:body_start].replace(b'\n', linesep)) + self._file.write(b'*** EOOH ***' + linesep) + self._file.write(message[:body_start].replace(b'\n', linesep)) + self._file.write(message[body_start:].replace(b'\n', linesep)) + else: + self._file.write(b'*** EOOH ***' + linesep + linesep) + self._file.write(message.replace(b'\n', linesep)) + elif hasattr(message, 'readline'): + if hasattr(message, 'buffer'): + warnings.warn("Use of text mode files is deprecated, " + "use a binary mode file instead", DeprecationWarning, 3) + message = message.buffer + original_pos = message.tell() + first_pass = True + while True: + line = message.readline() + # Universal newline support. + if line.endswith(b'\r\n'): + line = line[:-2] + b'\n' + elif line.endswith(b'\r'): + line = line[:-1] + b'\n' + self._file.write(line.replace(b'\n', linesep)) + if line == b'\n' or not line: + if first_pass: + first_pass = False + self._file.write(b'*** EOOH ***' + linesep) + message.seek(original_pos) + else: + break + while True: + line = message.readline() + if not line: + break + # Universal newline support. + if line.endswith(b'\r\n'): + line = line[:-2] + linesep + elif line.endswith(b'\r'): + line = line[:-1] + linesep + elif line.endswith(b'\n'): + line = line[:-1] + linesep + self._file.write(line) + else: + raise TypeError('Invalid message type: %s' % type(message)) + stop = self._file.tell() + return (start, stop) + + +class Message(email.message.Message): + """Message with mailbox-format-specific properties.""" + + def __init__(self, message=None): + """Initialize a Message instance.""" + if isinstance(message, email.message.Message): + self._become_message(copy.deepcopy(message)) + if isinstance(message, Message): + message._explain_to(self) + elif isinstance(message, bytes): + self._become_message(email.message_from_bytes(message)) + elif isinstance(message, str): + self._become_message(email.message_from_string(message)) + elif isinstance(message, io.TextIOWrapper): + self._become_message(email.message_from_file(message)) + elif hasattr(message, "read"): + self._become_message(email.message_from_binary_file(message)) + elif message is None: + email.message.Message.__init__(self) + else: + raise TypeError('Invalid message type: %s' % type(message)) + + def _become_message(self, message): + """Assume the non-format-specific state of message.""" + type_specific = getattr(message, '_type_specific_attributes', []) + for name in message.__dict__: + if name not in type_specific: + self.__dict__[name] = message.__dict__[name] + + def _explain_to(self, message): + """Copy format-specific state to message insofar as possible.""" + if isinstance(message, Message): + return # There's nothing format-specific to explain. + else: + raise TypeError('Cannot convert to specified type') + + +class MaildirMessage(Message): + """Message with Maildir-specific properties.""" + + _type_specific_attributes = ['_subdir', '_info', '_date'] + + def __init__(self, message=None): + """Initialize a MaildirMessage instance.""" + self._subdir = 'new' + self._info = '' + self._date = time.time() + Message.__init__(self, message) + + def get_subdir(self): + """Return 'new' or 'cur'.""" + return self._subdir + + def set_subdir(self, subdir): + """Set subdir to 'new' or 'cur'.""" + if subdir == 'new' or subdir == 'cur': + self._subdir = subdir + else: + raise ValueError("subdir must be 'new' or 'cur': %s" % subdir) + + def get_flags(self): + """Return as a string the flags that are set.""" + if self._info.startswith('2,'): + return self._info[2:] + else: + return '' + + def set_flags(self, flags): + """Set the given flags and unset all others.""" + self._info = '2,' + ''.join(sorted(flags)) + + def add_flag(self, flag): + """Set the given flag(s) without changing others.""" + self.set_flags(''.join(set(self.get_flags()) | set(flag))) + + def remove_flag(self, flag): + """Unset the given string flag(s) without changing others.""" + if self.get_flags(): + self.set_flags(''.join(set(self.get_flags()) - set(flag))) + + def get_date(self): + """Return delivery date of message, in seconds since the epoch.""" + return self._date + + def set_date(self, date): + """Set delivery date of message, in seconds since the epoch.""" + try: + self._date = float(date) + except ValueError: + raise TypeError("can't convert to float: %s" % date) from None + + def get_info(self): + """Get the message's "info" as a string.""" + return self._info + + def set_info(self, info): + """Set the message's "info" string.""" + if isinstance(info, str): + self._info = info + else: + raise TypeError('info must be a string: %s' % type(info)) + + def _explain_to(self, message): + """Copy Maildir-specific state to message insofar as possible.""" + if isinstance(message, MaildirMessage): + message.set_flags(self.get_flags()) + message.set_subdir(self.get_subdir()) + message.set_date(self.get_date()) + elif isinstance(message, _mboxMMDFMessage): + flags = set(self.get_flags()) + if 'S' in flags: + message.add_flag('R') + if self.get_subdir() == 'cur': + message.add_flag('O') + if 'T' in flags: + message.add_flag('D') + if 'F' in flags: + message.add_flag('F') + if 'R' in flags: + message.add_flag('A') + message.set_from('MAILER-DAEMON', time.gmtime(self.get_date())) + elif isinstance(message, MHMessage): + flags = set(self.get_flags()) + if 'S' not in flags: + message.add_sequence('unseen') + if 'R' in flags: + message.add_sequence('replied') + if 'F' in flags: + message.add_sequence('flagged') + elif isinstance(message, BabylMessage): + flags = set(self.get_flags()) + if 'S' not in flags: + message.add_label('unseen') + if 'T' in flags: + message.add_label('deleted') + if 'R' in flags: + message.add_label('answered') + if 'P' in flags: + message.add_label('forwarded') + elif isinstance(message, Message): + pass + else: + raise TypeError('Cannot convert to specified type: %s' % + type(message)) + + +class _mboxMMDFMessage(Message): + """Message with mbox- or MMDF-specific properties.""" + + _type_specific_attributes = ['_from'] + + def __init__(self, message=None): + """Initialize an mboxMMDFMessage instance.""" + self.set_from('MAILER-DAEMON', True) + if isinstance(message, email.message.Message): + unixfrom = message.get_unixfrom() + if unixfrom is not None and unixfrom.startswith('From '): + self.set_from(unixfrom[5:]) + Message.__init__(self, message) + + def get_from(self): + """Return contents of "From " line.""" + return self._from + + def set_from(self, from_, time_=None): + """Set "From " line, formatting and appending time_ if specified.""" + if time_ is not None: + if time_ is True: + time_ = time.gmtime() + from_ += ' ' + time.asctime(time_) + self._from = from_ + + def get_flags(self): + """Return as a string the flags that are set.""" + return self.get('Status', '') + self.get('X-Status', '') + + def set_flags(self, flags): + """Set the given flags and unset all others.""" + flags = set(flags) + status_flags, xstatus_flags = '', '' + for flag in ('R', 'O'): + if flag in flags: + status_flags += flag + flags.remove(flag) + for flag in ('D', 'F', 'A'): + if flag in flags: + xstatus_flags += flag + flags.remove(flag) + xstatus_flags += ''.join(sorted(flags)) + try: + self.replace_header('Status', status_flags) + except KeyError: + self.add_header('Status', status_flags) + try: + self.replace_header('X-Status', xstatus_flags) + except KeyError: + self.add_header('X-Status', xstatus_flags) + + def add_flag(self, flag): + """Set the given flag(s) without changing others.""" + self.set_flags(''.join(set(self.get_flags()) | set(flag))) + + def remove_flag(self, flag): + """Unset the given string flag(s) without changing others.""" + if 'Status' in self or 'X-Status' in self: + self.set_flags(''.join(set(self.get_flags()) - set(flag))) + + def _explain_to(self, message): + """Copy mbox- or MMDF-specific state to message insofar as possible.""" + if isinstance(message, MaildirMessage): + flags = set(self.get_flags()) + if 'O' in flags: + message.set_subdir('cur') + if 'F' in flags: + message.add_flag('F') + if 'A' in flags: + message.add_flag('R') + if 'R' in flags: + message.add_flag('S') + if 'D' in flags: + message.add_flag('T') + del message['status'] + del message['x-status'] + maybe_date = ' '.join(self.get_from().split()[-5:]) + try: + message.set_date(calendar.timegm(time.strptime(maybe_date, + '%a %b %d %H:%M:%S %Y'))) + except (ValueError, OverflowError): + pass + elif isinstance(message, _mboxMMDFMessage): + message.set_flags(self.get_flags()) + message.set_from(self.get_from()) + elif isinstance(message, MHMessage): + flags = set(self.get_flags()) + if 'R' not in flags: + message.add_sequence('unseen') + if 'A' in flags: + message.add_sequence('replied') + if 'F' in flags: + message.add_sequence('flagged') + del message['status'] + del message['x-status'] + elif isinstance(message, BabylMessage): + flags = set(self.get_flags()) + if 'R' not in flags: + message.add_label('unseen') + if 'D' in flags: + message.add_label('deleted') + if 'A' in flags: + message.add_label('answered') + del message['status'] + del message['x-status'] + elif isinstance(message, Message): + pass + else: + raise TypeError('Cannot convert to specified type: %s' % + type(message)) + + +class mboxMessage(_mboxMMDFMessage): + """Message with mbox-specific properties.""" + + +class MHMessage(Message): + """Message with MH-specific properties.""" + + _type_specific_attributes = ['_sequences'] + + def __init__(self, message=None): + """Initialize an MHMessage instance.""" + self._sequences = [] + Message.__init__(self, message) + + def get_sequences(self): + """Return a list of sequences that include the message.""" + return self._sequences[:] + + def set_sequences(self, sequences): + """Set the list of sequences that include the message.""" + self._sequences = list(sequences) + + def add_sequence(self, sequence): + """Add sequence to list of sequences including the message.""" + if isinstance(sequence, str): + if not sequence in self._sequences: + self._sequences.append(sequence) + else: + raise TypeError('sequence type must be str: %s' % type(sequence)) + + def remove_sequence(self, sequence): + """Remove sequence from the list of sequences including the message.""" + try: + self._sequences.remove(sequence) + except ValueError: + pass + + def _explain_to(self, message): + """Copy MH-specific state to message insofar as possible.""" + if isinstance(message, MaildirMessage): + sequences = set(self.get_sequences()) + if 'unseen' in sequences: + message.set_subdir('cur') + else: + message.set_subdir('cur') + message.add_flag('S') + if 'flagged' in sequences: + message.add_flag('F') + if 'replied' in sequences: + message.add_flag('R') + elif isinstance(message, _mboxMMDFMessage): + sequences = set(self.get_sequences()) + if 'unseen' not in sequences: + message.add_flag('RO') + else: + message.add_flag('O') + if 'flagged' in sequences: + message.add_flag('F') + if 'replied' in sequences: + message.add_flag('A') + elif isinstance(message, MHMessage): + for sequence in self.get_sequences(): + message.add_sequence(sequence) + elif isinstance(message, BabylMessage): + sequences = set(self.get_sequences()) + if 'unseen' in sequences: + message.add_label('unseen') + if 'replied' in sequences: + message.add_label('answered') + elif isinstance(message, Message): + pass + else: + raise TypeError('Cannot convert to specified type: %s' % + type(message)) + + +class BabylMessage(Message): + """Message with Babyl-specific properties.""" + + _type_specific_attributes = ['_labels', '_visible'] + + def __init__(self, message=None): + """Initialize a BabylMessage instance.""" + self._labels = [] + self._visible = Message() + Message.__init__(self, message) + + def get_labels(self): + """Return a list of labels on the message.""" + return self._labels[:] + + def set_labels(self, labels): + """Set the list of labels on the message.""" + self._labels = list(labels) + + def add_label(self, label): + """Add label to list of labels on the message.""" + if isinstance(label, str): + if label not in self._labels: + self._labels.append(label) + else: + raise TypeError('label must be a string: %s' % type(label)) + + def remove_label(self, label): + """Remove label from the list of labels on the message.""" + try: + self._labels.remove(label) + except ValueError: + pass + + def get_visible(self): + """Return a Message representation of visible headers.""" + return Message(self._visible) + + def set_visible(self, visible): + """Set the Message representation of visible headers.""" + self._visible = Message(visible) + + def update_visible(self): + """Update and/or sensibly generate a set of visible headers.""" + for header in self._visible.keys(): + if header in self: + self._visible.replace_header(header, self[header]) + else: + del self._visible[header] + for header in ('Date', 'From', 'Reply-To', 'To', 'CC', 'Subject'): + if header in self and header not in self._visible: + self._visible[header] = self[header] + + def _explain_to(self, message): + """Copy Babyl-specific state to message insofar as possible.""" + if isinstance(message, MaildirMessage): + labels = set(self.get_labels()) + if 'unseen' in labels: + message.set_subdir('cur') + else: + message.set_subdir('cur') + message.add_flag('S') + if 'forwarded' in labels or 'resent' in labels: + message.add_flag('P') + if 'answered' in labels: + message.add_flag('R') + if 'deleted' in labels: + message.add_flag('T') + elif isinstance(message, _mboxMMDFMessage): + labels = set(self.get_labels()) + if 'unseen' not in labels: + message.add_flag('RO') + else: + message.add_flag('O') + if 'deleted' in labels: + message.add_flag('D') + if 'answered' in labels: + message.add_flag('A') + elif isinstance(message, MHMessage): + labels = set(self.get_labels()) + if 'unseen' in labels: + message.add_sequence('unseen') + if 'answered' in labels: + message.add_sequence('replied') + elif isinstance(message, BabylMessage): + message.set_visible(self.get_visible()) + for label in self.get_labels(): + message.add_label(label) + elif isinstance(message, Message): + pass + else: + raise TypeError('Cannot convert to specified type: %s' % + type(message)) + + +class MMDFMessage(_mboxMMDFMessage): + """Message with MMDF-specific properties.""" + + +class _ProxyFile: + """A read-only wrapper of a file.""" + + def __init__(self, f, pos=None): + """Initialize a _ProxyFile.""" + self._file = f + if pos is None: + self._pos = f.tell() + else: + self._pos = pos + + def read(self, size=None): + """Read bytes.""" + return self._read(size, self._file.read) + + def read1(self, size=None): + """Read bytes.""" + return self._read(size, self._file.read1) + + def readline(self, size=None): + """Read a line.""" + return self._read(size, self._file.readline) + + def readlines(self, sizehint=None): + """Read multiple lines.""" + result = [] + for line in self: + result.append(line) + if sizehint is not None: + sizehint -= len(line) + if sizehint <= 0: + break + return result + + def __iter__(self): + """Iterate over lines.""" + while True: + line = self.readline() + if not line: + return + yield line + + def tell(self): + """Return the position.""" + return self._pos + + def seek(self, offset, whence=0): + """Change position.""" + if whence == 1: + self._file.seek(self._pos) + self._file.seek(offset, whence) + self._pos = self._file.tell() + + def close(self): + """Close the file.""" + if hasattr(self, '_file'): + try: + if hasattr(self._file, 'close'): + self._file.close() + finally: + del self._file + + def _read(self, size, read_method): + """Read size bytes using read_method.""" + if size is None: + size = -1 + self._file.seek(self._pos) + result = read_method(size) + self._pos = self._file.tell() + return result + + def __enter__(self): + """Context management protocol support.""" + return self + + def __exit__(self, *exc): + self.close() + + def readable(self): + return self._file.readable() + + def writable(self): + return self._file.writable() + + def seekable(self): + return self._file.seekable() + + def flush(self): + return self._file.flush() + + @property + def closed(self): + if not hasattr(self, '_file'): + return True + if not hasattr(self._file, 'closed'): + return False + return self._file.closed + + __class_getitem__ = classmethod(GenericAlias) + + +class _PartialFile(_ProxyFile): + """A read-only wrapper of part of a file.""" + + def __init__(self, f, start=None, stop=None): + """Initialize a _PartialFile.""" + _ProxyFile.__init__(self, f, start) + self._start = start + self._stop = stop + + def tell(self): + """Return the position with respect to start.""" + return _ProxyFile.tell(self) - self._start + + def seek(self, offset, whence=0): + """Change position, possibly with respect to start or stop.""" + if whence == 0: + self._pos = self._start + whence = 1 + elif whence == 2: + self._pos = self._stop + whence = 1 + _ProxyFile.seek(self, offset, whence) + + def _read(self, size, read_method): + """Read size bytes using read_method, honoring start and stop.""" + remaining = self._stop - self._pos + if remaining <= 0: + return b'' + if size is None or size < 0 or size > remaining: + size = remaining + return _ProxyFile._read(self, size, read_method) + + def close(self): + # do *not* close the underlying file object for partial files, + # since it's global to the mailbox object + if hasattr(self, '_file'): + del self._file + + +def _lock_file(f, dotlock=True): + """Lock file f using lockf and dot locking.""" + dotlock_done = False + try: + if fcntl: + try: + fcntl.lockf(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + except OSError as e: + if e.errno in (errno.EAGAIN, errno.EACCES, errno.EROFS): + raise ExternalClashError('lockf: lock unavailable: %s' % + f.name) + else: + raise + if dotlock: + try: + pre_lock = _create_temporary(f.name + '.lock') + pre_lock.close() + except OSError as e: + if e.errno in (errno.EACCES, errno.EROFS): + return # Without write access, just skip dotlocking. + else: + raise + try: + try: + os.link(pre_lock.name, f.name + '.lock') + dotlock_done = True + except (AttributeError, PermissionError): + os.rename(pre_lock.name, f.name + '.lock') + dotlock_done = True + else: + os.unlink(pre_lock.name) + except FileExistsError: + os.remove(pre_lock.name) + raise ExternalClashError('dot lock unavailable: %s' % + f.name) + except: + if fcntl: + fcntl.lockf(f, fcntl.LOCK_UN) + if dotlock_done: + os.remove(f.name + '.lock') + raise + +def _unlock_file(f): + """Unlock file f using lockf and dot locking.""" + if fcntl: + fcntl.lockf(f, fcntl.LOCK_UN) + if os.path.exists(f.name + '.lock'): + os.remove(f.name + '.lock') + +def _create_carefully(path): + """Create a file if it doesn't exist and open for reading and writing.""" + fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0o666) + try: + return open(path, 'rb+') + finally: + os.close(fd) + +def _create_temporary(path): + """Create a temp file based on path and open for reading and writing.""" + return _create_carefully('%s.%s.%s.%s' % (path, int(time.time()), + socket.gethostname(), + os.getpid())) + +def _sync_flush(f): + """Ensure changes to file f are physically on disk.""" + f.flush() + if hasattr(os, 'fsync'): + os.fsync(f.fileno()) + +def _sync_close(f): + """Close file f, ensuring all changes are physically on disk.""" + _sync_flush(f) + f.close() + + +class Error(Exception): + """Raised for module-specific errors.""" + +class NoSuchMailboxError(Error): + """The specified mailbox does not exist and won't be created.""" + +class NotEmptyError(Error): + """The specified mailbox is not empty and deletion was requested.""" + +class ExternalClashError(Error): + """Another process caused an action to fail.""" + +class FormatError(Error): + """A file appears to have an invalid format.""" diff --git a/Lib/test/test_genericalias.py b/Lib/test/test_genericalias.py index 2992dd885..560b96f7e 100644 --- a/Lib/test/test_genericalias.py +++ b/Lib/test/test_genericalias.py @@ -1,342 +1,346 @@ -"""Tests for C-implemented GenericAlias.""" - -import unittest -import pickle -import copy -from collections import ( - defaultdict, deque, OrderedDict, Counter, UserDict, UserList -) -from collections.abc import * -from concurrent.futures import Future -from concurrent.futures.thread import _WorkItem -from contextlib import AbstractContextManager, AbstractAsyncContextManager -# XXX RUSTPYTHON TODO: from contextvars import ContextVar, Token -from dataclasses import Field -from functools import partial, partialmethod, cached_property -# XXX RUSTPYTHON TODO: from mailbox import Mailbox, _PartialFile -try: - import ctypes -except ImportError: - ctypes = None -from difflib import SequenceMatcher -from filecmp import dircmp -# XXX RUSTPYTHON TODO: from fileinput import FileInput -from itertools import chain -from http.cookies import Morsel -from multiprocessing.managers import ValueProxy -from multiprocessing.pool import ApplyResult -try: - from multiprocessing.shared_memory import ShareableList -except ImportError: - # multiprocessing.shared_memory is not available on e.g. Android - ShareableList = None -from multiprocessing.queues import SimpleQueue as MPSimpleQueue -from os import DirEntry -from re import Pattern, Match -from types import GenericAlias, MappingProxyType, AsyncGeneratorType -from tempfile import TemporaryDirectory, SpooledTemporaryFile -from urllib.parse import SplitResult, ParseResult -from unittest.case import _AssertRaisesContext -from queue import Queue, SimpleQueue -from weakref import WeakSet, ReferenceType, ref -import typing - -from typing import TypeVar -T = TypeVar('T') -K = TypeVar('K') -V = TypeVar('V') - -class BaseTest(unittest.TestCase): - """Test basics.""" - generic_types = [type, tuple, list, dict, set, frozenset, enumerate, - defaultdict, deque, - SequenceMatcher, - dircmp, - # XXX RUSTPYTHON TODO: FileInput, - OrderedDict, Counter, UserDict, UserList, - Pattern, Match, - partialmethod, cached_property, # XXX RUSTPYTHON TODO: partial - # XXX RUSTPYTHON TODO: AbstractContextManager, AbstractAsyncContextManager, - Awaitable, Coroutine, - AsyncIterable, AsyncIterator, - AsyncGenerator, Generator, - Iterable, Iterator, - Reversible, - Container, Collection, - # XXX RUSTPYTHON TODO: Mailbox, _PartialFile, - # XXX RUSTPYTHON TODO: ContextVar, Token, - Field, - Set, MutableSet, - Mapping, MutableMapping, MappingView, - KeysView, ItemsView, ValuesView, - Sequence, MutableSequence, - MappingProxyType, AsyncGeneratorType, - DirEntry, - chain, - TemporaryDirectory, SpooledTemporaryFile, - Queue, SimpleQueue, - _AssertRaisesContext, - SplitResult, ParseResult, - ValueProxy, ApplyResult, - WeakSet, ReferenceType, ref, - ShareableList, MPSimpleQueue, - Future, _WorkItem, - Morsel] - if ctypes is not None: - generic_types.extend((ctypes.Array, ctypes.LibraryLoader)) - - def test_subscriptable(self): - for t in self.generic_types: - if t is None: - continue - tname = t.__name__ - with self.subTest(f"Testing {tname}"): - alias = t[int] - self.assertIs(alias.__origin__, t) - self.assertEqual(alias.__args__, (int,)) - self.assertEqual(alias.__parameters__, ()) - - def test_unsubscriptable(self): - for t in int, str, float, Sized, Hashable: - tname = t.__name__ - with self.subTest(f"Testing {tname}"): - with self.assertRaises(TypeError): - t[int] - - def test_instantiate(self): - for t in tuple, list, dict, set, frozenset, defaultdict, deque: - tname = t.__name__ - with self.subTest(f"Testing {tname}"): - alias = t[int] - self.assertEqual(alias(), t()) - if t is dict: - self.assertEqual(alias(iter([('a', 1), ('b', 2)])), dict(a=1, b=2)) - self.assertEqual(alias(a=1, b=2), dict(a=1, b=2)) - elif t is defaultdict: - def default(): - return 'value' - a = alias(default) - d = defaultdict(default) - self.assertEqual(a['test'], d['test']) - else: - self.assertEqual(alias(iter((1, 2, 3))), t((1, 2, 3))) - - def test_unbound_methods(self): - t = list[int] - a = t() - t.append(a, 'foo') - self.assertEqual(a, ['foo']) - x = t.__getitem__(a, 0) - self.assertEqual(x, 'foo') - self.assertEqual(t.__len__(a), 1) - - def test_subclassing(self): - class C(list[int]): - pass - self.assertEqual(C.__bases__, (list,)) - self.assertEqual(C.__class__, type) - - def test_class_methods(self): - t = dict[int, None] - self.assertEqual(dict.fromkeys(range(2)), {0: None, 1: None}) # This works - self.assertEqual(t.fromkeys(range(2)), {0: None, 1: None}) # Should be equivalent - - def test_no_chaining(self): - t = list[int] - with self.assertRaises(TypeError): - t[int] - - def test_generic_subclass(self): - class MyList(list): - pass - t = MyList[int] - self.assertIs(t.__origin__, MyList) - self.assertEqual(t.__args__, (int,)) - self.assertEqual(t.__parameters__, ()) - - def test_repr(self): - class MyList(list): - pass - self.assertEqual(repr(list[str]), 'list[str]') - self.assertEqual(repr(list[()]), 'list[()]') - self.assertEqual(repr(tuple[int, ...]), 'tuple[int, ...]') - self.assertTrue(repr(MyList[int]).endswith('.BaseTest.test_repr..MyList[int]')) - self.assertEqual(repr(list[str]()), '[]') # instances should keep their normal repr - - def test_exposed_type(self): - import types - a = types.GenericAlias(list, int) - self.assertEqual(str(a), 'list[int]') - self.assertIs(a.__origin__, list) - self.assertEqual(a.__args__, (int,)) - self.assertEqual(a.__parameters__, ()) - - def test_parameters(self): - from typing import List, Dict, Callable - D0 = dict[str, int] - self.assertEqual(D0.__args__, (str, int)) - self.assertEqual(D0.__parameters__, ()) - D1a = dict[str, V] - self.assertEqual(D1a.__args__, (str, V)) - self.assertEqual(D1a.__parameters__, (V,)) - D1b = dict[K, int] - self.assertEqual(D1b.__args__, (K, int)) - self.assertEqual(D1b.__parameters__, (K,)) - D2a = dict[K, V] - self.assertEqual(D2a.__args__, (K, V)) - self.assertEqual(D2a.__parameters__, (K, V)) - D2b = dict[T, T] - self.assertEqual(D2b.__args__, (T, T)) - self.assertEqual(D2b.__parameters__, (T,)) - L0 = list[str] - self.assertEqual(L0.__args__, (str,)) - self.assertEqual(L0.__parameters__, ()) - L1 = list[T] - self.assertEqual(L1.__args__, (T,)) - self.assertEqual(L1.__parameters__, (T,)) - L2 = list[list[T]] - self.assertEqual(L2.__args__, (list[T],)) - self.assertEqual(L2.__parameters__, (T,)) - L3 = list[List[T]] - self.assertEqual(L3.__args__, (List[T],)) - self.assertEqual(L3.__parameters__, (T,)) - L4a = list[Dict[K, V]] - self.assertEqual(L4a.__args__, (Dict[K, V],)) - self.assertEqual(L4a.__parameters__, (K, V)) - L4b = list[Dict[T, int]] - self.assertEqual(L4b.__args__, (Dict[T, int],)) - self.assertEqual(L4b.__parameters__, (T,)) - L5 = list[Callable[[K, V], K]] - self.assertEqual(L5.__args__, (Callable[[K, V], K],)) - self.assertEqual(L5.__parameters__, (K, V)) - - def test_parameter_chaining(self): - from typing import List, Dict, Union, Callable - self.assertEqual(list[T][int], list[int]) - self.assertEqual(dict[str, T][int], dict[str, int]) - self.assertEqual(dict[T, int][str], dict[str, int]) - self.assertEqual(dict[K, V][str, int], dict[str, int]) - self.assertEqual(dict[T, T][int], dict[int, int]) - - self.assertEqual(list[list[T]][int], list[list[int]]) - self.assertEqual(list[dict[T, int]][str], list[dict[str, int]]) - self.assertEqual(list[dict[str, T]][int], list[dict[str, int]]) - self.assertEqual(list[dict[K, V]][str, int], list[dict[str, int]]) - self.assertEqual(dict[T, list[int]][str], dict[str, list[int]]) - - self.assertEqual(list[List[T]][int], list[List[int]]) - self.assertEqual(list[Dict[K, V]][str, int], list[Dict[str, int]]) - self.assertEqual(list[Union[K, V]][str, int], list[Union[str, int]]) - self.assertEqual(list[Callable[[K, V], K]][str, int], - list[Callable[[str, int], str]]) - self.assertEqual(dict[T, List[int]][str], dict[str, List[int]]) - - with self.assertRaises(TypeError): - list[int][int] - dict[T, int][str, int] - dict[str, T][str, int] - dict[T, T][str, int] - - def test_equality(self): - self.assertEqual(list[int], list[int]) - self.assertEqual(dict[str, int], dict[str, int]) - self.assertNotEqual(dict[str, int], dict[str, str]) - self.assertNotEqual(list, list[int]) - self.assertNotEqual(list[int], list) - - def test_isinstance(self): - self.assertTrue(isinstance([], list)) - with self.assertRaises(TypeError): - isinstance([], list[str]) - - def test_issubclass(self): - class L(list): ... - self.assertTrue(issubclass(L, list)) - with self.assertRaises(TypeError): - issubclass(L, list[str]) - - def test_type_generic(self): - t = type[int] - Test = t('Test', (), {}) - self.assertTrue(isinstance(Test, type)) - test = Test() - self.assertEqual(t(test), Test) - self.assertEqual(t(0), int) - - def test_type_subclass_generic(self): - class MyType(type): - pass - with self.assertRaises(TypeError): - MyType[int] - - def test_pickle(self): - alias = GenericAlias(list, T) - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - s = pickle.dumps(alias, proto) - loaded = pickle.loads(s) - self.assertEqual(loaded.__origin__, alias.__origin__) - self.assertEqual(loaded.__args__, alias.__args__) - self.assertEqual(loaded.__parameters__, alias.__parameters__) - - def test_copy(self): - class X(list): - def __copy__(self): - return self - def __deepcopy__(self, memo): - return self - - for origin in list, deque, X: - alias = GenericAlias(origin, T) - copied = copy.copy(alias) - self.assertEqual(copied.__origin__, alias.__origin__) - self.assertEqual(copied.__args__, alias.__args__) - self.assertEqual(copied.__parameters__, alias.__parameters__) - copied = copy.deepcopy(alias) - self.assertEqual(copied.__origin__, alias.__origin__) - self.assertEqual(copied.__args__, alias.__args__) - self.assertEqual(copied.__parameters__, alias.__parameters__) - - def test_union(self): - a = typing.Union[list[int], list[str]] - self.assertEqual(a.__args__, (list[int], list[str])) - self.assertEqual(a.__parameters__, ()) - - def test_union_generic(self): - a = typing.Union[list[T], tuple[T, ...]] - self.assertEqual(a.__args__, (list[T], tuple[T, ...])) - self.assertEqual(a.__parameters__, (T,)) - - def test_dir(self): - dir_of_gen_alias = set(dir(list[int])) - self.assertTrue(dir_of_gen_alias.issuperset(dir(list))) - for generic_alias_property in ("__origin__", "__args__", "__parameters__"): - self.assertIn(generic_alias_property, dir_of_gen_alias) - - def test_weakref(self): - for t in self.generic_types: - if t is None: - continue - tname = t.__name__ - with self.subTest(f"Testing {tname}"): - alias = t[int] - self.assertEqual(ref(alias)(), alias) - - def test_no_kwargs(self): - # bpo-42576 - with self.assertRaises(TypeError): - GenericAlias(bad=float) - - def test_subclassing_types_genericalias(self): - class SubClass(GenericAlias): ... - alias = SubClass(list, int) - class Bad(GenericAlias): - def __new__(cls, *args, **kwargs): - super().__new__(cls, *args, **kwargs) - - self.assertEqual(alias, list[int]) - with self.assertRaises(TypeError): - Bad(list, int, bad=int) - - -if __name__ == "__main__": - unittest.main() +"""Tests for C-implemented GenericAlias.""" + +import unittest +import pickle +import copy +from collections import ( + defaultdict, deque, OrderedDict, Counter, UserDict, UserList +) +from collections.abc import * +from concurrent.futures import Future +from concurrent.futures.thread import _WorkItem +from contextlib import AbstractContextManager, AbstractAsyncContextManager +from contextvars import ContextVar, Token +from dataclasses import Field +from functools import partial, partialmethod, cached_property +from mailbox import Mailbox, _PartialFile +try: + import ctypes +except ImportError: + ctypes = None +from difflib import SequenceMatcher +from filecmp import dircmp +from fileinput import FileInput +from itertools import chain +from http.cookies import Morsel +from multiprocessing.managers import ValueProxy +from multiprocessing.pool import ApplyResult +try: + from multiprocessing.shared_memory import ShareableList +except ImportError: + # multiprocessing.shared_memory is not available on e.g. Android + ShareableList = None +from multiprocessing.queues import SimpleQueue as MPSimpleQueue +from os import DirEntry +from re import Pattern, Match +from types import GenericAlias, MappingProxyType, AsyncGeneratorType +from tempfile import TemporaryDirectory, SpooledTemporaryFile +from urllib.parse import SplitResult, ParseResult +from unittest.case import _AssertRaisesContext +from queue import Queue, SimpleQueue +from weakref import WeakSet, ReferenceType, ref +import typing + +from typing import TypeVar +T = TypeVar('T') +K = TypeVar('K') +V = TypeVar('V') + +class BaseTest(unittest.TestCase): + """Test basics.""" + generic_types = [type, tuple, list, dict, set, frozenset, enumerate, + defaultdict, deque, + SequenceMatcher, + dircmp, + FileInput, + OrderedDict, Counter, UserDict, UserList, + Pattern, Match, + partial, partialmethod, cached_property, + AbstractContextManager, AbstractAsyncContextManager, + Awaitable, Coroutine, + AsyncIterable, AsyncIterator, + AsyncGenerator, Generator, + Iterable, Iterator, + Reversible, + Container, Collection, + Mailbox, _PartialFile, + ContextVar, Token, + Field, + Set, MutableSet, + Mapping, MutableMapping, MappingView, + KeysView, ItemsView, ValuesView, + Sequence, MutableSequence, + MappingProxyType, AsyncGeneratorType, + DirEntry, + chain, + TemporaryDirectory, SpooledTemporaryFile, + Queue, SimpleQueue, + _AssertRaisesContext, + SplitResult, ParseResult, + ValueProxy, ApplyResult, + WeakSet, ReferenceType, ref, + ShareableList, MPSimpleQueue, + Future, _WorkItem, + Morsel] + if ctypes is not None: + generic_types.extend((ctypes.Array, ctypes.LibraryLoader)) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_subscriptable(self): + for t in self.generic_types: + if t is None: + continue + tname = t.__name__ + with self.subTest(f"Testing {tname}"): + alias = t[int] + self.assertIs(alias.__origin__, t) + self.assertEqual(alias.__args__, (int,)) + self.assertEqual(alias.__parameters__, ()) + + def test_unsubscriptable(self): + for t in int, str, float, Sized, Hashable: + tname = t.__name__ + with self.subTest(f"Testing {tname}"): + with self.assertRaises(TypeError): + t[int] + + def test_instantiate(self): + for t in tuple, list, dict, set, frozenset, defaultdict, deque: + tname = t.__name__ + with self.subTest(f"Testing {tname}"): + alias = t[int] + self.assertEqual(alias(), t()) + if t is dict: + self.assertEqual(alias(iter([('a', 1), ('b', 2)])), dict(a=1, b=2)) + self.assertEqual(alias(a=1, b=2), dict(a=1, b=2)) + elif t is defaultdict: + def default(): + return 'value' + a = alias(default) + d = defaultdict(default) + self.assertEqual(a['test'], d['test']) + else: + self.assertEqual(alias(iter((1, 2, 3))), t((1, 2, 3))) + + def test_unbound_methods(self): + t = list[int] + a = t() + t.append(a, 'foo') + self.assertEqual(a, ['foo']) + x = t.__getitem__(a, 0) + self.assertEqual(x, 'foo') + self.assertEqual(t.__len__(a), 1) + + def test_subclassing(self): + class C(list[int]): + pass + self.assertEqual(C.__bases__, (list,)) + self.assertEqual(C.__class__, type) + + def test_class_methods(self): + t = dict[int, None] + self.assertEqual(dict.fromkeys(range(2)), {0: None, 1: None}) # This works + self.assertEqual(t.fromkeys(range(2)), {0: None, 1: None}) # Should be equivalent + + def test_no_chaining(self): + t = list[int] + with self.assertRaises(TypeError): + t[int] + + def test_generic_subclass(self): + class MyList(list): + pass + t = MyList[int] + self.assertIs(t.__origin__, MyList) + self.assertEqual(t.__args__, (int,)) + self.assertEqual(t.__parameters__, ()) + + def test_repr(self): + class MyList(list): + pass + self.assertEqual(repr(list[str]), 'list[str]') + self.assertEqual(repr(list[()]), 'list[()]') + self.assertEqual(repr(tuple[int, ...]), 'tuple[int, ...]') + self.assertTrue(repr(MyList[int]).endswith('.BaseTest.test_repr..MyList[int]')) + self.assertEqual(repr(list[str]()), '[]') # instances should keep their normal repr + + def test_exposed_type(self): + import types + a = types.GenericAlias(list, int) + self.assertEqual(str(a), 'list[int]') + self.assertIs(a.__origin__, list) + self.assertEqual(a.__args__, (int,)) + self.assertEqual(a.__parameters__, ()) + + def test_parameters(self): + from typing import List, Dict, Callable + D0 = dict[str, int] + self.assertEqual(D0.__args__, (str, int)) + self.assertEqual(D0.__parameters__, ()) + D1a = dict[str, V] + self.assertEqual(D1a.__args__, (str, V)) + self.assertEqual(D1a.__parameters__, (V,)) + D1b = dict[K, int] + self.assertEqual(D1b.__args__, (K, int)) + self.assertEqual(D1b.__parameters__, (K,)) + D2a = dict[K, V] + self.assertEqual(D2a.__args__, (K, V)) + self.assertEqual(D2a.__parameters__, (K, V)) + D2b = dict[T, T] + self.assertEqual(D2b.__args__, (T, T)) + self.assertEqual(D2b.__parameters__, (T,)) + L0 = list[str] + self.assertEqual(L0.__args__, (str,)) + self.assertEqual(L0.__parameters__, ()) + L1 = list[T] + self.assertEqual(L1.__args__, (T,)) + self.assertEqual(L1.__parameters__, (T,)) + L2 = list[list[T]] + self.assertEqual(L2.__args__, (list[T],)) + self.assertEqual(L2.__parameters__, (T,)) + L3 = list[List[T]] + self.assertEqual(L3.__args__, (List[T],)) + self.assertEqual(L3.__parameters__, (T,)) + L4a = list[Dict[K, V]] + self.assertEqual(L4a.__args__, (Dict[K, V],)) + self.assertEqual(L4a.__parameters__, (K, V)) + L4b = list[Dict[T, int]] + self.assertEqual(L4b.__args__, (Dict[T, int],)) + self.assertEqual(L4b.__parameters__, (T,)) + L5 = list[Callable[[K, V], K]] + self.assertEqual(L5.__args__, (Callable[[K, V], K],)) + self.assertEqual(L5.__parameters__, (K, V)) + + def test_parameter_chaining(self): + from typing import List, Dict, Union, Callable + self.assertEqual(list[T][int], list[int]) + self.assertEqual(dict[str, T][int], dict[str, int]) + self.assertEqual(dict[T, int][str], dict[str, int]) + self.assertEqual(dict[K, V][str, int], dict[str, int]) + self.assertEqual(dict[T, T][int], dict[int, int]) + + self.assertEqual(list[list[T]][int], list[list[int]]) + self.assertEqual(list[dict[T, int]][str], list[dict[str, int]]) + self.assertEqual(list[dict[str, T]][int], list[dict[str, int]]) + self.assertEqual(list[dict[K, V]][str, int], list[dict[str, int]]) + self.assertEqual(dict[T, list[int]][str], dict[str, list[int]]) + + self.assertEqual(list[List[T]][int], list[List[int]]) + self.assertEqual(list[Dict[K, V]][str, int], list[Dict[str, int]]) + self.assertEqual(list[Union[K, V]][str, int], list[Union[str, int]]) + self.assertEqual(list[Callable[[K, V], K]][str, int], + list[Callable[[str, int], str]]) + self.assertEqual(dict[T, List[int]][str], dict[str, List[int]]) + + with self.assertRaises(TypeError): + list[int][int] + dict[T, int][str, int] + dict[str, T][str, int] + dict[T, T][str, int] + + def test_equality(self): + self.assertEqual(list[int], list[int]) + self.assertEqual(dict[str, int], dict[str, int]) + self.assertNotEqual(dict[str, int], dict[str, str]) + self.assertNotEqual(list, list[int]) + self.assertNotEqual(list[int], list) + + def test_isinstance(self): + self.assertTrue(isinstance([], list)) + with self.assertRaises(TypeError): + isinstance([], list[str]) + + def test_issubclass(self): + class L(list): ... + self.assertTrue(issubclass(L, list)) + with self.assertRaises(TypeError): + issubclass(L, list[str]) + + def test_type_generic(self): + t = type[int] + Test = t('Test', (), {}) + self.assertTrue(isinstance(Test, type)) + test = Test() + self.assertEqual(t(test), Test) + self.assertEqual(t(0), int) + + def test_type_subclass_generic(self): + class MyType(type): + pass + with self.assertRaises(TypeError): + MyType[int] + + def test_pickle(self): + alias = GenericAlias(list, T) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + s = pickle.dumps(alias, proto) + loaded = pickle.loads(s) + self.assertEqual(loaded.__origin__, alias.__origin__) + self.assertEqual(loaded.__args__, alias.__args__) + self.assertEqual(loaded.__parameters__, alias.__parameters__) + + def test_copy(self): + class X(list): + def __copy__(self): + return self + def __deepcopy__(self, memo): + return self + + for origin in list, deque, X: + alias = GenericAlias(origin, T) + copied = copy.copy(alias) + self.assertEqual(copied.__origin__, alias.__origin__) + self.assertEqual(copied.__args__, alias.__args__) + self.assertEqual(copied.__parameters__, alias.__parameters__) + copied = copy.deepcopy(alias) + self.assertEqual(copied.__origin__, alias.__origin__) + self.assertEqual(copied.__args__, alias.__args__) + self.assertEqual(copied.__parameters__, alias.__parameters__) + + def test_union(self): + a = typing.Union[list[int], list[str]] + self.assertEqual(a.__args__, (list[int], list[str])) + self.assertEqual(a.__parameters__, ()) + + def test_union_generic(self): + a = typing.Union[list[T], tuple[T, ...]] + self.assertEqual(a.__args__, (list[T], tuple[T, ...])) + self.assertEqual(a.__parameters__, (T,)) + + def test_dir(self): + dir_of_gen_alias = set(dir(list[int])) + self.assertTrue(dir_of_gen_alias.issuperset(dir(list))) + for generic_alias_property in ("__origin__", "__args__", "__parameters__"): + self.assertIn(generic_alias_property, dir_of_gen_alias) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_weakref(self): + for t in self.generic_types: + if t is None: + continue + tname = t.__name__ + with self.subTest(f"Testing {tname}"): + alias = t[int] + self.assertEqual(ref(alias)(), alias) + + def test_no_kwargs(self): + # bpo-42576 + with self.assertRaises(TypeError): + GenericAlias(bad=float) + + def test_subclassing_types_genericalias(self): + class SubClass(GenericAlias): ... + alias = SubClass(list, int) + class Bad(GenericAlias): + def __new__(cls, *args, **kwargs): + super().__new__(cls, *args, **kwargs) + + self.assertEqual(alias, list[int]) + with self.assertRaises(TypeError): + Bad(list, int, bad=int) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py index 13f7e4ac4..96db31b26 100644 --- a/Lib/test/test_glob.py +++ b/Lib/test/test_glob.py @@ -9,6 +9,7 @@ from test.support.os_helper import (TESTFN, skip_unless_symlink, class GlobTests(unittest.TestCase): + dir_fd = None def norm(self, *parts): return os.path.normpath(os.path.join(self.tempdir, *parts)) @@ -38,8 +39,14 @@ class GlobTests(unittest.TestCase): os.symlink(self.norm('broken'), self.norm('sym1')) os.symlink('broken', self.norm('sym2')) os.symlink(os.path.join('a', 'bcd'), self.norm('sym3')) + if {os.open, os.stat} <= os.supports_dir_fd and os.scandir in os.supports_fd: + self.dir_fd = os.open(self.tempdir, os.O_RDONLY | os.O_DIRECTORY) + else: + self.dir_fd = None def tearDown(self): + if self.dir_fd is not None: + os.close(self.dir_fd) shutil.rmtree(self.tempdir) def glob(self, *parts, **kwargs): @@ -53,6 +60,41 @@ class GlobTests(unittest.TestCase): bres = [os.fsencode(x) for x in res] self.assertCountEqual(glob.glob(os.fsencode(p), **kwargs), bres) self.assertCountEqual(glob.iglob(os.fsencode(p), **kwargs), bres) + + with change_cwd(self.tempdir): + res2 = glob.glob(pattern, **kwargs) + for x in res2: + self.assertFalse(os.path.isabs(x), x) + if pattern == '**' or pattern == '**' + os.sep: + expected = res[1:] + else: + expected = res + self.assertCountEqual([os.path.join(self.tempdir, x) for x in res2], + expected) + self.assertCountEqual(glob.iglob(pattern, **kwargs), res2) + bpattern = os.fsencode(pattern) + bres2 = [os.fsencode(x) for x in res2] + self.assertCountEqual(glob.glob(bpattern, **kwargs), bres2) + self.assertCountEqual(glob.iglob(bpattern, **kwargs), bres2) + + self.assertCountEqual(glob.glob(pattern, root_dir=self.tempdir, **kwargs), res2) + self.assertCountEqual(glob.iglob(pattern, root_dir=self.tempdir, **kwargs), res2) + btempdir = os.fsencode(self.tempdir) + self.assertCountEqual( + glob.glob(bpattern, root_dir=btempdir, **kwargs), bres2) + self.assertCountEqual( + glob.iglob(bpattern, root_dir=btempdir, **kwargs), bres2) + + if self.dir_fd is not None: + self.assertCountEqual( + glob.glob(pattern, dir_fd=self.dir_fd, **kwargs), res2) + self.assertCountEqual( + glob.iglob(pattern, dir_fd=self.dir_fd, **kwargs), res2) + self.assertCountEqual( + glob.glob(bpattern, dir_fd=self.dir_fd, **kwargs), bres2) + self.assertCountEqual( + glob.iglob(bpattern, dir_fd=self.dir_fd, **kwargs), bres2) + return res def assertSequencesEqual_noorder(self, l1, l2): @@ -78,6 +120,14 @@ class GlobTests(unittest.TestCase): res = glob.glob(os.path.join(os.fsencode(os.curdir), b'*')) self.assertEqual({type(r) for r in res}, {bytes}) + def test_glob_empty_pattern(self): + self.assertEqual(glob.glob(''), []) + self.assertEqual(glob.glob(b''), []) + self.assertEqual(glob.glob('', root_dir=self.tempdir), []) + self.assertEqual(glob.glob(b'', root_dir=os.fsencode(self.tempdir)), []) + self.assertEqual(glob.glob('', dir_fd=self.dir_fd), []) + self.assertEqual(glob.glob(b'', dir_fd=self.dir_fd), []) + def test_glob_one_directory(self): eq = self.assertSequencesEqual_noorder eq(self.glob('a*'), map(self.norm, ['a', 'aab', 'aaa'])) @@ -264,6 +314,23 @@ class GlobTests(unittest.TestCase): expect += [join('sym3', 'EF')] eq(glob.glob(join('**', 'EF'), recursive=True), expect) + def test_glob_many_open_files(self): + depth = 30 + base = os.path.join(self.tempdir, 'deep') + p = os.path.join(base, *(['d']*depth)) + os.makedirs(p) + pattern = os.path.join(base, *(['*']*depth)) + iters = [glob.iglob(pattern, recursive=True) for j in range(100)] + for it in iters: + self.assertEqual(next(it), p) + pattern = os.path.join(base, '**', 'd') + iters = [glob.iglob(pattern, recursive=True) for j in range(100)] + p = base + for i in range(depth): + p = os.path.join(p, 'd') + for it in iters: + self.assertEqual(next(it), p) + @skip_unless_symlink class SymlinkLoopGlobTests(unittest.TestCase): diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py index c71d05529..d0bde3fd0 100644 --- a/Lib/test/test_global.py +++ b/Lib/test/test_global.py @@ -1,6 +1,6 @@ """Verify that warnings are issued for global statements following use.""" -from test.support import run_unittest, check_syntax_error +from test.support import check_syntax_error from test.support.warnings_helper import check_warnings import unittest import warnings @@ -53,10 +53,12 @@ x = 2 compile(prog_text_4, "", "exec") -def test_main(): - with warnings.catch_warnings(): - warnings.filterwarnings("error", module="") - run_unittest(GlobalTests) +def setUpModule(): + cm = warnings.catch_warnings() + cm.__enter__() + unittest.addModuleCleanup(cm.__exit__, None, None, None) + warnings.filterwarnings("error", module="") + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py new file mode 100644 index 000000000..91a29da8a --- /dev/null +++ b/Lib/test/test_hash.py @@ -0,0 +1,358 @@ +# test the invariant that +# iff a==b then hash(a)==hash(b) +# +# Also test that hash implementations are inherited as expected + +import datetime +import os +import sys +import unittest +from test.support.script_helper import assert_python_ok +from collections.abc import Hashable + +IS_64BIT = sys.maxsize > 2**32 + +def lcg(x, length=16): + """Linear congruential generator""" + if x == 0: + return bytes(length) + out = bytearray(length) + for i in range(length): + x = (214013 * x + 2531011) & 0x7fffffff + out[i] = (x >> 16) & 0xff + return bytes(out) + +def pysiphash(uint64): + """Convert SipHash24 output to Py_hash_t + """ + assert 0 <= uint64 < (1 << 64) + # simple unsigned to signed int64 + if uint64 > (1 << 63) - 1: + int64 = uint64 - (1 << 64) + else: + int64 = uint64 + # mangle uint64 to uint32 + uint32 = (uint64 ^ uint64 >> 32) & 0xffffffff + # simple unsigned to signed int32 + if uint32 > (1 << 31) - 1: + int32 = uint32 - (1 << 32) + else: + int32 = uint32 + return int32, int64 + +def skip_unless_internalhash(test): + """Skip decorator for tests that depend on SipHash24 or FNV""" + ok = sys.hash_info.algorithm in {"fnv", "siphash24"} + msg = "Requires SipHash24 or FNV" + return test if ok else unittest.skip(msg)(test) + + +class HashEqualityTestCase(unittest.TestCase): + + def same_hash(self, *objlist): + # Hash each object given and fail if + # the hash values are not all the same. + hashed = list(map(hash, objlist)) + for h in hashed[1:]: + if h != hashed[0]: + self.fail("hashed values differ: %r" % (objlist,)) + + def test_numeric_literals(self): + self.same_hash(1, 1, 1.0, 1.0+0.0j) + self.same_hash(0, 0.0, 0.0+0.0j) + self.same_hash(-1, -1.0, -1.0+0.0j) + self.same_hash(-2, -2.0, -2.0+0.0j) + + def test_coerced_integers(self): + self.same_hash(int(1), int(1), float(1), complex(1), + int('1'), float('1.0')) + self.same_hash(int(-2**31), float(-2**31)) + self.same_hash(int(1-2**31), float(1-2**31)) + self.same_hash(int(2**31-1), float(2**31-1)) + # for 64-bit platforms + self.same_hash(int(2**31), float(2**31)) + self.same_hash(int(-2**63), float(-2**63)) + self.same_hash(int(2**63), float(2**63)) + + def test_coerced_floats(self): + self.same_hash(int(1.23e300), float(1.23e300)) + self.same_hash(float(0.5), complex(0.5, 0.0)) + + def test_unaligned_buffers(self): + # The hash function for bytes-like objects shouldn't have + # alignment-dependent results (example in issue #16427). + b = b"123456789abcdefghijklmnopqrstuvwxyz" * 128 + for i in range(16): + for j in range(16): + aligned = b[i:128+j] + unaligned = memoryview(b)[i:128+j] + self.assertEqual(hash(aligned), hash(unaligned)) + + +_default_hash = object.__hash__ +class DefaultHash(object): pass + +_FIXED_HASH_VALUE = 42 +class FixedHash(object): + def __hash__(self): + return _FIXED_HASH_VALUE + +class OnlyEquality(object): + def __eq__(self, other): + return self is other + +class OnlyInequality(object): + def __ne__(self, other): + return self is not other + +class InheritedHashWithEquality(FixedHash, OnlyEquality): pass +class InheritedHashWithInequality(FixedHash, OnlyInequality): pass + +class NoHash(object): + __hash__ = None + +class HashInheritanceTestCase(unittest.TestCase): + default_expected = [object(), + DefaultHash(), + OnlyInequality(), + ] + fixed_expected = [FixedHash(), + InheritedHashWithEquality(), + InheritedHashWithInequality(), + ] + error_expected = [NoHash(), + OnlyEquality(), + ] + + def test_default_hash(self): + for obj in self.default_expected: + self.assertEqual(hash(obj), _default_hash(obj)) + + def test_fixed_hash(self): + for obj in self.fixed_expected: + self.assertEqual(hash(obj), _FIXED_HASH_VALUE) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_error_hash(self): + for obj in self.error_expected: + self.assertRaises(TypeError, hash, obj) + + def test_hashable(self): + objects = (self.default_expected + + self.fixed_expected) + for obj in objects: + self.assertIsInstance(obj, Hashable) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_not_hashable(self): + for obj in self.error_expected: + self.assertNotIsInstance(obj, Hashable) + + +# Issue #4701: Check that some builtin types are correctly hashable +class DefaultIterSeq(object): + seq = range(10) + def __len__(self): + return len(self.seq) + def __getitem__(self, index): + return self.seq[index] + +class HashBuiltinsTestCase(unittest.TestCase): + hashes_to_check = [enumerate(range(10)), + iter(DefaultIterSeq()), + iter(lambda: 0, 0), + ] + + def test_hashes(self): + _default_hash = object.__hash__ + for obj in self.hashes_to_check: + self.assertEqual(hash(obj), _default_hash(obj)) + +class HashRandomizationTests: + + # Each subclass should define a field "repr_", containing the repr() of + # an object to be tested + + def get_hash_command(self, repr_): + return 'print(hash(eval(%a)))' % repr_ + + def get_hash(self, repr_, seed=None): + env = os.environ.copy() + env['__cleanenv'] = True # signal to assert_python not to do a copy + # of os.environ on its own + if seed is not None: + env['PYTHONHASHSEED'] = str(seed) + else: + env.pop('PYTHONHASHSEED', None) + out = assert_python_ok( + '-c', self.get_hash_command(repr_), + **env) + stdout = out[1].strip() + return int(stdout) + + def test_randomized_hash(self): + # two runs should return different hashes + run1 = self.get_hash(self.repr_, seed='random') + run2 = self.get_hash(self.repr_, seed='random') + self.assertNotEqual(run1, run2) + +class StringlikeHashRandomizationTests(HashRandomizationTests): + repr_ = None + repr_long = None + + # 32bit little, 64bit little, 32bit big, 64bit big + known_hashes = { + 'djba33x': [ # only used for small strings + # seed 0, 'abc' + [193485960, 193485960, 193485960, 193485960], + # seed 42, 'abc' + [-678966196, 573763426263223372, -820489388, -4282905804826039665], + ], + 'siphash24': [ + # NOTE: PyUCS2 layout depends on endianness + # seed 0, 'abc' + [1198583518, 4596069200710135518, 1198583518, 4596069200710135518], + # seed 42, 'abc' + [273876886, -4501618152524544106, 273876886, -4501618152524544106], + # seed 42, 'abcdefghijk' + [-1745215313, 4436719588892876975, -1745215313, 4436719588892876975], + # seed 0, 'äú∑ℇ' + [493570806, 5749986484189612790, -1006381564, -5915111450199468540], + # seed 42, 'äú∑ℇ' + [-1677110816, -2947981342227738144, -1860207793, -4296699217652516017], + ], + 'fnv': [ + # seed 0, 'abc' + [-1600925533, 1453079729188098211, -1600925533, + 1453079729188098211], + # seed 42, 'abc' + [-206076799, -4410911502303878509, -1024014457, + -3570150969479994130], + # seed 42, 'abcdefghijk' + [811136751, -5046230049376118746, -77208053 , + -4779029615281019666], + # seed 0, 'äú∑ℇ' + [44402817, 8998297579845987431, -1956240331, + -782697888614047887], + # seed 42, 'äú∑ℇ' + [-283066365, -4576729883824601543, -271871407, + -3927695501187247084], + ] + } + + def get_expected_hash(self, position, length): + if length < sys.hash_info.cutoff: + algorithm = "djba33x" + else: + algorithm = sys.hash_info.algorithm + if sys.byteorder == 'little': + platform = 1 if IS_64BIT else 0 + else: + assert(sys.byteorder == 'big') + platform = 3 if IS_64BIT else 2 + return self.known_hashes[algorithm][position][platform] + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_null_hash(self): + # PYTHONHASHSEED=0 disables the randomized hash + known_hash_of_obj = self.get_expected_hash(0, 3) + + # Randomization is enabled by default: + self.assertNotEqual(self.get_hash(self.repr_), known_hash_of_obj) + + # It can also be disabled by setting the seed to 0: + self.assertEqual(self.get_hash(self.repr_, seed=0), known_hash_of_obj) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + @skip_unless_internalhash + def test_fixed_hash(self): + # test a fixed seed for the randomized hash + # Note that all types share the same values: + h = self.get_expected_hash(1, 3) + self.assertEqual(self.get_hash(self.repr_, seed=42), h) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + @skip_unless_internalhash + def test_long_fixed_hash(self): + if self.repr_long is None: + return + h = self.get_expected_hash(2, 11) + self.assertEqual(self.get_hash(self.repr_long, seed=42), h) + + +class StrHashRandomizationTests(StringlikeHashRandomizationTests, + unittest.TestCase): + repr_ = repr('abc') + repr_long = repr('abcdefghijk') + repr_ucs2 = repr('äú∑ℇ') + + @skip_unless_internalhash + def test_empty_string(self): + self.assertEqual(hash(""), 0) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + @skip_unless_internalhash + def test_ucs2_string(self): + h = self.get_expected_hash(3, 6) + self.assertEqual(self.get_hash(self.repr_ucs2, seed=0), h) + h = self.get_expected_hash(4, 6) + self.assertEqual(self.get_hash(self.repr_ucs2, seed=42), h) + +class BytesHashRandomizationTests(StringlikeHashRandomizationTests, + unittest.TestCase): + repr_ = repr(b'abc') + repr_long = repr(b'abcdefghijk') + + @skip_unless_internalhash + def test_empty_string(self): + self.assertEqual(hash(b""), 0) + +class MemoryviewHashRandomizationTests(StringlikeHashRandomizationTests, + unittest.TestCase): + repr_ = "memoryview(b'abc')" + repr_long = "memoryview(b'abcdefghijk')" + + @skip_unless_internalhash + def test_empty_string(self): + self.assertEqual(hash(memoryview(b"")), 0) + +class DatetimeTests(HashRandomizationTests): + def get_hash_command(self, repr_): + return 'import datetime; print(hash(%s))' % repr_ + +class DatetimeDateTests(DatetimeTests, unittest.TestCase): + repr_ = repr(datetime.date(1066, 10, 14)) + +class DatetimeDatetimeTests(DatetimeTests, unittest.TestCase): + repr_ = repr(datetime.datetime(1, 2, 3, 4, 5, 6, 7)) + +class DatetimeTimeTests(DatetimeTests, unittest.TestCase): + repr_ = repr(datetime.time(0)) + + +class HashDistributionTestCase(unittest.TestCase): + + def test_hash_distribution(self): + # check for hash collision + base = "abcdefghabcdefg" + for i in range(1, len(base)): + prefix = base[:i] + with self.subTest(prefix=prefix): + s15 = set() + s255 = set() + for c in range(256): + h = hash(prefix + chr(c)) + s15.add(h & 0xf) + s255.add(h & 0xff) + # SipHash24 distribution depends on key, usually > 60% + self.assertGreater(len(s15), 8, prefix) + self.assertGreater(len(s255), 128, prefix) + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py index 054e0ab95..d61b27bde 100644 --- a/Lib/test/test_hashlib.py +++ b/Lib/test/test_hashlib.py @@ -21,6 +21,7 @@ from test import support from test.support import _4G, bigmemtest from test.support.import_helper import import_fresh_module from test.support import threading_helper +from test.support import warnings_helper from http.client import HTTPException # Were we compiled --with-pydebug or with #define Py_DEBUG? @@ -47,12 +48,15 @@ else: builtin_hashlib = None try: - from _hashlib import HASH, HASHXOF, openssl_md_meth_names + from _hashlib import HASH, HASHXOF, openssl_md_meth_names, get_fips_mode except ImportError: HASH = None HASHXOF = None openssl_md_meth_names = frozenset() + def get_fips_mode(): + return 0 + try: import _blake2 except ImportError: @@ -60,12 +64,9 @@ except ImportError: requires_blake2 = unittest.skipUnless(_blake2, 'requires _blake2') -try: - import _sha3 -except ImportError: - _sha3 = None - -requires_sha3 = unittest.skipUnless(_sha3, 'requires _sha3') +# bpo-46913: Don't test the _sha3 extension on a Python UBSAN build +SKIP_SHA3 = support.check_sanitizer(ub=True) +requires_sha3 = unittest.skipUnless(not SKIP_SHA3, 'requires _sha3') def hexstr(s): @@ -79,11 +80,10 @@ def hexstr(s): URL = "http://www.pythontest.net/hashlib/{}.txt" - def read_vectors(hash_name): url = URL.format(hash_name) try: - testdata = support.open_urlresource(url) + testdata = support.open_urlresource(url, encoding="utf-8") except (OSError, HTTPException): raise unittest.SkipTest("Could not retrieve {}".format(url)) with testdata: @@ -97,13 +97,13 @@ def read_vectors(hash_name): class HashLibTestCase(unittest.TestCase): - supported_hash_names = ('md5', 'MD5', 'sha1', 'SHA1', - 'sha224', 'SHA224', 'sha256', 'SHA256', - 'sha384', 'SHA384', 'sha512', 'SHA512', - 'blake2b', 'blake2s', - 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512', - # TODO: RUSTPYTHON - # 'shake_128', 'shake_256' + supported_hash_names = ( 'md5', 'MD5', 'sha1', 'SHA1', + 'sha224', 'SHA224', 'sha256', 'SHA256', + 'sha384', 'SHA384', 'sha512', 'SHA512', + 'blake2b', 'blake2s', + 'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512', + # TODO: RUSTPYTHON + # 'shake_128', 'shake_256' ) shakes = {'shake_128', 'shake_256'} @@ -131,13 +131,14 @@ class HashLibTestCase(unittest.TestCase): self.constructors_to_test = {} for algorithm in algorithms: + if SKIP_SHA3 and algorithm.startswith('sha3_'): + continue self.constructors_to_test[algorithm] = set() # For each algorithm, test the direct constructor and the use # of hashlib.new given the algorithm name. for algorithm, constructors in self.constructors_to_test.items(): constructors.add(getattr(hashlib, algorithm)) - def _test_algorithm_via_hashlib_new(data=None, _alg=algorithm, **kwargs): if data is None: return hashlib.new(_alg, **kwargs) @@ -184,14 +185,15 @@ class HashLibTestCase(unittest.TestCase): add_builtin_constructor('blake2s') add_builtin_constructor('blake2b') - _sha3 = self._conditional_import_module('_sha3') - if _sha3: - add_builtin_constructor('sha3_224') - add_builtin_constructor('sha3_256') - add_builtin_constructor('sha3_384') - add_builtin_constructor('sha3_512') - add_builtin_constructor('shake_128') - add_builtin_constructor('shake_256') + if not SKIP_SHA3: + _sha3 = self._conditional_import_module('_sha3') + if _sha3: + add_builtin_constructor('sha3_224') + add_builtin_constructor('sha3_256') + add_builtin_constructor('sha3_384') + add_builtin_constructor('sha3_512') + add_builtin_constructor('shake_128') + add_builtin_constructor('shake_256') super(HashLibTestCase, self).__init__(*args, **kwargs) @@ -202,10 +204,7 @@ class HashLibTestCase(unittest.TestCase): @property def is_fips_mode(self): - if hasattr(self._hashlib, "get_fips_mode"): - return self._hashlib.get_fips_mode() - else: - return None + return get_fips_mode() # TODO: RUSTPYTHON @unittest.expectedFailure @@ -222,14 +221,18 @@ class HashLibTestCase(unittest.TestCase): @unittest.expectedFailure def test_algorithms_guaranteed(self): self.assertEqual(hashlib.algorithms_guaranteed, - set(_algo for _algo in self.supported_hash_names - if _algo.islower())) + set(_algo for _algo in self.supported_hash_names + if _algo.islower())) # TODO: RUSTPYTHON @unittest.expectedFailure def test_algorithms_available(self): self.assertTrue(set(hashlib.algorithms_guaranteed). - issubset(hashlib.algorithms_available)) + issubset(hashlib.algorithms_available)) + # all available algorithms must be loadable, bpo-47101 + self.assertNotIn("undefined", hashlib.algorithms_available) + for name in hashlib.algorithms_available: + digest = hashlib.new(name, usedforsecurity=False) def test_usedforsecurity_true(self): hashlib.new("sha256", usedforsecurity=True) @@ -337,7 +340,7 @@ class HashLibTestCase(unittest.TestCase): aas = b'a' * 128 bees = b'b' * 127 cees = b'c' * 126 - dees = b'd' * 2048 # HASHLIB_GIL_MINSIZE + dees = b'd' * 2048 # HASHLIB_GIL_MINSIZE for cons in self.hash_constructors: m1 = cons(usedforsecurity=False) @@ -375,11 +378,11 @@ class HashLibTestCase(unittest.TestCase): m = hash_object_constructor(data, **kwargs) computed = m.hexdigest() if not shake else m.hexdigest(length) self.assertEqual( - computed, hexdigest, - "Hash algorithm %s constructed using %s returned hexdigest" - " %r for %d byte input data that should have hashed to %r." - % (name, hash_object_constructor, - computed, len(data), hexdigest)) + computed, hexdigest, + "Hash algorithm %s constructed using %s returned hexdigest" + " %r for %d byte input data that should have hashed to %r." + % (name, hash_object_constructor, + computed, len(data), hexdigest)) computed = m.digest() if not shake else m.digest(length) digest = bytes.fromhex(hexdigest) self.assertEqual(computed, digest) @@ -405,6 +408,8 @@ class HashLibTestCase(unittest.TestCase): self.check_no_unicode('blake2b') self.check_no_unicode('blake2s') + # TODO: RUSTPYTHON + @unittest.expectedFailure @requires_sha3 def test_no_unicode_sha3(self): self.check_no_unicode('sha3_224') @@ -466,6 +471,8 @@ class HashLibTestCase(unittest.TestCase): self.assertEqual(m._rate_bits, rate) self.assertEqual(m._suffix, suffix) + # TODO: RUSTPYTHON + @unittest.expectedFailure @requires_sha3 def test_extra_sha3(self): self.check_sha3('sha3_224', 448, 1152, b'\x06') @@ -531,87 +538,91 @@ class HashLibTestCase(unittest.TestCase): self.check('sha1', b"a" * 1000000, "34aa973cd4c4daa4f61eeb2bdbad27316534016f") + # use the examples from Federal Information Processing Standards # Publication 180-2, Secure Hash Standard, 2002 August 1 # http://csrc.nist.gov/publications/fips/fips180-2/fips180-2.pdf def test_case_sha224_0(self): self.check('sha224', b"", - "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f") + "d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f") def test_case_sha224_1(self): self.check('sha224', b"abc", - "23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7") + "23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7") def test_case_sha224_2(self): self.check('sha224', - b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", - "75388b16512776cc5dba5da1fd890150b0c6455cb4f58b1952522525") + b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", + "75388b16512776cc5dba5da1fd890150b0c6455cb4f58b1952522525") def test_case_sha224_3(self): self.check('sha224', b"a" * 1000000, - "20794655980c91d8bbb4c1ea97618a4bf03f42581948b2ee4ee7ad67") + "20794655980c91d8bbb4c1ea97618a4bf03f42581948b2ee4ee7ad67") + def test_case_sha256_0(self): self.check('sha256', b"", - "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") def test_case_sha256_1(self): self.check('sha256', b"abc", - "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") def test_case_sha256_2(self): self.check('sha256', - b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", - "248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1") + b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq", + "248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1") def test_case_sha256_3(self): self.check('sha256', b"a" * 1000000, - "cdc76e5c9914fb9281a1c7e284d73e67f1809a48a497200e046d39ccc7112cd0") + "cdc76e5c9914fb9281a1c7e284d73e67f1809a48a497200e046d39ccc7112cd0") + def test_case_sha384_0(self): self.check('sha384', b"", - "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da" + - "274edebfe76f65fbd51ad2f14898b95b") + "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da"+ + "274edebfe76f65fbd51ad2f14898b95b") def test_case_sha384_1(self): self.check('sha384', b"abc", - "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed" + - "8086072ba1e7cc2358baeca134c825a7") + "cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed"+ + "8086072ba1e7cc2358baeca134c825a7") def test_case_sha384_2(self): self.check('sha384', - b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmn" + + b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmn"+ b"hijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu", - "09330c33f71147e83d192fc782cd1b4753111b173b3b05d22fa08086e3b0f712" + - "fcc7c71a557e2db966c3e9fa91746039") + "09330c33f71147e83d192fc782cd1b4753111b173b3b05d22fa08086e3b0f712"+ + "fcc7c71a557e2db966c3e9fa91746039") def test_case_sha384_3(self): self.check('sha384', b"a" * 1000000, - "9d0e1809716474cb086e834e310a4a1ced149e9c00f248527972cec5704c2a5b" + - "07b8b3dc38ecc4ebae97ddd87f3d8985") + "9d0e1809716474cb086e834e310a4a1ced149e9c00f248527972cec5704c2a5b"+ + "07b8b3dc38ecc4ebae97ddd87f3d8985") + def test_case_sha512_0(self): self.check('sha512', b"", - "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce" + - "47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") + "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce"+ + "47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") def test_case_sha512_1(self): self.check('sha512', b"abc", - "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a" + - "2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f") + "ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a"+ + "2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f") def test_case_sha512_2(self): self.check('sha512', - b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmn" + + b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmn"+ b"hijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu", - "8e959b75dae313da8cf4f72814fc143f8f7779c6eb9f7fa17299aeadb6889018" + - "501d289e4900f7e4331b99dec4b5433ac7d329eeb6dd26545e96e55b874be909") + "8e959b75dae313da8cf4f72814fc143f8f7779c6eb9f7fa17299aeadb6889018"+ + "501d289e4900f7e4331b99dec4b5433ac7d329eeb6dd26545e96e55b874be909") def test_case_sha512_3(self): self.check('sha512', b"a" * 1000000, - "e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973eb" + - "de0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b") + "e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973eb"+ + "de0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b") def check_blake2(self, constructor, salt_size, person_size, key_size, digest_size, max_offset): @@ -664,9 +675,9 @@ class HashLibTestCase(unittest.TestCase): self.assertRaises(ValueError, constructor, inner_size=digest_size+1) constructor(leaf_size=0) - constructor(leaf_size=(1 << 32)-1) + constructor(leaf_size=(1<<32)-1) self.assertRaises(ValueError, constructor, leaf_size=-1) - self.assertRaises(OverflowError, constructor, leaf_size=1 << 32) + self.assertRaises(OverflowError, constructor, leaf_size=1<<32) constructor(node_offset=0) constructor(node_offset=max_offset) @@ -694,7 +705,7 @@ class HashLibTestCase(unittest.TestCase): def blake2_rfc7693(self, constructor, md_len, in_len): def selftest_seq(length, seed): - mask = (1 << 32)-1 + mask = (1<<32)-1 a = (0xDEAD4BAD * seed) & mask b = 1 out = bytearray(length) @@ -716,7 +727,7 @@ class HashLibTestCase(unittest.TestCase): @requires_blake2 def test_blake2b(self): - self.check_blake2(hashlib.blake2b, 16, 16, 64, 64, (1 << 64)-1) + self.check_blake2(hashlib.blake2b, 16, 16, 64, 64, (1<<64)-1) b2b_md_len = [20, 32, 48, 64] b2b_in_len = [0, 3, 128, 129, 255, 1024] self.assertEqual( @@ -726,32 +737,32 @@ class HashLibTestCase(unittest.TestCase): @requires_blake2 def test_case_blake2b_0(self): self.check('blake2b', b"", - "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419" + - "d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce") + "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419"+ + "d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce") @requires_blake2 def test_case_blake2b_1(self): self.check('blake2b', b"abc", - "ba80a53f981c4d0d6a2797b69f12f6e94c212f14685ac4b74b12bb6fdbffa2d1" + - "7d87c5392aab792dc252d5de4533cc9518d38aa8dbf1925ab92386edd4009923") + "ba80a53f981c4d0d6a2797b69f12f6e94c212f14685ac4b74b12bb6fdbffa2d1"+ + "7d87c5392aab792dc252d5de4533cc9518d38aa8dbf1925ab92386edd4009923") @requires_blake2 def test_case_blake2b_all_parameters(self): # This checks that all the parameters work in general, and also that # parameter byte order doesn't get confused on big endian platforms. self.check('blake2b', b"foo", - "920568b0c5873b2f0ab67bedb6cf1b2b", - digest_size=16, - key=b"bar", - salt=b"baz", - person=b"bing", - fanout=2, - depth=3, - leaf_size=4, - node_offset=5, - node_depth=6, - inner_size=7, - last_node=True) + "920568b0c5873b2f0ab67bedb6cf1b2b", + digest_size=16, + key=b"bar", + salt=b"baz", + person=b"bing", + fanout=2, + depth=3, + leaf_size=4, + node_offset=5, + node_depth=6, + inner_size=7, + last_node=True) @requires_blake2 def test_blake2b_vectors(self): @@ -761,7 +772,7 @@ class HashLibTestCase(unittest.TestCase): @requires_blake2 def test_blake2s(self): - self.check_blake2(hashlib.blake2s, 8, 8, 32, 32, (1 << 48)-1) + self.check_blake2(hashlib.blake2s, 8, 8, 32, 32, (1<<48)-1) b2s_md_len = [16, 20, 28, 32] b2s_in_len = [0, 3, 64, 65, 255, 1024] self.assertEqual( @@ -771,30 +782,30 @@ class HashLibTestCase(unittest.TestCase): @requires_blake2 def test_case_blake2s_0(self): self.check('blake2s', b"", - "69217a3079908094e11121d042354a7c1f55b6482ca1a51e1b250dfd1ed0eef9") + "69217a3079908094e11121d042354a7c1f55b6482ca1a51e1b250dfd1ed0eef9") @requires_blake2 def test_case_blake2s_1(self): self.check('blake2s', b"abc", - "508c5e8c327c14e2e1a72ba34eeb452f37458b209ed63a294d999b4c86675982") + "508c5e8c327c14e2e1a72ba34eeb452f37458b209ed63a294d999b4c86675982") @requires_blake2 def test_case_blake2s_all_parameters(self): # This checks that all the parameters work in general, and also that # parameter byte order doesn't get confused on big endian platforms. self.check('blake2s', b"foo", - "bf2a8f7fe3c555012a6f8046e646bc75", - digest_size=16, - key=b"bar", - salt=b"baz", - person=b"bing", - fanout=2, - depth=3, - leaf_size=4, - node_offset=5, - node_depth=6, - inner_size=7, - last_node=True) + "bf2a8f7fe3c555012a6f8046e646bc75", + digest_size=16, + key=b"bar", + salt=b"baz", + person=b"bing", + fanout=2, + depth=3, + leaf_size=4, + node_offset=5, + node_depth=6, + inner_size=7, + last_node=True) @requires_blake2 def test_blake2s_vectors(self): @@ -805,7 +816,7 @@ class HashLibTestCase(unittest.TestCase): @requires_sha3 def test_case_sha3_224_0(self): self.check('sha3_224', b"", - "6b4e03423667dbb73b6e15454f0eb1abd4597f9a1b078e3f5b5a6bc7") + "6b4e03423667dbb73b6e15454f0eb1abd4597f9a1b078e3f5b5a6bc7") @requires_sha3 def test_case_sha3_224_vector(self): @@ -815,7 +826,7 @@ class HashLibTestCase(unittest.TestCase): @requires_sha3 def test_case_sha3_256_0(self): self.check('sha3_256', b"", - "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a") + "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a") @requires_sha3 def test_case_sha3_256_vector(self): @@ -825,8 +836,8 @@ class HashLibTestCase(unittest.TestCase): @requires_sha3 def test_case_sha3_384_0(self): self.check('sha3_384', b"", - "0c63a75b845e4f7d01107d852e4c2485c51a50aaaa94fc61995e71bbee983a2a" + - "c3713831264adb47fb6bd1e058d5f004") + "0c63a75b845e4f7d01107d852e4c2485c51a50aaaa94fc61995e71bbee983a2a"+ + "c3713831264adb47fb6bd1e058d5f004") @requires_sha3 def test_case_sha3_384_vector(self): @@ -836,36 +847,38 @@ class HashLibTestCase(unittest.TestCase): @requires_sha3 def test_case_sha3_512_0(self): self.check('sha3_512', b"", - "a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a6" + - "15b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26") + "a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a6"+ + "15b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26") @requires_sha3 def test_case_sha3_512_vector(self): for msg, md in read_vectors('sha3_512'): self.check('sha3_512', msg, md) - @requires_sha3 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_case_shake_128_0(self): self.check('shake_128', b"", - "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26", - True) + "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26", + True) self.check('shake_128', b"", "7f9c", True) - @requires_sha3 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_case_shake128_vector(self): for msg, md in read_vectors('shake_128'): self.check('shake_128', msg, md, True) - @requires_sha3 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_case_shake_256_0(self): self.check('shake_256', b"", - "46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762f", - True) + "46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762f", + True) self.check('shake_256', b"", "46b9", True) # TODO: RUSTPYTHON @unittest.expectedFailure - @requires_sha3 def test_case_shake256_vector(self): for msg, md in read_vectors('shake_256'): self.check('shake_256', msg, md, True) @@ -940,17 +953,42 @@ class HashLibTestCase(unittest.TestCase): if fips_mode is not None: self.assertIsInstance(fips_mode, int) + @support.cpython_only + def test_disallow_instantiation(self): + for algorithm, constructors in self.constructors_to_test.items(): + if algorithm.startswith(("sha3_", "shake", "blake")): + # _sha3 and _blake types can be instantiated + continue + # all other types have DISALLOW_INSTANTIATION + for constructor in constructors: + # In FIPS mode some algorithms are not available raising ValueError + try: + h = constructor() + except ValueError: + continue + with self.subTest(constructor=constructor): + support.check_disallow_instantiation(self, type(h)) + @unittest.skipUnless(HASH is not None, 'need _hashlib') - def test_internal_types(self): + def test_hash_disallow_instantiation(self): # internal types like _hashlib.HASH are not constructable - with self.assertRaisesRegex( - TypeError, "cannot create 'HASH' instance" - ): - HASH() - with self.assertRaisesRegex( - TypeError, "cannot create 'HASHXOF' instance" - ): - HASHXOF() + support.check_disallow_instantiation(self, HASH) + support.check_disallow_instantiation(self, HASHXOF) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_readonly_types(self): + for algorithm, constructors in self.constructors_to_test.items(): + # all other types have DISALLOW_INSTANTIATION + for constructor in constructors: + # In FIPS mode some algorithms are not available raising ValueError + try: + hash_type = type(constructor()) + except ValueError: + continue + with self.subTest(hash_type=hash_type): + with self.assertRaisesRegex(TypeError, "immutable type"): + hash_type.value = False class KDFTests(unittest.TestCase): @@ -967,13 +1005,10 @@ class KDFTests(unittest.TestCase): ] scrypt_test_vectors = [ - (b'', b'', 16, 1, 1, unhexlify( - '77d6576238657b203b19ca42c18a0497f16b4844e3074ae8dfdffa3fede21442fcd0069ded0948f8326a753a0fc81f17e8d3e0fb2e0d3628cf35e20c38d18906')), - (b'password', b'NaCl', 1024, 8, 16, unhexlify( - 'fdbabe1c9d3472007856e7190d01e9fe7c6ad7cbc8237830e77376634b3731622eaf30d92e22a3886ff109279d9830dac727afb94a83ee6d8360cbdfa2cc0640')), - (b'pleaseletmein', b'SodiumChloride', 16384, 8, 1, unhexlify( - '7023bdcb3afd7348461c06cd81fd38ebfda8fbba904f8e3ea9b543f6545da1f2d5432955613f0fcf62d49705242a9af9e61e85dc0d651e40dfcf017b45575887')), - ] + (b'', b'', 16, 1, 1, unhexlify('77d6576238657b203b19ca42c18a0497f16b4844e3074ae8dfdffa3fede21442fcd0069ded0948f8326a753a0fc81f17e8d3e0fb2e0d3628cf35e20c38d18906')), + (b'password', b'NaCl', 1024, 8, 16, unhexlify('fdbabe1c9d3472007856e7190d01e9fe7c6ad7cbc8237830e77376634b3731622eaf30d92e22a3886ff109279d9830dac727afb94a83ee6d8360cbdfa2cc0640')), + (b'pleaseletmein', b'SodiumChloride', 16384, 8, 1, unhexlify('7023bdcb3afd7348461c06cd81fd38ebfda8fbba904f8e3ea9b543f6545da1f2d5432955613f0fcf62d49705242a9af9e61e85dc0d651e40dfcf017b45575887')), + ] pbkdf2_results = { "sha1": [ @@ -984,7 +1019,7 @@ class KDFTests(unittest.TestCase): #(bytes.fromhex('eefe3d61cd4da4e4e9945b3d6ba2158c2634e984'), None), (bytes.fromhex('3d2eec4fe41c849b80c8d83662c0e44a8b291a964c' 'f2f07038'), 25), - (bytes.fromhex('56fa6aa75548099dcc37d7f03425e0c3'), None), ], + (bytes.fromhex('56fa6aa75548099dcc37d7f03425e0c3'), None),], "sha256": [ (bytes.fromhex('120fb6cffcf8b32c43e7225256c4f837' 'a86548c92ccc35480805987cb70be17b'), None), @@ -992,11 +1027,11 @@ class KDFTests(unittest.TestCase): '2a303f8ef3c251dfd6e2d85a95474c43'), None), (bytes.fromhex('c5e478d59288c841aa530db6845c4c8d' '962893a001ce4e11a4963873aa98134a'), None), - # (bytes.fromhex('cf81c66fe8cfc04d1f31ecb65dab4089' + #(bytes.fromhex('cf81c66fe8cfc04d1f31ecb65dab4089' # 'f7f179e89b3b0bcb17ad10e3ac6eba46'), None), (bytes.fromhex('348c89dbcbd32b2f32d814b8116e84cf2b17' '347ebc1800181c4e2a1fb8dd53e1c635518c7dac47e9'), 40), - (bytes.fromhex('89b69d0516f829893c696226650a8687'), None), ], + (bytes.fromhex('89b69d0516f829893c696226650a8687'), None),], "sha512": [ (bytes.fromhex('867f70cf1ade02cff3752599a3a53dc4af34c7a669815ae5' 'd513554e1c8cf252c02d470a285a0501bad999bfe943c08f' @@ -1010,7 +1045,7 @@ class KDFTests(unittest.TestCase): (bytes.fromhex('8c0511f4c6e597c6ac6315d8f0362e225f3c501495ba23b8' '68c005174dc4ee71115b59f9e60cd9532fa33e0f75aefe30' '225c583a186cd82bd4daea9724a3d3b8'), 64), - (bytes.fromhex('9d9e9c4cd21fe4be24d5b8244c759665'), None), ], + (bytes.fromhex('9d9e9c4cd21fe4be24d5b8244c759665'), None),], } def _test_pbkdf2_hmac(self, pbkdf2, supported): @@ -1036,7 +1071,7 @@ class KDFTests(unittest.TestCase): self.assertEqual(out, expected, (digest_name, password, salt, rounds)) - with self.assertRaisesRegex(ValueError, 'unsupported hash type'): + with self.assertRaisesRegex(ValueError, '.*unsupported.*'): pbkdf2('unknown', b'pass', b'salt', 1) if 'sha1' in supported: @@ -1059,22 +1094,26 @@ class KDFTests(unittest.TestCase): ValueError, pbkdf2, 'sha1', b'pass', b'salt', 1, -1 ) out = pbkdf2(hash_name='sha1', password=b'password', salt=b'salt', - iterations=1, dklen=None) + iterations=1, dklen=None) self.assertEqual(out, self.pbkdf2_results['sha1'][0][0]) # TODO: RUSTPYTHON @unittest.expectedFailure @unittest.skipIf(builtin_hashlib is None, "test requires builtin_hashlib") def test_pbkdf2_hmac_py(self): - self._test_pbkdf2_hmac(builtin_hashlib.pbkdf2_hmac, builtin_hashes) + with warnings_helper.check_warnings(): + self._test_pbkdf2_hmac( + builtin_hashlib.pbkdf2_hmac, builtin_hashes + ) @unittest.skipUnless(hasattr(openssl_hashlib, 'pbkdf2_hmac'), - ' test requires OpenSSL > 1.0') + ' test requires OpenSSL > 1.0') def test_pbkdf2_hmac_c(self): self._test_pbkdf2_hmac(openssl_hashlib.pbkdf2_hmac, openssl_md_meth_names) @unittest.skipUnless(hasattr(hashlib, 'scrypt'), - ' test requires OpenSSL > 1.1') + ' test requires OpenSSL > 1.1') + @unittest.skipIf(get_fips_mode(), reason="scrypt is blocked in FIPS mode") def test_scrypt(self): for password, salt, n, r, p, expected in self.scrypt_test_vectors: result = hashlib.scrypt(password, salt=salt, n=n, r=r, p=p) diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index 98ee090b1..cb1e4505b 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -146,11 +146,11 @@ class TestHeap: self.assertEqual(type(h[0]), int) self.assertEqual(type(x), float) - h = [10]; + h = [10] x = self.module.heappushpop(h, 9) self.assertEqual((h, x), ([10], 9)) - h = [10]; + h = [10] x = self.module.heappushpop(h, 11) self.assertEqual((h, x), ([11], 10)) @@ -433,6 +433,37 @@ class TestErrorHandling: with self.assertRaises((IndexError, RuntimeError)): self.module.heappop(heap) + def test_comparison_operator_modifiying_heap(self): + # See bpo-39421: Strong references need to be taken + # when comparing objects as they can alter the heap + class EvilClass(int): + def __lt__(self, o): + heap.clear() + return NotImplemented + + heap = [] + self.module.heappush(heap, EvilClass(0)) + self.assertRaises(IndexError, self.module.heappushpop, heap, 1) + + def test_comparison_operator_modifiying_heap_two_heaps(self): + + class h(int): + def __lt__(self, o): + list2.clear() + return NotImplemented + + class g(int): + def __lt__(self, o): + list1.clear() + return NotImplemented + + list1, list2 = [], [] + + self.module.heappush(list1, h(0)) + self.module.heappush(list2, g(0)) + + self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1)) + self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1)) class TestErrorHandlingPython(TestErrorHandling, TestCase): module = py_heapq diff --git a/Lib/test/test_isinstance.py b/Lib/test/test_isinstance.py index bfefe3b97..a8d6ddb90 100644 --- a/Lib/test/test_isinstance.py +++ b/Lib/test/test_isinstance.py @@ -1,295 +1,364 @@ -# Tests some corner cases with isinstance() and issubclass(). While these -# tests use new style classes and properties, they actually do whitebox -# testing of error conditions uncovered when using extension types. - -import unittest -import sys - - - -class TestIsInstanceExceptions(unittest.TestCase): - # Test to make sure that an AttributeError when accessing the instance's - # class's bases is masked. This was actually a bug in Python 2.2 and - # 2.2.1 where the exception wasn't caught but it also wasn't being cleared - # (leading to an "undetected error" in the debug build). Set up is, - # isinstance(inst, cls) where: - # - # - cls isn't a type, or a tuple - # - cls has a __bases__ attribute - # - inst has a __class__ attribute - # - inst.__class__ as no __bases__ attribute - # - # Sounds complicated, I know, but this mimics a situation where an - # extension type raises an AttributeError when its __bases__ attribute is - # gotten. In that case, isinstance() should return False. - def test_class_has_no_bases(self): - class I(object): - def getclass(self): - # This must return an object that has no __bases__ attribute - return None - __class__ = property(getclass) - - class C(object): - def getbases(self): - return () - __bases__ = property(getbases) - - self.assertEqual(False, isinstance(I(), C())) - - # Like above except that inst.__class__.__bases__ raises an exception - # other than AttributeError - def test_bases_raises_other_than_attribute_error(self): - class E(object): - def getbases(self): - raise RuntimeError - __bases__ = property(getbases) - - class I(object): - def getclass(self): - return E() - __class__ = property(getclass) - - class C(object): - def getbases(self): - return () - __bases__ = property(getbases) - - self.assertRaises(RuntimeError, isinstance, I(), C()) - - # Here's a situation where getattr(cls, '__bases__') raises an exception. - # If that exception is not AttributeError, it should not get masked - def test_dont_mask_non_attribute_error(self): - class I: pass - - class C(object): - def getbases(self): - raise RuntimeError - __bases__ = property(getbases) - - self.assertRaises(RuntimeError, isinstance, I(), C()) - - # Like above, except that getattr(cls, '__bases__') raises an - # AttributeError, which /should/ get masked as a TypeError - def test_mask_attribute_error(self): - class I: pass - - class C(object): - def getbases(self): - raise AttributeError - __bases__ = property(getbases) - - self.assertRaises(TypeError, isinstance, I(), C()) - - # check that we don't mask non AttributeErrors - # see: http://bugs.python.org/issue1574217 - def test_isinstance_dont_mask_non_attribute_error(self): - class C(object): - def getclass(self): - raise RuntimeError - __class__ = property(getclass) - - c = C() - self.assertRaises(RuntimeError, isinstance, c, bool) - - # test another code path - class D: pass - self.assertRaises(RuntimeError, isinstance, c, D) - - -# These tests are similar to above, but tickle certain code paths in -# issubclass() instead of isinstance() -- really PyObject_IsSubclass() -# vs. PyObject_IsInstance(). -class TestIsSubclassExceptions(unittest.TestCase): - def test_dont_mask_non_attribute_error(self): - class C(object): - def getbases(self): - raise RuntimeError - __bases__ = property(getbases) - - class S(C): pass - - self.assertRaises(RuntimeError, issubclass, C(), S()) - - def test_mask_attribute_error(self): - class C(object): - def getbases(self): - raise AttributeError - __bases__ = property(getbases) - - class S(C): pass - - self.assertRaises(TypeError, issubclass, C(), S()) - - # Like above, but test the second branch, where the __bases__ of the - # second arg (the cls arg) is tested. This means the first arg must - # return a valid __bases__, and it's okay for it to be a normal -- - # unrelated by inheritance -- class. - def test_dont_mask_non_attribute_error_in_cls_arg(self): - class B: pass - - class C(object): - def getbases(self): - raise RuntimeError - __bases__ = property(getbases) - - self.assertRaises(RuntimeError, issubclass, B, C()) - - def test_mask_attribute_error_in_cls_arg(self): - class B: pass - - class C(object): - def getbases(self): - raise AttributeError - __bases__ = property(getbases) - - self.assertRaises(TypeError, issubclass, B, C()) - - - -# meta classes for creating abstract classes and instances -class AbstractClass(object): - def __init__(self, bases): - self.bases = bases - - def getbases(self): - return self.bases - __bases__ = property(getbases) - - def __call__(self): - return AbstractInstance(self) - -class AbstractInstance(object): - def __init__(self, klass): - self.klass = klass - - def getclass(self): - return self.klass - __class__ = property(getclass) - -# abstract classes -AbstractSuper = AbstractClass(bases=()) - -AbstractChild = AbstractClass(bases=(AbstractSuper,)) - -# normal classes -class Super: - pass - -class Child(Super): - pass - -class TestIsInstanceIsSubclass(unittest.TestCase): - # Tests to ensure that isinstance and issubclass work on abstract - # classes and instances. Before the 2.2 release, TypeErrors were - # raised when boolean values should have been returned. The bug was - # triggered by mixing 'normal' classes and instances were with - # 'abstract' classes and instances. This case tries to test all - # combinations. - - def test_isinstance_normal(self): - # normal instances - self.assertEqual(True, isinstance(Super(), Super)) - self.assertEqual(False, isinstance(Super(), Child)) - self.assertEqual(False, isinstance(Super(), AbstractSuper)) - self.assertEqual(False, isinstance(Super(), AbstractChild)) - - self.assertEqual(True, isinstance(Child(), Super)) - self.assertEqual(False, isinstance(Child(), AbstractSuper)) - - def test_isinstance_abstract(self): - # abstract instances - self.assertEqual(True, isinstance(AbstractSuper(), AbstractSuper)) - self.assertEqual(False, isinstance(AbstractSuper(), AbstractChild)) - self.assertEqual(False, isinstance(AbstractSuper(), Super)) - self.assertEqual(False, isinstance(AbstractSuper(), Child)) - - self.assertEqual(True, isinstance(AbstractChild(), AbstractChild)) - self.assertEqual(True, isinstance(AbstractChild(), AbstractSuper)) - self.assertEqual(False, isinstance(AbstractChild(), Super)) - self.assertEqual(False, isinstance(AbstractChild(), Child)) - - def test_subclass_normal(self): - # normal classes - self.assertEqual(True, issubclass(Super, Super)) - self.assertEqual(False, issubclass(Super, AbstractSuper)) - self.assertEqual(False, issubclass(Super, Child)) - - self.assertEqual(True, issubclass(Child, Child)) - self.assertEqual(True, issubclass(Child, Super)) - self.assertEqual(False, issubclass(Child, AbstractSuper)) - - def test_subclass_abstract(self): - # abstract classes - self.assertEqual(True, issubclass(AbstractSuper, AbstractSuper)) - self.assertEqual(False, issubclass(AbstractSuper, AbstractChild)) - self.assertEqual(False, issubclass(AbstractSuper, Child)) - - self.assertEqual(True, issubclass(AbstractChild, AbstractChild)) - self.assertEqual(True, issubclass(AbstractChild, AbstractSuper)) - self.assertEqual(False, issubclass(AbstractChild, Super)) - self.assertEqual(False, issubclass(AbstractChild, Child)) - - def test_subclass_tuple(self): - # test with a tuple as the second argument classes - self.assertEqual(True, issubclass(Child, (Child,))) - self.assertEqual(True, issubclass(Child, (Super,))) - self.assertEqual(False, issubclass(Super, (Child,))) - self.assertEqual(True, issubclass(Super, (Child, Super))) - self.assertEqual(False, issubclass(Child, ())) - self.assertEqual(True, issubclass(Super, (Child, (Super,)))) - - self.assertEqual(True, issubclass(int, (int, (float, int)))) - self.assertEqual(True, issubclass(str, (str, (Child, str)))) - - def test_subclass_recursion_limit(self): - # make sure that issubclass raises RecursionError before the C stack is - # blown - self.assertRaises(RecursionError, blowstack, issubclass, str, str) - - def test_isinstance_recursion_limit(self): - # make sure that issubclass raises RecursionError before the C stack is - # blown - self.assertRaises(RecursionError, blowstack, isinstance, '', str) - - def test_issubclass_refcount_handling(self): - # bpo-39382: abstract_issubclass() didn't hold item reference while - # peeking in the bases tuple, in the single inheritance case. - class A: - @property - def __bases__(self): - return (int, ) - - class B: - def __init__(self): - # setting this here increases the chances of exhibiting the bug, - # probably due to memory layout changes. - self.x = 1 - - @property - def __bases__(self): - return (A(), ) - - self.assertEqual(True, issubclass(B(), int)) - - def test_infinite_recursion_in_bases(self): - class X: - @property - def __bases__(self): - return self.__bases__ - - self.assertRaises(RecursionError, issubclass, X(), int) - self.assertRaises(RecursionError, issubclass, int, X()) - self.assertRaises(RecursionError, isinstance, 1, X()) - - -def blowstack(fxn, arg, compare_to): - # Make sure that calling isinstance with a deeply nested tuple for its - # argument will raise RecursionError eventually. - tuple_arg = (compare_to,) - for cnt in range(sys.getrecursionlimit()+5): - tuple_arg = (tuple_arg,) - fxn(arg, tuple_arg) - - -if __name__ == '__main__': - unittest.main() +# Tests some corner cases with isinstance() and issubclass(). While these +# tests use new style classes and properties, they actually do whitebox +# testing of error conditions uncovered when using extension types. + +import unittest +import sys +import typing +from test import support + + + +class TestIsInstanceExceptions(unittest.TestCase): + # Test to make sure that an AttributeError when accessing the instance's + # class's bases is masked. This was actually a bug in Python 2.2 and + # 2.2.1 where the exception wasn't caught but it also wasn't being cleared + # (leading to an "undetected error" in the debug build). Set up is, + # isinstance(inst, cls) where: + # + # - cls isn't a type, or a tuple + # - cls has a __bases__ attribute + # - inst has a __class__ attribute + # - inst.__class__ as no __bases__ attribute + # + # Sounds complicated, I know, but this mimics a situation where an + # extension type raises an AttributeError when its __bases__ attribute is + # gotten. In that case, isinstance() should return False. + def test_class_has_no_bases(self): + class I(object): + def getclass(self): + # This must return an object that has no __bases__ attribute + return None + __class__ = property(getclass) + + class C(object): + def getbases(self): + return () + __bases__ = property(getbases) + + self.assertEqual(False, isinstance(I(), C())) + + # Like above except that inst.__class__.__bases__ raises an exception + # other than AttributeError + def test_bases_raises_other_than_attribute_error(self): + class E(object): + def getbases(self): + raise RuntimeError + __bases__ = property(getbases) + + class I(object): + def getclass(self): + return E() + __class__ = property(getclass) + + class C(object): + def getbases(self): + return () + __bases__ = property(getbases) + + self.assertRaises(RuntimeError, isinstance, I(), C()) + + # Here's a situation where getattr(cls, '__bases__') raises an exception. + # If that exception is not AttributeError, it should not get masked + def test_dont_mask_non_attribute_error(self): + class I: pass + + class C(object): + def getbases(self): + raise RuntimeError + __bases__ = property(getbases) + + self.assertRaises(RuntimeError, isinstance, I(), C()) + + # Like above, except that getattr(cls, '__bases__') raises an + # AttributeError, which /should/ get masked as a TypeError + def test_mask_attribute_error(self): + class I: pass + + class C(object): + def getbases(self): + raise AttributeError + __bases__ = property(getbases) + + self.assertRaises(TypeError, isinstance, I(), C()) + + # check that we don't mask non AttributeErrors + # see: http://bugs.python.org/issue1574217 + def test_isinstance_dont_mask_non_attribute_error(self): + class C(object): + def getclass(self): + raise RuntimeError + __class__ = property(getclass) + + c = C() + self.assertRaises(RuntimeError, isinstance, c, bool) + + # test another code path + class D: pass + self.assertRaises(RuntimeError, isinstance, c, D) + + +# These tests are similar to above, but tickle certain code paths in +# issubclass() instead of isinstance() -- really PyObject_IsSubclass() +# vs. PyObject_IsInstance(). +class TestIsSubclassExceptions(unittest.TestCase): + def test_dont_mask_non_attribute_error(self): + class C(object): + def getbases(self): + raise RuntimeError + __bases__ = property(getbases) + + class S(C): pass + + self.assertRaises(RuntimeError, issubclass, C(), S()) + + def test_mask_attribute_error(self): + class C(object): + def getbases(self): + raise AttributeError + __bases__ = property(getbases) + + class S(C): pass + + self.assertRaises(TypeError, issubclass, C(), S()) + + # Like above, but test the second branch, where the __bases__ of the + # second arg (the cls arg) is tested. This means the first arg must + # return a valid __bases__, and it's okay for it to be a normal -- + # unrelated by inheritance -- class. + def test_dont_mask_non_attribute_error_in_cls_arg(self): + class B: pass + + class C(object): + def getbases(self): + raise RuntimeError + __bases__ = property(getbases) + + self.assertRaises(RuntimeError, issubclass, B, C()) + + def test_mask_attribute_error_in_cls_arg(self): + class B: pass + + class C(object): + def getbases(self): + raise AttributeError + __bases__ = property(getbases) + + self.assertRaises(TypeError, issubclass, B, C()) + + + +# meta classes for creating abstract classes and instances +class AbstractClass(object): + def __init__(self, bases): + self.bases = bases + + def getbases(self): + return self.bases + __bases__ = property(getbases) + + def __call__(self): + return AbstractInstance(self) + +class AbstractInstance(object): + def __init__(self, klass): + self.klass = klass + + def getclass(self): + return self.klass + __class__ = property(getclass) + +# abstract classes +AbstractSuper = AbstractClass(bases=()) + +AbstractChild = AbstractClass(bases=(AbstractSuper,)) + +# normal classes +class Super: + pass + +class Child(Super): + pass + +class TestIsInstanceIsSubclass(unittest.TestCase): + # Tests to ensure that isinstance and issubclass work on abstract + # classes and instances. Before the 2.2 release, TypeErrors were + # raised when boolean values should have been returned. The bug was + # triggered by mixing 'normal' classes and instances were with + # 'abstract' classes and instances. This case tries to test all + # combinations. + + def test_isinstance_normal(self): + # normal instances + self.assertEqual(True, isinstance(Super(), Super)) + self.assertEqual(False, isinstance(Super(), Child)) + self.assertEqual(False, isinstance(Super(), AbstractSuper)) + self.assertEqual(False, isinstance(Super(), AbstractChild)) + + self.assertEqual(True, isinstance(Child(), Super)) + self.assertEqual(False, isinstance(Child(), AbstractSuper)) + + def test_isinstance_abstract(self): + # abstract instances + self.assertEqual(True, isinstance(AbstractSuper(), AbstractSuper)) + self.assertEqual(False, isinstance(AbstractSuper(), AbstractChild)) + self.assertEqual(False, isinstance(AbstractSuper(), Super)) + self.assertEqual(False, isinstance(AbstractSuper(), Child)) + + self.assertEqual(True, isinstance(AbstractChild(), AbstractChild)) + self.assertEqual(True, isinstance(AbstractChild(), AbstractSuper)) + self.assertEqual(False, isinstance(AbstractChild(), Super)) + self.assertEqual(False, isinstance(AbstractChild(), Child)) + + def test_isinstance_with_or_union(self): + self.assertTrue(isinstance(Super(), Super | int)) + self.assertFalse(isinstance(None, str | int)) + self.assertTrue(isinstance(3, str | int)) + self.assertTrue(isinstance("", str | int)) + self.assertTrue(isinstance([], typing.List | typing.Tuple)) + self.assertTrue(isinstance(2, typing.List | int)) + self.assertFalse(isinstance(2, typing.List | typing.Tuple)) + self.assertTrue(isinstance(None, int | None)) + self.assertFalse(isinstance(3.14, int | str)) + with self.assertRaises(TypeError): + isinstance(2, list[int]) + with self.assertRaises(TypeError): + isinstance(2, list[int] | int) + with self.assertRaises(TypeError): + isinstance(2, int | str | list[int] | float) + + + + def test_subclass_normal(self): + # normal classes + self.assertEqual(True, issubclass(Super, Super)) + self.assertEqual(False, issubclass(Super, AbstractSuper)) + self.assertEqual(False, issubclass(Super, Child)) + + self.assertEqual(True, issubclass(Child, Child)) + self.assertEqual(True, issubclass(Child, Super)) + self.assertEqual(False, issubclass(Child, AbstractSuper)) + self.assertTrue(issubclass(typing.List, typing.List|typing.Tuple)) + self.assertFalse(issubclass(int, typing.List|typing.Tuple)) + + def test_subclass_abstract(self): + # abstract classes + self.assertEqual(True, issubclass(AbstractSuper, AbstractSuper)) + self.assertEqual(False, issubclass(AbstractSuper, AbstractChild)) + self.assertEqual(False, issubclass(AbstractSuper, Child)) + + self.assertEqual(True, issubclass(AbstractChild, AbstractChild)) + self.assertEqual(True, issubclass(AbstractChild, AbstractSuper)) + self.assertEqual(False, issubclass(AbstractChild, Super)) + self.assertEqual(False, issubclass(AbstractChild, Child)) + + def test_subclass_tuple(self): + # test with a tuple as the second argument classes + self.assertEqual(True, issubclass(Child, (Child,))) + self.assertEqual(True, issubclass(Child, (Super,))) + self.assertEqual(False, issubclass(Super, (Child,))) + self.assertEqual(True, issubclass(Super, (Child, Super))) + self.assertEqual(False, issubclass(Child, ())) + self.assertEqual(True, issubclass(Super, (Child, (Super,)))) + + self.assertEqual(True, issubclass(int, (int, (float, int)))) + self.assertEqual(True, issubclass(str, (str, (Child, str)))) + + def test_subclass_recursion_limit(self): + # make sure that issubclass raises RecursionError before the C stack is + # blown + with support.infinite_recursion(): + self.assertRaises(RecursionError, blowstack, issubclass, str, str) + + def test_isinstance_recursion_limit(self): + # make sure that issubclass raises RecursionError before the C stack is + # blown + with support.infinite_recursion(): + self.assertRaises(RecursionError, blowstack, isinstance, '', str) + + def test_subclass_with_union(self): + self.assertTrue(issubclass(int, int | float | int)) + self.assertTrue(issubclass(str, str | Child | str)) + self.assertFalse(issubclass(dict, float|str)) + self.assertFalse(issubclass(object, float|str)) + with self.assertRaises(TypeError): + issubclass(2, Child | Super) + with self.assertRaises(TypeError): + issubclass(int, list[int] | Child) + + def test_issubclass_refcount_handling(self): + # bpo-39382: abstract_issubclass() didn't hold item reference while + # peeking in the bases tuple, in the single inheritance case. + class A: + @property + def __bases__(self): + return (int, ) + + class B: + def __init__(self): + # setting this here increases the chances of exhibiting the bug, + # probably due to memory layout changes. + self.x = 1 + + @property + def __bases__(self): + return (A(), ) + + self.assertEqual(True, issubclass(B(), int)) + + def test_infinite_recursion_in_bases(self): + class X: + @property + def __bases__(self): + return self.__bases__ + with support.infinite_recursion(): + self.assertRaises(RecursionError, issubclass, X(), int) + self.assertRaises(RecursionError, issubclass, int, X()) + self.assertRaises(RecursionError, isinstance, 1, X()) + + @unittest.skip("TODO: RUSTPYTHON") + def test_infinite_recursion_via_bases_tuple(self): + """Regression test for bpo-30570.""" + class Failure(object): + def __getattr__(self, attr): + return (self, None) + with support.infinite_recursion(): + with self.assertRaises(RecursionError): + issubclass(Failure(), int) + + @unittest.skip("TODO: RUSTPYTHON") + def test_infinite_cycle_in_bases(self): + """Regression test for bpo-30570.""" + class X: + @property + def __bases__(self): + return (self, self, self) + with support.infinite_recursion(): + self.assertRaises(RecursionError, issubclass, X(), int) + + def test_infinitely_many_bases(self): + """Regression test for bpo-30570.""" + class X: + def __getattr__(self, attr): + self.assertEqual(attr, "__bases__") + class A: + pass + class B: + pass + A.__getattr__ = B.__getattr__ = X.__getattr__ + return (A(), B()) + with support.infinite_recursion(): + self.assertRaises(RecursionError, issubclass, X(), int) + + +def blowstack(fxn, arg, compare_to): + # Make sure that calling isinstance with a deeply nested tuple for its + # argument will raise RecursionError eventually. + tuple_arg = (compare_to,) + for cnt in range(sys.getrecursionlimit()+5): + tuple_arg = (tuple_arg,) + fxn(arg, tuple_arg) + + +if __name__ == '__main__': + unittest.main()