use heaptypeext for number protocol

This commit is contained in:
Kangzhi Shi
2022-06-25 14:00:51 +02:00
committed by Jeong YunWon
parent 44daeef9c8
commit 2895c8124a
2 changed files with 83 additions and 44 deletions

View File

@@ -63,6 +63,9 @@ pub struct PyNumberMethods {
}
impl PyNumberMethods {
/// this is NOT a global variable
// TODO: weak order read for performance
#[allow(clippy::declare_interior_mutable_const)]
pub const NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods {
add: AtomicCell::new(None),
subtract: AtomicCell::new(None),
@@ -101,28 +104,6 @@ impl PyNumberMethods {
matrix_multiply: AtomicCell::new(None),
inplace_matrix_multiply: AtomicCell::new(None),
};
fn int(num: &PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyInt>> {
let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?;
ret.downcast::<PyInt>().map_err(|obj| {
vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class()))
})
}
fn float(num: &PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyFloat>> {
let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?;
ret.downcast::<PyFloat>().map_err(|obj| {
vm.new_type_error(format!(
"__float__ returned non-float (type {})",
obj.class()
))
})
}
fn index(num: &PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyInt>> {
let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?;
ret.downcast::<PyInt>().map_err(|obj| {
vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class()))
})
}
}
pub struct PyNumber<'a> {
@@ -142,18 +123,19 @@ impl<'a> From<&'a PyObject> for PyNumber<'a> {
impl PyNumber<'_> {
pub fn methods(&self) -> &PyNumberMethods {
static GLOBAL_NOT_IMPLEMENTED: PyNumberMethods = PyNumberMethods::NOT_IMPLEMENTED;
let as_number = self.methods.get_or_init(|| {
Self::find_methods(self.obj).unwrap_or(NonNull::from(&PyNumberMethods::NOT_IMPLEMENTED))
Self::find_methods(self.obj).unwrap_or_else(|| NonNull::from(&GLOBAL_NOT_IMPLEMENTED))
});
unsafe { as_number.as_ref() }
}
fn find_methods<'a>(obj: &'a PyObject) -> Option<NonNull<PyNumberMethods>> {
fn find_methods(obj: &PyObject) -> Option<NonNull<PyNumberMethods>> {
obj.class().mro_find_map(|x| x.slots.as_number.load())
}
// PyNumber_Check
pub fn check<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> bool {
pub fn check(obj: &PyObject) -> bool {
let num = PyNumber::from(obj);
let methods = num.methods();
methods.int.load().is_some()

View File

@@ -1,6 +1,6 @@
use crate::common::{hash::PyHash, lock::PyRwLock};
use crate::{
builtins::{PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef},
builtins::{PyFloat, PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef},
bytecode::ComparisonOperator,
convert::ToPyResult,
function::Either,
@@ -205,6 +205,30 @@ fn slot_as_sequence(zelf: &PyObject, vm: &VirtualMachine) -> &'static PySequence
PySequenceMethods::generic(has_length, has_ass_item)
}
fn int_wrapper(num: &PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyInt>> {
let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __int__), ())?;
ret.downcast::<PyInt>().map_err(|obj| {
vm.new_type_error(format!("__int__ returned non-int (type {})", obj.class()))
})
}
fn index_wrapper(num: &PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyInt>> {
let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __index__), ())?;
ret.downcast::<PyInt>().map_err(|obj| {
vm.new_type_error(format!("__index__ returned non-int (type {})", obj.class()))
})
}
fn float_wrapper(num: &PyNumber, vm: &VirtualMachine) -> PyResult<PyRef<PyFloat>> {
let ret = vm.call_special_method(num.obj.to_owned(), identifier!(vm, __float__), ())?;
ret.downcast::<PyFloat>().map_err(|obj| {
vm.new_type_error(format!(
"__float__ returned non-float (type {})",
obj.class()
))
})
}
fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult<PyHash> {
let hash_obj = vm.call_special_method(zelf.to_owned(), identifier!(vm, __hash__), ())?;
match hash_obj.payload_if_subclass::<PyInt>(vm) {
@@ -318,21 +342,37 @@ impl PyType {
debug_assert!(name.as_str().starts_with("__"));
debug_assert!(name.as_str().ends_with("__"));
macro_rules! update_slot {
macro_rules! toggle_slot {
($name:ident, $func:expr) => {{
self.slots.$name.store(if add { Some($func) } else { None });
}};
}
macro_rules! update_slot {
($name:ident, $func:expr) => {{
self.slots.$name.store(Some($func));
}};
}
macro_rules! update_pointer_slot {
($name:ident, $pointed:ident) => {{
self.slots.$name.store(
self.heaptype_ext
.as_ref()
.map(|ext| NonNull::from(&ext.$pointed)),
);
}};
}
match name.as_str() {
"__len__" | "__getitem__" | "__setitem__" | "__delitem__" => {
update_slot!(as_mapping, slot_as_mapping);
update_slot!(as_sequence, slot_as_sequence);
}
"__hash__" => {
update_slot!(hash, hash_wrapper);
toggle_slot!(hash, hash_wrapper);
}
"__call__" => {
update_slot!(call, call_wrapper);
toggle_slot!(call, call_wrapper);
}
"__getattr__" | "__getattribute__" => {
update_slot!(getattro, getattro_wrapper);
@@ -344,28 +384,52 @@ impl PyType {
update_slot!(richcompare, richcompare_wrapper);
}
"__iter__" => {
update_slot!(iter, iter_wrapper);
toggle_slot!(iter, iter_wrapper);
}
"__next__" => {
update_slot!(iternext, iternext_wrapper);
toggle_slot!(iternext, iternext_wrapper);
}
"__get__" => {
update_slot!(descr_get, descr_get_wrapper);
toggle_slot!(descr_get, descr_get_wrapper);
}
"__set__" | "__delete__" => {
update_slot!(descr_set, descr_set_wrapper);
}
"__init__" => {
update_slot!(init, init_wrapper);
toggle_slot!(init, init_wrapper);
}
"__new__" => {
update_slot!(new, new_wrapper);
toggle_slot!(new, new_wrapper);
}
"__del__" => {
update_slot!(del, del_wrapper);
toggle_slot!(del, del_wrapper);
}
"__int__" | "__index__" | "__float__" => {
// update_slot!(as_number, slot_as_number);
"__int__" => {
self.heaptype_ext
.as_ref()
.unwrap()
.number_methods
.int
.store(Some(int_wrapper));
update_pointer_slot!(as_number, number_methods);
}
"__index__" => {
self.heaptype_ext
.as_ref()
.unwrap()
.number_methods
.index
.store(Some(index_wrapper));
update_pointer_slot!(as_number, number_methods);
}
"__float__" => {
self.heaptype_ext
.as_ref()
.unwrap()
.number_methods
.float
.store(Some(float_wrapper));
update_pointer_slot!(as_number, number_methods);
}
_ => {}
}
@@ -871,15 +935,8 @@ pub trait AsSequence: PyPayload {
#[pyimpl]
pub trait AsNumber: PyPayload {
// const AS_NUMBER: PyNumberMethods;
#[pyslot]
fn as_number() -> &'static PyNumberMethods;
// #[inline]
// #[pyslot]
// fn as_number() -> &'static PyNumberMethods {
// &Self::AS_NUMBER
// }
fn number_downcast<'a>(number: &'a PyNumber) -> &'a Py<Self> {
unsafe { number.obj.downcast_unchecked_ref() }