diff --git a/vm/src/builtins/function.rs b/vm/src/builtins/function.rs index 339cb44ce..82ebaa600 100644 --- a/vm/src/builtins/function.rs +++ b/vm/src/builtins/function.rs @@ -461,7 +461,8 @@ impl Comparable for PyBoundMethod { impl GetAttr for PyBoundMethod { fn getattro(zelf: PyRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult { - if let Some(obj) = zelf.get_class_attr(name.as_str()) { + let class_attr = zelf.get_class_attr(name.as_str()); + if let Some(obj) = class_attr { return vm.call_if_get_descriptor(obj, zelf.into()); } zelf.function.clone().get_attr(name, vm) diff --git a/vm/src/builtins/module.rs b/vm/src/builtins/module.rs index c142eb381..1d34d0e63 100644 --- a/vm/src/builtins/module.rs +++ b/vm/src/builtins/module.rs @@ -42,7 +42,8 @@ impl PyModule { #[pymethod(magic)] fn init(zelf: PyRef, args: ModuleInitArgs, vm: &VirtualMachine) { - debug_assert!(crate::AsPyObject::class(zelf.as_object()) + debug_assert!(zelf + .class() .slots .flags .has_feature(crate::types::PyTypeFlags::HAS_DICT)); diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 48d6e9564..230fafb4d 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -428,10 +428,6 @@ where self.class().get_attr(attr_name) } - fn has_class_attr(&self, attr_name: &str) -> bool { - self.class().has_attr(attr_name) - } - /// Determines if `obj` actually an instance of `cls`, this doesn't call __instancecheck__, so only /// use this if `cls` is known to have not overridden the base __instancecheck__ magic method. #[inline] diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 0e4cd3ff8..d578b16ab 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -187,14 +187,17 @@ fn length_wrapper(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { fn as_mapping_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> PyMappingMethods { PyMappingMethods { - length: then_some_closure!(zelf.has_class_attr("__len__"), |mapping, vm| { + length: then_some_closure!(zelf.class().has_attr("__len__"), |mapping, vm| { length_wrapper(mapping.obj.to_owned(), vm) }), - subscript: then_some_closure!(zelf.has_class_attr("__getitem__"), |mapping, needle, vm| { - vm.call_special_method(mapping.obj.to_owned(), "__getitem__", (needle.to_owned(),)) - }), + subscript: then_some_closure!( + zelf.class().has_attr("__getitem__"), + |mapping, needle, vm| { + vm.call_special_method(mapping.obj.to_owned(), "__getitem__", (needle.to_owned(),)) + } + ), ass_subscript: then_some_closure!( - zelf.has_class_attr("__setitem__") | zelf.has_class_attr("__delitem__"), + zelf.class().has_attr("__setitem__") | zelf.class().has_attr("__delitem__"), |mapping, needle, value, vm| match value { Some(value) => vm .call_special_method( @@ -216,19 +219,19 @@ fn as_mapping_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> PyMappingMethods } fn as_sequence_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PySequenceMethods> { - if !zelf.has_class_attr("__getitem__") { + if !zelf.class().has_attr("__getitem__") { return Cow::Borrowed(PySequenceMethods::not_implemented()); } Cow::Owned(PySequenceMethods { - length: then_some_closure!(zelf.has_class_attr("__len__"), |seq, vm| { + length: then_some_closure!(zelf.class().has_attr("__len__"), |seq, vm| { length_wrapper(seq.obj.to_owned(), vm) }), item: Some(|seq, i, vm| { vm.call_special_method(seq.obj.to_owned(), "__getitem__", (i.into_pyobject(vm),)) }), ass_item: then_some_closure!( - zelf.has_class_attr("__setitem__") | zelf.has_class_attr("__delitem__"), + zelf.class().has_attr("__setitem__") | zelf.class().has_attr("__delitem__"), |seq, i, value, vm| match value { Some(value) => vm .call_special_method( diff --git a/vm/src/vm.rs b/vm/src/vm.rs index cdc663e9e..5975553dc 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -747,10 +747,11 @@ impl VirtualMachine { where F: FnOnce() -> String, { - match obj.get_class_attr(method_name) { - Some(method) => self.call_if_get_descriptor(method, obj), - None => Err(self.new_type_error(err_msg())), - } + let method = obj + .class() + .get_attr(method_name) + .ok_or_else(|| self.new_type_error(err_msg()))?; + self.call_if_get_descriptor(method, obj) } // TODO: remove + transfer over to get_special_method diff --git a/vm/src/vm_ops.rs b/vm/src/vm_ops.rs index b75515485..eaca41057 100644 --- a/vm/src/vm_ops.rs +++ b/vm/src/vm_ops.rs @@ -170,8 +170,8 @@ impl VirtualMachine { unsupported: fn(&VirtualMachine, &PyObject, &PyObject) -> PyResult, ) -> PyResult { if rhs.isinstance(&lhs.class()) { - let lop = lhs.class().get_attr(reflection); - let rop = rhs.class().get_attr(reflection); + let lop = lhs.get_class_attr(reflection); + let rop = rhs.get_class_attr(reflection); if let Some((lop, rop)) = lop.zip(rop) { if !lop.is(&rop) { if let Ok(r) = self.call_or_unsupported(rhs, lhs, reflection, |vm, _, _| {