First step for Python 3.12 support (#5078)

* Mark 3.12

* Update importlib from Python 3.12.0

* Update test_importlib from Python3.12

* Mark failings tests from importlib

* Update test.support from Python3.12

* Fix unsupported parser feature

* mark failing test

* Update functools from Python 3.12

* manual type annotation

* slice behavior changed in 3.12

* empty unittest.main returns non-zero

* test_decimal from CPython 3.12

* Mark failing tests

* Update test_unicode from CPython 3.12

* Update test_functools from Python 3.12

* Update enum from Python 3.12

* enum

* Doc format changed

* Update test_module from CPython

---------

Co-authored-by: CPython developers <>
This commit is contained in:
Jeong, YunWon
2023-10-22 19:19:05 -07:00
committed by GitHub
parent a75f26b922
commit af884cb284
113 changed files with 4826 additions and 3624 deletions

View File

@@ -105,7 +105,7 @@ env:
test_weakref
test_yield_from
# Python version targeted by the CI.
PYTHON_VERSION: "3.11.4"
PYTHON_VERSION: "3.12.0"
jobs:
rust_tests:

View File

@@ -7,7 +7,7 @@ name: Periodic checks/tasks
env:
CARGO_ARGS: --no-default-features --features stdlib,zlib,importlib,encodings,ssl,jit
PYTHON_VERSION: "3.11.4"
PYTHON_VERSION: "3.12.0"
jobs:
# codecov collects code coverage data from the rust tests, python snippets and python test suite.

View File

@@ -25,7 +25,7 @@ RustPython requires the following:
stable version: `rustup update stable`
- If you do not have Rust installed, use [rustup](https://rustup.rs/) to
do so.
- CPython version 3.11 or higher
- CPython version 3.12 or higher
- CPython can be installed by your operating system's package manager,
from the [Python website](https://www.python.org/downloads/), or
using a third-party distribution, such as

254
Lib/enum.py vendored
View File

@@ -190,41 +190,48 @@ class property(DynamicClassAttribute):
a corresponding enum member.
"""
member = None
_attr_type = None
_cls_type = None
def __get__(self, instance, ownerclass=None):
if instance is None:
try:
return ownerclass._member_map_[self.name]
except KeyError:
if self.member is not None:
return self.member
else:
raise AttributeError(
'%r has no attribute %r' % (ownerclass, self.name)
)
else:
if self.fget is None:
# look for a member by this name.
try:
return ownerclass._member_map_[self.name]
except KeyError:
raise AttributeError(
'%r has no attribute %r' % (ownerclass, self.name)
) from None
else:
return self.fget(instance)
if self.fget is not None:
# use previous enum.property
return self.fget(instance)
elif self._attr_type == 'attr':
# look up previous attibute
return getattr(self._cls_type, self.name)
elif self._attr_type == 'desc':
# use previous descriptor
return getattr(instance._value_, self.name)
# look for a member by this name.
try:
return ownerclass._member_map_[self.name]
except KeyError:
raise AttributeError(
'%r has no attribute %r' % (ownerclass, self.name)
) from None
def __set__(self, instance, value):
if self.fset is None:
raise AttributeError(
"<enum %r> cannot set attribute %r" % (self.clsname, self.name)
)
else:
if self.fset is not None:
return self.fset(instance, value)
raise AttributeError(
"<enum %r> cannot set attribute %r" % (self.clsname, self.name)
)
def __delete__(self, instance):
if self.fdel is None:
raise AttributeError(
"<enum %r> cannot delete attribute %r" % (self.clsname, self.name)
)
else:
if self.fdel is not None:
return self.fdel(instance)
raise AttributeError(
"<enum %r> cannot delete attribute %r" % (self.clsname, self.name)
)
def __set_name__(self, ownerclass, name):
self.name = name
@@ -312,27 +319,38 @@ class _proto_member:
enum_class._member_names_.append(member_name)
# if necessary, get redirect in place and then add it to _member_map_
found_descriptor = None
descriptor_type = None
class_type = None
for base in enum_class.__mro__[1:]:
descriptor = base.__dict__.get(member_name)
if descriptor is not None:
if isinstance(descriptor, (property, DynamicClassAttribute)):
found_descriptor = descriptor
attr = base.__dict__.get(member_name)
if attr is not None:
if isinstance(attr, (property, DynamicClassAttribute)):
found_descriptor = attr
class_type = base
descriptor_type = 'enum'
break
elif (
hasattr(descriptor, 'fget') and
hasattr(descriptor, 'fset') and
hasattr(descriptor, 'fdel')
):
found_descriptor = descriptor
elif _is_descriptor(attr):
found_descriptor = attr
descriptor_type = descriptor_type or 'desc'
class_type = class_type or base
continue
else:
descriptor_type = 'attr'
class_type = base
if found_descriptor:
redirect = property()
redirect.member = enum_member
redirect.__set_name__(enum_class, member_name)
# earlier descriptor found; copy fget, fset, fdel to this one.
redirect.fget = found_descriptor.fget
redirect.fset = found_descriptor.fset
redirect.fdel = found_descriptor.fdel
if descriptor_type in ('enum','desc'):
# earlier descriptor found; copy fget, fset, fdel to this one.
redirect.fget = getattr(found_descriptor, 'fget', None)
redirect._get = getattr(found_descriptor, '__get__', None)
redirect.fset = getattr(found_descriptor, 'fset', None)
redirect._set = getattr(found_descriptor, '__set__', None)
redirect.fdel = getattr(found_descriptor, 'fdel', None)
redirect._del = getattr(found_descriptor, '__delete__', None)
redirect._attr_type = descriptor_type
redirect._cls_type = class_type
setattr(enum_class, member_name, redirect)
else:
setattr(enum_class, member_name, enum_member)
@@ -521,8 +539,13 @@ class EnumType(type):
#
# adjust the sunders
_order_ = classdict.pop('_order_', None)
_gnv = classdict.get('_generate_next_value_')
if _gnv is not None and type(_gnv) is not staticmethod:
_gnv = staticmethod(_gnv)
# convert to normal dict
classdict = dict(classdict.items())
if _gnv is not None:
classdict['_generate_next_value_'] = _gnv
#
# data type of member and the controlling Enum class
member_type, first_enum = metacls._get_mixins_(cls, bases)
@@ -674,7 +697,7 @@ class EnumType(type):
'member order does not match _order_:\n %r\n %r'
% (enum_class._member_names_, _order_)
)
#
return enum_class
def __bool__(cls):
@@ -683,7 +706,7 @@ class EnumType(type):
"""
return True
def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None):
def __call__(cls, value, names=None, *values, module=None, qualname=None, type=None, start=1, boundary=None):
"""
Either returns an existing member, or creates a new enum class.
@@ -691,6 +714,8 @@ class EnumType(type):
to an enumeration member (i.e. Color(3)) and for the functional API
(i.e. Color = Enum('Color', names='RED GREEN BLUE')).
The value lookup branch is chosen if the enum is final.
When used for the functional API:
`value` will be the name of the new class.
@@ -708,12 +733,20 @@ class EnumType(type):
`type`, if set, will be mixed in as the first base class.
"""
if names is None: # simple value lookup
if cls._member_map_:
# simple value lookup if members exist
if names:
value = (value, names) + values
return cls.__new__(cls, value)
# otherwise, functional API: we're creating a new Enum type
if names is None and type is None:
# no body? no data-type? possibly wrong usage
raise TypeError(
f"{cls} has no members; specify `names=()` if you meant to create a new, empty, enum"
)
return cls._create_(
value,
names,
class_name=value,
names=names,
module=module,
qualname=qualname,
type=type,
@@ -721,26 +754,16 @@ class EnumType(type):
boundary=boundary,
)
def __contains__(cls, member):
"""
Return True if member is a member of this enum
raises TypeError if member is not an enum member
def __contains__(cls, value):
"""Return True if `value` is in `cls`.
note: in 3.12 TypeError will no longer be raised, and True will also be
returned if member is the value of a member in this enum
`value` is in `cls` if:
1) `value` is a member of `cls`, or
2) `value` is the value of one of the `cls`'s members.
"""
if not isinstance(member, Enum):
import warnings
warnings.warn(
"in 3.12 __contains__ will no longer raise TypeError, but will return True or\n"
"False depending on whether the value is a member or the value of a member",
DeprecationWarning,
stacklevel=2,
)
raise TypeError(
"unsupported operand type(s) for 'in': '%s' and '%s'" % (
type(member).__qualname__, cls.__class__.__qualname__))
return isinstance(member, cls) and member._name_ in cls._member_map_
if isinstance(value, cls):
return True
return value in cls._value2member_map_ or value in cls._unhashable_values_
def __delattr__(cls, attr):
# nicer error message when someone tries to delete an attribute
@@ -767,22 +790,6 @@ class EnumType(type):
# return whatever mixed-in data type has
return sorted(set(dir(cls._member_type_)) | interesting)
def __getattr__(cls, name):
"""
Return the enum member matching `name`
We use __getattr__ instead of descriptors or inserting into the enum
class' __dict__ in order to support `name` and `value` being both
properties for enum members (which live in the class' __dict__) and
enum members themselves.
"""
if _is_dunder(name):
raise AttributeError(name)
try:
return cls._member_map_[name]
except KeyError:
raise AttributeError(name) from None
def __getitem__(cls, name):
"""
Return the member matching `name`.
@@ -863,6 +870,8 @@ class EnumType(type):
value = first_enum._generate_next_value_(name, start, count, last_values[:])
last_values.append(value)
names.append((name, value))
if names is None:
names = ()
# Here, names is either an iterable of (name, value) or a mapping.
for item in names:
@@ -872,13 +881,15 @@ class EnumType(type):
member_name, member_value = item
classdict[member_name] = member_value
# TODO: replace the frame hack if a blessed way to know the calling
# module is ever developed
if module is None:
try:
module = sys._getframe(2).f_globals['__name__']
except (AttributeError, ValueError, KeyError):
pass
module = sys._getframemodulename(2)
except AttributeError:
# Fall back on _getframe if _getframemodulename is missing
try:
module = sys._getframe(2).f_globals['__name__']
except (AttributeError, ValueError, KeyError):
pass
if module is None:
_make_class_unpicklable(classdict)
else:
@@ -946,9 +957,6 @@ class EnumType(type):
"""
if not bases:
return object, Enum
mcls._check_for_existing_members_(class_name, bases)
# ensure final parent class is an Enum derivative, find any concrete
# data type, and check that Enum has no members
first_enum = bases[-1]
@@ -969,12 +977,20 @@ class EnumType(type):
return base._value_repr_
elif '__repr__' in base.__dict__:
# this is our data repr
return base.__dict__['__repr__']
# double-check if a dataclass with a default __repr__
if (
'__dataclass_fields__' in base.__dict__
and '__dataclass_params__' in base.__dict__
and base.__dict__['__dataclass_params__'].repr
):
return _dataclass_repr
else:
return base.__dict__['__repr__']
return None
@classmethod
def _find_data_type_(mcls, class_name, bases):
# a datatype has a __new__ method
# a datatype has a __new__ method, or a __dataclass_fields__ attribute
data_types = set()
base_chain = set()
for chain in bases:
@@ -988,8 +1004,6 @@ class EnumType(type):
data_types.add(base._member_type_)
break
elif '__new__' in base.__dict__ or '__dataclass_fields__' in base.__dict__:
if isinstance(base, EnumType):
continue
data_types.add(candidate or base)
break
else:
@@ -1061,20 +1075,20 @@ class Enum(metaclass=EnumType):
Access them by:
- attribute access::
- attribute access:
>>> Color.RED
<Color.RED: 1>
>>> Color.RED
<Color.RED: 1>
- value lookup:
>>> Color(1)
<Color.RED: 1>
>>> Color(1)
<Color.RED: 1>
- name lookup:
>>> Color['RED']
<Color.RED: 1>
>>> Color['RED']
<Color.RED: 1>
Enumerations can be iterated over, and know how many members they have:
@@ -1088,6 +1102,13 @@ class Enum(metaclass=EnumType):
attributes -- see the documentation for details.
"""
@classmethod
def __signature__(cls):
if cls._member_names_:
return '(*values)'
else:
return '(new_class_name, /, names, *, module=None, qualname=None, type=None, start=1, boundary=None)'
def __new__(cls, value):
# all enum instances are actually created during class construction
# without calling this method; this method is called by the metaclass'
@@ -1107,6 +1128,11 @@ class Enum(metaclass=EnumType):
for member in cls._member_map_.values():
if member._value_ == value:
return member
# still not found -- verify that members exist, in-case somebody got here mistakenly
# (such as via super when trying to override __new__)
if not cls._member_map_:
raise TypeError("%r has no members defined" % cls)
#
# still not found -- try _missing_ hook
try:
exc = None
@@ -1142,6 +1168,7 @@ class Enum(metaclass=EnumType):
def __init__(self, *args, **kwds):
pass
@staticmethod
def _generate_next_value_(name, start, count, last_values):
"""
Generate the next value when not given.
@@ -1236,10 +1263,10 @@ class Enum(metaclass=EnumType):
# enum.property is used to provide access to the `name` and
# `value` attributes of enum members while keeping some measure of
# protection from modification, while still allowing for an enumeration
# to have members named `name` and `value`. This works because enumeration
# members are not set directly on the enum class; they are kept in a
# separate structure, _member_map_, which is where enum.property looks for
# them
# to have members named `name` and `value`. This works because each
# instance of enum.property saves its companion member, which it returns
# on class lookup; on instance lookup it either executes a provided function
# or raises an AttributeError.
@property
def name(self):
@@ -1290,6 +1317,7 @@ class StrEnum(str, ReprEnum):
member._value_ = value
return member
@staticmethod
def _generate_next_value_(name, start, count, last_values):
"""
Return the lower-cased version of the member name.
@@ -1328,6 +1356,7 @@ class Flag(Enum, boundary=STRICT):
_numeric_repr_ = repr
@staticmethod
def _generate_next_value_(name, start, count, last_values):
"""
Generate the next value when not given.
@@ -1566,10 +1595,13 @@ def unique(enumeration):
(enumeration, alias_details))
return enumeration
def _power_of_two(value):
if value < 1:
return False
return value == 2 ** _high_bit(value)
def _dataclass_repr(self):
dcf = self.__dataclass_fields__
return ', '.join(
'%s=%r' % (k, getattr(self, k))
for k in dcf.keys()
if dcf[k].repr
)
def global_enum_repr(self):
"""
@@ -1713,10 +1745,12 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
value = gnv(name, 1, len(member_names), gnv_last_values)
if value in value2member_map:
# an alias to an existing member
member = value2member_map[value]
redirect = property()
redirect.member = member
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = value2member_map[value]
member_map[name] = member
else:
# create the member
if use_args:
@@ -1732,6 +1766,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
member.__objclass__ = enum_class
member.__init__(value)
redirect = property()
redirect.member = member
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = member
@@ -1760,10 +1795,12 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
value = value.value
if value in value2member_map:
# an alias to an existing member
member = value2member_map[value]
redirect = property()
redirect.member = member
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = value2member_map[value]
member_map[name] = member
else:
# create the member
if use_args:
@@ -1780,6 +1817,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
member.__init__(value)
member._sort_order_ = len(member_names)
redirect = property()
redirect.member = member
redirect.__set_name__(enum_class, name)
setattr(enum_class, name, redirect)
member_map[name] = member
@@ -1903,8 +1941,8 @@ def _test_simple_enum(checked_enum, simple_enum):
... RED = auto()
... GREEN = auto()
... BLUE = auto()
>>> # TODO: RUSTPYTHON
>>> # _test_simple_enum(CheckedColor, Color)
... # TODO: RUSTPYTHON
>>> _test_simple_enum(CheckedColor, Color) # doctest: +SKIP
If differences are found, a :exc:`TypeError` is raised.
"""

183
Lib/functools.py vendored
View File

@@ -10,9 +10,9 @@
# See C source code for _functools credits/copyright
__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial',
'partialmethod', 'singledispatch', 'singledispatchmethod',
"cached_property"]
'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce',
'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod',
'cached_property']
from abc import get_cache_token
from collections import namedtuple
@@ -30,7 +30,7 @@ from types import GenericAlias
# wrapper functions that can handle naive introspection
WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
'__annotations__')
'__annotations__', '__type_params__')
WRAPPER_UPDATES = ('__dict__',)
def update_wrapper(wrapper,
wrapped,
@@ -86,82 +86,86 @@ def wraps(wrapped,
# infinite recursion that could occur when the operator dispatch logic
# detects a NotImplemented result and then calls a reflected method.
def _gt_from_lt(self, other, NotImplemented=NotImplemented):
def _gt_from_lt(self, other):
'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).'
op_result = self.__lt__(other)
op_result = type(self).__lt__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result and self != other
def _le_from_lt(self, other, NotImplemented=NotImplemented):
def _le_from_lt(self, other):
'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).'
op_result = self.__lt__(other)
op_result = type(self).__lt__(self, other)
if op_result is NotImplemented:
return op_result
return op_result or self == other
def _ge_from_lt(self, other, NotImplemented=NotImplemented):
def _ge_from_lt(self, other):
'Return a >= b. Computed by @total_ordering from (not a < b).'
op_result = self.__lt__(other)
op_result = type(self).__lt__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result
def _ge_from_le(self, other, NotImplemented=NotImplemented):
def _ge_from_le(self, other):
'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).'
op_result = self.__le__(other)
op_result = type(self).__le__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result or self == other
def _lt_from_le(self, other, NotImplemented=NotImplemented):
def _lt_from_le(self, other):
'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).'
op_result = self.__le__(other)
op_result = type(self).__le__(self, other)
if op_result is NotImplemented:
return op_result
return op_result and self != other
def _gt_from_le(self, other, NotImplemented=NotImplemented):
def _gt_from_le(self, other):
'Return a > b. Computed by @total_ordering from (not a <= b).'
op_result = self.__le__(other)
op_result = type(self).__le__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result
def _lt_from_gt(self, other, NotImplemented=NotImplemented):
def _lt_from_gt(self, other):
'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).'
op_result = self.__gt__(other)
op_result = type(self).__gt__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result and self != other
def _ge_from_gt(self, other, NotImplemented=NotImplemented):
def _ge_from_gt(self, other):
'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).'
op_result = self.__gt__(other)
op_result = type(self).__gt__(self, other)
if op_result is NotImplemented:
return op_result
return op_result or self == other
def _le_from_gt(self, other, NotImplemented=NotImplemented):
def _le_from_gt(self, other):
'Return a <= b. Computed by @total_ordering from (not a > b).'
op_result = self.__gt__(other)
op_result = type(self).__gt__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result
def _le_from_ge(self, other, NotImplemented=NotImplemented):
def _le_from_ge(self, other):
'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).'
op_result = self.__ge__(other)
op_result = type(self).__ge__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result or self == other
def _gt_from_ge(self, other, NotImplemented=NotImplemented):
def _gt_from_ge(self, other):
'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).'
op_result = self.__ge__(other)
op_result = type(self).__ge__(self, other)
if op_result is NotImplemented:
return op_result
return op_result and self != other
def _lt_from_ge(self, other, NotImplemented=NotImplemented):
def _lt_from_ge(self, other):
'Return a < b. Computed by @total_ordering from (not a >= b).'
op_result = self.__ge__(other)
op_result = type(self).__ge__(self, other)
if op_result is NotImplemented:
return op_result
return not op_result
@@ -232,14 +236,14 @@ _initial_missing = object()
def reduce(function, sequence, initial=_initial_missing):
"""
reduce(function, sequence[, initial]) -> value
reduce(function, iterable[, initial]) -> value
Apply a function of two arguments cumulatively to the items of a sequence,
from left to right, so as to reduce the sequence to a single value.
For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates
Apply a function of two arguments cumulatively to the items of a sequence
or iterable, from left to right, so as to reduce the iterable to a single
value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates
((((1+2)+3)+4)+5). If initial is present, it is placed before the items
of the sequence in the calculation, and serves as a default when the
sequence is empty.
of the iterable in the calculation, and serves as a default when the
iterable is empty.
"""
it = iter(sequence)
@@ -248,7 +252,8 @@ def reduce(function, sequence, initial=_initial_missing):
try:
value = next(it)
except StopIteration:
raise TypeError("reduce() of empty sequence with no initial value") from None
raise TypeError(
"reduce() of empty iterable with no initial value") from None
else:
value = initial
@@ -347,23 +352,7 @@ class partialmethod(object):
callables as instance methods.
"""
def __init__(*args, **keywords):
if len(args) >= 2:
self, func, *args = args
elif not args:
raise TypeError("descriptor '__init__' of partialmethod "
"needs an argument")
elif 'func' in keywords:
func = keywords.pop('func')
self, *args = args
import warnings
warnings.warn("Passing 'func' as keyword argument is deprecated",
DeprecationWarning, stacklevel=2)
else:
raise TypeError("type 'partialmethod' takes at least one argument, "
"got %d" % (len(args)-1))
args = tuple(args)
def __init__(self, func, /, *args, **keywords):
if not callable(func) and not hasattr(func, "__get__"):
raise TypeError("{!r} is not callable or a descriptor"
.format(func))
@@ -381,7 +370,6 @@ class partialmethod(object):
self.func = func
self.args = args
self.keywords = keywords
__init__.__text_signature__ = '($self, func, /, *args, **keywords)'
def __repr__(self):
args = ", ".join(map(repr, self.args))
@@ -427,6 +415,7 @@ class partialmethod(object):
__class_getitem__ = classmethod(GenericAlias)
# Helper functions
def _unwrap_partial(func):
@@ -503,7 +492,7 @@ def lru_cache(maxsize=128, typed=False):
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
Access the underlying function with f.__wrapped__.
See: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
"""
@@ -520,6 +509,7 @@ def lru_cache(maxsize=128, typed=False):
# The user_function was passed in directly via the maxsize argument
user_function, maxsize = maxsize, 128
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
elif maxsize is not None:
raise TypeError(
@@ -527,6 +517,7 @@ def lru_cache(maxsize=128, typed=False):
def decorating_function(user_function):
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
return decorating_function
@@ -653,6 +644,15 @@ except ImportError:
pass
################################################################################
### cache -- simplified access to the infinity cache
################################################################################
def cache(user_function, /):
'Simple lightweight unbounded cache. Sometimes called "memoize".'
return lru_cache(maxsize=None)(user_function)
################################################################################
### singledispatch() - single-dispatch generic function decorator
################################################################################
@@ -660,7 +660,7 @@ except ImportError:
def _c3_merge(sequences):
"""Merges MROs in *sequences* to a single MRO using the C3 algorithm.
Adapted from http://www.python.org/download/releases/2.3/mro/.
Adapted from https://www.python.org/download/releases/2.3/mro/.
"""
result = []
@@ -740,6 +740,7 @@ def _compose_mro(cls, types):
# Remove entries which are already present in the __mro__ or unrelated.
def is_related(typ):
return (typ not in bases and hasattr(typ, '__mro__')
and not isinstance(typ, GenericAlias)
and issubclass(cls, typ))
types = [n for n in types if is_related(n)]
# Remove entries which are strict bases of other entries (they will end up
@@ -837,6 +838,17 @@ def singledispatch(func):
dispatch_cache[cls] = impl
return impl
def _is_union_type(cls):
from typing import get_origin, Union
return get_origin(cls) in {Union, types.UnionType}
def _is_valid_dispatch_type(cls):
if isinstance(cls, type):
return True
from typing import get_args
return (_is_union_type(cls) and
all(isinstance(arg, type) for arg in get_args(cls)))
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
@@ -844,9 +856,15 @@ def singledispatch(func):
"""
nonlocal cache_token
if func is None:
if isinstance(cls, type):
if _is_valid_dispatch_type(cls):
if func is None:
return lambda f: register(cls, f)
else:
if func is not None:
raise TypeError(
f"Invalid first argument to `register()`. "
f"{cls!r} is not a class or union type."
)
ann = getattr(cls, '__annotations__', {})
if not ann:
raise TypeError(
@@ -859,12 +877,25 @@ def singledispatch(func):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
if not isinstance(cls, type):
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} is not a class."
)
registry[cls] = func
if not _is_valid_dispatch_type(cls):
if _is_union_type(cls):
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} not all arguments are classes."
)
else:
raise TypeError(
f"Invalid annotation for {argname!r}. "
f"{cls!r} is not a class."
)
if _is_union_type(cls):
from typing import get_args
for arg in get_args(cls):
registry[arg] = func
else:
registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()
@@ -925,18 +956,16 @@ class singledispatchmethod:
################################################################################
### cached_property() - computed once per instance, cached as attribute
### cached_property() - property result cached as instance attribute
################################################################################
_NOT_FOUND = object()
class cached_property:
def __init__(self, func):
self.func = func
self.attrname = None
self.__doc__ = func.__doc__
self.lock = RLock()
def __set_name__(self, owner, name):
if self.attrname is None:
@@ -963,19 +992,15 @@ class cached_property:
raise TypeError(msg) from None
val = cache.get(self.attrname, _NOT_FOUND)
if val is _NOT_FOUND:
with self.lock:
# check if another thread filled cache while we awaited lock
val = cache.get(self.attrname, _NOT_FOUND)
if val is _NOT_FOUND:
val = self.func(instance)
try:
cache[self.attrname] = val
except TypeError:
msg = (
f"The '__dict__' attribute on {type(instance).__name__!r} instance "
f"does not support item assignment for caching {self.attrname!r} property."
)
raise TypeError(msg) from None
val = self.func(instance)
try:
cache[self.attrname] = val
except TypeError:
msg = (
f"The '__dict__' attribute on {type(instance).__name__!r} instance "
f"does not support item assignment for caching {self.attrname!r} property."
)
raise TypeError(msg) from None
return val
__class_getitem__ = classmethod(GenericAlias)

View File

