From 95a947d7e35e7dc8a8cb1c4da30c90a052f6bec1 Mon Sep 17 00:00:00 2001 From: Noa <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:07:50 -0500 Subject: [PATCH 1/4] __class_getitem__ isn't a special method --- vm/src/builtins/make_module.rs | 46 +++++++++++++++++++++++++++++++--- vm/src/pyobject.rs | 22 ++++++---------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/vm/src/builtins/make_module.rs b/vm/src/builtins/make_module.rs index f13f3f083..fcf9098ef 100644 --- a/vm/src/builtins/make_module.rs +++ b/vm/src/builtins/make_module.rs @@ -20,7 +20,7 @@ mod decl { use crate::builtins::pybool::IntoPyBool; use crate::builtins::pystr::{PyStr, PyStrRef}; use crate::builtins::pytype::PyTypeRef; - use crate::builtins::{PyByteArray, PyBytes}; + use crate::builtins::{PyByteArray, PyBytes, PyTupleRef}; use crate::byteslike::ArgBytesLike; use crate::common::{hash::PyHash, str::to_ascii}; #[cfg(feature = "rustpython-compiler")] @@ -832,7 +832,7 @@ mod decl { pub fn __build_class__( function: PyFunctionRef, qualified_name: PyStrRef, - bases: Args, + bases: Args, mut kwargs: KwArgs, vm: &VirtualMachine, ) -> PyResult { @@ -845,7 +845,41 @@ mod decl { vm.ctx.types.type_type.clone() }; - for base in bases.clone() { + let mut new_bases: Option> = None; + + let bases = PyTupleRef::with_elements(bases.into_vec(), &vm.ctx); + + for (i, base) in bases.as_slice().iter().enumerate() { + if base.isinstance(&vm.ctx.types.type_type) { + if let Some(bases) = &mut new_bases { + bases.push(base.clone()); + } + continue; + } + let mro_entries = vm.get_attribute_opt(base.clone(), "__mro_entries__")?; + let entries = match mro_entries { + Some(meth) => vm.invoke(&meth, (bases.clone(),))?, + None => { + if let Some(bases) = &mut new_bases { + bases.push(base.clone()); + } + continue; + } + }; + let entries: PyTupleRef = entries + .downcast() + .map_err(|_| vm.new_type_error("__mro_entries__ must return a tuple".to_owned()))?; + let new_bases = new_bases.get_or_insert_with(|| bases.as_slice()[..i].to_vec()); + new_bases.extend_from_slice(entries.as_slice()); + } + + let new_bases = new_bases.map(|v| PyTupleRef::with_elements(v, &vm.ctx)); + let (orig_bases, bases) = match new_bases { + Some(new) => (Some(bases), new), + None => (None, bases), + }; + + for base in bases.as_slice().iter() { let base_class = base.class(); if base_class.issubclass(&metaclass) { metaclass = base.clone_class(); @@ -858,7 +892,7 @@ mod decl { } } - let bases = bases.into_tuple(vm); + let bases = bases.into_object(); // Prepare uses full __getattribute__ resolution chain. let prepare = vm.get_attribute(metaclass.clone().into_object(), "__prepare__")?; @@ -872,6 +906,10 @@ mod decl { let classcell = function.invoke_with_locals(().into(), Some(namespace.clone()), vm)?; let classcell = >::try_from_object(vm, classcell)?; + if let Some(orig_bases) = orig_bases { + namespace.set_item("__orig_bases__", orig_bases.into_object(), vm)?; + } + let class = vm.invoke( metaclass.as_object(), FuncArgs::new(vec![name_obj, bases, namespace.into_object()], kwargs), diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index da8525187..cdd89e66c 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -686,25 +686,19 @@ where { fn get_item(&self, key: T, vm: &VirtualMachine) -> PyResult { match vm.get_special_method(self.clone(), "__getitem__")? { - Ok(special_method) => special_method.invoke((key,), vm), + Ok(special_method) => return special_method.invoke((key,), vm), Err(obj) => { if obj.isinstance(&vm.ctx.types.type_type) { - vm.get_special_method(obj, "__class_getitem__")? - .map_err(|obj2| { - vm.new_type_error(format!( - "'{}' object is not subscriptable", - obj2.class().name - )) - })? - .invoke((key,), vm) - } else { - Err(vm.new_type_error(format!( - "'{}' object is not subscriptable", - obj.class().name - ))) + if let Some(class_getitem) = vm.get_attribute_opt(obj, "__class_getitem__")? { + return vm.invoke(&class_getitem, (key,)); + } } } } + Err(vm.new_type_error(format!( + "'{}' object is not subscriptable", + self.class().name + ))) } fn set_item(&self, key: T, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { From 0494472c0eaafecafd078a8a68d99a09c8129bf4 Mon Sep 17 00:00:00 2001 From: Noa <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:15:58 -0500 Subject: [PATCH 2/4] Don't define object.__class_getitem__ --- vm/src/builtins/object.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index 284d46982..6f47968dd 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -229,11 +229,6 @@ impl PyBaseObject { #[pyclassmethod(magic)] fn init_subclass(_cls: PyTypeRef) {} - #[pyclassmethod(magic)] - fn class_getitem(cls: PyTypeRef, _args: FuncArgs) -> PyObjectRef { - cls.into_object() - } - #[pymethod(magic)] pub fn dir(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { let attributes: PyAttributes = obj.class().get_attributes(); From 5cfdb3361ebe45fe061693b150200fd93a1d5668 Mon Sep 17 00:00:00 2001 From: Noa <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:33:31 -0500 Subject: [PATCH 3/4] Add test.test_genericclass from CPython 3.8 --- Lib/test/test_genericclass.py | 286 ++++++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 Lib/test/test_genericclass.py diff --git a/Lib/test/test_genericclass.py b/Lib/test/test_genericclass.py new file mode 100644 index 000000000..27420d4f2 --- /dev/null +++ b/Lib/test/test_genericclass.py @@ -0,0 +1,286 @@ +import unittest +from test import support + + +class TestMROEntry(unittest.TestCase): + def test_mro_entry_signature(self): + tested = [] + class B: ... + class C: + def __mro_entries__(self, *args, **kwargs): + tested.extend([args, kwargs]) + return (C,) + c = C() + self.assertEqual(tested, []) + class D(B, c): ... + self.assertEqual(tested[0], ((B, c),)) + self.assertEqual(tested[1], {}) + + def test_mro_entry(self): + tested = [] + class A: ... + class B: ... + class C: + def __mro_entries__(self, bases): + tested.append(bases) + return (self.__class__,) + c = C() + self.assertEqual(tested, []) + class D(A, c, B): ... + self.assertEqual(tested[-1], (A, c, B)) + self.assertEqual(D.__bases__, (A, C, B)) + self.assertEqual(D.__orig_bases__, (A, c, B)) + self.assertEqual(D.__mro__, (D, A, C, B, object)) + d = D() + class E(d): ... + self.assertEqual(tested[-1], (d,)) + self.assertEqual(E.__bases__, (D,)) + + def test_mro_entry_none(self): + tested = [] + class A: ... + class B: ... + class C: + def __mro_entries__(self, bases): + tested.append(bases) + return () + c = C() + self.assertEqual(tested, []) + class D(A, c, B): ... + self.assertEqual(tested[-1], (A, c, B)) + self.assertEqual(D.__bases__, (A, B)) + self.assertEqual(D.__orig_bases__, (A, c, B)) + self.assertEqual(D.__mro__, (D, A, B, object)) + class E(c): ... + self.assertEqual(tested[-1], (c,)) + self.assertEqual(E.__bases__, (object,)) + self.assertEqual(E.__orig_bases__, (c,)) + self.assertEqual(E.__mro__, (E, object)) + + def test_mro_entry_with_builtins(self): + tested = [] + class A: ... + class C: + def __mro_entries__(self, bases): + tested.append(bases) + return (dict,) + c = C() + self.assertEqual(tested, []) + class D(A, c): ... + self.assertEqual(tested[-1], (A, c)) + self.assertEqual(D.__bases__, (A, dict)) + self.assertEqual(D.__orig_bases__, (A, c)) + self.assertEqual(D.__mro__, (D, A, dict, object)) + + def test_mro_entry_with_builtins_2(self): + tested = [] + class C: + def __mro_entries__(self, bases): + tested.append(bases) + return (C,) + c = C() + self.assertEqual(tested, []) + class D(c, dict): ... + self.assertEqual(tested[-1], (c, dict)) + self.assertEqual(D.__bases__, (C, dict)) + self.assertEqual(D.__orig_bases__, (c, dict)) + self.assertEqual(D.__mro__, (D, C, dict, object)) + + def test_mro_entry_errors(self): + class C_too_many: + def __mro_entries__(self, bases, something, other): + return () + c = C_too_many() + with self.assertRaises(TypeError): + class D(c): ... + class C_too_few: + def __mro_entries__(self): + return () + d = C_too_few() + with self.assertRaises(TypeError): + class D(d): ... + + def test_mro_entry_errors_2(self): + class C_not_callable: + __mro_entries__ = "Surprise!" + c = C_not_callable() + with self.assertRaises(TypeError): + class D(c): ... + class C_not_tuple: + def __mro_entries__(self): + return object + c = C_not_tuple() + with self.assertRaises(TypeError): + class D(c): ... + + def test_mro_entry_metaclass(self): + meta_args = [] + class Meta(type): + def __new__(mcls, name, bases, ns): + meta_args.extend([mcls, name, bases, ns]) + return super().__new__(mcls, name, bases, ns) + class A: ... + class C: + def __mro_entries__(self, bases): + return (A,) + c = C() + class D(c, metaclass=Meta): + x = 1 + self.assertEqual(meta_args[0], Meta) + self.assertEqual(meta_args[1], 'D') + self.assertEqual(meta_args[2], (A,)) + self.assertEqual(meta_args[3]['x'], 1) + self.assertEqual(D.__bases__, (A,)) + self.assertEqual(D.__orig_bases__, (c,)) + self.assertEqual(D.__mro__, (D, A, object)) + self.assertEqual(D.__class__, Meta) + + def test_mro_entry_type_call(self): + # Substitution should _not_ happen in direct type call + class C: + def __mro_entries__(self, bases): + return () + c = C() + with self.assertRaisesRegex(TypeError, + "MRO entry resolution; " + "use types.new_class()"): + type('Bad', (c,), {}) + + +class TestClassGetitem(unittest.TestCase): + def test_class_getitem(self): + getitem_args = [] + class C: + def __class_getitem__(*args, **kwargs): + getitem_args.extend([args, kwargs]) + return None + C[int, str] + self.assertEqual(getitem_args[0], (C, (int, str))) + self.assertEqual(getitem_args[1], {}) + + def test_class_getitem_format(self): + class C: + def __class_getitem__(cls, item): + return f'C[{item.__name__}]' + self.assertEqual(C[int], 'C[int]') + self.assertEqual(C[C], 'C[C]') + + def test_class_getitem_inheritance(self): + class C: + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + class D(C): ... + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_inheritance_2(self): + class C: + def __class_getitem__(cls, item): + return 'Should not see this' + class D(C): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_classmethod(self): + class C: + @classmethod + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + class D(C): ... + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_patched(self): + class C: + def __init_subclass__(cls): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + cls.__class_getitem__ = classmethod(__class_getitem__) + class D(C): ... + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_with_builtins(self): + class A(dict): + called_with = None + + def __class_getitem__(cls, item): + cls.called_with = item + class B(A): + pass + self.assertIs(B.called_with, None) + B[int] + self.assertIs(B.called_with, int) + + def test_class_getitem_errors(self): + class C_too_few: + def __class_getitem__(cls): + return None + with self.assertRaises(TypeError): + C_too_few[int] + class C_too_many: + def __class_getitem__(cls, one, two): + return None + with self.assertRaises(TypeError): + C_too_many[int] + + def test_class_getitem_errors_2(self): + class C: + def __class_getitem__(cls, item): + return None + with self.assertRaises(TypeError): + C()[int] + class E: ... + e = E() + e.__class_getitem__ = lambda cls, item: 'This will not work' + with self.assertRaises(TypeError): + e[int] + class C_not_callable: + __class_getitem__ = "Surprise!" + with self.assertRaises(TypeError): + C_not_callable[int] + + def test_class_getitem_metaclass(self): + class Meta(type): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + self.assertEqual(Meta[int], 'Meta[int]') + + def test_class_getitem_with_metaclass(self): + class Meta(type): pass + class C(metaclass=Meta): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + self.assertEqual(C[int], 'C[int]') + + def test_class_getitem_metaclass_first(self): + class Meta(type): + def __getitem__(cls, item): + return 'from metaclass' + class C(metaclass=Meta): + def __class_getitem__(cls, item): + return 'from __class_getitem__' + self.assertEqual(C[int], 'from metaclass') + + +@support.cpython_only +class CAPITest(unittest.TestCase): + + def test_c_class(self): + from _testcapi import Generic, GenericAlias + self.assertIsInstance(Generic.__class_getitem__(int), GenericAlias) + + IntGeneric = Generic[int] + self.assertIs(type(IntGeneric), GenericAlias) + self.assertEqual(IntGeneric.__mro_entries__(()), (int,)) + class C(IntGeneric): + pass + self.assertEqual(C.__bases__, (int,)) + self.assertEqual(C.__orig_bases__, (IntGeneric,)) + self.assertEqual(C.__mro__, (C, int, object)) + + +if __name__ == "__main__": + unittest.main() From 4a71420f4ff891171c715313509b069426511fa5 Mon Sep 17 00:00:00 2001 From: Noa <33094578+coolreader18@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:48:25 -0500 Subject: [PATCH 4/4] Give correct error message for type with __mro_entries__ passed to type() --- vm/src/builtins/pytype.rs | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/vm/src/builtins/pytype.rs b/vm/src/builtins/pytype.rs index 25f3d2698..03deb756d 100644 --- a/vm/src/builtins/pytype.rs +++ b/vm/src/builtins/pytype.rs @@ -13,9 +13,8 @@ use super::mappingproxy::PyMappingProxy; use super::object; use super::pystr::{PyStr, PyStrRef}; use super::staticmethod::PyStaticMethod; -use super::tuple::PyTuple; +use super::tuple::{PyTuple, PyTupleRef}; use super::weakref::PyWeak; -use crate::builtins::tuple::PyTupleTyped; use crate::function::{FuncArgs, KwArgs, OptionalArg}; use crate::slots::{self, Callable, PyTpFlags, PyTypeSlots, SlotGetattro, SlotSetattro}; use crate::utils::Either; @@ -399,7 +398,7 @@ impl PyType { })); } - let (name, bases, dict, kwargs): (PyStrRef, PyTupleTyped, PyDictRef, KwArgs) = + let (name, bases, dict, kwargs): (PyStrRef, PyTupleRef, PyDictRef, KwArgs) = args.clone().bind(vm)?; let bases = bases.as_slice(); @@ -407,16 +406,25 @@ impl PyType { let base = vm.ctx.types.object_type.clone(); (metatype, base.clone(), vec![base]) } else { - // TODO - // for base in &bases { - // if PyType_Check(base) { continue; } - // _PyObject_LookupAttrId(base, PyId___mro_entries__, &base)? - // Err(new_type_error( "type() doesn't support MRO entry resolution; " - // "use types.new_class()")) - // } + let bases = bases + .iter() + .map(|obj| { + obj.clone().downcast::().or_else(|obj| { + if vm.get_attribute_opt(obj, "__mro_entries__")?.is_some() { + Err(vm.new_type_error( + "type() doesn't support MRO entry resolution; \ + use types.new_class()" + .to_owned(), + )) + } else { + Err(vm.new_type_error("bases must be types".to_owned())) + } + }) + }) + .collect::>>()?; // Search the bases for the proper metatype to deal with this: - let winner = calculate_meta_class(metatype.clone(), bases, vm)?; + let winner = calculate_meta_class(metatype.clone(), &bases, vm)?; let metatype = if !winner.is(&metatype) { #[allow(clippy::redundant_clone)] // false positive if let Some(ref tp_new) = winner.clone().slots.new { @@ -429,9 +437,9 @@ impl PyType { metatype }; - let base = best_base(bases, vm)?; + let base = best_base(&bases, vm)?; - (metatype, base, bases.to_vec()) + (metatype, base, bases) }; let mut attributes = dict.to_attributes();