diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index 9e4f5af90..4ed47b93b 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -1,14 +1,15 @@ use super::{float, PyStr, PyType, PyTypeRef}; use crate::{ class::PyClassImpl, - convert::ToPyObject, + convert::{ToPyObject, ToPyResult}, function::{ OptionalArg, OptionalOption, PyArithmeticValue::{self, *}, PyComparisonValue, }, identifier, - types::{Comparable, Constructor, Hashable, PyComparisonOp}, + protocol::{PyNumber, PyNumberMethods}, + types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp}, AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use num_complex::Complex64; @@ -203,7 +204,7 @@ impl PyComplex { } } -#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor))] +#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor, AsNumber))] impl PyComplex { #[pymethod(magic)] fn complex(zelf: PyRef, vm: &VirtualMachine) -> PyRef { @@ -419,6 +420,98 @@ impl Hashable for PyComplex { } } +impl AsNumber for PyComplex { + const AS_NUMBER: PyNumberMethods = PyNumberMethods { + add: Some(|number, other, vm| Self::number_complex_op(number, other, |a, b| a + b, vm)), + subtract: Some(|number, other, vm| { + Self::number_complex_op(number, other, |a, b| a - b, vm) + }), + multiply: Some(|number, other, vm| { + Self::number_complex_op(number, other, |a, b| a * b, vm) + }), + remainder: None, + divmod: None, + power: Some(|number, other, vm| Self::number_general_op(number, other, inner_pow, vm)), + negative: Some(|number, vm| { + let value = Self::number_downcast(number).value; + (-value).to_pyresult(vm) + }), + positive: Some(|number, vm| Self::number_complex(number, vm).to_pyresult(vm)), + absolute: Some(|number, vm| { + let value = Self::number_downcast(number).value; + value.norm().to_pyresult(vm) + }), + boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), + invert: None, + lshift: None, + rshift: None, + and: None, + xor: None, + or: None, + int: None, + float: None, + inplace_add: None, + inplace_subtract: None, + inplace_multiply: None, + inplace_remainder: None, + inplace_divmod: None, + inplace_power: None, + inplace_lshift: None, + inplace_rshift: None, + inplace_and: None, + inplace_xor: None, + inplace_or: None, + floor_divide: None, + true_divide: Some(|number, other, vm| { + Self::number_general_op(number, other, inner_div, vm) + }), + inplace_floor_divide: None, + inplace_true_divide: None, + index: None, + matrix_multiply: None, + inplace_matrix_multiply: None, + }; +} + +impl PyComplex { + fn number_general_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R, + R: ToPyResult, + { + if let (Some(a), Some(b)) = (number.obj.payload::(), other.payload::()) { + op(a.value, b.value, vm).to_pyresult(vm) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + fn number_complex_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: FnOnce(Complex64, Complex64) -> Complex64, + { + Self::number_general_op(number, other, |a, b, _vm| op(a, b), vm) + } + + fn number_complex(number: &PyNumber, vm: &VirtualMachine) -> PyRef { + if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { + zelf.to_owned() + } else { + vm.ctx.new_complex(Self::number_downcast(number).value) + } + } +} + #[derive(FromArgs)] pub struct ComplexArgs { #[pyarg(any, optional)]