rich_compare uses ptr

This commit is contained in:
Jeong YunWon
2021-10-15 06:46:08 +09:00
parent 113ad3bdea
commit ad7925cac6
5 changed files with 48 additions and 26 deletions

View File

@@ -14,8 +14,8 @@ use crate::{
},
utils::Either,
vm::{ReprGuard, VirtualMachine},
IdProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef,
PyResult, PyValue, TryFromObject, TypeProtocol,
IdProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectPtr, PyObjectRef,
PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
};
use std::fmt;
use std::iter::FromIterator;
@@ -305,8 +305,8 @@ impl PyList {
drop(elem_cls);
fn cmp(
elem: &PyObjectRef,
needle: &PyObjectRef,
elem: PyObjectPtr,
needle: PyObjectPtr,
elem_cmp: RichCompareFunc,
needle_cmp: RichCompareFunc,
vm: &VirtualMachine,
@@ -329,14 +329,20 @@ impl PyList {
if elem_cmp as usize == richcompare_wrapper as usize {
let elem = elem.clone();
drop(guard);
cmp(&elem, needle, elem_cmp, needle_cmp, vm)?
PyObjectPtr::with((&elem, needle), |(elem, needle)| {
cmp(elem, needle, elem_cmp, needle_cmp, vm)
})?
} else {
let eq = cmp(elem, needle, elem_cmp, needle_cmp, vm)?;
let eq = PyObjectPtr::with((elem, needle), |(elem, needle)| {
cmp(elem, needle, elem_cmp, needle_cmp, vm)
})?;
borrower = Some(guard);
eq
}
} else {
match needle_cmp(needle, elem, PyComparisonOp::Eq, vm)? {
match PyObjectPtr::with((elem, needle), |(elem, needle)| {
needle_cmp(needle, elem, PyComparisonOp::Eq, vm)
})? {
Either::B(PyComparisonValue::Implemented(value)) => {
drop(elem_cls);
borrower = Some(guard);
@@ -354,8 +360,8 @@ impl PyList {
drop(elem_cls);
fn cmp(
elem: &PyObjectRef,
needle: &PyObjectRef,
elem: PyObjectPtr,
needle: PyObjectPtr,
elem_cmp: RichCompareFunc,
vm: &VirtualMachine,
) -> PyResult<bool> {
@@ -371,9 +377,13 @@ impl PyList {
if elem_cmp as usize == richcompare_wrapper as usize {
let elem = elem.clone();
drop(guard);
cmp(&elem, needle, elem_cmp, vm)?
PyObjectPtr::with((&elem, needle), |(elem, needle)| {
cmp(elem, needle, elem_cmp, vm)
})?
} else {
let eq = cmp(elem, needle, elem_cmp, vm)?;
let eq = PyObjectPtr::with((elem, needle), |(elem, needle)| {
cmp(elem, needle, elem_cmp, vm)
})?;
borrower = Some(guard);
eq
}

View File

@@ -39,12 +39,12 @@ impl PyBaseObject {
#[pyslot]
fn slot_richcompare(
zelf: &PyObjectRef,
other: &PyObjectRef,
zelf: PyObjectPtr,
other: PyObjectPtr,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
Self::cmp(zelf, other, op, vm).map(Either::B)
Self::cmp(&*zelf, &*other, op, vm).map(Either::B)
}
#[inline(always)]
@@ -67,7 +67,9 @@ impl PyBaseObject {
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
let value = match cmp(zelf, other, PyComparisonOp::Eq, vm)? {
let value = match PyObjectPtr::with((zelf, other), |(zelf, other)| {
cmp(zelf, other, PyComparisonOp::Eq, vm)
})? {
Either::A(obj) => PyArithmeticValue::from_object(vm, obj)
.map(|obj| obj.try_to_bool(vm))
.transpose()?,

View File

@@ -93,15 +93,16 @@ impl PyObjectRef {
vm: &VirtualMachine,
) -> PyResult<Either<PyObjectRef, bool>> {
let swapped = op.swapped();
let call_cmp = |obj: &PyObjectRef, other, op| {
let call_cmp = |obj: &PyObjectRef, other: &PyObjectRef, op| {
let cmp = obj
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
Ok(match cmp(obj, other, op, vm)? {
let r = match obj.with_ptr(|obj| other.with_ptr(|other| cmp(obj, other, op, vm)))? {
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
Either::B(arithmetic) => arithmetic.map(Either::B),
})
};
Ok(r)
};
let mut checked_reverse_op = false;

View File

@@ -652,6 +652,15 @@ impl<'a> PyObjectPtr<'a> {
let obj = std::mem::transmute_copy(obj);
Self { obj }
}
// TODO: make variadic sized tuple generic
pub fn with<F, R>(objs: (&PyObjectRef, &PyObjectRef), f: F) -> R
where
F: FnOnce((PyObjectPtr, PyObjectPtr)) -> R,
{
objs.0
.with_ptr(|obj1| objs.1.with_ptr(|obj2| f((obj1, obj2))))
}
}
impl<'a> Deref for PyObjectPtr<'a> {

View File

@@ -136,8 +136,8 @@ pub(crate) type SetattroFunc =
fn(&PyObjectRef, PyStrRef, Option<PyObjectRef>, &VirtualMachine) -> PyResult<()>;
pub(crate) type AsBufferFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult<PyBuffer>;
pub(crate) type RichCompareFunc = fn(
&PyObjectRef,
&PyObjectRef,
PyObjectPtr,
PyObjectPtr,
PyComparisonOp,
&VirtualMachine,
) -> PyResult<Either<PyObjectRef, PyComparisonValue>>;
@@ -226,12 +226,12 @@ fn setattro_wrapper(
}
pub(crate) fn richcompare_wrapper(
zelf: &PyObjectRef,
other: &PyObjectRef,
zelf: PyObjectPtr,
other: PyObjectPtr,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
vm.call_special_method(zelf.clone(), op.method_name(), (other.clone(),))
vm.call_special_method((*zelf).clone(), op.method_name(), ((*other).clone(),))
.map(Either::A)
}
@@ -514,13 +514,13 @@ pub trait Comparable: PyValue {
#[inline]
#[pyslot]
fn slot_richcompare(
zelf: &PyObjectRef,
other: &PyObjectRef,
zelf: PyObjectPtr,
other: PyObjectPtr,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
if let Some(zelf) = zelf.downcast_ref() {
Self::cmp(zelf, other, op, vm).map(Either::B)
Self::cmp(zelf, &*other, op, vm).map(Either::B)
} else {
Err(vm.new_type_error(format!("unexpected payload for {}", op.method_name())))
}