From c427462554149537e5aa535cbf148d9c01fd911b Mon Sep 17 00:00:00 2001 From: Aratrik Date: Tue, 19 Oct 2021 23:39:12 +0530 Subject: [PATCH] Relocate vm.rich_compare to obj.rich_compare --- stdlib/src/bisect.rs | 4 +-- vm/src/builtins/list.rs | 2 +- vm/src/frame.rs | 16 +++++---- vm/src/protocol/object.rs | 65 +++++++++++++++++++++++++++++++++--- vm/src/stdlib/builtins.rs | 4 +-- vm/src/stdlib/io.rs | 2 +- vm/src/stdlib/operator.rs | 12 +++---- vm/src/vm.rs | 70 ++------------------------------------- 8 files changed, 86 insertions(+), 89 deletions(-) diff --git a/stdlib/src/bisect.rs b/stdlib/src/bisect.rs index 67b3d12902..bd313c1647 100644 --- a/stdlib/src/bisect.rs +++ b/stdlib/src/bisect.rs @@ -77,7 +77,7 @@ mod _bisect { while lo < hi { // Handles issue 13496. let mid = (lo + hi) / 2; - if vm.bool_cmp(&a.get_item(mid, vm)?, &x, Lt)? { + if a.get_item(mid, vm)?.rich_compare_bool(&x, Lt, vm)? { lo = mid + 1; } else { hi = mid; @@ -105,7 +105,7 @@ mod _bisect { while lo < hi { // Handles issue 13496. let mid = (lo + hi) / 2; - if vm.bool_cmp(&x, &a.get_item(mid, vm)?, Lt)? { + if x.rich_compare_bool(&a.get_item(mid, vm)?, Lt, vm)? { hi = mid; } else { lo = mid + 1; diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 027bca90bc..015bf91819 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -613,7 +613,7 @@ fn do_sort( } else { PyComparisonOp::Gt }; - let cmp = |a: &PyObjectRef, b: &PyObjectRef| vm.bool_cmp(a, b, op); + let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm); if let Some(ref key_func) = key_func { let mut items = values diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 5084c2a0bb..aead471ffe 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1772,12 +1772,16 @@ impl ExecutingFrame<'_> { let b = self.pop_value(); let a = self.pop_value(); let value = match *op { - bytecode::ComparisonOperator::Equal => vm.obj_cmp(a, b, PyComparisonOp::Eq)?, - bytecode::ComparisonOperator::NotEqual => vm.obj_cmp(a, b, PyComparisonOp::Ne)?, - bytecode::ComparisonOperator::Less => vm.obj_cmp(a, b, PyComparisonOp::Lt)?, - bytecode::ComparisonOperator::LessOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Le)?, - bytecode::ComparisonOperator::Greater => vm.obj_cmp(a, b, PyComparisonOp::Gt)?, - bytecode::ComparisonOperator::GreaterOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Ge)?, + bytecode::ComparisonOperator::Equal => a.rich_compare(b, PyComparisonOp::Eq, vm)?, + bytecode::ComparisonOperator::NotEqual => a.rich_compare(b, PyComparisonOp::Ne, vm)?, + bytecode::ComparisonOperator::Less => a.rich_compare(b, PyComparisonOp::Lt, vm)?, + bytecode::ComparisonOperator::LessOrEqual => { + a.rich_compare(b, PyComparisonOp::Le, vm)? + } + bytecode::ComparisonOperator::Greater => a.rich_compare(b, PyComparisonOp::Gt, vm)?, + bytecode::ComparisonOperator::GreaterOrEqual => { + a.rich_compare(b, PyComparisonOp::Ge, vm)? + } bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)).into(), bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)).into(), bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?).into(), diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 7bcd081ca2..00ae94fc05 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -5,11 +5,13 @@ use crate::{ builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef}, bytesinner::ByteInnerNewOptions, common::{hash::PyHash, str::to_ascii}, - function::OptionalArg, + function::{IntoPyObject, OptionalArg}, protocol::PyIter, pyref_type_error, types::{Constructor, PyComparisonOp}, - PyObjectRef, PyResult, TryFromObject, TypeProtocol, VirtualMachine, + utils::Either, + IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, TryFromObject, TypeProtocol, + VirtualMachine, }; // RustPython doesn't need these items @@ -78,11 +80,63 @@ impl PyObjectRef { self.call_set_attr(vm, attr_name, None) } + // Perform a comparison, raising TypeError when the requested comparison + // operator is not supported. + // see: CPython PyObject_RichCompare + fn _cmp( + &self, + other: &Self, + op: PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult> { + let swapped = op.swapped(); + let call_cmp = |obj: &PyObjectRef, other, op| { + let cmp = obj + .class() + .mro_find_map(|cls| cls.slots.richcompare.load()) + .unwrap(); + Ok(match cmp(obj, other, op, vm)? { + Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A), + Either::B(arithmetic) => arithmetic.map(Either::B), + }) + }; + + let mut checked_reverse_op = false; + let is_strict_subclass = { + let self_class = self.class(); + let other_class = other.class(); + !self_class.is(&other_class) && other_class.issubclass(&self_class) + }; + if is_strict_subclass { + let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?; + checked_reverse_op = true; + if let PyArithmeticValue::Implemented(x) = res { + return Ok(x); + } + } + if let PyArithmeticValue::Implemented(x) = + vm.with_recursion("in comparison", || call_cmp(self, other, op))? + { + return Ok(x); + } + if !checked_reverse_op { + let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?; + if let PyArithmeticValue::Implemented(x) = res { + return Ok(x); + } + } + match op { + PyComparisonOp::Eq => Ok(Either::B(self.is(&other))), + PyComparisonOp::Ne => Ok(Either::B(!self.is(&other))), + _ => Err(vm.new_unsupported_binop_error(self, other, op.operator_token())), + } + } + // PyObject *PyObject_GenericGetDict(PyObject *o, void *context) // int PyObject_GenericSetDict(PyObject *o, PyObject *value, void *context) pub fn rich_compare(self, other: Self, opid: PyComparisonOp, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(self, other, opid) + self._cmp(&other, opid, vm).map(|res| res.into_pyobject(vm)) } pub fn rich_compare_bool( @@ -91,7 +145,10 @@ impl PyObjectRef { opid: PyComparisonOp, vm: &VirtualMachine, ) -> PyResult { - vm.bool_cmp(self, other, opid) + match self._cmp(other, opid, vm)? { + Either::A(obj) => obj.try_to_bool(vm), + Either::B(other) => Ok(other), + } } pub fn repr(&self, vm: &VirtualMachine) -> PyResult { diff --git a/vm/src/stdlib/builtins.rs b/vm/src/stdlib/builtins.rs index 83020cb3ad..9ad9c0d791 100644 --- a/vm/src/stdlib/builtins.rs +++ b/vm/src/stdlib/builtins.rs @@ -477,14 +477,14 @@ mod builtins { let mut x_key = vm.invoke(key_func, (x.clone(),))?; for y in candidates_iter { let y_key = vm.invoke(key_func, (y.clone(),))?; - if vm.bool_cmp(&y_key, &x_key, op)? { + if y_key.rich_compare_bool(&x_key, op, vm)? { x = y; x_key = y_key; } } } else { for y in candidates_iter { - if vm.bool_cmp(&y, &x, op)? { + if y.rich_compare_bool(&x, op, vm)? { x = y; } } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index a6befda11f..63cd33989f 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -2379,7 +2379,7 @@ mod _io { } }; use crate::types::PyComparisonOp; - if vm.bool_cmp(&cookie, &vm.ctx.new_int(0).into(), PyComparisonOp::Lt)? { + if cookie.rich_compare_bool(&vm.ctx.new_int(0).into(), PyComparisonOp::Lt, vm)? { return Err( vm.new_value_error(format!("negative seek position {}", vm.to_repr(&cookie)?)) ); diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs index 8f2e562b1d..8aef4d351a 100644 --- a/vm/src/stdlib/operator.rs +++ b/vm/src/stdlib/operator.rs @@ -27,37 +27,37 @@ mod _operator { /// Same as a < b. #[pyfunction] fn lt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(a, b, Lt) + a.rich_compare(b, Lt, vm) } /// Same as a <= b. #[pyfunction] fn le(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(a, b, Le) + a.rich_compare(b, Le, vm) } /// Same as a > b. #[pyfunction] fn gt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(a, b, Gt) + a.rich_compare(b, Gt, vm) } /// Same as a >= b. #[pyfunction] fn ge(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(a, b, Ge) + a.rich_compare(b, Ge, vm) } /// Same as a == b. #[pyfunction] fn eq(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(a, b, Eq) + a.rich_compare(b, Eq, vm) } /// Same as a != b. #[pyfunction] fn ne(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { - vm.obj_cmp(a, b, Ne) + a.rich_compare(b, Ne, vm) } /// Same as not a. diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 6f11bea8f2..68bf7bedaf 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -27,7 +27,6 @@ use crate::{ signal::NSIG, stdlib, types::PyComparisonOp, - utils::Either, IdProtocol, ItemProtocol, PyArithmeticValue, PyContext, PyLease, PyMethod, PyObject, PyObjectRef, PyObjectWrap, PyRef, PyRefExact, PyResult, PyValue, TryFromObject, TypeProtocol, }; @@ -1817,69 +1816,6 @@ impl VirtualMachine { .invoke((), self) } - // Perform a comparison, raising TypeError when the requested comparison - // operator is not supported. - // see: CPython PyObject_RichCompare - fn _cmp( - &self, - v: &PyObjectRef, - w: &PyObjectRef, - op: PyComparisonOp, - ) -> PyResult> { - let swapped = op.swapped(); - let call_cmp = |obj: &PyObjectRef, other, op| { - let cmp = obj - .class() - .mro_find_map(|cls| cls.slots.richcompare.load()) - .unwrap(); - Ok(match cmp(obj, other, op, self)? { - Either::A(obj) => PyArithmeticValue::from_object(self, obj).map(Either::A), - Either::B(arithmetic) => arithmetic.map(Either::B), - }) - }; - - let mut checked_reverse_op = false; - let is_strict_subclass = { - let v_class = v.class(); - let w_class = w.class(); - !v_class.is(&w_class) && w_class.issubclass(&v_class) - }; - if is_strict_subclass { - let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?; - checked_reverse_op = true; - if let PyArithmeticValue::Implemented(x) = res { - return Ok(x); - } - } - if let PyArithmeticValue::Implemented(x) = - self.with_recursion("in comparison", || call_cmp(v, w, op))? - { - return Ok(x); - } - if !checked_reverse_op { - let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?; - if let PyArithmeticValue::Implemented(x) = res { - return Ok(x); - } - } - match op { - PyComparisonOp::Eq => Ok(Either::B(v.is(&w))), - PyComparisonOp::Ne => Ok(Either::B(!v.is(&w))), - _ => Err(self.new_unsupported_binop_error(v, w, op.operator_token())), - } - } - - pub fn bool_cmp(&self, a: &PyObjectRef, b: &PyObjectRef, op: PyComparisonOp) -> PyResult { - match self._cmp(a, b, op)? { - Either::A(obj) => obj.try_to_bool(self), - Either::B(b) => Ok(b), - } - } - - pub fn obj_cmp(&self, a: PyObjectRef, b: PyObjectRef, op: PyComparisonOp) -> PyResult { - self._cmp(&a, &b, op).map(|res| res.into_pyobject(self)) - } - pub fn obj_len_opt(&self, obj: &PyObjectRef) -> Option> { self.get_special_method(obj.clone(), "__len__") .map(Result::ok) @@ -2050,7 +1986,7 @@ impl VirtualMachine { } pub fn bool_eq(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult { - self.bool_cmp(a, b, PyComparisonOp::Eq) + a.rich_compare_bool(b, PyComparisonOp::Eq, self) } pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult { @@ -2062,7 +1998,7 @@ impl VirtualMachine { } pub fn bool_seq_lt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult> { - let value = if self.bool_cmp(a, b, PyComparisonOp::Lt)? { + let value = if a.rich_compare_bool(b, PyComparisonOp::Lt, self)? { Some(true) } else if !self.bool_eq(a, b)? { Some(false) @@ -2073,7 +2009,7 @@ impl VirtualMachine { } pub fn bool_seq_gt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult> { - let value = if self.bool_cmp(a, b, PyComparisonOp::Gt)? { + let value = if a.rich_compare_bool(b, PyComparisonOp::Gt, self)? { Some(true) } else if !self.bool_eq(a, b)? { Some(false)