mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Merge pull request #2924 from RustPython/fix-class_getitem
Fix __class_getitem__ slightly
This commit is contained in:
286
Lib/test/test_genericclass.py
Normal file
286
Lib/test/test_genericclass.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import unittest
|
||||
from test import support
|
||||
|
||||
|
||||
class TestMROEntry(unittest.TestCase):
|
||||
def test_mro_entry_signature(self):
|
||||
tested = []
|
||||
class B: ...
|
||||
class C:
|
||||
def __mro_entries__(self, *args, **kwargs):
|
||||
tested.extend([args, kwargs])
|
||||
return (C,)
|
||||
c = C()
|
||||
self.assertEqual(tested, [])
|
||||
class D(B, c): ...
|
||||
self.assertEqual(tested[0], ((B, c),))
|
||||
self.assertEqual(tested[1], {})
|
||||
|
||||
def test_mro_entry(self):
|
||||
tested = []
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C:
|
||||
def __mro_entries__(self, bases):
|
||||
tested.append(bases)
|
||||
return (self.__class__,)
|
||||
c = C()
|
||||
self.assertEqual(tested, [])
|
||||
class D(A, c, B): ...
|
||||
self.assertEqual(tested[-1], (A, c, B))
|
||||
self.assertEqual(D.__bases__, (A, C, B))
|
||||
self.assertEqual(D.__orig_bases__, (A, c, B))
|
||||
self.assertEqual(D.__mro__, (D, A, C, B, object))
|
||||
d = D()
|
||||
class E(d): ...
|
||||
self.assertEqual(tested[-1], (d,))
|
||||
self.assertEqual(E.__bases__, (D,))
|
||||
|
||||
def test_mro_entry_none(self):
|
||||
tested = []
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C:
|
||||
def __mro_entries__(self, bases):
|
||||
tested.append(bases)
|
||||
return ()
|
||||
c = C()
|
||||
self.assertEqual(tested, [])
|
||||
class D(A, c, B): ...
|
||||
self.assertEqual(tested[-1], (A, c, B))
|
||||
self.assertEqual(D.__bases__, (A, B))
|
||||
self.assertEqual(D.__orig_bases__, (A, c, B))
|
||||
self.assertEqual(D.__mro__, (D, A, B, object))
|
||||
class E(c): ...
|
||||
self.assertEqual(tested[-1], (c,))
|
||||
self.assertEqual(E.__bases__, (object,))
|
||||
self.assertEqual(E.__orig_bases__, (c,))
|
||||
self.assertEqual(E.__mro__, (E, object))
|
||||
|
||||
def test_mro_entry_with_builtins(self):
|
||||
tested = []
|
||||
class A: ...
|
||||
class C:
|
||||
def __mro_entries__(self, bases):
|
||||
tested.append(bases)
|
||||
return (dict,)
|
||||
c = C()
|
||||
self.assertEqual(tested, [])
|
||||
class D(A, c): ...
|
||||
self.assertEqual(tested[-1], (A, c))
|
||||
self.assertEqual(D.__bases__, (A, dict))
|
||||
self.assertEqual(D.__orig_bases__, (A, c))
|
||||
self.assertEqual(D.__mro__, (D, A, dict, object))
|
||||
|
||||
def test_mro_entry_with_builtins_2(self):
|
||||
tested = []
|
||||
class C:
|
||||
def __mro_entries__(self, bases):
|
||||
tested.append(bases)
|
||||
return (C,)
|
||||
c = C()
|
||||
self.assertEqual(tested, [])
|
||||
class D(c, dict): ...
|
||||
self.assertEqual(tested[-1], (c, dict))
|
||||
self.assertEqual(D.__bases__, (C, dict))
|
||||
self.assertEqual(D.__orig_bases__, (c, dict))
|
||||
self.assertEqual(D.__mro__, (D, C, dict, object))
|
||||
|
||||
def test_mro_entry_errors(self):
|
||||
class C_too_many:
|
||||
def __mro_entries__(self, bases, something, other):
|
||||
return ()
|
||||
c = C_too_many()
|
||||
with self.assertRaises(TypeError):
|
||||
class D(c): ...
|
||||
class C_too_few:
|
||||
def __mro_entries__(self):
|
||||
return ()
|
||||
d = C_too_few()
|
||||
with self.assertRaises(TypeError):
|
||||
class D(d): ...
|
||||
|
||||
def test_mro_entry_errors_2(self):
|
||||
class C_not_callable:
|
||||
__mro_entries__ = "Surprise!"
|
||||
c = C_not_callable()
|
||||
with self.assertRaises(TypeError):
|
||||
class D(c): ...
|
||||
class C_not_tuple:
|
||||
def __mro_entries__(self):
|
||||
return object
|
||||
c = C_not_tuple()
|
||||
with self.assertRaises(TypeError):
|
||||
class D(c): ...
|
||||
|
||||
def test_mro_entry_metaclass(self):
|
||||
meta_args = []
|
||||
class Meta(type):
|
||||
def __new__(mcls, name, bases, ns):
|
||||
meta_args.extend([mcls, name, bases, ns])
|
||||
return super().__new__(mcls, name, bases, ns)
|
||||
class A: ...
|
||||
class C:
|
||||
def __mro_entries__(self, bases):
|
||||
return (A,)
|
||||
c = C()
|
||||
class D(c, metaclass=Meta):
|
||||
x = 1
|
||||
self.assertEqual(meta_args[0], Meta)
|
||||
self.assertEqual(meta_args[1], 'D')
|
||||
self.assertEqual(meta_args[2], (A,))
|
||||
self.assertEqual(meta_args[3]['x'], 1)
|
||||
self.assertEqual(D.__bases__, (A,))
|
||||
self.assertEqual(D.__orig_bases__, (c,))
|
||||
self.assertEqual(D.__mro__, (D, A, object))
|
||||
self.assertEqual(D.__class__, Meta)
|
||||
|
||||
def test_mro_entry_type_call(self):
|
||||
# Substitution should _not_ happen in direct type call
|
||||
class C:
|
||||
def __mro_entries__(self, bases):
|
||||
return ()
|
||||
c = C()
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"MRO entry resolution; "
|
||||
"use types.new_class()"):
|
||||
type('Bad', (c,), {})
|
||||
|
||||
|
||||
class TestClassGetitem(unittest.TestCase):
|
||||
def test_class_getitem(self):
|
||||
getitem_args = []
|
||||
class C:
|
||||
def __class_getitem__(*args, **kwargs):
|
||||
getitem_args.extend([args, kwargs])
|
||||
return None
|
||||
C[int, str]
|
||||
self.assertEqual(getitem_args[0], (C, (int, str)))
|
||||
self.assertEqual(getitem_args[1], {})
|
||||
|
||||
def test_class_getitem_format(self):
|
||||
class C:
|
||||
def __class_getitem__(cls, item):
|
||||
return f'C[{item.__name__}]'
|
||||
self.assertEqual(C[int], 'C[int]')
|
||||
self.assertEqual(C[C], 'C[C]')
|
||||
|
||||
def test_class_getitem_inheritance(self):
|
||||
class C:
|
||||
def __class_getitem__(cls, item):
|
||||
return f'{cls.__name__}[{item.__name__}]'
|
||||
class D(C): ...
|
||||
self.assertEqual(D[int], 'D[int]')
|
||||
self.assertEqual(D[D], 'D[D]')
|
||||
|
||||
def test_class_getitem_inheritance_2(self):
|
||||
class C:
|
||||
def __class_getitem__(cls, item):
|
||||
return 'Should not see this'
|
||||
class D(C):
|
||||
def __class_getitem__(cls, item):
|
||||
return f'{cls.__name__}[{item.__name__}]'
|
||||
self.assertEqual(D[int], 'D[int]')
|
||||
self.assertEqual(D[D], 'D[D]')
|
||||
|
||||
def test_class_getitem_classmethod(self):
|
||||
class C:
|
||||
@classmethod
|
||||
def __class_getitem__(cls, item):
|
||||
return f'{cls.__name__}[{item.__name__}]'
|
||||
class D(C): ...
|
||||
self.assertEqual(D[int], 'D[int]')
|
||||
self.assertEqual(D[D], 'D[D]')
|
||||
|
||||
def test_class_getitem_patched(self):
|
||||
class C:
|
||||
def __init_subclass__(cls):
|
||||
def __class_getitem__(cls, item):
|
||||
return f'{cls.__name__}[{item.__name__}]'
|
||||
cls.__class_getitem__ = classmethod(__class_getitem__)
|
||||
class D(C): ...
|
||||
self.assertEqual(D[int], 'D[int]')
|
||||
self.assertEqual(D[D], 'D[D]')
|
||||
|
||||
def test_class_getitem_with_builtins(self):
|
||||
class A(dict):
|
||||
called_with = None
|
||||
|
||||
def __class_getitem__(cls, item):
|
||||
cls.called_with = item
|
||||
class B(A):
|
||||
pass
|
||||
self.assertIs(B.called_with, None)
|
||||
B[int]
|
||||
self.assertIs(B.called_with, int)
|
||||
|
||||
def test_class_getitem_errors(self):
|
||||
class C_too_few:
|
||||
def __class_getitem__(cls):
|
||||
return None
|
||||
with self.assertRaises(TypeError):
|
||||
C_too_few[int]
|
||||
class C_too_many:
|
||||
def __class_getitem__(cls, one, two):
|
||||
return None
|
||||
with self.assertRaises(TypeError):
|
||||
C_too_many[int]
|
||||
|
||||
def test_class_getitem_errors_2(self):
|
||||
class C:
|
||||
def __class_getitem__(cls, item):
|
||||
return None
|
||||
with self.assertRaises(TypeError):
|
||||
C()[int]
|
||||
class E: ...
|
||||
e = E()
|
||||
e.__class_getitem__ = lambda cls, item: 'This will not work'
|
||||
with self.assertRaises(TypeError):
|
||||
e[int]
|
||||
class C_not_callable:
|
||||
__class_getitem__ = "Surprise!"
|
||||
with self.assertRaises(TypeError):
|
||||
C_not_callable[int]
|
||||
|
||||
def test_class_getitem_metaclass(self):
|
||||
class Meta(type):
|
||||
def __class_getitem__(cls, item):
|
||||
return f'{cls.__name__}[{item.__name__}]'
|
||||
self.assertEqual(Meta[int], 'Meta[int]')
|
||||
|
||||
def test_class_getitem_with_metaclass(self):
|
||||
class Meta(type): pass
|
||||
class C(metaclass=Meta):
|
||||
def __class_getitem__(cls, item):
|
||||
return f'{cls.__name__}[{item.__name__}]'
|
||||
self.assertEqual(C[int], 'C[int]')
|
||||
|
||||
def test_class_getitem_metaclass_first(self):
|
||||
class Meta(type):
|
||||
def __getitem__(cls, item):
|
||||
return 'from metaclass'
|
||||
class C(metaclass=Meta):
|
||||
def __class_getitem__(cls, item):
|
||||
return 'from __class_getitem__'
|
||||
self.assertEqual(C[int], 'from metaclass')
|
||||
|
||||
|
||||
@support.cpython_only
|
||||
class CAPITest(unittest.TestCase):
|
||||
|
||||
def test_c_class(self):
|
||||
from _testcapi import Generic, GenericAlias
|
||||
self.assertIsInstance(Generic.__class_getitem__(int), GenericAlias)
|
||||
|
||||
IntGeneric = Generic[int]
|
||||
self.assertIs(type(IntGeneric), GenericAlias)
|
||||
self.assertEqual(IntGeneric.__mro_entries__(()), (int,))
|
||||
class C(IntGeneric):
|
||||
pass
|
||||
self.assertEqual(C.__bases__, (int,))
|
||||
self.assertEqual(C.__orig_bases__, (IntGeneric,))
|
||||
self.assertEqual(C.__mro__, (C, int, object))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -20,7 +20,7 @@ mod decl {
|
||||
use crate::builtins::pybool::IntoPyBool;
|
||||
use crate::builtins::pystr::{PyStr, PyStrRef};
|
||||
use crate::builtins::pytype::PyTypeRef;
|
||||
use crate::builtins::{PyByteArray, PyBytes};
|
||||
use crate::builtins::{PyByteArray, PyBytes, PyTupleRef};
|
||||
use crate::byteslike::ArgBytesLike;
|
||||
use crate::common::{hash::PyHash, str::to_ascii};
|
||||
#[cfg(feature = "rustpython-compiler")]
|
||||
@@ -832,7 +832,7 @@ mod decl {
|
||||
pub fn __build_class__(
|
||||
function: PyFunctionRef,
|
||||
qualified_name: PyStrRef,
|
||||
bases: Args<PyTypeRef>,
|
||||
bases: Args,
|
||||
mut kwargs: KwArgs,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
@@ -845,7 +845,41 @@ mod decl {
|
||||
vm.ctx.types.type_type.clone()
|
||||
};
|
||||
|
||||
for base in bases.clone() {
|
||||
let mut new_bases: Option<Vec<PyObjectRef>> = None;
|
||||
|
||||
let bases = PyTupleRef::with_elements(bases.into_vec(), &vm.ctx);
|
||||
|
||||
for (i, base) in bases.as_slice().iter().enumerate() {
|
||||
if base.isinstance(&vm.ctx.types.type_type) {
|
||||
if let Some(bases) = &mut new_bases {
|
||||
bases.push(base.clone());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let mro_entries = vm.get_attribute_opt(base.clone(), "__mro_entries__")?;
|
||||
let entries = match mro_entries {
|
||||
Some(meth) => vm.invoke(&meth, (bases.clone(),))?,
|
||||
None => {
|
||||
if let Some(bases) = &mut new_bases {
|
||||
bases.push(base.clone());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let entries: PyTupleRef = entries
|
||||
.downcast()
|
||||
.map_err(|_| vm.new_type_error("__mro_entries__ must return a tuple".to_owned()))?;
|
||||
let new_bases = new_bases.get_or_insert_with(|| bases.as_slice()[..i].to_vec());
|
||||
new_bases.extend_from_slice(entries.as_slice());
|
||||
}
|
||||
|
||||
let new_bases = new_bases.map(|v| PyTupleRef::with_elements(v, &vm.ctx));
|
||||
let (orig_bases, bases) = match new_bases {
|
||||
Some(new) => (Some(bases), new),
|
||||
None => (None, bases),
|
||||
};
|
||||
|
||||
for base in bases.as_slice().iter() {
|
||||
let base_class = base.class();
|
||||
if base_class.issubclass(&metaclass) {
|
||||
metaclass = base.clone_class();
|
||||
@@ -858,7 +892,7 @@ mod decl {
|
||||
}
|
||||
}
|
||||
|
||||
let bases = bases.into_tuple(vm);
|
||||
let bases = bases.into_object();
|
||||
|
||||
// Prepare uses full __getattribute__ resolution chain.
|
||||
let prepare = vm.get_attribute(metaclass.clone().into_object(), "__prepare__")?;
|
||||
@@ -872,6 +906,10 @@ mod decl {
|
||||
let classcell = function.invoke_with_locals(().into(), Some(namespace.clone()), vm)?;
|
||||
let classcell = <Option<PyCellRef>>::try_from_object(vm, classcell)?;
|
||||
|
||||
if let Some(orig_bases) = orig_bases {
|
||||
namespace.set_item("__orig_bases__", orig_bases.into_object(), vm)?;
|
||||
}
|
||||
|
||||
let class = vm.invoke(
|
||||
metaclass.as_object(),
|
||||
FuncArgs::new(vec![name_obj, bases, namespace.into_object()], kwargs),
|
||||
|
||||
@@ -229,11 +229,6 @@ impl PyBaseObject {
|
||||
#[pyclassmethod(magic)]
|
||||
fn init_subclass(_cls: PyTypeRef) {}
|
||||
|
||||
#[pyclassmethod(magic)]
|
||||
fn class_getitem(cls: PyTypeRef, _args: FuncArgs) -> PyObjectRef {
|
||||
cls.into_object()
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
pub fn dir(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyList> {
|
||||
let attributes: PyAttributes = obj.class().get_attributes();
|
||||
|
||||
@@ -13,9 +13,8 @@ use super::mappingproxy::PyMappingProxy;
|
||||
use super::object;
|
||||
use super::pystr::{PyStr, PyStrRef};
|
||||
use super::staticmethod::PyStaticMethod;
|
||||
use super::tuple::PyTuple;
|
||||
use super::tuple::{PyTuple, PyTupleRef};
|
||||
use super::weakref::PyWeak;
|
||||
use crate::builtins::tuple::PyTupleTyped;
|
||||
use crate::function::{FuncArgs, KwArgs, OptionalArg};
|
||||
use crate::slots::{self, Callable, PyTpFlags, PyTypeSlots, SlotGetattro, SlotSetattro};
|
||||
use crate::utils::Either;
|
||||
@@ -399,7 +398,7 @@ impl PyType {
|
||||
}));
|
||||
}
|
||||
|
||||
let (name, bases, dict, kwargs): (PyStrRef, PyTupleTyped<PyTypeRef>, PyDictRef, KwArgs) =
|
||||
let (name, bases, dict, kwargs): (PyStrRef, PyTupleRef, PyDictRef, KwArgs) =
|
||||
args.clone().bind(vm)?;
|
||||
|
||||
let bases = bases.as_slice();
|
||||
@@ -407,16 +406,25 @@ impl PyType {
|
||||
let base = vm.ctx.types.object_type.clone();
|
||||
(metatype, base.clone(), vec![base])
|
||||
} else {
|
||||
// TODO
|
||||
// for base in &bases {
|
||||
// if PyType_Check(base) { continue; }
|
||||
// _PyObject_LookupAttrId(base, PyId___mro_entries__, &base)?
|
||||
// Err(new_type_error( "type() doesn't support MRO entry resolution; "
|
||||
// "use types.new_class()"))
|
||||
// }
|
||||
let bases = bases
|
||||
.iter()
|
||||
.map(|obj| {
|
||||
obj.clone().downcast::<PyType>().or_else(|obj| {
|
||||
if vm.get_attribute_opt(obj, "__mro_entries__")?.is_some() {
|
||||
Err(vm.new_type_error(
|
||||
"type() doesn't support MRO entry resolution; \
|
||||
use types.new_class()"
|
||||
.to_owned(),
|
||||
))
|
||||
} else {
|
||||
Err(vm.new_type_error("bases must be types".to_owned()))
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
// Search the bases for the proper metatype to deal with this:
|
||||
let winner = calculate_meta_class(metatype.clone(), bases, vm)?;
|
||||
let winner = calculate_meta_class(metatype.clone(), &bases, vm)?;
|
||||
let metatype = if !winner.is(&metatype) {
|
||||
#[allow(clippy::redundant_clone)] // false positive
|
||||
if let Some(ref tp_new) = winner.clone().slots.new {
|
||||
@@ -429,9 +437,9 @@ impl PyType {
|
||||
metatype
|
||||
};
|
||||
|
||||
let base = best_base(bases, vm)?;
|
||||
let base = best_base(&bases, vm)?;
|
||||
|
||||
(metatype, base, bases.to_vec())
|
||||
(metatype, base, bases)
|
||||
};
|
||||
|
||||
let mut attributes = dict.to_attributes();
|
||||
|
||||
@@ -686,25 +686,19 @@ where
|
||||
{
|
||||
fn get_item(&self, key: T, vm: &VirtualMachine) -> PyResult {
|
||||
match vm.get_special_method(self.clone(), "__getitem__")? {
|
||||
Ok(special_method) => special_method.invoke((key,), vm),
|
||||
Ok(special_method) => return special_method.invoke((key,), vm),
|
||||
Err(obj) => {
|
||||
if obj.isinstance(&vm.ctx.types.type_type) {
|
||||
vm.get_special_method(obj, "__class_getitem__")?
|
||||
.map_err(|obj2| {
|
||||
vm.new_type_error(format!(
|
||||
"'{}' object is not subscriptable",
|
||||
obj2.class().name
|
||||
))
|
||||
})?
|
||||
.invoke((key,), vm)
|
||||
} else {
|
||||
Err(vm.new_type_error(format!(
|
||||
"'{}' object is not subscriptable",
|
||||
obj.class().name
|
||||
)))
|
||||
if let Some(class_getitem) = vm.get_attribute_opt(obj, "__class_getitem__")? {
|
||||
return vm.invoke(&class_getitem, (key,));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(vm.new_type_error(format!(
|
||||
"'{}' object is not subscriptable",
|
||||
self.class().name
|
||||
)))
|
||||
}
|
||||
|
||||
fn set_item(&self, key: T, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
|
||||
Reference in New Issue
Block a user