diff --git a/Lib/enum.py b/Lib/enum.py new file mode 100644 index 0000000000..2e868bff13 --- /dev/null +++ b/Lib/enum.py @@ -0,0 +1,868 @@ +import sys +from types import MappingProxyType, DynamicClassAttribute +from functools import reduce +from operator import or_ as _or_, and_ as _and_, xor, neg + +# try _collections first to reduce startup cost +try: + from _collections import OrderedDict +except ImportError: + from collections import OrderedDict + + +__all__ = [ + 'EnumMeta', + 'Enum', 'IntEnum', 'Flag', 'IntFlag', + 'auto', 'unique', + ] + + +def _is_descriptor(obj): + """Returns True if obj is a descriptor, False otherwise.""" + return ( + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) + + +def _is_dunder(name): + """Returns True if a __dunder__ name, False otherwise.""" + return (name[:2] == name[-2:] == '__' and + name[2:3] != '_' and + name[-3:-2] != '_' and + len(name) > 4) + + +def _is_sunder(name): + """Returns True if a _sunder_ name, False otherwise.""" + return (name[0] == name[-1] == '_' and + name[1:2] != '_' and + name[-2:-1] != '_' and + len(name) > 2) + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self, proto): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce_ex__ = _break_on_call_reduce + cls.__module__ = '' + +_auto_null = object() +class auto: + """ + Instances are replaced with an appropriate value in Enum class suites. + """ + value = _auto_null + + +class _EnumDict(dict): + """Track enum member order and ensure member names are not reused. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super().__init__() + self._member_names = [] + self._last_values = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or not a descriptor. + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + Single underscore (sunder) names are reserved. + + """ + if _is_sunder(key): + if key not in ( + '_order_', '_create_pseudo_member_', + '_generate_next_value_', '_missing_', + ): + raise ValueError('_names_ are reserved for future Enum use') + if key == '_generate_next_value_': + setattr(self, '_generate_next_value', value) + elif _is_dunder(key): + if key == '__order__': + key = '_order_' + elif key in self._member_names: + # descriptor overwriting an enum? + raise TypeError('Attempted to reuse key: %r' % key) + elif not _is_descriptor(value): + if key in self: + # enum overwriting a descriptor? + raise TypeError('%r already defined as: %r' % (key, self[key])) + if isinstance(value, auto): + if value.value == _auto_null: + value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) + value = value.value + self._member_names.append(key) + self._last_values.append(value) + super().__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicitly checks for it, but of course +# until EnumMeta finishes running the first time the Enum class doesn't exist. +# This is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + # create the namespace dict + enum_dict = _EnumDict() + # inherit previous flags and _generate_next_value_ function + member_type, first_enum = metacls._get_mixins_(bases) + if first_enum is not None: + enum_dict['_generate_next_value_'] = getattr(first_enum, '_generate_next_value_', None) + return enum_dict + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + member_type, first_enum = metacls._get_mixins_(bases) + __new__, save_new, use_args = metacls._find_new_(classdict, member_type, + first_enum) + + # save enum items into separate mapping so they don't get baked into + # the new class + enum_members = {k: classdict[k] for k in classdict._member_names} + for name in classdict._member_names: + del classdict[name] + + # adjust the sunders + _order_ = classdict.pop('_order_', None) + + # check for illegal enum names (any others?) + invalid_names = set(enum_members) & {'mro', } + if invalid_names: + raise ValueError('Invalid enum member name: {0}'.format( + ','.join(invalid_names))) + + # create a default docstring if one has not been provided + if '__doc__' not in classdict: + classdict['__doc__'] = 'An enumeration.' + + # create our new Enum type + enum_class = super().__new__(metacls, cls, bases, classdict) + enum_class._member_names_ = [] # names in definition order + enum_class._member_map_ = OrderedDict() # name->value map + enum_class._member_type_ = member_type + + # save attributes from super classes so we know if we can take + # the shortcut of storing members in the class dict + base_attributes = {a for b in enum_class.mro() for a in dir(b)} # XXX modified for rustpython + + # Reverse value->name map for hashable values. + enum_class._value2member_map_ = {} + + # If a custom type is mixed into the Enum, and it does not know how + # to pickle itself, pickle.dumps will succeed but pickle.loads will + # fail. Rather than have the error show up later and possibly far + # from the source, sabotage the pickle protocol for this class so + # that pickle.dumps also fails. + # + # However, if the new class implements its own __reduce_ex__, do not + # sabotage -- it's on them to make sure it works correctly. We use + # __reduce_ex__ instead of any of the others as it is preferred by + # pickle over __reduce__, and it handles all pickle protocols. + if '__reduce_ex__' not in classdict: + if member_type is not object: + methods = ('__getnewargs_ex__', '__getnewargs__', + '__reduce_ex__', '__reduce__') + if not any(m in member_type.__dict__ for m in methods): + _make_class_unpicklable(enum_class) + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + for member_name in classdict._member_names: + value = enum_members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args: + enum_member = __new__(enum_class) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value_'): + if member_type is object: + enum_member._value_ = value + else: + enum_member._value_ = member_type(*args) + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member._value_ == enum_member._value_: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names_.append(member_name) + # performance boost for any member that would not shadow + # a DynamicClassAttribute + if member_name not in base_attributes: + setattr(enum_class, member_name, enum_member) + # now add to _member_map_ + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_[value] = enum_member + except TypeError: + pass + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if obj_method is not None and obj_method is class_method: + setattr(enum_class, name, enum_method) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + enum_class.__new_member__ = __new__ + enum_class.__new__ = Enum.__new__ + + # py3 support for definition order (helps keep py2/py3 code in sync) + if _order_ is not None: + if isinstance(_order_, str): + _order_ = _order_.replace(',', ' ').split() + if _order_ != enum_class._member_names_: + raise TypeError('member order does not match _order_') + + return enum_class + + def __bool__(self): + """ + classes/types should always be True. + """ + return True + + def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='RED GREEN BLUE')). + + When used for the functional API: + + `value` will be the name of the new class. + + `names` should be either a string of white-space/comma delimited names + (values will start at `start`), or an iterator/mapping of name, value pairs. + + `module` should be set to the module this class is being created in; + if it is not set, an attempt to find that module will be made, but if + it fails the class will not be picklable. + + `qualname` should be set to the actual location this class can be found + at in its module; by default it is set to the global scope. If this is + not correct, unpickling will fail in some circumstances. + + `type`, if set, will be mixed in as the first base class. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create_(value, names, module=module, qualname=qualname, type=type, start=start) + + def __contains__(cls, member): + return isinstance(member, cls) and member._name_ in cls._member_map_ + + def __delattr__(cls, attr): + # nicer error message when someone tries to delete an attribute + # (see issue19025). + if attr in cls._member_map_: + raise AttributeError( + "%s: cannot delete Enum member." % cls.__name__) + super().__delattr__(attr) + + def __dir__(self): + return (['__class__', '__doc__', '__members__', '__module__'] + + self._member_names_) + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map_[name] + except KeyError: + raise AttributeError(name) from None + + def __getitem__(cls, name): + return cls._member_map_[name] + + def __iter__(cls): + return (cls._member_map_[name] for name in cls._member_names_) + + def __len__(cls): + return len(cls._member_names_) + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a read-only view of the internal mapping. + + """ + return MappingProxyType(cls._member_map_) + + def __repr__(cls): + return "" % cls.__name__ + + def __reversed__(cls): + return (cls._member_map_[name] for name in reversed(cls._member_names_)) + + def __setattr__(cls, name, value): + """Block attempts to reassign Enum members. + + A simple assignment to the class namespace only changes one of the + several possible ways to get an Enum member from the Enum class, + resulting in an inconsistent Enumeration. + + """ + member_map = cls.__dict__.get('_member_map_', {}) + if name in member_map: + raise AttributeError('Cannot reassign members.') + super().__setattr__(name, value) + + def _create_(cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are incremented by 1 from `start`. + * An iterable of member names. Values are incremented by 1 from `start`. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value pairs. + + """ + metacls = cls.__class__ + bases = (cls, ) if type is None else (type, cls) + _, first_enum = cls._get_mixins_(bases) + classdict = metacls.__prepare__(class_name, bases) + + # special processing needed for names? + if isinstance(names, str): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], str): + original_names, names = names, [] + last_values = [] + for count, name in enumerate(original_names): + value = first_enum._generate_next_value_(name, start, count, last_values[:]) + last_values.append(value) + names.append((name, value)) + + # Here, names is either an iterable of (name, value) or a mapping. + for item in names: + if isinstance(item, str): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError) as exc: + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + if qualname is not None: + enum_class.__qualname__ = qualname + + return enum_class + + @staticmethod + def _get_mixins_(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases: + return object, Enum + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names_): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (, , + # , , + # ) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __new_member__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __new_member__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __new_member__ before falling back to + # __new__ + for method in ('__new_member__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in { + None, + None.__new__, + object.__new__, + Enum.__new__, + }: + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +class Enum(metaclass=EnumMeta): + """Generic enumeration. + + Derive from this class to define new enumerations. + + """ + def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.RED) + return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + try: + if value in cls._value2member_map_: + return cls._value2member_map_[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map_.values(): + if member._value_ == value: + return member + # still not found -- try _missing_ hook + return cls._missing_(value) + + def _generate_next_value_(name, start, count, last_values): + for last_value in reversed(last_values): + try: + return last_value + 1 + except TypeError: + pass + else: + return start + + @classmethod + def _missing_(cls, value): + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + + def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name_, self._value_) + + def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name_) + + def __dir__(self): + added_behavior = [ + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' and m not in self._member_map_ + ] + return (['__class__', '__doc__', '__module__'] + added_behavior) + + def __format__(self, format_spec): + # mixed-in Enums should use the mixed-in type's __format__, otherwise + # we can get strange results with the Enum name showing up instead of + # the value + + # pure Enum branch + if self._member_type_ is object: + cls = str + val = str(self) + # mix-in branch + else: + cls = self._member_type_ + val = self._value_ + return cls.__format__(val, format_spec) + + def __hash__(self): + return hash(self._name_) + + def __reduce_ex__(self, proto): + return self.__class__, (self._value_, ) + + # DynamicClassAttribute is used to provide access to the `name` and + # `value` properties of enum members while keeping some measure of + # protection from modification, while still allowing for an enumeration + # to have members named `name` and `value`. This works because enumeration + # members are not set directly on the enum class -- __getattr__ is + # used to look them up. + + @DynamicClassAttribute + def name(self): + """The name of the Enum member.""" + return self._name_ + + @DynamicClassAttribute + def value(self): + """The value of the Enum member.""" + return self._value_ + + @classmethod + def _convert(cls, name, module, filter, source=None): + """ + Create a new Enum subclass that replaces a collection of global constants + """ + # convert all constants from source (or module) that pass filter() to + # a new Enum called name, and export the enum and its members back to + # module; + # also, replace the __reduce_ex__ method so unpickling works in + # previous Python versions + module_globals = vars(sys.modules[module]) + if source: + source = vars(source) + else: + source = module_globals + # We use an OrderedDict of sorted source keys so that the + # _value2member_map is populated in the same order every time + # for a consistent reverse mapping of number to name when there + # are multiple names for the same number rather than varying + # between runs due to hash randomization of the module dictionary. + members = [ + (name, source[name]) + for name in source.keys() + if filter(name)] + try: + # sort by value + members.sort(key=lambda t: (t[1], t[0])) + except TypeError: + # unless some values aren't comparable, in which case sort by name + members.sort(key=lambda t: t[0]) + cls = cls(name, members, module=module) + cls.__reduce_ex__ = _reduce_ex_by_name + module_globals.update(cls.__members__) + module_globals[name] = cls + return cls + + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" + + +def _reduce_ex_by_name(self, proto): + return self.name + +class Flag(Enum): + """Support for flags""" + + def _generate_next_value_(name, start, count, last_values): + """ + Generate the next value when not given. + + name: the name of the member + start: the initital start value or None + count: the number of existing members + last_value: the last value assigned or None + """ + if not count: + return start if start is not None else 1 + for last_value in reversed(last_values): + try: + high_bit = _high_bit(last_value) + break + except Exception: + raise TypeError('Invalid Flag value: %r' % last_value) from None + return 2 ** (high_bit+1) + + @classmethod + def _missing_(cls, value): + original_value = value + if value < 0: + value = ~value + possible_member = cls._create_pseudo_member_(value) + if original_value < 0: + possible_member = ~possible_member + return possible_member + + @classmethod + def _create_pseudo_member_(cls, value): + """ + Create a composite member iff value contains only members. + """ + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + # verify all bits are accounted for + _, extra_flags = _decompose(cls, value) + if extra_flags: + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + # construct a singleton enum pseudo-member + pseudo_member = object.__new__(cls) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member + return pseudo_member + + def __contains__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other._value_ & self._value_ == other._value_ + + def __repr__(self): + cls = self.__class__ + if self._name_ is not None: + return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) + members, uncovered = _decompose(cls, self._value_) + return '<%s.%s: %r>' % ( + cls.__name__, + '|'.join([str(m._name_ or m._value_) for m in members]), + self._value_, + ) + + def __str__(self): + cls = self.__class__ + if self._name_ is not None: + return '%s.%s' % (cls.__name__, self._name_) + members, uncovered = _decompose(cls, self._value_) + if len(members) == 1 and members[0]._name_ is None: + return '%s.%r' % (cls.__name__, members[0]._value_) + else: + return '%s.%s' % ( + cls.__name__, + '|'.join([str(m._name_ or m._value_) for m in members]), + ) + + def __bool__(self): + return bool(self._value_) + + def __or__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ | other._value_) + + def __and__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ & other._value_) + + def __xor__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ ^ other._value_) + + def __invert__(self): + members, uncovered = _decompose(self.__class__, self._value_) + inverted_members = [ + m for m in self.__class__ + if m not in members and not m._value_ & self._value_ + ] + inverted = reduce(_or_, inverted_members, self.__class__(0)) + return self.__class__(inverted) + + +class IntFlag(int, Flag): + """Support for integer-based Flags""" + + @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + new_member = cls._create_pseudo_member_(value) + return new_member + + @classmethod + def _create_pseudo_member_(cls, value): + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + need_to_create = [value] + # get unaccounted for bits + _, extra_flags = _decompose(cls, value) + # timer = 10 + while extra_flags: + # timer -= 1 + bit = _high_bit(extra_flags) + flag_value = 2 ** bit + if (flag_value not in cls._value2member_map_ and + flag_value not in need_to_create + ): + need_to_create.append(flag_value) + if extra_flags == -flag_value: + extra_flags = 0 + else: + extra_flags ^= flag_value + for value in reversed(need_to_create): + # construct singleton pseudo-members + pseudo_member = int.__new__(cls, value) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member + return pseudo_member + + def __or__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + result = self.__class__(self._value_ | self.__class__(other)._value_) + return result + + def __and__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ & self.__class__(other)._value_) + + def __xor__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ ^ self.__class__(other)._value_) + + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ + + def __invert__(self): + result = self.__class__(~self._value_) + return result + + +def _high_bit(value): + """returns index of highest bit, or -1 if value is zero or negative""" + return value.bit_length() - 1 + +def unique(enumeration): + """Class decorator for enumerations ensuring unique member values.""" + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + alias_details = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates]) + raise ValueError('duplicate values found in %r: %s' % + (enumeration, alias_details)) + return enumeration + +def _decompose(flag, value): + """Extract all members from the value.""" + # _decompose is only called if the value is not named + not_covered = value + negative = value < 0 + if negative: + # only check for named flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None + ] + else: + # check for named flags and powers-of-two flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None or _power_of_two(v) + ] + members = [] + for member, member_value in flags_to_check: + if member_value and member_value & value == member_value: + members.append(member) + not_covered &= ~member_value + if not members and value in flag._value2member_map_: + members.append(flag._value2member_map_[value]) + members.sort(key=lambda m: m._value_, reverse=True) + if len(members) > 1 and members[0].value == value: + # we have the breakdown, don't need the value member itself + members.pop(0) + return members, not_covered + +def _power_of_two(value): + if value < 1: + return False + return value == 2 ** _high_bit(value) diff --git a/derive/src/pyclass.rs b/derive/src/pyclass.rs index f4f0ec2555..cad700fdb9 100644 --- a/derive/src/pyclass.rs +++ b/derive/src/pyclass.rs @@ -20,6 +20,10 @@ enum ClassItem { py_name: String, setter: bool, }, + Slot { + slot_ident: Ident, + item_ident: Ident, + }, } fn meta_to_vec(meta: Meta) -> Result, Meta> { @@ -192,6 +196,28 @@ impl ClassItem { setter, }); attr_idx = Some(i); + } else if name == "pyslot" { + if item.is_some() { + bail_span!( + sig.ident, + "You can only have one #[py*] attribute on an impl item" + ) + } + let pyslot_err = "#[pyslot] must be of the form #[pyslot(slotname)]"; + let nesteds = + meta_to_vec(meta).map_err(|meta| err_span!(meta, "{}", pyslot_err))?; + if nesteds.len() != 1 { + return Err(Diagnostic::spanned_error("e!(#(#nesteds)*), pyslot_err)); + } + let slot_ident = match nesteds.into_iter().next().unwrap() { + NestedMeta::Meta(Meta::Word(ident)) => ident, + bad => bail_span!(bad, "{}", pyslot_err), + }; + item = Some(ClassItem::Slot { + slot_ident, + item_ident: sig.ident.clone(), + }); + attr_idx = Some(i); } } if let Some(attr_idx) = attr_idx { @@ -257,6 +283,14 @@ pub fn impl_pyimpl(_attr: AttributeArgs, item: Item) -> Result Some(quote! { class.set_str_attr(#py_name, ctx.new_classmethod(Self::#item_ident)); }), + ClassItem::Slot { + slot_ident, + item_ident, + } => Some(quote! { + class.slots.borrow_mut().#slot_ident = Some( + ::rustpython_vm::function::IntoPyNativeFunc::into_func(Self::#item_ident) + ); + }), _ => None, }); let properties = properties diff --git a/tests/snippets/mappingproxy.py b/tests/snippets/mappingproxy.py index fa3b17beeb..cfba56a8df 100644 --- a/tests/snippets/mappingproxy.py +++ b/tests/snippets/mappingproxy.py @@ -18,3 +18,7 @@ assert 'b' in A.__dict__ assert 'c' not in A.__dict__ assert '__dict__' in A.__dict__ + +assert A.__dict__.get("not here", "default") == "default" +assert A.__dict__.get("a", "default") is A.a +assert A.__dict__.get("not here") is None diff --git a/vm/src/macros.rs b/vm/src/macros.rs index 1e3ae1d98c..a1e35368f6 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -137,12 +137,10 @@ macro_rules! extend_module { #[macro_export] macro_rules! py_class { - ( $ctx:expr, $class_name:expr, $class_base:expr, { $($name:expr => $value:expr),* $(,)* }) => { + ( $ctx:expr, $class_name:expr, $class_base:expr, { $($name:tt => $value:expr),* $(,)* }) => { { let py_class = $ctx.new_class($class_name, $class_base); - $( - py_class.set_str_attr($name, $value); - )* + $crate::extend_class!($ctx, &py_class, { $($name => $value),* }); py_class } } @@ -150,12 +148,20 @@ macro_rules! py_class { #[macro_export] macro_rules! extend_class { - ( $ctx:expr, $class:expr, { $($name:expr => $value:expr),* $(,)* }) => { - let class = $class; + ( $ctx:expr, $class:expr, { $($name:tt => $value:expr),* $(,)* }) => { $( - class.set_str_attr($name, $value); + $crate::extend_class!(@set_attr($ctx, $class, $name, $value)); )* - } + }; + + (@set_attr($ctx:expr, $class:expr, (slot $slot_name:ident), $value:expr)) => { + $class.slots.borrow_mut().$slot_name = Some( + $crate::function::IntoPyNativeFunc::into_func($value) + ); + }; + (@set_attr($ctx:expr, $class:expr, $name:expr, $value:expr)) => { + $class.set_str_attr($name, $value); + }; } #[macro_export] diff --git a/vm/src/obj/objbool.rs b/vm/src/obj/objbool.rs index 039b4688d9..bdf2306e9e 100644 --- a/vm/src/obj/objbool.rs +++ b/vm/src/obj/objbool.rs @@ -81,7 +81,7 @@ The class bool is a subclass of the class int, and cannot be subclassed."; let bool_type = &context.types.bool_type; extend_class!(context, bool_type, { - "__new__" => context.new_rustfunc(bool_new), + (slot new) => bool_new, "__repr__" => context.new_rustfunc(bool_repr), "__format__" => context.new_rustfunc(bool_format), "__or__" => context.new_rustfunc(bool_or), diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index c3d722ea74..68bfbd3f76 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -90,8 +90,8 @@ pub fn init(context: &PyContext) { #[pyimpl] impl PyByteArrayRef { - #[pymethod(name = "__new__")] - fn bytearray_new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, options: ByteInnerNewOptions, vm: &VirtualMachine, diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 585658944b..0e038290ca 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -97,8 +97,8 @@ pub fn init(context: &PyContext) { #[pyimpl] impl PyBytesRef { - #[pymethod(name = "__new__")] - fn bytes_new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, options: ByteInnerNewOptions, vm: &VirtualMachine, diff --git a/vm/src/obj/objclassmethod.rs b/vm/src/obj/objclassmethod.rs index 38ef47d9e5..e1fa5d266b 100644 --- a/vm/src/obj/objclassmethod.rs +++ b/vm/src/obj/objclassmethod.rs @@ -39,8 +39,8 @@ impl PyValue for PyClassMethod { #[pyimpl] impl PyClassMethod { - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, callable: PyObjectRef, vm: &VirtualMachine, diff --git a/vm/src/obj/objcode.rs b/vm/src/obj/objcode.rs index a1671d0efd..40fb3e8d1c 100644 --- a/vm/src/obj/objcode.rs +++ b/vm/src/obj/objcode.rs @@ -82,7 +82,7 @@ impl PyCodeRef { pub fn init(context: &PyContext) { extend_class!(context, &context.types.code_type, { - "__new__" => context.new_rustfunc(PyCodeRef::new), + (slot new) => PyCodeRef::new, "__repr__" => context.new_rustfunc(PyCodeRef::repr), "co_argcount" => context.new_property(PyCodeRef::co_argcount), diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 029b8805d3..6dff00ce75 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -230,8 +230,8 @@ impl PyComplex { !Complex64::is_zero(&self.value) } - #[pymethod(name = "__new__")] - fn complex_new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, real: OptionalArg, imag: OptionalArg, diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index 4576d71eda..33b9bb5ca2 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -577,7 +577,7 @@ pub fn init(context: &PyContext) { "__eq__" => context.new_rustfunc(PyDictRef::eq), "__getitem__" => context.new_rustfunc(PyDictRef::inner_getitem), "__iter__" => context.new_rustfunc(PyDictRef::iter), - "__new__" => context.new_rustfunc(PyDictRef::new), + (slot new) => PyDictRef::new, "__repr__" => context.new_rustfunc(PyDictRef::repr), "__setitem__" => context.new_rustfunc(PyDictRef::inner_setitem), "__hash__" => context.new_rustfunc(PyDictRef::hash), diff --git a/vm/src/obj/objellipsis.rs b/vm/src/obj/objellipsis.rs index 3b2b6d8bfe..cb9af2f58e 100644 --- a/vm/src/obj/objellipsis.rs +++ b/vm/src/obj/objellipsis.rs @@ -4,7 +4,7 @@ use crate::vm::VirtualMachine; pub fn init(context: &PyContext) { extend_class!(context, &context.ellipsis_type, { - "__new__" => context.new_rustfunc(ellipsis_new), + (slot new) => ellipsis_new, "__repr__" => context.new_rustfunc(ellipsis_repr), "__reduce__" => context.new_rustfunc(ellipsis_reduce), }); diff --git a/vm/src/obj/objenumerate.rs b/vm/src/obj/objenumerate.rs index 0ce34c689f..a1903aea49 100644 --- a/vm/src/obj/objenumerate.rs +++ b/vm/src/obj/objenumerate.rs @@ -70,6 +70,6 @@ impl PyEnumerate { pub fn init(context: &PyContext) { PyEnumerate::extend_class(context, &context.types.enumerate_type); extend_class!(context, &context.types.enumerate_type, { - "__new__" => context.new_rustfunc(enumerate_new), + (slot new) => enumerate_new, }); } diff --git a/vm/src/obj/objfilter.rs b/vm/src/obj/objfilter.rs index d046fda2cc..1e2222517e 100644 --- a/vm/src/obj/objfilter.rs +++ b/vm/src/obj/objfilter.rs @@ -69,6 +69,6 @@ impl PyFilter { pub fn init(context: &PyContext) { PyFilter::extend_class(context, &context.types.filter_type); extend_class!(context, &context.types.filter_type, { - "__new__" => context.new_rustfunc(filter_new), + (slot new) => filter_new, }); } diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 8fdd7af890..16b7848471 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -157,8 +157,8 @@ fn inner_gt_int(value: f64, other_int: &BigInt) -> bool { #[pyimpl] #[allow(clippy::trivially_copy_pass_by_ref)] impl PyFloat { - #[pymethod(name = "__new__")] - fn float_new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, arg: OptionalArg, vm: &VirtualMachine, diff --git a/vm/src/obj/objframe.rs b/vm/src/obj/objframe.rs index efc833270f..2b2d2e2bdb 100644 --- a/vm/src/obj/objframe.rs +++ b/vm/src/obj/objframe.rs @@ -10,7 +10,7 @@ use crate::vm::VirtualMachine; pub fn init(context: &PyContext) { extend_class!(context, &context.types.frame_type, { - "__new__" => context.new_rustfunc(FrameRef::new), + (slot new) => FrameRef::new, "__repr__" => context.new_rustfunc(FrameRef::repr), "f_locals" => context.new_property(FrameRef::flocals), "f_globals" => context.new_property(FrameRef::f_globals), diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 0fd01b0784..45c262122f 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -88,7 +88,8 @@ macro_rules! impl_try_from_object_int { ($(($t:ty, $to_prim:ident),)*) => {$( impl TryFromObject for $t { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - match PyRef::::try_from_object(vm, obj)?.value.$to_prim() { + let int = PyIntRef::try_from_object(vm, obj)?; + match int.value.$to_prim() { Some(value) => Ok(value), None => Err( vm.new_overflow_error(concat!( @@ -923,6 +924,6 @@ fn get_py_int(obj: &PyObjectRef) -> &PyInt { pub fn init(context: &PyContext) { PyInt::extend_class(context, &context.types.int_type); extend_class!(context, &context.types.int_type, { - "__new__" => context.new_rustfunc(int_new), + (slot new) => int_new, }); } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index caf174e463..0afb44e393 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -881,7 +881,7 @@ pub fn init(context: &PyContext) { "__rmul__" => context.new_rustfunc(PyListRef::rmul), "__imul__" => context.new_rustfunc(PyListRef::imul), "__len__" => context.new_rustfunc(PyListRef::len), - "__new__" => context.new_rustfunc(list_new), + (slot new) => list_new, "__repr__" => context.new_rustfunc(PyListRef::repr), "__hash__" => context.new_rustfunc(PyListRef::hash), "__doc__" => context.new_str(list_doc.to_string()), diff --git a/vm/src/obj/objmap.rs b/vm/src/obj/objmap.rs index a27b56e8b6..655b561ecb 100644 --- a/vm/src/obj/objmap.rs +++ b/vm/src/obj/objmap.rs @@ -63,6 +63,6 @@ impl PyMap { pub fn init(context: &PyContext) { PyMap::extend_class(context, &context.types.map_type); extend_class!(context, &context.types.map_type, { - "__new__" => context.new_rustfunc(map_new), + (slot new) => map_new, }); } diff --git a/vm/src/obj/objmappingproxy.rs b/vm/src/obj/objmappingproxy.rs index b38e9d1ef6..692fde2b37 100644 --- a/vm/src/obj/objmappingproxy.rs +++ b/vm/src/obj/objmappingproxy.rs @@ -1,6 +1,7 @@ use super::objstr::PyStringRef; use super::objtype::{self, PyClassRef}; -use crate::pyobject::{PyClassImpl, PyContext, PyRef, PyResult, PyValue}; +use crate::function::OptionalArg; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[pyclass] @@ -23,6 +24,14 @@ impl PyMappingProxy { PyMappingProxy { class } } + #[pymethod] + fn get(&self, key: PyStringRef, default: OptionalArg, vm: &VirtualMachine) -> PyObjectRef { + let default = default.into_option(); + objtype::class_get_attr(&self.class, key.as_str()) + .or(default) + .unwrap_or_else(|| vm.get_none()) + } + #[pymethod(name = "__getitem__")] pub fn getitem(&self, key: PyStringRef, vm: &VirtualMachine) -> PyResult { if let Some(value) = objtype::class_get_attr(&self.class, key.as_str()) { diff --git a/vm/src/obj/objmemory.rs b/vm/src/obj/objmemory.rs index ed4687032b..bb2ff9fe2c 100644 --- a/vm/src/obj/objmemory.rs +++ b/vm/src/obj/objmemory.rs @@ -17,8 +17,8 @@ impl PyMemoryView { try_as_byte(&self.obj_ref) } - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, bytes_object: PyObjectRef, vm: &VirtualMachine, diff --git a/vm/src/obj/objmodule.rs b/vm/src/obj/objmodule.rs index 7fc6c13687..645fd22fd7 100644 --- a/vm/src/obj/objmodule.rs +++ b/vm/src/obj/objmodule.rs @@ -82,7 +82,7 @@ impl PyModuleRef { pub fn init(context: &PyContext) { extend_class!(&context, &context.types.module_type, { - "__new__" => context.new_rustfunc(PyModuleRef::new), + (slot new) => PyModuleRef::new, "__getattribute__" => context.new_rustfunc(PyModuleRef::getattribute), "__repr__" => context.new_rustfunc(PyModuleRef::repr), }); diff --git a/vm/src/obj/objnone.rs b/vm/src/obj/objnone.rs index fab016b63a..1a64123270 100644 --- a/vm/src/obj/objnone.rs +++ b/vm/src/obj/objnone.rs @@ -7,7 +7,7 @@ use crate::pyobject::{ }; use crate::vm::VirtualMachine; -#[pyclass(name = "none")] +#[pyclass(name = "NoneType")] #[derive(Debug)] pub struct PyNone; pub type PyNoneRef = PyRef; @@ -37,8 +37,8 @@ impl IntoPyObject for Option { #[pyimpl] impl PyNoneRef { - #[pymethod(name = "__new__")] - fn none_new(_: PyClassRef, vm: &VirtualMachine) -> PyNoneRef { + #[pyslot(new)] + fn tp_new(_: PyClassRef, vm: &VirtualMachine) -> PyNoneRef { vm.ctx.none.clone() } diff --git a/vm/src/obj/objobject.rs b/vm/src/obj/objobject.rs index dbd5498c42..eaebf8b8bf 100644 --- a/vm/src/obj/objobject.rs +++ b/vm/src/obj/objobject.rs @@ -157,7 +157,9 @@ pub fn init(context: &PyContext) { let object_doc = "The most base type"; extend_class!(context, object, { - "__new__" => context.new_rustfunc(new_instance), + (slot new) => new_instance, + // yeah, it's `type_new`, but we're putting here so it's available on every object + "__new__" => context.new_classmethod(objtype::type_new), "__init__" => context.new_rustfunc(object_init), "__class__" => PropertyBuilder::new(context) diff --git a/vm/src/obj/objproperty.rs b/vm/src/obj/objproperty.rs index 52091f29f6..776f06af97 100644 --- a/vm/src/obj/objproperty.rs +++ b/vm/src/obj/objproperty.rs @@ -105,12 +105,8 @@ struct PropertyArgs { #[pyimpl] impl PyProperty { - #[pymethod(name = "__new__")] - fn new_property( - cls: PyClassRef, - args: PropertyArgs, - vm: &VirtualMachine, - ) -> PyResult { + #[pyslot(new)] + fn tp_new(cls: PyClassRef, args: PropertyArgs, vm: &VirtualMachine) -> PyResult { PyProperty { getter: args.fget, setter: args.fset, diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 32f8a48450..3b2c6e68e1 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -133,7 +133,6 @@ type PyRangeRef = PyRef; #[pyimpl] impl PyRange { - #[pymethod(name = "__new__")] fn new(cls: PyClassRef, stop: PyIntRef, vm: &VirtualMachine) -> PyResult { PyRange { start: PyInt::new(BigInt::zero()).into_ref(vm), @@ -403,8 +402,8 @@ impl PyRange { } } - #[pymethod(name = "__new__")] - fn range_new(args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { + #[pyslot(new)] + fn tp_new(args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { let range = if args.args.len() <= 2 { let (cls, stop) = args.bind(vm)?; PyRange::new(cls, stop, vm) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index f7cbc85013..4f7cba0550 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -324,8 +324,8 @@ macro_rules! try_set_inner { #[pyimpl] impl PySet { - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, iterable: OptionalArg, vm: &VirtualMachine, @@ -570,8 +570,8 @@ impl PySet { #[pyimpl] impl PyFrozenSet { - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, iterable: OptionalArg, vm: &VirtualMachine, diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index 63b145b397..32542c5fdf 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -87,8 +87,8 @@ impl PySlice { } } - #[pymethod(name = "__new__")] - fn slice_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { + #[pyslot(new)] + fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { let slice: PySlice = match args.args.len() { 0 => { return Err( diff --git a/vm/src/obj/objstaticmethod.rs b/vm/src/obj/objstaticmethod.rs index e330652a5e..d721112f75 100644 --- a/vm/src/obj/objstaticmethod.rs +++ b/vm/src/obj/objstaticmethod.rs @@ -35,6 +35,6 @@ pub fn init(context: &PyContext) { let staticmethod_type = &context.types.staticmethod_type; extend_class!(context, staticmethod_type, { "__get__" => context.new_rustfunc(PyStaticMethodRef::get), - "__new__" => context.new_rustfunc(PyStaticMethodRef::new), + (slot new) => PyStaticMethodRef::new, }); } diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 5ff1424f33..79181c497b 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -176,8 +176,8 @@ impl PyString { // TODO: should with following format // class str(object='') // class str(object=b'', encoding='utf-8', errors='strict') - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, object: OptionalArg, vm: &VirtualMachine, diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index 25c8caffc0..6c29d433cc 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -78,8 +78,8 @@ impl PySuper { } } - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, py_type: OptionalArg, py_obj: OptionalArg, diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index 9ce483fc01..a31c44b12d 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -290,7 +290,7 @@ If the argument is a tuple, the return value is the same object."; "__hash__" => context.new_rustfunc(PyTupleRef::hash), "__iter__" => context.new_rustfunc(PyTupleRef::iter), "__len__" => context.new_rustfunc(PyTupleRef::len), - "__new__" => context.new_rustfunc(tuple_new), + (slot new) => tuple_new, "__mul__" => context.new_rustfunc(PyTupleRef::mul), "__rmul__" => context.new_rustfunc(PyTupleRef::rmul), "__repr__" => context.new_rustfunc(PyTupleRef::repr), diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index 48e50bcc9d..6e83e2426b 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::fmt; -use crate::function::{Args, KwArgs, PyFuncArgs}; +use crate::function::{PyFuncArgs, PyNativeFunc}; use crate::pyobject::{ IdProtocol, PyAttributes, PyContext, PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, @@ -23,6 +23,17 @@ pub struct PyClass { pub mro: Vec, pub subclasses: RefCell>, pub attributes: RefCell, + pub slots: RefCell, +} + +#[derive(Default)] +pub struct PyClassSlots { + pub new: Option, +} +impl fmt::Debug for PyClassSlots { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("PyClassSlots") + } } impl fmt::Display for PyClass { @@ -149,8 +160,7 @@ impl PyClassRef { if let Some(attr) = class_get_attr(&self, &name) { let attr_class = attr.class(); if let Some(ref descriptor) = class_get_attr(&attr_class, "__get__") { - let none = vm.get_none(); - return vm.invoke(descriptor, vec![attr, none, self.into_object()]); + return vm.invoke(descriptor, vec![attr, vm.get_none(), self.into_object()]); } } @@ -203,6 +213,11 @@ impl PyClassRef { } } +fn type_mro(cls: PyClassRef, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx + .new_list(cls.mro.iter().map(|x| x.clone().into_object()).collect()) +} + /* * The magical type type */ @@ -213,13 +228,14 @@ pub fn init(ctx: &PyContext) { type(name, bases, dict) -> a new type"; extend_class!(&ctx, &ctx.types.type_type, { + "mro" => ctx.new_rustfunc(type_mro), "__call__" => ctx.new_rustfunc(type_call), "__dict__" => PropertyBuilder::new(ctx) .add_getter(type_dict) .add_setter(type_dict_setter) .create(), - "__new__" => ctx.new_rustfunc(type_new), + (slot new) => type_new_slot, "__mro__" => PropertyBuilder::new(ctx) .add_getter(PyClassRef::mro) @@ -260,39 +276,84 @@ pub fn issubclass(subclass: &PyClassRef, cls: &PyClassRef) -> bool { subclass.is(cls) || mro.iter().any(|c| c.is(cls.as_object())) } -pub fn type_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { +fn type_new_slot(metatype: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { vm_trace!("type.__new__ {:?}", args); - if args.args.len() == 2 { - Ok(args.args[1].class().into_object()) - } else if args.args.len() == 4 { - let (typ, name, bases, dict) = args.bind(vm)?; - type_new_class(vm, typ, name, bases, dict).map(PyRef::into_object) - } else { - Err(vm.new_type_error("type() takes 1 or 3 arguments".to_string())) - } -} -pub fn type_new_class( - vm: &VirtualMachine, - typ: PyClassRef, - name: PyStringRef, - bases: PyIterable, - dict: PyDictRef, -) -> PyResult { + if metatype.is(&vm.ctx.types.type_type) { + if args.args.len() == 1 && args.kwargs.is_empty() { + return Ok(args.args[0].class().into_object()); + } + if args.args.len() != 3 { + return Err(vm.new_type_error("type() takes 1 or 3 arguments".to_string())); + } + } + + let (name, bases, dict): (PyStringRef, PyIterable, PyDictRef) = args.bind(vm)?; + let mut bases: Vec = bases.iter(vm)?.collect::, _>>()?; bases.push(vm.ctx.object()); - new(typ.clone(), name.as_str(), bases, dict.to_attributes()) + + let attributes = dict.to_attributes(); + + let mut winner = metatype.clone(); + for base in &bases { + let base_type = base.class(); + if issubclass(&winner, &base_type) { + continue; + } else if issubclass(&base_type, &winner) { + winner = base_type.clone(); + continue; + } + + return Err(vm.new_type_error( + "metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass \ + of the metaclasses of all its bases" + .to_string(), + )); + } + + new(winner, name.as_str(), bases, attributes).map(Into::into) } -pub fn type_call(class: PyClassRef, args: Args, kwargs: KwArgs, vm: &VirtualMachine) -> PyResult { +pub fn type_new( + zelf: PyClassRef, + cls: PyClassRef, + args: PyFuncArgs, + vm: &VirtualMachine, +) -> PyResult { + if !issubclass(&cls, &zelf) { + return Err(vm.new_type_error(format!( + "{zelf}.__new__({cls}): {cls} is not a subtype of {zelf}", + zelf = zelf.name, + cls = cls.name, + ))); + } + + let class_with_new_slot = if cls.slots.borrow().new.is_some() { + cls.clone() + } else { + cls.mro + .iter() + .cloned() + .find(|cls| cls.slots.borrow().new.is_some()) + .expect("Should be able to find a new slot somewhere in the mro") + }; + + let slots = class_with_new_slot.slots.borrow(); + let new = slots.new.as_ref().unwrap(); + + new(vm, args.insert(cls.into_object())) +} + +pub fn type_call(class: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { vm_trace!("type_call: {:?}", class); - let new = class_get_attr(&class, "__new__").expect("All types should have a __new__."); - let new_wrapped = vm.call_get_descriptor(new, class.into_object())?; - let obj = vm.invoke(&new_wrapped, (&args, &kwargs))?; + let new = vm.get_attribute(class.as_object().clone(), "__new__")?; + let new_args = args.insert(class.into_object()); + let obj = vm.invoke(&new, new_args)?; if let Some(init_method_or_err) = vm.get_method(obj.clone(), "__init__") { let init_method = init_method_or_err?; - let res = vm.invoke(&init_method, (&args, &kwargs))?; + let res = vm.invoke(&init_method, args)?; if !res.is(&vm.get_none()) { return Err(vm.new_type_error("__init__ must return None".to_string())); } @@ -314,15 +375,19 @@ fn type_dict_setter(_instance: PyClassRef, _value: PyObjectRef, vm: &VirtualMach pub fn class_get_attr(class: &PyClassRef, attr_name: &str) -> Option { flame_guard!(format!("class_get_attr({:?})", attr_name)); - if let Some(item) = class.attributes.borrow().get(attr_name).cloned() { - return Some(item); - } - for class in &class.mro { - if let Some(item) = class.attributes.borrow().get(attr_name).cloned() { - return Some(item); - } - } - None + class + .attributes + .borrow() + .get(attr_name) + .cloned() + .or_else(|| class_get_super_attr(class, attr_name)) +} + +pub fn class_get_super_attr(class: &PyClassRef, attr_name: &str) -> Option { + class + .mro + .iter() + .find_map(|class| class.attributes.borrow().get(attr_name).cloned()) } // This is the internal has_attr implementation for fast lookup on a class. @@ -404,8 +469,9 @@ pub fn new( payload: PyClass { name: String::from(name), mro, - subclasses: RefCell::new(vec![]), + subclasses: RefCell::default(), attributes: RefCell::new(dict), + slots: RefCell::default(), }, dict: None, typ, diff --git a/vm/src/obj/objweakproxy.rs b/vm/src/obj/objweakproxy.rs index f510a03efa..1f1de31a9d 100644 --- a/vm/src/obj/objweakproxy.rs +++ b/vm/src/obj/objweakproxy.rs @@ -21,8 +21,8 @@ pub type PyWeakProxyRef = PyRef; #[pyimpl] impl PyWeakProxy { // TODO: callbacks - #[pymethod(name = "__new__")] - fn create( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, referent: PyObjectRef, callback: OptionalArg, diff --git a/vm/src/obj/objweakref.rs b/vm/src/obj/objweakref.rs index ae4375cd0e..64a0ae8392 100644 --- a/vm/src/obj/objweakref.rs +++ b/vm/src/obj/objweakref.rs @@ -49,7 +49,7 @@ impl PyWeakRef { pub fn init(context: &PyContext) { extend_class!(context, &context.types.weakref_type, { - "__new__" => context.new_rustfunc(PyWeakRef::create), + (slot new) => PyWeakRef::create, "__call__" => context.new_rustfunc(PyWeakRef::call) }); } diff --git a/vm/src/obj/objzip.rs b/vm/src/obj/objzip.rs index 0223b8bf23..1749b592e8 100644 --- a/vm/src/obj/objzip.rs +++ b/vm/src/obj/objzip.rs @@ -53,6 +53,6 @@ impl PyZip { pub fn init(context: &PyContext) { PyZip::extend_class(context, &context.types.zip_type); extend_class!(context, &context.types.zip_type, { - "__new__" => context.new_rustfunc(zip_new), + (slot new) => zip_new, }); } diff --git a/vm/src/stdlib/array.rs b/vm/src/stdlib/array.rs index 6654a4447d..bc99720a95 100644 --- a/vm/src/stdlib/array.rs +++ b/vm/src/stdlib/array.rs @@ -190,8 +190,8 @@ impl PyValue for PyArray { #[pyimpl] impl PyArray { - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, spec: PyStringRef, init: OptionalArg, diff --git a/vm/src/stdlib/ast.rs b/vm/src/stdlib/ast.rs index 32986a24cd..b149088f7a 100644 --- a/vm/src/stdlib/ast.rs +++ b/vm/src/stdlib/ast.rs @@ -637,6 +637,7 @@ fn ast_parse(source: PyStringRef, vm: &VirtualMachine) -> PyResult { program_to_ast(&vm, &internal_ast) } +#[allow(clippy::cognitive_complexity)] pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 422269c967..76112b5e90 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -28,8 +28,8 @@ struct PyDequeOptions { #[pyimpl] impl PyDeque { - #[pymethod(name = "__new__")] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, iter: OptionalArg, PyDequeOptions { maxlen }: PyDequeOptions, diff --git a/vm/src/stdlib/hashlib.rs b/vm/src/stdlib/hashlib.rs index 6632d27747..62b1c4623c 100644 --- a/vm/src/stdlib/hashlib.rs +++ b/vm/src/stdlib/hashlib.rs @@ -41,8 +41,8 @@ impl PyHasher { } } - #[pymethod(name = "__new__")] - fn py_new(_cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { + #[pyslot(new)] + fn tp_new(_cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { Ok(PyHasher::new("md5", HashWrapper::md5()) .into_ref(vm) .into_object()) diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 350ff43ba5..44d6dc65e9 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -753,7 +753,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { //StringIO: in-memory text let string_io = py_class!(ctx, "StringIO", text_io_base.clone(), { - "__new__" => ctx.new_rustfunc(string_io_new), + (slot new) => string_io_new, "seek" => ctx.new_rustfunc(PyStringIORef::seek), "seekable" => ctx.new_rustfunc(PyStringIORef::seekable), "read" => ctx.new_rustfunc(PyStringIORef::read), @@ -763,7 +763,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { //BytesIO: in-memory bytes let bytes_io = py_class!(ctx, "BytesIO", buffered_io_base.clone(), { - "__new__" => ctx.new_rustfunc(bytes_io_new), + (slot new) => bytes_io_new, "read" => ctx.new_rustfunc(PyBytesIORef::read), "read1" => ctx.new_rustfunc(PyBytesIORef::read), "seek" => ctx.new_rustfunc(PyBytesIORef::seek), diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ddac6cc263..55ae1bdb29 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -30,15 +30,13 @@ impl PyValue for PyItertoolsChain { #[pyimpl] impl PyItertoolsChain { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new(_cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { - Ok(PyItertoolsChain { + #[pyslot(new)] + fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { + PyItertoolsChain { iterables: args.args, cur: RefCell::new((0, None)), } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -88,23 +86,21 @@ impl PyValue for PyItertoolsCompress { #[pyimpl] impl PyItertoolsCompress { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( - _cls: PyClassRef, + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, data: PyObjectRef, selector: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let data_iter = get_iter(vm, &data)?; let selector_iter = get_iter(vm, &selector)?; - Ok(PyItertoolsCompress { + PyItertoolsCompress { data: data_iter, selector: selector_iter, } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -141,14 +137,13 @@ impl PyValue for PyItertoolsCount { #[pyimpl] impl PyItertoolsCount { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( - _cls: PyClassRef, + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, start: OptionalArg, step: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let start = match start.into_option() { Some(int) => int.as_bigint().clone(), None => BigInt::from(0), @@ -158,12 +153,11 @@ impl PyItertoolsCount { None => BigInt::from(1), }; - Ok(PyItertoolsCount { + PyItertoolsCount { cur: RefCell::new(start), step, } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -194,25 +188,23 @@ impl PyValue for PyItertoolsRepeat { #[pyimpl] impl PyItertoolsRepeat { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( - _cls: PyClassRef, + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, object: PyObjectRef, times: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let times = match times.into_option() { Some(int) => Some(RefCell::new(int.as_bigint().clone())), None => None, }; - Ok(PyItertoolsRepeat { + PyItertoolsRepeat { object: object.clone(), times, } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -253,19 +245,16 @@ impl PyValue for PyItertoolsStarmap { #[pyimpl] impl PyItertoolsStarmap { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( - _cls: PyClassRef, + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, function: PyObjectRef, iterable: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let iter = get_iter(vm, &iterable)?; - Ok(PyItertoolsStarmap { function, iter } - .into_ref(vm) - .into_object()) + PyItertoolsStarmap { function, iter }.into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -298,23 +287,21 @@ impl PyValue for PyItertoolsTakewhile { #[pyimpl] impl PyItertoolsTakewhile { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( - _cls: PyClassRef, + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, predicate: PyObjectRef, iterable: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let iter = get_iter(vm, &iterable)?; - Ok(PyItertoolsTakewhile { + PyItertoolsTakewhile { predicate, iterable: iter, stop_flag: RefCell::new(false), } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -357,18 +344,15 @@ impl PyValue for PyItertoolsDropwhile { } } -type PyItertoolsDropwhileRef = PyRef; - #[pyimpl] impl PyItertoolsDropwhile { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, predicate: PyCallable, iterable: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let iter = get_iter(vm, &iterable)?; PyItertoolsDropwhile { @@ -431,9 +415,8 @@ fn pyobject_to_opt_usize(obj: PyObjectRef, vm: &VirtualMachine) -> Option #[pyimpl] impl PyItertoolsIslice { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new(_cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { + #[pyslot(new)] + fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { let (iter, start, stop, step) = match args.args.len() { 0 | 1 => { return Err(vm.new_type_error(format!( @@ -492,15 +475,14 @@ impl PyItertoolsIslice { let iter = get_iter(vm, &iter)?; - Ok(PyItertoolsIslice { + PyItertoolsIslice { iterable: iter, cur: RefCell::new(0), next: RefCell::new(start), stop, step, } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -547,22 +529,20 @@ impl PyValue for PyItertoolsFilterFalse { #[pyimpl] impl PyItertoolsFilterFalse { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( - _cls: PyClassRef, + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, predicate: PyObjectRef, iterable: PyObjectRef, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let iter = get_iter(vm, &iterable)?; - Ok(PyItertoolsFilterFalse { + PyItertoolsFilterFalse { predicate, iterable: iter, } - .into_ref(vm) - .into_object()) + .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] @@ -606,14 +586,13 @@ impl PyValue for PyItertoolsAccumulate { #[pyimpl] impl PyItertoolsAccumulate { - #[pymethod(name = "__new__")] - #[allow(clippy::new_ret_no_self)] - fn new( + #[pyslot(new)] + fn tp_new( cls: PyClassRef, iterable: PyObjectRef, binop: OptionalArg, vm: &VirtualMachine, - ) -> PyResult> { + ) -> PyResult> { let iter = get_iter(vm, &iterable)?; PyItertoolsAccumulate { diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 5f4079cd83..a8efc87ae7 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -566,7 +566,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let socket_gaierror = ctx.new_class("socket.gaierror", vm.ctx.exceptions.os_error.clone()); let socket = py_class!(ctx, "socket", ctx.object(), { - "__new__" => ctx.new_rustfunc(SocketRef::new), + (slot new) => SocketRef::new, "__enter__" => ctx.new_rustfunc(SocketRef::enter), "__exit__" => ctx.new_rustfunc(SocketRef::exit), "connect" => ctx.new_rustfunc(SocketRef::connect), diff --git a/vm/src/stdlib/subprocess.rs b/vm/src/stdlib/subprocess.rs index 3af308f7c5..466f84555c 100644 --- a/vm/src/stdlib/subprocess.rs +++ b/vm/src/stdlib/subprocess.rs @@ -194,7 +194,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let timeout_expired = ctx.new_class("TimeoutExpired", subprocess_error.clone()); let popen = py_class!(ctx, "Popen", ctx.object(), { - "__new__" => ctx.new_rustfunc(PopenRef::new), + (slot new) => PopenRef::new, "poll" => ctx.new_rustfunc(PopenRef::poll), "returncode" => ctx.new_property(PopenRef::return_code), "wait" => ctx.new_rustfunc(PopenRef::wait), diff --git a/vm/src/types.rs b/vm/src/types.rs index 048b3d8adb..e978df7106 100644 --- a/vm/src/types.rs +++ b/vm/src/types.rs @@ -233,8 +233,9 @@ fn init_type_hierarchy() -> (PyClassRef, PyClassRef) { payload: PyClass { name: String::from("object"), mro: vec![], - subclasses: RefCell::new(vec![]), + subclasses: RefCell::default(), attributes: RefCell::new(PyAttributes::new()), + slots: RefCell::default(), }, } .into_ref(); @@ -245,8 +246,9 @@ fn init_type_hierarchy() -> (PyClassRef, PyClassRef) { payload: PyClass { name: String::from("type"), mro: vec![object_type.clone().downcast().unwrap()], - subclasses: RefCell::new(vec![]), + subclasses: RefCell::default(), attributes: RefCell::new(PyAttributes::new()), + slots: RefCell::default(), }, } .into_ref();