From 3698d0e4388ef0f30eadb9b767320eeea70ebfdf Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 25 Jan 2020 23:16:40 +0900 Subject: [PATCH] Fix comparison operator --- Lib/test/test_compare.py | 2 -- vm/src/vm.rs | 62 ++++++++++++++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/Lib/test/test_compare.py b/Lib/test/test_compare.py index 46f6918bb..471c8dae7 100644 --- a/Lib/test/test_compare.py +++ b/Lib/test/test_compare.py @@ -70,8 +70,6 @@ class ComparisonTest(unittest.TestCase): Left() != Right() self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__']) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ne_low_priority(self): """object.__ne__() should not invoke reflected __eq__()""" calls = [] diff --git a/vm/src/vm.rs b/vm/src/vm.rs index eb0882b7c..38762ca77 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -1145,40 +1145,78 @@ impl VirtualMachine { }) } + // Perform a comparison, raising TypeError when the requested comparison + // operator is not supported. + // see: CPython PyObject_RichCompare + fn _cmp( + &self, + v: PyObjectRef, + w: PyObjectRef, + op: &str, + swap_op: &str, + default: F, + ) -> PyResult + where + F: Fn(&VirtualMachine, PyObjectRef, PyObjectRef) -> PyResult, + { + // TODO: _Py_EnterRecursiveCall(tstate, " in comparison") + + let mut checked_reverse_op = false; + if !v.typ.is(&w.typ) && objtype::issubclass(&w.class(), &v.class()) { + if let Some(method_or_err) = self.get_method(w.clone(), swap_op) { + let method = method_or_err?; + checked_reverse_op = true; + + let result = self.invoke(&method, vec![v.clone()])?; + if !result.is(&self.ctx.not_implemented()) { + return Ok(result); + } + } + } + + self.call_or_unsupported(v, w, op, |vm, v, w| { + if !checked_reverse_op { + self.call_or_unsupported(w, v, swap_op, |vm, v, w| default(vm, v, w)) + } else { + default(vm, v, w) + } + }) + + // TODO: _Py_LeaveRecursiveCall(tstate); + } + pub fn _eq(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - self.call_or_reflection( - a, - b, - "__eq__", - "__eq__", - |vm, _a, _b| Ok(vm.new_bool(false)), - ) + self._cmp(a, b, "__eq__", "__eq__", |vm, a, b| { + Ok(vm.new_bool(a.is(&b))) + }) } pub fn _ne(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - self.call_or_reflection(a, b, "__ne__", "__ne__", |vm, _a, _b| Ok(vm.new_bool(true))) + self._cmp(a, b, "__ne__", "__ne__", |vm, a, b| { + Ok(vm.new_bool(!a.is(&b))) + }) } pub fn _lt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - self.call_or_reflection(a, b, "__lt__", "__gt__", |vm, a, b| { + self._cmp(a, b, "__lt__", "__gt__", |vm, a, b| { Err(vm.new_unsupported_operand_error(a, b, "<")) }) } pub fn _le(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - self.call_or_reflection(a, b, "__le__", "__ge__", |vm, a, b| { + self._cmp(a, b, "__le__", "__ge__", |vm, a, b| { Err(vm.new_unsupported_operand_error(a, b, "<=")) }) } pub fn _gt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - self.call_or_reflection(a, b, "__gt__", "__lt__", |vm, a, b| { + self._cmp(a, b, "__gt__", "__lt__", |vm, a, b| { Err(vm.new_unsupported_operand_error(a, b, ">")) }) } pub fn _ge(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - self.call_or_reflection(a, b, "__ge__", "__le__", |vm, a, b| { + self._cmp(a, b, "__ge__", "__le__", |vm, a, b| { Err(vm.new_unsupported_operand_error(a, b, ">=")) }) }