PyComparisonValue

This commit is contained in:
Jeong YunWon
2020-01-05 03:00:14 +09:00
parent b1e582e138
commit 6bba9ff446
3 changed files with 35 additions and 37 deletions

View File

@@ -12,9 +12,8 @@ use crate::format::FormatSpec;
use crate::function::{OptionalArg, OptionalOption};
use crate::pyhash;
use crate::pyobject::{
IntoPyObject,
PyArithmaticValue::{self, *},
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
IntoPyObject, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef,
PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
};
use crate::vm::VirtualMachine;
@@ -197,15 +196,15 @@ impl PyFloat {
float_op: F,
int_op: G,
vm: &VirtualMachine,
) -> PyArithmaticValue<bool>
) -> PyComparisonValue
where
F: Fn(f64, f64) -> bool,
G: Fn(f64, &BigInt) -> bool,
{
if let Some(other) = other.payload_if_subclass::<PyFloat>(vm) {
ArithmaticValue(float_op(self.value, other.value))
Implemented(float_op(self.value, other.value))
} else if let Some(other) = other.payload_if_subclass::<PyInt>(vm) {
ArithmaticValue(int_op(self.value, other.as_bigint()))
Implemented(int_op(self.value, other.as_bigint()))
} else {
NotImplemented
}
@@ -222,22 +221,22 @@ impl PyFloat {
}
#[pymethod(name = "__eq__")]
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
self.cmp(other, |a, b| a == b, |a, b| int_eq(a, b), vm)
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a == b, int_eq, vm)
}
#[pymethod(name = "__ne__")]
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.eq(other, vm).map(|v| !v)
}
#[pymethod(name = "__lt__")]
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
self.cmp(other, |a, b| a < b, |a, b| inner_lt_int(a, b), vm)
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a < b, inner_lt_int, vm)
}
#[pymethod(name = "__le__")]
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(
other,
|a, b| a <= b,
@@ -253,12 +252,12 @@ impl PyFloat {
}
#[pymethod(name = "__gt__")]
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
self.cmp(other, |a, b| a > b, |a, b| inner_gt_int(a, b), vm)
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a > b, inner_gt_int, vm)
}
#[pymethod(name = "__ge__")]
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(
other,
|a, b| a >= b,

View File

@@ -16,9 +16,8 @@ use crate::format::FormatSpec;
use crate::function::{OptionalArg, PyFuncArgs};
use crate::pyhash;
use crate::pyobject::{
IdProtocol, IntoPyObject,
PyArithmaticValue::{self, *},
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
IdProtocol, IntoPyObject, PyArithmaticValue, PyClassImpl, PyComparisonValue, PyContext,
PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
};
use crate::vm::VirtualMachine;
@@ -220,44 +219,43 @@ impl PyInt {
}
#[inline]
fn cmp<F>(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyArithmaticValue<bool>
fn cmp<F>(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyComparisonValue
where
F: Fn(&BigInt, &BigInt) -> bool,
{
if let Some(other) = other.payload_if_subclass::<PyInt>(vm) {
ArithmaticValue(op(&self.value, &other.value))
} else {
NotImplemented
}
let r = other
.payload_if_subclass::<PyInt>(vm)
.map(|other| op(&self.value, &other.value));
PyComparisonValue::from_option(r)
}
#[pymethod(name = "__eq__")]
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a == b, vm)
}
#[pymethod(name = "__ne__")]
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a != b, vm)
}
#[pymethod(name = "__lt__")]
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a < b, vm)
}
#[pymethod(name = "__le__")]
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a <= b, vm)
}
#[pymethod(name = "__gt__")]
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a > b, vm)
}
#[pymethod(name = "__ge__")]
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue<bool> {
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue {
self.cmp(other, |a, b| a >= b, vm)
}
@@ -266,11 +264,10 @@ impl PyInt {
where
F: Fn(&BigInt, &BigInt) -> BigInt,
{
if let Some(other) = other.payload_if_subclass::<PyInt>(vm) {
ArithmaticValue(op(&self.value, &other.value))
} else {
NotImplemented
}
let r = other
.payload_if_subclass::<PyInt>(vm)
.map(|other| op(&self.value, &other.value));
PyArithmaticValue::from_option(r)
}
#[inline]

View File

@@ -1218,7 +1218,7 @@ impl TryFromObject for std::time::Duration {
}
}
result_like::option_like!(pub PyArithmaticValue, ArithmaticValue, NotImplemented);
result_like::option_like!(pub PyArithmaticValue, Implemented, NotImplemented);
impl<T> IntoPyObject for PyArithmaticValue<T>
where
@@ -1226,12 +1226,14 @@ where
{
fn into_pyobject(self, vm: &VirtualMachine) -> PyResult {
match self {
PyArithmaticValue::ArithmaticValue(v) => v.into_pyobject(vm),
PyArithmaticValue::Implemented(v) => v.into_pyobject(vm),
PyArithmaticValue::NotImplemented => Ok(vm.ctx.not_implemented()),
}
}
}
pub type PyComparisonValue = PyArithmaticValue<bool>;
#[cfg(test)]
mod tests {
use super::*;