diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index d75c361fd..e6b6bdca7 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -12,9 +12,8 @@ use crate::format::FormatSpec; use crate::function::{OptionalArg, OptionalOption}; use crate::pyhash; use crate::pyobject::{ - IntoPyObject, - PyArithmaticValue::{self, *}, - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, + IntoPyObject, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, + PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -197,15 +196,15 @@ impl PyFloat { float_op: F, int_op: G, vm: &VirtualMachine, - ) -> PyArithmaticValue + ) -> PyComparisonValue where F: Fn(f64, f64) -> bool, G: Fn(f64, &BigInt) -> bool, { if let Some(other) = other.payload_if_subclass::(vm) { - ArithmaticValue(float_op(self.value, other.value)) + Implemented(float_op(self.value, other.value)) } else if let Some(other) = other.payload_if_subclass::(vm) { - ArithmaticValue(int_op(self.value, other.as_bigint())) + Implemented(int_op(self.value, other.as_bigint())) } else { NotImplemented } @@ -222,22 +221,22 @@ impl PyFloat { } #[pymethod(name = "__eq__")] - fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { - self.cmp(other, |a, b| a == b, |a, b| int_eq(a, b), vm) + fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { + self.cmp(other, |a, b| a == b, int_eq, vm) } #[pymethod(name = "__ne__")] - fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.eq(other, vm).map(|v| !v) } #[pymethod(name = "__lt__")] - fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { - self.cmp(other, |a, b| a < b, |a, b| inner_lt_int(a, b), vm) + fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { + self.cmp(other, |a, b| a < b, inner_lt_int, vm) } #[pymethod(name = "__le__")] - fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp( other, |a, b| a <= b, @@ -253,12 +252,12 @@ impl PyFloat { } #[pymethod(name = "__gt__")] - fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { - self.cmp(other, |a, b| a > b, |a, b| inner_gt_int(a, b), vm) + fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { + self.cmp(other, |a, b| a > b, inner_gt_int, vm) } #[pymethod(name = "__ge__")] - fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp( other, |a, b| a >= b, diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index a212fc0f3..112bc6975 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -16,9 +16,8 @@ use crate::format::FormatSpec; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyhash; use crate::pyobject::{ - IdProtocol, IntoPyObject, - PyArithmaticValue::{self, *}, - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, + IdProtocol, IntoPyObject, PyArithmaticValue, PyClassImpl, PyComparisonValue, PyContext, + PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -220,44 +219,43 @@ impl PyInt { } #[inline] - fn cmp(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyArithmaticValue + fn cmp(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyComparisonValue where F: Fn(&BigInt, &BigInt) -> bool, { - if let Some(other) = other.payload_if_subclass::(vm) { - ArithmaticValue(op(&self.value, &other.value)) - } else { - NotImplemented - } + let r = other + .payload_if_subclass::(vm) + .map(|other| op(&self.value, &other.value)); + PyComparisonValue::from_option(r) } #[pymethod(name = "__eq__")] - fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp(other, |a, b| a == b, vm) } #[pymethod(name = "__ne__")] - fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp(other, |a, b| a != b, vm) } #[pymethod(name = "__lt__")] - fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp(other, |a, b| a < b, vm) } #[pymethod(name = "__le__")] - fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp(other, |a, b| a <= b, vm) } #[pymethod(name = "__gt__")] - fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp(other, |a, b| a > b, vm) } #[pymethod(name = "__ge__")] - fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { + fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.cmp(other, |a, b| a >= b, vm) } @@ -266,11 +264,10 @@ impl PyInt { where F: Fn(&BigInt, &BigInt) -> BigInt, { - if let Some(other) = other.payload_if_subclass::(vm) { - ArithmaticValue(op(&self.value, &other.value)) - } else { - NotImplemented - } + let r = other + .payload_if_subclass::(vm) + .map(|other| op(&self.value, &other.value)); + PyArithmaticValue::from_option(r) } #[inline] diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 90235b8d8..009e20438 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1218,7 +1218,7 @@ impl TryFromObject for std::time::Duration { } } -result_like::option_like!(pub PyArithmaticValue, ArithmaticValue, NotImplemented); +result_like::option_like!(pub PyArithmaticValue, Implemented, NotImplemented); impl IntoPyObject for PyArithmaticValue where @@ -1226,12 +1226,14 @@ where { fn into_pyobject(self, vm: &VirtualMachine) -> PyResult { match self { - PyArithmaticValue::ArithmaticValue(v) => v.into_pyobject(vm), + PyArithmaticValue::Implemented(v) => v.into_pyobject(vm), PyArithmaticValue::NotImplemented => Ok(vm.ctx.not_implemented()), } } } +pub type PyComparisonValue = PyArithmaticValue; + #[cfg(test)] mod tests { use super::*;