Fix comparison operator

This commit is contained in:
Jeong YunWon
2020-01-25 23:16:40 +09:00
parent 8bc915711e
commit 3698d0e438
2 changed files with 50 additions and 14 deletions

View File

@@ -70,8 +70,6 @@ class ComparisonTest(unittest.TestCase):
Left() != Right()
self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__'])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_ne_low_priority(self):
"""object.__ne__() should not invoke reflected __eq__()"""
calls = []

View File

@@ -1145,40 +1145,78 @@ impl VirtualMachine {
})
}
// Perform a comparison, raising TypeError when the requested comparison
// operator is not supported.
// see: CPython PyObject_RichCompare
fn _cmp<F>(
&self,
v: PyObjectRef,
w: PyObjectRef,
op: &str,
swap_op: &str,
default: F,
) -> PyResult
where
F: Fn(&VirtualMachine, PyObjectRef, PyObjectRef) -> PyResult,
{
// TODO: _Py_EnterRecursiveCall(tstate, " in comparison")
let mut checked_reverse_op = false;
if !v.typ.is(&w.typ) && objtype::issubclass(&w.class(), &v.class()) {
if let Some(method_or_err) = self.get_method(w.clone(), swap_op) {
let method = method_or_err?;
checked_reverse_op = true;
let result = self.invoke(&method, vec![v.clone()])?;
if !result.is(&self.ctx.not_implemented()) {
return Ok(result);
}
}
}
self.call_or_unsupported(v, w, op, |vm, v, w| {
if !checked_reverse_op {
self.call_or_unsupported(w, v, swap_op, |vm, v, w| default(vm, v, w))
} else {
default(vm, v, w)
}
})
// TODO: _Py_LeaveRecursiveCall(tstate);
}
pub fn _eq(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
self.call_or_reflection(
a,
b,
"__eq__",
"__eq__",
|vm, _a, _b| Ok(vm.new_bool(false)),
)
self._cmp(a, b, "__eq__", "__eq__", |vm, a, b| {
Ok(vm.new_bool(a.is(&b)))
})
}
pub fn _ne(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
self.call_or_reflection(a, b, "__ne__", "__ne__", |vm, _a, _b| Ok(vm.new_bool(true)))
self._cmp(a, b, "__ne__", "__ne__", |vm, a, b| {
Ok(vm.new_bool(!a.is(&b)))
})
}
pub fn _lt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
self.call_or_reflection(a, b, "__lt__", "__gt__", |vm, a, b| {
self._cmp(a, b, "__lt__", "__gt__", |vm, a, b| {
Err(vm.new_unsupported_operand_error(a, b, "<"))
})
}
pub fn _le(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
self.call_or_reflection(a, b, "__le__", "__ge__", |vm, a, b| {
self._cmp(a, b, "__le__", "__ge__", |vm, a, b| {
Err(vm.new_unsupported_operand_error(a, b, "<="))
})
}
pub fn _gt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
self.call_or_reflection(a, b, "__gt__", "__lt__", |vm, a, b| {
self._cmp(a, b, "__gt__", "__lt__", |vm, a, b| {
Err(vm.new_unsupported_operand_error(a, b, ">"))
})
}
pub fn _ge(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
self.call_or_reflection(a, b, "__ge__", "__le__", |vm, a, b| {
self._cmp(a, b, "__ge__", "__le__", |vm, a, b| {
Err(vm.new_unsupported_operand_error(a, b, ">="))
})
}