From ad7925cac61a2619a8cd4d71e2f95fffb394cfd3 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 15 Oct 2021 06:46:08 +0900 Subject: [PATCH] rich_compare uses ptr --- vm/src/builtins/list.rs | 32 +++++++++++++++++++++----------- vm/src/builtins/object.rs | 10 ++++++---- vm/src/protocol/object.rs | 7 ++++--- vm/src/pyobjectrc.rs | 9 +++++++++ vm/src/types/slot.rs | 16 ++++++++-------- 5 files changed, 48 insertions(+), 26 deletions(-) diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index bc8b0083c0..f937268df7 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -14,8 +14,8 @@ use crate::{ }, utils::Either, vm::{ReprGuard, VirtualMachine}, - IdProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, - PyResult, PyValue, TryFromObject, TypeProtocol, + IdProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectPtr, PyObjectRef, + PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use std::fmt; use std::iter::FromIterator; @@ -305,8 +305,8 @@ impl PyList { drop(elem_cls); fn cmp( - elem: &PyObjectRef, - needle: &PyObjectRef, + elem: PyObjectPtr, + needle: PyObjectPtr, elem_cmp: RichCompareFunc, needle_cmp: RichCompareFunc, vm: &VirtualMachine, @@ -329,14 +329,20 @@ impl PyList { if elem_cmp as usize == richcompare_wrapper as usize { let elem = elem.clone(); drop(guard); - cmp(&elem, needle, elem_cmp, needle_cmp, vm)? + PyObjectPtr::with((&elem, needle), |(elem, needle)| { + cmp(elem, needle, elem_cmp, needle_cmp, vm) + })? } else { - let eq = cmp(elem, needle, elem_cmp, needle_cmp, vm)?; + let eq = PyObjectPtr::with((elem, needle), |(elem, needle)| { + cmp(elem, needle, elem_cmp, needle_cmp, vm) + })?; borrower = Some(guard); eq } } else { - match needle_cmp(needle, elem, PyComparisonOp::Eq, vm)? { + match PyObjectPtr::with((elem, needle), |(elem, needle)| { + needle_cmp(needle, elem, PyComparisonOp::Eq, vm) + })? { Either::B(PyComparisonValue::Implemented(value)) => { drop(elem_cls); borrower = Some(guard); @@ -354,8 +360,8 @@ impl PyList { drop(elem_cls); fn cmp( - elem: &PyObjectRef, - needle: &PyObjectRef, + elem: PyObjectPtr, + needle: PyObjectPtr, elem_cmp: RichCompareFunc, vm: &VirtualMachine, ) -> PyResult { @@ -371,9 +377,13 @@ impl PyList { if elem_cmp as usize == richcompare_wrapper as usize { let elem = elem.clone(); drop(guard); - cmp(&elem, needle, elem_cmp, vm)? + PyObjectPtr::with((&elem, needle), |(elem, needle)| { + cmp(elem, needle, elem_cmp, vm) + })? } else { - let eq = cmp(elem, needle, elem_cmp, vm)?; + let eq = PyObjectPtr::with((elem, needle), |(elem, needle)| { + cmp(elem, needle, elem_cmp, vm) + })?; borrower = Some(guard); eq } diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index 1b6f5ad52c..2399aa98b8 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -39,12 +39,12 @@ impl PyBaseObject { #[pyslot] fn slot_richcompare( - zelf: &PyObjectRef, - other: &PyObjectRef, + zelf: PyObjectPtr, + other: PyObjectPtr, op: PyComparisonOp, vm: &VirtualMachine, ) -> PyResult> { - Self::cmp(zelf, other, op, vm).map(Either::B) + Self::cmp(&*zelf, &*other, op, vm).map(Either::B) } #[inline(always)] @@ -67,7 +67,9 @@ impl PyBaseObject { .class() .mro_find_map(|cls| cls.slots.richcompare.load()) .unwrap(); - let value = match cmp(zelf, other, PyComparisonOp::Eq, vm)? { + let value = match PyObjectPtr::with((zelf, other), |(zelf, other)| { + cmp(zelf, other, PyComparisonOp::Eq, vm) + })? { Either::A(obj) => PyArithmeticValue::from_object(vm, obj) .map(|obj| obj.try_to_bool(vm)) .transpose()?, diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index e905803c9a..90b93bd1c3 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -93,15 +93,16 @@ impl PyObjectRef { vm: &VirtualMachine, ) -> PyResult> { let swapped = op.swapped(); - let call_cmp = |obj: &PyObjectRef, other, op| { + let call_cmp = |obj: &PyObjectRef, other: &PyObjectRef, op| { let cmp = obj .class() .mro_find_map(|cls| cls.slots.richcompare.load()) .unwrap(); - Ok(match cmp(obj, other, op, vm)? { + let r = match obj.with_ptr(|obj| other.with_ptr(|other| cmp(obj, other, op, vm)))? { Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A), Either::B(arithmetic) => arithmetic.map(Either::B), - }) + }; + Ok(r) }; let mut checked_reverse_op = false; diff --git a/vm/src/pyobjectrc.rs b/vm/src/pyobjectrc.rs index 14a9fa55d5..d6fd0d6550 100644 --- a/vm/src/pyobjectrc.rs +++ b/vm/src/pyobjectrc.rs @@ -652,6 +652,15 @@ impl<'a> PyObjectPtr<'a> { let obj = std::mem::transmute_copy(obj); Self { obj } } + + // TODO: make variadic sized tuple generic + pub fn with(objs: (&PyObjectRef, &PyObjectRef), f: F) -> R + where + F: FnOnce((PyObjectPtr, PyObjectPtr)) -> R, + { + objs.0 + .with_ptr(|obj1| objs.1.with_ptr(|obj2| f((obj1, obj2)))) + } } impl<'a> Deref for PyObjectPtr<'a> { diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 7e8f4880d2..2285ad3502 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -136,8 +136,8 @@ pub(crate) type SetattroFunc = fn(&PyObjectRef, PyStrRef, Option, &VirtualMachine) -> PyResult<()>; pub(crate) type AsBufferFunc = fn(&PyObjectRef, &VirtualMachine) -> PyResult; pub(crate) type RichCompareFunc = fn( - &PyObjectRef, - &PyObjectRef, + PyObjectPtr, + PyObjectPtr, PyComparisonOp, &VirtualMachine, ) -> PyResult>; @@ -226,12 +226,12 @@ fn setattro_wrapper( } pub(crate) fn richcompare_wrapper( - zelf: &PyObjectRef, - other: &PyObjectRef, + zelf: PyObjectPtr, + other: PyObjectPtr, op: PyComparisonOp, vm: &VirtualMachine, ) -> PyResult> { - vm.call_special_method(zelf.clone(), op.method_name(), (other.clone(),)) + vm.call_special_method((*zelf).clone(), op.method_name(), ((*other).clone(),)) .map(Either::A) } @@ -514,13 +514,13 @@ pub trait Comparable: PyValue { #[inline] #[pyslot] fn slot_richcompare( - zelf: &PyObjectRef, - other: &PyObjectRef, + zelf: PyObjectPtr, + other: PyObjectPtr, op: PyComparisonOp, vm: &VirtualMachine, ) -> PyResult> { if let Some(zelf) = zelf.downcast_ref() { - Self::cmp(zelf, other, op, vm).map(Either::B) + Self::cmp(zelf, &*other, op, vm).map(Either::B) } else { Err(vm.new_type_error(format!("unexpected payload for {}", op.method_name()))) }