diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index ecca0b544..be1b64d22 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -807,7 +807,7 @@ fn try_int_radix(obj: &PyObject, base: u32, vm: &VirtualMachine) -> PyResult Option { +pub(crate) fn bytes_to_int(lit: &[u8], mut base: u32) -> Option { // split sign let mut lit = lit.trim(); let sign = match lit.first()? { diff --git a/vm/src/protocol/number.rs b/vm/src/protocol/number.rs new file mode 100644 index 000000000..aee744b27 --- /dev/null +++ b/vm/src/protocol/number.rs @@ -0,0 +1,161 @@ +use std::borrow::Cow; + +use crate::{ + builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr}, + common::{lock::OnceCell, static_cell}, + function::ArgBytesLike, + IdProtocol, PyObject, PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, + VirtualMachine, +}; + +#[allow(clippy::type_complexity)] +#[derive(Default, Clone)] +pub struct PyNumberMethods { + /* Number implementations must check *both* + arguments for proper type and implement the necessary conversions + in the slot functions themselves. */ + pub add: Option PyResult>, + pub subtract: Option PyResult>, + pub multiply: Option PyResult>, + pub remainder: Option PyResult>, + pub divmod: Option PyResult>, + pub power: Option PyResult>, + pub negative: Option PyResult>, + pub positive: Option PyResult>, + pub absolute: Option PyResult>, + pub boolean: Option PyResult>, + pub invert: Option PyResult>, + pub lshift: Option PyResult>, + pub rshift: Option PyResult>, + pub and: Option PyResult>, + pub xor: Option PyResult>, + pub or: Option PyResult>, + pub int: Option PyResult>, + pub float: Option PyResult>>, + + pub inplace_add: Option PyResult>, + pub inplace_substract: Option PyResult>, + pub inplace_multiply: Option PyResult>, + pub inplace_remainder: Option PyResult>, + pub inplace_divmod: Option PyResult>, + pub inplace_power: Option PyResult>, + pub inplace_lshift: Option PyResult>, + pub inplace_rshift: Option PyResult>, + pub inplace_and: Option PyResult>, + pub inplace_xor: Option PyResult>, + pub inplace_or: Option PyResult>, + + pub floor_divide: Option PyResult>, + pub true_divide: Option PyResult>, + pub inplace_floor_divide: Option PyResult>, + pub inplace_true_devide: Option PyResult>, + + pub index: Option PyResult>, + + pub matrix_multiply: Option PyResult>, + pub inplace_matrix_multiply: Option PyResult>, +} + +impl PyNumberMethods { + fn not_implemented() -> &'static Self { + static_cell! { + static NOT_IMPLEMENTED: PyNumberMethods; + } + NOT_IMPLEMENTED.get_or_init(Self::default) + } +} + +pub struct PyNumber<'a> { + pub obj: &'a PyObject, + // some fast path do not need methods, so we do lazy initialize + methods: OnceCell>, +} + +impl<'a> From<&'a PyObject> for PyNumber<'a> { + fn from(obj: &'a PyObject) -> Self { + Self { + obj, + methods: OnceCell::new(), + } + } +} + +impl<'a> PyNumber<'a> { + pub fn methods(&'a self, vm: &VirtualMachine) -> &'a Cow<'static, PyNumberMethods> { + self.methods.get_or_init(|| { + self.obj + .class() + .mro_find_map(|x| x.slots.as_number.load()) + .map(|f| f(self.obj, vm)) + .unwrap_or_else(|| Cow::Borrowed(PyNumberMethods::not_implemented())) + }) + } +} + +impl PyNumber<'_> { + // PyNumber_Check + pub fn is_numeric(&self, vm: &VirtualMachine) -> bool { + let methods = self.methods(vm); + methods.int.is_some() + || methods.index.is_some() + || methods.float.is_some() + || self.obj.payload_is::() + } + + // PyIndex_Check + pub fn is_index(&self, vm: &VirtualMachine) -> bool { + self.methods(vm).index.is_some() + } + + pub fn to_int(&self, vm: &VirtualMachine) -> PyResult { + fn try_convert(obj: &PyObject, lit: &[u8], vm: &VirtualMachine) -> PyResult { + let base = 10; + match int::bytes_to_int(lit, base) { + Some(i) => Ok(PyInt::from(i).into_ref(vm)), + None => Err(vm.new_value_error(format!( + "invalid literal for int() with base {}: {}", + base, + obj.repr(vm)?, + ))), + } + } + + if self.obj.class().is(PyInt::class(vm)) { + Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) + } else if let Some(f) = self.methods(vm).int { + f(self, vm) + } else if let Some(f) = self.methods(vm).index { + f(self, vm) + } else if let Ok(Ok(f)) = vm.get_special_method(self.obj.to_owned(), "__trunc__") { + let r = f.invoke((), vm)?; + PyNumber::from(r.as_ref()).to_index(vm) + } else if let Some(s) = self.obj.payload::() { + try_convert(self.obj, s.as_str().as_bytes(), vm) + } else if let Some(bytes) = self.obj.payload::() { + try_convert(self.obj, bytes, vm) + } else if let Some(bytearray) = self.obj.payload::() { + try_convert(self.obj, &bytearray.borrow_buf(), vm) + } else if let Ok(buffer) = ArgBytesLike::try_from_borrowed_object(vm, self.obj) { + // TODO: replace to PyBuffer + try_convert(self.obj, &buffer.borrow_buf(), vm) + } else { + Err(vm.new_type_error(format!( + "int() argument must be a string, a bytes-like object or a real number, not '{}'", + self.obj.class() + ))) + } + } + + pub fn to_index(&self, vm: &VirtualMachine) -> PyResult { + if self.obj.class().is(PyInt::class(vm)) { + Ok(unsafe { self.obj.downcast_unchecked_ref::() }.to_owned()) + } else if let Some(f) = self.methods(vm).index { + f(self, vm) + } else { + Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer", + self.obj.class() + ))) + } + } +} diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 245e7a81c..0dac14a21 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,4 +1,5 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; +use crate::protocol::PyNumberMethods; use crate::{ builtins::{PyInt, PyStrInterned, PyStrRef, PyType, PyTypeRef}, bytecode::ComparisonOperator, @@ -138,6 +139,7 @@ impl Default for PyTypeFlags { pub(crate) type GenericMethod = fn(&PyObject, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type AsMappingFunc = fn(&PyObject, &VirtualMachine) -> &'static PyMappingMethods; +pub(crate) type AsNumberFunc = fn(&PyObject, &VirtualMachine) -> Cow<'static, PyNumberMethods>; pub(crate) type HashFunc = fn(&PyObject, &VirtualMachine) -> PyResult; // CallFunc = GenericMethod pub(crate) type GetattroFunc = fn(&PyObject, PyStrRef, &VirtualMachine) -> PyResult;