diff --git a/vm/src/pyobjectrc.rs b/vm/src/pyobjectrc.rs index 2a38e660c4..8925ba1ba9 100644 --- a/vm/src/pyobjectrc.rs +++ b/vm/src/pyobjectrc.rs @@ -122,6 +122,8 @@ impl Drop for PyObject { } } +type PyObjectRefInner = PyObject; + /// The `PyObjectRef` is one of the most used types. It is a reference to a /// python object. A single python object can have multiple references, and /// this reference counting is accounted for by this type. Use the `.clone()` @@ -130,13 +132,13 @@ impl Drop for PyObject { #[derive(Clone)] #[repr(transparent)] pub struct PyObjectRef { - rc: PyRc>, + rc: PyRc, } #[derive(Clone)] #[repr(transparent)] pub struct PyObjectWeak { - weak: PyWeak>, + weak: PyWeak, } pub trait PyObjectWrap @@ -175,7 +177,7 @@ impl PyObjectRef { fn new(value: PyObject) -> Self { let inner = PyRc::into_raw(PyRc::new(value)); - let rc = unsafe { PyRc::from_raw(inner as *const PyObject) }; + let rc = unsafe { PyRc::from_raw(inner as *const PyObjectRefInner) }; Self { rc } } @@ -306,6 +308,17 @@ impl PyObjectRef { None } } + + #[inline] + pub fn with_ptr<'a, F, R>(&'a self, f: F) -> R + where + F: FnOnce(PyObjectPtr<'a>) -> R, + { + unsafe { + // SAFETY: self will be alive until f is done + f(PyObjectPtr::new(self)) + } + } } impl AsRef for PyObjectRef { @@ -625,6 +638,33 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef) { (type_type, object_type) } +#[derive(Clone, Copy)] +pub struct PyObjectPtr<'a> { + obj: &'a PyObjectRefInner, +} + +impl<'a> PyObjectPtr<'a> { + /// # Safety + /// + /// `obj` *MUST* be alive until this ptr is destroyed. + /// Do not directly call this function without helper functions. + unsafe fn new(obj: &PyObjectRef) -> Self { + let obj = std::mem::transmute_copy(obj); + Self { obj } + } +} + +impl<'a> Deref for PyObjectPtr<'a> { + type Target = PyObjectRef; + + fn deref(&self) -> &Self::Target { + unsafe { + // SAFETY: only when PyObjectRef = PyRc + std::mem::transmute(&self.obj) + } + } +} + #[cfg(test)] mod tests { use super::*;