diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 37086576a3..6a1f91aec9 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -2,7 +2,7 @@ //! https://docs.python.org/3/c-api/object.html use crate::{ - builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef, PyTupleRef}, + builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef, PyTupleRef, PyTypeRef}, bytesinner::ByteInnerNewOptions, common::{hash::PyHash, str::to_ascii}, function::{IntoPyObject, OptionalArg}, @@ -227,8 +227,136 @@ impl PyObject { } } + // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything + // else go through. + fn check_cls(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult + where + F: Fn() -> String, + { + cls.to_owned().get_attr("__bases__", vm).map_err(|e| { + // Only mask AttributeErrors. + if e.class().is(&vm.ctx.exceptions.attribute_error) { + vm.new_type_error(msg()) + } else { + e + } + }) + } + + fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { + let mut derived = self; + let mut first_item: PyObjectRef; + loop { + if derived.is(cls) { + return Ok(true); + } + + let bases = derived.to_owned().get_attr("__bases__", vm)?; + let tuple = PyTupleRef::try_from_object(vm, bases)?; + + let n = tuple.len(); + match n { + 0 => { + return Ok(false); + } + 1 => { + first_item = tuple.fast_getitem(0).clone(); + derived = &first_item; + continue; + } + _ => { + for i in 0..n { + if let Ok(true) = tuple.fast_getitem(i).abstract_issubclass(cls, vm) { + return Ok(true); + } + } + } + } + + return Ok(false); + } + } + + fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { + if let (Ok(obj), Ok(cls)) = ( + PyTypeRef::try_from_object(vm, self.to_owned()), + PyTypeRef::try_from_object(vm, cls.to_owned()), + ) { + Ok(obj.issubclass(cls)) + } else { + self.check_cls(self, vm, || { + format!("issubclass() arg 1 must be a class, not {}", self.class()) + }) + .and(self.check_cls(cls, vm, || { + format!( + "issubclass() arg 2 must be a class or tuple of classes, not {}", + cls.class() + ) + })) + .and(self.abstract_issubclass(cls, vm)) + } + } + + /// Determines if `self` is a subclass of `cls`, either directly, indirectly or virtually + /// via the __subclasscheck__ magic method. pub fn is_subclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { - vm.issubclass(self, cls) + if cls.class().is(&vm.ctx.types.type_type) { + if self.is(cls) { + return Ok(true); + } + return self.recursive_issubclass(cls, vm); + } + + if let Ok(tuple) = PyTupleRef::try_from_object(vm, cls.to_owned()) { + for typ in tuple.as_slice().iter() { + if vm.with_recursion("in __subclasscheck__", || self.is_subclass(typ, vm))? { + return Ok(true); + } + } + return Ok(false); + } + + if let Ok(meth) = vm.get_special_method(cls.to_owned(), "__subclasscheck__")? { + let ret = vm.with_recursion("in __subclasscheck__", || { + meth.invoke((self.to_owned(),), vm) + })?; + return ret.try_to_bool(vm); + } + + self.recursive_issubclass(cls, vm) + } + + fn abstract_isinstance(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { + if let Ok(typ) = PyTypeRef::try_from_object(vm, cls.to_owned()) { + if self.class().issubclass(typ.clone()) { + Ok(true) + } else if let Ok(icls) = + PyTypeRef::try_from_object(vm, self.to_owned().get_attr("__class__", vm)?) + { + if icls.is(&self.class()) { + Ok(false) + } else { + Ok(icls.issubclass(typ)) + } + } else { + Ok(false) + } + } else { + self.check_cls(cls, vm, || { + format!( + "isinstance() arg 2 must be a type or tuple of types, not {}", + cls.class() + ) + }) + .and_then(|_| { + let icls: PyObjectRef = self.to_owned().get_attr("__class__", vm)?; + if vm.is_none(&icls) { + Ok(false) + } else { + icls.abstract_issubclass(cls, vm) + } + }) + } } /// Determines if `self` is an instance of `cls`, either directly, indirectly or virtually via @@ -241,7 +369,7 @@ impl PyObject { } if cls.class().is(&vm.ctx.types.type_type) { - return vm.abstract_isinstance(self, cls); + return self.abstract_isinstance(cls, vm); } if let Ok(tuple) = PyTupleRef::try_from_object(vm, cls.to_owned()) { @@ -260,7 +388,7 @@ impl PyObject { return ret.try_to_bool(vm); } - vm.abstract_isinstance(self, cls) + self.abstract_isinstance(cls, vm) } pub fn hash(&self, vm: &VirtualMachine) -> PyResult { diff --git a/vm/src/stdlib/builtins.rs b/vm/src/stdlib/builtins.rs index 607ef42260..a8b4db02d1 100644 --- a/vm/src/stdlib/builtins.rs +++ b/vm/src/stdlib/builtins.rs @@ -402,7 +402,7 @@ mod builtins { #[pyfunction] fn issubclass(subclass: PyObjectRef, typ: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.issubclass(&subclass, &typ) + subclass.is_subclass(&typ, vm) } #[pyfunction] diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 2539078c06..c76e041e55 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -992,139 +992,6 @@ impl VirtualMachine { } } - // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything - // else go through. - fn check_cls(&self, cls: &PyObject, msg: F) -> PyResult - where - F: Fn() -> String, - { - cls.to_owned().get_attr("__bases__", self).map_err(|e| { - // Only mask AttributeErrors. - if e.class().is(&self.ctx.exceptions.attribute_error) { - self.new_type_error(msg()) - } else { - e - } - }) - } - - pub fn abstract_isinstance(&self, obj: &PyObject, cls: &PyObject) -> PyResult { - if let Ok(typ) = PyTypeRef::try_from_object(self, cls.to_owned()) { - if obj.class().issubclass(typ.clone()) { - Ok(true) - } else if let Ok(icls) = - PyTypeRef::try_from_object(self, obj.to_owned().get_attr("__class__", self)?) - { - if icls.is(&obj.class()) { - Ok(false) - } else { - Ok(icls.issubclass(typ)) - } - } else { - Ok(false) - } - } else { - self.check_cls(cls, || { - format!( - "isinstance() arg 2 must be a type or tuple of types, not {}", - cls.class() - ) - }) - .and_then(|_| { - let icls: PyObjectRef = obj.to_owned().get_attr("__class__", self)?; - if self.is_none(&icls) { - Ok(false) - } else { - self.abstract_issubclass(icls, cls) - } - }) - } - } - - fn abstract_issubclass(&self, subclass: PyObjectRef, cls: &PyObject) -> PyResult { - let mut derived = subclass; - loop { - if derived.is(cls) { - return Ok(true); - } - - let bases = derived.get_attr("__bases__", self)?; - let tuple = PyTupleRef::try_from_object(self, bases)?; - - let n = tuple.len(); - match n { - 0 => { - return Ok(false); - } - 1 => { - derived = tuple.fast_getitem(0); - continue; - } - _ => { - for i in 0..n { - if let Ok(true) = self.abstract_issubclass(tuple.fast_getitem(i), cls) { - return Ok(true); - } - } - } - } - - return Ok(false); - } - } - - fn recursive_issubclass(&self, subclass: &PyObject, cls: &PyObject) -> PyResult { - if let (Ok(subclass), Ok(cls)) = ( - PyTypeRef::try_from_object(self, subclass.to_owned()), - PyTypeRef::try_from_object(self, cls.to_owned()), - ) { - Ok(subclass.issubclass(cls)) - } else { - self.check_cls(subclass, || { - format!( - "issubclass() arg 1 must be a class, not {}", - subclass.class() - ) - }) - .and(self.check_cls(cls, || { - format!( - "issubclass() arg 2 must be a class or tuple of classes, not {}", - cls.class() - ) - })) - .and(self.abstract_issubclass(subclass.to_owned(), cls)) - } - } - - /// Determines if `subclass` is a subclass of `cls`, either directly, indirectly or virtually - /// via the __subclasscheck__ magic method. - pub fn issubclass(&self, subclass: &PyObject, cls: &PyObject) -> PyResult { - if cls.class().is(&self.ctx.types.type_type) { - if subclass.is(cls) { - return Ok(true); - } - return self.recursive_issubclass(subclass, cls); - } - - if let Ok(tuple) = PyTupleRef::try_from_object(self, cls.to_owned()) { - for typ in tuple.as_slice().iter() { - if self.with_recursion("in __subclasscheck__", || self.issubclass(subclass, typ))? { - return Ok(true); - } - } - return Ok(false); - } - - if let Ok(meth) = self.get_special_method(cls.to_owned(), "__subclasscheck__")? { - let ret = self.with_recursion("in __subclasscheck__", || { - meth.invoke((subclass.to_owned(),), self) - })?; - return ret.try_to_bool(self); - } - - self.recursive_issubclass(subclass, cls) - } - pub fn call_get_descriptor_specific( &self, descr: PyObjectRef,