diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 4d0fa1202..8d883a9dd 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -1,4 +1,4 @@ -from testutils import assertRaises +from testutils import assert_raises # __abs__ @@ -23,11 +23,57 @@ assert bool(complex(1, 0)) assert complex(1, 2) != complex(1, 1) assert complex(1, 2) != 'foo' assert complex(1, 2).__eq__('foo') == NotImplemented +assert 1j != 10 ** 1000 -# __mul__ +# __mul__, __rmul__ assert complex(2, -3) * complex(-5, 7) == complex(11, 29) assert complex(2, -3) * 5 == complex(10, -15) +assert 5 * complex(2, -3) == complex(2, -3) * 5 + +# __truediv__, __rtruediv__ + +assert complex(2, -3) / 2 == complex(1, -1.5) +assert 5 / complex(3, -4) == complex(0.6, 0.8) + +# __mod__, __rmod__ + +assert_raises( + TypeError, + lambda: complex(2, -3) % 2, + "can't mod complex numbers.") +assert_raises( + TypeError, + lambda: 2 % complex(2, -3), + "can't mod complex numbers.") + +# __floordiv__, __rfloordiv__ + +assert_raises( + TypeError, + lambda: complex(2, -3) // 2, + "can't take floor of complex number.") +assert_raises( + TypeError, + lambda: 2 // complex(2, -3), + "can't take floor of complex number.") + +# __divmod__, __rdivmod__ + +assert_raises( + TypeError, + lambda: divmod(complex(2, -3), 2), + "can't take floor or mod of complex number.") +assert_raises( + TypeError, + lambda: divmod(2, complex(2, -3)), + "can't take floor or mod of complex number.") + +# __pow__, __rpow__ + +# assert 1j ** 2 == -1 +assert complex(1) ** 2 == 1 +assert 2 ** complex(2) == 4 # __neg__ @@ -40,18 +86,19 @@ assert bool(complex(0, 0)) is False assert bool(complex(0, 1)) is True assert bool(complex(1, 0)) is True -# real +# numbers.Complex a = complex(3, 4) b = 4j assert a.real == 3 assert b.real == 0 -# imag - assert a.imag == 4 assert b.imag == 4 +assert a.conjugate() == 3 - 4j +assert b.conjugate() == -4j + # int and complex addition assert 1 + 1j == complex(1, 1) assert 1j + 1 == complex(1, 1) @@ -70,19 +117,13 @@ assert 1j - 1 == complex(-1, 1) assert 2j - 1j == complex(0, 1) # type error addition -with assertRaises(TypeError): - assert 1j + 'str' -with assertRaises(TypeError): - assert 1j - 'str' -with assertRaises(TypeError): - assert 'str' + 1j -with assertRaises(TypeError): - assert 'str' - 1j +assert_raises(TypeError, lambda: 1j + 'str') +assert_raises(TypeError, lambda: 1j - 'str') +assert_raises(TypeError, lambda: 'str' + 1j) +assert_raises(TypeError, lambda: 'str' - 1j) # overflow -with assertRaises(OverflowError): - complex(10 ** 1000, 0) -with assertRaises(OverflowError): - complex(0, 10 ** 1000) -with assertRaises(OverflowError): - complex(0, 0) + 10 ** 1000 +msg = 'int too large to convert to float' +assert_raises(OverflowError, lambda: complex(10 ** 1000, 0), msg) +assert_raises(OverflowError, lambda: complex(0, 10 ** 1000), msg) +assert_raises(OverflowError, lambda: 0j + 10 ** 1000, msg) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index f9ff9d9b0..677763213 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -1,14 +1,18 @@ use num_complex::Complex64; -use num_traits::{ToPrimitive, Zero}; +use num_traits::Zero; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, +}; use crate::vm::VirtualMachine; use super::objfloat::{self, PyFloat}; -use super::objint; use super::objtype::{self, PyClassRef}; +/// Create a complex number from a real part and an optional imaginary part. +/// +/// This is equivalent to (real + imag*1j) where imag defaults to 0. #[pyclass(name = "complex")] #[derive(Debug, Copy, Clone, PartialEq)] pub struct PyComplex { @@ -22,6 +26,12 @@ impl PyValue for PyComplex { } } +impl IntoPyObject for Complex64 { + fn into_pyobject(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_complex(self)) + } +} + impl From for PyComplex { fn from(value: Complex64) -> Self { PyComplex { value } @@ -30,56 +40,20 @@ impl From for PyComplex { pub fn init(context: &PyContext) { PyComplex::extend_class(context, &context.complex_type); - let complex_doc = - "Create a complex number from a real part and an optional imaginary part.\n\n\ - This is equivalent to (real + imag*1j) where imag defaults to 0."; - - extend_class!(context, &context.complex_type, { - "__doc__" => context.new_str(complex_doc.to_string()), - "__new__" => context.new_rustfunc(PyComplexRef::new), - }); } pub fn get_value(obj: &PyObjectRef) -> Complex64 { obj.payload::().unwrap().value } -impl PyComplexRef { - fn new( - cls: PyClassRef, - real: OptionalArg, - imag: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - let real = match real { - OptionalArg::Missing => 0.0, - OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, - }; - - let imag = match imag { - OptionalArg::Missing => 0.0, - OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, - }; - - let value = Complex64::new(real, imag); - PyComplex { value }.into_ref_with_type(vm, cls) - } -} - -fn to_complex(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if objtype::isinstance(&value, &vm.ctx.complex_type()) { - Ok(Some(get_value(&value))) - } else if objtype::isinstance(&value, &vm.ctx.int_type()) { - match objint::get_value(&value).to_f64() { - Some(v) => Ok(Some(Complex64::new(v, 0.0))), - None => Err(vm.new_overflow_error("int too large to convert to float".to_string())), - } - } else if objtype::isinstance(&value, &vm.ctx.float_type()) { - let v = objfloat::get_value(&value); - Ok(Some(Complex64::new(v, 0.0))) +fn try_complex(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Ok(if objtype::isinstance(&value, &vm.ctx.complex_type()) { + Some(get_value(&value)) + } else if let Some(float) = objfloat::try_float(value, vm)? { + Some(Complex64::new(float, 0.0)) } else { - Ok(None) - } + None + }) } #[pyimpl] @@ -102,62 +76,48 @@ impl PyComplex { #[pymethod(name = "__add__")] fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if objtype::isinstance(&other, &vm.ctx.complex_type()) { - Ok(vm.ctx.new_complex(self.value + get_value(&other))) - } else { - self.radd(other, vm) - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value + other).into_pyobject(vm), + ) } #[pymethod(name = "__radd__")] fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(self.value + other)), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } + self.add(other, vm) } #[pymethod(name = "__sub__")] fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if objtype::isinstance(&other, &vm.ctx.complex_type()) { - Ok(vm.ctx.new_complex(self.value - get_value(&other))) - } else { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(self.value - other)), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value - other).into_pyobject(vm), + ) } #[pymethod(name = "__rsub__")] fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(other - self.value)), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other - self.value).into_pyobject(vm), + ) } #[pymethod(name = "conjugate")] - fn conjugate(&self, _vm: &VirtualMachine) -> PyComplex { - self.value.conj().into() + fn conjugate(&self, _vm: &VirtualMachine) -> Complex64 { + self.value.conj() } #[pymethod(name = "__eq__")] fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { let result = if objtype::isinstance(&other, &vm.ctx.complex_type()) { self.value == get_value(&other) - } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - match objint::get_value(&other).to_f64() { - Some(f) => self.value.im == 0.0f64 && self.value.re == f, - None => false, - } - } else if objtype::isinstance(&other, &vm.ctx.float_type()) { - self.value.im == 0.0 && self.value.re == objfloat::get_value(&other) } else { - return vm.ctx.not_implemented(); + match objfloat::try_float(&other, vm) { + Ok(Some(other)) => self.value.im == 0.0f64 && self.value.re == other, + Err(_) => false, + Ok(None) => return vm.ctx.not_implemented(), + } }; vm.ctx.new_bool(result) @@ -175,19 +135,66 @@ impl PyComplex { #[pymethod(name = "__mul__")] fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(Complex64::new( - self.value.re * other.re - self.value.im * other.im, - self.value.re * other.im + self.value.im * other.re, - ))), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value * other).into_pyobject(vm), + ) + } + + #[pymethod(name = "__rmul__")] + fn rmul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.mul(other, vm) + } + + #[pymethod(name = "__truediv__")] + fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value / other).into_pyobject(vm), + ) + } + + #[pymethod(name = "__rtruediv__")] + fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other / self.value).into_pyobject(vm), + ) + } + + #[pymethod(name = "__mod__")] + fn mod_(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("can't mod complex numbers.".to_string())) + } + + #[pymethod(name = "__rmod__")] + fn rmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.mod_(other, vm) + } + + #[pymethod(name = "__floordiv__")] + fn floordiv(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("can't take floor of complex number.".to_string())) + } + + #[pymethod(name = "__rfloordiv__")] + fn rfloordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.floordiv(other, vm) + } + + #[pymethod(name = "__divmod__")] + fn divmod(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("can't take floor or mod of complex number.".to_string())) + } + + #[pymethod(name = "__rdivmod__")] + fn rdivmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.divmod(other, vm) } #[pymethod(name = "__neg__")] - fn neg(&self, _vm: &VirtualMachine) -> PyComplex { - PyComplex::from(-self.value) + fn neg(&self, _vm: &VirtualMachine) -> Complex64 { + -self.value } #[pymethod(name = "__repr__")] @@ -200,8 +207,45 @@ impl PyComplex { } } + #[pymethod(name = "__pow__")] + fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value.powc(other)).into_pyobject(vm), + ) + } + + #[pymethod(name = "__rpow__")] + fn rpow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other.powc(self.value)).into_pyobject(vm), + ) + } + #[pymethod(name = "__bool__")] fn bool(&self, _vm: &VirtualMachine) -> bool { - self.value != Complex64::zero() + !Complex64::is_zero(&self.value) + } + + #[pymethod(name = "__new__")] + fn complex_new( + cls: PyClassRef, + real: OptionalArg, + imag: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let real = match real { + OptionalArg::Missing => 0.0, + OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, + }; + + let imag = match imag { + OptionalArg::Missing => 0.0, + OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, + }; + + let value = Complex64::new(real, imag); + PyComplex { value }.into_ref_with_type(vm, cls) } } diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index de4ee9ff5..b52a469e1 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -44,7 +44,7 @@ impl From for PyFloat { } } -fn try_float(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { +pub fn try_float(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { Ok(if objtype::isinstance(&value, &vm.ctx.float_type()) { Some(get_value(&value)) } else if objtype::isinstance(&value, &vm.ctx.int_type()) {