@@ -70,41 +70,6 @@ def invalidate_caches():
finder.invalidate_caches()
def find_loader(name, path=None):
"""Return the loader for the specified module.
This is a backward-compatible wrapper around find_spec().
This function is deprecated in favor of importlib.util.find_spec().
"""
warnings.warn('Deprecated since Python 3.4 and slated for removal in '
'Python 3.12; use importlib.util.find_spec() instead',
DeprecationWarning, stacklevel=2)
try:
loader = sys.modules[name].__loader__
if loader is None:
raise ValueError('{}.__loader__ is None'.format(name))
else:
return loader
except KeyError:
pass
except AttributeError:
raise ValueError('{}.__loader__ is not set'.format(name)) from None
spec = _bootstrap._find_spec(name, path)
# We won't worry about malformed specs (missing attributes).
if spec is None:
return None
if spec.loader is None:
if spec.submodule_search_locations is None:
raise ImportError('spec for {} missing loader'.format(name),
name=name)
raise ImportError('namespace packages do not have loaders',
name=name)
return spec.loader
def import_module(name, package=None):
"""Import a module.
@@ -116,9 +81,8 @@ def import_module(name, package=None):
level = 0
if name.startswith('.'):
if not package:
msg = ("the 'package' argument is required to perform a relative "
"import for {!r}")
raise TypeError(msg.format(name))
raise TypeError("the 'package' argument is required to perform a "
f"relative import for {name!r}")
for character in name:
if character != '.':
break
@@ -144,8 +108,7 @@ def reload(module):
raise TypeError("reload() argument must be a module")
if sys.modules.get(name) is not module:
msg = "module {} not in sys.modules"
raise ImportError(msg.format(name), name=name)
raise ImportError(f"module {name} not in sys.modules", name=name)
if name in _RELOADING:
return _RELOADING[name]
_RELOADING[name] = module
@@ -155,8 +118,7 @@ def reload(module):
try:
parent = sys.modules[parent_name]
except KeyError:
msg = "parent {!r} not in sys.modules"
raise ImportError(msg.format(parent_name),
raise ImportError(f"parent {parent_name!r} not in sys.modules",
name=parent_name) from None
else:
pkgpath = parent.__path__

15
Lib/importlib/_abc.py vendored
View File

@@ -1,7 +1,6 @@
"""Subset of importlib.abc used to reduce importlib.util imports."""
from . import _bootstrap
import abc
import warnings
class Loader(metaclass=abc.ABCMeta):
@@ -38,17 +37,3 @@ class Loader(metaclass=abc.ABCMeta):
raise ImportError
# Warning implemented in _load_module_shim().
return _bootstrap._load_module_shim(self, fullname)
def module_repr(self, module):
"""Return a module's repr.
Used by the module type when the method does not raise
NotImplementedError.
This method is deprecated.
"""
warnings.warn("importlib.abc.Loader.module_repr() is deprecated and "
"slated for removal in Python 3.12", DeprecationWarning)
# The exception will cause ModuleType.__repr__ to ignore this method.
raise NotImplementedError

View File

@@ -51,17 +51,178 @@ def _new_module(name):
# Module-level locking ########################################################
# A dict mapping module names to weakrefs of _ModuleLock instances
# Dictionary protected by the global import lock
# For a list that can have a weakref to it.
class _List(list):
pass
# Copied from weakref.py with some simplifications and modifications unique to
# bootstrapping importlib. Many methods were simply deleting for simplicity, so if they
# are needed in the future they may work if simply copied back in.
class _WeakValueDictionary:
def __init__(self):
self_weakref = _weakref.ref(self)
# Inlined to avoid issues with inheriting from _weakref.ref before _weakref is
# set by _setup(). Since there's only one instance of this class, this is
# not expensive.
class KeyedRef(_weakref.ref):
__slots__ = "key",
def __new__(type, ob, key):
self = super().__new__(type, ob, type.remove)
self.key = key
return self
def __init__(self, ob, key):
super().__init__(ob, self.remove)
@staticmethod
def remove(wr):
nonlocal self_weakref
self = self_weakref()
if self is not None:
if self._iterating:
self._pending_removals.append(wr.key)
else:
_weakref._remove_dead_weakref(self.data, wr.key)
self._KeyedRef = KeyedRef
self.clear()
def clear(self):
self._pending_removals = []
self._iterating = set()
self.data = {}
def _commit_removals(self):
pop = self._pending_removals.pop
d = self.data
while True:
try:
key = pop()
except IndexError:
return
_weakref._remove_dead_weakref(d, key)
def get(self, key, default=None):
if self._pending_removals:
self._commit_removals()
try:
wr = self.data[key]
except KeyError:
return default
else:
if (o := wr()) is None:
return default
else:
return o
def setdefault(self, key, default=None):
try:
o = self.data[key]()
except KeyError:
o = None
if o is None:
if self._pending_removals:
self._commit_removals()
self.data[key] = self._KeyedRef(default, key)
return default
else:
return o
# A dict mapping module names to weakrefs of _ModuleLock instances.
# Dictionary protected by the global import lock.
_module_locks = {}
# A dict mapping thread ids to _ModuleLock instances
_blocking_on = {}
# A dict mapping thread IDs to weakref'ed lists of _ModuleLock instances.
# This maps a thread to the module locks it is blocking on acquiring. The
# values are lists because a single thread could perform a re-entrant import
# and be "in the process" of blocking on locks for more than one module. A
# thread can be "in the process" because a thread cannot actually block on
# acquiring more than one lock but it can have set up bookkeeping that reflects
# that it intends to block on acquiring more than one lock.
#
# The dictionary uses a WeakValueDictionary to avoid keeping unnecessary
# lists around, regardless of GC runs. This way there's no memory leak if
# the list is no longer needed (GH-106176).
_blocking_on = None
class _BlockingOnManager:
"""A context manager responsible to updating ``_blocking_on``."""
def __init__(self, thread_id, lock):
self.thread_id = thread_id
self.lock = lock
def __enter__(self):
"""Mark the running thread as waiting for self.lock. via _blocking_on."""
# Interactions with _blocking_on are *not* protected by the global
# import lock here because each thread only touches the state that it
# owns (state keyed on its thread id). The global import lock is
# re-entrant (i.e., a single thread may take it more than once) so it
# wouldn't help us be correct in the face of re-entrancy either.
self.blocked_on = _blocking_on.setdefault(self.thread_id, _List())
self.blocked_on.append(self.lock)
def __exit__(self, *args, **kwargs):
"""Remove self.lock from this thread's _blocking_on list."""
self.blocked_on.remove(self.lock)
class _DeadlockError(RuntimeError):
pass
def _has_deadlocked(target_id, *, seen_ids, candidate_ids, blocking_on):
"""Check if 'target_id' is holding the same lock as another thread(s).
The search within 'blocking_on' starts with the threads listed in
'candidate_ids'. 'seen_ids' contains any threads that are considered
already traversed in the search.
Keyword arguments:
target_id -- The thread id to try to reach.
seen_ids -- A set of threads that have already been visited.
candidate_ids -- The thread ids from which to begin.
blocking_on -- A dict representing the thread/blocking-on graph. This may
be the same object as the global '_blocking_on' but it is
a parameter to reduce the impact that global mutable
state has on the result of this function.
"""
if target_id in candidate_ids:
# If we have already reached the target_id, we're done - signal that it
# is reachable.
return True
# Otherwise, try to reach the target_id from each of the given candidate_ids.
for tid in candidate_ids:
if not (candidate_blocking_on := blocking_on.get(tid)):
# There are no edges out from this node, skip it.
continue
elif tid in seen_ids:
# bpo 38091: the chain of tid's we encounter here eventually leads
# to a fixed point or a cycle, but does not reach target_id.
# This means we would not actually deadlock. This can happen if
# other threads are at the beginning of acquire() below.
return False
seen_ids.add(tid)
# Follow the edges out from this thread.
edges = [lock.owner for lock in candidate_blocking_on]
if _has_deadlocked(target_id, seen_ids=seen_ids, candidate_ids=edges,
blocking_on=blocking_on):
return True
return False
class _ModuleLock:
"""A recursive lock implementation which is able to detect deadlocks
(e.g. thread 1 trying to take locks A then B, and thread 2 trying to
@@ -69,33 +230,76 @@ class _ModuleLock:
"""
def __init__(self, name):
self.lock = _thread.allocate_lock()
# Create an RLock for protecting the import process for the
# corresponding module. Since it is an RLock, a single thread will be
# able to take it more than once. This is necessary to support
# re-entrancy in the import system that arises from (at least) signal
# handlers and the garbage collector. Consider the case of:
#
# import foo
# -> ...
# -> importlib._bootstrap._ModuleLock.acquire
# -> ...
# -> <garbage collector>
# -> __del__
# -> import foo
# -> ...
# -> importlib._bootstrap._ModuleLock.acquire
# -> _BlockingOnManager.__enter__
#
# If a different thread than the running one holds the lock then the
# thread will have to block on taking the lock, which is what we want
# for thread safety.
self.lock = _thread.RLock()
self.wakeup = _thread.allocate_lock()
# The name of the module for which this is a lock.
self.name = name
# Can end up being set to None if this lock is not owned by any thread
# or the thread identifier for the owning thread.
self.owner = None
self.count = 0
self.waiters = 0
# Represent the number of times the owning thread has acquired this lock
# via a list of True. This supports RLock-like ("re-entrant lock")
# behavior, necessary in case a single thread is following a circular
# import dependency and needs to take the lock for a single module
# more than once.
#
# Counts are represented as a list of True because list.append(True)
# and list.pop() are both atomic and thread-safe in CPython and it's hard
# to find another primitive with the same properties.
self.count = []
# This is a count of the number of threads that are blocking on
# self.wakeup.acquire() awaiting to get their turn holding this module
# lock. When the module lock is released, if this is greater than
# zero, it is decremented and `self.wakeup` is released one time. The
# intent is that this will let one other thread make more progress on
# acquiring this module lock. This repeats until all the threads have
# gotten a turn.
#
# This is incremented in self.acquire() when a thread notices it is
# going to have to wait for another thread to finish.
#
# See the comment above count for explanation of the representation.
self.waiters = []
def has_deadlock(self):
# Deadlock avoidance for concurrent circular imports.
me = _thread.get_ident()
tid = self.owner
seen = set()
while True:
lock = _blocking_on.get(tid)
if lock is None:
return False
tid = lock.owner
if tid == me:
return True
if tid in seen:
# bpo 38091: the chain of tid's we encounter here
# eventually leads to a fixpoint or a cycle, but
# does not reach 'me'. This means we would not
# actually deadlock. This can happen if other
# threads are at the beginning of acquire() below.
return False
seen.add(tid)
# To avoid deadlocks for concurrent or re-entrant circular imports,
# look at _blocking_on to see if any threads are blocking
# on getting the import lock for any module for which the import lock
# is held by this thread.
return _has_deadlocked(
# Try to find this thread.
target_id=_thread.get_ident(),
seen_ids=set(),
# Start from the thread that holds the import lock for this
# module.
candidate_ids=[self.owner],
# Use the global "blocking on" state.
blocking_on=_blocking_on,
)
def acquire(self):
"""
@@ -104,39 +308,82 @@ class _ModuleLock:
Otherwise, the lock is always acquired and True is returned.
"""
tid = _thread.get_ident()
_blocking_on[tid] = self
try:
with _BlockingOnManager(tid, self):
while True:
# Protect interaction with state on self with a per-module
# lock. This makes it safe for more than one thread to try to
# acquire the lock for a single module at the same time.
with self.lock:
if self.count == 0 or self.owner == tid:
if self.count == [] or self.owner == tid:
# If the lock for this module is unowned then we can
# take the lock immediately and succeed. If the lock
# for this module is owned by the running thread then
# we can also allow the acquire to succeed. This
# supports circular imports (thread T imports module A
# which imports module B which imports module A).
self.owner = tid
self.count += 1
self.count.append(True)
return True
# At this point we know the lock is held (because count !=
# 0) by another thread (because owner != tid). We'll have
# to get in line to take the module lock.
# But first, check to see if this thread would create a
# deadlock by acquiring this module lock. If it would
# then just stop with an error.
#
# It's not clear who is expected to handle this error.
# There is one handler in _lock_unlock_module but many
# times this method is called when entering the context
# manager _ModuleLockManager instead - so _DeadlockError
# will just propagate up to application code.
#
# This seems to be more than just a hypothetical -
# https://stackoverflow.com/questions/59509154
# https://github.com/encode/django-rest-framework/issues/7078
if self.has_deadlock():
raise _DeadlockError('deadlock detected by %r' % self)
raise _DeadlockError(f'deadlock detected by {self!r}')
# Check to see if we're going to be able to acquire the
# lock. If we are going to have to wait then increment
# the waiters so `self.release` will know to unblock us
# later on. We do this part non-blockingly so we don't
# get stuck here before we increment waiters. We have
# this extra acquire call (in addition to the one below,
# outside the self.lock context manager) to make sure
# self.wakeup is held when the next acquire is called (so
# we block). This is probably needlessly complex and we
# should just take self.wakeup in the return codepath
# above.
if self.wakeup.acquire(False):
self.waiters += 1
# Wait for a release() call
self.waiters.append(None)
# Now take the lock in a blocking fashion. This won't
# complete until the thread holding this lock
# (self.owner) calls self.release.
self.wakeup.acquire()
# Taking the lock has served its purpose (making us wait), so we can
# give it up now. We'll take it w/o blocking again on the
# next iteration around this 'while' loop.
self.wakeup.release()
finally:
del _blocking_on[tid]
def release(self):
tid = _thread.get_ident()
with self.lock:
if self.owner != tid:
raise RuntimeError('cannot release un-acquired lock')
assert self.count > 0
self.count -= 1
if self.count == 0:
assert len(self.count) > 0
self.count.pop()
if not len(self.count):
self.owner = None
if self.waiters:
self.waiters -= 1
if len(self.waiters) > 0:
self.waiters.pop()
self.wakeup.release()
def __repr__(self):
return '_ModuleLock({!r}) at {}'.format(self.name, id(self))
return f'_ModuleLock({self.name!r}) at {id(self)}'
class _DummyModuleLock:
@@ -157,7 +404,7 @@ class _DummyModuleLock:
self.count -= 1
def __repr__(self):
return '_DummyModuleLock({!r}) at {}'.format(self.name, id(self))
return f'_DummyModuleLock({self.name!r}) at {id(self)}'
class _ModuleLockManager:
@@ -254,7 +501,7 @@ def _requires_builtin(fxn):
"""Decorator to verify the named module is built-in."""
def _requires_builtin_wrapper(self, fullname):
if fullname not in sys.builtin_module_names:
raise ImportError('{!r} is not a built-in module'.format(fullname),
raise ImportError(f'{fullname!r} is not a built-in module',
name=fullname)
return fxn(self, fullname)
_wrap(_requires_builtin_wrapper, fxn)
@@ -265,7 +512,7 @@ def _requires_frozen(fxn):
"""Decorator to verify the named module is frozen."""
def _requires_frozen_wrapper(self, fullname):
if not _imp.is_frozen(fullname):
raise ImportError('{!r} is not a frozen module'.format(fullname),
raise ImportError(f'{fullname!r} is not a frozen module',
name=fullname)
return fxn(self, fullname)
_wrap(_requires_frozen_wrapper, fxn)
@@ -297,11 +544,6 @@ def _module_repr(module):
loader = getattr(module, '__loader__', None)
if spec := getattr(module, "__spec__", None):
return _module_repr_from_spec(spec)
elif hasattr(loader, 'module_repr'):
try:
return loader.module_repr(module)
except Exception:
pass
# Fall through to a catch-all which always succeeds.
try:
name = module.__name__
@@ -311,11 +553,11 @@ def _module_repr(module):
filename = module.__file__
except AttributeError:
if loader is None:
return '<module {!r}>'.format(name)
return f'<module {name!r}>'
else:
return '<module {!r} ({!r})>'.format(name, loader)
return f'<module {name!r} ({loader!r})>'
else:
return '<module {!r} from {!r}>'.format(name, filename)
return f'<module {name!r} from {filename!r}>'
class ModuleSpec:
@@ -369,14 +611,12 @@ class ModuleSpec:
self._cached = None
def __repr__(self):
args = ['name={!r}'.format(self.name),
'loader={!r}'.format(self.loader)]
args = [f'name={self.name!r}', f'loader={self.loader!r}']
if self.origin is not None:
args.append('origin={!r}'.format(self.origin))
args.append(f'origin={self.origin!r}')
if self.submodule_search_locations is not None:
args.append('submodule_search_locations={}'
.format(self.submodule_search_locations))
return '{}({})'.format(self.__class__.__name__, ', '.join(args))
args.append(f'submodule_search_locations={self.submodule_search_locations}')
return f'{self.__class__.__name__}({", ".join(args)})'
def __eq__(self, other):
smsl = self.submodule_search_locations
@@ -583,18 +823,17 @@ def module_from_spec(spec):
def _module_repr_from_spec(spec):
"""Return the repr to use for the module."""
# We mostly replicate _module_repr() using the spec attributes.
name = '?' if spec.name is None else spec.name
if spec.origin is None:
if spec.loader is None:
return '<module {!r}>'.format(name)
return f'<module {name!r}>'
else:
return '<module {!r} ({!r})>'.format(name, spec.loader)
return f'<module {name!r} (namespace) from {list(spec.loader._path)}>'
else:
if spec.has_location:
return '<module {!r} from {!r}>'.format(name, spec.origin)
return f'<module {name!r} from {spec.origin!r}>'
else:
return '<module {!r} ({})>'.format(spec.name, spec.origin)
return f'<module {spec.name!r} ({spec.origin})>'
# Used by importlib.reload() and _load_module_shim().
@@ -603,7 +842,7 @@ def _exec(spec, module):
name = spec.name
with _ModuleLockManager(name):
if sys.modules.get(name) is not module:
msg = 'module {!r} not in sys.modules'.format(name)
msg = f'module {name!r} not in sys.modules'
raise ImportError(msg, name=name)
try:
if spec.loader is None:
@@ -735,46 +974,18 @@ class BuiltinImporter:
_ORIGIN = "built-in"
@staticmethod
def module_repr(module):
"""Return repr for the module.
The method is deprecated. The import machinery does the job itself.
"""
_warnings.warn("BuiltinImporter.module_repr() is deprecated and "
"slated for removal in Python 3.12", DeprecationWarning)
return f'<module {module.__name__!r} ({BuiltinImporter._ORIGIN})>'
@classmethod
def find_spec(cls, fullname, path=None, target=None):
if path is not None:
return None
if _imp.is_builtin(fullname):
return spec_from_loader(fullname, cls, origin=cls._ORIGIN)
else:
return None
@classmethod
def find_module(cls, fullname, path=None):
"""Find the built-in module.
If 'path' is ever specified then the search is considered a failure.
This method is deprecated. Use find_spec() instead.
"""
_warnings.warn("BuiltinImporter.find_module() is deprecated and "
"slated for removal in Python 3.12; use find_spec() instead",
DeprecationWarning)
spec = cls.find_spec(fullname, path)
return spec.loader if spec is not None else None
@staticmethod
def create_module(spec):
"""Create a built-in module"""
if spec.name not in sys.builtin_module_names:
raise ImportError('{!r} is not a built-in module'.format(spec.name),
raise ImportError(f'{spec.name!r} is not a built-in module',
name=spec.name)
return _call_with_frames_removed(_imp.create_builtin, spec)
@@ -815,17 +1026,6 @@ class FrozenImporter:
_ORIGIN = "frozen"
@staticmethod
def module_repr(m):
"""Return repr for the module.
The method is deprecated. The import machinery does the job itself.
"""
_warnings.warn("FrozenImporter.module_repr() is deprecated and "
"slated for removal in Python 3.12", DeprecationWarning)
return '<module {!r} ({})>'.format(m.__name__, FrozenImporter._ORIGIN)
@classmethod
def _fix_up_module(cls, module):
spec = module.__spec__
@@ -950,18 +1150,6 @@ class FrozenImporter:
spec.submodule_search_locations.insert(0, pkgdir)
return spec
@classmethod
def find_module(cls, fullname, path=None):
"""Find a frozen module.
This method is deprecated. Use find_spec() instead.
"""
_warnings.warn("FrozenImporter.find_module() is deprecated and "
"slated for removal in Python 3.12; use find_spec() instead",
DeprecationWarning)
return cls if _imp.is_frozen(fullname) else None
@staticmethod
def create_module(spec):
"""Set __file__, if able."""
@@ -1041,17 +1229,7 @@ def _resolve_name(name, package, level):
if len(bits) < level:
raise ImportError('attempted relative import beyond top-level package')
base = bits[0]
return '{}.{}'.format(base, name) if name else base
def _find_spec_legacy(finder, name, path):
msg = (f"{_object_name(finder)}.find_spec() not found; "
"falling back to find_module()")
_warnings.warn(msg, ImportWarning)
loader = finder.find_module(name, path)
if loader is None:
return None
return spec_from_loader(name, loader)
return f'{base}.{name}' if name else base
def _find_spec(name, path, target=None):
@@ -1074,9 +1252,7 @@ def _find_spec(name, path, target=None):
try:
find_spec = finder.find_spec
except AttributeError:
spec = _find_spec_legacy(finder, name, path)
if spec is None:
continue
continue
else:
spec = find_spec(name, path, target)
if spec is not None:
@@ -1104,7 +1280,7 @@ def _find_spec(name, path, target=None):
def _sanity_check(name, package, level):
"""Verify arguments are "sane"."""
if not isinstance(name, str):
raise TypeError('module name must be str, not {}'.format(type(name)))
raise TypeError(f'module name must be str, not {type(name)}')
if level < 0:
raise ValueError('level must be >= 0')
if level > 0:
@@ -1134,13 +1310,13 @@ def _find_and_load_unlocked(name, import_):
try:
path = parent_module.__path__
except AttributeError:
msg = (_ERR_MSG + '; {!r} is not a package').format(name, parent)
msg = f'{_ERR_MSG_PREFIX}{name!r}; {parent!r} is not a package'
raise ModuleNotFoundError(msg, name=name) from None
parent_spec = parent_module.__spec__
child = name.rpartition('.')[2]
spec = _find_spec(name, path)
if spec is None:
raise ModuleNotFoundError(_ERR_MSG.format(name), name=name)
raise ModuleNotFoundError(f'{_ERR_MSG_PREFIX}{name!r}', name=name)
else:
if parent_spec:
# Temporarily add child we are currently importing to parent's
@@ -1185,8 +1361,7 @@ def _find_and_load(name, import_):
_lock_unlock_module(name)
if module is None:
message = ('import of {} halted; '
'None in sys.modules'.format(name))
message = f'import of {name} halted; None in sys.modules'
raise ModuleNotFoundError(message, name=name)
return module
@@ -1230,7 +1405,7 @@ def _handle_fromlist(module, fromlist, import_, *, recursive=False):
_handle_fromlist(module, module.__all__, import_,
recursive=True)
elif not hasattr(module, x):
from_name = '{}.{}'.format(module.__name__, x)
from_name = f'{module.__name__}.{x}'
try:
_call_with_frames_removed(import_, from_name)
except ModuleNotFoundError as exc:
@@ -1257,7 +1432,7 @@ def _calc___package__(globals):
if spec is not None and package != spec.parent:
_warnings.warn("__package__ != __spec__.parent "
f"({package!r} != {spec.parent!r})",
ImportWarning, stacklevel=3)
DeprecationWarning, stacklevel=3)
return package
elif spec is not None:
return spec.parent
@@ -1323,7 +1498,7 @@ def _setup(sys_module, _imp_module):
modules, those two modules must be explicitly passed in.
"""
global _imp, sys
global _imp, sys, _blocking_on
_imp = _imp_module
sys = sys_module
@@ -1351,6 +1526,9 @@ def _setup(sys_module, _imp_module):
builtin_module = sys.modules[builtin_name]
setattr(self_module, builtin_name, builtin_module)
# Instantiation requires _weakref to have been set.
_blocking_on = _WeakValueDictionary()
def _install(sys_module, _imp_module):
"""Install importers for builtin and frozen modules"""

View File

@@ -182,12 +182,22 @@ else:
return path.startswith(path_separators)
def _path_abspath(path):
"""Replacement for os.path.abspath."""
if not _path_isabs(path):
for sep in path_separators:
path = path.removeprefix(f".{sep}")
return _path_join(_os.getcwd(), path)
else:
return path
def _write_atomic(path, data, mode=0o666):
"""Best-effort function to write data to a path atomically.
Be prepared to handle a FileExistsError if concurrent writing of the
temporary file is attempted."""
# id() is used to generate a pseudo-random filename.
path_tmp = '{}.{}'.format(path, id(path))
path_tmp = f'{path}.{id(path)}'
fd = _os.open(path_tmp,
_os.O_EXCL | _os.O_CREAT | _os.O_WRONLY, mode & 0o666)
try:
@@ -403,11 +413,45 @@ _code_type = type(_write_atomic.__code__)
# Python 3.11a7 3492 (make POP_JUMP_IF_NONE/NOT_NONE/TRUE/FALSE relative)
# Python 3.11a7 3493 (Make JUMP_IF_TRUE_OR_POP/JUMP_IF_FALSE_OR_POP relative)
# Python 3.11a7 3494 (New location info table)
# Python 3.11b4 3495 (Set line number of module's RESUME instr to 0 per PEP 626)
# Python 3.12 will start with magic number 3500
# Python 3.12a1 3500 (Remove PRECALL opcode)
# Python 3.12a1 3501 (YIELD_VALUE oparg == stack_depth)
# Python 3.12a1 3502 (LOAD_FAST_CHECK, no NULL-check in LOAD_FAST)
# Python 3.12a1 3503 (Shrink LOAD_METHOD cache)
# Python 3.12a1 3504 (Merge LOAD_METHOD back into LOAD_ATTR)
# Python 3.12a1 3505 (Specialization/Cache for FOR_ITER)
# Python 3.12a1 3506 (Add BINARY_SLICE and STORE_SLICE instructions)
# Python 3.12a1 3507 (Set lineno of module's RESUME to 0)
# Python 3.12a1 3508 (Add CLEANUP_THROW)
# Python 3.12a1 3509 (Conditional jumps only jump forward)
# Python 3.12a2 3510 (FOR_ITER leaves iterator on the stack)
# Python 3.12a2 3511 (Add STOPITERATION_ERROR instruction)
# Python 3.12a2 3512 (Remove all unused consts from code objects)
# Python 3.12a4 3513 (Add CALL_INTRINSIC_1 instruction, removed STOPITERATION_ERROR, PRINT_EXPR, IMPORT_STAR)
# Python 3.12a4 3514 (Remove ASYNC_GEN_WRAP, LIST_TO_TUPLE, and UNARY_POSITIVE)
# Python 3.12a5 3515 (Embed jump mask in COMPARE_OP oparg)
# Python 3.12a5 3516 (Add COMPARE_AND_BRANCH instruction)
# Python 3.12a5 3517 (Change YIELD_VALUE oparg to exception block depth)
# Python 3.12a6 3518 (Add RETURN_CONST instruction)
# Python 3.12a6 3519 (Modify SEND instruction)
# Python 3.12a6 3520 (Remove PREP_RERAISE_STAR, add CALL_INTRINSIC_2)
# Python 3.12a7 3521 (Shrink the LOAD_GLOBAL caches)
# Python 3.12a7 3522 (Removed JUMP_IF_FALSE_OR_POP/JUMP_IF_TRUE_OR_POP)
# Python 3.12a7 3523 (Convert COMPARE_AND_BRANCH back to COMPARE_OP)
# Python 3.12a7 3524 (Shrink the BINARY_SUBSCR caches)
# Python 3.12b1 3525 (Shrink the CALL caches)
# Python 3.12b1 3526 (Add instrumentation support)
# Python 3.12b1 3527 (Add LOAD_SUPER_ATTR)
# Python 3.12b1 3528 (Add LOAD_SUPER_ATTR_METHOD specialization)
# Python 3.12b1 3529 (Inline list/dict/set comprehensions)
# Python 3.12b1 3530 (Shrink the LOAD_SUPER_ATTR caches)
# Python 3.12b1 3531 (Add PEP 695 changes)
# Python 3.13 will start with 3550
# Please don't copy-paste the same pre-release tag for new entries above!!!
# You should always use the *upcoming* tag. For example, if 3.12a6 came out
# a week ago, I should put "Python 3.12a7" next to my new magic number.
#
# MAGIC must change whenever the bytecode emitted by the compiler may no
# longer be understood by older implementations of the eval loop (usually
# due to the addition of new opcodes).
@@ -417,7 +461,7 @@ _code_type = type(_write_atomic.__code__)
# Whenever MAGIC_NUMBER is changed, the ranges in the magic_values array
# in PC/launcher.c must also be updated.
MAGIC_NUMBER = (3495).to_bytes(2, 'little') + b'\r\n'
MAGIC_NUMBER = (3531).to_bytes(2, 'little') + b'\r\n'
_RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c
@@ -474,8 +518,8 @@ def cache_from_source(path, debug_override=None, *, optimization=None):
optimization = str(optimization)
if optimization != '':
if not optimization.isalnum():
raise ValueError('{!r} is not alphanumeric'.format(optimization))
almost_filename = '{}.{}{}'.format(almost_filename, _OPT, optimization)
raise ValueError(f'{optimization!r} is not alphanumeric')
almost_filename = f'{almost_filename}.{_OPT}{optimization}'
filename = almost_filename + BYTECODE_SUFFIXES[0]
if sys.pycache_prefix is not None:
# We need an absolute path to the py file to avoid the possibility of
@@ -486,8 +530,7 @@ def cache_from_source(path, debug_override=None, *, optimization=None):
# make it absolute (`C:\Somewhere\Foo\Bar`), then make it root-relative
# (`Somewhere\Foo\Bar`), so we end up placing the bytecode file in an
# unambiguous `C:\Bytecode\Somewhere\Foo\Bar\`.
if not _path_isabs(head):
head = _path_join(_os.getcwd(), head)
head = _path_abspath(head)
# Strip initial drive from a Windows path. We know we have an absolute
# path here, so the second part of the check rules out a POSIX path that
@@ -619,26 +662,6 @@ def _check_name(method):
return _check_name_wrapper
def _find_module_shim(self, fullname):
"""Try to find a loader for the specified module by delegating to
self.find_loader().
This method is deprecated in favor of finder.find_spec().
"""
_warnings.warn("find_module() is deprecated and "
"slated for removal in Python 3.12; use find_spec() instead",
DeprecationWarning)
# Call find_loader(). If it returns a string (indicating this
# is a namespace package portion), generate a warning and
# return None.
loader, portions = self.find_loader(fullname)
if loader is None and len(portions):
msg = 'Not importing directory {}: missing __init__'
_warnings.warn(msg.format(portions[0]), ImportWarning)
return loader
def _classify_pyc(data, name, exc_details):
"""Perform basic validity checking of a pyc header and return the flags field,
which determines how the pyc should be further validated against the source.
@@ -733,7 +756,7 @@ def _compile_bytecode(data, name=None, bytecode_path=None, source_path=None):
_imp._fix_co_filename(code, source_path)
return code
else:
raise ImportError('Non-code object in {!r}'.format(bytecode_path),
raise ImportError(f'Non-code object in {bytecode_path!r}',
name=name, path=bytecode_path)
@@ -800,11 +823,10 @@ def spec_from_file_location(name, location=None, *, loader=None,
pass
else:
location = _os.fspath(location)
if not _path_isabs(location):
try:
location = _path_join(_os.getcwd(), location)
except OSError:
pass
try:
location = _path_abspath(location)
except OSError:
pass
# If the location is on the filesystem, but doesn't actually exist,
# we could return None here, indicating that the location is not
@@ -846,6 +868,54 @@ def spec_from_file_location(name, location=None, *, loader=None,
return spec
def _bless_my_loader(module_globals):
"""Helper function for _warnings.c
See GH#97850 for details.
"""
# 2022-10-06(warsaw): For now, this helper is only used in _warnings.c and
# that use case only has the module globals. This function could be
# extended to accept either that or a module object. However, in the
# latter case, it would be better to raise certain exceptions when looking
# at a module, which should have either a __loader__ or __spec__.loader.
# For backward compatibility, it is possible that we'll get an empty
# dictionary for the module globals, and that cannot raise an exception.
if not isinstance(module_globals, dict):
return None
missing = object()
loader = module_globals.get('__loader__', None)
spec = module_globals.get('__spec__', missing)
if loader is None:
if spec is missing:
# If working with a module:
# raise AttributeError('Module globals is missing a __spec__')
return None
elif spec is None:
raise ValueError('Module globals is missing a __spec__.loader')
spec_loader = getattr(spec, 'loader', missing)
if spec_loader in (missing, None):
if loader is None:
exc = AttributeError if spec_loader is missing else ValueError
raise exc('Module globals is missing a __spec__.loader')
_warnings.warn(
'Module globals is missing a __spec__.loader',
DeprecationWarning)
spec_loader = loader
assert spec_loader is not None
if loader is not None and loader != spec_loader:
_warnings.warn(
'Module globals; __loader__ != __spec__.loader',
DeprecationWarning)
return loader
return spec_loader
# Loaders #####################################################################
class WindowsRegistryFinder:
@@ -898,22 +968,6 @@ class WindowsRegistryFinder:
origin=filepath)
return spec
@classmethod
def find_module(cls, fullname, path=None):
"""Find module named in the registry.
This method is deprecated. Use find_spec() instead.
"""
_warnings.warn("WindowsRegistryFinder.find_module() is deprecated and "
"slated for removal in Python 3.12; use find_spec() instead",
DeprecationWarning)
spec = cls.find_spec(fullname, path)
if spec is not None:
return spec.loader
else:
return None
class _LoaderBasics:
@@ -935,8 +989,8 @@ class _LoaderBasics:
"""Execute the module."""
code = self.get_code(module.__name__)
if code is None:
raise ImportError('cannot load module {!r} when get_code() '
'returns None'.format(module.__name__))
raise ImportError(f'cannot load module {module.__name__!r} when '
'get_code() returns None')
_bootstrap._call_with_frames_removed(exec, code, module.__dict__)
def load_module(self, fullname):
@@ -1077,7 +1131,8 @@ class SourceLoader(_LoaderBasics):
source_mtime is not None):
if hash_based:
if source_hash is None:
source_hash = _imp.source_hash(source_bytes)
source_hash = _imp.source_hash(_RAW_MAGIC_NUMBER,
source_bytes)
data = _code_to_hash_pyc(code_object, source_hash, check_source)
else:
data = _code_to_timestamp_pyc(code_object, source_mtime,
@@ -1321,7 +1376,7 @@ class _NamespacePath:
return len(self._recalculate())
def __repr__(self):
return '_NamespacePath({!r})'.format(self._path)
return f'_NamespacePath({self._path!r})'
def __contains__(self, item):
return item in self._recalculate()
@@ -1332,22 +1387,11 @@ class _NamespacePath:
# This class is actually exposed publicly in a namespace package's __loader__
# attribute, so it should be available through a non-private name.
# https://bugs.python.org/issue35673
# https://github.com/python/cpython/issues/92054
class NamespaceLoader:
def __init__(self, name, path, path_finder):
self._path = _NamespacePath(name, path, path_finder)
@staticmethod
def module_repr(module):
"""Return repr for the module.
The method is deprecated. The import machinery does the job itself.
"""
_warnings.warn("NamespaceLoader.module_repr() is deprecated and "
"slated for removal in Python 3.12", DeprecationWarning)
return '<module {!r} (namespace)>'.format(module.__name__)
def is_package(self, fullname):
return True
@@ -1440,27 +1484,6 @@ class PathFinder:
sys.path_importer_cache[path] = finder
return finder
@classmethod
def _legacy_get_spec(cls, fullname, finder):
# This would be a good place for a DeprecationWarning if
# we ended up going that route.
if hasattr(finder, 'find_loader'):
msg = (f"{_bootstrap._object_name(finder)}.find_spec() not found; "
"falling back to find_loader()")
_warnings.warn(msg, ImportWarning)
loader, portions = finder.find_loader(fullname)
else:
msg = (f"{_bootstrap._object_name(finder)}.find_spec() not found; "
"falling back to find_module()")
_warnings.warn(msg, ImportWarning)
loader = finder.find_module(fullname)
portions = []
if loader is not None:
return _bootstrap.spec_from_loader(fullname, loader)
spec = _bootstrap.ModuleSpec(fullname, None)
spec.submodule_search_locations = portions
return spec
@classmethod
def _get_spec(cls, fullname, path, target=None):
"""Find the loader or namespace_path for this module/package name."""
@@ -1472,10 +1495,7 @@ class PathFinder:
continue
finder = cls._path_importer_cache(entry)
if finder is not None:
if hasattr(finder, 'find_spec'):
spec = finder.find_spec(fullname, target)
else:
spec = cls._legacy_get_spec(fullname, finder)
spec = finder.find_spec(fullname, target)
if spec is None:
continue
if spec.loader is not None:
@@ -1517,22 +1537,6 @@ class PathFinder:
else:
return spec
@classmethod
def find_module(cls, fullname, path=None):
"""find the module on sys.path or 'path' based on sys.path_hooks and
sys.path_importer_cache.
This method is deprecated. Use find_spec() instead.
"""
_warnings.warn("PathFinder.find_module() is deprecated and "
"slated for removal in Python 3.12; use find_spec() instead",
DeprecationWarning)
spec = cls.find_spec(fullname, path)
if spec is None:
return None
return spec.loader
@staticmethod
def find_distributions(*args, **kwargs):
"""
@@ -1567,10 +1571,8 @@ class FileFinder:
# Base (directory) path
if not path or path == '.':
self.path = _os.getcwd()
elif not _path_isabs(path):
self.path = _path_join(_os.getcwd(), path)
else:
self.path = path
self.path = _path_abspath(path)
self._path_mtime = -1
self._path_cache = set()
self._relaxed_path_cache = set()
@@ -1579,23 +1581,6 @@ class FileFinder:
"""Invalidate the directory mtime."""
self._path_mtime = -1
find_module = _find_module_shim
def find_loader(self, fullname):
"""Try to find a loader for the specified module, or the namespace
package portions. Returns (loader, list-of-portions).
This method is deprecated. Use find_spec() instead.
"""
_warnings.warn("FileFinder.find_loader() is deprecated and "
"slated for removal in Python 3.12; use find_spec() instead",
DeprecationWarning)
spec = self.find_spec(fullname)
if spec is None:
return None, []
return spec.loader, spec.submodule_search_locations or []
def _get_spec(self, loader_class, fullname, path, smsl, target):
loader = loader_class(fullname, path)
return spec_from_file_location(fullname, path, loader=loader,
@@ -1675,7 +1660,7 @@ class FileFinder:
for item in contents:
name, dot, suffix = item.partition('.')
if dot:
new_name = '{}.{}'.format(name, suffix.lower())
new_name = f'{name}.{suffix.lower()}'
else:
new_name = name
lower_suffix_contents.add(new_name)
@@ -1702,7 +1687,7 @@ class FileFinder:
return path_hook_for_FileFinder
def __repr__(self):
return 'FileFinder({!r})'.format(self.path)
return f'FileFinder({self.path!r})'
# Import setup ###############################################################
@@ -1720,6 +1705,8 @@ def _fix_up_module(ns, name, pathname, cpathname=None):
loader = SourceFileLoader(name, pathname)
if not spec:
spec = spec_from_file_location(name, pathname, loader=loader)
if cpathname:
spec.cached = _path_abspath(cpathname)
try:
ns['__spec__'] = spec
ns['__loader__'] = loader

111
Lib/importlib/abc.py vendored
View File

@@ -15,20 +15,29 @@ from ._abc import Loader
import abc
import warnings
# for compatibility with Python 3.10
from .resources.abc import ResourceReader, Traversable, TraversableResources
from .resources import abc as _resources_abc
__all__ = [
'Loader', 'Finder', 'MetaPathFinder', 'PathEntryFinder',
'Loader', 'MetaPathFinder', 'PathEntryFinder',
'ResourceLoader', 'InspectLoader', 'ExecutionLoader',
'FileLoader', 'SourceLoader',
# for compatibility with Python 3.10
'ResourceReader', 'Traversable', 'TraversableResources',
]
def __getattr__(name):
"""
For backwards compatibility, continue to make names
from _resources_abc available through this module. #93963
"""
if name in _resources_abc.__all__:
obj = getattr(_resources_abc, name)
warnings._deprecated(f"{__name__}.{name}", remove=(3, 14))
globals()[name] = obj
return obj
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
def _register(abstract_cls, *classes):
for cls in classes:
abstract_cls.register(cls)
@@ -40,38 +49,6 @@ def _register(abstract_cls, *classes):
abstract_cls.register(frozen_cls)
class Finder(metaclass=abc.ABCMeta):
"""Legacy abstract base class for import finders.
It may be subclassed for compatibility with legacy third party
reimplementations of the import system. Otherwise, finder
implementations should derive from the more specific MetaPathFinder
or PathEntryFinder ABCs.
Deprecated since Python 3.3
"""
def __init__(self):
warnings.warn("the Finder ABC is deprecated and "
"slated for removal in Python 3.12; use MetaPathFinder "
"or PathEntryFinder instead",
DeprecationWarning)
@abc.abstractmethod
def find_module(self, fullname, path=None):
"""An abstract method that should find a module.
The fullname is a str and the optional path is a str or None.
Returns a Loader object or None.
"""
warnings.warn("importlib.abc.Finder along with its find_module() "
"method are deprecated and "
"slated for removal in Python 3.12; use "
"MetaPathFinder.find_spec() or "
"PathEntryFinder.find_spec() instead",
DeprecationWarning)
class MetaPathFinder(metaclass=abc.ABCMeta):
"""Abstract base class for import finders on sys.meta_path."""
@@ -79,27 +56,6 @@ class MetaPathFinder(metaclass=abc.ABCMeta):
# We don't define find_spec() here since that would break
# hasattr checks we do to support backward compatibility.
def find_module(self, fullname, path):
"""Return a loader for the module.
If no module is found, return None. The fullname is a str and
the path is a list of strings or None.
This method is deprecated since Python 3.4 in favor of
finder.find_spec(). If find_spec() exists then backwards-compatible
functionality is provided for this method.
"""
warnings.warn("MetaPathFinder.find_module() is deprecated since Python "
"3.4 in favor of MetaPathFinder.find_spec() and is "
"slated for removal in Python 3.12",
DeprecationWarning,
stacklevel=2)
if not hasattr(self, 'find_spec'):
return None
found = self.find_spec(fullname, path)
return found.loader if found is not None else None
def invalidate_caches(self):
"""An optional method for clearing the finder's cache, if any.
This method is used by importlib.invalidate_caches().
@@ -113,43 +69,6 @@ class PathEntryFinder(metaclass=abc.ABCMeta):
"""Abstract base class for path entry finders used by PathFinder."""
# We don't define find_spec() here since that would break
# hasattr checks we do to support backward compatibility.
def find_loader(self, fullname):
"""Return (loader, namespace portion) for the path entry.
The fullname is a str. The namespace portion is a sequence of
path entries contributing to part of a namespace package. The
sequence may be empty. If loader is not None, the portion will
be ignored.
The portion will be discarded if another path entry finder
locates the module as a normal module or package.
This method is deprecated since Python 3.4 in favor of
finder.find_spec(). If find_spec() is provided than backwards-compatible
functionality is provided.
"""
warnings.warn("PathEntryFinder.find_loader() is deprecated since Python "
"3.4 in favor of PathEntryFinder.find_spec() "
"(available since 3.4)",
DeprecationWarning,
stacklevel=2)
if not hasattr(self, 'find_spec'):
return None, []
found = self.find_spec(fullname)
if found is not None:
if not found.submodule_search_locations:
portions = []
else:
portions = found.submodule_search_locations
return found.loader, portions
else:
return None, []
find_module = _bootstrap_external._find_module_shim
def invalidate_caches(self):
"""An optional method for clearing the finder's cache, if any.
This method is used by PathFinder.invalidate_caches().

View File

@@ -12,7 +12,9 @@ import warnings
import functools
import itertools
import posixpath
import contextlib
import collections
import inspect
from . import _adapters, _meta
from ._collections import FreezableDefaultDict, Pair
@@ -24,7 +26,7 @@ from contextlib import suppress
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, cast
__all__ = [
@@ -140,6 +142,7 @@ class DeprecatedTuple:
1
"""
# Do not remove prior to 2023-05-01 or Python 3.13
_warn = functools.partial(
warnings.warn,
"EntryPoint tuple interface is deprecated. Access members by name.",
@@ -228,17 +231,6 @@ class EntryPoint(DeprecatedTuple):
vars(self).update(dist=dist)
return self
def __iter__(self):
"""
Supply iter so one may construct dicts of EntryPoints by name.
"""
msg = (
"Construction of dict of EntryPoints is deprecated in "
"favor of EntryPoints."
)
warnings.warn(msg, DeprecationWarning)
return iter((self.name, self))
def matches(self, **params):
"""
EntryPoint matches the given parameters.
@@ -284,77 +276,7 @@ class EntryPoint(DeprecatedTuple):
return hash(self._key())
class DeprecatedList(list):
"""
Allow an otherwise immutable object to implement mutability
for compatibility.
>>> recwarn = getfixture('recwarn')
>>> dl = DeprecatedList(range(3))
>>> dl[0] = 1
>>> dl.append(3)
>>> del dl[3]
>>> dl.reverse()
>>> dl.sort()
>>> dl.extend([4])
>>> dl.pop(-1)
4
>>> dl.remove(1)
>>> dl += [5]
>>> dl + [6]
[1, 2, 5, 6]
>>> dl + (6,)
[1, 2, 5, 6]
>>> dl.insert(0, 0)
>>> dl
[0, 1, 2, 5]
>>> dl == [0, 1, 2, 5]
True
>>> dl == (0, 1, 2, 5)
True
>>> len(recwarn)
1
"""
__slots__ = ()
_warn = functools.partial(
warnings.warn,
"EntryPoints list interface is deprecated. Cast to list if needed.",
DeprecationWarning,
stacklevel=2,
)
def _wrap_deprecated_method(method_name: str): # type: ignore
def wrapped(self, *args, **kwargs):
self._warn()
return getattr(super(), method_name)(*args, **kwargs)
return method_name, wrapped
locals().update(
map(
_wrap_deprecated_method,
'__setitem__ __delitem__ append reverse extend pop remove '
'__iadd__ insert sort'.split(),
)
)
def __add__(self, other):
if not isinstance(other, tuple):
self._warn()
other = tuple(other)
return self.__class__(tuple(self) + other)
def __eq__(self, other):
if not isinstance(other, tuple):
self._warn()
other = tuple(other)
return tuple(self).__eq__(other)
class EntryPoints(DeprecatedList):
class EntryPoints(tuple):
"""
An immutable collection of selectable EntryPoint objects.
"""
@@ -365,14 +287,6 @@ class EntryPoints(DeprecatedList):
"""
Get the EntryPoint in self matching name.
"""
if isinstance(name, int):
warnings.warn(
"Accessing entry points by index is deprecated. "
"Cast to tuple if needed.",
DeprecationWarning,
stacklevel=2,
)
return super().__getitem__(name)
try:
return next(iter(self.select(name=name)))
except StopIteration:
@@ -396,10 +310,6 @@ class EntryPoints(DeprecatedList):
def groups(self):
"""
Return the set of all groups of all entry points.
For coverage while SelectableGroups is present.
>>> EntryPoints().groups
set()
"""
return {ep.group for ep in self}
@@ -415,101 +325,6 @@ class EntryPoints(DeprecatedList):
)
class Deprecated:
"""
Compatibility add-in for mapping to indicate that
mapping behavior is deprecated.
>>> recwarn = getfixture('recwarn')
>>> class DeprecatedDict(Deprecated, dict): pass
>>> dd = DeprecatedDict(foo='bar')
>>> dd.get('baz', None)
>>> dd['foo']
'bar'
>>> list(dd)
['foo']
>>> list(dd.keys())
['foo']
>>> 'foo' in dd
True
>>> list(dd.values())
['bar']
>>> len(recwarn)
1
"""
_warn = functools.partial(
warnings.warn,
"SelectableGroups dict interface is deprecated. Use select.",
DeprecationWarning,
stacklevel=2,
)
def __getitem__(self, name):
self._warn()
return super().__getitem__(name)
def get(self, name, default=None):
self._warn()
return super().get(name, default)
def __iter__(self):
self._warn()
return super().__iter__()
def __contains__(self, *args):
self._warn()
return super().__contains__(*args)
def keys(self):
self._warn()
return super().keys()
def values(self):
self._warn()
return super().values()
class SelectableGroups(Deprecated, dict):
"""
A backward- and forward-compatible result from
entry_points that fully implements the dict interface.
"""
@classmethod
def load(cls, eps):
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return cls((group, EntryPoints(eps)) for group, eps in grouped)
@property
def _all(self):
"""
Reconstruct a list of all entrypoints from the groups.
"""
groups = super(Deprecated, self).values()
return EntryPoints(itertools.chain.from_iterable(groups))
@property
def groups(self):
return self._all.groups
@property
def names(self):
"""
for coverage:
>>> SelectableGroups().names
set()
"""
return self._all.names
def select(self, **params):
if not params:
return self
return self._all.select(**params)
class PackagePath(pathlib.PurePosixPath):
"""A reference to a path in a package"""
@@ -534,11 +349,30 @@ class FileHash:
return f'<FileHash mode: {self.mode} value: {self.value}>'
class Distribution:
class DeprecatedNonAbstract:
def __new__(cls, *args, **kwargs):
all_names = {
name for subclass in inspect.getmro(cls) for name in vars(subclass)
}
abstract = {
name
for name in all_names
if getattr(getattr(cls, name), '__isabstractmethod__', False)
}
if abstract:
warnings.warn(
f"Unimplemented abstract methods {abstract}",
DeprecationWarning,
stacklevel=2,
)
return super().__new__(cls)
class Distribution(DeprecatedNonAbstract):
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
def read_text(self, filename) -> Optional[str]:
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
@@ -612,7 +446,7 @@ class Distribution:
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
text = (
opt_text = (
self.read_text('METADATA')
or self.read_text('PKG-INFO')
# This last clause is here to support old egg-info files. Its
@@ -620,6 +454,7 @@ class Distribution:
# (which points to the egg-info file) attribute unchanged.
or self.read_text('')
)
text = cast(str, opt_text)
return _adapters.Message(email.message_from_string(text))
@property
@@ -648,8 +483,8 @@ class Distribution:
:return: List of PackagePath for this distribution or None
Result is `None` if the metadata file that enumerates files
(i.e. RECORD for dist-info or SOURCES.txt for egg-info) is
missing.
(i.e. RECORD for dist-info, or installed-files.txt or
SOURCES.txt for egg-info) is missing.
Result may be empty if the metadata exists but is empty.
"""
@@ -662,9 +497,19 @@ class Distribution:
@pass_none
def make_files(lines):
return list(starmap(make_file, csv.reader(lines)))
return starmap(make_file, csv.reader(lines))
return make_files(self._read_files_distinfo() or self._read_files_egginfo())
@pass_none
def skip_missing_files(package_paths):
return list(filter(lambda path: path.locate().exists(), package_paths))
return skip_missing_files(
make_files(
self._read_files_distinfo()
or self._read_files_egginfo_installed()
or self._read_files_egginfo_sources()
)
)
def _read_files_distinfo(self):
"""
@@ -673,10 +518,45 @@ class Distribution:
text = self.read_text('RECORD')
return text and text.splitlines()
def _read_files_egginfo(self):
def _read_files_egginfo_installed(self):
"""
SOURCES.txt might contain literal commas, so wrap each line
in quotes.
Read installed-files.txt and return lines in a similar
CSV-parsable format as RECORD: each file must be placed
relative to the site-packages directory and must also be
quoted (since file names can contain literal commas).
This file is written when the package is installed by pip,
but it might not be written for other installation methods.
Assume the file is accurate if it exists.
"""
text = self.read_text('installed-files.txt')
# Prepend the .egg-info/ subdir to the lines in this file.
# But this subdir is only available from PathDistribution's
# self._path.
subdir = getattr(self, '_path', None)
if not text or not subdir:
return
paths = (
(subdir / name)
.resolve()
.relative_to(self.locate_file('').resolve())
.as_posix()
for name in text.splitlines()
)
return map('"{}"'.format, paths)
def _read_files_egginfo_sources(self):
"""
Read SOURCES.txt and return lines in a similar CSV-parsable
format as RECORD: each file name must be quoted (since it
might contain literal commas).
Note that SOURCES.txt is not a reliable source for what
files are installed by a package. This file is generated
for a source archive, and the files that are present
there (e.g. setup.py) may not correctly reflect the files
that are present after the package has been installed.
"""
text = self.read_text('SOURCES.txt')
return text and map('"{}"'.format, text.splitlines())
@@ -1023,27 +903,19 @@ Wrapper for ``distributions`` to return unique distributions by name.
"""
def entry_points(**params) -> Union[EntryPoints, SelectableGroups]:
def entry_points(**params) -> EntryPoints:
"""Return EntryPoint objects for all installed packages.
Pass selection parameters (group or name) to filter the
result to entry points matching those properties (see
EntryPoints.select()).
For compatibility, returns ``SelectableGroups`` object unless
selection parameters are supplied. In the future, this function
will return ``EntryPoints`` instead of ``SelectableGroups``
even when no selection parameters are supplied.
For maximum future compatibility, pass selection parameters
or invoke ``.select`` with parameters on the result.
:return: EntryPoints or SelectableGroups for all installed packages.
:return: EntryPoints for all installed packages.
"""
eps = itertools.chain.from_iterable(
dist.entry_points for dist in _unique(distributions())
)
return SelectableGroups.load(eps).select(**params)
return EntryPoints(eps).select(**params)
def files(distribution_name):
@@ -1087,8 +959,13 @@ def _top_level_declared(dist):
def _top_level_inferred(dist):
return {
f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name
opt_names = {
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f)
for f in always_iterable(dist.files)
if f.suffix == ".py"
}
@pass_none
def importable_name(name):
return '.' not in name
return filter(importable_name, opt_names)

View File

@@ -1,3 +1,5 @@
import functools
import warnings
import re
import textwrap
import email.message
@@ -5,6 +7,15 @@ import email.message
from ._text import FoldedCase
# Do not remove prior to 2024-01-01 or Python 3.14
_warn = functools.partial(
warnings.warn,
"Implicit None on return values is deprecated and will raise KeyErrors.",
DeprecationWarning,
stacklevel=2,
)
class Message(email.message.Message):
multiple_use_keys = set(
map(
@@ -39,6 +50,16 @@ class Message(email.message.Message):
def __iter__(self):
return super().__iter__()
def __getitem__(self, item):
"""
Warn users that a ``KeyError`` can be expected when a
mising key is supplied. Ref python/importlib_metadata#371.
"""
res = super().__getitem__(item)
if res is None:
_warn()
return res
def _repair_headers(self):
def redent(value):
"Correct for RFC822 indentation"

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, Iterator, List, Protocol, TypeVar, Union
from typing import Protocol
from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union, overload
_T = TypeVar("_T")
@@ -17,7 +18,21 @@ class PackageMetadata(Protocol):
def __iter__(self) -> Iterator[str]:
... # pragma: no cover
def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]:
@overload
def get(self, name: str, failobj: None = None) -> Optional[str]:
... # pragma: no cover
@overload
def get(self, name: str, failobj: _T) -> Union[str, _T]:
... # pragma: no cover
# overload per python/importlib_metadata#435
@overload
def get_all(self, name: str, failobj: None = None) -> Optional[List[Any]]:
... # pragma: no cover
@overload
def get_all(self, name: str, failobj: _T) -> Union[List[Any], _T]:
"""
Return all values associated with a possibly multi-valued key.
"""
@@ -29,18 +44,19 @@ class PackageMetadata(Protocol):
"""
class SimplePath(Protocol):
class SimplePath(Protocol[_T]):
"""
A minimal subset of pathlib.Path required by PathDistribution.
"""
def joinpath(self) -> 'SimplePath':
def joinpath(self) -> _T:
... # pragma: no cover
def __truediv__(self) -> 'SimplePath':
def __truediv__(self, other: Union[str, _T]) -> _T:
... # pragma: no cover
def parent(self) -> 'SimplePath':
@property
def parent(self) -> _T:
... # pragma: no cover
def read_text(self) -> str:

View File

@@ -34,9 +34,7 @@ def _io_wrapper(file, mode='r', *args, **kwargs):
return TextIOWrapper(file, *args, **kwargs)
elif mode == 'rb':
return file
raise ValueError(
"Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode)
)
raise ValueError(f"Invalid mode value '{mode}', only 'r' and 'rb' are supported")
class CompatibilityFiles:

View File

@@ -5,25 +5,58 @@ import functools
import contextlib
import types
import importlib
import inspect
import warnings
import itertools
from typing import Union, Optional
from typing import Union, Optional, cast
from .abc import ResourceReader, Traversable
from ._adapters import wrap_spec
Package = Union[types.ModuleType, str]
Anchor = Package
def files(package):
# type: (Package) -> Traversable
def package_to_anchor(func):
"""
Get a Traversable resource from a package
Replace 'package' parameter as 'anchor' and warn about the change.
Other errors should fall through.
>>> files('a', 'b')
Traceback (most recent call last):
TypeError: files() takes from 0 to 1 positional arguments but 2 were given
"""
return from_package(get_package(package))
undefined = object()
@functools.wraps(func)
def wrapper(anchor=undefined, package=undefined):
if package is not undefined:
if anchor is not undefined:
return func(anchor, package)
warnings.warn(
"First parameter to files is renamed to 'anchor'",
DeprecationWarning,
stacklevel=2,
)
return func(package)
elif anchor is undefined:
return func()
return func(anchor)
return wrapper
def get_resource_reader(package):
# type: (types.ModuleType) -> Optional[ResourceReader]
@package_to_anchor
def files(anchor: Optional[Anchor] = None) -> Traversable:
"""
Get a Traversable resource for an anchor.
"""
return from_package(resolve(anchor))
def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
"""
Return the package's loader if it's a ResourceReader.
"""
@@ -39,24 +72,39 @@ def get_resource_reader(package):
return reader(spec.name) # type: ignore
def resolve(cand):
# type: (Package) -> types.ModuleType
return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand)
@functools.singledispatch
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
return cast(types.ModuleType, cand)
def get_package(package):
# type: (Package) -> types.ModuleType
"""Take a package name or module object and return the module.
@resolve.register(str) # TODO: RUSTPYTHON; manual type annotation
def _(cand: str) -> types.ModuleType:
return importlib.import_module(cand)
Raise an exception if the resolved module is not a package.
@resolve.register(type(None)) # TODO: RUSTPYTHON; manual type annotation
def _(cand: None) -> types.ModuleType:
return resolve(_infer_caller().f_globals['__name__'])
def _infer_caller():
"""
resolved = resolve(package)
if wrap_spec(resolved).submodule_search_locations is None:
raise TypeError(f'{package!r} is not a package')
return resolved
Walk the stack and find the frame of the first caller not in this module.
"""
def is_this_file(frame_info):
return frame_info.filename == __file__
def is_wrapper(frame_info):
return frame_info.function == 'wrapper'
not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
# also exclude 'wrapper' due to singledispatch in the call stack
callers = itertools.filterfalse(is_wrapper, not_this_file)
return next(callers).frame
def from_package(package):
def from_package(package: types.ModuleType):
"""
Return a Traversable object for the given package.
@@ -67,10 +115,14 @@ def from_package(package):
@contextlib.contextmanager
def _tempfile(reader, suffix='',
# gh-93353: Keep a reference to call os.remove() in late Python
# finalization.
*, _os_remove=os.remove):
def _tempfile(
reader,
suffix='',
# gh-93353: Keep a reference to call os.remove() in late Python
# finalization.
*,
_os_remove=os.remove,
):
# Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
# blocks due to the need to close the temporary file to work on Windows
# properly.
@@ -89,13 +141,30 @@ def _tempfile(reader, suffix='',
pass
def _temp_file(path):
return _tempfile(path.read_bytes, suffix=path.name)
def _is_present_dir(path: Traversable) -> bool:
"""
Some Traversables implement ``is_dir()`` to raise an
exception (i.e. ``FileNotFoundError``) when the
directory doesn't exist. This function wraps that call
to always return a boolean and only return True
if there's a dir and it exists.
"""
with contextlib.suppress(FileNotFoundError):
return path.is_dir()
return False
@functools.singledispatch
def as_file(path):
"""
Given a Traversable object, return that object as a
path on the local file system in a context manager.
"""
return _tempfile(path.read_bytes, suffix=path.name)
return _temp_dir(path) if _is_present_dir(path) else _temp_file(path)
@as_file.register(pathlib.Path)
@@ -105,3 +174,34 @@ def _(path):
Degenerate behavior for pathlib.Path objects.
"""
yield path
@contextlib.contextmanager
def _temp_path(dir: tempfile.TemporaryDirectory):
"""
Wrap tempfile.TemporyDirectory to return a pathlib object.
"""
with dir as result:
yield pathlib.Path(result)
@contextlib.contextmanager
def _temp_dir(path):
"""
Given a traversable dir, recursively replicate the whole tree
to the file system in a context manager.
"""
assert path.is_dir()
with _temp_path(tempfile.TemporaryDirectory()) as temp_dir:
yield _write_contents(temp_dir, path)
def _write_contents(target, source):
child = target.joinpath(source.name)
if source.is_dir():
child.mkdir()
for item in source.iterdir():
_write_contents(child, item)
else:
child.write_bytes(source.read_bytes())
return child

View File

@@ -1,35 +1,38 @@
from itertools import filterfalse
# from more_itertools 9.0
def only(iterable, default=None, too_long=None):
"""If *iterable* has only one item, return it.
If it has zero items, return *default*.
If it has more than one item, raise the exception given by *too_long*,
which is ``ValueError`` by default.
>>> only([], default='missing')
'missing'
>>> only([1])
1
>>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: Expected exactly one item in iterable, but got 1, 2,
and perhaps more.'
>>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError
Note that :func:`only` attempts to advance *iterable* twice to ensure there
is only one item. See :func:`spy` or :func:`peekable` to check
iterable contents less destructively.
"""
it = iter(iterable)
first_value = next(it, default)
from typing import (
Callable,
Iterable,
Iterator,
Optional,
Set,
TypeVar,
Union,
)
# Type and type variable definitions
_T = TypeVar('_T')
_U = TypeVar('_U')
def unique_everseen(
iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = None
) -> Iterator[_T]:
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
# unique_everseen('ABBCcAD', str.lower) --> A B C D
seen: Set[Union[_T, _U]] = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
try:
second_value = next(it)
except StopIteration:
pass
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
msg = (
'Expected exactly one item in iterable, but got {!r}, {!r}, '
'and perhaps more.'.format(first_value, second_value)
)
raise too_long or ValueError(msg)
return first_value

View File

@@ -27,8 +27,7 @@ def deprecated(func):
return wrapper
def normalize_path(path):
# type: (Any) -> str
def normalize_path(path: Any) -> str:
"""Normalize a path by ensuring it is a string.
If the resulting string contains path separators, an exception is raised.

View File

@@ -1,6 +1,8 @@
import abc
import io
import itertools
import os
import pathlib
from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional
from typing import runtime_checkable, Protocol
from typing import Union
@@ -53,6 +55,10 @@ class ResourceReader(metaclass=abc.ABCMeta):
raise FileNotFoundError
class TraversalError(Exception):
pass
@runtime_checkable
class Traversable(Protocol):
"""
@@ -95,7 +101,6 @@ class Traversable(Protocol):
Return True if self is a file
"""
@abc.abstractmethod
def joinpath(self, *descendants: StrPath) -> "Traversable":
"""
Return Traversable resolved with any descendants applied.
@@ -104,6 +109,22 @@ class Traversable(Protocol):
and each may contain multiple levels separated by
``posixpath.sep`` (``/``).
"""
if not descendants:
return self
names = itertools.chain.from_iterable(
path.parts for path in map(pathlib.PurePosixPath, descendants)
)
target = next(names)
matches = (
traversable for traversable in self.iterdir() if traversable.name == target
)
try:
match = next(matches)
except StopIteration:
raise TraversalError(
"Target not found during traversal.", target, list(names)
)
return match.joinpath(*names)
def __truediv__(self, child: StrPath) -> "Traversable":
"""
@@ -121,7 +142,8 @@ class Traversable(Protocol):
accepted by io.TextIOWrapper.
"""
@abc.abstractproperty
@property
@abc.abstractmethod
def name(self) -> str:
"""
The base name of this object without any parent references.

View File

@@ -1,11 +1,12 @@
import collections
import operator
import itertools
import pathlib
import operator
import zipfile
from . import abc
from ._itertools import unique_everseen
from ._itertools import only
def remove_duplicates(items):
@@ -41,8 +42,10 @@ class ZipReader(abc.TraversableResources):
raise FileNotFoundError(exc.args[0])
def is_resource(self, path):
# workaround for `zipfile.Path.is_file` returning true
# for non-existent paths.
"""
Workaround for `zipfile.Path.is_file` returning true
for non-existent paths.
"""
target = self.files().joinpath(path)
return target.is_file() and target.exists()
@@ -67,8 +70,10 @@ class MultiplexedPath(abc.Traversable):
raise NotADirectoryError('MultiplexedPath only supports directories')
def iterdir(self):
files = (file for path in self._paths for file in path.iterdir())
return unique_everseen(files, key=operator.attrgetter('name'))
children = (child for path in self._paths for child in path.iterdir())
by_name = operator.attrgetter('name')
groups = itertools.groupby(sorted(children, key=by_name), key=by_name)
return map(self._follow, (locs for name, locs in groups))
def read_bytes(self):
raise FileNotFoundError(f'{self} is not a file')
@@ -82,15 +87,32 @@ class MultiplexedPath(abc.Traversable):
def is_file(self):
return False
def joinpath(self, child):
# first try to find child in current paths
for file in self.iterdir():
if file.name == child:
return file
# if it does not exist, construct it with the first path
return self._paths[0] / child
def joinpath(self, *descendants):
try:
return super().joinpath(*descendants)
except abc.TraversalError:
# One of the paths did not resolve (a directory does not exist).
# Just return something that will not exist.
return self._paths[0].joinpath(*descendants)
__truediv__ = joinpath
@classmethod
def _follow(cls, children):
"""
Construct a MultiplexedPath if needed.
If children contains a sole element, return it.
Otherwise, return a MultiplexedPath of the items.
Unless one of the items is not a Directory, then return the first.
"""
subdirs, one_dir, one_file = itertools.tee(children, 3)
try:
return only(one_dir)
except ValueError:
try:
return cls(*subdirs)
except NotADirectoryError:
return next(one_file)
def open(self, *args, **kwargs):
raise FileNotFoundError(f'{self} is not a file')

View File

@@ -16,31 +16,28 @@ class SimpleReader(abc.ABC):
provider.
"""
@abc.abstractproperty
def package(self):
# type: () -> str
@property
@abc.abstractmethod
def package(self) -> str:
"""
The name of the package for which this reader loads resources.
"""
@abc.abstractmethod
def children(self):
# type: () -> List['SimpleReader']
def children(self) -> List['SimpleReader']:
"""
Obtain an iterable of SimpleReader for available
child containers (e.g. directories).
"""
@abc.abstractmethod
def resources(self):
# type: () -> List[str]
def resources(self) -> List[str]:
"""
Obtain available named resources for this virtual package.
"""
@abc.abstractmethod
def open_binary(self, resource):
# type: (str) -> BinaryIO
def open_binary(self, resource: str) -> BinaryIO:
"""
Obtain a File-like for a named resource.
"""
@@ -50,13 +47,35 @@ class SimpleReader(abc.ABC):
return self.package.split('.')[-1]
class ResourceContainer(Traversable):
"""
Traversable container for a package's resources via its reader.
"""
def __init__(self, reader: SimpleReader):
self.reader = reader
def is_dir(self):
return True
def is_file(self):
return False
def iterdir(self):
files = (ResourceHandle(self, name) for name in self.reader.resources)
dirs = map(ResourceContainer, self.reader.children())
return itertools.chain(files, dirs)
def open(self, *args, **kwargs):
raise IsADirectoryError()
class ResourceHandle(Traversable):
"""
Handle to a named resource in a ResourceReader.
"""
def __init__(self, parent, name):
# type: (ResourceContainer, str) -> None
def __init__(self, parent: ResourceContainer, name: str):
self.parent = parent
self.name = name # type: ignore
@@ -76,44 +95,6 @@ class ResourceHandle(Traversable):
raise RuntimeError("Cannot traverse into a resource")
class ResourceContainer(Traversable):
"""
Traversable container for a package's resources via its reader.
"""
def __init__(self, reader):
# type: (SimpleReader) -> None
self.reader = reader
def is_dir(self):
return True
def is_file(self):
return False
def iterdir(self):
files = (ResourceHandle(self, name) for name in self.reader.resources)
dirs = map(ResourceContainer, self.reader.children())
return itertools.chain(files, dirs)
def open(self, *args, **kwargs):
raise IsADirectoryError()
@staticmethod
def _flatten(compound_names):
for name in compound_names:
yield from name.split('/')
def joinpath(self, *descendants):
if not descendants:
return self
names = self._flatten(descendants)
target = next(names)
return next(
traversable for traversable in self.iterdir() if traversable.name == target
).joinpath(*names)
class TraversableReader(TraversableResources, SimpleReader):
"""
A TraversableResources based on SimpleReader. Resource providers

144
Lib/importlib/util.py vendored
View File

@@ -11,12 +11,9 @@ from ._bootstrap_external import decode_source
from ._bootstrap_external import source_from_cache
from ._bootstrap_external import spec_from_file_location
from contextlib import contextmanager
import _imp
import functools
import sys
import types
import warnings
def source_hash(source_bytes):
@@ -63,10 +60,10 @@ def _find_spec_from_path(name, path=None):
try:
spec = module.__spec__
except AttributeError:
raise ValueError('{}.__spec__ is not set'.format(name)) from None
raise ValueError(f'{name}.__spec__ is not set') from None
else:
if spec is None:
raise ValueError('{}.__spec__ is None'.format(name))
raise ValueError(f'{name}.__spec__ is None')
return spec
@@ -108,115 +105,64 @@ def find_spec(name, package=None):
try:
spec = module.__spec__
except AttributeError:
raise ValueError('{}.__spec__ is not set'.format(name)) from None
raise ValueError(f'{name}.__spec__ is not set') from None
else:
if spec is None:
raise ValueError('{}.__spec__ is None'.format(name))
raise ValueError(f'{name}.__spec__ is None')
return spec
@contextmanager
def _module_to_load(name):
is_reload = name in sys.modules
# Normally we would use contextlib.contextmanager. However, this module
# is imported by runpy, which means we want to avoid any unnecessary
# dependencies. Thus we use a class.
module = sys.modules.get(name)
if not is_reload:
# This must be done before open() is called as the 'io' module
# implicitly imports 'locale' and would otherwise trigger an
# infinite loop.
module = type(sys)(name)
# This must be done before putting the module in sys.modules
# (otherwise an optimization shortcut in import.c becomes wrong)
module.__initializing__ = True
sys.modules[name] = module
try:
yield module
except Exception:
if not is_reload:
try:
del sys.modules[name]
except KeyError:
pass
finally:
module.__initializing__ = False
class _incompatible_extension_module_restrictions:
"""A context manager that can temporarily skip the compatibility check.
NOTE: This function is meant to accommodate an unusual case; one
which is likely to eventually go away. There's is a pretty good
chance this is not what you were looking for.
def set_package(fxn):
"""Set __package__ on the returned module.
WARNING: Using this function to disable the check can lead to
unexpected behavior and even crashes. It should only be used during
extension module development.
This function is deprecated.
If "disable_check" is True then the compatibility check will not
happen while the context manager is active. Otherwise the check
*will* happen.
Normally, extensions that do not support multiple interpreters
may not be imported in a subinterpreter. That implies modules
that do not implement multi-phase init or that explicitly of out.
Likewise for modules import in a subinterpeter with its own GIL
when the extension does not support a per-interpreter GIL. This
implies the module does not have a Py_mod_multiple_interpreters slot
set to Py_MOD_PER_INTERPRETER_GIL_SUPPORTED.
In both cases, this context manager may be used to temporarily
disable the check for compatible extension modules.
You can get the same effect as this function by implementing the
basic interface of multi-phase init (PEP 489) and lying about
support for mulitple interpreters (or per-interpreter GIL).
"""
@functools.wraps(fxn)
def set_package_wrapper(*args, **kwargs):
warnings.warn('The import system now takes care of this automatically; '
'this decorator is slated for removal in Python 3.12',
DeprecationWarning, stacklevel=2)
module = fxn(*args, **kwargs)
if getattr(module, '__package__', None) is None:
module.__package__ = module.__name__
if not hasattr(module, '__path__'):
module.__package__ = module.__package__.rpartition('.')[0]
return module
return set_package_wrapper
def __init__(self, *, disable_check):
self.disable_check = bool(disable_check)
def set_loader(fxn):
"""Set __loader__ on the returned module.
def __enter__(self):
self.old = _imp._override_multi_interp_extensions_check(self.override)
return self
This function is deprecated.
def __exit__(self, *args):
old = self.old
del self.old
_imp._override_multi_interp_extensions_check(old)
"""
@functools.wraps(fxn)
def set_loader_wrapper(self, *args, **kwargs):
warnings.warn('The import system now takes care of this automatically; '
'this decorator is slated for removal in Python 3.12',
DeprecationWarning, stacklevel=2)
module = fxn(self, *args, **kwargs)
if getattr(module, '__loader__', None) is None:
module.__loader__ = self
return module
return set_loader_wrapper
def module_for_loader(fxn):
"""Decorator to handle selecting the proper module for loaders.
The decorated function is passed the module to use instead of the module
name. The module passed in to the function is either from sys.modules if
it already exists or is a new module. If the module is new, then __name__
is set the first argument to the method, __loader__ is set to self, and
__package__ is set accordingly (if self.is_package() is defined) will be set
before it is passed to the decorated function (if self.is_package() does
not work for the module it will be set post-load).
If an exception is raised and the decorator created the module it is
subsequently removed from sys.modules.
The decorator assumes that the decorated function takes the module name as
the second argument.
"""
warnings.warn('The import system now takes care of this automatically; '
'this decorator is slated for removal in Python 3.12',
DeprecationWarning, stacklevel=2)
@functools.wraps(fxn)
def module_for_loader_wrapper(self, fullname, *args, **kwargs):
with _module_to_load(fullname) as module:
module.__loader__ = self
try:
is_package = self.is_package(fullname)
except (ImportError, AttributeError):
pass
else:
if is_package:
module.__package__ = fullname
else:
module.__package__ = fullname.rpartition('.')[0]
# If __package__ was not set above, __import__() will do it later.
return fxn(self, module, *args, **kwargs)
return module_for_loader_wrapper
@property
def override(self):
return -1 if self.disable_check else 1
class _LazyModule(types.ModuleType):

View File

@@ -4,13 +4,16 @@ if __name__ != 'test.support':
raise ImportError('support must be imported from the test package')
import contextlib
import dataclasses
import functools
import getpass
import opcode
import os
import re
import stat
import sys
import sysconfig
import textwrap
import time
import types
import unittest
@@ -19,11 +22,6 @@ import warnings
from .testresult import get_test_runner
try:
from _testcapi import unicode_legacy_string
except ImportError:
unicode_legacy_string = None
__all__ = [
# globals
"PIPE_MAX_SIZE", "verbose", "max_memuse", "use_resources", "failfast",
@@ -36,7 +34,7 @@ __all__ = [
"is_resource_enabled", "requires", "requires_freebsd_version",
"requires_linux_version", "requires_mac_ver",
"check_syntax_error",
"BasicTestRunner", "run_unittest", "run_doctest",
"run_unittest", "run_doctest",
"requires_gzip", "requires_bz2", "requires_lzma",
"bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute",
"requires_IEEE_754", "requires_zlib",
@@ -46,9 +44,12 @@ __all__ = [
"anticipate_failure", "load_package_tests", "detect_api_mismatch",
"check__all__", "skip_if_buggy_ucrt_strfptime",
"check_disallow_instantiation", "check_sanitizer", "skip_if_sanitizer",
"requires_limited_api", "requires_specialization",
# sys
"is_jython", "is_android", "is_emscripten", "is_wasi",
"check_impl_detail", "unix_shell", "setswitchinterval",
# os
"get_pagesize",
# network
"open_urlresource",
# processes
@@ -59,6 +60,8 @@ __all__ = [
"run_with_tz", "PGO", "missing_compiler_executable",
"ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST",
"LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT",
"Py_DEBUG", "EXCEEDS_RECURSION_LIMIT", "C_RECURSION_LIMIT",
"skip_on_s390x",
]
@@ -116,18 +119,21 @@ class Error(Exception):
class TestFailed(Error):
"""Test failed."""
class TestFailedWithDetails(TestFailed):
"""Test failed."""
def __init__(self, msg, errors, failures):
def __init__(self, msg, *args, stats=None):
self.msg = msg
self.errors = errors
self.failures = failures
super().__init__(msg, errors, failures)
self.stats = stats
super().__init__(msg, *args)
def __str__(self):
return self.msg
class TestFailedWithDetails(TestFailed):
"""Test failed."""
def __init__(self, msg, errors, failures, stats):
self.errors = errors
self.failures = failures
super().__init__(msg, errors, failures, stats=stats)
class TestDidNotRun(Error):
"""Test did not run any subtests."""
@@ -408,7 +414,7 @@ def check_sanitizer(*, address=False, memory=False, ub=False):
)
address_sanitizer = (
'-fsanitize=address' in _cflags or
'--with-memory-sanitizer' in _config_args
'--with-address-sanitizer' in _config_args
)
ub_sanitizer = (
'-fsanitize=undefined' in _cflags or
@@ -500,9 +506,16 @@ def has_no_debug_ranges():
def requires_debug_ranges(reason='requires co_positions / debug_ranges'):
return unittest.skipIf(has_no_debug_ranges(), reason)
requires_legacy_unicode_capi = unittest.skipUnless(unicode_legacy_string,
'requires legacy Unicode C API')
def requires_legacy_unicode_capi():
try:
from _testcapi import unicode_legacy_string
except ImportError:
unicode_legacy_string = None
return unittest.skipUnless(unicode_legacy_string,
'requires legacy Unicode C API')
# Is not actually used in tests, but is kept for compatibility.
is_jython = sys.platform.startswith('java')
is_android = hasattr(sys, 'getandroidapilevel')
@@ -578,7 +591,8 @@ def darwin_malloc_err_warning(test_name):
msg = ' NOTICE '
detail = (f'{test_name} may generate "malloc can\'t allocate region"\n'
'warnings on macOS systems. This behavior is known. Do not\n'
'report a bug unless tests are also failing. See bpo-40928.')
'report a bug unless tests are also failing.\n'
'See https://github.com/python/cpython/issues/85100')
padding, _ = shutil.get_terminal_size()
print(msg.center(padding, '-'))
@@ -612,6 +626,14 @@ def sortdict(dict):
withcommas = ", ".join(reprpairs)
return "{%s}" % withcommas
def run_code(code: str) -> dict[str, object]:
"""Run a piece of code after dedenting it, and return its global namespace."""
ns = {}
exec(textwrap.dedent(code), ns)
return ns
def check_syntax_error(testcase, statement, errtext='', *, lineno=None, offset=None):
with testcase.assertRaisesRegex(SyntaxError, errtext) as cm:
compile(statement, '<test string>', 'exec')
@@ -994,12 +1016,6 @@ def bigaddrspacetest(f):
#=======================================================================
# unittest integration.
class BasicTestRunner:
def run(self, test):
result = unittest.TestResult()
test(result)
return result
def _id(obj):
return obj
@@ -1078,6 +1094,18 @@ def refcount_test(test):
return no_tracing(cpython_only(test))
def requires_limited_api(test):
try:
import _testcapi
except ImportError:
return unittest.skip('needs _testcapi module')(test)
return unittest.skipUnless(
_testcapi.LIMITED_API_AVAILABLE, 'needs Limited API support')(test)
def requires_specialization(test):
return unittest.skipUnless(
opcode.ENABLE_SPECIALIZATION, "requires specialization")(test)
def _filter_suite(suite, pred):
"""Recursively filter test cases in a suite based on a predicate."""
newtests = []
@@ -1090,6 +1118,29 @@ def _filter_suite(suite, pred):
newtests.append(test)
suite._tests = newtests
@dataclasses.dataclass(slots=True)
class TestStats:
tests_run: int = 0
failures: int = 0
skipped: int = 0
@staticmethod
def from_unittest(result):
return TestStats(result.testsRun,
len(result.failures),
len(result.skipped))
@staticmethod
def from_doctest(results):
return TestStats(results.attempted,
results.failed)
def accumulate(self, stats):
self.tests_run += stats.tests_run
self.failures += stats.failures
self.skipped += stats.skipped
def _run_suite(suite):
"""Run tests from a unittest.TestSuite-derived class."""
runner = get_test_runner(sys.stdout,
@@ -1101,9 +1152,10 @@ def _run_suite(suite):
if junit_xml_list is not None:
junit_xml_list.append(result.get_xml_element())
if not result.testsRun and not result.skipped:
if not result.testsRun and not result.skipped and not result.errors:
raise TestDidNotRun
if not result.wasSuccessful():
stats = TestStats.from_unittest(result)
if len(result.errors) == 1 and not result.failures:
err = result.errors[0][1]
elif len(result.failures) == 1 and not result.errors:
@@ -1113,7 +1165,8 @@ def _run_suite(suite):
if not verbose: err += "; run in verbose mode for details"
errors = [(str(tc), exc_str) for tc, exc_str in result.errors]
failures = [(str(tc), exc_str) for tc, exc_str in result.failures]
raise TestFailedWithDetails(err, errors, failures)
raise TestFailedWithDetails(err, errors, failures, stats=stats)
return result
# By default, don't filter tests
@@ -1144,7 +1197,6 @@ def _is_full_match_test(pattern):
def set_match_tests(accept_patterns=None, ignore_patterns=None):
global _match_test_func, _accept_test_patterns, _ignore_test_patterns
if accept_patterns is None:
accept_patterns = ()
if ignore_patterns is None:
@@ -1222,7 +1274,7 @@ def run_unittest(*classes):
else:
suite.addTest(loader.loadTestsFromTestCase(cls))
_filter_suite(suite, match_test)
_run_suite(suite)
return _run_suite(suite)
#=======================================================================
# Check for the presence of docstrings.
@@ -1262,13 +1314,18 @@ def run_doctest(module, verbosity=None, optionflags=0):
else:
verbosity = None
f, t = doctest.testmod(module, verbose=verbosity, optionflags=optionflags)
if f:
raise TestFailed("%d of %d doctests failed" % (f, t))
results = doctest.testmod(module,
verbose=verbosity,
optionflags=optionflags)
if results.failed:
stats = TestStats.from_doctest(results)
raise TestFailed(f"{results.failed} of {results.attempted} "
f"doctests failed",
stats=stats)
if verbose:
print('doctest (%s) ... %d tests with zero failures' %
(module.__name__, t))
return f, t
(module.__name__, results.attempted))
return results
#=======================================================================
@@ -1792,6 +1849,25 @@ def run_in_subinterp(code):
Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc
module is enabled.
"""
_check_tracemalloc()
import _testcapi
return _testcapi.run_in_subinterp(code)
def run_in_subinterp_with_config(code, *, own_gil=None, **config):
"""
Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc
module is enabled.
"""
_check_tracemalloc()
import _testcapi
if own_gil is not None:
assert 'gil' not in config, (own_gil, config)
config['gil'] = 2 if own_gil else 1
return _testcapi.run_in_subinterp_with_config(code, **config)
def _check_tracemalloc():
# Issue #10915, #15751: PyGILState_*() functions don't work with
# sub-interpreters, the tracemalloc module uses these functions internally
try:
@@ -1803,8 +1879,6 @@ def run_in_subinterp(code):
raise unittest.SkipTest("run_in_subinterp() cannot be used "
"if tracemalloc module is tracing "
"memory allocations")
import _testcapi
return _testcapi.run_in_subinterp(code)
# TODO: RUSTPYTHON (comment out before)
@@ -1836,15 +1910,16 @@ def missing_compiler_executable(cmd_names=[]):
missing.
"""
# TODO (PEP 632): alternate check without using distutils
from distutils import ccompiler, sysconfig, spawn, errors
from setuptools._distutils import ccompiler, sysconfig, spawn
from setuptools import errors
compiler = ccompiler.new_compiler()
sysconfig.customize_compiler(compiler)
if compiler.compiler_type == "msvc":
# MSVC has no executables, so check whether initialization succeeds
try:
compiler.initialize()
except errors.DistutilsPlatformError:
except errors.PlatformError:
return "msvc"
for name in compiler.executables:
if cmd_names and name not in cmd_names:
@@ -1875,6 +1950,18 @@ def setswitchinterval(interval):
return sys.setswitchinterval(interval)
def get_pagesize():
"""Get size of a page in bytes."""
try:
page_size = os.sysconf('SC_PAGESIZE')
except (ValueError, AttributeError):
try:
page_size = os.sysconf('SC_PAGE_SIZE')
except (ValueError, AttributeError):
page_size = 4096
return page_size
@contextlib.contextmanager
def disable_faulthandler():
import faulthandler
@@ -2092,31 +2179,26 @@ def wait_process(pid, *, exitcode, timeout=None):
if timeout is None:
timeout = LONG_TIMEOUT
t0 = time.monotonic()
sleep = 0.001
max_sleep = 0.1
while True:
start_time = time.monotonic()
for _ in sleeping_retry(timeout, error=False):
pid2, status = os.waitpid(pid, os.WNOHANG)
if pid2 != 0:
break
# process is still running
# rety: the process is still running
else:
try:
os.kill(pid, signal.SIGKILL)
os.waitpid(pid, 0)
except OSError:
# Ignore errors like ChildProcessError or PermissionError
pass
dt = time.monotonic() - t0
if dt > timeout:
try:
os.kill(pid, signal.SIGKILL)
os.waitpid(pid, 0)
except OSError:
# Ignore errors like ChildProcessError or PermissionError
pass
raise AssertionError(f"process {pid} is still running "
f"after {dt:.1f} seconds")
sleep = min(sleep * 2, max_sleep)
time.sleep(sleep)
dt = time.monotonic() - start_time
raise AssertionError(f"process {pid} is still running "
f"after {dt:.1f} seconds")
else:
# Windows implementation
# Windows implementation: don't support timeout :-(
pid2, status = os.waitpid(pid, 0)
exitcode2 = os.waitstatus_to_exitcode(status)
@@ -2168,20 +2250,61 @@ def check_disallow_instantiation(testcase, tp, *args, **kwds):
msg = f"cannot create '{re.escape(qualname)}' instances"
testcase.assertRaisesRegex(TypeError, msg, tp, *args, **kwds)
def get_recursion_depth():
"""Get the recursion depth of the caller function.
In the __main__ module, at the module level, it should be 1.
"""
try:
import _testinternalcapi
depth = _testinternalcapi.get_recursion_depth()
except (ImportError, RecursionError) as exc:
# sys._getframe() + frame.f_back implementation.
try:
depth = 0
frame = sys._getframe()
while frame is not None:
depth += 1
frame = frame.f_back
finally:
# Break any reference cycles.
frame = None
# Ignore get_recursion_depth() frame.
return max(depth - 1, 1)
def get_recursion_available():
"""Get the number of available frames before RecursionError.
It depends on the current recursion depth of the caller function and
sys.getrecursionlimit().
"""
limit = sys.getrecursionlimit()
depth = get_recursion_depth()
return limit - depth
@contextlib.contextmanager
def infinite_recursion(max_depth=75):
def set_recursion_limit(limit):
"""Temporarily change the recursion limit."""
original_limit = sys.getrecursionlimit()
try:
sys.setrecursionlimit(limit)
yield
finally:
sys.setrecursionlimit(original_limit)
def infinite_recursion(max_depth=100):
"""Set a lower limit for tests that interact with infinite recursions
(e.g test_ast.ASTHelpers_Test.test_recursion_direct) since on some
debug windows builds, due to not enough functions being inlined the
stack size might not handle the default recursion limit (1000). See
bpo-11105 for details."""
original_depth = sys.getrecursionlimit()
try:
sys.setrecursionlimit(max_depth)
yield
finally:
sys.setrecursionlimit(original_depth)
if max_depth < 3:
raise ValueError("max_depth must be at least 3, got {max_depth}")
depth = get_recursion_depth()
depth = max(depth - 1, 1) # Ignore infinite_recursion() frame.
limit = depth + max_depth
return set_recursion_limit(limit)
def ignore_deprecations_from(module: str, *, like: str) -> object:
token = object()
@@ -2230,6 +2353,180 @@ def requires_venv_with_pip():
return unittest.skipUnless(ctypes, 'venv: pip requires ctypes')
@functools.cache
def _findwheel(pkgname):
"""Try to find a wheel with the package specified as pkgname.
If set, the wheels are searched for in WHEEL_PKG_DIR (see ensurepip).
Otherwise, they are searched for in the test directory.
"""
wheel_dir = sysconfig.get_config_var('WHEEL_PKG_DIR') or TEST_HOME_DIR
filenames = os.listdir(wheel_dir)
filenames = sorted(filenames, reverse=True) # approximate "newest" first
for filename in filenames:
# filename is like 'setuptools-67.6.1-py3-none-any.whl'
if not filename.endswith(".whl"):
continue
prefix = pkgname + '-'
if filename.startswith(prefix):
return os.path.join(wheel_dir, filename)
raise FileNotFoundError(f"No wheel for {pkgname} found in {wheel_dir}")
# Context manager that creates a virtual environment, install setuptools and wheel in it
# and returns the path to the venv directory and the path to the python executable
@contextlib.contextmanager
def setup_venv_with_pip_setuptools_wheel(venv_dir):
import subprocess
from .os_helper import temp_cwd
with temp_cwd() as temp_dir:
# Create virtual environment to get setuptools
cmd = [sys.executable, '-X', 'dev', '-m', 'venv', venv_dir]
if verbose:
print()
print('Run:', ' '.join(cmd))
subprocess.run(cmd, check=True)
venv = os.path.join(temp_dir, venv_dir)
# Get the Python executable of the venv
python_exe = os.path.basename(sys.executable)
if sys.platform == 'win32':
python = os.path.join(venv, 'Scripts', python_exe)
else:
python = os.path.join(venv, 'bin', python_exe)
cmd = [python, '-X', 'dev',
'-m', 'pip', 'install',
_findwheel('setuptools'),
_findwheel('wheel')]
if verbose:
print()
print('Run:', ' '.join(cmd))
subprocess.run(cmd, check=True)
yield python
# True if Python is built with the Py_DEBUG macro defined: if
# Python is built in debug mode (./configure --with-pydebug).
Py_DEBUG = hasattr(sys, 'gettotalrefcount')
def late_deletion(obj):
"""
Keep a Python alive as long as possible.
Create a reference cycle and store the cycle in an object deleted late in
Python finalization. Try to keep the object alive until the very last
garbage collection.
The function keeps a strong reference by design. It should be called in a
subprocess to not mark a test as "leaking a reference".
"""
# Late CPython finalization:
# - finalize_interp_clear()
# - _PyInterpreterState_Clear(): Clear PyInterpreterState members
# (ex: codec_search_path, before_forkers)
# - clear os.register_at_fork() callbacks
# - clear codecs.register() callbacks
ref_cycle = [obj]
ref_cycle.append(ref_cycle)
# Store a reference in PyInterpreterState.codec_search_path
import codecs
def search_func(encoding):
return None
search_func.reference = ref_cycle
codecs.register(search_func)
if hasattr(os, 'register_at_fork'):
# Store a reference in PyInterpreterState.before_forkers
def atfork_func():
pass
atfork_func.reference = ref_cycle
os.register_at_fork(before=atfork_func)
def busy_retry(timeout, err_msg=None, /, *, error=True):
"""
Run the loop body until "break" stops the loop.
After *timeout* seconds, raise an AssertionError if *error* is true,
or just stop if *error is false.
Example:
for _ in support.busy_retry(support.SHORT_TIMEOUT):
if check():
break
Example of error=False usage:
for _ in support.busy_retry(support.SHORT_TIMEOUT, error=False):
if check():
break
else:
raise RuntimeError('my custom error')
"""
if timeout <= 0:
raise ValueError("timeout must be greater than zero")
start_time = time.monotonic()
deadline = start_time + timeout
while True:
yield
if time.monotonic() >= deadline:
break
if error:
dt = time.monotonic() - start_time
msg = f"timeout ({dt:.1f} seconds)"
if err_msg:
msg = f"{msg}: {err_msg}"
raise AssertionError(msg)
def sleeping_retry(timeout, err_msg=None, /,
*, init_delay=0.010, max_delay=1.0, error=True):
"""
Wait strategy that applies exponential backoff.
Run the loop body until "break" stops the loop. Sleep at each loop
iteration, but not at the first iteration. The sleep delay is doubled at
each iteration (up to *max_delay* seconds).
See busy_retry() documentation for the parameters usage.
Example raising an exception after SHORT_TIMEOUT seconds:
for _ in support.sleeping_retry(support.SHORT_TIMEOUT):
if check():
break
Example of error=False usage:
for _ in support.sleeping_retry(support.SHORT_TIMEOUT, error=False):
if check():
break
else:
raise RuntimeError('my custom error')
"""
delay = init_delay
for _ in busy_retry(timeout, err_msg, error=error):
yield
time.sleep(delay)
delay = min(delay * 2, max_delay)
@contextlib.contextmanager
def adjust_int_max_str_digits(max_digits):
"""Temporarily change the integer string conversion length limit."""
@@ -2239,3 +2536,13 @@ def adjust_int_max_str_digits(max_digits):
yield
finally:
sys.set_int_max_str_digits(current)
#For recursion tests, easily exceeds default recursion limit
EXCEEDS_RECURSION_LIMIT = 5000
# The default C recursion limit (from Include/cpython/pystate.h).
C_RECURSION_LIMIT = 1500
#Windows doesn't have os.uname() but it doesn't support s390x.
skip_on_s390x = unittest.skipIf(hasattr(os, 'uname') and os.uname().machine == 's390x',
'skipped on s390x')

View File

@@ -3,6 +3,7 @@
import unittest
import dis
import io
from _testinternalcapi import compiler_codegen, optimize_cfg, assemble_code_object
_UNSPECIFIED = object()
@@ -16,6 +17,7 @@ class BytecodeTestCase(unittest.TestCase):
def assertInBytecode(self, x, opname, argval=_UNSPECIFIED):
"""Returns instr if opname is found, otherwise throws AssertionError"""
self.assertIn(opname, dis.opmap)
for instr in dis.get_instructions(x):
if instr.opname == opname:
if argval is _UNSPECIFIED or instr.argval == argval:
@@ -30,6 +32,7 @@ class BytecodeTestCase(unittest.TestCase):
def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED):
"""Throws AssertionError if opname is found"""
self.assertIn(opname, dis.opmap)
for instr in dis.get_instructions(x):
if instr.opname == opname:
disassembly = self.get_disassembly_as_string(x)
@@ -40,3 +43,101 @@ class BytecodeTestCase(unittest.TestCase):
msg = '(%s,%r) occurs in bytecode:\n%s'
msg = msg % (opname, argval, disassembly)
self.fail(msg)
class CompilationStepTestCase(unittest.TestCase):
HAS_ARG = set(dis.hasarg)
HAS_TARGET = set(dis.hasjrel + dis.hasjabs + dis.hasexc)
HAS_ARG_OR_TARGET = HAS_ARG.union(HAS_TARGET)
class Label:
pass
def assertInstructionsMatch(self, actual_, expected_):
# get two lists where each entry is a label or
# an instruction tuple. Normalize the labels to the
# instruction count of the target, and compare the lists.
self.assertIsInstance(actual_, list)
self.assertIsInstance(expected_, list)
actual = self.normalize_insts(actual_)
expected = self.normalize_insts(expected_)
self.assertEqual(len(actual), len(expected))
# compare instructions
for act, exp in zip(actual, expected):
if isinstance(act, int):
self.assertEqual(exp, act)
continue
self.assertIsInstance(exp, tuple)
self.assertIsInstance(act, tuple)
# crop comparison to the provided expected values
if len(act) > len(exp):
act = act[:len(exp)]
self.assertEqual(exp, act)
def resolveAndRemoveLabels(self, insts):
idx = 0
res = []
for item in insts:
assert isinstance(item, (self.Label, tuple))
if isinstance(item, self.Label):
item.value = idx
else:
idx += 1
res.append(item)
return res
def normalize_insts(self, insts):
""" Map labels to instruction index.
Map opcodes to opnames.
"""
insts = self.resolveAndRemoveLabels(insts)
res = []
for item in insts:
assert isinstance(item, tuple)
opcode, oparg, *loc = item
opcode = dis.opmap.get(opcode, opcode)
if isinstance(oparg, self.Label):
arg = oparg.value
else:
arg = oparg if opcode in self.HAS_ARG else None
opcode = dis.opname[opcode]
res.append((opcode, arg, *loc))
return res
def complete_insts_info(self, insts):
# fill in omitted fields in location, and oparg 0 for ops with no arg.
res = []
for item in insts:
assert isinstance(item, tuple)
inst = list(item)
opcode = dis.opmap[inst[0]]
oparg = inst[1]
loc = inst[2:] + [-1] * (6 - len(inst))
res.append((opcode, oparg, *loc))
return res
class CodegenTestCase(CompilationStepTestCase):
def generate_code(self, ast):
insts, _ = compiler_codegen(ast, "my_file.py", 0)
return insts
class CfgOptimizationTestCase(CompilationStepTestCase):
def get_optimized(self, insts, consts, nlocals=0):
insts = self.normalize_insts(insts)
insts = self.complete_insts_info(insts)
insts = optimize_cfg(insts, consts, nlocals)
return insts, consts
class AssemblerTestCase(CompilationStepTestCase):
def get_code_object(self, filename, insts, metadata):
co = assemble_code_object(filename, insts, metadata)
return co

View File

@@ -105,6 +105,26 @@ def frozen_modules(enabled=True):
_imp._override_frozen_modules_for_tests(0)
@contextlib.contextmanager
def multi_interp_extensions_check(enabled=True):
"""Force legacy modules to be allowed in subinterpreters (or not).
("legacy" == single-phase init)
This only applies to modules that haven't been imported yet.
It overrides the PyInterpreterConfig.check_multi_interp_extensions
setting (see support.run_in_subinterp_with_config() and
_xxsubinterpreters.create()).
Also see importlib.utils.allowing_all_extensions().
"""
old = _imp._override_multi_interp_extensions_check(1 if enabled else -1)
try:
yield
finally:
_imp._override_multi_interp_extensions_check(old)
def import_fresh_module(name, fresh=(), blocked=(), *,
deprecated=False,
usefrozen=False,
@@ -246,3 +266,11 @@ def modules_cleanup(oldmodules):
# do currently). Implicitly imported *real* modules should be left alone
# (see issue 10556).
sys.modules.update(oldmodules)
def mock_register_at_fork(func):
# bpo-30599: Mock os.register_at_fork() when importing the random module,
# since this function doesn't allow to unregister callbacks and would leak
# memory.
from unittest import mock
return mock.patch('os.register_at_fork', create=True)(func)

View File

@@ -2,11 +2,12 @@
import time
import _xxsubinterpreters as _interpreters
import _xxinterpchannels as _channels
# aliases:
from _xxsubinterpreters import (
from _xxsubinterpreters import is_shareable, RunFailedError
from _xxinterpchannels import (
ChannelError, ChannelNotFoundError, ChannelEmptyError,
is_shareable,
)
@@ -102,7 +103,7 @@ def create_channel():
The channel may be used to pass data safely between interpreters.
"""
cid = _interpreters.channel_create()
cid = _channels.create()
recv, send = RecvChannel(cid), SendChannel(cid)
return recv, send
@@ -110,14 +111,14 @@ def create_channel():
def list_all_channels():
"""Return a list of (recv, send) for all open channels."""
return [(RecvChannel(cid), SendChannel(cid))
for cid in _interpreters.channel_list_all()]
for cid in _channels.list_all()]
class _ChannelEnd:
"""The base class for RecvChannel and SendChannel."""
def __init__(self, id):
if not isinstance(id, (int, _interpreters.ChannelID)):
if not isinstance(id, (int, _channels.ChannelID)):
raise TypeError(f'id must be an int, got {id!r}')
self._id = id
@@ -152,10 +153,10 @@ class RecvChannel(_ChannelEnd):
This blocks until an object has been sent, if none have been
sent already.
"""
obj = _interpreters.channel_recv(self._id, _sentinel)
obj = _channels.recv(self._id, _sentinel)
while obj is _sentinel:
time.sleep(_delay)
obj = _interpreters.channel_recv(self._id, _sentinel)
obj = _channels.recv(self._id, _sentinel)
return obj
def recv_nowait(self, default=_NOT_SET):
@@ -166,9 +167,9 @@ class RecvChannel(_ChannelEnd):
is the same as recv().
"""
if default is _NOT_SET:
return _interpreters.channel_recv(self._id)
return _channels.recv(self._id)
else:
return _interpreters.channel_recv(self._id, default)
return _channels.recv(self._id, default)
class SendChannel(_ChannelEnd):
@@ -179,7 +180,7 @@ class SendChannel(_ChannelEnd):
This blocks until the object is received.
"""
_interpreters.channel_send(self._id, obj)
_channels.send(self._id, obj)
# XXX We are missing a low-level channel_send_wait().
# See bpo-32604 and gh-19829.
# Until that shows up we fake it:
@@ -194,4 +195,4 @@ class SendChannel(_ChannelEnd):
# XXX Note that at the moment channel_send() only ever returns
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _interpreters.channel_send(self._id, obj)
return _channels.send(self._id, obj)

View File

@@ -4,6 +4,7 @@ import errno
import os
import re
import stat
import string
import sys
import time
import unittest
@@ -11,11 +12,7 @@ import warnings
# Filename used for testing
if os.name == 'java':
# Jython disallows @ in module names
TESTFN_ASCII = '$test'
else:
TESTFN_ASCII = '@test'
TESTFN_ASCII = '@test'
# Disambiguate TESTFN for parallel testing, while letting it remain a valid
# module name.
@@ -141,6 +138,11 @@ for name in (
try:
name.decode(sys.getfilesystemencoding())
except UnicodeDecodeError:
try:
name.decode(sys.getfilesystemencoding(),
sys.getfilesystemencodeerrors())
except UnicodeDecodeError:
continue
TESTFN_UNDECODABLE = os.fsencode(TESTFN_ASCII) + name
break
@@ -567,7 +569,7 @@ def fs_is_case_insensitive(directory):
class FakePath:
"""Simple implementing of the path protocol.
"""Simple implementation of the path protocol.
"""
def __init__(self, path):
self.path = path
@@ -715,3 +717,37 @@ class EnvironmentVarGuard(collections.abc.MutableMapping):
else:
self._environ[k] = v
os.environ = self._environ
try:
import ctypes
kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
ERROR_FILE_NOT_FOUND = 2
DDD_REMOVE_DEFINITION = 2
DDD_EXACT_MATCH_ON_REMOVE = 4
DDD_NO_BROADCAST_SYSTEM = 8
except (ImportError, AttributeError):
def subst_drive(path):
raise unittest.SkipTest('ctypes or kernel32 is not available')
else:
@contextlib.contextmanager
def subst_drive(path):
"""Temporarily yield a substitute drive for a given path."""
for c in reversed(string.ascii_uppercase):
drive = f'{c}:'
if (not kernel32.QueryDosDeviceW(drive, None, 0) and
ctypes.get_last_error() == ERROR_FILE_NOT_FOUND):
break
else:
raise unittest.SkipTest('no available logical drive')
if not kernel32.DefineDosDeviceW(
DDD_NO_BROADCAST_SYSTEM, drive, path):
raise ctypes.WinError(ctypes.get_last_error())
try:
yield drive
finally:
if not kernel32.DefineDosDeviceW(
DDD_REMOVE_DEFINITION | DDD_EXACT_MATCH_ON_REMOVE,
drive, path):
raise ctypes.WinError(ctypes.get_last_error())

View File

@@ -1,8 +1,11 @@
import contextlib
import errno
import os.path
import socket
import unittest
import sys
import subprocess
import tempfile
import unittest
from .. import support
from . import warnings_helper
@@ -61,7 +64,7 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
http://bugs.python.org/issue2550 for more info. The following site also
has a very thorough description about the implications of both REUSEADDR
and EXCLUSIVEADDRUSE on Windows:
http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx)
https://learn.microsoft.com/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
XXX: although this approach is a vast improvement on previous attempts to
elicit unused ports, it rests heavily on the assumption that the ephemeral
@@ -270,3 +273,73 @@ def transient_internet(resource_name, *, timeout=_NOT_SET, errnos=()):
# __cause__ or __context__?
finally:
socket.setdefaulttimeout(old_timeout)
def create_unix_domain_name():
"""
Create a UNIX domain name: socket.bind() argument of a AF_UNIX socket.
Return a path relative to the current directory to get a short path
(around 27 ASCII characters).
"""
return tempfile.mktemp(prefix="test_python_", suffix='.sock',
dir=os.path.curdir)
# consider that sysctl values should not change while tests are running
_sysctl_cache = {}
def _get_sysctl(name):
"""Get a sysctl value as an integer."""
try:
return _sysctl_cache[name]
except KeyError:
pass
# At least Linux and FreeBSD support the "-n" option
cmd = ['sysctl', '-n', name]
proc = subprocess.run(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
if proc.returncode:
support.print_warning(f'{" ".join(cmd)!r} command failed with '
f'exit code {proc.returncode}')
# cache the error to only log the warning once
_sysctl_cache[name] = None
return None
output = proc.stdout
# Parse '0\n' to get '0'
try:
value = int(output.strip())
except Exception as exc:
support.print_warning(f'Failed to parse {" ".join(cmd)!r} '
f'command output {output!r}: {exc!r}')
# cache the error to only log the warning once
_sysctl_cache[name] = None
return None
_sysctl_cache[name] = value
return value
def tcp_blackhole():
if not sys.platform.startswith('freebsd'):
return False
# gh-109015: test if FreeBSD TCP blackhole is enabled
value = _get_sysctl('net.inet.tcp.blackhole')
if value is None:
# don't skip if we fail to get the sysctl value
return False
return (value != 0)
def skip_if_tcp_blackhole(test):
"""Decorator skipping test if TCP blackhole is enabled."""
skip_if = unittest.skipIf(
tcp_blackhole(),
"TCP blackhole is enabled (sysctl net.inet.tcp.blackhole)"
)
return skip_if(test)

View File

@@ -8,6 +8,7 @@ import sys
import time
import traceback
import unittest
from test import support
class RegressionTestResult(unittest.TextTestResult):
USE_XML = False
@@ -18,10 +19,13 @@ class RegressionTestResult(unittest.TextTestResult):
self.buffer = True
if self.USE_XML:
from xml.etree import ElementTree as ET
from datetime import datetime
from datetime import datetime, UTC
self.__ET = ET
self.__suite = ET.Element('testsuite')
self.__suite.set('start', datetime.utcnow().isoformat(' '))
self.__suite.set('start',
datetime.now(UTC)
.replace(tzinfo=None)
.isoformat(' '))
self.__e = None
self.__start_time = None
@@ -109,6 +113,8 @@ class RegressionTestResult(unittest.TextTestResult):
def addFailure(self, test, err):
self._add_result(test, True, failure=self.__makeErrorDict(*err))
super().addFailure(test, err)
if support.failfast:
self.stop()
def addSkip(self, test, reason):
self._add_result(test, skipped=reason)

View File

@@ -88,19 +88,17 @@ def wait_threads_exit(timeout=None):
yield
finally:
start_time = time.monotonic()
deadline = start_time + timeout
while True:
for _ in support.sleeping_retry(timeout, error=False):
support.gc_collect()
count = _thread._count()
if count <= old_count:
break
if time.monotonic() > deadline:
dt = time.monotonic() - start_time
msg = (f"wait_threads() failed to cleanup {count - old_count} "
f"threads after {dt:.1f} seconds "
f"(count: {count}, old count: {old_count})")
raise AssertionError(msg)
time.sleep(0.010)
support.gc_collect()
else:
dt = time.monotonic() - start_time
msg = (f"wait_threads() failed to cleanup {count - old_count} "
f"threads after {dt:.1f} seconds "
f"(count: {count}, old count: {old_count})")
raise AssertionError(msg)
def join_thread(thread, timeout=None):
@@ -117,7 +115,11 @@ def join_thread(thread, timeout=None):
@contextlib.contextmanager
def start_threads(threads, unlock=None):
import faulthandler
try:
import faulthandler
except ImportError:
# It isn't supported on subinterpreters yet.
faulthandler = None
threads = list(threads)
started = []
try:
@@ -149,7 +151,8 @@ def start_threads(threads, unlock=None):
finally:
started = [t for t in started if t.is_alive()]
if started:
faulthandler.dump_traceback(sys.stdout)
if faulthandler is not None:
faulthandler.dump_traceback(sys.stdout)
raise AssertionError('Unable to join %d threads' % len(started))

View File

@@ -44,7 +44,7 @@ def check_syntax_warning(testcase, statement, errtext='',
def ignore_warnings(*, category):
"""Decorator to suppress deprecation warnings.
"""Decorator to suppress warnings.
Use of context managers to hide warnings make diffs
more noisy and tools like 'git blame' less useful.

View File

@@ -20,7 +20,7 @@ Cowlishaw's tests can be downloaded from:
This test module can be called from command line with one parameter (Arithmetic
or Behaviour) to test each part, or without parameter to test both parts. If
you're working through IDLE, you can import this test module and call test_main()
you're working through IDLE, you can import this test module and call test()
with the corresponding argument.
"""
@@ -32,13 +32,14 @@ import pickle, copy
import unittest
import numbers
import locale
from test.support import (run_unittest, run_doctest, is_resource_enabled,
from test.support import (is_resource_enabled,
requires_IEEE_754, requires_docstrings,
requires_legacy_unicode_capi, check_sanitizer)
from test.support import (TestFailed,
run_with_locale, cpython_only,
darwin_malloc_err_warning)
darwin_malloc_err_warning, is_emscripten)
from test.support.import_helper import import_fresh_module
from test.support import threading_helper
from test.support import warnings_helper
import random
import inspect
@@ -61,6 +62,7 @@ sys.modules['decimal'] = C
fractions = {C:cfractions, P:pfractions}
sys.modules['decimal'] = orig_sys_decimal
requires_cdecimal = unittest.skipUnless(C, "test requires C version")
# Useful Test Constant
Signals = {
@@ -98,7 +100,7 @@ RoundingModes = [
]
# Tests are built around these assumed context defaults.
# test_main() restores the original context.
# test() restores the original context.
ORIGINAL_CONTEXT = {
C: C.getcontext().copy() if C else None,
P: P.getcontext().copy()
@@ -132,7 +134,7 @@ skip_if_extra_functionality = unittest.skipIf(
EXTRA_FUNCTIONALITY, "test requires regular build")
class IBMTestCases(unittest.TestCase):
class IBMTestCases:
"""Class which tests the Decimal class against the IBM test cases."""
def setUp(self):
@@ -487,14 +489,10 @@ class IBMTestCases(unittest.TestCase):
def change_clamp(self, clamp):
self.context.clamp = clamp
class CIBMTestCases(IBMTestCases):
decimal = C
class PyIBMTestCases(IBMTestCases):
decimal = P
# The following classes test the behaviour of Decimal according to PEP 327
class ExplicitConstructionTest(unittest.TestCase):
class ExplicitConstructionTest:
'''Unit tests for Explicit Construction cases of Decimal.'''
def test_explicit_empty(self):
@@ -589,7 +587,7 @@ class ExplicitConstructionTest(unittest.TestCase):
self.assertRaises(InvalidOperation, Decimal, "1_2_\u00003")
@cpython_only
@requires_legacy_unicode_capi
@requires_legacy_unicode_capi()
@warnings_helper.ignore_warnings(category=DeprecationWarning)
def test_from_legacy_strings(self):
import _testcapi
@@ -839,12 +837,13 @@ class ExplicitConstructionTest(unittest.TestCase):
for input, expected in test_values.items():
self.assertEqual(str(Decimal(input)), expected)
class CExplicitConstructionTest(ExplicitConstructionTest):
@requires_cdecimal
class CExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase):
decimal = C
class PyExplicitConstructionTest(ExplicitConstructionTest):
class PyExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase):
decimal = P
class ImplicitConstructionTest(unittest.TestCase):
class ImplicitConstructionTest:
'''Unit tests for Implicit Construction cases of Decimal.'''
def test_implicit_from_None(self):
@@ -921,13 +920,16 @@ class ImplicitConstructionTest(unittest.TestCase):
self.assertEqual(eval('Decimal(10)' + sym + 'E()'),
'10' + rop + 'str')
class CImplicitConstructionTest(ImplicitConstructionTest):
@requires_cdecimal
class CImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase):
decimal = C
class PyImplicitConstructionTest(ImplicitConstructionTest):
class PyImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase):
decimal = P
class FormatTest(unittest.TestCase):
class FormatTest:
'''Unit tests for the format function.'''
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_formatting(self):
Decimal = self.decimal.Decimal
@@ -1073,6 +1075,57 @@ class FormatTest(unittest.TestCase):
(',e', '123456', '1.23456e+5'),
(',E', '123456', '1.23456E+5'),
# negative zero: default behavior
('.1f', '-0', '-0.0'),
('.1f', '-.0', '-0.0'),
('.1f', '-.01', '-0.0'),
# negative zero: z option
('z.1f', '0.', '0.0'),
('z6.1f', '0.', ' 0.0'),
('z6.1f', '-1.', ' -1.0'),
('z.1f', '-0.', '0.0'),
('z.1f', '.01', '0.0'),
('z.1f', '-.01', '0.0'),
('z.2f', '0.', '0.00'),
('z.2f', '-0.', '0.00'),
('z.2f', '.001', '0.00'),
('z.2f', '-.001', '0.00'),
('z.1e', '0.', '0.0e+1'),
('z.1e', '-0.', '0.0e+1'),
('z.1E', '0.', '0.0E+1'),
('z.1E', '-0.', '0.0E+1'),
('z.2e', '-0.001', '-1.00e-3'), # tests for mishandled rounding
('z.2g', '-0.001', '-0.001'),
('z.2%', '-0.001', '-0.10%'),
('zf', '-0.0000', '0.0000'), # non-normalized form is preserved
('z.1f', '-00000.000001', '0.0'),
('z.1f', '-00000.', '0.0'),
('z.1f', '-.0000000000', '0.0'),
('z.2f', '-00000.000001', '0.00'),
('z.2f', '-00000.', '0.00'),
('z.2f', '-.0000000000', '0.00'),
('z.1f', '.09', '0.1'),
('z.1f', '-.09', '-0.1'),
(' z.0f', '-0.', ' 0'),
('+z.0f', '-0.', '+0'),
('-z.0f', '-0.', '0'),
(' z.0f', '-1.', '-1'),
('+z.0f', '-1.', '-1'),
('-z.0f', '-1.', '-1'),
('z>6.1f', '-0.', 'zz-0.0'),
('z>z6.1f', '-0.', 'zzz0.0'),
('x>z6.1f', '-0.', 'xxx0.0'),
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
# issue 6850
('a=-7.0', '0.12345', 'aaaa0.1'),
@@ -1087,6 +1140,17 @@ class FormatTest(unittest.TestCase):
# bytes format argument
self.assertRaises(TypeError, Decimal(1).__format__, b'-020')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_negative_zero_format_directed_rounding(self):
with self.decimal.localcontext() as ctx:
ctx.rounding = ROUND_CEILING
self.assertEqual(format(self.decimal.Decimal('-0.001'), 'z.2f'),
'0.00')
def test_negative_zero_bad_format(self):
self.assertRaises(ValueError, format, self.decimal.Decimal('1.23'), 'fz')
def test_n_format(self):
Decimal = self.decimal.Decimal
@@ -1205,12 +1269,13 @@ class FormatTest(unittest.TestCase):
a = A.from_float(42)
self.assertEqual(self.decimal.Decimal, a.a_type)
class CFormatTest(FormatTest):
@requires_cdecimal
class CFormatTest(FormatTest, unittest.TestCase):
decimal = C
class PyFormatTest(FormatTest):
class PyFormatTest(FormatTest, unittest.TestCase):
decimal = P
class ArithmeticOperatorsTest(unittest.TestCase):
class ArithmeticOperatorsTest:
'''Unit tests for all arithmetic operators, binary and unary.'''
def test_addition(self):
@@ -1466,14 +1531,17 @@ class ArithmeticOperatorsTest(unittest.TestCase):
equality_ops = operator.eq, operator.ne
# results when InvalidOperation is not trapped
for x, y in qnan_pairs + snan_pairs:
for op in order_ops + equality_ops:
got = op(x, y)
expected = True if op is operator.ne else False
self.assertIs(expected, got,
"expected {0!r} for operator.{1}({2!r}, {3!r}); "
"got {4!r}".format(
expected, op.__name__, x, y, got))
with localcontext() as ctx:
ctx.traps[InvalidOperation] = 0
for x, y in qnan_pairs + snan_pairs:
for op in order_ops + equality_ops:
got = op(x, y)
expected = True if op is operator.ne else False
self.assertIs(expected, got,
"expected {0!r} for operator.{1}({2!r}, {3!r}); "
"got {4!r}".format(
expected, op.__name__, x, y, got))
# repeat the above, but this time trap the InvalidOperation
with localcontext() as ctx:
@@ -1505,9 +1573,10 @@ class ArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(Decimal(1).copy_sign(-2), d)
self.assertRaises(TypeError, Decimal(1).copy_sign, '-2')
class CArithmeticOperatorsTest(ArithmeticOperatorsTest):
@requires_cdecimal
class CArithmeticOperatorsTest(ArithmeticOperatorsTest, unittest.TestCase):
decimal = C
class PyArithmeticOperatorsTest(ArithmeticOperatorsTest):
class PyArithmeticOperatorsTest(ArithmeticOperatorsTest, unittest.TestCase):
decimal = P
# The following are two functions used to test threading in the next class
@@ -1595,7 +1664,9 @@ def thfunc2(cls):
for sig in Overflow, Underflow, DivisionByZero, InvalidOperation:
cls.assertFalse(thiscontext.flags[sig])
class ThreadingTest(unittest.TestCase):
@threading_helper.requires_working_threading()
class ThreadingTest:
'''Unit tests for thread local contexts in Decimal.'''
# Take care executing this test from IDLE, there's an issue in threading
@@ -1640,13 +1711,14 @@ class ThreadingTest(unittest.TestCase):
DefaultContext.Emin = save_emin
class CThreadingTest(ThreadingTest):
@requires_cdecimal
class CThreadingTest(ThreadingTest, unittest.TestCase):
decimal = C
class PyThreadingTest(ThreadingTest):
class PyThreadingTest(ThreadingTest, unittest.TestCase):
decimal = P
class UsabilityTest(unittest.TestCase):
class UsabilityTest:
'''Unit tests for Usability cases of Decimal.'''
def test_comparison_operators(self):
@@ -2466,12 +2538,22 @@ class UsabilityTest(unittest.TestCase):
self.assertEqual(Decimal(-12).fma(45, Decimal(67)),
Decimal(-12).fma(Decimal(45), Decimal(67)))
class CUsabilityTest(UsabilityTest):
@requires_cdecimal
class CUsabilityTest(UsabilityTest, unittest.TestCase):
decimal = C
class PyUsabilityTest(UsabilityTest):
class PyUsabilityTest(UsabilityTest, unittest.TestCase):
decimal = P
class PythonAPItests(unittest.TestCase):
def setUp(self):
super().setUp()
self._previous_int_limit = sys.get_int_max_str_digits()
sys.set_int_max_str_digits(7000)
def tearDown(self):
sys.set_int_max_str_digits(self._previous_int_limit)
super().tearDown()
class PythonAPItests:
def test_abc(self):
Decimal = self.decimal.Decimal
@@ -2549,6 +2631,13 @@ class PythonAPItests(unittest.TestCase):
self.assertRaises(OverflowError, int, Decimal('inf'))
self.assertRaises(OverflowError, int, Decimal('-inf'))
@cpython_only
def test_small_ints(self):
Decimal = self.decimal.Decimal
# bpo-46361
for x in range(-5, 257):
self.assertIs(int(Decimal(x)), x)
def test_trunc(self):
Decimal = self.decimal.Decimal
@@ -2815,12 +2904,13 @@ class PythonAPItests(unittest.TestCase):
self.assertTrue(issubclass(decimal.DivisionUndefined, ZeroDivisionError))
self.assertTrue(issubclass(decimal.InvalidContext, InvalidOperation))
class CPythonAPItests(PythonAPItests):
@requires_cdecimal
class CPythonAPItests(PythonAPItests, unittest.TestCase):
decimal = C
class PyPythonAPItests(PythonAPItests):
class PyPythonAPItests(PythonAPItests, unittest.TestCase):
decimal = P
class ContextAPItests(unittest.TestCase):
class ContextAPItests:
def test_none_args(self):
Context = self.decimal.Context
@@ -2843,7 +2933,7 @@ class ContextAPItests(unittest.TestCase):
Overflow])
@cpython_only
@requires_legacy_unicode_capi
@requires_legacy_unicode_capi()
@warnings_helper.ignore_warnings(category=DeprecationWarning)
def test_from_legacy_strings(self):
import _testcapi
@@ -3566,12 +3656,13 @@ class ContextAPItests(unittest.TestCase):
self.assertRaises(TypeError, c.to_integral_value, '10')
self.assertRaises(TypeError, c.to_integral_value, 10, 'x')
class CContextAPItests(ContextAPItests):
@requires_cdecimal
class CContextAPItests(ContextAPItests, unittest.TestCase):
decimal = C
class PyContextAPItests(ContextAPItests):
class PyContextAPItests(ContextAPItests, unittest.TestCase):
decimal = P
class ContextWithStatement(unittest.TestCase):
class ContextWithStatement:
# Can't do these as docstrings until Python 2.6
# as doctest can't handle __future__ statements
@@ -3605,6 +3696,48 @@ class ContextWithStatement(unittest.TestCase):
self.assertIsNot(new_ctx, set_ctx, 'did not copy the context')
self.assertIs(set_ctx, enter_ctx, '__enter__ returned wrong context')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_localcontext_kwargs(self):
with self.decimal.localcontext(
prec=10, rounding=ROUND_HALF_DOWN,
Emin=-20, Emax=20, capitals=0,
clamp=1
) as ctx:
self.assertEqual(ctx.prec, 10)
self.assertEqual(ctx.rounding, self.decimal.ROUND_HALF_DOWN)
self.assertEqual(ctx.Emin, -20)
self.assertEqual(ctx.Emax, 20)
self.assertEqual(ctx.capitals, 0)
self.assertEqual(ctx.clamp, 1)
self.assertRaises(TypeError, self.decimal.localcontext, precision=10)
self.assertRaises(ValueError, self.decimal.localcontext, Emin=1)
self.assertRaises(ValueError, self.decimal.localcontext, Emax=-1)
self.assertRaises(ValueError, self.decimal.localcontext, capitals=2)
self.assertRaises(ValueError, self.decimal.localcontext, clamp=2)
self.assertRaises(TypeError, self.decimal.localcontext, rounding="")
self.assertRaises(TypeError, self.decimal.localcontext, rounding=1)
self.assertRaises(TypeError, self.decimal.localcontext, flags="")
self.assertRaises(TypeError, self.decimal.localcontext, traps="")
self.assertRaises(TypeError, self.decimal.localcontext, Emin="")
self.assertRaises(TypeError, self.decimal.localcontext, Emax="")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_local_context_kwargs_does_not_overwrite_existing_argument(self):
ctx = self.decimal.getcontext()
orig_prec = ctx.prec
with self.decimal.localcontext(prec=10) as ctx2:
self.assertEqual(ctx2.prec, 10)
self.assertEqual(ctx.prec, orig_prec)
with self.decimal.localcontext(prec=20) as ctx2:
self.assertEqual(ctx2.prec, 20)
self.assertEqual(ctx.prec, orig_prec)
def test_nested_with_statements(self):
# Use a copy of the supplied context in the block
Decimal = self.decimal.Decimal
@@ -3697,12 +3830,13 @@ class ContextWithStatement(unittest.TestCase):
self.assertEqual(c4.prec, 4)
del c4
class CContextWithStatement(ContextWithStatement):
@requires_cdecimal
class CContextWithStatement(ContextWithStatement, unittest.TestCase):
decimal = C
class PyContextWithStatement(ContextWithStatement):
class PyContextWithStatement(ContextWithStatement, unittest.TestCase):
decimal = P
class ContextFlags(unittest.TestCase):
class ContextFlags:
def test_flags_irrelevant(self):
# check that the result (numeric result + flags raised) of an
@@ -3969,12 +4103,13 @@ class ContextFlags(unittest.TestCase):
self.assertTrue(context.traps[FloatOperation])
self.assertTrue(context.traps[Inexact])
class CContextFlags(ContextFlags):
@requires_cdecimal
class CContextFlags(ContextFlags, unittest.TestCase):
decimal = C
class PyContextFlags(ContextFlags):
class PyContextFlags(ContextFlags, unittest.TestCase):
decimal = P
class SpecialContexts(unittest.TestCase):
class SpecialContexts:
"""Test the context templates."""
def test_context_templates(self):
@@ -4054,12 +4189,13 @@ class SpecialContexts(unittest.TestCase):
if ex:
raise ex
class CSpecialContexts(SpecialContexts):
@requires_cdecimal
class CSpecialContexts(SpecialContexts, unittest.TestCase):
decimal = C
class PySpecialContexts(SpecialContexts):
class PySpecialContexts(SpecialContexts, unittest.TestCase):
decimal = P
class ContextInputValidation(unittest.TestCase):
class ContextInputValidation:
def test_invalid_context(self):
Context = self.decimal.Context
@@ -4121,12 +4257,13 @@ class ContextInputValidation(unittest.TestCase):
self.assertRaises(TypeError, Context, flags=(0,1))
self.assertRaises(TypeError, Context, traps=(1,0))
class CContextInputValidation(ContextInputValidation):
@requires_cdecimal
class CContextInputValidation(ContextInputValidation, unittest.TestCase):
decimal = C
class PyContextInputValidation(ContextInputValidation):
class PyContextInputValidation(ContextInputValidation, unittest.TestCase):
decimal = P
class ContextSubclassing(unittest.TestCase):
class ContextSubclassing:
def test_context_subclassing(self):
decimal = self.decimal
@@ -4235,12 +4372,14 @@ class ContextSubclassing(unittest.TestCase):
for signal in OrderedSignals[decimal]:
self.assertFalse(c.traps[signal])
class CContextSubclassing(ContextSubclassing):
@requires_cdecimal
class CContextSubclassing(ContextSubclassing, unittest.TestCase):
decimal = C
class PyContextSubclassing(ContextSubclassing):
class PyContextSubclassing(ContextSubclassing, unittest.TestCase):
decimal = P
@skip_if_extra_functionality
@requires_cdecimal
class CheckAttributes(unittest.TestCase):
def test_module_attributes(self):
@@ -4270,7 +4409,7 @@ class CheckAttributes(unittest.TestCase):
y = [s for s in dir(C.Decimal(9)) if '__' in s or not s.startswith('_')]
self.assertEqual(set(x) - set(y), set())
class Coverage(unittest.TestCase):
class Coverage:
def test_adjusted(self):
Decimal = self.decimal.Decimal
@@ -4527,11 +4666,21 @@ class Coverage(unittest.TestCase):
y = c.copy_sign(x, 1)
self.assertEqual(y, -x)
class CCoverage(Coverage):
@requires_cdecimal
class CCoverage(Coverage, unittest.TestCase):
decimal = C
class PyCoverage(Coverage):
class PyCoverage(Coverage, unittest.TestCase):
decimal = P
def setUp(self):
super().setUp()
self._previous_int_limit = sys.get_int_max_str_digits()
sys.set_int_max_str_digits(7000)
def tearDown(self):
sys.set_int_max_str_digits(self._previous_int_limit)
super().tearDown()
class PyFunctionality(unittest.TestCase):
"""Extra functionality in decimal.py"""
@@ -4773,6 +4922,7 @@ class CFunctionality(unittest.TestCase):
self.assertEqual(C.DecTraps,
C.DecErrors|C.DecOverflow|C.DecUnderflow)
@requires_cdecimal
class CWhitebox(unittest.TestCase):
"""Whitebox testing for _decimal"""
@@ -5426,6 +5576,7 @@ class CWhitebox(unittest.TestCase):
with localcontext() as c:
c.prec = 9
c.traps[InvalidOperation] = True
c.traps[Overflow] = True
c.traps[Underflow] = True
@@ -5510,6 +5661,7 @@ class CWhitebox(unittest.TestCase):
# Issue 41540:
@unittest.skipIf(sys.platform.startswith("aix"),
"AIX: default ulimit: test is flaky because of extreme over-allocation")
@unittest.skipIf(is_emscripten, "Test is unstable on Emscripten")
@unittest.skipIf(check_sanitizer(address=True, memory=True),
"ASAN/MSAN sanitizer defaults to crashing "
"instead of returning NULL for malloc failure.")
@@ -5548,8 +5700,38 @@ class CWhitebox(unittest.TestCase):
self.assertEqual(Decimal(400) ** -1, Decimal('0.0025'))
def test_c_signaldict_segfault(self):
# See gh-106263 for details.
SignalDict = type(C.Context().flags)
sd = SignalDict()
err_msg = "invalid signal dict"
with self.assertRaisesRegex(ValueError, err_msg):
len(sd)
with self.assertRaisesRegex(ValueError, err_msg):
iter(sd)
with self.assertRaisesRegex(ValueError, err_msg):
repr(sd)
with self.assertRaisesRegex(ValueError, err_msg):
sd[C.InvalidOperation] = True
with self.assertRaisesRegex(ValueError, err_msg):
sd[C.InvalidOperation]
with self.assertRaisesRegex(ValueError, err_msg):
sd == C.Context().flags
with self.assertRaisesRegex(ValueError, err_msg):
C.Context().flags == sd
with self.assertRaisesRegex(ValueError, err_msg):
sd.copy()
@requires_docstrings
@unittest.skipUnless(C, "test requires C version")
@requires_cdecimal
class SignatureTest(unittest.TestCase):
"""Function signatures"""
@@ -5685,52 +5867,10 @@ class SignatureTest(unittest.TestCase):
doit('Context')
all_tests = [
CExplicitConstructionTest, PyExplicitConstructionTest,
CImplicitConstructionTest, PyImplicitConstructionTest,
CFormatTest, PyFormatTest,
CArithmeticOperatorsTest, PyArithmeticOperatorsTest,
CThreadingTest, PyThreadingTest,
CUsabilityTest, PyUsabilityTest,
CPythonAPItests, PyPythonAPItests,
CContextAPItests, PyContextAPItests,
CContextWithStatement, PyContextWithStatement,
CContextFlags, PyContextFlags,
CSpecialContexts, PySpecialContexts,
CContextInputValidation, PyContextInputValidation,
CContextSubclassing, PyContextSubclassing,
CCoverage, PyCoverage,
CFunctionality, PyFunctionality,
CWhitebox, PyWhitebox,
CIBMTestCases, PyIBMTestCases,
]
# Delete C tests if _decimal.so is not present.
if not C:
all_tests = all_tests[1::2]
else:
all_tests.insert(0, CheckAttributes)
all_tests.insert(1, SignatureTest)
def test_main(arith=None, verbose=None, todo_tests=None, debug=None):
""" Execute the tests.
Runs all arithmetic tests if arith is True or if the "decimal" resource
is enabled in regrtest.py
"""
init(C)
init(P)
global TEST_ALL, DEBUG
TEST_ALL = arith if arith is not None else is_resource_enabled('decimal')
DEBUG = debug
if todo_tests is None:
test_classes = all_tests
else:
test_classes = [CIBMTestCases, PyIBMTestCases]
def load_tests(loader, tests, pattern):
if TODO_TESTS is not None:
# Run only Arithmetic tests
tests = loader.suiteClass()
# Dynamically build custom test definition for each file in the test
# directory and add the definitions to the DecimalTest class. This
# procedure insures that new files do not get skipped.
@@ -5738,34 +5878,69 @@ def test_main(arith=None, verbose=None, todo_tests=None, debug=None):
if '.decTest' not in filename or filename.startswith("."):
continue
head, tail = filename.split('.')
if todo_tests is not None and head not in todo_tests:
if TODO_TESTS is not None and head not in TODO_TESTS:
continue
tester = lambda self, f=filename: self.eval_file(directory + f)
setattr(CIBMTestCases, 'test_' + head, tester)
setattr(PyIBMTestCases, 'test_' + head, tester)
setattr(IBMTestCases, 'test_' + head, tester)
del filename, head, tail, tester
for prefix, mod in ('C', C), ('Py', P):
if not mod:
continue
test_class = type(prefix + 'IBMTestCases',
(IBMTestCases, unittest.TestCase),
{'decimal': mod})
tests.addTest(loader.loadTestsFromTestCase(test_class))
if TODO_TESTS is None:
from doctest import DocTestSuite, IGNORE_EXCEPTION_DETAIL
for mod in C, P:
if not mod:
continue
def setUp(slf, mod=mod):
sys.modules['decimal'] = mod
def tearDown(slf):
sys.modules['decimal'] = orig_sys_decimal
optionflags = IGNORE_EXCEPTION_DETAIL if mod is C else 0
sys.modules['decimal'] = mod
tests.addTest(DocTestSuite(mod, setUp=setUp, tearDown=tearDown,
optionflags=optionflags))
sys.modules['decimal'] = orig_sys_decimal
return tests
def setUpModule():
init(C)
init(P)
global TEST_ALL
TEST_ALL = ARITH if ARITH is not None else is_resource_enabled('decimal')
def tearDownModule():
if C: C.setcontext(ORIGINAL_CONTEXT[C])
P.setcontext(ORIGINAL_CONTEXT[P])
if not C:
warnings.warn('C tests skipped: no module named _decimal.',
UserWarning)
if not orig_sys_decimal is sys.modules['decimal']:
raise TestFailed("Internal error: unbalanced number of changes to "
"sys.modules['decimal'].")
try:
run_unittest(*test_classes)
if todo_tests is None:
from doctest import IGNORE_EXCEPTION_DETAIL
savedecimal = sys.modules['decimal']
if C:
sys.modules['decimal'] = C
run_doctest(C, verbose, optionflags=IGNORE_EXCEPTION_DETAIL)
sys.modules['decimal'] = P
run_doctest(P, verbose)
sys.modules['decimal'] = savedecimal
finally:
if C: C.setcontext(ORIGINAL_CONTEXT[C])
P.setcontext(ORIGINAL_CONTEXT[P])
if not C:
warnings.warn('C tests skipped: no module named _decimal.',
UserWarning)
if not orig_sys_decimal is sys.modules['decimal']:
raise TestFailed("Internal error: unbalanced number of changes to "
"sys.modules['decimal'].")
ARITH = None
TEST_ALL = True
TODO_TESTS = None
DEBUG = False
def test(arith=None, verbose=None, todo_tests=None, debug=None):
""" Execute the tests.
Runs all arithmetic tests if arith is True or if the "decimal" resource
is enabled in regrtest.py
"""
global ARITH, TODO_TESTS, DEBUG
ARITH = arith
TODO_TESTS = todo_tests
DEBUG = debug
unittest.main(__name__, verbosity=2 if verbose else 1, exit=False, argv=[__name__])
if __name__ == '__main__':
@@ -5776,8 +5951,8 @@ if __name__ == '__main__':
(opt, args) = p.parse_args()
if opt.skip:
test_main(arith=False, verbose=True)
test(arith=False, verbose=True)
elif args:
test_main(arith=True, verbose=True, todo_tests=args, debug=opt.debug)
test(arith=True, verbose=True, todo_tests=args, debug=opt.debug)
else:
test_main(arith=True, verbose=True)
test(arith=True, verbose=True)

714
Lib/test/test_enum.py vendored

File diff suppressed because it is too large Load Diff

View File

@@ -13,19 +13,29 @@ import time
import typing
import unittest
import unittest.mock
import weakref
import gc
from weakref import proxy
import contextlib
from inspect import Signature
from test.support import import_helper
from test.support import threading_helper
import functools
py_functools = import_helper.import_fresh_module('functools', blocked=['_functools'])
c_functools = import_helper.import_fresh_module('functools', fresh=['_functools'])
py_functools = import_helper.import_fresh_module('functools',
blocked=['_functools'])
c_functools = import_helper.import_fresh_module('functools',
fresh=['_functools'])
decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
_partial_types = [py_functools.partial]
if c_functools:
_partial_types.append(c_functools.partial)
@contextlib.contextmanager
def replaced_module(name, replacement):
original_module = sys.modules[name]
@@ -162,6 +172,7 @@ class TestPartial:
p = proxy(f)
self.assertEqual(f.func, p.func)
f = None
support.gc_collect() # For PyPy or other GCs.
self.assertRaises(ReferenceError, getattr, p, 'func')
def test_with_bound_and_unbound_methods(self):
@@ -196,7 +207,7 @@ class TestPartial:
kwargs = {'a': object(), 'b': object()}
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
'b={b!r}, a={a!r}'.format_map(kwargs)]
if self.partial in (c_functools.partial, py_functools.partial):
if self.partial in _partial_types:
name = 'functools.partial'
else:
name = self.partial.__name__
@@ -218,7 +229,7 @@ class TestPartial:
for kwargs_repr in kwargs_reprs])
def test_recursive_repr(self):
if self.partial in (c_functools.partial, py_functools.partial):
if self.partial in _partial_types:
name = 'functools.partial'
else:
name = self.partial.__name__
@@ -245,7 +256,7 @@ class TestPartial:
f.__setstate__((capture, (), {}, {}))
def test_pickle(self):
with self.AllowPickle():
with replaced_module('functools', self.module):
f = self.partial(signature, ['asdf'], bar=[True])
f.attr = []
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
@@ -328,7 +339,7 @@ class TestPartial:
self.assertIs(type(r[0]), tuple)
def test_recursive_pickle(self):
with self.AllowPickle():
with replaced_module('functools', self.module):
f = self.partial(capture)
f.__setstate__((f, (), {}, {}))
try:
@@ -382,24 +393,9 @@ class TestPartial:
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialC(TestPartial, unittest.TestCase):
if c_functools:
module = c_functools
partial = c_functools.partial
class AllowPickle:
def __enter__(self):
return self
def __exit__(self, type, value, tb):
return False
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_pickle(self):
super().test_pickle()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_recursive_pickle(self):
super().test_recursive_pickle()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_attributes_unwritable(self):
@@ -444,15 +440,9 @@ class TestPartialC(TestPartial, unittest.TestCase):
class TestPartialPy(TestPartial, unittest.TestCase):
module = py_functools
partial = py_functools.partial
class AllowPickle:
def __init__(self):
self._cm = replaced_module("functools", py_functools)
def __enter__(self):
return self._cm.__enter__()
def __exit__(self, type, value, tb):
return self._cm.__exit__(type, value, tb)
if c_functools:
class CPartialSubclass(c_functools.partial):
@@ -579,11 +569,9 @@ class TestPartialMethod(unittest.TestCase):
with self.assertRaises(TypeError):
class B:
method = functools.partialmethod()
with self.assertWarns(DeprecationWarning):
with self.assertRaises(TypeError):
class B:
method = functools.partialmethod(func=capture, a=1)
b = B()
self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3}))
def test_repr(self):
self.assertEqual(repr(vars(self.A)['both']),
@@ -634,6 +622,8 @@ class TestUpdateWrapper(unittest.TestCase):
def _default_update(self):
# XXX: RUSTPYTHON; f[T] is not supported yet
# def f[T](a:'This is a new annotation'):
def f(a:'This is a new annotation'):
"""This is a test"""
pass
@@ -644,15 +634,19 @@ class TestUpdateWrapper(unittest.TestCase):
functools.update_wrapper(wrapper, f)
return wrapper, f
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_default_update(self):
wrapper, f = self._default_update()
self.check_wrapper(wrapper, f)
T, = f.__type_params__
self.assertIs(wrapper.__wrapped__, f)
self.assertEqual(wrapper.__name__, 'f')
self.assertEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.attr, 'This is also a test')
self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
self.assertNotIn('b', wrapper.__annotations__)
self.assertEqual(wrapper.__type_params__, (T,))
@unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -O2 and above")
@@ -959,6 +953,10 @@ class TestCmpToKey:
self.assertRaises(TypeError, hash, k)
self.assertNotIsInstance(k, collections.abc.Hashable)
def test_cmp_to_signature(self):
self.assertEqual(str(Signature.from_callable(self.cmp_to_key)),
'(mycmp)')
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
@@ -1000,6 +998,18 @@ class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
def test_sort_int_str(self):
super().test_sort_int_str()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_cmp_to_signature(self):
super().test_cmp_to_signature()
@support.cpython_only
def test_disallow_instantiation(self):
# Ensure that the type disallows instantiation (bpo-43916)
support.check_disallow_instantiation(
self, type(c_functools.cmp_to_key(None))
)
class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
cmp_to_key = staticmethod(py_functools.cmp_to_key)
@@ -1093,6 +1103,73 @@ class TestTotalOrdering(unittest.TestCase):
class A:
pass
def test_notimplemented(self):
# Verify NotImplemented results are correctly handled
@functools.total_ordering
class ImplementsLessThan:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsLessThan):
return self.value == other.value
return False
def __lt__(self, other):
if isinstance(other, ImplementsLessThan):
return self.value < other.value
return NotImplemented
@functools.total_ordering
class ImplementsLessThanEqualTo:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsLessThanEqualTo):
return self.value == other.value
return False
def __le__(self, other):
if isinstance(other, ImplementsLessThanEqualTo):
return self.value <= other.value
return NotImplemented
@functools.total_ordering
class ImplementsGreaterThan:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsGreaterThan):
return self.value == other.value
return False
def __gt__(self, other):
if isinstance(other, ImplementsGreaterThan):
return self.value > other.value
return NotImplemented
@functools.total_ordering
class ImplementsGreaterThanEqualTo:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsGreaterThanEqualTo):
return self.value == other.value
return False
def __ge__(self, other):
if isinstance(other, ImplementsGreaterThanEqualTo):
return self.value >= other.value
return NotImplemented
self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented)
self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented)
self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented)
self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented)
self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented)
self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented)
self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented)
self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented)
self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented)
self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented)
self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented)
self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented)
def test_type_error_when_not_implemented(self):
# bug 10042; ensure stack overflow does not occur
# when decorated types return NotImplemented
@@ -1208,6 +1285,34 @@ class TestTotalOrdering(unittest.TestCase):
method_copy = pickle.loads(pickle.dumps(method, proto))
self.assertIs(method_copy, method)
def test_total_ordering_for_metaclasses_issue_44605(self):
@functools.total_ordering
class SortableMeta(type):
def __new__(cls, name, bases, ns):
return super().__new__(cls, name, bases, ns)
def __lt__(self, other):
if not isinstance(other, SortableMeta):
pass
return self.__name__ < other.__name__
def __eq__(self, other):
if not isinstance(other, SortableMeta):
pass
return self.__name__ == other.__name__
class B(metaclass=SortableMeta):
pass
class A(metaclass=SortableMeta):
pass
self.assertTrue(A < B)
self.assertFalse(A > B)
@functools.total_ordering
class Orderable_LT:
def __init__(self, value):
@@ -1218,6 +1323,25 @@ class Orderable_LT:
return self.value == other.value
class TestCache:
# This tests that the pass-through is working as designed.
# The underlying functionality is tested in TestLRU.
def test_cache(self):
@self.module.cache
def fib(n):
if n < 2:
return n
return fib(n-1) + fib(n-2)
self.assertEqual([fib(n) for n in range(16)],
[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
self.assertEqual(fib.cache_info(),
self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
fib.cache_clear()
self.assertEqual(fib.cache_info(),
self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
class TestLRU:
def test_lru(self):
@@ -1411,7 +1535,7 @@ class TestLRU:
def test_lru_star_arg_handling(self):
# Test regression that arose in ea064ff3c10f
@functools.lru_cache()
@self.module.lru_cache()
def f(*args):
return args
@@ -1423,11 +1547,11 @@ class TestLRU:
# lru_cache was leaking when one of the arguments
# wasn't cacheable.
@functools.lru_cache(maxsize=None)
@self.module.lru_cache(maxsize=None)
def infinite_cache(o):
pass
@functools.lru_cache(maxsize=10)
@self.module.lru_cache(maxsize=10)
def limited_cache(o):
pass
@@ -1492,6 +1616,33 @@ class TestLRU:
self.assertEqual(square.cache_info().hits, 4)
self.assertEqual(square.cache_info().misses, 4)
def test_lru_cache_typed_is_not_recursive(self):
cached = self.module.lru_cache(typed=True)(repr)
self.assertEqual(cached(1), '1')
self.assertEqual(cached(True), 'True')
self.assertEqual(cached(1.0), '1.0')
self.assertEqual(cached(0), '0')
self.assertEqual(cached(False), 'False')
self.assertEqual(cached(0.0), '0.0')
self.assertEqual(cached((1,)), '(1,)')
self.assertEqual(cached((True,)), '(1,)')
self.assertEqual(cached((1.0,)), '(1,)')
self.assertEqual(cached((0,)), '(0,)')
self.assertEqual(cached((False,)), '(0,)')
self.assertEqual(cached((0.0,)), '(0,)')
class T(tuple):
pass
self.assertEqual(cached(T((1,))), '(1,)')
self.assertEqual(cached(T((True,))), '(1,)')
self.assertEqual(cached(T((1.0,))), '(1,)')
self.assertEqual(cached(T((0,))), '(0,)')
self.assertEqual(cached(T((False,))), '(0,)')
self.assertEqual(cached(T((0.0,))), '(0,)')
def test_lru_with_keyword_args(self):
@self.module.lru_cache()
def fib(n):
@@ -1542,6 +1693,7 @@ class TestLRU:
# TODO: RUSTPYTHON
@unittest.expectedFailure
@threading_helper.requires_working_threading()
def test_lru_cache_threaded(self):
n, m = 5, 11
def orig(x, y):
@@ -1590,6 +1742,7 @@ class TestLRU:
finally:
sys.setswitchinterval(orig_si)
@threading_helper.requires_working_threading()
def test_lru_cache_threaded2(self):
# Simultaneous call with the same arguments
n, m = 5, 7
@@ -1617,6 +1770,7 @@ class TestLRU:
pause.reset()
self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
@threading_helper.requires_working_threading()
def test_lru_cache_threaded3(self):
@self.module.lru_cache(maxsize=2)
def f(x):
@@ -1717,14 +1871,62 @@ class TestLRU:
f_copy = copy.deepcopy(f)
self.assertIs(f_copy, f)
def test_lru_cache_parameters(self):
@self.module.lru_cache(maxsize=2)
def f():
return 1
self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
@self.module.lru_cache(maxsize=1000, typed=True)
def f():
return 1
self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
def test_lru_cache_weakrefable(self):
@self.module.lru_cache
def test_function(x):
return x
class A:
@self.module.lru_cache
def test_method(self, x):
return (self, x)
@staticmethod
@self.module.lru_cache
def test_staticmethod(x):
return (self, x)
refs = [weakref.ref(test_function),
weakref.ref(A.test_method),
weakref.ref(A.test_staticmethod)]
for ref in refs:
self.assertIsNotNone(ref())
del A
del test_function
gc.collect()
for ref in refs:
self.assertIsNone(ref())
def test_common_signatures(self):
def orig(): ...
lru = self.module.lru_cache(1)(orig)
self.assertEqual(str(Signature.from_callable(lru.cache_info)), '()')
self.assertEqual(str(Signature.from_callable(lru.cache_clear)), '()')
@py_functools.lru_cache()
def py_cached_func(x, y):
return 3 * x + y
@c_functools.lru_cache()
def c_cached_func(x, y):
return 3 * x + y
if c_functools:
@c_functools.lru_cache()
def c_cached_func(x, y):
return 3 * x + y
class TestLRUPy(TestLRU, unittest.TestCase):
@@ -1741,18 +1943,20 @@ class TestLRUPy(TestLRU, unittest.TestCase):
return 3 * x + y
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestLRUC(TestLRU, unittest.TestCase):
module = c_functools
cached_func = c_cached_func,
if c_functools:
module = c_functools
cached_func = c_cached_func,
@module.lru_cache()
def cached_meth(self, x, y):
return 3 * x + y
@module.lru_cache()
def cached_meth(self, x, y):
return 3 * x + y
@staticmethod
@module.lru_cache()
def cached_staticmeth(x, y):
return 3 * x + y
@staticmethod
@module.lru_cache()
def cached_staticmeth(x, y):
return 3 * x + y
class TestSingleDispatch(unittest.TestCase):
@@ -1867,7 +2071,7 @@ class TestSingleDispatch(unittest.TestCase):
c.MutableSequence.register(D)
bases = [c.MutableSequence, c.MutableMapping]
for haystack in permutations(bases):
m = mro(D, bases)
m = mro(D, haystack)
self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
collections.defaultdict, dict, c.MutableMapping, c.Mapping,
c.Collection, c.Sized, c.Iterable, c.Container,
@@ -2370,7 +2574,7 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(A.t(0.0).arg, "base")
def test_abstractmethod_register(self):
class Abstract(abc.ABCMeta):
class Abstract(metaclass=abc.ABCMeta):
@functools.singledispatchmethod
@abc.abstractmethod
@@ -2378,6 +2582,10 @@ class TestSingleDispatch(unittest.TestCase):
pass
self.assertTrue(Abstract.add.__isabstractmethod__)
self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__)
with self.assertRaises(TypeError):
Abstract()
def test_type_ann_register(self):
class A:
@@ -2396,6 +2604,183 @@ class TestSingleDispatch(unittest.TestCase):
self.assertEqual(a.t(''), "str")
self.assertEqual(a.t(0.0), "base")
def test_staticmethod_type_ann_register(self):
class A:
@functools.singledispatchmethod
@staticmethod
def t(arg):
return arg
@t.register
@staticmethod
def _(arg: int):
return isinstance(arg, int)
@t.register
@staticmethod
def _(arg: str):
return isinstance(arg, str)
a = A()
self.assertTrue(A.t(0))
self.assertTrue(A.t(''))
self.assertEqual(A.t(0.0), 0.0)
def test_classmethod_type_ann_register(self):
class A:
def __init__(self, arg):
self.arg = arg
@functools.singledispatchmethod
@classmethod
def t(cls, arg):
return cls("base")
@t.register
@classmethod
def _(cls, arg: int):
return cls("int")
@t.register
@classmethod
def _(cls, arg: str):
return cls("str")
self.assertEqual(A.t(0).arg, "int")
self.assertEqual(A.t('').arg, "str")
self.assertEqual(A.t(0.0).arg, "base")
def test_method_wrapping_attributes(self):
class A:
@functools.singledispatchmethod
def func(self, arg: int) -> str:
"""My function docstring"""
return str(arg)
@functools.singledispatchmethod
@classmethod
def cls_func(cls, arg: int) -> str:
"""My function docstring"""
return str(arg)
@functools.singledispatchmethod
@staticmethod
def static_func(arg: int) -> str:
"""My function docstring"""
return str(arg)
for meth in (
A.func,
A().func,
A.cls_func,
A().cls_func,
A.static_func,
A().static_func
):
with self.subTest(meth=meth):
self.assertEqual(meth.__doc__, 'My function docstring')
self.assertEqual(meth.__annotations__['arg'], int)
self.assertEqual(A.func.__name__, 'func')
self.assertEqual(A().func.__name__, 'func')
self.assertEqual(A.cls_func.__name__, 'cls_func')
self.assertEqual(A().cls_func.__name__, 'cls_func')
self.assertEqual(A.static_func.__name__, 'static_func')
self.assertEqual(A().static_func.__name__, 'static_func')
def test_double_wrapped_methods(self):
def classmethod_friendly_decorator(func):
wrapped = func.__func__
@classmethod
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
return wrapped(*args, **kwargs)
return wrapper
class WithoutSingleDispatch:
@classmethod
@contextlib.contextmanager
def cls_context_manager(cls, arg: int) -> str:
try:
yield str(arg)
finally:
return 'Done'
@classmethod_friendly_decorator
@classmethod
def decorated_classmethod(cls, arg: int) -> str:
return str(arg)
class WithSingleDispatch:
@functools.singledispatchmethod
@classmethod
@contextlib.contextmanager
def cls_context_manager(cls, arg: int) -> str:
"""My function docstring"""
try:
yield str(arg)
finally:
return 'Done'
@functools.singledispatchmethod
@classmethod_friendly_decorator
@classmethod
def decorated_classmethod(cls, arg: int) -> str:
"""My function docstring"""
return str(arg)
# These are sanity checks
# to test the test itself is working as expected
with WithoutSingleDispatch.cls_context_manager(5) as foo:
without_single_dispatch_foo = foo
with WithSingleDispatch.cls_context_manager(5) as foo:
single_dispatch_foo = foo
self.assertEqual(without_single_dispatch_foo, single_dispatch_foo)
self.assertEqual(single_dispatch_foo, '5')
self.assertEqual(
WithoutSingleDispatch.decorated_classmethod(5),
WithSingleDispatch.decorated_classmethod(5)
)
self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5')
# Behavioural checks now follow
for method_name in ('cls_context_manager', 'decorated_classmethod'):
with self.subTest(method=method_name):
self.assertEqual(
getattr(WithSingleDispatch, method_name).__name__,
getattr(WithoutSingleDispatch, method_name).__name__
)
self.assertEqual(
getattr(WithSingleDispatch(), method_name).__name__,
getattr(WithoutSingleDispatch(), method_name).__name__
)
for meth in (
WithSingleDispatch.cls_context_manager,
WithSingleDispatch().cls_context_manager,
WithSingleDispatch.decorated_classmethod,
WithSingleDispatch().decorated_classmethod
):
with self.subTest(meth=meth):
self.assertEqual(meth.__doc__, 'My function docstring')
self.assertEqual(meth.__annotations__['arg'], int)
self.assertEqual(
WithSingleDispatch.cls_context_manager.__name__,
'cls_context_manager'
)
self.assertEqual(
WithSingleDispatch().cls_context_manager.__name__,
'cls_context_manager'
)
self.assertEqual(
WithSingleDispatch.decorated_classmethod.__name__,
'decorated_classmethod'
)
self.assertEqual(
WithSingleDispatch().decorated_classmethod.__name__,
'decorated_classmethod'
)
def test_invalid_registrations(self):
msg_prefix = "Invalid first argument to `register()`: "
msg_suffix = (
@@ -2435,6 +2820,17 @@ class TestSingleDispatch(unittest.TestCase):
'typing.Iterable[str] is not a class.'
))
with self.assertRaises(TypeError) as exc:
@i.register
def _(arg: typing.Union[int, typing.Iterable[str]]):
return "Invalid Union"
self.assertTrue(str(exc.exception).startswith(
"Invalid annotation for 'arg'."
))
self.assertTrue(str(exc.exception).endswith(
'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
))
def test_invalid_positional_argument(self):
@functools.singledispatch
def f(*args):
@@ -2443,6 +2839,134 @@ class TestSingleDispatch(unittest.TestCase):
with self.assertRaisesRegex(TypeError, msg):
f()
def test_union(self):
@functools.singledispatch
def f(arg):
return "default"
@f.register
def _(arg: typing.Union[str, bytes]):
return "typing.Union"
@f.register
def _(arg: int | float):
return "types.UnionType"
self.assertEqual(f([]), "default")
self.assertEqual(f(""), "typing.Union")
self.assertEqual(f(b""), "typing.Union")
self.assertEqual(f(1), "types.UnionType")
self.assertEqual(f(1.0), "types.UnionType")
def test_union_conflict(self):
@functools.singledispatch
def f(arg):
return "default"
@f.register
def _(arg: typing.Union[str, bytes]):
return "typing.Union"
@f.register
def _(arg: int | str):
return "types.UnionType"
self.assertEqual(f([]), "default")
self.assertEqual(f(""), "types.UnionType") # last one wins
self.assertEqual(f(b""), "typing.Union")
self.assertEqual(f(1), "types.UnionType")
def test_union_None(self):
@functools.singledispatch
def typing_union(arg):
return "default"
@typing_union.register
def _(arg: typing.Union[str, None]):
return "typing.Union"
self.assertEqual(typing_union(1), "default")
self.assertEqual(typing_union(""), "typing.Union")
self.assertEqual(typing_union(None), "typing.Union")
@functools.singledispatch
def types_union(arg):
return "default"
@types_union.register
def _(arg: int | None):
return "types.UnionType"
self.assertEqual(types_union(""), "default")
self.assertEqual(types_union(1), "types.UnionType")
self.assertEqual(types_union(None), "types.UnionType")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_register_genericalias(self):
@functools.singledispatch
def f(arg):
return "default"
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int], lambda arg: "types.GenericAlias")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[int], lambda arg: "typing.GenericAlias")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
self.assertEqual(f([1]), "default")
self.assertEqual(f([1.0]), "default")
self.assertEqual(f(""), "default")
self.assertEqual(f(b""), "default")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_register_genericalias_decorator(self):
@functools.singledispatch
def f(arg):
return "default"
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int])
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[int])
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(list[int] | str)
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
f.register(typing.List[int] | str)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_register_genericalias_annotation(self):
@functools.singledispatch
def f(arg):
return "default"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: list[int]):
return "types.GenericAlias"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: typing.List[float]):
return "typing.GenericAlias"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: list[int] | str):
return "types.UnionType(types.GenericAlias)"
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
@f.register
def _(arg: typing.List[float] | bytes):
return "typing.Union[typing.GenericAlias]"
self.assertEqual(f([1]), "default")
self.assertEqual(f([1.0]), "default")
self.assertEqual(f(""), "default")
self.assertEqual(f(b""), "default")
class CachedCostItem:
_cost = 1
@@ -2469,21 +2993,6 @@ class OptionallyCachedCostItem:
cached_cost = py_functools.cached_property(get_cost)
class CachedCostItemWait:
def __init__(self, event):
self._cost = 1
self.lock = py_functools.RLock()
self.event = event
@py_functools.cached_property
def cost(self):
self.event.wait(1)
with self.lock:
self._cost += 1
return self._cost
class CachedCostItemWithSlots:
__slots__ = ('_cost')
@@ -2508,28 +3017,6 @@ class TestCachedProperty(unittest.TestCase):
self.assertEqual(item.get_cost(), 4)
self.assertEqual(item.cached_cost, 3)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_threaded(self):
go = threading.Event()
item = CachedCostItemWait(go)
num_threads = 3
orig_si = sys.getswitchinterval()
sys.setswitchinterval(1e-6)
try:
threads = [
threading.Thread(target=lambda: item.cost)
for k in range(num_threads)
]
with threading_helper.start_threads(threads):
go.set()
finally:
sys.setswitchinterval(orig_si)
self.assertEqual(item.cost, 2)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_object_with_slots(self):
@@ -2559,7 +3046,7 @@ class TestCachedProperty(unittest.TestCase):
@unittest.expectedFailure
def test_reuse_different_names(self):
"""Disallow this case because decorated function a would not be cached."""
with self.assertRaises(RuntimeError) as ctx:
with self.assertRaises(TypeError) as ctx:
class ReusedCachedProperty:
@py_functools.cached_property
def a(self):
@@ -2568,7 +3055,7 @@ class TestCachedProperty(unittest.TestCase):
b = a
self.assertEqual(
str(ctx.exception.__context__),
str(ctx.exception),
str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
)
@@ -2614,6 +3101,25 @@ class TestCachedProperty(unittest.TestCase):
def test_doc(self):
self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
def test_subclass_with___set__(self):
"""Caching still works for a subclass defining __set__."""
class readonly_cached_property(py_functools.cached_property):
def __set__(self, obj, value):
raise AttributeError("read only property")
class Test:
def __init__(self, prop):
self._prop = prop
@readonly_cached_property
def prop(self):
return self._prop
t = Test(1)
self.assertEqual(t.prop, 1)
t._prop = 999
self.assertEqual(t.prop, 1)
if __name__ == '__main__':
unittest.main()

13
Lib/test/test_importlib/_context.py vendored Normal file
View File

@@ -0,0 +1,13 @@
import contextlib
# from jaraco.context 4.3
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""
A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""

109
Lib/test/test_importlib/_path.py vendored Normal file
View File

@@ -0,0 +1,109 @@
# from jaraco.path 3.5
import functools
import pathlib
from typing import Dict, Union
try:
from typing import Protocol, runtime_checkable
except ImportError: # pragma: no cover
# Python 3.7
from typing_extensions import Protocol, runtime_checkable # type: ignore
FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore
@runtime_checkable
class TreeMaker(Protocol):
def __truediv__(self, *args, **kwargs):
... # pragma: no cover
def mkdir(self, **kwargs):
... # pragma: no cover
def write_text(self, content, **kwargs):
... # pragma: no cover
def write_bytes(self, content):
... # pragma: no cover
def _ensure_tree_maker(obj: Union[str, TreeMaker]) -> TreeMaker:
return obj if isinstance(obj, TreeMaker) else pathlib.Path(obj) # type: ignore
def build(
spec: FilesSpec,
prefix: Union[str, TreeMaker] = pathlib.Path(), # type: ignore
):
"""
Build a set of files/directories, as described by the spec.
Each key represents a pathname, and the value represents
the content. Content may be a nested directory.
>>> spec = {
... 'README.txt': "A README file",
... "foo": {
... "__init__.py": "",
... "bar": {
... "__init__.py": "",
... },
... "baz.py": "# Some code",
... }
... }
>>> target = getfixture('tmp_path')
>>> build(spec, target)
>>> target.joinpath('foo/baz.py').read_text(encoding='utf-8')
'# Some code'
"""
for name, contents in spec.items():
create(contents, _ensure_tree_maker(prefix) / name)
@functools.singledispatch
def create(content: Union[str, bytes, FilesSpec], path):
path.mkdir(exist_ok=True)
build(content, prefix=path) # type: ignore
@create.register
def _(content: bytes, path):
path.write_bytes(content)
@create.register
def _(content: str, path):
path.write_text(content, encoding='utf-8')
@create.register
def _(content: str, path):
path.write_text(content, encoding='utf-8')
class Recording:
"""
A TreeMaker object that records everything that would be written.
>>> r = Recording()
>>> build({'foo': {'foo1.txt': 'yes'}, 'bar.txt': 'abc'}, r)
>>> r.record
['foo/foo1.txt', 'bar.txt']
"""
def __init__(self, loc=pathlib.PurePosixPath(), record=None):
self.loc = loc
self.record = record if record is not None else []
def __truediv__(self, other):
return Recording(self.loc / other, self.record)
def write_text(self, content, **kwargs):
self.record.append(str(self.loc))
write_bytes = write_text
def mkdir(self, **kwargs):
return

View File

@@ -37,61 +37,11 @@ class FindSpecTests(abc.FinderTests):
spec = self.machinery.BuiltinImporter.find_spec(name)
self.assertIsNone(spec)
def test_ignore_path(self):
# The value for 'path' should always trigger a failed import.
with util.uncache(util.BUILTINS.good_name):
spec = self.machinery.BuiltinImporter.find_spec(util.BUILTINS.good_name,
['pkg'])
self.assertIsNone(spec)
(Frozen_FindSpecTests,
Source_FindSpecTests
) = util.test_both(FindSpecTests, machinery=machinery)
@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module')
class FinderTests(abc.FinderTests):
"""Test find_module() for built-in modules."""
def test_module(self):
# Common case.
with util.uncache(util.BUILTINS.good_name):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
found = self.machinery.BuiltinImporter.find_module(util.BUILTINS.good_name)
self.assertTrue(found)
self.assertTrue(hasattr(found, 'load_module'))
# Built-in modules cannot be a package.
test_package = test_package_in_package = test_package_over_module = None
# Built-in modules cannot be in a package.
test_module_in_package = None
def test_failure(self):
assert 'importlib' not in sys.builtin_module_names
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loader = self.machinery.BuiltinImporter.find_module('importlib')
self.assertIsNone(loader)
def test_ignore_path(self):
# The value for 'path' should always trigger a failed import.
with util.uncache(util.BUILTINS.good_name):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loader = self.machinery.BuiltinImporter.find_module(
util.BUILTINS.good_name,
['pkg'])
self.assertIsNone(loader)
(Frozen_FinderTests,
Source_FinderTests
) = util.test_both(FinderTests, machinery=machinery)
if __name__ == '__main__':
unittest.main()

View File

@@ -8,7 +8,7 @@ importlib = util.import_importlib('importlib')
machinery = util.import_importlib('importlib.machinery')
@unittest.skipIf(util.EXTENSIONS.filename is None, '_testcapi not available')
@unittest.skipIf(util.EXTENSIONS.filename is None, f'{util.EXTENSIONS.name} not available')
@util.case_insensitive_tests
class ExtensionModuleCaseSensitivityTest(util.CASEOKTestBase):

View File

@@ -13,9 +13,9 @@ import importlib
from test.support.script_helper import assert_python_failure
class LoaderTests(abc.LoaderTests):
class LoaderTests:
"""Test load_module() for extension modules."""
"""Test ExtensionFileLoader."""
def setUp(self):
if not self.machinery.EXTENSION_SUFFIXES:
@@ -32,17 +32,6 @@ class LoaderTests(abc.LoaderTests):
warnings.simplefilter("ignore", DeprecationWarning)
return self.loader.load_module(fullname)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_load_module_API(self):
# Test the default argument for load_module().
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
self.loader.load_module()
self.loader.load_module(None)
with self.assertRaises(ImportError):
self.load_module('XXX')
def test_equality(self):
other = self.machinery.ExtensionFileLoader(util.EXTENSIONS.name,
util.EXTENSIONS.file_path)
@@ -53,6 +42,15 @@ class LoaderTests(abc.LoaderTests):
util.EXTENSIONS.file_path)
self.assertNotEqual(self.loader, other)
def test_load_module_API(self):
# Test the default argument for load_module().
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
self.loader.load_module()
self.loader.load_module(None)
with self.assertRaises(ImportError):
self.load_module('XXX')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_module(self):
@@ -72,14 +70,6 @@ class LoaderTests(abc.LoaderTests):
# No extension module in a package available for testing.
test_lacking_parent = None
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_module_reuse(self):
with util.uncache(util.EXTENSIONS.name):
module1 = self.load_module(util.EXTENSIONS.name)
module2 = self.load_module(util.EXTENSIONS.name)
self.assertIs(module1, module2)
# No easy way to trigger a failure after a successful import.
test_state_after_failure = None
@@ -89,6 +79,12 @@ class LoaderTests(abc.LoaderTests):
self.load_module(name)
self.assertEqual(cm.exception.name, name)
def test_module_reuse(self):
with util.uncache(util.EXTENSIONS.name):
module1 = self.load_module(util.EXTENSIONS.name)
module2 = self.load_module(util.EXTENSIONS.name)
self.assertIs(module1, module2)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_is_package(self):
@@ -98,11 +94,94 @@ class LoaderTests(abc.LoaderTests):
loader = self.machinery.ExtensionFileLoader('pkg', path)
self.assertTrue(loader.is_package('pkg'))
(Frozen_LoaderTests,
Source_LoaderTests
) = util.test_both(LoaderTests, machinery=machinery)
@unittest.skip("TODO: RUSTPYTHON, AssertionError")
class SinglePhaseExtensionModuleTests(abc.LoaderTests):
# Test loading extension modules without multi-phase initialization.
def setUp(self):
if not self.machinery.EXTENSION_SUFFIXES:
raise unittest.SkipTest("Requires dynamic loading support.")
self.name = '_testsinglephase'
if self.name in sys.builtin_module_names:
raise unittest.SkipTest(
f"{self.name} is a builtin module"
)
finder = self.machinery.FileFinder(None)
self.spec = importlib.util.find_spec(self.name)
assert self.spec
self.loader = self.machinery.ExtensionFileLoader(
self.name, self.spec.origin)
def load_module(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return self.loader.load_module(self.name)
def load_module_by_name(self, fullname):
# Load a module from the test extension by name.
origin = self.spec.origin
loader = self.machinery.ExtensionFileLoader(fullname, origin)
spec = importlib.util.spec_from_loader(fullname, loader)
module = importlib.util.module_from_spec(spec)
loader.exec_module(module)
return module
def test_module(self):
# Test loading an extension module.
with util.uncache(self.name):
module = self.load_module()
for attr, value in [('__name__', self.name),
('__file__', self.spec.origin),
('__package__', '')]:
self.assertEqual(getattr(module, attr), value)
with self.assertRaises(AttributeError):
module.__path__
self.assertIs(module, sys.modules[self.name])
self.assertIsInstance(module.__loader__,
self.machinery.ExtensionFileLoader)
# No extension module as __init__ available for testing.
test_package = None
# No extension module in a package available for testing.
test_lacking_parent = None
# No easy way to trigger a failure after a successful import.
test_state_after_failure = None
def test_unloadable(self):
name = 'asdfjkl;'
with self.assertRaises(ImportError) as cm:
self.load_module_by_name(name)
self.assertEqual(cm.exception.name, name)
def test_unloadable_nonascii(self):
# Test behavior with nonexistent module with non-ASCII name.
name = 'fo\xf3'
with self.assertRaises(ImportError) as cm:
self.load_module_by_name(name)
self.assertEqual(cm.exception.name, name)
# It may make sense to add the equivalent to
# the following MultiPhaseExtensionModuleTests tests:
#
# * test_nonmodule
# * test_nonmodule_with_methods
# * test_bad_modules
# * test_nonascii
(Frozen_SinglePhaseExtensionModuleTests,
Source_SinglePhaseExtensionModuleTests
) = util.test_both(SinglePhaseExtensionModuleTests, machinery=machinery)
# @unittest.skip("TODO: RUSTPYTHON, AssertionError")
class MultiPhaseExtensionModuleTests(abc.LoaderTests):
# Test loading extension modules with multi-phase initialization (PEP 489).
@@ -188,15 +267,16 @@ class MultiPhaseExtensionModuleTests(abc.LoaderTests):
def test_try_registration(self):
# Assert that the PyState_{Find,Add,Remove}Module C API doesn't work.
module = self.load_module()
with self.subTest('PyState_FindModule'):
self.assertEqual(module.call_state_registration_func(0), None)
with self.subTest('PyState_AddModule'):
with self.assertRaises(SystemError):
module.call_state_registration_func(1)
with self.subTest('PyState_RemoveModule'):
with self.assertRaises(SystemError):
module.call_state_registration_func(2)
with util.uncache(self.name):
module = self.load_module()
with self.subTest('PyState_FindModule'):
self.assertEqual(module.call_state_registration_func(0), None)
with self.subTest('PyState_AddModule'):
with self.assertRaises(SystemError):
module.call_state_registration_func(1)
with self.subTest('PyState_RemoveModule'):
with self.assertRaises(SystemError):
module.call_state_registration_func(2)
def test_load_submodule(self):
# Test loading a simulated submodule.
@@ -274,12 +354,19 @@ class MultiPhaseExtensionModuleTests(abc.LoaderTests):
'exec_err',
'exec_raise',
'exec_unreported_exception',
'multiple_create_slots',
'multiple_multiple_interpreters_slots',
]:
with self.subTest(name_base):
name = self.name + '_' + name_base
with self.assertRaises(SystemError):
with self.assertRaises(SystemError) as cm:
self.load_module_by_name(name)
# If there is an unreported exception, it should be chained
# with the `SystemError`.
if "unreported_exception" in name_base:
self.assertIsNotNone(cm.exception.__cause__)
def test_nonascii(self):
# Test that modules with non-ASCII names can be loaded.
# punycode behaves slightly differently in some-ASCII and no-ASCII

View File

@@ -19,7 +19,7 @@ class PathHookTests:
def test_success(self):
# Path hook should handle a directory where a known extension module
# exists.
self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_module'))
self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_spec'))
(Frozen_PathHooksTests,

View File

@@ -10,7 +10,10 @@ import contextlib
from test.support.os_helper import FS_NONASCII
from test.support import requires_zlib
from typing import Dict, Union
from . import _path
from ._path import FilesSpec
try:
from importlib import resources # type: ignore
@@ -83,13 +86,8 @@ class OnSysPath(Fixtures):
self.fixtures.enter_context(self.add_sys_path(self.site_dir))
# Except for python/mypy#731, prefer to define
# FilesDef = Dict[str, Union['FilesDef', str]]
FilesDef = Dict[str, Union[Dict[str, Union[Dict[str, str], str]], str]]
class DistInfoPkg(OnSysPath, SiteDir):
files: FilesDef = {
files: FilesSpec = {
"distinfo_pkg-1.0.0.dist-info": {
"METADATA": """
Name: distinfo-pkg
@@ -131,7 +129,7 @@ class DistInfoPkg(OnSysPath, SiteDir):
class DistInfoPkgWithDot(OnSysPath, SiteDir):
files: FilesDef = {
files: FilesSpec = {
"pkg_dot-1.0.0.dist-info": {
"METADATA": """
Name: pkg.dot
@@ -146,7 +144,7 @@ class DistInfoPkgWithDot(OnSysPath, SiteDir):
class DistInfoPkgWithDotLegacy(OnSysPath, SiteDir):
files: FilesDef = {
files: FilesSpec = {
"pkg.dot-1.0.0.dist-info": {
"METADATA": """
Name: pkg.dot
@@ -173,7 +171,7 @@ class DistInfoPkgOffPath(SiteDir):
class EggInfoPkg(OnSysPath, SiteDir):
files: FilesDef = {
files: FilesSpec = {
"egginfo_pkg.egg-info": {
"PKG-INFO": """
Name: egginfo-pkg
@@ -212,8 +210,99 @@ class EggInfoPkg(OnSysPath, SiteDir):
build_files(EggInfoPkg.files, prefix=self.site_dir)
class EggInfoPkgPipInstalledNoToplevel(OnSysPath, SiteDir):
files: FilesSpec = {
"egg_with_module_pkg.egg-info": {
"PKG-INFO": "Name: egg_with_module-pkg",
# SOURCES.txt is made from the source archive, and contains files
# (setup.py) that are not present after installation.
"SOURCES.txt": """
egg_with_module.py
setup.py
egg_with_module_pkg.egg-info/PKG-INFO
egg_with_module_pkg.egg-info/SOURCES.txt
egg_with_module_pkg.egg-info/top_level.txt
""",
# installed-files.txt is written by pip, and is a strictly more
# accurate source than SOURCES.txt as to the installed contents of
# the package.
"installed-files.txt": """
../egg_with_module.py
PKG-INFO
SOURCES.txt
top_level.txt
""",
# missing top_level.txt (to trigger fallback to installed-files.txt)
},
"egg_with_module.py": """
def main():
print("hello world")
""",
}
def setUp(self):
super().setUp()
build_files(EggInfoPkgPipInstalledNoToplevel.files, prefix=self.site_dir)
class EggInfoPkgPipInstalledNoModules(OnSysPath, SiteDir):
files: FilesSpec = {
"egg_with_no_modules_pkg.egg-info": {
"PKG-INFO": "Name: egg_with_no_modules-pkg",
# SOURCES.txt is made from the source archive, and contains files
# (setup.py) that are not present after installation.
"SOURCES.txt": """
setup.py
egg_with_no_modules_pkg.egg-info/PKG-INFO
egg_with_no_modules_pkg.egg-info/SOURCES.txt
egg_with_no_modules_pkg.egg-info/top_level.txt
""",
# installed-files.txt is written by pip, and is a strictly more
# accurate source than SOURCES.txt as to the installed contents of
# the package.
"installed-files.txt": """
PKG-INFO
SOURCES.txt
top_level.txt
""",
# top_level.txt correctly reflects that no modules are installed
"top_level.txt": b"\n",
},
}
def setUp(self):
super().setUp()
build_files(EggInfoPkgPipInstalledNoModules.files, prefix=self.site_dir)
class EggInfoPkgSourcesFallback(OnSysPath, SiteDir):
files: FilesSpec = {
"sources_fallback_pkg.egg-info": {
"PKG-INFO": "Name: sources_fallback-pkg",
# SOURCES.txt is made from the source archive, and contains files
# (setup.py) that are not present after installation.
"SOURCES.txt": """
sources_fallback.py
setup.py
sources_fallback_pkg.egg-info/PKG-INFO
sources_fallback_pkg.egg-info/SOURCES.txt
""",
# missing installed-files.txt (i.e. not installed by pip) and
# missing top_level.txt (to trigger fallback to SOURCES.txt)
},
"sources_fallback.py": """
def main():
print("hello world")
""",
}
def setUp(self):
super().setUp()
build_files(EggInfoPkgSourcesFallback.files, prefix=self.site_dir)
class EggInfoFile(OnSysPath, SiteDir):
files: FilesDef = {
files: FilesSpec = {
"egginfo_file.egg-info": """
Metadata-Version: 1.0
Name: egginfo_file
@@ -233,38 +322,22 @@ class EggInfoFile(OnSysPath, SiteDir):
build_files(EggInfoFile.files, prefix=self.site_dir)
def build_files(file_defs, prefix=pathlib.Path()):
"""Build a set of files/directories, as described by the
# dedent all text strings before writing
orig = _path.create.registry[str]
_path.create.register(str, lambda content, path: orig(DALS(content), path))
file_defs dictionary. Each key/value pair in the dictionary is
interpreted as a filename/contents pair. If the contents value is a
dictionary, a directory is created, and the dictionary interpreted
as the files within it, recursively.
For example:
build_files = _path.build
{"README.txt": "A README file",
"foo": {
"__init__.py": "",
"bar": {
"__init__.py": "",
},
"baz.py": "# Some code",
}
}
"""
for name, contents in file_defs.items():
full_name = prefix / name
if isinstance(contents, dict):
full_name.mkdir()
build_files(contents, prefix=full_name)
else:
if isinstance(contents, bytes):
with full_name.open('wb') as f:
f.write(contents)
else:
with full_name.open('w', encoding='utf-8') as f:
f.write(DALS(contents))
def build_record(file_defs):
return ''.join(f'{name},,\n' for name in record_names(file_defs))
def record_names(file_defs):
recording = _path.Recording()
_path.build(file_defs, recording)
return recording.record
class FileBuilder:
@@ -277,11 +350,6 @@ def DALS(str):
return textwrap.dedent(str).lstrip()
class NullFinder:
def find_module(self, name):
pass
@requires_zlib()
class ZipFixtures:
root = 'test.test_importlib.data'

View File

@@ -70,14 +70,6 @@ class FindSpecTests(abc.FinderTests):
expected = [os.path.dirname(filename)]
self.assertListEqual(spec.submodule_search_locations, expected)
def test_package(self):
spec = self.find('__phello__')
self.assertIsNotNone(spec)
def test_module_in_package(self):
spec = self.find('__phello__.spam', ['__phello__'])
self.assertIsNotNone(spec)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_module(self):
@@ -196,45 +188,5 @@ class FindSpecTests(abc.FinderTests):
) = util.test_both(FindSpecTests, machinery=machinery)
class FinderTests(abc.FinderTests):
"""Test finding frozen modules."""
def find(self, name, path=None):
finder = self.machinery.FrozenImporter
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
with import_helper.frozen_modules():
return finder.find_module(name, path)
def test_module(self):
name = '__hello__'
loader = self.find(name)
self.assertTrue(hasattr(loader, 'load_module'))
def test_package(self):
loader = self.find('__phello__')
self.assertTrue(hasattr(loader, 'load_module'))
def test_module_in_package(self):
loader = self.find('__phello__.spam', ['__phello__'])
self.assertTrue(hasattr(loader, 'load_module'))
# No frozen package within another package to test with.
test_package_in_package = None
# No easy way to test.
test_package_over_module = None
def test_failure(self):
loader = self.find('<not real>')
self.assertIsNone(loader)
(Frozen_FinderTests,
Source_FinderTests
) = util.test_both(FinderTests, machinery=machinery)
if __name__ == '__main__':
unittest.main()

View File

@@ -103,15 +103,7 @@ class ExecModuleTests(abc.LoaderTests):
expected=value))
self.assertEqual(output, 'Hello world!\n')
def test_module_repr(self):
name = '__hello__'
module, output = self.exec_module(name)
with deprecated():
repr_str = self.machinery.FrozenImporter.module_repr(module)
self.assertEqual(repr_str,
"<module '__hello__' (frozen)>")
def test_module_repr_indirect(self):
def test_module_repr_indirect_through_spec(self):
name = '__hello__'
module, output = self.exec_module(name)
self.assertEqual(repr(module),
@@ -133,101 +125,6 @@ class ExecModuleTests(abc.LoaderTests):
) = util.test_both(ExecModuleTests, machinery=machinery)
class LoaderTests(abc.LoaderTests):
def load_module(self, name):
with fresh(name, oldapi=True):
module = self.machinery.FrozenImporter.load_module(name)
with captured_stdout() as stdout:
module.main()
return module, stdout
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_module(self):
module, stdout = self.load_module('__hello__')
filename = resolve_stdlib_file('__hello__')
check = {'__name__': '__hello__',
'__package__': '',
'__loader__': self.machinery.FrozenImporter,
'__file__': filename,
}
for attr, value in check.items():
self.assertEqual(getattr(module, attr, None), value)
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_package(self):
module, stdout = self.load_module('__phello__')
filename = resolve_stdlib_file('__phello__', ispkg=True)
pkgdir = os.path.dirname(filename)
check = {'__name__': '__phello__',
'__package__': '__phello__',
'__path__': [pkgdir],
'__loader__': self.machinery.FrozenImporter,
'__file__': filename,
}
for attr, value in check.items():
attr_value = getattr(module, attr, None)
self.assertEqual(attr_value, value,
"for __phello__.%s, %r != %r" %
(attr, attr_value, value))
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_lacking_parent(self):
with util.uncache('__phello__'):
module, stdout = self.load_module('__phello__.spam')
filename = resolve_stdlib_file('__phello__.spam')
check = {'__name__': '__phello__.spam',
'__package__': '__phello__',
'__loader__': self.machinery.FrozenImporter,
'__file__': filename,
}
for attr, value in check.items():
attr_value = getattr(module, attr)
self.assertEqual(attr_value, value,
"for __phello__.spam.%s, %r != %r" %
(attr, attr_value, value))
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
def test_module_reuse(self):
with fresh('__hello__', oldapi=True):
module1 = self.machinery.FrozenImporter.load_module('__hello__')
module2 = self.machinery.FrozenImporter.load_module('__hello__')
with captured_stdout() as stdout:
module1.main()
module2.main()
self.assertIs(module1, module2)
self.assertEqual(stdout.getvalue(),
'Hello world!\nHello world!\n')
def test_module_repr(self):
with fresh('__hello__', oldapi=True):
module = self.machinery.FrozenImporter.load_module('__hello__')
repr_str = self.machinery.FrozenImporter.module_repr(module)
self.assertEqual(repr_str,
"<module '__hello__' (frozen)>")
# No way to trigger an error in a frozen module.
test_state_after_failure = None
def test_unloadable(self):
with import_helper.frozen_modules():
with deprecated():
assert self.machinery.FrozenImporter.find_module('_not_real') is None
with self.assertRaises(ImportError) as cm:
self.load_module('_not_real')
self.assertEqual(cm.exception.name, '_not_real')
(Frozen_LoaderTests,
Source_LoaderTests
) = util.test_both(LoaderTests, machinery=machinery)
class InspectLoaderTests:
"""Tests for the InspectLoader methods for FrozenImporter."""

View File

@@ -33,48 +33,5 @@ class SpecLoaderAttributeTests:
) = util.test_both(SpecLoaderAttributeTests, __import__=util.__import__)
class LoaderMock:
def find_module(self, fullname, path=None):
return self
def load_module(self, fullname):
sys.modules[fullname] = self.module
return self.module
class LoaderAttributeTests:
def test___loader___missing(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
module = types.ModuleType('blah')
try:
del module.__loader__
except AttributeError:
pass
loader = LoaderMock()
loader.module = module
with util.uncache('blah'), util.import_state(meta_path=[loader]):
module = self.__import__('blah')
self.assertEqual(loader, module.__loader__)
def test___loader___is_None(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
module = types.ModuleType('blah')
module.__loader__ = None
loader = LoaderMock()
loader.module = module
with util.uncache('blah'), util.import_state(meta_path=[loader]):
returned_module = self.__import__('blah')
self.assertEqual(loader, module.__loader__)
(Frozen_Tests,
Source_Tests
) = util.test_both(LoaderAttributeTests, __import__=util.__import__)
if __name__ == '__main__':
unittest.main()

View File

@@ -78,8 +78,8 @@ class Using__package__:
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_warn_when_package_and_spec_disagree(self):
# Raise an ImportWarning if __package__ != __spec__.parent.
with self.assertWarns(ImportWarning):
# Raise a DeprecationWarning if __package__ != __spec__.parent.
with self.assertWarns(DeprecationWarning):
self.import_module({'__package__': 'pkg.fake',
'__spec__': FakeSpec('pkg.fakefake')})
@@ -99,25 +99,6 @@ class FakeSpec:
self.parent = parent
class Using__package__PEP302(Using__package__):
mock_modules = util.mock_modules
def test_using___package__(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
super().test_using___package__()
def test_spec_fallback(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
super().test_spec_fallback()
(Frozen_UsingPackagePEP302,
Source_UsingPackagePEP302
) = util.test_both(Using__package__PEP302, __import__=util.__import__)
class Using__package__PEP451(Using__package__):
mock_modules = util.mock_spec
@@ -166,23 +147,6 @@ class Setting__package__:
module = getattr(pkg, 'mod')
self.assertEqual(module.__package__, 'pkg')
class Setting__package__PEP302(Setting__package__, unittest.TestCase):
mock_modules = util.mock_modules
def test_top_level(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
super().test_top_level()
def test_package(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
super().test_package()
def test_submodule(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
super().test_submodule()
class Setting__package__PEP451(Setting__package__, unittest.TestCase):
mock_modules = util.mock_spec

View File

@@ -28,11 +28,6 @@ class BadSpecFinderLoader:
class BadLoaderFinder:
@classmethod
def find_module(cls, fullname, path):
if fullname == SUBMOD_NAME:
return cls
@classmethod
def load_module(cls, fullname):
if fullname == SUBMOD_NAME:

View File

@@ -52,12 +52,11 @@ class ImportlibUseCache(UseCache, unittest.TestCase):
__import__ = util.__import__['Source']
def create_mock(self, *names, return_=None):
mock = util.mock_modules(*names)
original_load = mock.load_module
def load_module(self, fullname):
original_load(fullname)
return return_
mock.load_module = MethodType(load_module, mock)
mock = util.mock_spec(*names)
original_spec = mock.find_spec
def find_spec(self, fullname, path, target=None):
return original_spec(fullname)
mock.find_spec = MethodType(find_spec, mock)
return mock
# __import__ inconsistent between loaders and built-in import when it comes
@@ -86,14 +85,12 @@ class ImportlibUseCache(UseCache, unittest.TestCase):
# See test_using_cache_after_loader() for reasoning.
def test_using_cache_for_fromlist(self):
# [from cache for fromlist]
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
with self.create_mock('pkg.__init__', 'pkg.module') as importer:
with util.import_state(meta_path=[importer]):
module = self.__import__('pkg', fromlist=['module'])
self.assertTrue(hasattr(module, 'module'))
self.assertEqual(id(module.module),
id(sys.modules['pkg.module']))
with self.create_mock('pkg.__init__', 'pkg.module') as importer:
with util.import_state(meta_path=[importer]):
module = self.__import__('pkg', fromlist=['module'])
self.assertTrue(hasattr(module, 'module'))
self.assertEqual(id(module.module),
id(sys.modules['pkg.module']))
if __name__ == '__main__':

View File

@@ -0,0 +1,192 @@
"""Tests for helper functions used by import.c ."""
from importlib import _bootstrap_external, machinery
import os.path
from types import ModuleType, SimpleNamespace
import unittest
import warnings
from .. import util
class FixUpModuleTests:
def test_no_loader_but_spec(self):
loader = object()
name = "hello"
path = "hello.py"
spec = machinery.ModuleSpec(name, loader)
ns = {"__spec__": spec}
_bootstrap_external._fix_up_module(ns, name, path)
expected = {"__spec__": spec, "__loader__": loader, "__file__": path,
"__cached__": None}
self.assertEqual(ns, expected)
def test_no_loader_no_spec_but_sourceless(self):
name = "hello"
path = "hello.py"
ns = {}
_bootstrap_external._fix_up_module(ns, name, path, path)
expected = {"__file__": path, "__cached__": path}
for key, val in expected.items():
with self.subTest(f"{key}: {val}"):
self.assertEqual(ns[key], val)
spec = ns["__spec__"]
self.assertIsInstance(spec, machinery.ModuleSpec)
self.assertEqual(spec.name, name)
self.assertEqual(spec.origin, os.path.abspath(path))
self.assertEqual(spec.cached, os.path.abspath(path))
self.assertIsInstance(spec.loader, machinery.SourcelessFileLoader)
self.assertEqual(spec.loader.name, name)
self.assertEqual(spec.loader.path, path)
self.assertEqual(spec.loader, ns["__loader__"])
def test_no_loader_no_spec_but_source(self):
name = "hello"
path = "hello.py"
ns = {}
_bootstrap_external._fix_up_module(ns, name, path)
expected = {"__file__": path, "__cached__": None}
for key, val in expected.items():
with self.subTest(f"{key}: {val}"):
self.assertEqual(ns[key], val)
spec = ns["__spec__"]
self.assertIsInstance(spec, machinery.ModuleSpec)
self.assertEqual(spec.name, name)
self.assertEqual(spec.origin, os.path.abspath(path))
self.assertIsInstance(spec.loader, machinery.SourceFileLoader)
self.assertEqual(spec.loader.name, name)
self.assertEqual(spec.loader.path, path)
self.assertEqual(spec.loader, ns["__loader__"])
FrozenFixUpModuleTests, SourceFixUpModuleTests = util.test_both(FixUpModuleTests)
class TestBlessMyLoader(unittest.TestCase):
# GH#86298 is part of the migration away from module attributes and toward
# __spec__ attributes. There are several cases to test here. This will
# have to change in Python 3.14 when we actually remove/ignore __loader__
# in favor of requiring __spec__.loader.
def test_gh86298_no_loader_and_no_spec(self):
bar = ModuleType('bar')
del bar.__loader__
del bar.__spec__
# 2022-10-06(warsaw): For backward compatibility with the
# implementation in _warnings.c, this can't raise an
# AttributeError. See _bless_my_loader() in _bootstrap_external.py
# If working with a module:
## self.assertRaises(
## AttributeError, _bootstrap_external._bless_my_loader,
## bar.__dict__)
self.assertIsNone(_bootstrap_external._bless_my_loader(bar.__dict__))
def test_gh86298_loader_is_none_and_no_spec(self):
bar = ModuleType('bar')
bar.__loader__ = None
del bar.__spec__
# 2022-10-06(warsaw): For backward compatibility with the
# implementation in _warnings.c, this can't raise an
# AttributeError. See _bless_my_loader() in _bootstrap_external.py
# If working with a module:
## self.assertRaises(
## AttributeError, _bootstrap_external._bless_my_loader,
## bar.__dict__)
self.assertIsNone(_bootstrap_external._bless_my_loader(bar.__dict__))
def test_gh86298_no_loader_and_spec_is_none(self):
bar = ModuleType('bar')
del bar.__loader__
bar.__spec__ = None
self.assertRaises(
ValueError,
_bootstrap_external._bless_my_loader, bar.__dict__)
def test_gh86298_loader_is_none_and_spec_is_none(self):
bar = ModuleType('bar')
bar.__loader__ = None
bar.__spec__ = None
self.assertRaises(
ValueError,
_bootstrap_external._bless_my_loader, bar.__dict__)
def test_gh86298_loader_is_none_and_spec_loader_is_none(self):
bar = ModuleType('bar')
bar.__loader__ = None
bar.__spec__ = SimpleNamespace(loader=None)
self.assertRaises(
ValueError,
_bootstrap_external._bless_my_loader, bar.__dict__)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_gh86298_no_spec(self):
bar = ModuleType('bar')
bar.__loader__ = object()
del bar.__spec__
with warnings.catch_warnings():
self.assertWarns(
DeprecationWarning,
_bootstrap_external._bless_my_loader, bar.__dict__)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_gh86298_spec_is_none(self):
bar = ModuleType('bar')
bar.__loader__ = object()
bar.__spec__ = None
with warnings.catch_warnings():
self.assertWarns(
DeprecationWarning,
_bootstrap_external._bless_my_loader, bar.__dict__)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_gh86298_no_spec_loader(self):
bar = ModuleType('bar')
bar.__loader__ = object()
bar.__spec__ = SimpleNamespace()
with warnings.catch_warnings():
self.assertWarns(
DeprecationWarning,
_bootstrap_external._bless_my_loader, bar.__dict__)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_gh86298_loader_and_spec_loader_disagree(self):
bar = ModuleType('bar')
bar.__loader__ = object()
bar.__spec__ = SimpleNamespace(loader=object())
with warnings.catch_warnings():
self.assertWarns(
DeprecationWarning,
_bootstrap_external._bless_my_loader, bar.__dict__)
def test_gh86298_no_loader_and_no_spec_loader(self):
bar = ModuleType('bar')
del bar.__loader__
bar.__spec__ = SimpleNamespace()
self.assertRaises(
AttributeError,
_bootstrap_external._bless_my_loader, bar.__dict__)
def test_gh86298_no_loader_with_spec_loader_okay(self):
bar = ModuleType('bar')
del bar.__loader__
loader = object()
bar.__spec__ = SimpleNamespace(loader=loader)
self.assertEqual(
_bootstrap_external._bless_my_loader(bar.__dict__),
loader)
if __name__ == "__main__":
unittest.main()

View File

@@ -115,16 +115,6 @@ class CallSignoreSuppressImportWarning(CallSignature):
super().test_no_path()
class CallSignaturePEP302(CallSignoreSuppressImportWarning):
mock_modules = util.mock_modules
finder_name = 'find_module'
(Frozen_CallSignaturePEP302,
Source_CallSignaturePEP302
) = util.test_both(CallSignaturePEP302, __import__=util.__import__)
class CallSignaturePEP451(CallSignature):
mock_modules = util.mock_spec
finder_name = 'find_spec'

View File

@@ -118,46 +118,6 @@ class FinderTests:
if email is not missing:
sys.modules['email'] = email
def test_finder_with_find_module(self):
class TestFinder:
def find_module(self, fullname):
return self.to_return
failing_finder = TestFinder()
failing_finder.to_return = None
path = 'testing path'
with util.import_state(path_importer_cache={path: failing_finder}):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
self.assertIsNone(
self.machinery.PathFinder.find_spec('whatever', [path]))
success_finder = TestFinder()
success_finder.to_return = __loader__
with util.import_state(path_importer_cache={path: success_finder}):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
spec = self.machinery.PathFinder.find_spec('whatever', [path])
self.assertEqual(spec.loader, __loader__)
def test_finder_with_find_loader(self):
class TestFinder:
loader = None
portions = []
def find_loader(self, fullname):
return self.loader, self.portions
path = 'testing path'
with util.import_state(path_importer_cache={path: TestFinder()}):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
self.assertIsNone(
self.machinery.PathFinder.find_spec('whatever', [path]))
success_finder = TestFinder()
success_finder.loader = __loader__
with util.import_state(path_importer_cache={path: success_finder}):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
spec = self.machinery.PathFinder.find_spec('whatever', [path])
self.assertEqual(spec.loader, __loader__)
def test_finder_with_find_spec(self):
class TestFinder:
spec = None
@@ -230,9 +190,9 @@ class FinderTests:
class FindModuleTests(FinderTests):
def find(self, *args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return self.machinery.PathFinder.find_module(*args, **kwargs)
spec = self.machinery.PathFinder.find_spec(*args, **kwargs)
return None if spec is None else spec.loader
def check_found(self, found, importer):
self.assertIs(found, importer)
@@ -257,16 +217,14 @@ class FindSpecTests(FinderTests):
class PathEntryFinderTests:
def test_finder_with_failing_find_spec(self):
# PathEntryFinder with find_module() defined should work.
# Issue #20763.
class Finder:
path_location = 'test_finder_with_find_module'
path_location = 'test_finder_with_find_spec'
def __init__(self, path):
if path != self.path_location:
raise ImportError
@staticmethod
def find_module(fullname):
def find_spec(fullname, target=None):
return None
@@ -276,27 +234,6 @@ class PathEntryFinderTests:
warnings.simplefilter("ignore", ImportWarning)
self.machinery.PathFinder.find_spec('importlib')
def test_finder_with_failing_find_module(self):
# PathEntryFinder with find_module() defined should work.
# Issue #20763.
class Finder:
path_location = 'test_finder_with_find_module'
def __init__(self, path):
if path != self.path_location:
raise ImportError
@staticmethod
def find_module(fullname):
return None
with util.import_state(path=[Finder.path_location]+sys.path[:],
path_hooks=[Finder]):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
warnings.simplefilter("ignore", DeprecationWarning)
self.machinery.PathFinder.find_module('importlib')
(Frozen_PEFTests,
Source_PEFTests

View File

@@ -0,0 +1,56 @@
import pathlib
import functools
from typing import Dict, Union
####
# from jaraco.path 3.4.1
FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore
def build(spec: FilesSpec, prefix=pathlib.Path()):
"""
Build a set of files/directories, as described by the spec.
Each key represents a pathname, and the value represents
the content. Content may be a nested directory.
>>> spec = {
... 'README.txt': "A README file",
... "foo": {
... "__init__.py": "",
... "bar": {
... "__init__.py": "",
... },
... "baz.py": "# Some code",
... }
... }
>>> target = getfixture('tmp_path')
>>> build(spec, target)
>>> target.joinpath('foo/baz.py').read_text(encoding='utf-8')
'# Some code'
"""
for name, contents in spec.items():
create(contents, pathlib.Path(prefix) / name)
@functools.singledispatch
def create(content: Union[str, bytes, FilesSpec], path):
path.mkdir(exist_ok=True)
build(content, prefix=path) # type: ignore
@create.register
def _(content: bytes, path):
path.write_bytes(content)
@create.register
def _(content: str, path):
path.write_text(content, encoding='utf-8')
# end from jaraco.path
####

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1 @@
a resource

Binary file not shown.

Binary file not shown.

View File

@@ -8,7 +8,7 @@ from importlib.resources._adapters import (
wrap_spec,
)
from .resources import util
from . import util
class CompatibilityFilesTests(unittest.TestCase):
@@ -64,11 +64,13 @@ class CompatibilityFilesTests(unittest.TestCase):
def test_spec_path_open(self):
self.assertEqual(self.files.read_bytes(), b'Hello, world!')
self.assertEqual(self.files.read_text(), 'Hello, world!')
self.assertEqual(self.files.read_text(encoding='utf-8'), 'Hello, world!')
def test_child_path_open(self):
self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!')
self.assertEqual((self.files / 'a').read_text(), 'Hello, world!')
self.assertEqual(
(self.files / 'a').read_text(encoding='utf-8'), 'Hello, world!'
)
def test_orphan_path_open(self):
with self.assertRaises(FileNotFoundError):

View File

@@ -2,7 +2,7 @@ import unittest
from importlib import resources
from . import data01
from .resources import util
from . import util
class ContentsTests:

View File

@@ -0,0 +1,46 @@
import unittest
import contextlib
import pathlib
from test.support import os_helper
from importlib import resources
from importlib.resources.abc import TraversableResources, ResourceReader
from . import util
class SimpleLoader:
"""
A simple loader that only implements a resource reader.
"""
def __init__(self, reader: ResourceReader):
self.reader = reader
def get_resource_reader(self, package):
return self.reader
class MagicResources(TraversableResources):
"""
Magically returns the resources at path.
"""
def __init__(self, path: pathlib.Path):
self.path = path
def files(self):
return self.path
class CustomTraversableResourcesTests(unittest.TestCase):
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
def test_custom_loader(self):
temp_dir = self.fixtures.enter_context(os_helper.temp_dir())
loader = SimpleLoader(MagicResources(temp_dir))
pkg = util.create_package_from_loader(loader)
files = resources.files(pkg)
assert files is temp_dir

View File

@@ -0,0 +1,113 @@
import typing
import textwrap
import unittest
import warnings
import importlib
import contextlib
from importlib import resources
from importlib.resources.abc import Traversable
from . import data01
from . import util
from . import _path
from test.support import os_helper
from test.support import import_helper
@contextlib.contextmanager
def suppress_known_deprecation():
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter('default', category=DeprecationWarning)
yield ctx
class FilesTests:
def test_read_bytes(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_bytes()
assert actual == b'Hello, UTF-8 world!\n'
def test_read_text(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_text(encoding='utf-8')
assert actual == 'Hello, UTF-8 world!\n'
@unittest.skipUnless(
hasattr(typing, 'runtime_checkable'),
"Only suitable when typing supports runtime_checkable",
)
def test_traversable(self):
assert isinstance(resources.files(self.data), Traversable)
def test_old_parameter(self):
"""
Files used to take a 'package' parameter. Make sure anyone
passing by name is still supported.
"""
with suppress_known_deprecation():
resources.files(package=self.data)
class OpenDiskTests(FilesTests, unittest.TestCase):
def setUp(self):
self.data = data01
class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase):
pass
class OpenNamespaceTests(FilesTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
class SiteDir:
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(os_helper.temp_dir())
self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir))
self.fixtures.enter_context(import_helper.CleanImport())
class ModulesFilesTests(SiteDir, unittest.TestCase):
def test_module_resources(self):
"""
A module can have resources found adjacent to the module.
"""
spec = {
'mod.py': '',
'res.txt': 'resources are the best',
}
_path.build(spec, self.site_dir)
import mod
actual = resources.files(mod).joinpath('res.txt').read_text(encoding='utf-8')
assert actual == spec['res.txt']
class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
def test_implicit_files(self):
"""
Without any parameter, files() will infer the location as the caller.
"""
spec = {
'somepkg': {
'__init__.py': textwrap.dedent(
"""
import importlib.resources as res
val = res.files().joinpath('res.txt').read_text(encoding='utf-8')
"""
),
'res.txt': 'resources are the best',
},
}
_path.build(spec, self.site_dir)
assert importlib.import_module('somepkg').val == 'resources are the best'
if __name__ == '__main__':
unittest.main()

View File

@@ -2,7 +2,7 @@ import unittest
from importlib import resources
from . import data01
from .resources import util
from . import util
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
@@ -15,7 +15,7 @@ class CommonBinaryTests(util.CommonTests, unittest.TestCase):
class CommonTextTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
target = resources.files(package).joinpath(path)
with target.open():
with target.open(encoding='utf-8'):
pass
@@ -28,7 +28,7 @@ class OpenTests:
def test_open_text_default_encoding(self):
target = resources.files(self.data) / 'utf-8.file'
with target.open() as fp:
with target.open(encoding='utf-8') as fp:
result = fp.read()
self.assertEqual(result, 'Hello, UTF-8 world!\n')
@@ -39,7 +39,9 @@ class OpenTests:
self.assertEqual(result, 'Hello, UTF-16 world!\n')
def test_open_text_with_errors(self):
# Raises UnicodeError without the 'errors' argument.
"""
Raises UnicodeError without the 'errors' argument.
"""
target = resources.files(self.data) / 'utf-16.file'
with target.open(encoding='utf-8', errors='strict') as fp:
self.assertRaises(UnicodeError, fp.read)
@@ -54,11 +56,13 @@ class OpenTests:
def test_open_binary_FileNotFoundError(self):
target = resources.files(self.data) / 'does-not-exist'
self.assertRaises(FileNotFoundError, target.open, 'rb')
with self.assertRaises(FileNotFoundError):
target.open('rb')
def test_open_text_FileNotFoundError(self):
target = resources.files(self.data) / 'does-not-exist'
self.assertRaises(FileNotFoundError, target.open)
with self.assertRaises(FileNotFoundError):
target.open(encoding='utf-8')
class OpenDiskTests(OpenTests, unittest.TestCase):
@@ -72,12 +76,6 @@ class OpenDiskNamespaceTests(OpenTests, unittest.TestCase):
self.data = namespacedata01
# TODO: RUSTPYTHON
import sys
if sys.platform == 'win32':
@unittest.expectedFailure
def test_open_text_default_encoding(self):
super().test_open_text_default_encoding()
class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase):
pass

View File

@@ -3,7 +3,7 @@ import unittest
from importlib import resources
from . import data01
from .resources import util
from . import util
class CommonTests(util.CommonTests, unittest.TestCase):
@@ -14,9 +14,12 @@ class CommonTests(util.CommonTests, unittest.TestCase):
class PathTests:
def test_reading(self):
# Path should be readable.
# Test also implicitly verifies the returned object is a pathlib.Path
# instance.
"""
Path should be readable.
Test also implicitly verifies the returned object is a pathlib.Path
instance.
"""
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
self.assertTrue(path.name.endswith("utf-8.file"), repr(path))
@@ -51,8 +54,10 @@ class PathMemoryTests(PathTests, unittest.TestCase):
class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase):
def test_remove_in_context_manager(self):
# It is not an error if the file that was temporarily stashed on the
# file system is removed inside the `with` stanza.
"""
It is not an error if the file that was temporarily stashed on the
file system is removed inside the `with` stanza.
"""
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
path.unlink()

View File

@@ -2,7 +2,7 @@ import unittest
from importlib import import_module, resources
from . import data01
from .resources import util
from . import util
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
@@ -12,7 +12,7 @@ class CommonBinaryTests(util.CommonTests, unittest.TestCase):
class CommonTextTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
resources.files(package).joinpath(path).read_text()
resources.files(package).joinpath(path).read_text(encoding='utf-8')
class ReadTests:
@@ -21,7 +21,11 @@ class ReadTests:
self.assertEqual(result, b'\0\1\2\3')
def test_read_text_default_encoding(self):
result = resources.files(self.data).joinpath('utf-8.file').read_text()
result = (
resources.files(self.data)
.joinpath('utf-8.file')
.read_text(encoding='utf-8')
)
self.assertEqual(result, 'Hello, UTF-8 world!\n')
def test_read_text_given_encoding(self):
@@ -33,7 +37,9 @@ class ReadTests:
self.assertEqual(result, 'Hello, UTF-16 world!\n')
def test_read_text_with_errors(self):
# Raises UnicodeError without the 'errors' argument.
"""
Raises UnicodeError without the 'errors' argument.
"""
target = resources.files(self.data) / 'utf-16.file'
self.assertRaises(UnicodeError, target.read_text, encoding='utf-8')
result = target.read_text(encoding='utf-8', errors='ignore')

View File

@@ -75,6 +75,22 @@ class MultiplexedPathTest(unittest.TestCase):
str(path.joinpath('imaginary'))[len(prefix) + 1 :],
os.path.join('namespacedata01', 'imaginary'),
)
self.assertEqual(path.joinpath(), path)
def test_join_path_compound(self):
path = MultiplexedPath(self.folder)
assert not path.joinpath('imaginary/foo.py').exists()
def test_join_path_common_subdir(self):
prefix = os.path.abspath(os.path.join(__file__, '..'))
data01 = os.path.join(prefix, 'data01')
data02 = os.path.join(prefix, 'data02')
path = MultiplexedPath(data01, data02)
self.assertIsInstance(path.joinpath('subdirectory'), MultiplexedPath)
self.assertEqual(
str(path.joinpath('subdirectory', 'subsubdir'))[len(prefix) + 1 :],
os.path.join('data02', 'subdirectory', 'subsubdir'),
)
def test_repr(self):
self.assertEqual(

View File

@@ -1,3 +1,4 @@
import contextlib
import sys
import unittest
import uuid
@@ -5,9 +6,9 @@ import pathlib
from . import data01
from . import zipdata01, zipdata02
from .resources import util
from . import util
from importlib import resources, import_module
from test.support import import_helper
from test.support import import_helper, os_helper
from test.support.os_helper import unlink
@@ -69,10 +70,12 @@ class ResourceLoaderTests(unittest.TestCase):
class ResourceCornerCaseTests(unittest.TestCase):
def test_package_has_no_reader_fallback(self):
# Test odd ball packages which:
"""
Test odd ball packages which:
# 1. Do not have a ResourceReader as a loader
# 2. Are not on the file system
# 3. Are not in a zip file
"""
module = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
)
@@ -111,6 +114,14 @@ class ResourceFromZipsTest01(util.ZipSetupBase, unittest.TestCase):
{'__init__.py', 'binary.file'},
)
def test_as_file_directory(self):
with resources.as_file(resources.files('ziptestdata')) as data:
assert data.name == 'ziptestdata'
assert data.is_dir()
assert data.joinpath('subdirectory').is_dir()
assert len(list(data.iterdir()))
assert not data.parent.exists()
class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase):
ZIP_MODULE = zipdata02 # type: ignore
@@ -130,82 +141,71 @@ class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase):
)
@contextlib.contextmanager
def zip_on_path(dir):
data_path = pathlib.Path(zipdata01.__file__)
source_zip_path = data_path.parent.joinpath('ziptestdata.zip')
zip_path = pathlib.Path(dir) / f'{uuid.uuid4()}.zip'
zip_path.write_bytes(source_zip_path.read_bytes())
sys.path.append(str(zip_path))
import_module('ziptestdata')
try:
yield
finally:
with contextlib.suppress(ValueError):
sys.path.remove(str(zip_path))
with contextlib.suppress(KeyError):
del sys.path_importer_cache[str(zip_path)]
del sys.modules['ziptestdata']
with contextlib.suppress(OSError):
unlink(zip_path)
class DeletingZipsTest(unittest.TestCase):
"""Having accessed resources in a zip file should not keep an open
reference to the zip.
"""
ZIP_MODULE = zipdata01
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
modules = import_helper.modules_setup()
self.addCleanup(import_helper.modules_cleanup, *modules)
data_path = pathlib.Path(self.ZIP_MODULE.__file__)
data_dir = data_path.parent
self.source_zip_path = data_dir / 'ziptestdata.zip'
self.zip_path = pathlib.Path(f'{uuid.uuid4()}.zip').absolute()
self.zip_path.write_bytes(self.source_zip_path.read_bytes())
sys.path.append(str(self.zip_path))
self.data = import_module('ziptestdata')
def tearDown(self):
try:
sys.path.remove(str(self.zip_path))
except ValueError:
pass
try:
del sys.path_importer_cache[str(self.zip_path)]
del sys.modules[self.data.__name__]
except KeyError:
pass
try:
unlink(self.zip_path)
except OSError:
# If the test fails, this will probably fail too
pass
temp_dir = self.fixtures.enter_context(os_helper.temp_dir())
self.fixtures.enter_context(zip_on_path(temp_dir))
def test_iterdir_does_not_keep_open(self):
c = [item.name for item in resources.files('ziptestdata').iterdir()]
self.zip_path.unlink()
del c
[item.name for item in resources.files('ziptestdata').iterdir()]
def test_is_file_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('binary.file').is_file()
self.zip_path.unlink()
del c
resources.files('ziptestdata').joinpath('binary.file').is_file()
def test_is_file_failure_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('not-present').is_file()
self.zip_path.unlink()
del c
resources.files('ziptestdata').joinpath('not-present').is_file()
@unittest.skip("Desired but not supported.")
def test_as_file_does_not_keep_open(self): # pragma: no cover
c = resources.as_file(resources.files('ziptestdata') / 'binary.file')
self.zip_path.unlink()
del c
resources.as_file(resources.files('ziptestdata') / 'binary.file')
def test_entered_path_does_not_keep_open(self):
# This is what certifi does on import to make its bundle
# available for the process duration.
c = resources.as_file(
resources.files('ziptestdata') / 'binary.file'
).__enter__()
self.zip_path.unlink()
del c
"""
Mimic what certifi does on import to make its bundle
available for the process duration.
"""
resources.as_file(resources.files('ziptestdata') / 'binary.file').__enter__()
def test_read_binary_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('binary.file').read_bytes()
self.zip_path.unlink()
del c
resources.files('ziptestdata').joinpath('binary.file').read_bytes()
def test_read_text_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('utf-8.file').read_text()
self.zip_path.unlink()
del c
resources.files('ziptestdata').joinpath('utf-8.file').read_text(
encoding='utf-8'
)
class ResourceFromNamespaceTest01(unittest.TestCase):

View File

@@ -3,11 +3,11 @@ import importlib
import io
import sys
import types
from pathlib import Path, PurePath
import pathlib
from .. import data01
from .. import zipdata01
from importlib.abc import ResourceReader
from . import data01
from . import zipdata01
from importlib.resources.abc import ResourceReader
from test.support import import_helper
@@ -80,43 +80,44 @@ class CommonTests(metaclass=abc.ABCMeta):
"""
def test_package_name(self):
# Passing in the package name should succeed.
"""
Passing in the package name should succeed.
"""
self.execute(data01.__name__, 'utf-8.file')
def test_package_object(self):
# Passing in the package itself should succeed.
"""
Passing in the package itself should succeed.
"""
self.execute(data01, 'utf-8.file')
def test_string_path(self):
# Passing in a string for the path should succeed.
"""
Passing in a string for the path should succeed.
"""
path = 'utf-8.file'
self.execute(data01, path)
def test_pathlib_path(self):
# Passing in a pathlib.PurePath object for the path should succeed.
path = PurePath('utf-8.file')
"""
Passing in a pathlib.PurePath object for the path should succeed.
"""
path = pathlib.PurePath('utf-8.file')
self.execute(data01, path)
def test_importing_module_as_side_effect(self):
# The anchor package can already be imported.
"""
The anchor package can already be imported.
"""
del sys.modules[data01.__name__]
self.execute(data01.__name__, 'utf-8.file')
def test_non_package_by_name(self):
# The anchor package cannot be a module.
with self.assertRaises(TypeError):
self.execute(__name__, 'utf-8.file')
def test_non_package_by_package(self):
# The anchor package cannot be a module.
with self.assertRaises(TypeError):
module = sys.modules['test.test_importlib.resources.util']
self.execute(module, 'utf-8.file')
def test_missing_path(self):
# Attempting to open or read or request the path for a
# non-existent path should succeed if open_resource
# can return a viable data stream.
"""
Attempting to open or read or request the path for a
non-existent path should succeed if open_resource
can return a viable data stream.
"""
bytes_data = io.BytesIO(b'Hello, world!')
package = create_package(file=bytes_data, path=FileNotFoundError())
self.execute(package, 'utf-8.file')
@@ -144,7 +145,7 @@ class ZipSetupBase:
@classmethod
def setUpClass(cls):
data_path = Path(cls.ZIP_MODULE.__file__)
data_path = pathlib.Path(cls.ZIP_MODULE.__file__)
data_dir = data_path.parent
cls._zip_path = str(data_dir / 'ziptestdata.zip')
sys.path.append(cls._zip_path)

Binary file not shown.

Binary file not shown.

View File

@@ -63,19 +63,6 @@ class CaseSensitivityTest(util.CASEOKTestBase):
self.assertIn(self.name, insensitive.get_filename(self.name))
class CaseSensitivityTestPEP302(CaseSensitivityTest):
def find(self, finder):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return finder.find_module(self.name)
(Frozen_CaseSensitivityTestPEP302,
Source_CaseSensitivityTestPEP302
) = util.test_both(CaseSensitivityTestPEP302, importlib=importlib,
machinery=machinery)
class CaseSensitivityTestPEP451(CaseSensitivityTest):
def find(self, finder):
found = finder.find_spec(self.name)

View File

@@ -51,7 +51,6 @@ class SimpleTest(abc.LoaderTests):
def get_code(self, _): pass
def get_source(self, _): pass
def is_package(self, _): pass
def module_repr(self, _): pass
path = 'some_path'
name = 'some_name'

View File

@@ -120,7 +120,7 @@ class FinderTests(abc.FinderTests):
def test_failure(self):
with util.create_modules('blah') as mapping:
nothing = self.import_(mapping['.root'], 'sdfsadsadf')
self.assertIsNone(nothing)
self.assertEqual(nothing, self.NOT_FOUND)
def test_empty_string_for_dir(self):
# The empty string from sys.path means to search in the cwd.
@@ -150,7 +150,7 @@ class FinderTests(abc.FinderTests):
found = self._find(finder, 'mod', loader_only=True)
self.assertIsNotNone(found)
found = self._find(finder, 'mod', loader_only=True)
self.assertIsNone(found)
self.assertEqual(found, self.NOT_FOUND)
@unittest.skipUnless(sys.platform != 'win32',
'os.chmod() does not support the needed arguments under Windows')
@@ -197,10 +197,12 @@ class FinderTestsPEP420(FinderTests):
NOT_FOUND = (None, [])
def _find(self, finder, name, loader_only=False):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loader_portions = finder.find_loader(name)
return loader_portions[0] if loader_only else loader_portions
spec = finder.find_spec(name)
if spec is None:
return self.NOT_FOUND
if loader_only:
return spec.loader
return spec.loader, spec.submodule_search_locations
(Frozen_FinderTestsPEP420,
@@ -208,20 +210,5 @@ class FinderTestsPEP420(FinderTests):
) = util.test_both(FinderTestsPEP420, machinery=machinery)
class FinderTestsPEP302(FinderTests):
NOT_FOUND = None
def _find(self, finder, name, loader_only=False):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return finder.find_module(name)
(Frozen_FinderTestsPEP302,
Source_FinderTestsPEP302
) = util.test_both(FinderTestsPEP302, machinery=machinery)
if __name__ == '__main__':
unittest.main()

View File

@@ -18,19 +18,10 @@ class PathHookTest:
self.assertTrue(hasattr(self.path_hook()(mapping['.root']),
'find_spec'))
def test_success_legacy(self):
with util.create_modules('dummy') as mapping:
self.assertTrue(hasattr(self.path_hook()(mapping['.root']),
'find_module'))
def test_empty_string(self):
# The empty string represents the cwd.
self.assertTrue(hasattr(self.path_hook()(''), 'find_spec'))
def test_empty_string_legacy(self):
# The empty string represents the cwd.
self.assertTrue(hasattr(self.path_hook()(''), 'find_module'))
(Frozen_PathHookTest,
Source_PathHooktest

View File

@@ -2,7 +2,6 @@ import io
import marshal
import os
import sys
from test import support
from test.support import import_helper
import types
import unittest
@@ -148,20 +147,13 @@ class ABCTestHarness:
class MetaPathFinder:
def find_module(self, fullname, path):
return super().find_module(fullname, path)
pass
class MetaPathFinderDefaultsTests(ABCTestHarness):
SPLIT = make_abc_subclasses(MetaPathFinder)
def test_find_module(self):
# Default should return None.
with self.assertWarns(DeprecationWarning):
found = self.ins.find_module('something', None)
self.assertIsNone(found)
def test_invalidate_caches(self):
# Calling the method is a no-op.
self.ins.invalidate_caches()
@@ -174,22 +166,13 @@ class MetaPathFinderDefaultsTests(ABCTestHarness):
class PathEntryFinder:
def find_loader(self, fullname):
return super().find_loader(fullname)
pass
class PathEntryFinderDefaultsTests(ABCTestHarness):
SPLIT = make_abc_subclasses(PathEntryFinder)
def test_find_loader(self):
with self.assertWarns(DeprecationWarning):
found = self.ins.find_loader('something')
self.assertEqual(found, (None, []))
def find_module(self):
self.assertEqual(None, self.ins.find_module('something'))
def test_invalidate_caches(self):
# Should be a no-op.
self.ins.invalidate_caches()
@@ -202,8 +185,7 @@ class PathEntryFinderDefaultsTests(ABCTestHarness):
class Loader:
def load_module(self, fullname):
return super().load_module(fullname)
pass
class LoaderDefaultsTests(ABCTestHarness):
@@ -222,8 +204,6 @@ class LoaderDefaultsTests(ABCTestHarness):
mod = types.ModuleType('blah')
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
with self.assertRaises(NotImplementedError):
self.ins.module_repr(mod)
original_repr = repr(mod)
mod.__loader__ = self.ins
# Should still return a proper repr.
@@ -323,32 +303,6 @@ class ResourceReader:
return super().contents(*args, **kwargs)
class ResourceReaderDefaultsTests(ABCTestHarness):
SPLIT = make_abc_subclasses(ResourceReader)
def test_open_resource(self):
with self.assertRaises(FileNotFoundError):
self.ins.open_resource('dummy_file')
def test_resource_path(self):
with self.assertRaises(FileNotFoundError):
self.ins.resource_path('dummy_file')
def test_is_resource(self):
with self.assertRaises(FileNotFoundError):
self.ins.is_resource('dummy_file')
def test_contents(self):
with self.assertRaises(FileNotFoundError):
self.ins.contents()
(Frozen_RRDefaultTests,
Source_RRDefaultsTests
) = test_util.test_both(ResourceReaderDefaultsTests)
##### MetaPathFinder concrete methods ##########################################
class MetaPathFinderFindModuleTests:
@@ -362,14 +316,6 @@ class MetaPathFinderFindModuleTests:
return MetaPathSpecFinder()
def test_find_module(self):
finder = self.finder(None)
path = ['a', 'b', 'c']
name = 'blah'
with self.assertWarns(DeprecationWarning):
found = finder.find_module(name, path)
self.assertIsNone(found)
def test_find_spec_with_explicit_target(self):
loader = object()
spec = self.util.spec_from_loader('blah', loader)
@@ -399,53 +345,6 @@ class MetaPathFinderFindModuleTests:
) = test_util.test_both(MetaPathFinderFindModuleTests, abc=abc, util=util)
##### PathEntryFinder concrete methods #########################################
class PathEntryFinderFindLoaderTests:
@classmethod
def finder(cls, spec):
class PathEntrySpecFinder(cls.abc.PathEntryFinder):
def find_spec(self, fullname, target=None):
self.called_for = fullname
return spec
return PathEntrySpecFinder()
def test_no_spec(self):
finder = self.finder(None)
name = 'blah'
with self.assertWarns(DeprecationWarning):
found = finder.find_loader(name)
self.assertIsNone(found[0])
self.assertEqual([], found[1])
self.assertEqual(name, finder.called_for)
def test_spec_with_loader(self):
loader = object()
spec = self.util.spec_from_loader('blah', loader)
finder = self.finder(spec)
with self.assertWarns(DeprecationWarning):
found = finder.find_loader('blah')
self.assertIs(found[0], spec.loader)
def test_spec_with_portions(self):
spec = self.machinery.ModuleSpec('blah', None)
paths = ['a', 'b', 'c']
spec.submodule_search_locations = paths
finder = self.finder(spec)
with self.assertWarns(DeprecationWarning):
found = finder.find_loader('blah')
self.assertIsNone(found[0])
self.assertEqual(paths, found[1])
(Frozen_PEFFindLoaderTests,
Source_PEFFindLoaderTests
) = test_util.test_both(PathEntryFinderFindLoaderTests, abc=abc, util=util,
machinery=machinery)
##### Loader concrete methods ##################################################
class LoaderLoadModuleTests:
@@ -716,9 +615,6 @@ class SourceOnlyLoader:
def get_filename(self, fullname):
return self.path
def module_repr(self, module):
return '<module>'
SPLIT_SOL = make_abc_subclasses(SourceOnlyLoader, 'SourceLoader')
@@ -803,13 +699,7 @@ class SourceLoaderTestHarness:
class SourceOnlyLoaderTests(SourceLoaderTestHarness):
"""Test importlib.abc.SourceLoader for source-only loading.
Reload testing is subsumed by the tests for
importlib.util.module_for_loader.
"""
"""Test importlib.abc.SourceLoader for source-only loading."""
# TODO: RUSTPYTHON
@unittest.expectedFailure

View File

@@ -6,7 +6,6 @@ machinery = test_util.import_importlib('importlib.machinery')
import os.path
import sys
from test import support
from test.support import import_helper
from test.support import os_helper
import types
@@ -96,7 +95,8 @@ class ImportModuleTests:
(Frozen_ImportModuleTests,
Source_ImportModuleTests
) = test_util.test_both(ImportModuleTests, init=init)
) = test_util.test_both(
ImportModuleTests, init=init, util=util, machinery=machinery)
class FindLoaderTests:
@@ -104,29 +104,26 @@ class FindLoaderTests:
FakeMetaFinder = None
def test_sys_modules(self):
# If a module with __loader__ is in sys.modules, then return it.
# If a module with __spec__.loader is in sys.modules, then return it.
name = 'some_mod'
with test_util.uncache(name):
module = types.ModuleType(name)
loader = 'a loader!'
module.__loader__ = loader
module.__spec__ = self.machinery.ModuleSpec(name, loader)
sys.modules[name] = module
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
found = self.init.find_loader(name)
self.assertEqual(loader, found)
spec = self.util.find_spec(name)
self.assertIsNotNone(spec)
self.assertEqual(spec.loader, loader)
def test_sys_modules_loader_is_None(self):
# If sys.modules[name].__loader__ is None, raise ValueError.
# If sys.modules[name].__spec__.loader is None, raise ValueError.
name = 'some_mod'
with test_util.uncache(name):
module = types.ModuleType(name)
module.__loader__ = None
sys.modules[name] = module
with self.assertRaises(ValueError):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.init.find_loader(name)
self.util.find_spec(name)
def test_sys_modules_loader_is_not_set(self):
# Should raise ValueError
@@ -135,24 +132,20 @@ class FindLoaderTests:
with test_util.uncache(name):
module = types.ModuleType(name)
try:
del module.__loader__
del module.__spec__.loader
except AttributeError:
pass
sys.modules[name] = module
with self.assertRaises(ValueError):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.init.find_loader(name)
self.util.find_spec(name)
def test_success(self):
# Return the loader found on sys.meta_path.
name = 'some_mod'
with test_util.uncache(name):
with test_util.import_state(meta_path=[self.FakeMetaFinder]):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
warnings.simplefilter('ignore', ImportWarning)
self.assertEqual((name, None), self.init.find_loader(name))
spec = self.util.find_spec(name)
self.assertEqual((name, (name, None)), (spec.name, spec.loader))
def test_success_path(self):
# Searching on a path should work.
@@ -160,17 +153,12 @@ class FindLoaderTests:
path = 'path to some place'
with test_util.uncache(name):
with test_util.import_state(meta_path=[self.FakeMetaFinder]):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
warnings.simplefilter('ignore', ImportWarning)
self.assertEqual((name, path),
self.init.find_loader(name, path))
spec = self.util.find_spec(name, path)
self.assertEqual(name, spec.name)
def test_nothing(self):
# None is returned upon failure to find a loader.
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.assertIsNone(self.init.find_loader('nevergoingtofindthismodule'))
self.assertIsNone(self.util.find_spec('nevergoingtofindthismodule'))
class FindLoaderPEP451Tests(FindLoaderTests):
@@ -183,20 +171,8 @@ class FindLoaderPEP451Tests(FindLoaderTests):
(Frozen_FindLoaderPEP451Tests,
Source_FindLoaderPEP451Tests
) = test_util.test_both(FindLoaderPEP451Tests, init=init)
class FindLoaderPEP302Tests(FindLoaderTests):
class FakeMetaFinder:
@staticmethod
def find_module(name, path=None):
return name, path
(Frozen_FindLoaderPEP302Tests,
Source_FindLoaderPEP302Tests
) = test_util.test_both(FindLoaderPEP302Tests, init=init)
) = test_util.test_both(
FindLoaderPEP451Tests, init=init, util=util, machinery=machinery)
class ReloadTests:
@@ -301,7 +277,8 @@ class ReloadTests:
name = 'spam'
with os_helper.temp_cwd(None) as cwd:
with test_util.uncache('spam'):
with import_helper.DirsOnSysPath(cwd):
with test_util.import_state(path=[cwd]):
self.init._bootstrap_external._install(self.init._bootstrap)
# Start as a namespace package.
self.init.invalidate_caches()
bad_path = os.path.join(cwd, name, '__init.py')
@@ -380,7 +357,8 @@ class ReloadTests:
(Frozen_ReloadTests,
Source_ReloadTests
) = test_util.test_both(ReloadTests, init=init, util=util)
) = test_util.test_both(
ReloadTests, init=init, util=util, machinery=machinery)
class InvalidateCacheTests:
@@ -390,8 +368,6 @@ class InvalidateCacheTests:
class InvalidatingNullFinder:
def __init__(self, *ignored):
self.called = False
def find_module(self, *args):
return None
def invalidate_caches(self):
self.called = True
@@ -416,7 +392,8 @@ class InvalidateCacheTests:
(Frozen_InvalidateCacheTests,
Source_InvalidateCacheTests
) = test_util.test_both(InvalidateCacheTests, init=init)
) = test_util.test_both(
InvalidateCacheTests, init=init, util=util, machinery=machinery)
class FrozenImportlibTests(unittest.TestCase):

View File

@@ -1,46 +0,0 @@
import typing
import unittest
from importlib import resources
from importlib.abc import Traversable
from . import data01
from .resources import util
class FilesTests:
def test_read_bytes(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_bytes()
assert actual == b'Hello, UTF-8 world!\n'
def test_read_text(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_text(encoding='utf-8')
assert actual == 'Hello, UTF-8 world!\n'
@unittest.skipUnless(
hasattr(typing, 'runtime_checkable'),
"Only suitable when typing supports runtime_checkable",
)
def test_traversable(self):
assert isinstance(resources.files(self.data), Traversable)
class OpenDiskTests(FilesTests, unittest.TestCase):
def setUp(self):
self.data = data01
class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase):
pass
class OpenNamespaceTests(FilesTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
if __name__ == '__main__':
unittest.main()

View File

@@ -33,6 +33,11 @@ class ModuleLockAsRLockTests:
test_repr = None
test_locked_repr = None
def tearDown(self):
for splitinit in init.values():
splitinit._bootstrap._blocking_on.clear()
LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock
for kind, splitinit in init.items()}

View File

@@ -1,9 +1,10 @@
import re
import json
import pickle
import unittest
import warnings
import importlib.metadata
import contextlib
import itertools
try:
import pyfakefs.fake_filesystem_unittest as ffs
@@ -11,6 +12,7 @@ except ImportError:
from .stubs import fake_filesystem_unittest as ffs
from . import fixtures
from ._context import suppress
from importlib.metadata import (
Distribution,
EntryPoint,
@@ -24,6 +26,13 @@ from importlib.metadata import (
)
@contextlib.contextmanager
def suppress_known_deprecation():
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter('default', category=DeprecationWarning)
yield ctx
class BasicTests(fixtures.DistInfoPkg, unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
@@ -39,7 +48,7 @@ class BasicTests(fixtures.DistInfoPkg, unittest.TestCase):
def test_package_not_found_mentions_metadata(self):
"""
When a package is not found, that could indicate that the
packgae is not installed or that it is installed without
package is not installed or that it is installed without
metadata. Ensure the exception mentions metadata to help
guide users toward the cause. See #124.
"""
@@ -48,15 +57,19 @@ class BasicTests(fixtures.DistInfoPkg, unittest.TestCase):
assert "metadata" in str(ctx.exception)
def test_new_style_classes(self):
self.assertIsInstance(Distribution, type)
# expected to fail until ABC is enforced
@suppress(AssertionError)
@suppress_known_deprecation()
def test_abc_enforced(self):
with self.assertRaises(TypeError):
type('DistributionSubclass', (Distribution,), {})()
@fixtures.parameterize(
dict(name=None),
dict(name=''),
)
def test_invalid_inputs_to_from_name(self, name):
with self.assertRaises(Exception):
with self.assertRaises(ValueError):
Distribution.from_name(name)
@@ -174,11 +187,21 @@ class NonASCIITests(fixtures.OnSysPath, fixtures.SiteDir, unittest.TestCase):
assert meta['Description'] == 'pôrˈtend'
class DiscoveryTests(fixtures.EggInfoPkg, fixtures.DistInfoPkg, unittest.TestCase):
class DiscoveryTests(
fixtures.EggInfoPkg,
fixtures.EggInfoPkgPipInstalledNoToplevel,
fixtures.EggInfoPkgPipInstalledNoModules,
fixtures.EggInfoPkgSourcesFallback,
fixtures.DistInfoPkg,
unittest.TestCase,
):
def test_package_discovery(self):
dists = list(distributions())
assert all(isinstance(dist, Distribution) for dist in dists)
assert any(dist.metadata['Name'] == 'egginfo-pkg' for dist in dists)
assert any(dist.metadata['Name'] == 'egg_with_module-pkg' for dist in dists)
assert any(dist.metadata['Name'] == 'egg_with_no_modules-pkg' for dist in dists)
assert any(dist.metadata['Name'] == 'sources_fallback-pkg' for dist in dists)
assert any(dist.metadata['Name'] == 'distinfo-pkg' for dist in dists)
def test_invalid_usage(self):
@@ -260,14 +283,6 @@ class TestEntryPoints(unittest.TestCase):
"""EntryPoints should be hashable"""
hash(self.ep)
def test_json_dump(self):
"""
json should not expect to be able to dump an EntryPoint
"""
with self.assertRaises(Exception):
with warnings.catch_warnings(record=True):
json.dumps(self.ep)
def test_module(self):
assert self.ep.module == 'value'
@@ -334,3 +349,79 @@ class PackagesDistributionsTest(
prefix=self.site_dir,
)
packages_distributions()
def test_packages_distributions_all_module_types(self):
"""
Test top-level modules detected on a package without 'top-level.txt'.
"""
suffixes = importlib.machinery.all_suffixes()
metadata = dict(
METADATA="""
Name: all_distributions
Version: 1.0.0
""",
)
files = {
'all_distributions-1.0.0.dist-info': metadata,
}
for i, suffix in enumerate(suffixes):
files.update(
{
f'importable-name {i}{suffix}': '',
f'in_namespace_{i}': {
f'mod{suffix}': '',
},
f'in_package_{i}': {
'__init__.py': '',
f'mod{suffix}': '',
},
}
)
metadata.update(RECORD=fixtures.build_record(files))
fixtures.build_files(files, prefix=self.site_dir)
distributions = packages_distributions()
for i in range(len(suffixes)):
assert distributions[f'importable-name {i}'] == ['all_distributions']
assert distributions[f'in_namespace_{i}'] == ['all_distributions']
assert distributions[f'in_package_{i}'] == ['all_distributions']
assert not any(name.endswith('.dist-info') for name in distributions)
class PackagesDistributionsEggTest(
fixtures.EggInfoPkg,
fixtures.EggInfoPkgPipInstalledNoToplevel,
fixtures.EggInfoPkgPipInstalledNoModules,
fixtures.EggInfoPkgSourcesFallback,
unittest.TestCase,
):
def test_packages_distributions_on_eggs(self):
"""
Test old-style egg packages with a variation of 'top_level.txt',
'SOURCES.txt', and 'installed-files.txt', available.
"""
distributions = packages_distributions()
def import_names_from_package(package_name):
return {
import_name
for import_name, package_names in distributions.items()
if package_name in package_names
}
# egginfo-pkg declares one import ('mod') via top_level.txt
assert import_names_from_package('egginfo-pkg') == {'mod'}
# egg_with_module-pkg has one import ('egg_with_module') inferred from
# installed-files.txt (top_level.txt is missing)
assert import_names_from_package('egg_with_module-pkg') == {'egg_with_module'}
# egg_with_no_modules-pkg should not be associated with any import names
# (top_level.txt is empty, and installed-files.txt has no .py files)
assert import_names_from_package('egg_with_no_modules-pkg') == set()
# sources_fallback-pkg has one import ('sources_fallback') inferred from
# SOURCES.txt (top_level.txt and installed-files.txt is missing)
assert import_names_from_package('sources_fallback-pkg') == {'sources_fallback'}

View File

@@ -27,12 +27,14 @@ def suppress_known_deprecation():
class APITests(
fixtures.EggInfoPkg,
fixtures.EggInfoPkgPipInstalledNoToplevel,
fixtures.EggInfoPkgPipInstalledNoModules,
fixtures.EggInfoPkgSourcesFallback,
fixtures.DistInfoPkg,
fixtures.DistInfoPkgWithDot,
fixtures.EggInfoFile,
unittest.TestCase,
):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_self(self):
@@ -63,15 +65,28 @@ class APITests(
distribution(prefix)
def test_for_top_level(self):
self.assertEqual(
distribution('egginfo-pkg').read_text('top_level.txt').strip(), 'mod'
)
tests = [
('egginfo-pkg', 'mod'),
('egg_with_no_modules-pkg', ''),
]
for pkg_name, expect_content in tests:
with self.subTest(pkg_name):
self.assertEqual(
distribution(pkg_name).read_text('top_level.txt').strip(),
expect_content,
)
def test_read_text(self):
top_level = [
path for path in files('egginfo-pkg') if path.name == 'top_level.txt'
][0]
self.assertEqual(top_level.read_text(), 'mod\n')
tests = [
('egginfo-pkg', 'mod\n'),
('egg_with_no_modules-pkg', '\n'),
]
for pkg_name, expect_content in tests:
with self.subTest(pkg_name):
top_level = [
path for path in files(pkg_name) if path.name == 'top_level.txt'
][0]
self.assertEqual(top_level.read_text(), expect_content)
def test_entry_points(self):
eps = entry_points()
@@ -124,62 +139,6 @@ class APITests(
def test_entry_points_missing_group(self):
assert entry_points(group='missing') == ()
def test_entry_points_dict_construction(self):
"""
Prior versions of entry_points() returned simple lists and
allowed casting those lists into maps by name using ``dict()``.
Capture this now deprecated use-case.
"""
with suppress_known_deprecation() as caught:
eps = dict(entry_points(group='entries'))
assert 'main' in eps
assert eps['main'] == entry_points(group='entries')['main']
# check warning
expected = next(iter(caught))
assert expected.category is DeprecationWarning
assert "Construction of dict of EntryPoints is deprecated" in str(expected)
def test_entry_points_by_index(self):
"""
Prior versions of Distribution.entry_points would return a
tuple that allowed access by index.
Capture this now deprecated use-case
See python/importlib_metadata#300 and bpo-44246.
"""
eps = distribution('distinfo-pkg').entry_points
with suppress_known_deprecation() as caught:
eps[0]
# check warning
expected = next(iter(caught))
assert expected.category is DeprecationWarning
assert "Accessing entry points by index is deprecated" in str(expected)
def test_entry_points_groups_getitem(self):
"""
Prior versions of entry_points() returned a dict. Ensure
that callers using '.__getitem__()' are supported but warned to
migrate.
"""
with suppress_known_deprecation():
entry_points()['entries'] == entry_points(group='entries')
with self.assertRaises(KeyError):
entry_points()['missing']
def test_entry_points_groups_get(self):
"""
Prior versions of entry_points() returned a dict. Ensure
that callers using '.get()' are supported but warned to
migrate.
"""
with suppress_known_deprecation():
entry_points().get('missing', 'default') == 'default'
entry_points().get('entries', 'default') == entry_points()['entries']
entry_points().get('missing', ()) == ()
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_entry_points_allows_no_attributes(self):
@@ -195,6 +154,28 @@ class APITests(
classifiers = md.get_all('Classifier')
assert 'Topic :: Software Development :: Libraries' in classifiers
def test_missing_key_legacy(self):
"""
Requesting a missing key will still return None, but warn.
"""
md = metadata('distinfo-pkg')
with suppress_known_deprecation():
assert md['does-not-exist'] is None
def test_get_key(self):
"""
Getting a key gets the key.
"""
md = metadata('egginfo-pkg')
assert md.get('Name') == 'egginfo-pkg'
def test_get_missing_key(self):
"""
Requesting a missing key will return None.
"""
md = metadata('distinfo-pkg')
assert md.get('does-not-exist') is None
@staticmethod
def _test_files(files):
root = files[0].root
@@ -217,6 +198,9 @@ class APITests(
def test_files_egg_info(self):
self._test_files(files('egginfo-pkg'))
self._test_files(files('egg_with_module-pkg'))
self._test_files(files('egg_with_no_modules-pkg'))
self._test_files(files('sources_fallback-pkg'))
def test_version_egg_info_file(self):
self.assertEqual(version('egginfo-file'), '0.1')

View File

@@ -79,12 +79,9 @@ class SingleNamespacePackage(NamespacePackageTest):
with self.assertRaises(ImportError):
import foo.two
def test_module_repr(self):
def test_simple_repr(self):
import foo.one
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.assertEqual(foo.__spec__.loader.module_repr(foo),
"<module 'foo' (namespace)>")
assert repr(foo).startswith("<module 'foo' (namespace) from [")
class DynamicPathNamespacePackage(NamespacePackageTest):

View File

@@ -47,21 +47,6 @@ class NewLoader(TestLoader):
module.eggs = self.EGGS
class LegacyLoader(TestLoader):
HAM = -1
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
frozen_util = util['Frozen']
@frozen_util.module_for_loader
def load_module(self, module):
module.ham = self.HAM
return module
class ModuleSpecTests:
def setUp(self):
@@ -302,26 +287,6 @@ class ModuleSpecMethodsTests:
loaded = self.bootstrap._load(self.spec)
self.assertNotIn(self.spec.name, sys.modules)
def test_load_legacy(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
self.spec.loader = LegacyLoader()
with CleanImport(self.spec.name):
loaded = self.bootstrap._load(self.spec)
self.assertEqual(loaded.ham, -1)
def test_load_legacy_attributes(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
self.spec.loader = LegacyLoader()
with CleanImport(self.spec.name):
loaded = self.bootstrap._load(self.spec)
self.assertIs(loaded.__loader__, self.spec.loader)
self.assertEqual(loaded.__package__, self.spec.parent)
self.assertIs(loaded.__spec__, self.spec)
def test_load_legacy_attributes_immutable(self):
module = object()
with warnings.catch_warnings():
@@ -387,19 +352,6 @@ class ModuleSpecMethodsTests:
self.assertFalse(hasattr(loaded, '__file__'))
self.assertFalse(hasattr(loaded, '__cached__'))
def test_reload_legacy(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
self.spec.loader = LegacyLoader()
with CleanImport(self.spec.name):
loaded = self.bootstrap._load(self.spec)
reloaded = self.bootstrap._exec(self.spec, loaded)
installed = sys.modules[self.spec.name]
self.assertEqual(loaded.ham, -1)
self.assertIs(reloaded, loaded)
self.assertIs(installed, loaded)
(Frozen_ModuleSpecMethodsTests,
Source_ModuleSpecMethodsTests
@@ -407,101 +359,6 @@ class ModuleSpecMethodsTests:
machinery=machinery)
class ModuleReprTests:
@property
def bootstrap(self):
return self.init._bootstrap
def setUp(self):
self.module = type(os)('spam')
self.spec = self.machinery.ModuleSpec('spam', TestLoader())
def test_module___loader___module_repr(self):
class Loader:
def module_repr(self, module):
return '<delicious {}>'.format(module.__name__)
self.module.__loader__ = Loader()
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr, '<delicious spam>')
def test_module___loader___module_repr_bad(self):
class Loader(TestLoader):
def module_repr(self, module):
raise Exception
self.module.__loader__ = Loader()
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr,
'<module {!r} (<TestLoader object>)>'.format('spam'))
def test_module___spec__(self):
origin = 'in a hole, in the ground'
self.spec.origin = origin
self.module.__spec__ = self.spec
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr, '<module {!r} ({})>'.format('spam', origin))
def test_module___spec___location(self):
location = 'in_a_galaxy_far_far_away.py'
self.spec.origin = location
self.spec._set_fileattr = True
self.module.__spec__ = self.spec
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr,
'<module {!r} from {!r}>'.format('spam', location))
def test_module___spec___no_origin(self):
self.spec.loader = TestLoader()
self.module.__spec__ = self.spec
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr,
'<module {!r} (<TestLoader object>)>'.format('spam'))
def test_module___spec___no_origin_no_loader(self):
self.spec.loader = None
self.module.__spec__ = self.spec
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr, '<module {!r}>'.format('spam'))
def test_module_no_name(self):
del self.module.__name__
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr, '<module {!r}>'.format('?'))
def test_module_with_file(self):
filename = 'e/i/e/i/o/spam.py'
self.module.__file__ = filename
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr,
'<module {!r} from {!r}>'.format('spam', filename))
def test_module_no_file(self):
self.module.__loader__ = TestLoader()
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr,
'<module {!r} (<TestLoader object>)>'.format('spam'))
def test_module_no_file_no_loader(self):
modrepr = self.bootstrap._module_repr(self.module)
self.assertEqual(modrepr, '<module {!r}>'.format('spam'))
(Frozen_ModuleReprTests,
Source_ModuleReprTests
) = test_util.test_both(ModuleReprTests, init=init, util=util,
machinery=machinery)
class FactoryTests:
def setUp(self):

View File

@@ -16,7 +16,7 @@ import threading
import unittest
from unittest import mock
from test.support import verbose
from test.support.import_helper import forget
from test.support.import_helper import forget, mock_register_at_fork
from test.support.os_helper import (TESTFN, unlink, rmtree)
from test.support import script_helper, threading_helper
@@ -42,12 +42,6 @@ def task(N, done, done_tasks, errors):
if finished:
done.set()
def mock_register_at_fork(func):
# bpo-30599: Mock os.register_at_fork() when importing the random module,
# since this function doesn't allow to unregister callbacks and would leak
# memory.
return mock.patch('os.register_at_fork', create=True)(func)
# Create a circular import structure: A -> C -> B -> D -> A
# NOTE: `time` is already loaded and therefore doesn't threaten to deadlock.
@@ -251,7 +245,8 @@ class ThreadedImportTests(unittest.TestCase):
self.addCleanup(forget, TESTFN)
self.addCleanup(rmtree, '__pycache__')
importlib.invalidate_caches()
__import__(TESTFN)
with threading_helper.wait_threads_exit():
__import__(TESTFN)
del sys.modules[TESTFN]
@unittest.skip("TODO: RUSTPYTHON; hang")

View File

@@ -8,14 +8,29 @@ importlib_util = util.import_importlib('importlib.util')
import importlib.util
import os
import pathlib
import re
import string
import sys
from test import support
import textwrap
import types
import unittest
import unittest.mock
import warnings
try:
import _testsinglephase
except ImportError:
_testsinglephase = None
try:
import _testmultiphase
except ImportError:
_testmultiphase = None
try:
import _xxsubinterpreters as _interpreters
except ModuleNotFoundError:
_interpreters = None
class DecodeSourceBytesTests:
@@ -127,247 +142,6 @@ class ModuleFromSpecTests:
util=importlib_util)
class ModuleForLoaderTests:
"""Tests for importlib.util.module_for_loader."""
@classmethod
def module_for_loader(cls, func):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return cls.util.module_for_loader(func)
def test_warning(self):
# Should raise a PendingDeprecationWarning when used.
with warnings.catch_warnings():
warnings.simplefilter('error', DeprecationWarning)
with self.assertRaises(DeprecationWarning):
func = self.util.module_for_loader(lambda x: x)
def return_module(self, name):
fxn = self.module_for_loader(lambda self, module: module)
return fxn(self, name)
def raise_exception(self, name):
def to_wrap(self, module):
raise ImportError
fxn = self.module_for_loader(to_wrap)
try:
fxn(self, name)
except ImportError:
pass
def test_new_module(self):
# Test that when no module exists in sys.modules a new module is
# created.
module_name = 'a.b.c'
with util.uncache(module_name):
module = self.return_module(module_name)
self.assertIn(module_name, sys.modules)
self.assertIsInstance(module, types.ModuleType)
self.assertEqual(module.__name__, module_name)
def test_reload(self):
# Test that a module is reused if already in sys.modules.
class FakeLoader:
def is_package(self, name):
return True
@self.module_for_loader
def load_module(self, module):
return module
name = 'a.b.c'
module = types.ModuleType('a.b.c')
module.__loader__ = 42
module.__package__ = 42
with util.uncache(name):
sys.modules[name] = module
loader = FakeLoader()
returned_module = loader.load_module(name)
self.assertIs(returned_module, sys.modules[name])
self.assertEqual(module.__loader__, loader)
self.assertEqual(module.__package__, name)
def test_new_module_failure(self):
# Test that a module is removed from sys.modules if added but an
# exception is raised.
name = 'a.b.c'
with util.uncache(name):
self.raise_exception(name)
self.assertNotIn(name, sys.modules)
def test_reload_failure(self):
# Test that a failure on reload leaves the module in-place.
name = 'a.b.c'
module = types.ModuleType(name)
with util.uncache(name):
sys.modules[name] = module
self.raise_exception(name)
self.assertIs(module, sys.modules[name])
def test_decorator_attrs(self):
def fxn(self, module): pass
wrapped = self.module_for_loader(fxn)
self.assertEqual(wrapped.__name__, fxn.__name__)
self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
def test_false_module(self):
# If for some odd reason a module is considered false, still return it
# from sys.modules.
class FalseModule(types.ModuleType):
def __bool__(self): return False
name = 'mod'
module = FalseModule(name)
with util.uncache(name):
self.assertFalse(module)
sys.modules[name] = module
given = self.return_module(name)
self.assertIs(given, module)
def test_attributes_set(self):
# __name__, __loader__, and __package__ should be set (when
# is_package() is defined; undefined implicitly tested elsewhere).
class FakeLoader:
def __init__(self, is_package):
self._pkg = is_package
def is_package(self, name):
return self._pkg
@self.module_for_loader
def load_module(self, module):
return module
name = 'pkg.mod'
with util.uncache(name):
loader = FakeLoader(False)
module = loader.load_module(name)
self.assertEqual(module.__name__, name)
self.assertIs(module.__loader__, loader)
self.assertEqual(module.__package__, 'pkg')
name = 'pkg.sub'
with util.uncache(name):
loader = FakeLoader(True)
module = loader.load_module(name)
self.assertEqual(module.__name__, name)
self.assertIs(module.__loader__, loader)
self.assertEqual(module.__package__, name)
(Frozen_ModuleForLoaderTests,
Source_ModuleForLoaderTests
) = util.test_both(ModuleForLoaderTests, util=importlib_util)
class SetPackageTests:
"""Tests for importlib.util.set_package."""
def verify(self, module, expect):
"""Verify the module has the expected value for __package__ after
passing through set_package."""
fxn = lambda: module
wrapped = self.util.set_package(fxn)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
wrapped()
self.assertTrue(hasattr(module, '__package__'))
self.assertEqual(expect, module.__package__)
def test_top_level(self):
# __package__ should be set to the empty string if a top-level module.
# Implicitly tests when package is set to None.
module = types.ModuleType('module')
module.__package__ = None
self.verify(module, '')
def test_package(self):
# Test setting __package__ for a package.
module = types.ModuleType('pkg')
module.__path__ = ['<path>']
module.__package__ = None
self.verify(module, 'pkg')
def test_submodule(self):
# Test __package__ for a module in a package.
module = types.ModuleType('pkg.mod')
module.__package__ = None
self.verify(module, 'pkg')
def test_setting_if_missing(self):
# __package__ should be set if it is missing.
module = types.ModuleType('mod')
if hasattr(module, '__package__'):
delattr(module, '__package__')
self.verify(module, '')
def test_leaving_alone(self):
# If __package__ is set and not None then leave it alone.
for value in (True, False):
module = types.ModuleType('mod')
module.__package__ = value
self.verify(module, value)
def test_decorator_attrs(self):
def fxn(module): pass
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
wrapped = self.util.set_package(fxn)
self.assertEqual(wrapped.__name__, fxn.__name__)
self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
(Frozen_SetPackageTests,
Source_SetPackageTests
) = util.test_both(SetPackageTests, util=importlib_util)
class SetLoaderTests:
"""Tests importlib.util.set_loader()."""
@property
def DummyLoader(self):
# Set DummyLoader on the class lazily.
class DummyLoader:
@self.util.set_loader
def load_module(self, module):
return self.module
self.__class__.DummyLoader = DummyLoader
return DummyLoader
def test_no_attribute(self):
loader = self.DummyLoader()
loader.module = types.ModuleType('blah')
try:
del loader.module.__loader__
except AttributeError:
pass
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.assertEqual(loader, loader.load_module('blah').__loader__)
def test_attribute_is_None(self):
loader = self.DummyLoader()
loader.module = types.ModuleType('blah')
loader.module.__loader__ = None
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.assertEqual(loader, loader.load_module('blah').__loader__)
def test_not_reset(self):
loader = self.DummyLoader()
loader.module = types.ModuleType('blah')
loader.module.__loader__ = 42
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
self.assertEqual(42, loader.load_module('blah').__loader__)
(Frozen_SetLoaderTests,
Source_SetLoaderTests
) = util.test_both(SetLoaderTests, util=importlib_util)
class ResolveNameTests:
"""Tests importlib.util.resolve_name()."""
@@ -877,7 +651,7 @@ class MagicNumberTests(unittest.TestCase):
# stakeholders such as OS package maintainers must be notified
# in advance. Such exceptional releases will then require an
# adjustment to this test case.
EXPECTED_MAGIC_NUMBER = 3495
EXPECTED_MAGIC_NUMBER = 3531
actual = int.from_bytes(importlib.util.MAGIC_NUMBER[:2], 'little')
msg = (
@@ -895,5 +669,111 @@ class MagicNumberTests(unittest.TestCase):
self.assertEqual(EXPECTED_MAGIC_NUMBER, actual, msg)
@unittest.skipIf(_interpreters is None, 'subinterpreters required')
class IncompatibleExtensionModuleRestrictionsTests(unittest.TestCase):
ERROR = re.compile("^<class 'ImportError'>: module (.*) does not support loading in subinterpreters")
def run_with_own_gil(self, script):
interpid = _interpreters.create(isolated=True)
try:
_interpreters.run_string(interpid, script)
except _interpreters.RunFailedError as exc:
if m := self.ERROR.match(str(exc)):
modname, = m.groups()
raise ImportError(modname)
def run_with_shared_gil(self, script):
interpid = _interpreters.create(isolated=False)
try:
_interpreters.run_string(interpid, script)
except _interpreters.RunFailedError as exc:
if m := self.ERROR.match(str(exc)):
modname, = m.groups()
raise ImportError(modname)
@unittest.skipIf(_testsinglephase is None, "test requires _testsinglephase module")
def test_single_phase_init_module(self):
script = textwrap.dedent('''
from importlib.util import _incompatible_extension_module_restrictions
with _incompatible_extension_module_restrictions(disable_check=True):
import _testsinglephase
''')
with self.subTest('check disabled, shared GIL'):
self.run_with_shared_gil(script)
with self.subTest('check disabled, per-interpreter GIL'):
self.run_with_own_gil(script)
script = textwrap.dedent(f'''
from importlib.util import _incompatible_extension_module_restrictions
with _incompatible_extension_module_restrictions(disable_check=False):
import _testsinglephase
''')
with self.subTest('check enabled, shared GIL'):
with self.assertRaises(ImportError):
self.run_with_shared_gil(script)
with self.subTest('check enabled, per-interpreter GIL'):
with self.assertRaises(ImportError):
self.run_with_own_gil(script)
@unittest.skipIf(_testmultiphase is None, "test requires _testmultiphase module")
def test_incomplete_multi_phase_init_module(self):
prescript = textwrap.dedent(f'''
from importlib.util import spec_from_loader, module_from_spec
from importlib.machinery import ExtensionFileLoader
name = '_test_shared_gil_only'
filename = {_testmultiphase.__file__!r}
loader = ExtensionFileLoader(name, filename)
spec = spec_from_loader(name, loader)
''')
script = prescript + textwrap.dedent('''
from importlib.util import _incompatible_extension_module_restrictions
with _incompatible_extension_module_restrictions(disable_check=True):
module = module_from_spec(spec)
loader.exec_module(module)
''')
with self.subTest('check disabled, shared GIL'):
self.run_with_shared_gil(script)
with self.subTest('check disabled, per-interpreter GIL'):
self.run_with_own_gil(script)
script = prescript + textwrap.dedent('''
from importlib.util import _incompatible_extension_module_restrictions
with _incompatible_extension_module_restrictions(disable_check=False):
module = module_from_spec(spec)
loader.exec_module(module)
''')
with self.subTest('check enabled, shared GIL'):
self.run_with_shared_gil(script)
with self.subTest('check enabled, per-interpreter GIL'):
with self.assertRaises(ImportError):
self.run_with_own_gil(script)
@unittest.skipIf(_testmultiphase is None, "test requires _testmultiphase module")
def test_complete_multi_phase_init_module(self):
script = textwrap.dedent('''
from importlib.util import _incompatible_extension_module_restrictions
with _incompatible_extension_module_restrictions(disable_check=True):
import _testmultiphase
''')
with self.subTest('check disabled, shared GIL'):
self.run_with_shared_gil(script)
with self.subTest('check disabled, per-interpreter GIL'):
self.run_with_own_gil(script)
script = textwrap.dedent(f'''
from importlib.util import _incompatible_extension_module_restrictions
with _incompatible_extension_module_restrictions(disable_check=False):
import _testmultiphase
''')
with self.subTest('check enabled, shared GIL'):
self.run_with_shared_gil(script)
with self.subTest('check enabled, per-interpreter GIL'):
self.run_with_own_gil(script)
if __name__ == '__main__':
unittest.main()

View File

@@ -92,30 +92,16 @@ class WindowsRegistryFinderTests:
def test_find_spec_missing(self):
spec = self.machinery.WindowsRegistryFinder.find_spec('spam')
self.assertIs(spec, None)
def test_find_module_missing(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loader = self.machinery.WindowsRegistryFinder.find_module('spam')
self.assertIs(loader, None)
self.assertIsNone(spec)
def test_module_found(self):
with setup_module(self.machinery, self.test_module):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module)
spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module)
self.assertIsNot(loader, None)
self.assertIsNot(spec, None)
self.assertIsNotNone(spec)
def test_module_not_found(self):
with setup_module(self.machinery, self.test_module, path="."):
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module)
spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module)
self.assertIsNone(loader)
self.assertIsNone(spec)
(Frozen_WindowsRegistryFinderTests,

View File

@@ -27,7 +27,7 @@ EXTENSIONS.path = None
EXTENSIONS.ext = None
EXTENSIONS.filename = None
EXTENSIONS.file_path = None
EXTENSIONS.name = '_testcapi'
EXTENSIONS.name = '_testsinglephase'
def _extension_details():
global EXTENSIONS
@@ -131,9 +131,8 @@ def uncache(*names):
"""
for name in names:
if name in ('sys', 'marshal', 'imp'):
raise ValueError(
"cannot uncache {0}".format(name))
if name in ('sys', 'marshal'):
raise ValueError("cannot uncache {}".format(name))
try:
del sys.modules[name]
except KeyError:
@@ -195,8 +194,7 @@ def import_state(**kwargs):
new_value = default
setattr(sys, attr, new_value)
if len(kwargs):
raise ValueError(
'unrecognized arguments: {0}'.format(kwargs.keys()))
raise ValueError('unrecognized arguments: {}'.format(kwargs))
yield
finally:
for attr, value in originals.items():
@@ -244,30 +242,6 @@ class _ImporterMock:
self._uncache.__exit__(None, None, None)
class mock_modules(_ImporterMock):
"""Importer mock using PEP 302 APIs."""
def find_module(self, fullname, path=None):
if fullname not in self.modules:
return None
else:
return self
def load_module(self, fullname):
if fullname not in self.modules:
raise ImportError
else:
sys.modules[fullname] = self.modules[fullname]
if fullname in self.module_code:
try:
self.module_code[fullname]()
except Exception:
del sys.modules[fullname]
raise
return self.modules[fullname]
class mock_spec(_ImporterMock):
"""Importer mock using PEP 451 APIs."""

Some files were not shown because too many files have changed in this diff Show More