From 9448254914d7202c846ac6fa00a300fafdc5aea0 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 01:11:14 +0900 Subject: [PATCH 01/10] PyComplex uses extend_class for __new__ and __doc__ --- vm/src/obj/objcomplex.rs | 54 ++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index f9ff9d9b0..ec7afd110 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -9,6 +9,9 @@ use super::objfloat::{self, PyFloat}; use super::objint; use super::objtype::{self, PyClassRef}; +/// Create a complex number from a real part and an optional imaginary part. +/// +/// This is equivalent to (real + imag*1j) where imag defaults to 0. #[pyclass(name = "complex")] #[derive(Debug, Copy, Clone, PartialEq)] pub struct PyComplex { @@ -30,42 +33,12 @@ impl From for PyComplex { pub fn init(context: &PyContext) { PyComplex::extend_class(context, &context.complex_type); - let complex_doc = - "Create a complex number from a real part and an optional imaginary part.\n\n\ - This is equivalent to (real + imag*1j) where imag defaults to 0."; - - extend_class!(context, &context.complex_type, { - "__doc__" => context.new_str(complex_doc.to_string()), - "__new__" => context.new_rustfunc(PyComplexRef::new), - }); } pub fn get_value(obj: &PyObjectRef) -> Complex64 { obj.payload::().unwrap().value } -impl PyComplexRef { - fn new( - cls: PyClassRef, - real: OptionalArg, - imag: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - let real = match real { - OptionalArg::Missing => 0.0, - OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, - }; - - let imag = match imag { - OptionalArg::Missing => 0.0, - OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, - }; - - let value = Complex64::new(real, imag); - PyComplex { value }.into_ref_with_type(vm, cls) - } -} - fn to_complex(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { if objtype::isinstance(&value, &vm.ctx.complex_type()) { Ok(Some(get_value(&value))) @@ -204,4 +177,25 @@ impl PyComplex { fn bool(&self, _vm: &VirtualMachine) -> bool { self.value != Complex64::zero() } + + #[pymethod(name = "__new__")] + fn complex_new( + cls: PyClassRef, + real: OptionalArg, + imag: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let real = match real { + OptionalArg::Missing => 0.0, + OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, + }; + + let imag = match imag { + OptionalArg::Missing => 0.0, + OptionalArg::Present(ref value) => objfloat::make_float(vm, value)?, + }; + + let value = Complex64::new(real, imag); + PyComplex { value }.into_ref_with_type(vm, cls) + } } From 7b438d9be8e5b4a821e80e0d1494e678233c892c Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 01:51:01 +0900 Subject: [PATCH 02/10] impl IntoPyObject for Complex64 --- vm/src/obj/objcomplex.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index ec7afd110..ab258064a 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -2,7 +2,9 @@ use num_complex::Complex64; use num_traits::{ToPrimitive, Zero}; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, +}; use crate::vm::VirtualMachine; use super::objfloat::{self, PyFloat}; @@ -25,6 +27,12 @@ impl PyValue for PyComplex { } } +impl IntoPyObject for Complex64 { + fn into_pyobject(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_complex(self)) + } +} + impl From for PyComplex { fn from(value: Complex64) -> Self { PyComplex { value } @@ -114,8 +122,8 @@ impl PyComplex { } #[pymethod(name = "conjugate")] - fn conjugate(&self, _vm: &VirtualMachine) -> PyComplex { - self.value.conj().into() + fn conjugate(&self, _vm: &VirtualMachine) -> Complex64 { + self.value.conj() } #[pymethod(name = "__eq__")] @@ -148,19 +156,15 @@ impl PyComplex { #[pymethod(name = "__mul__")] fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(Complex64::new( - self.value.re * other.re - self.value.im * other.im, - self.value.re * other.im + self.value.im * other.re, - ))), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value * other).into_pyobject(vm), + ) } #[pymethod(name = "__neg__")] - fn neg(&self, _vm: &VirtualMachine) -> PyComplex { - PyComplex::from(-self.value) + fn neg(&self, _vm: &VirtualMachine) -> Complex64 { + -self.value } #[pymethod(name = "__repr__")] From 9a7fadcb6c457b1704f1e357991f485b0b7eb6d8 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 01:52:42 +0900 Subject: [PATCH 03/10] Refactor PyComplex using try_complex --- vm/src/obj/objcomplex.rs | 57 ++++++++++++++-------------------------- vm/src/obj/objfloat.rs | 2 +- 2 files changed, 21 insertions(+), 38 deletions(-) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index ab258064a..eaee63259 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -47,20 +47,14 @@ pub fn get_value(obj: &PyObjectRef) -> Complex64 { obj.payload::().unwrap().value } -fn to_complex(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if objtype::isinstance(&value, &vm.ctx.complex_type()) { - Ok(Some(get_value(&value))) - } else if objtype::isinstance(&value, &vm.ctx.int_type()) { - match objint::get_value(&value).to_f64() { - Some(v) => Ok(Some(Complex64::new(v, 0.0))), - None => Err(vm.new_overflow_error("int too large to convert to float".to_string())), - } - } else if objtype::isinstance(&value, &vm.ctx.float_type()) { - let v = objfloat::get_value(&value); - Ok(Some(Complex64::new(v, 0.0))) +fn try_complex(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Ok(if objtype::isinstance(&value, &vm.ctx.complex_type()) { + Some(get_value(&value)) + } else if let Some(float) = objfloat::try_float(value, vm)? { + Some(Complex64::new(float, 0.0)) } else { - Ok(None) - } + None + }) } #[pyimpl] @@ -83,42 +77,31 @@ impl PyComplex { #[pymethod(name = "__add__")] fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if objtype::isinstance(&other, &vm.ctx.complex_type()) { - Ok(vm.ctx.new_complex(self.value + get_value(&other))) - } else { - self.radd(other, vm) - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value + other).into_pyobject(vm), + ) } #[pymethod(name = "__radd__")] fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(self.value + other)), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } + self.add(other, vm) } #[pymethod(name = "__sub__")] fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if objtype::isinstance(&other, &vm.ctx.complex_type()) { - Ok(vm.ctx.new_complex(self.value - get_value(&other))) - } else { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(self.value - other)), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value - other).into_pyobject(vm), + ) } #[pymethod(name = "__rsub__")] fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match to_complex(other, vm) { - Ok(Some(other)) => Ok(vm.ctx.new_complex(other - self.value)), - Ok(None) => Ok(vm.ctx.not_implemented()), - Err(err) => Err(err), - } + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other - self.value).into_pyobject(vm), + ) } #[pymethod(name = "conjugate")] diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index de4ee9ff5..b52a469e1 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -44,7 +44,7 @@ impl From for PyFloat { } } -fn try_float(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { +pub fn try_float(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { Ok(if objtype::isinstance(&value, &vm.ctx.float_type()) { Some(get_value(&value)) } else if objtype::isinstance(&value, &vm.ctx.int_type()) { From 61de5f2efc7f30d2f968638ab6848943f061d831 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 02:03:28 +0900 Subject: [PATCH 04/10] complex.__eq__ using try_float --- tests/snippets/builtin_complex.py | 8 +++++--- vm/src/obj/objcomplex.rs | 16 ++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 4d0fa1202..1c0e78b4e 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -23,6 +23,7 @@ assert bool(complex(1, 0)) assert complex(1, 2) != complex(1, 1) assert complex(1, 2) != 'foo' assert complex(1, 2).__eq__('foo') == NotImplemented +assert 1j != 10 ** 1000 # __mul__ @@ -40,18 +41,19 @@ assert bool(complex(0, 0)) is False assert bool(complex(0, 1)) is True assert bool(complex(1, 0)) is True -# real +# numbers.Complex a = complex(3, 4) b = 4j assert a.real == 3 assert b.real == 0 -# imag - assert a.imag == 4 assert b.imag == 4 +assert a.conjugate() == 3 - 4j +assert b.conjugate() == -4j + # int and complex addition assert 1 + 1j == complex(1, 1) assert 1j + 1 == complex(1, 1) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index eaee63259..6f13e76fa 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -1,5 +1,5 @@ use num_complex::Complex64; -use num_traits::{ToPrimitive, Zero}; +use num_traits::Zero; use crate::function::OptionalArg; use crate::pyobject::{ @@ -8,7 +8,6 @@ use crate::pyobject::{ use crate::vm::VirtualMachine; use super::objfloat::{self, PyFloat}; -use super::objint; use super::objtype::{self, PyClassRef}; /// Create a complex number from a real part and an optional imaginary part. @@ -113,15 +112,12 @@ impl PyComplex { fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { let result = if objtype::isinstance(&other, &vm.ctx.complex_type()) { self.value == get_value(&other) - } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - match objint::get_value(&other).to_f64() { - Some(f) => self.value.im == 0.0f64 && self.value.re == f, - None => false, - } - } else if objtype::isinstance(&other, &vm.ctx.float_type()) { - self.value.im == 0.0 && self.value.re == objfloat::get_value(&other) } else { - return vm.ctx.not_implemented(); + match objfloat::try_float(&other, vm) { + Ok(Some(other)) => self.value.im == 0.0f64 && self.value.re == other, + Err(_) => false, + Ok(None) => return vm.ctx.not_implemented(), + } }; vm.ctx.new_bool(result) From 982bbd69d803b288dc052c1b0e872d11d2628294 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 02:07:50 +0900 Subject: [PATCH 05/10] complex.__bool__ uses Zero::is_zero instead of zero() --- vm/src/obj/objcomplex.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 6f13e76fa..9c38650b0 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -158,7 +158,7 @@ impl PyComplex { #[pymethod(name = "__bool__")] fn bool(&self, _vm: &VirtualMachine) -> bool { - self.value != Complex64::zero() + !Complex64::is_zero(&self.value) } #[pymethod(name = "__new__")] From 2a2d0e47648fef006303a4a629a41861bd6c465d Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 02:11:13 +0900 Subject: [PATCH 06/10] Add complex.__rmul__ --- tests/snippets/builtin_complex.py | 1 + vm/src/obj/objcomplex.rs | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 1c0e78b4e..20300c4e7 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -29,6 +29,7 @@ assert 1j != 10 ** 1000 assert complex(2, -3) * complex(-5, 7) == complex(11, 29) assert complex(2, -3) * 5 == complex(10, -15) +assert 5 * complex(2, -3) == complex(2, -3) * 5 # __neg__ diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 9c38650b0..83dd4ebf1 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -141,6 +141,11 @@ impl PyComplex { ) } + #[pymethod(name = "__rmul__")] + fn rmul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.mul(other, vm) + } + #[pymethod(name = "__neg__")] fn neg(&self, _vm: &VirtualMachine) -> Complex64 { -self.value From 7c8880fb4ad5d30e42bca3272f1eb04414353ec1 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 02:27:50 +0900 Subject: [PATCH 07/10] complex overflow message test --- tests/snippets/builtin_complex.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 20300c4e7..a3a911fa7 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -1,4 +1,4 @@ -from testutils import assertRaises +from testutils import assert_raises # __abs__ @@ -73,19 +73,13 @@ assert 1j - 1 == complex(-1, 1) assert 2j - 1j == complex(0, 1) # type error addition -with assertRaises(TypeError): - assert 1j + 'str' -with assertRaises(TypeError): - assert 1j - 'str' -with assertRaises(TypeError): - assert 'str' + 1j -with assertRaises(TypeError): - assert 'str' - 1j +assert_raises(TypeError, lambda: 1j + 'str') +assert_raises(TypeError, lambda: 1j - 'str') +assert_raises(TypeError, lambda: 'str' + 1j) +assert_raises(TypeError, lambda: 'str' - 1j) # overflow -with assertRaises(OverflowError): - complex(10 ** 1000, 0) -with assertRaises(OverflowError): - complex(0, 10 ** 1000) -with assertRaises(OverflowError): - complex(0, 0) + 10 ** 1000 +msg = 'int too large to convert to float' +assert_raises(OverflowError, lambda: complex(10 ** 1000, 0), msg) +assert_raises(OverflowError, lambda: complex(0, 10 ** 1000), msg) +assert_raises(OverflowError, lambda: 0j + 10 ** 1000, msg) From 9523baf5ac4c6f8e242aba9b624440082989de18 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 02:35:16 +0900 Subject: [PATCH 08/10] complex [r]truediv, [r]floordiv --- tests/snippets/builtin_complex.py | 18 +++++++++++++++++- vm/src/obj/objcomplex.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index a3a911fa7..5c28a6b11 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -25,12 +25,28 @@ assert complex(1, 2) != 'foo' assert complex(1, 2).__eq__('foo') == NotImplemented assert 1j != 10 ** 1000 -# __mul__ +# __mul__, __rmul__ assert complex(2, -3) * complex(-5, 7) == complex(11, 29) assert complex(2, -3) * 5 == complex(10, -15) assert 5 * complex(2, -3) == complex(2, -3) * 5 +# __truediv__, __rtruediv__ + +assert complex(2, -3) / 2 == complex(1, -1.5) +assert 5 / complex(3, -4) == complex(0.6, 0.8) + +# __floordiv__, __rfloordiv__ + +assert_raises( + TypeError, + lambda: complex(2, -3) // 2, + "can't take floor of complex number.") +assert_raises( + TypeError, + lambda: 2 // complex(2, -3), + "can't take floor of complex number.") + # __neg__ assert -complex(1, -1) == complex(-1, 1) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 83dd4ebf1..c0392a3ab 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -146,6 +146,32 @@ impl PyComplex { self.mul(other, vm) } + #[pymethod(name = "__truediv__")] + fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value / other).into_pyobject(vm), + ) + } + + #[pymethod(name = "__rtruediv__")] + fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other / self.value).into_pyobject(vm), + ) + } + + #[pymethod(name = "__floordiv__")] + fn floordiv(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("can't take floor of complex number.".to_string())) + } + + #[pymethod(name = "__rfloordiv__")] + fn rfloordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.floordiv(other, vm) + } + #[pymethod(name = "__neg__")] fn neg(&self, _vm: &VirtualMachine) -> Complex64 { -self.value From 930c8eef503a4b7254759c19a09a2982e944a96d Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 02:57:23 +0900 Subject: [PATCH 09/10] Add complex.{__mod__, __rmod__, __divmod__, __rdivmod__} --- tests/snippets/builtin_complex.py | 22 ++++++++++++++++++++++ vm/src/obj/objcomplex.rs | 20 ++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 5c28a6b11..81cc9186e 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -36,6 +36,17 @@ assert 5 * complex(2, -3) == complex(2, -3) * 5 assert complex(2, -3) / 2 == complex(1, -1.5) assert 5 / complex(3, -4) == complex(0.6, 0.8) +# __mod__, __rmod__ + +assert_raises( + TypeError, + lambda: complex(2, -3) % 2, + "can't mod complex numbers.") +assert_raises( + TypeError, + lambda: 2 % complex(2, -3), + "can't mod complex numbers.") + # __floordiv__, __rfloordiv__ assert_raises( @@ -47,6 +58,17 @@ assert_raises( lambda: 2 // complex(2, -3), "can't take floor of complex number.") +# __divmod__, __rdivmod__ + +assert_raises( + TypeError, + lambda: divmod(complex(2, -3), 2), + "can't take floor or mod of complex number.") +assert_raises( + TypeError, + lambda: divmod(2, complex(2, -3)), + "can't take floor or mod of complex number.") + # __neg__ assert -complex(1, -1) == complex(-1, 1) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index c0392a3ab..3343b5dc7 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -162,6 +162,16 @@ impl PyComplex { ) } + #[pymethod(name = "__mod__")] + fn mod_(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("can't mod complex numbers.".to_string())) + } + + #[pymethod(name = "__rmod__")] + fn rmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.mod_(other, vm) + } + #[pymethod(name = "__floordiv__")] fn floordiv(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { Err(vm.new_type_error("can't take floor of complex number.".to_string())) @@ -172,6 +182,16 @@ impl PyComplex { self.floordiv(other, vm) } + #[pymethod(name = "__divmod__")] + fn divmod(&self, _other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("can't take floor or mod of complex number.".to_string())) + } + + #[pymethod(name = "__rdivmod__")] + fn rdivmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.divmod(other, vm) + } + #[pymethod(name = "__neg__")] fn neg(&self, _vm: &VirtualMachine) -> Complex64 { -self.value From 88e64adc561b4bac6b34f07b01f2c26252f25f9e Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 1 May 2019 05:43:31 +0900 Subject: [PATCH 10/10] Add complex __pow__ and __rpow__ --- tests/snippets/builtin_complex.py | 6 ++++++ vm/src/obj/objcomplex.rs | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 81cc9186e..8d883a9dd 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -69,6 +69,12 @@ assert_raises( lambda: divmod(2, complex(2, -3)), "can't take floor or mod of complex number.") +# __pow__, __rpow__ + +# assert 1j ** 2 == -1 +assert complex(1) ** 2 == 1 +assert 2 ** complex(2) == 4 + # __neg__ assert -complex(1, -1) == complex(-1, 1) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 3343b5dc7..677763213 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -207,6 +207,22 @@ impl PyComplex { } } + #[pymethod(name = "__pow__")] + fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value.powc(other)).into_pyobject(vm), + ) + } + + #[pymethod(name = "__rpow__")] + fn rpow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_complex(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other.powc(self.value)).into_pyobject(vm), + ) + } + #[pymethod(name = "__bool__")] fn bool(&self, _vm: &VirtualMachine) -> bool { !Complex64::is_zero(&self.value)