Refactor PyComplex

This commit is contained in:
Jeong YunWon
2019-12-21 18:32:31 +09:00
parent 801d01161c
commit cbbacbaa47

View File

@@ -3,7 +3,7 @@ use num_traits::Zero;
use std::num::Wrapping;
use super::objfloat::{self, IntoPyFloat};
use super::objtype::{self, PyClassRef};
use super::objtype::PyClassRef;
use crate::function::OptionalArg;
use crate::pyhash;
use crate::pyobject::{
@@ -43,18 +43,15 @@ pub fn init(context: &PyContext) {
PyComplex::extend_class(context, &context.types.complex_type);
}
pub fn get_value(obj: &PyObjectRef) -> Complex64 {
obj.payload::<PyComplex>().unwrap().value
}
fn try_complex(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult<Option<Complex64>> {
Ok(if objtype::isinstance(&value, &vm.ctx.complex_type()) {
Some(get_value(&value))
let r = if let Some(complex) = value.payload_if_subclass::<PyComplex>(vm) {
Some(complex.value)
} else if let Some(float) = objfloat::try_float(value, vm)? {
Some(Complex64::new(float, 0.0))
} else {
None
})
};
Ok(r)
}
#[pyimpl]
@@ -75,14 +72,22 @@ impl PyComplex {
re.hypot(im)
}
#[pymethod(name = "__add__")]
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
#[inline]
fn op<F>(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyResult
where
F: Fn(Complex64, Complex64) -> Complex64,
{
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (self.value + other).into_pyobject(vm),
|other| op(self.value, other).into_pyobject(vm),
)
}
#[pymethod(name = "__add__")]
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
self.op(other, |a, b| a + b, vm)
}
#[pymethod(name = "__radd__")]
fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
self.add(other, vm)
@@ -90,18 +95,12 @@ impl PyComplex {
#[pymethod(name = "__sub__")]
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (self.value - other).into_pyobject(vm),
)
self.op(other, |a, b| a - b, vm)
}
#[pymethod(name = "__rsub__")]
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (other - self.value).into_pyobject(vm),
)
self.op(other, |a, b| b - a, vm)
}
#[pymethod(name = "conjugate")]
@@ -111,8 +110,8 @@ impl PyComplex {
#[pymethod(name = "__eq__")]
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
let result = if objtype::isinstance(&other, &vm.ctx.complex_type()) {
self.value == get_value(&other)
let result = if let Some(other) = other.payload_if_subclass::<PyComplex>(vm) {
self.value == other.value
} else {
match objfloat::try_float(&other, vm) {
Ok(Some(other)) => self.value.im == 0.0f64 && self.value.re == other,
@@ -136,10 +135,7 @@ impl PyComplex {
#[pymethod(name = "__mul__")]
fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (self.value * other).into_pyobject(vm),
)
self.op(other, |a, b| a * b, vm)
}
#[pymethod(name = "__rmul__")]
@@ -149,18 +145,12 @@ impl PyComplex {
#[pymethod(name = "__truediv__")]
fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (self.value / other).into_pyobject(vm),
)
self.op(other, |a, b| a / b, vm)
}
#[pymethod(name = "__rtruediv__")]
fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (other / self.value).into_pyobject(vm),
)
self.op(other, |a, b| b / a, vm)
}
#[pymethod(name = "__mod__")]
@@ -210,18 +200,12 @@ impl PyComplex {
#[pymethod(name = "__pow__")]
fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (self.value.powc(other)).into_pyobject(vm),
)
self.op(other, |a, b| a.powc(b), vm)
}
#[pymethod(name = "__rpow__")]
fn rpow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
try_complex(&other, vm)?.map_or_else(
|| Ok(vm.ctx.not_implemented()),
|other| (other.powc(self.value)).into_pyobject(vm),
)
self.op(other, |a, b| b.powc(a), vm)
}
#[pymethod(name = "__bool__")]