diff --git a/vm/src/builtins/pytype.rs b/vm/src/builtins/pytype.rs index 997419d14e..aeb407949f 100644 --- a/vm/src/builtins/pytype.rs +++ b/vm/src/builtins/pytype.rs @@ -1,6 +1,6 @@ use super::{ - mappingproxy::PyMappingProxy, object, PyClassMethod, PyDictRef, PyInt, PyList, PyStaticMethod, - PyStr, PyStrRef, PyTuple, PyTupleRef, PyWeak, + mappingproxy::PyMappingProxy, object, PyClassMethod, PyDictRef, PyList, PyStaticMethod, PyStr, + PyStrRef, PyTuple, PyTupleRef, PyWeak, }; use crate::common::{ ascii, @@ -9,14 +9,11 @@ use crate::common::{ }; use crate::{ function::{FuncArgs, KwArgs, OptionalArg}, - protocol::{PyIterReturn, PyMappingMethods}, - types::{self, Callable, PyComparisonOp, PyTypeFlags, PyTypeSlots, SlotGetattro, SlotSetattro}, - utils::Either, - IdProtocol, PyAttributes, PyClassImpl, PyComparisonValue, PyContext, PyLease, PyObjectRef, - PyRef, PyResult, PyValue, StaticType, TypeProtocol, VirtualMachine, + types::{Callable, PyTypeFlags, PyTypeSlots, SlotGetattro, SlotSetattro}, + IdProtocol, PyAttributes, PyClassImpl, PyContext, PyLease, PyObjectRef, PyRef, PyResult, + PyValue, StaticType, TypeProtocol, VirtualMachine, }; use itertools::Itertools; -use num_traits::ToPrimitive; use std::collections::HashSet; use std::fmt; use std::ops::Deref; @@ -195,161 +192,6 @@ impl PyType { attributes } - - pub(crate) fn update_slot(&self, name: &str, add: bool) { - macro_rules! update_slot { - ($name:ident, $func:expr) => {{ - self.slots.$name.store(if add { Some($func) } else { None }); - }}; - } - match name { - "__new__" => { - let func: types::NewFunc = - |cls: PyTypeRef, mut args: FuncArgs, vm: &VirtualMachine| { - let new = vm - .get_attribute_opt(cls.as_object().clone(), "__new__")? - .unwrap(); - args.prepend_arg(cls.into()); - vm.invoke(&new, args) - }; - update_slot!(new, func); - } - "__call__" => { - let func: types::GenericMethod = - |zelf, args, vm| vm.call_special_method(zelf.clone(), "__call__", args); - update_slot!(call, func); - } - "__get__" => { - let func: types::DescrGetFunc = - |zelf, obj, cls, vm| vm.call_special_method(zelf, "__get__", (obj, cls)); - update_slot!(descr_get, func); - } - "__set__" | "__delete__" => { - let func: types::DescrSetFunc = |zelf, obj, value, vm| { - match value { - Some(val) => vm.call_special_method(zelf, "__set__", (obj, val)), - None => vm.call_special_method(zelf, "__delete__", (obj,)), - } - .map(drop) - }; - update_slot!(descr_set, func); - } - "__hash__" => { - let func: types::HashFunc = |zelf, vm| { - let hash_obj = vm.call_special_method(zelf.clone(), "__hash__", ())?; - match hash_obj.payload_if_subclass::(vm) { - Some(py_int) => { - Ok(rustpython_common::hash::hash_bigint(py_int.as_bigint())) - } - None => Err(vm - .new_type_error("__hash__ method should return an integer".to_owned())), - } - } as _; - update_slot!(hash, func); - } - "__del__" => { - let func: types::DelFunc = |zelf, vm| { - vm.call_special_method(zelf.clone(), "__del__", ())?; - Ok(()) - } as _; - update_slot!(del, func); - } - "__eq__" | "__ne__" | "__le__" | "__lt__" | "__ge__" | "__gt__" => { - update_slot!(richcompare, richcompare_wrapper); - } - "__getattribute__" => { - let func: types::GetattroFunc = - |zelf, name, vm| vm.call_special_method(zelf, "__getattribute__", (name,)); - update_slot!(getattro, func); - } - "__setattr__" => { - let func: types::SetattroFunc = |zelf, name, value, vm| { - match value { - Some(value) => { - vm.call_special_method(zelf.clone(), "__setattr__", (name, value))?; - } - None => { - vm.call_special_method(zelf.clone(), "__delattr__", (name,))?; - } - }; - Ok(()) - }; - update_slot!(setattro, func); - } - "__iter__" => { - let func: types::IterFunc = |zelf, vm| vm.call_special_method(zelf, "__iter__", ()); - update_slot!(iter, func); - } - "__next__" => { - let func: types::IterNextFunc = |zelf, vm| { - PyIterReturn::from_pyresult( - vm.call_special_method(zelf.clone(), "__next__", ()), - vm, - ) - }; - update_slot!(iternext, func); - } - "__len__" | "__getitem__" | "__setitem__" | "__delitem__" => { - macro_rules! then_some_closure { - ($cond:expr, $closure:expr) => { - if $cond { - Some($closure) - } else { - None - } - }; - } - - let func: types::MappingFunc = |zelf, _vm| { - Ok(PyMappingMethods { - length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| { - vm.call_special_method(zelf, "__len__", ()).map(|obj| { - obj.payload_if_subclass::(vm) - .map(|length_obj| { - length_obj.as_bigint().to_usize().ok_or_else(|| { - vm.new_value_error( - "__len__() should return >= 0".to_owned(), - ) - }) - }) - .unwrap() - })? - }), - subscript: then_some_closure!( - zelf.has_class_attr("__getitem__"), - |zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine| { - vm.call_special_method(zelf, "__getitem__", (needle,)) - } - ), - ass_subscript: then_some_closure!( - zelf.has_class_attr("__setitem__") | zelf.has_class_attr("__delitem__"), - |zelf, needle, value, vm| match value { - Some(value) => vm - .call_special_method(zelf, "__setitem__", (needle, value),) - .map(|_| Ok(()))?, - None => vm - .call_special_method(zelf, "__delitem__", (needle,)) - .map(|_| Ok(()))?, - } - ), - }) - }; - update_slot!(as_mapping, func); - // TODO: need to update sequence protocol too - } - _ => {} - } - } -} - -fn richcompare_wrapper( - zelf: &PyObjectRef, - other: &PyObjectRef, - op: PyComparisonOp, - vm: &VirtualMachine, -) -> PyResult> { - vm.call_special_method(zelf.clone(), op.method_name(), (other.clone(),)) - .map(Either::A) } impl PyTypeRef { diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 10a64cc00b..146fda3c09 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,6 +1,6 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; use crate::{ - builtins::{PyStrRef, PyTypeRef}, + builtins::{PyInt, PyStrRef, PyType, PyTypeRef}, function::{FromArgs, FuncArgs, IntoPyResult, OptionalArg}, protocol::{PyBuffer, PyIterReturn, PyMappingMethods}, utils::Either, @@ -8,70 +8,9 @@ use crate::{ VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; +use num_traits::ToPrimitive; use std::cmp::Ordering; -bitflags! { - pub struct PyTypeFlags: u64 { - const HEAPTYPE = 1 << 9; - const BASETYPE = 1 << 10; - const METHOD_DESCR = 1 << 17; - const HAS_DICT = 1 << 40; - - #[cfg(debug_assertions)] - const _CREATED_WITH_FLAGS = 1 << 63; - } -} - -impl PyTypeFlags { - // Default used for both built-in and normal classes: empty, for now. - // CPython default: Py_TPFLAGS_HAVE_STACKLESS_EXTENSION | Py_TPFLAGS_HAVE_VERSION_TAG - pub const DEFAULT: Self = Self::empty(); - - // CPython: See initialization of flags in type_new. - /// Used for types created in Python. Subclassable and are a - /// heaptype. - pub fn heap_type_flags() -> Self { - Self::DEFAULT | Self::HEAPTYPE | Self::BASETYPE - } - - pub fn has_feature(self, flag: Self) -> bool { - self.contains(flag) - } - - #[cfg(debug_assertions)] - pub fn is_created_with_flags(self) -> bool { - self.contains(Self::_CREATED_WITH_FLAGS) - } -} - -impl Default for PyTypeFlags { - fn default() -> Self { - Self::DEFAULT - } -} - -pub(crate) type GenericMethod = fn(&PyObjectRef, FuncArgs, &VirtualMachine) -> PyResult; -pub(crate) type NewFunc = fn(PyTypeRef, FuncArgs, &VirtualMachine) -> PyResult; -pub(crate) type DelFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult<()>; -pub(crate) type DescrGetFunc = - fn(PyObjectRef, Option, Option, &VirtualMachine) -> PyResult; -pub(crate) type DescrSetFunc = - fn(PyObjectRef, PyObjectRef, Option, &VirtualMachine) -> PyResult<()>; -pub(crate) type HashFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; -pub(crate) type RichCompareFunc = fn( - &PyObjectRef, - &PyObjectRef, - PyComparisonOp, - &VirtualMachine, -) -> PyResult>; -pub(crate) type GetattroFunc = fn(PyObjectRef, PyStrRef, &VirtualMachine) -> PyResult; -pub(crate) type SetattroFunc = - fn(&PyObjectRef, PyStrRef, Option, &VirtualMachine) -> PyResult<()>; -pub(crate) type BufferFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; -pub(crate) type MappingFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; -pub(crate) type IterFunc = fn(PyObjectRef, &VirtualMachine) -> PyResult; -pub(crate) type IterNextFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; - // The corresponding field in CPython is `tp_` prefixed. // e.g. name -> tp_name #[derive(Default)] @@ -84,7 +23,7 @@ pub struct PyTypeSlots { // Method suites for standard classes // tp_as_number // tp_as_sequence - pub as_mapping: AtomicCell>, + pub as_mapping: AtomicCell>, // More standard operations (here for binary compatibility) pub hash: AtomicCell>, @@ -94,7 +33,7 @@ pub struct PyTypeSlots { pub setattro: AtomicCell>, // Functions to access object as input/output buffer - pub as_buffer: Option, + pub as_buffer: Option, // Assigned meaning in release 2.1 // rich comparisons @@ -144,6 +83,250 @@ impl std::fmt::Debug for PyTypeSlots { } } +bitflags! { + pub struct PyTypeFlags: u64 { + const HEAPTYPE = 1 << 9; + const BASETYPE = 1 << 10; + const METHOD_DESCR = 1 << 17; + const HAS_DICT = 1 << 40; + + #[cfg(debug_assertions)] + const _CREATED_WITH_FLAGS = 1 << 63; + } +} + +impl PyTypeFlags { + // Default used for both built-in and normal classes: empty, for now. + // CPython default: Py_TPFLAGS_HAVE_STACKLESS_EXTENSION | Py_TPFLAGS_HAVE_VERSION_TAG + pub const DEFAULT: Self = Self::empty(); + + // CPython: See initialization of flags in type_new. + /// Used for types created in Python. Subclassable and are a + /// heaptype. + pub fn heap_type_flags() -> Self { + Self::DEFAULT | Self::HEAPTYPE | Self::BASETYPE + } + + pub fn has_feature(self, flag: Self) -> bool { + self.contains(flag) + } + + #[cfg(debug_assertions)] + pub fn is_created_with_flags(self) -> bool { + self.contains(Self::_CREATED_WITH_FLAGS) + } +} + +impl Default for PyTypeFlags { + fn default() -> Self { + Self::DEFAULT + } +} + +pub(crate) type GenericMethod = fn(&PyObjectRef, FuncArgs, &VirtualMachine) -> PyResult; +pub(crate) type AsMappingFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; +pub(crate) type HashFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; +// CallFunc = GenericMethod +pub(crate) type GetattroFunc = fn(PyObjectRef, PyStrRef, &VirtualMachine) -> PyResult; +pub(crate) type SetattroFunc = + fn(&PyObjectRef, PyStrRef, Option, &VirtualMachine) -> PyResult<()>; +pub(crate) type AsBufferFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; +pub(crate) type RichCompareFunc = fn( + &PyObjectRef, + &PyObjectRef, + PyComparisonOp, + &VirtualMachine, +) -> PyResult>; +pub(crate) type IterFunc = fn(PyObjectRef, &VirtualMachine) -> PyResult; +pub(crate) type IterNextFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; +pub(crate) type DescrGetFunc = + fn(PyObjectRef, Option, Option, &VirtualMachine) -> PyResult; +pub(crate) type DescrSetFunc = + fn(PyObjectRef, PyObjectRef, Option, &VirtualMachine) -> PyResult<()>; +pub(crate) type NewFunc = fn(PyTypeRef, FuncArgs, &VirtualMachine) -> PyResult; +pub(crate) type DelFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult<()>; + +fn as_mapping_wrapper(zelf: &PyObjectRef, _vm: &VirtualMachine) -> PyResult { + macro_rules! then_some_closure { + ($cond:expr, $closure:expr) => { + if $cond { + Some($closure) + } else { + None + } + }; + } + Ok(PyMappingMethods { + length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| { + vm.call_special_method(zelf, "__len__", ()).map(|obj| { + obj.payload_if_subclass::(vm) + .map(|length_obj| { + length_obj.as_bigint().to_usize().ok_or_else(|| { + vm.new_value_error("__len__() should return >= 0".to_owned()) + }) + }) + .unwrap() + })? + }), + subscript: then_some_closure!( + zelf.has_class_attr("__getitem__"), + |zelf: PyObjectRef, needle: PyObjectRef, vm: &VirtualMachine| { + vm.call_special_method(zelf, "__getitem__", (needle,)) + } + ), + ass_subscript: then_some_closure!( + zelf.has_class_attr("__setitem__") | zelf.has_class_attr("__delitem__"), + |zelf, needle, value, vm| match value { + Some(value) => vm + .call_special_method(zelf, "__setitem__", (needle, value),) + .map(|_| Ok(()))?, + None => vm + .call_special_method(zelf, "__delitem__", (needle,)) + .map(|_| Ok(()))?, + } + ), + }) +} + +fn hash_wrapper(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + let hash_obj = vm.call_special_method(zelf.clone(), "__hash__", ())?; + match hash_obj.payload_if_subclass::(vm) { + Some(py_int) => Ok(rustpython_common::hash::hash_bigint(py_int.as_bigint())), + None => Err(vm.new_type_error("__hash__ method should return an integer".to_owned())), + } +} + +fn call_wrapper(zelf: &PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + vm.call_special_method(zelf.clone(), "__call__", args) +} + +fn getattro_wrapper(zelf: PyObjectRef, name: PyStrRef, vm: &VirtualMachine) -> PyResult { + vm.call_special_method(zelf, "__getattribute__", (name,)) +} + +fn setattro_wrapper( + zelf: &PyObjectRef, + name: PyStrRef, + value: Option, + vm: &VirtualMachine, +) -> PyResult<()> { + match value { + Some(value) => { + vm.call_special_method(zelf.clone(), "__setattr__", (name, value))?; + } + None => { + vm.call_special_method(zelf.clone(), "__delattr__", (name,))?; + } + }; + Ok(()) +} + +fn richcompare_wrapper( + zelf: &PyObjectRef, + other: &PyObjectRef, + op: PyComparisonOp, + vm: &VirtualMachine, +) -> PyResult> { + vm.call_special_method(zelf.clone(), op.method_name(), (other.clone(),)) + .map(Either::A) +} + +fn iter_wrapper(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { + vm.call_special_method(zelf, "__iter__", ()) +} + +fn iternext_wrapper(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + PyIterReturn::from_pyresult(vm.call_special_method(zelf.clone(), "__next__", ()), vm) +} + +fn descr_get_wrapper( + zelf: PyObjectRef, + obj: Option, + cls: Option, + vm: &VirtualMachine, +) -> PyResult { + vm.call_special_method(zelf, "__get__", (obj, cls)) +} + +fn descr_set_wrapper( + zelf: PyObjectRef, + obj: PyObjectRef, + value: Option, + vm: &VirtualMachine, +) -> PyResult<()> { + match value { + Some(val) => vm.call_special_method(zelf, "__set__", (obj, val)), + None => vm.call_special_method(zelf, "__delete__", (obj,)), + } + .map(drop) +} + +fn new_wrapper(cls: PyTypeRef, mut args: FuncArgs, vm: &VirtualMachine) -> PyResult { + let new = vm + .get_attribute_opt(cls.as_object().clone(), "__new__")? + .unwrap(); + args.prepend_arg(cls.into()); + vm.invoke(&new, args) +} + +fn del_wrapper(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + vm.call_special_method(zelf.clone(), "__del__", ())?; + Ok(()) +} + +impl PyType { + pub(crate) fn update_slot(&self, name: &str, add: bool) { + debug_assert!(name.starts_with("__")); + debug_assert!(name.ends_with("__")); + + macro_rules! update_slot { + ($name:ident, $func:expr) => {{ + self.slots.$name.store(if add { Some($func) } else { None }); + }}; + } + match name { + "__len__" | "__getitem__" | "__setitem__" | "__delitem__" => { + update_slot!(as_mapping, as_mapping_wrapper); + // TODO: need to update sequence protocol too + } + "__hash__" => { + update_slot!(hash, hash_wrapper); + } + "__call__" => { + update_slot!(call, call_wrapper); + } + "__getattribute__" => { + update_slot!(getattro, getattro_wrapper); + } + "__setattr__" | "__delattr__" => { + update_slot!(setattro, setattro_wrapper); + } + "__eq__" | "__ne__" | "__le__" | "__lt__" | "__ge__" | "__gt__" => { + update_slot!(richcompare, richcompare_wrapper); + } + "__iter__" => { + update_slot!(iter, iter_wrapper); + } + "__next__" => { + update_slot!(iternext, iternext_wrapper); + } + "__get__" => { + update_slot!(descr_get, descr_get_wrapper); + } + "__set__" | "__delete__" => { + update_slot!(descr_set, descr_set_wrapper); + } + "__new__" => { + update_slot!(new, new_wrapper); + } + "__del__" => { + update_slot!(del, del_wrapper); + } + _ => {} + } + } +} + #[pyimpl] pub trait SlotConstructor: PyValue { type Args: FromArgs;