Merge pull request #3347 from AP2008/relocate-rich_compare

Relocate `vm.rich_compare` to `obj.rich_compare`
This commit is contained in:
Jim Fasarakis-Hilliard
2021-10-20 08:49:40 +03:00
committed by GitHub
8 changed files with 86 additions and 89 deletions

View File

@@ -77,7 +77,7 @@ mod _bisect {
while lo < hi {
// Handles issue 13496.
let mid = (lo + hi) / 2;
if vm.bool_cmp(&a.get_item(mid, vm)?, &x, Lt)? {
if a.get_item(mid, vm)?.rich_compare_bool(&x, Lt, vm)? {
lo = mid + 1;
} else {
hi = mid;
@@ -105,7 +105,7 @@ mod _bisect {
while lo < hi {
// Handles issue 13496.
let mid = (lo + hi) / 2;
if vm.bool_cmp(&x, &a.get_item(mid, vm)?, Lt)? {
if x.rich_compare_bool(&a.get_item(mid, vm)?, Lt, vm)? {
hi = mid;
} else {
lo = mid + 1;

View File

@@ -613,7 +613,7 @@ fn do_sort(
} else {
PyComparisonOp::Gt
};
let cmp = |a: &PyObjectRef, b: &PyObjectRef| vm.bool_cmp(a, b, op);
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);
if let Some(ref key_func) = key_func {
let mut items = values

View File

@@ -1772,12 +1772,16 @@ impl ExecutingFrame<'_> {
let b = self.pop_value();
let a = self.pop_value();
let value = match *op {
bytecode::ComparisonOperator::Equal => vm.obj_cmp(a, b, PyComparisonOp::Eq)?,
bytecode::ComparisonOperator::NotEqual => vm.obj_cmp(a, b, PyComparisonOp::Ne)?,
bytecode::ComparisonOperator::Less => vm.obj_cmp(a, b, PyComparisonOp::Lt)?,
bytecode::ComparisonOperator::LessOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Le)?,
bytecode::ComparisonOperator::Greater => vm.obj_cmp(a, b, PyComparisonOp::Gt)?,
bytecode::ComparisonOperator::GreaterOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Ge)?,
bytecode::ComparisonOperator::Equal => a.rich_compare(b, PyComparisonOp::Eq, vm)?,
bytecode::ComparisonOperator::NotEqual => a.rich_compare(b, PyComparisonOp::Ne, vm)?,
bytecode::ComparisonOperator::Less => a.rich_compare(b, PyComparisonOp::Lt, vm)?,
bytecode::ComparisonOperator::LessOrEqual => {
a.rich_compare(b, PyComparisonOp::Le, vm)?
}
bytecode::ComparisonOperator::Greater => a.rich_compare(b, PyComparisonOp::Gt, vm)?,
bytecode::ComparisonOperator::GreaterOrEqual => {
a.rich_compare(b, PyComparisonOp::Ge, vm)?
}
bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)).into(),
bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)).into(),
bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?).into(),

View File

@@ -5,12 +5,14 @@ use crate::{
builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef, PyTupleRef},
bytesinner::ByteInnerNewOptions,
common::{hash::PyHash, str::to_ascii},
function::OptionalArg,
function::{IntoPyObject, OptionalArg},
protocol::PyIter,
pyobject::IdProtocol,
pyref_type_error,
types::{Constructor, PyComparisonOp},
PyObjectRef, PyResult, TryFromObject, TypeProtocol, VirtualMachine,
utils::Either,
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, TryFromObject, TypeProtocol,
VirtualMachine,
};
// RustPython doesn't need these items
@@ -79,11 +81,63 @@ impl PyObjectRef {
self.call_set_attr(vm, attr_name, None)
}
// Perform a comparison, raising TypeError when the requested comparison
// operator is not supported.
// see: CPython PyObject_RichCompare
fn _cmp(
&self,
other: &Self,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<Either<PyObjectRef, bool>> {
let swapped = op.swapped();
let call_cmp = |obj: &PyObjectRef, other, op| {
let cmp = obj
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
Ok(match cmp(obj, other, op, vm)? {
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
Either::B(arithmetic) => arithmetic.map(Either::B),
})
};
let mut checked_reverse_op = false;
let is_strict_subclass = {
let self_class = self.class();
let other_class = other.class();
!self_class.is(&other_class) && other_class.issubclass(&self_class)
};
if is_strict_subclass {
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
checked_reverse_op = true;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
if let PyArithmeticValue::Implemented(x) =
vm.with_recursion("in comparison", || call_cmp(self, other, op))?
{
return Ok(x);
}
if !checked_reverse_op {
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
match op {
PyComparisonOp::Eq => Ok(Either::B(self.is(&other))),
PyComparisonOp::Ne => Ok(Either::B(!self.is(&other))),
_ => Err(vm.new_unsupported_binop_error(self, other, op.operator_token())),
}
}
// PyObject *PyObject_GenericGetDict(PyObject *o, void *context)
// int PyObject_GenericSetDict(PyObject *o, PyObject *value, void *context)
pub fn rich_compare(self, other: Self, opid: PyComparisonOp, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(self, other, opid)
self._cmp(&other, opid, vm).map(|res| res.into_pyobject(vm))
}
pub fn rich_compare_bool(
@@ -92,7 +146,10 @@ impl PyObjectRef {
opid: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<bool> {
vm.bool_cmp(self, other, opid)
match self._cmp(other, opid, vm)? {
Either::A(obj) => obj.try_to_bool(vm),
Either::B(other) => Ok(other),
}
}
pub fn repr(&self, vm: &VirtualMachine) -> PyResult<PyStrRef> {

View File

@@ -477,14 +477,14 @@ mod builtins {
let mut x_key = vm.invoke(key_func, (x.clone(),))?;
for y in candidates_iter {
let y_key = vm.invoke(key_func, (y.clone(),))?;
if vm.bool_cmp(&y_key, &x_key, op)? {
if y_key.rich_compare_bool(&x_key, op, vm)? {
x = y;
x_key = y_key;
}
}
} else {
for y in candidates_iter {
if vm.bool_cmp(&y, &x, op)? {
if y.rich_compare_bool(&x, op, vm)? {
x = y;
}
}

View File

@@ -2379,7 +2379,7 @@ mod _io {
}
};
use crate::types::PyComparisonOp;
if vm.bool_cmp(&cookie, &vm.ctx.new_int(0).into(), PyComparisonOp::Lt)? {
if cookie.rich_compare_bool(&vm.ctx.new_int(0).into(), PyComparisonOp::Lt, vm)? {
return Err(
vm.new_value_error(format!("negative seek position {}", vm.to_repr(&cookie)?))
);

View File

@@ -27,37 +27,37 @@ mod _operator {
/// Same as a < b.
#[pyfunction]
fn lt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Lt)
a.rich_compare(b, Lt, vm)
}
/// Same as a <= b.
#[pyfunction]
fn le(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Le)
a.rich_compare(b, Le, vm)
}
/// Same as a > b.
#[pyfunction]
fn gt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Gt)
a.rich_compare(b, Gt, vm)
}
/// Same as a >= b.
#[pyfunction]
fn ge(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Ge)
a.rich_compare(b, Ge, vm)
}
/// Same as a == b.
#[pyfunction]
fn eq(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Eq)
a.rich_compare(b, Eq, vm)
}
/// Same as a != b.
#[pyfunction]
fn ne(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Ne)
a.rich_compare(b, Ne, vm)
}
/// Same as not a.

View File

@@ -27,7 +27,6 @@ use crate::{
signal::NSIG,
stdlib,
types::PyComparisonOp,
utils::Either,
IdProtocol, ItemProtocol, PyArithmeticValue, PyContext, PyLease, PyMethod, PyObject,
PyObjectRef, PyObjectWrap, PyRef, PyRefExact, PyResult, PyValue, TryFromObject, TypeProtocol,
};
@@ -1786,69 +1785,6 @@ impl VirtualMachine {
.invoke((), self)
}
// Perform a comparison, raising TypeError when the requested comparison
// operator is not supported.
// see: CPython PyObject_RichCompare
fn _cmp(
&self,
v: &PyObjectRef,
w: &PyObjectRef,
op: PyComparisonOp,
) -> PyResult<Either<PyObjectRef, bool>> {
let swapped = op.swapped();
let call_cmp = |obj: &PyObjectRef, other, op| {
let cmp = obj
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
Ok(match cmp(obj, other, op, self)? {
Either::A(obj) => PyArithmeticValue::from_object(self, obj).map(Either::A),
Either::B(arithmetic) => arithmetic.map(Either::B),
})
};
let mut checked_reverse_op = false;
let is_strict_subclass = {
let v_class = v.class();
let w_class = w.class();
!v_class.is(&w_class) && w_class.issubclass(&v_class)
};
if is_strict_subclass {
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
checked_reverse_op = true;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
if let PyArithmeticValue::Implemented(x) =
self.with_recursion("in comparison", || call_cmp(v, w, op))?
{
return Ok(x);
}
if !checked_reverse_op {
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
match op {
PyComparisonOp::Eq => Ok(Either::B(v.is(&w))),
PyComparisonOp::Ne => Ok(Either::B(!v.is(&w))),
_ => Err(self.new_unsupported_binop_error(v, w, op.operator_token())),
}
}
pub fn bool_cmp(&self, a: &PyObjectRef, b: &PyObjectRef, op: PyComparisonOp) -> PyResult<bool> {
match self._cmp(a, b, op)? {
Either::A(obj) => obj.try_to_bool(self),
Either::B(b) => Ok(b),
}
}
pub fn obj_cmp(&self, a: PyObjectRef, b: PyObjectRef, op: PyComparisonOp) -> PyResult {
self._cmp(&a, &b, op).map(|res| res.into_pyobject(self))
}
pub fn obj_len_opt(&self, obj: &PyObjectRef) -> Option<PyResult<usize>> {
self.get_special_method(obj.clone(), "__len__")
.map(Result::ok)
@@ -2019,7 +1955,7 @@ impl VirtualMachine {
}
pub fn bool_eq(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
self.bool_cmp(a, b, PyComparisonOp::Eq)
a.rich_compare_bool(b, PyComparisonOp::Eq, self)
}
pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
@@ -2031,7 +1967,7 @@ impl VirtualMachine {
}
pub fn bool_seq_lt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
let value = if self.bool_cmp(a, b, PyComparisonOp::Lt)? {
let value = if a.rich_compare_bool(b, PyComparisonOp::Lt, self)? {
Some(true)
} else if !self.bool_eq(a, b)? {
Some(false)
@@ -2042,7 +1978,7 @@ impl VirtualMachine {
}
pub fn bool_seq_gt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
let value = if self.bool_cmp(a, b, PyComparisonOp::Gt)? {
let value = if a.rich_compare_bool(b, PyComparisonOp::Gt, self)? {
Some(true)
} else if !self.bool_eq(a, b)? {
Some(false)