Merge pull request #2924 from RustPython/fix-class_getitem

Fix __class_getitem__ slightly
This commit is contained in:
Jeong YunWon
2021-08-22 13:51:27 +09:00
committed by GitHub
5 changed files with 357 additions and 36 deletions

View File

@@ -0,0 +1,286 @@
import unittest
from test import support
class TestMROEntry(unittest.TestCase):
def test_mro_entry_signature(self):
tested = []
class B: ...
class C:
def __mro_entries__(self, *args, **kwargs):
tested.extend([args, kwargs])
return (C,)
c = C()
self.assertEqual(tested, [])
class D(B, c): ...
self.assertEqual(tested[0], ((B, c),))
self.assertEqual(tested[1], {})
def test_mro_entry(self):
tested = []
class A: ...
class B: ...
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return (self.__class__,)
c = C()
self.assertEqual(tested, [])
class D(A, c, B): ...
self.assertEqual(tested[-1], (A, c, B))
self.assertEqual(D.__bases__, (A, C, B))
self.assertEqual(D.__orig_bases__, (A, c, B))
self.assertEqual(D.__mro__, (D, A, C, B, object))
d = D()
class E(d): ...
self.assertEqual(tested[-1], (d,))
self.assertEqual(E.__bases__, (D,))
def test_mro_entry_none(self):
tested = []
class A: ...
class B: ...
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return ()
c = C()
self.assertEqual(tested, [])
class D(A, c, B): ...
self.assertEqual(tested[-1], (A, c, B))
self.assertEqual(D.__bases__, (A, B))
self.assertEqual(D.__orig_bases__, (A, c, B))
self.assertEqual(D.__mro__, (D, A, B, object))
class E(c): ...
self.assertEqual(tested[-1], (c,))
self.assertEqual(E.__bases__, (object,))
self.assertEqual(E.__orig_bases__, (c,))
self.assertEqual(E.__mro__, (E, object))
def test_mro_entry_with_builtins(self):
tested = []
class A: ...
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return (dict,)
c = C()
self.assertEqual(tested, [])
class D(A, c): ...
self.assertEqual(tested[-1], (A, c))
self.assertEqual(D.__bases__, (A, dict))
self.assertEqual(D.__orig_bases__, (A, c))
self.assertEqual(D.__mro__, (D, A, dict, object))
def test_mro_entry_with_builtins_2(self):
tested = []
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return (C,)
c = C()
self.assertEqual(tested, [])
class D(c, dict): ...
self.assertEqual(tested[-1], (c, dict))
self.assertEqual(D.__bases__, (C, dict))
self.assertEqual(D.__orig_bases__, (c, dict))
self.assertEqual(D.__mro__, (D, C, dict, object))
def test_mro_entry_errors(self):
class C_too_many:
def __mro_entries__(self, bases, something, other):
return ()
c = C_too_many()
with self.assertRaises(TypeError):
class D(c): ...
class C_too_few:
def __mro_entries__(self):
return ()
d = C_too_few()
with self.assertRaises(TypeError):
class D(d): ...
def test_mro_entry_errors_2(self):
class C_not_callable:
__mro_entries__ = "Surprise!"
c = C_not_callable()
with self.assertRaises(TypeError):
class D(c): ...
class C_not_tuple:
def __mro_entries__(self):
return object
c = C_not_tuple()
with self.assertRaises(TypeError):
class D(c): ...
def test_mro_entry_metaclass(self):
meta_args = []
class Meta(type):
def __new__(mcls, name, bases, ns):
meta_args.extend([mcls, name, bases, ns])
return super().__new__(mcls, name, bases, ns)
class A: ...
class C:
def __mro_entries__(self, bases):
return (A,)
c = C()
class D(c, metaclass=Meta):
x = 1
self.assertEqual(meta_args[0], Meta)
self.assertEqual(meta_args[1], 'D')
self.assertEqual(meta_args[2], (A,))
self.assertEqual(meta_args[3]['x'], 1)
self.assertEqual(D.__bases__, (A,))
self.assertEqual(D.__orig_bases__, (c,))
self.assertEqual(D.__mro__, (D, A, object))
self.assertEqual(D.__class__, Meta)
def test_mro_entry_type_call(self):
# Substitution should _not_ happen in direct type call
class C:
def __mro_entries__(self, bases):
return ()
c = C()
with self.assertRaisesRegex(TypeError,
"MRO entry resolution; "
"use types.new_class()"):
type('Bad', (c,), {})
class TestClassGetitem(unittest.TestCase):
def test_class_getitem(self):
getitem_args = []
class C:
def __class_getitem__(*args, **kwargs):
getitem_args.extend([args, kwargs])
return None
C[int, str]
self.assertEqual(getitem_args[0], (C, (int, str)))
self.assertEqual(getitem_args[1], {})
def test_class_getitem_format(self):
class C:
def __class_getitem__(cls, item):
return f'C[{item.__name__}]'
self.assertEqual(C[int], 'C[int]')
self.assertEqual(C[C], 'C[C]')
def test_class_getitem_inheritance(self):
class C:
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
class D(C): ...
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_inheritance_2(self):
class C:
def __class_getitem__(cls, item):
return 'Should not see this'
class D(C):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_classmethod(self):
class C:
@classmethod
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
class D(C): ...
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_patched(self):
class C:
def __init_subclass__(cls):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
cls.__class_getitem__ = classmethod(__class_getitem__)
class D(C): ...
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_with_builtins(self):
class A(dict):
called_with = None
def __class_getitem__(cls, item):
cls.called_with = item
class B(A):
pass
self.assertIs(B.called_with, None)
B[int]
self.assertIs(B.called_with, int)
def test_class_getitem_errors(self):
class C_too_few:
def __class_getitem__(cls):
return None
with self.assertRaises(TypeError):
C_too_few[int]
class C_too_many:
def __class_getitem__(cls, one, two):
return None
with self.assertRaises(TypeError):
C_too_many[int]
def test_class_getitem_errors_2(self):
class C:
def __class_getitem__(cls, item):
return None
with self.assertRaises(TypeError):
C()[int]
class E: ...
e = E()
e.__class_getitem__ = lambda cls, item: 'This will not work'
with self.assertRaises(TypeError):
e[int]
class C_not_callable:
__class_getitem__ = "Surprise!"
with self.assertRaises(TypeError):
C_not_callable[int]
def test_class_getitem_metaclass(self):
class Meta(type):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
self.assertEqual(Meta[int], 'Meta[int]')
def test_class_getitem_with_metaclass(self):
class Meta(type): pass
class C(metaclass=Meta):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
self.assertEqual(C[int], 'C[int]')
def test_class_getitem_metaclass_first(self):
class Meta(type):
def __getitem__(cls, item):
return 'from metaclass'
class C(metaclass=Meta):
def __class_getitem__(cls, item):
return 'from __class_getitem__'
self.assertEqual(C[int], 'from metaclass')
@support.cpython_only
class CAPITest(unittest.TestCase):
def test_c_class(self):
from _testcapi import Generic, GenericAlias
self.assertIsInstance(Generic.__class_getitem__(int), GenericAlias)
IntGeneric = Generic[int]
self.assertIs(type(IntGeneric), GenericAlias)
self.assertEqual(IntGeneric.__mro_entries__(()), (int,))
class C(IntGeneric):
pass
self.assertEqual(C.__bases__, (int,))
self.assertEqual(C.__orig_bases__, (IntGeneric,))
self.assertEqual(C.__mro__, (C, int, object))
if __name__ == "__main__":
unittest.main()

