diff --git a/derive/src/pyclass.rs b/derive/src/pyclass.rs index 6e7ce96ef..f06e66408 100644 --- a/derive/src/pyclass.rs +++ b/derive/src/pyclass.rs @@ -342,20 +342,22 @@ where let into_func = quote_spanned! {ident.span() => #transform(Self::#ident) }; - if slot_name == "call" { - quote! { - slots.#slot_ident.store( - Some( - |vm: &::rustpython_vm::VirtualMachine, args: ::rustpython_vm::function::PyFuncArgs| -> ::rustpython_vm::pyobject::PyResult { - ::rustpython_vm::function::IntoPyNativeFunc::call(&Self::#ident, vm, args) - } as _ - ) - ); - } - } else { - quote! { + match slot_name.as_str() { + "call" => quote! { + slots.#slot_ident.store(Some( + |vm: &::rustpython_vm::VirtualMachine, args: ::rustpython_vm::function::PyFuncArgs| -> ::rustpython_vm::pyobject::PyResult { + ::rustpython_vm::function::IntoPyNativeFunc::call(&Self::#ident, vm, args) + } as _ + )); + }, + "descr_get" => quote! { + slots.#slot_ident.store(Some( + Self::#ident as _ + )) + }, + _ => quote! { slots.#slot_ident = Some(#into_func); - } + }, } }; diff --git a/extra_tests/snippets/callables.py b/extra_tests/snippets/callables.py index 43b570628..c549ef468 100644 --- a/extra_tests/snippets/callables.py +++ b/extra_tests/snippets/callables.py @@ -9,3 +9,10 @@ class Callable(): c = Callable() assert 1 == c() assert 2 == c() + +class Inherited(Callable): + pass + +i = Inherited() + +assert 1 == i() diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index 3a2c030fa..e50d22dd3 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -210,24 +210,25 @@ impl PyClass { if let Some(attr) = mcl.get_attr(&name) { let attr_class = attr.lease_class(); if attr_class.has_attr("__set__") { - if let Some(ref descriptor) = attr_class.get_attr("__get__") { - drop(attr_class); + if let Some(ref descr_get) = + PyLease::into_pyref(attr_class).first_in_mro(|cls| cls.slots.descr_get.load()) + { let mcl = PyLease::into_pyref(mcl).into_object(); - return vm.invoke(descriptor, vec![attr, zelf.into_object(), mcl]); + return descr_get( + vm, + attr, + Some(zelf.into_object()), + OptionalArg::Present(mcl), + ); } } } if let Some(attr) = zelf.get_attr(&name) { let attr_class = attr.class(); - let slots = &attr_class.slots; - if let Some(ref descr_get) = slots.descr_get { + if let Some(ref descr_get) = attr_class.first_in_mro(|cls| cls.slots.descr_get.load()) { drop(mcl); return descr_get(vm, attr, None, OptionalArg::Present(zelf.into_object())); - } else if let Some(ref descriptor) = attr_class.get_attr("__get__") { - drop(mcl); - // TODO: is this nessessary? - return vm.invoke(descriptor, vec![attr, vm.ctx.none(), zelf.into_object()]); } } @@ -706,11 +707,12 @@ pub fn new( if base.slots.flags.has_feature(PyTpFlags::HAS_DICT) { slots.flags |= PyTpFlags::HAS_DICT } - for slot_name in ["__call__"].iter() { + for slot_name in ["__call__", "__get__"].iter() { if attrs.contains_key(*slot_name) { slots.update_slot_func(*slot_name); } } + let new_type = PyRef::new_ref( PyClass { name: String::from(name), diff --git a/vm/src/slots.rs b/vm/src/slots.rs index d492e0e1a..cd2297f9d 100644 --- a/vm/src/slots.rs +++ b/vm/src/slots.rs @@ -5,6 +5,7 @@ use crate::common::hash::PyHash; use crate::function::{IntoPyNativeFunc, OptionalArg, PyFuncArgs, PyNativeFunc}; use crate::pyobject::{ IdProtocol, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::VirtualMachine; use crossbeam_utils::atomic::AtomicCell; @@ -40,18 +41,22 @@ impl Default for PyTpFlags { } } +type GenericFunc = fn(&VirtualMachine, PyFuncArgs) -> PyResult; +type DescrGetFunc = + fn(&VirtualMachine, PyObjectRef, Option, OptionalArg) -> PyResult; +type HashFunc = Box PyResult)>; + #[derive(Default)] pub struct PyClassSlots { pub flags: PyTpFlags, pub name: PyRwLock>, // tp_name, not class name pub new: Option, - pub call: AtomicCell PyResult>>, - pub descr_get: Option, + pub call: AtomicCell>, + pub descr_get: AtomicCell>, pub hash: Option, pub cmp: Option, } -type HashFunc = Box PyResult)>; type CmpFunc = Box< py_dyn_fn!( dyn Fn( @@ -75,10 +80,14 @@ impl PyClassSlots { pub(crate) fn update_slot_func(&self, name: &str) { match name { "__call__" => { - self.call - .store(Some(|vm: &VirtualMachine, args: PyFuncArgs| -> PyResult { - IntoPyNativeFunc::call(&call_magic_call, vm, args) - } as _)) + let func: GenericFunc = + |vm, args| { IntoPyNativeFunc::call(&call_magic_call, vm, args) } as _; + self.call.store(Some(func)) + } + "__get__" => { + let func: DescrGetFunc = + |vm, zelf, obj, cls| { call_magic_descr_get(vm, zelf, obj, cls) } as _; + self.descr_get.store(Some(func)) } _ => (), } @@ -98,17 +107,37 @@ pub trait SlotCall: PyValue { fn call(zelf: PyRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult; } +#[inline] +fn get_class_magic(zelf: &PyObjectRef, name: &str) -> PyObjectRef { + zelf.get_class_attr(name).unwrap() + // let cls = zelf.lease_class(); + // let attrs = cls.attributes.read(); + // let attr = attrs.get(name); + // attr.unwrap().clone() +} + fn call_magic_call(zelf: PyObjectRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { - use crate::obj::objstr::PyString; - use crate::pyobject::IntoPyRef; + let magic = get_class_magic(&zelf, "__call__"); + let magic = vm.call_if_get_descriptor(magic, zelf.clone())?; + args.insert(zelf); + vm.invoke(&magic, args) +} - let magic_call = vm.generic_getattribute(zelf, PyString::from("__call__").into_pyref(vm))?; - vm.invoke(&magic_call, args) - - // use crate::pyobject::TypeProtocol; - // let magic_call = zelf.get_class_attr("__call__").unwrap(); - // args.insert( zelf); - // vm.invoke(&magic_call, args) +fn call_magic_descr_get( + vm: &VirtualMachine, + zelf: PyObjectRef, + obj: Option, + cls: OptionalArg, +) -> PyResult { + let magic = get_class_magic(&zelf, "__get__"); + vm.invoke( + &magic, + vec![ + zelf, + vm.unwrap_or_none(obj), + vm.unwrap_or_none(cls.into_option()), + ], + ) } pub type PyDescrGetFunc = Box< diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 09b9a8b69..b84b06555 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -841,18 +841,10 @@ impl VirtualMachine { obj: Option, cls: Option, ) -> Option { - let descr_class = descr.class(); - let slots = &descr_class.slots; - if let Some(descr_get) = slots.descr_get.as_ref() { - Some(descr_get(self, descr, obj, OptionalArg::from_option(cls))) - } else if let Some(ref descriptor) = descr_class.get_attr("__get__") { - Some(self.invoke( - descriptor, - vec![descr, self.unwrap_or_none(obj), self.unwrap_or_none(cls)], - )) - } else { - None - } + descr + .class() + .first_in_mro(|cls| cls.slots.descr_get.load()) + .map(|descr_get| descr_get(self, descr, obj, OptionalArg::from_option(cls))) } pub fn call_get_descriptor(&self, descr: PyObjectRef, obj: PyObjectRef) -> Option {