From cbbacbaa47149c0b8f3d5685cbcce4afe80d88ce Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 21 Dec 2019 18:32:31 +0900 Subject: [PATCH] Refactor PyComplex --- vm/src/obj/objcomplex.rs | 66 +++++++++++++++------------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 360371d38..80534a6cb 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -3,7 +3,7 @@ use num_traits::Zero; use std::num::Wrapping; use super::objfloat::{self, IntoPyFloat}; -use super::objtype::{self, PyClassRef}; +use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyhash; use crate::pyobject::{ @@ -43,18 +43,15 @@ pub fn init(context: &PyContext) { PyComplex::extend_class(context, &context.types.complex_type); } -pub fn get_value(obj: &PyObjectRef) -> Complex64 { - obj.payload::().unwrap().value -} - fn try_complex(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { - Ok(if objtype::isinstance(&value, &vm.ctx.complex_type()) { - Some(get_value(&value)) + let r = if let Some(complex) = value.payload_if_subclass::(vm) { + Some(complex.value) } else if let Some(float) = objfloat::try_float(value, vm)? { Some(Complex64::new(float, 0.0)) } else { None - }) + }; + Ok(r) } #[pyimpl] @@ -75,14 +72,22 @@ impl PyComplex { re.hypot(im) } - #[pymethod(name = "__add__")] - fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + #[inline] + fn op(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyResult + where + F: Fn(Complex64, Complex64) -> Complex64, + { try_complex(&other, vm)?.map_or_else( || Ok(vm.ctx.not_implemented()), - |other| (self.value + other).into_pyobject(vm), + |other| op(self.value, other).into_pyobject(vm), ) } + #[pymethod(name = "__add__")] + fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.op(other, |a, b| a + b, vm) + } + #[pymethod(name = "__radd__")] fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.add(other, vm) @@ -90,18 +95,12 @@ impl PyComplex { #[pymethod(name = "__sub__")] fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (self.value - other).into_pyobject(vm), - ) + self.op(other, |a, b| a - b, vm) } #[pymethod(name = "__rsub__")] fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (other - self.value).into_pyobject(vm), - ) + self.op(other, |a, b| b - a, vm) } #[pymethod(name = "conjugate")] @@ -111,8 +110,8 @@ impl PyComplex { #[pymethod(name = "__eq__")] fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - let result = if objtype::isinstance(&other, &vm.ctx.complex_type()) { - self.value == get_value(&other) + let result = if let Some(other) = other.payload_if_subclass::(vm) { + self.value == other.value } else { match objfloat::try_float(&other, vm) { Ok(Some(other)) => self.value.im == 0.0f64 && self.value.re == other, @@ -136,10 +135,7 @@ impl PyComplex { #[pymethod(name = "__mul__")] fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (self.value * other).into_pyobject(vm), - ) + self.op(other, |a, b| a * b, vm) } #[pymethod(name = "__rmul__")] @@ -149,18 +145,12 @@ impl PyComplex { #[pymethod(name = "__truediv__")] fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (self.value / other).into_pyobject(vm), - ) + self.op(other, |a, b| a / b, vm) } #[pymethod(name = "__rtruediv__")] fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (other / self.value).into_pyobject(vm), - ) + self.op(other, |a, b| b / a, vm) } #[pymethod(name = "__mod__")] @@ -210,18 +200,12 @@ impl PyComplex { #[pymethod(name = "__pow__")] fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (self.value.powc(other)).into_pyobject(vm), - ) + self.op(other, |a, b| a.powc(b), vm) } #[pymethod(name = "__rpow__")] fn rpow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - try_complex(&other, vm)?.map_or_else( - || Ok(vm.ctx.not_implemented()), - |other| (other.powc(self.value)).into_pyobject(vm), - ) + self.op(other, |a, b| b.powc(a), vm) } #[pymethod(name = "__bool__")]