View File

@@ -20,7 +20,7 @@ mod decl {
use crate::builtins::pybool::IntoPyBool;
use crate::builtins::pystr::{PyStr, PyStrRef};
use crate::builtins::pytype::PyTypeRef;
use crate::builtins::{PyByteArray, PyBytes};
use crate::builtins::{PyByteArray, PyBytes, PyTupleRef};
use crate::byteslike::ArgBytesLike;
use crate::common::{hash::PyHash, str::to_ascii};
#[cfg(feature = "rustpython-compiler")]
@@ -832,7 +832,7 @@ mod decl {
pub fn __build_class__(
function: PyFunctionRef,
qualified_name: PyStrRef,
bases: Args<PyTypeRef>,
bases: Args,
mut kwargs: KwArgs,
vm: &VirtualMachine,
) -> PyResult {
@@ -845,7 +845,41 @@ mod decl {
vm.ctx.types.type_type.clone()
};
for base in bases.clone() {
let mut new_bases: Option<Vec<PyObjectRef>> = None;
let bases = PyTupleRef::with_elements(bases.into_vec(), &vm.ctx);
for (i, base) in bases.as_slice().iter().enumerate() {
if base.isinstance(&vm.ctx.types.type_type) {
if let Some(bases) = &mut new_bases {
bases.push(base.clone());
}
continue;
}
let mro_entries = vm.get_attribute_opt(base.clone(), "__mro_entries__")?;
let entries = match mro_entries {
Some(meth) => vm.invoke(&meth, (bases.clone(),))?,
None => {
if let Some(bases) = &mut new_bases {
bases.push(base.clone());
}
continue;
}
};
let entries: PyTupleRef = entries
.downcast()
.map_err(|_| vm.new_type_error("__mro_entries__ must return a tuple".to_owned()))?;
let new_bases = new_bases.get_or_insert_with(|| bases.as_slice()[..i].to_vec());
new_bases.extend_from_slice(entries.as_slice());
}
let new_bases = new_bases.map(|v| PyTupleRef::with_elements(v, &vm.ctx));
let (orig_bases, bases) = match new_bases {
Some(new) => (Some(bases), new),
None => (None, bases),
};
for base in bases.as_slice().iter() {
let base_class = base.class();
if base_class.issubclass(&metaclass) {
metaclass = base.clone_class();
@@ -858,7 +892,7 @@ mod decl {
}
}
let bases = bases.into_tuple(vm);
let bases = bases.into_object();
// Prepare uses full __getattribute__ resolution chain.
let prepare = vm.get_attribute(metaclass.clone().into_object(), "__prepare__")?;
@@ -872,6 +906,10 @@ mod decl {
let classcell = function.invoke_with_locals(().into(), Some(namespace.clone()), vm)?;
let classcell = <Option<PyCellRef>>::try_from_object(vm, classcell)?;
if let Some(orig_bases) = orig_bases {
namespace.set_item("__orig_bases__", orig_bases.into_object(), vm)?;
}
let class = vm.invoke(
metaclass.as_object(),
FuncArgs::new(vec![name_obj, bases, namespace.into_object()], kwargs),

View File

@@ -229,11 +229,6 @@ impl PyBaseObject {
#[pyclassmethod(magic)]
fn init_subclass(_cls: PyTypeRef) {}
#[pyclassmethod(magic)]
fn class_getitem(cls: PyTypeRef, _args: FuncArgs) -> PyObjectRef {
cls.into_object()
}
#[pymethod(magic)]
pub fn dir(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyList> {
let attributes: PyAttributes = obj.class().get_attributes();

View File

@@ -13,9 +13,8 @@ use super::mappingproxy::PyMappingProxy;
use super::object;
use super::pystr::{PyStr, PyStrRef};
use super::staticmethod::PyStaticMethod;
use super::tuple::PyTuple;
use super::tuple::{PyTuple, PyTupleRef};
use super::weakref::PyWeak;
use crate::builtins::tuple::PyTupleTyped;
use crate::function::{FuncArgs, KwArgs, OptionalArg};
use crate::slots::{self, Callable, PyTpFlags, PyTypeSlots, SlotGetattro, SlotSetattro};
use crate::utils::Either;
@@ -399,7 +398,7 @@ impl PyType {
}));
}
let (name, bases, dict, kwargs): (PyStrRef, PyTupleTyped<PyTypeRef>, PyDictRef, KwArgs) =
let (name, bases, dict, kwargs): (PyStrRef, PyTupleRef, PyDictRef, KwArgs) =
args.clone().bind(vm)?;
let bases = bases.as_slice();
@@ -407,16 +406,25 @@ impl PyType {
let base = vm.ctx.types.object_type.clone();
(metatype, base.clone(), vec![base])
} else {
// TODO
// for base in &bases {
// if PyType_Check(base) { continue; }
// _PyObject_LookupAttrId(base, PyId___mro_entries__, &base)?
// Err(new_type_error( "type() doesn't support MRO entry resolution; "
// "use types.new_class()"))
// }
let bases = bases
.iter()
.map(|obj| {
obj.clone().downcast::<PyType>().or_else(|obj| {
if vm.get_attribute_opt(obj, "__mro_entries__")?.is_some() {
Err(vm.new_type_error(
"type() doesn't support MRO entry resolution; \
use types.new_class()"
.to_owned(),
))
} else {
Err(vm.new_type_error("bases must be types".to_owned()))
}
})
})
.collect::<PyResult<Vec<_>>>()?;
// Search the bases for the proper metatype to deal with this:
let winner = calculate_meta_class(metatype.clone(), bases, vm)?;
let winner = calculate_meta_class(metatype.clone(), &bases, vm)?;
let metatype = if !winner.is(&metatype) {
#[allow(clippy::redundant_clone)] // false positive
if let Some(ref tp_new) = winner.clone().slots.new {
@@ -429,9 +437,9 @@ impl PyType {
metatype
};
let base = best_base(bases, vm)?;
let base = best_base(&bases, vm)?;
(metatype, base, bases.to_vec())
(metatype, base, bases)
};
let mut attributes = dict.to_attributes();

View File

@@ -686,25 +686,19 @@ where
{
fn get_item(&self, key: T, vm: &VirtualMachine) -> PyResult {
match vm.get_special_method(self.clone(), "__getitem__")? {
Ok(special_method) => special_method.invoke((key,), vm),
Ok(special_method) => return special_method.invoke((key,), vm),
Err(obj) => {
if obj.isinstance(&vm.ctx.types.type_type) {
vm.get_special_method(obj, "__class_getitem__")?
.map_err(|obj2| {
vm.new_type_error(format!(
"'{}' object is not subscriptable",
obj2.class().name
))
})?
.invoke((key,), vm)
} else {
Err(vm.new_type_error(format!(
"'{}' object is not subscriptable",
obj.class().name
)))
if let Some(class_getitem) = vm.get_attribute_opt(obj, "__class_getitem__")? {
return vm.invoke(&class_getitem, (key,));
}
}
}
}
Err(vm.new_type_error(format!(
"'{}' object is not subscriptable",
self.class().name
)))
}
fn set_item(&self, key: T, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {