forked from Rust-related/RustPython
Merge pull request #3347 from AP2008/relocate-rich_compare
Relocate `vm.rich_compare` to `obj.rich_compare`
This commit is contained in:
@@ -77,7 +77,7 @@ mod _bisect {
|
||||
while lo < hi {
|
||||
// Handles issue 13496.
|
||||
let mid = (lo + hi) / 2;
|
||||
if vm.bool_cmp(&a.get_item(mid, vm)?, &x, Lt)? {
|
||||
if a.get_item(mid, vm)?.rich_compare_bool(&x, Lt, vm)? {
|
||||
lo = mid + 1;
|
||||
} else {
|
||||
hi = mid;
|
||||
@@ -105,7 +105,7 @@ mod _bisect {
|
||||
while lo < hi {
|
||||
// Handles issue 13496.
|
||||
let mid = (lo + hi) / 2;
|
||||
if vm.bool_cmp(&x, &a.get_item(mid, vm)?, Lt)? {
|
||||
if x.rich_compare_bool(&a.get_item(mid, vm)?, Lt, vm)? {
|
||||
hi = mid;
|
||||
} else {
|
||||
lo = mid + 1;
|
||||
|
||||
@@ -613,7 +613,7 @@ fn do_sort(
|
||||
} else {
|
||||
PyComparisonOp::Gt
|
||||
};
|
||||
let cmp = |a: &PyObjectRef, b: &PyObjectRef| vm.bool_cmp(a, b, op);
|
||||
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);
|
||||
|
||||
if let Some(ref key_func) = key_func {
|
||||
let mut items = values
|
||||
|
||||
@@ -1772,12 +1772,16 @@ impl ExecutingFrame<'_> {
|
||||
let b = self.pop_value();
|
||||
let a = self.pop_value();
|
||||
let value = match *op {
|
||||
bytecode::ComparisonOperator::Equal => vm.obj_cmp(a, b, PyComparisonOp::Eq)?,
|
||||
bytecode::ComparisonOperator::NotEqual => vm.obj_cmp(a, b, PyComparisonOp::Ne)?,
|
||||
bytecode::ComparisonOperator::Less => vm.obj_cmp(a, b, PyComparisonOp::Lt)?,
|
||||
bytecode::ComparisonOperator::LessOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Le)?,
|
||||
bytecode::ComparisonOperator::Greater => vm.obj_cmp(a, b, PyComparisonOp::Gt)?,
|
||||
bytecode::ComparisonOperator::GreaterOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Ge)?,
|
||||
bytecode::ComparisonOperator::Equal => a.rich_compare(b, PyComparisonOp::Eq, vm)?,
|
||||
bytecode::ComparisonOperator::NotEqual => a.rich_compare(b, PyComparisonOp::Ne, vm)?,
|
||||
bytecode::ComparisonOperator::Less => a.rich_compare(b, PyComparisonOp::Lt, vm)?,
|
||||
bytecode::ComparisonOperator::LessOrEqual => {
|
||||
a.rich_compare(b, PyComparisonOp::Le, vm)?
|
||||
}
|
||||
bytecode::ComparisonOperator::Greater => a.rich_compare(b, PyComparisonOp::Gt, vm)?,
|
||||
bytecode::ComparisonOperator::GreaterOrEqual => {
|
||||
a.rich_compare(b, PyComparisonOp::Ge, vm)?
|
||||
}
|
||||
bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)).into(),
|
||||
bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)).into(),
|
||||
bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?).into(),
|
||||
|
||||
@@ -5,12 +5,14 @@ use crate::{
|
||||
builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef, PyTupleRef},
|
||||
bytesinner::ByteInnerNewOptions,
|
||||
common::{hash::PyHash, str::to_ascii},
|
||||
function::OptionalArg,
|
||||
function::{IntoPyObject, OptionalArg},
|
||||
protocol::PyIter,
|
||||
pyobject::IdProtocol,
|
||||
pyref_type_error,
|
||||
types::{Constructor, PyComparisonOp},
|
||||
PyObjectRef, PyResult, TryFromObject, TypeProtocol, VirtualMachine,
|
||||
utils::Either,
|
||||
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, TryFromObject, TypeProtocol,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
// RustPython doesn't need these items
|
||||
@@ -79,11 +81,63 @@ impl PyObjectRef {
|
||||
self.call_set_attr(vm, attr_name, None)
|
||||
}
|
||||
|
||||
// Perform a comparison, raising TypeError when the requested comparison
|
||||
// operator is not supported.
|
||||
// see: CPython PyObject_RichCompare
|
||||
fn _cmp(
|
||||
&self,
|
||||
other: &Self,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Either<PyObjectRef, bool>> {
|
||||
let swapped = op.swapped();
|
||||
let call_cmp = |obj: &PyObjectRef, other, op| {
|
||||
let cmp = obj
|
||||
.class()
|
||||
.mro_find_map(|cls| cls.slots.richcompare.load())
|
||||
.unwrap();
|
||||
Ok(match cmp(obj, other, op, vm)? {
|
||||
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
|
||||
Either::B(arithmetic) => arithmetic.map(Either::B),
|
||||
})
|
||||
};
|
||||
|
||||
let mut checked_reverse_op = false;
|
||||
let is_strict_subclass = {
|
||||
let self_class = self.class();
|
||||
let other_class = other.class();
|
||||
!self_class.is(&other_class) && other_class.issubclass(&self_class)
|
||||
};
|
||||
if is_strict_subclass {
|
||||
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
|
||||
checked_reverse_op = true;
|
||||
if let PyArithmeticValue::Implemented(x) = res {
|
||||
return Ok(x);
|
||||
}
|
||||
}
|
||||
if let PyArithmeticValue::Implemented(x) =
|
||||
vm.with_recursion("in comparison", || call_cmp(self, other, op))?
|
||||
{
|
||||
return Ok(x);
|
||||
}
|
||||
if !checked_reverse_op {
|
||||
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
|
||||
if let PyArithmeticValue::Implemented(x) = res {
|
||||
return Ok(x);
|
||||
}
|
||||
}
|
||||
match op {
|
||||
PyComparisonOp::Eq => Ok(Either::B(self.is(&other))),
|
||||
PyComparisonOp::Ne => Ok(Either::B(!self.is(&other))),
|
||||
_ => Err(vm.new_unsupported_binop_error(self, other, op.operator_token())),
|
||||
}
|
||||
}
|
||||
|
||||
// PyObject *PyObject_GenericGetDict(PyObject *o, void *context)
|
||||
// int PyObject_GenericSetDict(PyObject *o, PyObject *value, void *context)
|
||||
|
||||
pub fn rich_compare(self, other: Self, opid: PyComparisonOp, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(self, other, opid)
|
||||
self._cmp(&other, opid, vm).map(|res| res.into_pyobject(vm))
|
||||
}
|
||||
|
||||
pub fn rich_compare_bool(
|
||||
@@ -92,7 +146,10 @@ impl PyObjectRef {
|
||||
opid: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<bool> {
|
||||
vm.bool_cmp(self, other, opid)
|
||||
match self._cmp(other, opid, vm)? {
|
||||
Either::A(obj) => obj.try_to_bool(vm),
|
||||
Either::B(other) => Ok(other),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn repr(&self, vm: &VirtualMachine) -> PyResult<PyStrRef> {
|
||||
|
||||
@@ -477,14 +477,14 @@ mod builtins {
|
||||
let mut x_key = vm.invoke(key_func, (x.clone(),))?;
|
||||
for y in candidates_iter {
|
||||
let y_key = vm.invoke(key_func, (y.clone(),))?;
|
||||
if vm.bool_cmp(&y_key, &x_key, op)? {
|
||||
if y_key.rich_compare_bool(&x_key, op, vm)? {
|
||||
x = y;
|
||||
x_key = y_key;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for y in candidates_iter {
|
||||
if vm.bool_cmp(&y, &x, op)? {
|
||||
if y.rich_compare_bool(&x, op, vm)? {
|
||||
x = y;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2379,7 +2379,7 @@ mod _io {
|
||||
}
|
||||
};
|
||||
use crate::types::PyComparisonOp;
|
||||
if vm.bool_cmp(&cookie, &vm.ctx.new_int(0).into(), PyComparisonOp::Lt)? {
|
||||
if cookie.rich_compare_bool(&vm.ctx.new_int(0).into(), PyComparisonOp::Lt, vm)? {
|
||||
return Err(
|
||||
vm.new_value_error(format!("negative seek position {}", vm.to_repr(&cookie)?))
|
||||
);
|
||||
|
||||
@@ -27,37 +27,37 @@ mod _operator {
|
||||
/// Same as a < b.
|
||||
#[pyfunction]
|
||||
fn lt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(a, b, Lt)
|
||||
a.rich_compare(b, Lt, vm)
|
||||
}
|
||||
|
||||
/// Same as a <= b.
|
||||
#[pyfunction]
|
||||
fn le(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(a, b, Le)
|
||||
a.rich_compare(b, Le, vm)
|
||||
}
|
||||
|
||||
/// Same as a > b.
|
||||
#[pyfunction]
|
||||
fn gt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(a, b, Gt)
|
||||
a.rich_compare(b, Gt, vm)
|
||||
}
|
||||
|
||||
/// Same as a >= b.
|
||||
#[pyfunction]
|
||||
fn ge(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(a, b, Ge)
|
||||
a.rich_compare(b, Ge, vm)
|
||||
}
|
||||
|
||||
/// Same as a == b.
|
||||
#[pyfunction]
|
||||
fn eq(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(a, b, Eq)
|
||||
a.rich_compare(b, Eq, vm)
|
||||
}
|
||||
|
||||
/// Same as a != b.
|
||||
#[pyfunction]
|
||||
fn ne(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
|
||||
vm.obj_cmp(a, b, Ne)
|
||||
a.rich_compare(b, Ne, vm)
|
||||
}
|
||||
|
||||
/// Same as not a.
|
||||
|
||||
70
vm/src/vm.rs
70
vm/src/vm.rs
@@ -27,7 +27,6 @@ use crate::{
|
||||
signal::NSIG,
|
||||
stdlib,
|
||||
types::PyComparisonOp,
|
||||
utils::Either,
|
||||
IdProtocol, ItemProtocol, PyArithmeticValue, PyContext, PyLease, PyMethod, PyObject,
|
||||
PyObjectRef, PyObjectWrap, PyRef, PyRefExact, PyResult, PyValue, TryFromObject, TypeProtocol,
|
||||
};
|
||||
@@ -1786,69 +1785,6 @@ impl VirtualMachine {
|
||||
.invoke((), self)
|
||||
}
|
||||
|
||||
// Perform a comparison, raising TypeError when the requested comparison
|
||||
// operator is not supported.
|
||||
// see: CPython PyObject_RichCompare
|
||||
fn _cmp(
|
||||
&self,
|
||||
v: &PyObjectRef,
|
||||
w: &PyObjectRef,
|
||||
op: PyComparisonOp,
|
||||
) -> PyResult<Either<PyObjectRef, bool>> {
|
||||
let swapped = op.swapped();
|
||||
let call_cmp = |obj: &PyObjectRef, other, op| {
|
||||
let cmp = obj
|
||||
.class()
|
||||
.mro_find_map(|cls| cls.slots.richcompare.load())
|
||||
.unwrap();
|
||||
Ok(match cmp(obj, other, op, self)? {
|
||||
Either::A(obj) => PyArithmeticValue::from_object(self, obj).map(Either::A),
|
||||
Either::B(arithmetic) => arithmetic.map(Either::B),
|
||||
})
|
||||
};
|
||||
|
||||
let mut checked_reverse_op = false;
|
||||
let is_strict_subclass = {
|
||||
let v_class = v.class();
|
||||
let w_class = w.class();
|
||||
!v_class.is(&w_class) && w_class.issubclass(&v_class)
|
||||
};
|
||||
if is_strict_subclass {
|
||||
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
|
||||
checked_reverse_op = true;
|
||||
if let PyArithmeticValue::Implemented(x) = res {
|
||||
return Ok(x);
|
||||
}
|
||||
}
|
||||
if let PyArithmeticValue::Implemented(x) =
|
||||
self.with_recursion("in comparison", || call_cmp(v, w, op))?
|
||||
{
|
||||
return Ok(x);
|
||||
}
|
||||
if !checked_reverse_op {
|
||||
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
|
||||
if let PyArithmeticValue::Implemented(x) = res {
|
||||
return Ok(x);
|
||||
}
|
||||
}
|
||||
match op {
|
||||
PyComparisonOp::Eq => Ok(Either::B(v.is(&w))),
|
||||
PyComparisonOp::Ne => Ok(Either::B(!v.is(&w))),
|
||||
_ => Err(self.new_unsupported_binop_error(v, w, op.operator_token())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bool_cmp(&self, a: &PyObjectRef, b: &PyObjectRef, op: PyComparisonOp) -> PyResult<bool> {
|
||||
match self._cmp(a, b, op)? {
|
||||
Either::A(obj) => obj.try_to_bool(self),
|
||||
Either::B(b) => Ok(b),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn obj_cmp(&self, a: PyObjectRef, b: PyObjectRef, op: PyComparisonOp) -> PyResult {
|
||||
self._cmp(&a, &b, op).map(|res| res.into_pyobject(self))
|
||||
}
|
||||
|
||||
pub fn obj_len_opt(&self, obj: &PyObjectRef) -> Option<PyResult<usize>> {
|
||||
self.get_special_method(obj.clone(), "__len__")
|
||||
.map(Result::ok)
|
||||
@@ -2019,7 +1955,7 @@ impl VirtualMachine {
|
||||
}
|
||||
|
||||
pub fn bool_eq(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
|
||||
self.bool_cmp(a, b, PyComparisonOp::Eq)
|
||||
a.rich_compare_bool(b, PyComparisonOp::Eq, self)
|
||||
}
|
||||
|
||||
pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
|
||||
@@ -2031,7 +1967,7 @@ impl VirtualMachine {
|
||||
}
|
||||
|
||||
pub fn bool_seq_lt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
|
||||
let value = if self.bool_cmp(a, b, PyComparisonOp::Lt)? {
|
||||
let value = if a.rich_compare_bool(b, PyComparisonOp::Lt, self)? {
|
||||
Some(true)
|
||||
} else if !self.bool_eq(a, b)? {
|
||||
Some(false)
|
||||
@@ -2042,7 +1978,7 @@ impl VirtualMachine {
|
||||
}
|
||||
|
||||
pub fn bool_seq_gt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
|
||||
let value = if self.bool_cmp(a, b, PyComparisonOp::Gt)? {
|
||||
let value = if a.rich_compare_bool(b, PyComparisonOp::Gt, self)? {
|
||||
Some(true)
|
||||
} else if !self.bool_eq(a, b)? {
|
||||
Some(false)
|
||||
|
||||
Reference in New Issue
Block a user