From d66ca54a2d05d83b1629b53921cd79a6eba0350f Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Fri, 8 Feb 2019 18:20:58 -0800 Subject: [PATCH] Add complex.{__eq__, __neg__} --- tests/snippets/builtin_complex.py | 19 ++++++++++++++++++ vm/src/obj/objcomplex.rs | 32 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 5ac70ff59..8c897fd85 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -1,3 +1,22 @@ +# __abs__ + assert complex(3, 4).__abs__() == 5 assert complex(3, -4).__abs__() == 5 assert complex(1.5, 2.5).__abs__() == 2.9154759474226504 + +# __eq__ + +assert complex(1, -1).__eq__(complex(1, -1)) +assert complex(1, 0).__eq__(1) +assert not complex(1, 1).__eq__(1) +assert complex(1, 0).__eq__(1.0) +assert not complex(1, 1).__eq__(1.0) +assert not complex(1, 0).__eq__(1.5) +assert complex(1, 0).__eq__(True) +assert not complex(1, 2).__eq__(complex(1, 1)) +#assert complex(1, 2).__eq__('foo') == NotImplemented + +# __neg__ + +assert complex(1, -1).__neg__() == complex(-1, 1) +assert complex(0, 0).__neg__() == complex(0, 0) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 00556096b..e8bb66ea7 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -3,8 +3,10 @@ use super::super::pyobject::{ }; use super::super::vm::VirtualMachine; use super::objfloat; +use super::objint; use super::objtype; use num_complex::Complex64; +use num_traits::ToPrimitive; pub fn init(context: &PyContext) { let complex_type = &context.complex_type; @@ -15,6 +17,8 @@ pub fn init(context: &PyContext) { context.set_attr(&complex_type, "__abs__", context.new_rustfunc(complex_abs)); context.set_attr(&complex_type, "__add__", context.new_rustfunc(complex_add)); + context.set_attr(&complex_type, "__eq__", context.new_rustfunc(complex_eq)); + context.set_attr(&complex_type, "__neg__", context.new_rustfunc(complex_neg)); context.set_attr(&complex_type, "__new__", context.new_rustfunc(complex_new)); context.set_attr( &complex_type, @@ -100,6 +104,34 @@ fn complex_conjugate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_complex(v1.conj())) } +fn complex_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, Some(vm.ctx.complex_type())), (other, None)] + ); + + let z = get_value(zelf); + let result = if objtype::isinstance(other, &vm.ctx.complex_type()) { + z == get_value(other) + } else if objtype::isinstance(other, &vm.ctx.int_type()) { + match objint::get_value(other).to_f64() { + Some(f) => z.im == 0.0f64 && z.re == f, + None => false, + } + } else if objtype::isinstance(other, &vm.ctx.float_type()) { + z.im == 0.0 && z.re == objfloat::get_value(other) + } else { + false + }; + Ok(vm.ctx.new_bool(result)) +} + +fn complex_neg(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]); + Ok(vm.ctx.new_complex(-get_value(zelf))) +} + fn complex_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, Some(vm.ctx.complex_type()))]); let v = get_value(obj);