From 2895c8124a2673403297ffe162fec2cdc09a7513 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 25 Jun 2022 14:00:51 +0200 Subject: [PATCH] use heaptypeext for number protocol --- vm/src/protocol/number.rs | 32 +++---------- vm/src/types/slot.rs | 95 +++++++++++++++++++++++++++++++-------- 2 files changed, 83 insertions(+), 44 deletions(-) diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs index 86a13f268..36999a227 100644 --- a/vm/src/protocol/number.rs +++ b/vm/src/protocol/number.rs @@ -63,6 +63,9 @@ pub struct PyNumberMethods { } impl PyNumberMethods { + /// this is NOT a global variable + // TODO: weak order read for performance + #[allow(clippy::declare_interior_mutable_const)] pub const NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods { add: AtomicCell::new(None), subtract: AtomicCell::new(None), @@ -101,28 +104,6 @@ impl PyNumberMethods { matrix_multiply: AtomicCell::new(None), inplace_matrix_multiply: AtomicCell::new(None), }; - - fn int(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { - let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) - }) - } - fn float(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { - let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!( - "__float__ returned non-float (type {})", - obj.class() - )) - }) - } - fn index(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { - let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?; - ret.downcast::().map_err(|obj| { - vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) - }) - } } pub struct PyNumber<'a> { @@ -142,18 +123,19 @@ impl<'a> From<&'a PyObject> for PyNumber<'a> { impl PyNumber<'_> { pub fn methods(&self) -> &PyNumberMethods { + static GLOBAL_NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods::NOT_IMPLEMENTED; let as_number = self.methods.get_or_init(|| { - Self::find_methods(self.obj).unwrap_or(NonNull::from(&PyNumberMethods::NOT_IMPLEMENTED)) + Self::find_methods(self.obj).unwrap_or_else(|| NonNull::from(&GLOBAL_NOT_IMPLEMENTED)) }); unsafe { as_number.as_ref() } } - fn find_methods<'a>(obj: &'a PyObject) -> Option> { + fn find_methods(obj: &PyObject) -> Option> { obj.class().mro_find_map(|x| x.slots.as_number.load()) } // PyNumber_Check - pub fn check<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> bool { + pub fn check(obj: &PyObject) -> bool { let num = PyNumber::from(obj); let methods = num.methods(); methods.int.load().is_some() diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 8a0c31b89..8fea8053a 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,6 +1,6 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; use crate::{ - builtins::{PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, + builtins::{PyFloat, PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, bytecode::ComparisonOperator, convert::ToPyResult, function::Either, @@ -205,6 +205,30 @@ fn slot_as_sequence(zelf: &PyObject, vm: &VirtualMachine) -> &'static PySequence PySequenceMethods::generic(has_length, has_ass_item) } +fn int_wrapper(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class())) + }) +} + +fn index_wrapper(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class())) + }) +} + +fn float_wrapper(num: &PyNumber, vm: &VirtualMachine) -> PyResult> { + let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?; + ret.downcast::().map_err(|obj| { + vm.new_type_error(format!( + "__float__ returned non-float (type {})", + obj.class() + )) + }) +} + fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { let hash_obj = vm.call_special_method(zelf.to_owned(), identifier!(vm, __hash__), ())?; match hash_obj.payload_if_subclass::(vm) { @@ -318,21 +342,37 @@ impl PyType { debug_assert!(name.as_str().starts_with("__")); debug_assert!(name.as_str().ends_with("__")); - macro_rules! update_slot { + macro_rules! toggle_slot { ($name:ident, $func:expr) => {{ self.slots.$name.store(if add { Some($func) } else { None }); }}; } + + macro_rules! update_slot { + ($name:ident, $func:expr) => {{ + self.slots.$name.store(Some($func)); + }}; + } + + macro_rules! update_pointer_slot { + ($name:ident, $pointed:ident) => {{ + self.slots.$name.store( + self.heaptype_ext + .as_ref() + .map(|ext| NonNull::from(&ext.$pointed)), + ); + }}; + } match name.as_str() { "__len__" | "__getitem__" | "__setitem__" | "__delitem__" => { update_slot!(as_mapping, slot_as_mapping); update_slot!(as_sequence, slot_as_sequence); } "__hash__" => { - update_slot!(hash, hash_wrapper); + toggle_slot!(hash, hash_wrapper); } "__call__" => { - update_slot!(call, call_wrapper); + toggle_slot!(call, call_wrapper); } "__getattr__" | "__getattribute__" => { update_slot!(getattro, getattro_wrapper); @@ -344,28 +384,52 @@ impl PyType { update_slot!(richcompare, richcompare_wrapper); } "__iter__" => { - update_slot!(iter, iter_wrapper); + toggle_slot!(iter, iter_wrapper); } "__next__" => { - update_slot!(iternext, iternext_wrapper); + toggle_slot!(iternext, iternext_wrapper); } "__get__" => { - update_slot!(descr_get, descr_get_wrapper); + toggle_slot!(descr_get, descr_get_wrapper); } "__set__" | "__delete__" => { update_slot!(descr_set, descr_set_wrapper); } "__init__" => { - update_slot!(init, init_wrapper); + toggle_slot!(init, init_wrapper); } "__new__" => { - update_slot!(new, new_wrapper); + toggle_slot!(new, new_wrapper); } "__del__" => { - update_slot!(del, del_wrapper); + toggle_slot!(del, del_wrapper); } - "__int__" | "__index__" | "__float__" => { - // update_slot!(as_number, slot_as_number); + "__int__" => { + self.heaptype_ext + .as_ref() + .unwrap() + .number_methods + .int + .store(Some(int_wrapper)); + update_pointer_slot!(as_number, number_methods); + } + "__index__" => { + self.heaptype_ext + .as_ref() + .unwrap() + .number_methods + .index + .store(Some(index_wrapper)); + update_pointer_slot!(as_number, number_methods); + } + "__float__" => { + self.heaptype_ext + .as_ref() + .unwrap() + .number_methods + .float + .store(Some(float_wrapper)); + update_pointer_slot!(as_number, number_methods); } _ => {} } @@ -871,15 +935,8 @@ pub trait AsSequence: PyPayload { #[pyimpl] pub trait AsNumber: PyPayload { - // const AS_NUMBER: PyNumberMethods; - #[pyslot] fn as_number() -> &'static PyNumberMethods; - // #[inline] - // #[pyslot] - // fn as_number() -> &'static PyNumberMethods { - // &Self::AS_NUMBER - // } fn number_downcast<'a>(number: &'a PyNumber) -> &'a Py { unsafe { number.obj.downcast_unchecked_ref() }